diff --git a/pkg/dialer/http3/conn.go b/pkg/dialer/http3/conn.go deleted file mode 100644 index 6343e64..0000000 --- a/pkg/dialer/http3/conn.go +++ /dev/null @@ -1,148 +0,0 @@ -package http3 - -import ( - "bufio" - "bytes" - "encoding/base64" - "errors" - "fmt" - "net" - "net/http" - "time" - - "github.com/go-gost/gost/pkg/logger" -) - -type conn struct { - cid string - addr string - client *http.Client - buf []byte - rxc chan []byte - closed chan struct{} - md metadata - logger logger.Logger -} - -func (c *conn) Read(b []byte) (n int, err error) { - if len(c.buf) == 0 { - select { - case c.buf = <-c.rxc: - case <-c.closed: - err = net.ErrClosed - return - } - } - - n = copy(b, c.buf) - c.buf = c.buf[n:] - - return -} - -func (c *conn) Write(b []byte) (n int, err error) { - if len(b) == 0 { - return - } - - buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) - buf.WriteByte('\n') - - url := fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pushPath, c.cid) - r, err := http.NewRequest(http.MethodPost, url, buf) - if err != nil { - return - } - - resp, err := c.client.Do(r) - if err != nil { - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - err = errors.New(resp.Status) - return - } - - n = len(b) - return -} - -func (c *conn) readLoop() { - defer c.Close() - - url := fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pullPath, c.cid) - for { - err := func() error { - r, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - - resp, err := c.client.Do(r) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.New(resp.Status) - } - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - b, err := base64.StdEncoding.DecodeString(scanner.Text()) - if err != nil { - return err - } - select { - case c.rxc <- b: - case <-c.closed: - return net.ErrClosed - } - } - - return scanner.Err() - }() - - if err != nil { - c.logger.Error(err) - return - } - } -} - -func (c *conn) LocalAddr() net.Addr { - return &net.TCPAddr{} -} - -func (c *conn) RemoteAddr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", c.addr) - if addr == nil { - addr = &net.TCPAddr{} - } - - return addr -} - -func (c *conn) Close() error { - select { - case <-c.closed: - default: - close(c.closed) - } - return nil -} - -func (c *conn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *conn) SetWriteDeadline(t time.Time) error { - return nil -} - -func (c *conn) SetDeadline(t time.Time) error { - return nil -} diff --git a/pkg/dialer/http3/dialer.go b/pkg/dialer/http3/dialer.go index 0303638..601131c 100644 --- a/pkg/dialer/http3/dialer.go +++ b/pkg/dialer/http3/dialer.go @@ -2,16 +2,12 @@ package http3 import ( "context" - "errors" - "fmt" - "io" "net" "net/http" - "net/http/httputil" - "strings" "time" "github.com/go-gost/gost/pkg/dialer" + pht_util "github.com/go-gost/gost/pkg/internal/util/pht" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -23,7 +19,7 @@ func init() { } type http3Dialer struct { - client *http.Client + client *pht_util.Client md metadata logger logger.Logger options dialer.Options @@ -35,78 +31,34 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { opt(&options) } - tr := &http3.RoundTripper{ - TLSClientConfig: options.TLSConfig, - } - client := &http.Client{ - Timeout: 60 * time.Second, - Transport: tr, - } return &http3Dialer{ - client: client, logger: options.Logger, options: options, } } func (d *http3Dialer) Init(md md.Metadata) (err error) { - return d.parseMetadata(md) + if err = d.parseMetadata(md); err != nil { + return + } + + tr := &http3.RoundTripper{ + TLSClientConfig: d.options.TLSConfig, + } + d.client = &pht_util.Client{ + Client: &http.Client{ + Timeout: 60 * time.Second, + Transport: tr, + }, + AuthorizePath: d.md.authorizePath, + PushPath: d.md.pushPath, + PullPath: d.md.pullPath, + TLSEnabled: true, + Logger: d.options.Logger, + } + return nil } func (d *http3Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - token, err := d.authorize(ctx, addr) - if err != nil { - d.logger.Error(err) - return nil, err - } - - c := &conn{ - cid: token, - addr: addr, - client: d.client, - rxc: make(chan []byte, 128), - closed: make(chan struct{}), - md: d.md, - logger: d.logger, - } - go c.readLoop() - - return c, nil -} - -func (d *http3Dialer) authorize(ctx context.Context, addr string) (token string, err error) { - url := fmt.Sprintf("https://%s%s", addr, d.md.authorizePath) - r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return - } - - if d.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpRequest(r, false) - d.logger.Debug(string(dump)) - } - - resp, err := d.client.Do(r) - if err != nil { - return - } - defer resp.Body.Close() - - if d.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpResponse(resp, false) - d.logger.Debug(string(dump)) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - return - } - - if strings.HasPrefix(string(data), "token=") { - token = strings.TrimPrefix(string(data), "token=") - } - if token == "" { - err = errors.New("authorize failed") - } - return + return d.client.Dial(ctx, addr) } diff --git a/pkg/dialer/pht/dialer.go b/pkg/dialer/pht/dialer.go index af9bf36..74e9f39 100644 --- a/pkg/dialer/pht/dialer.go +++ b/pkg/dialer/pht/dialer.go @@ -2,16 +2,12 @@ package pht import ( "context" - "errors" - "fmt" - "io" "net" "net/http" - "net/http/httputil" - "strings" "time" "github.com/go-gost/gost/pkg/dialer" + pht_util "github.com/go-gost/gost/pkg/internal/util/pht" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -24,7 +20,7 @@ func init() { type phtDialer struct { tlsEnabled bool - client *http.Client + client *pht_util.Client md metadata logger logger.Logger options dialer.Options @@ -76,73 +72,20 @@ func (d *phtDialer) Init(md md.Metadata) (err error) { tr.TLSClientConfig = d.options.TLSConfig } - d.client = &http.Client{ - Timeout: 60 * time.Second, - Transport: tr, + d.client = &pht_util.Client{ + Client: &http.Client{ + Timeout: 60 * time.Second, + Transport: tr, + }, + AuthorizePath: d.md.authorizePath, + PushPath: d.md.pushPath, + PullPath: d.md.pullPath, + TLSEnabled: d.tlsEnabled, + Logger: d.options.Logger, } return nil } func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - token, err := d.authorize(ctx, addr) - if err != nil { - d.logger.Error(err) - return nil, err - } - - c := &conn{ - cid: token, - addr: addr, - client: d.client, - tlsEnabled: d.tlsEnabled, - rxc: make(chan []byte, 128), - closed: make(chan struct{}), - md: d.md, - logger: d.logger, - } - go c.readLoop() - - return c, nil -} - -func (d *phtDialer) authorize(ctx context.Context, addr string) (token string, err error) { - var url string - if d.tlsEnabled { - url = fmt.Sprintf("https://%s%s", addr, d.md.authorizePath) - } else { - url = fmt.Sprintf("http://%s%s", addr, d.md.authorizePath) - } - r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return - } - - if d.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpRequest(r, false) - d.logger.Debug(string(dump)) - } - - resp, err := d.client.Do(r) - if err != nil { - return - } - defer resp.Body.Close() - - if d.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpResponse(resp, false) - d.logger.Debug(string(dump)) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - return - } - - if strings.HasPrefix(string(data), "token=") { - token = strings.TrimPrefix(string(data), "token=") - } - if token == "" { - err = errors.New("authorize failed") - } - return + return d.client.Dial(ctx, addr) } diff --git a/pkg/internal/util/pht/client.go b/pkg/internal/util/pht/client.go new file mode 100644 index 0000000..b708235 --- /dev/null +++ b/pkg/internal/util/pht/client.go @@ -0,0 +1,93 @@ +package pht + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" + + "github.com/go-gost/gost/pkg/logger" +) + +type Client struct { + Client *http.Client + AuthorizePath string + PushPath string + PullPath string + TLSEnabled bool + Logger logger.Logger +} + +func (c *Client) Dial(ctx context.Context, addr string) (net.Conn, error) { + token, err := c.authorize(ctx, addr) + if err != nil { + c.Logger.Error(err) + return nil, err + } + + cn := &clientConn{ + client: c.Client, + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + localAddr: &net.TCPAddr{}, + logger: c.Logger, + } + cn.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) + + scheme := "http" + if c.TLSEnabled { + scheme = "https" + } + cn.pushURL = fmt.Sprintf("%s://%s%s?token=%s", scheme, addr, c.PushPath, token) + cn.pullURL = fmt.Sprintf("%s://%s%s?token=%s", scheme, addr, c.PullPath, token) + + go cn.readLoop() + + return cn, nil +} + +func (c *Client) authorize(ctx context.Context, addr string) (token string, err error) { + var url string + if c.TLSEnabled { + url = fmt.Sprintf("https://%s%s", addr, c.AuthorizePath) + } else { + url = fmt.Sprintf("http://%s%s", addr, c.AuthorizePath) + } + r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return + } + + if c.Logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + c.Logger.Debug(string(dump)) + } + + resp, err := c.Client.Do(r) + if err != nil { + return + } + defer resp.Body.Close() + + if c.Logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + c.Logger.Debug(string(dump)) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return + } + + if strings.HasPrefix(string(data), "token=") { + token = strings.TrimPrefix(string(data), "token=") + } + if token == "" { + err = errors.New("authorize failed") + } + return +} diff --git a/pkg/dialer/pht/conn.go b/pkg/internal/util/pht/conn.go similarity index 54% rename from pkg/dialer/pht/conn.go rename to pkg/internal/util/pht/conn.go index dadbb65..9e8d7da 100644 --- a/pkg/dialer/pht/conn.go +++ b/pkg/internal/util/pht/conn.go @@ -5,27 +5,27 @@ import ( "bytes" "encoding/base64" "errors" - "fmt" "net" "net/http" + "net/http/httputil" "time" "github.com/go-gost/gost/pkg/logger" ) -type conn struct { - cid string - addr string +type clientConn struct { client *http.Client - tlsEnabled bool + pushURL string + pullURL string buf []byte rxc chan []byte closed chan struct{} - md metadata + localAddr net.Addr + remoteAddr net.Addr logger logger.Logger } -func (c *conn) Read(b []byte) (n int, err error) { +func (c *clientConn) Read(b []byte) (n int, err error) { if len(c.buf) == 0 { select { case c.buf = <-c.rxc: @@ -41,7 +41,7 @@ func (c *conn) Read(b []byte) (n int, err error) { return } -func (c *conn) Write(b []byte) (n int, err error) { +func (c *clientConn) Write(b []byte) (n int, err error) { if len(b) == 0 { return } @@ -49,16 +49,14 @@ func (c *conn) Write(b []byte) (n int, err error) { buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) buf.WriteByte('\n') - var url string - if c.tlsEnabled { - url = fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pushPath, c.cid) - } else { - url = fmt.Sprintf("http://%s%s?token=%s", c.addr, c.md.pushPath, c.cid) - } - r, err := http.NewRequest(http.MethodPost, url, buf) + r, err := http.NewRequest(http.MethodPost, c.pushURL, buf) if err != nil { return } + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + c.logger.Debug(string(dump)) + } resp, err := c.client.Do(r) if err != nil { @@ -66,6 +64,11 @@ func (c *conn) Write(b []byte) (n int, err error) { } defer resp.Body.Close() + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + c.logger.Debug(string(dump)) + } + if resp.StatusCode != http.StatusOK { err = errors.New(resp.Status) return @@ -75,21 +78,19 @@ func (c *conn) Write(b []byte) (n int, err error) { return } -func (c *conn) readLoop() { +func (c *clientConn) readLoop() { defer c.Close() - var url string - if c.tlsEnabled { - url = fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pullPath, c.cid) - } else { - url = fmt.Sprintf("http://%s%s?token=%s", c.addr, c.md.pullPath, c.cid) - } for { err := func() error { - r, err := http.NewRequest(http.MethodGet, url, nil) + r, err := http.NewRequest(http.MethodGet, c.pullURL, nil) if err != nil { return err } + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + c.logger.Debug(string(dump)) + } resp, err := c.client.Do(r) if err != nil { @@ -97,6 +98,11 @@ func (c *conn) readLoop() { } defer resp.Body.Close() + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + c.logger.Debug(string(dump)) + } + if resp.StatusCode != http.StatusOK { return errors.New(resp.Status) } @@ -124,20 +130,15 @@ func (c *conn) readLoop() { } } -func (c *conn) LocalAddr() net.Addr { - return &net.TCPAddr{} +func (c *clientConn) LocalAddr() net.Addr { + return c.localAddr } -func (c *conn) RemoteAddr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", c.addr) - if addr == nil { - addr = &net.TCPAddr{} - } - - return addr +func (c *clientConn) RemoteAddr() net.Addr { + return c.remoteAddr } -func (c *conn) Close() error { +func (c *clientConn) Close() error { select { case <-c.closed: default: @@ -146,14 +147,14 @@ func (c *conn) Close() error { return nil } -func (c *conn) SetReadDeadline(t time.Time) error { +func (c *clientConn) SetReadDeadline(t time.Time) error { return nil } -func (c *conn) SetWriteDeadline(t time.Time) error { +func (c *clientConn) SetWriteDeadline(t time.Time) error { return nil } -func (c *conn) SetDeadline(t time.Time) error { +func (c *clientConn) SetDeadline(t time.Time) error { return nil }