diff --git a/internal/util/pht/client.go b/internal/util/pht/client.go index 2e590f0..c201475 100644 --- a/internal/util/pht/client.go +++ b/internal/util/pht/client.go @@ -77,6 +77,8 @@ func (c *Client) authorize(ctx context.Context, addr string) (token string, err if c.Logger.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(r, false) c.Logger.Trace(string(dump)) + } else if c.Logger.IsLevelEnabled(logger.DebugLevel) { + c.Logger.Debugf("%s %s", r.Method, r.URL) } resp, err := c.Client.Do(r) diff --git a/internal/util/pht/conn.go b/internal/util/pht/conn.go index fdd2502..7202850 100644 --- a/internal/util/pht/conn.go +++ b/internal/util/pht/conn.go @@ -5,9 +5,11 @@ import ( "bytes" "encoding/base64" "errors" + "io" "net" "net/http" "net/http/httputil" + "sync" "time" "github.com/go-gost/core/logger" @@ -20,6 +22,7 @@ type clientConn struct { buf []byte rxc chan []byte closed chan struct{} + mu sync.Mutex localAddr net.Addr remoteAddr net.Addr logger logger.Logger @@ -30,7 +33,7 @@ func (c *clientConn) Read(b []byte) (n int, err error) { select { case c.buf = <-c.rxc: case <-c.closed: - err = net.ErrClosed + err = io.ErrClosedPipe return } } @@ -45,21 +48,35 @@ func (c *clientConn) Write(b []byte) (n int, err error) { if len(b) == 0 { return } + return c.write(b) +} - buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) - buf.WriteByte('\n') +func (c *clientConn) write(b []byte) (n int, err error) { + if c.isClosed() { + err = io.ErrClosedPipe + return + } - r, err := http.NewRequest(http.MethodPost, c.pushURL, buf) + var r io.Reader + if len(b) > 0 { + buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) + buf.WriteByte('\n') + r = buf + } + + req, err := http.NewRequest(http.MethodPost, c.pushURL, r) if err != nil { return } if c.logger.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(r, false) + dump, _ := httputil.DumpRequest(req, false) c.logger.Trace(string(dump)) + } else if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debugf("%s %s", req.Method, req.URL) } - resp, err := c.client.Do(r) + resp, err := c.client.Do(req) if err != nil { return } @@ -80,9 +97,12 @@ func (c *clientConn) Write(b []byte) (n int, err error) { } func (c *clientConn) readLoop() { - defer c.Close() - for { + if c.isClosed() { + return + } + + done := true err := func() error { r, err := http.NewRequest(http.MethodGet, c.pullURL, nil) if err != nil { @@ -92,6 +112,8 @@ func (c *clientConn) readLoop() { if c.logger.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(r, false) c.logger.Trace(string(dump)) + } else if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debugf("%s %s", r.Method, r.URL) } resp, err := c.client.Do(r) @@ -111,6 +133,11 @@ func (c *clientConn) readLoop() { scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { + done = false + if scanner.Text() == "" { + continue + } + b, err := base64.StdEncoding.DecodeString(scanner.Text()) if err != nil { return err @@ -121,12 +148,15 @@ func (c *clientConn) readLoop() { return net.ErrClosed } } - return scanner.Err() }() if err != nil { - c.logger.Error(err) + c.Close() + return + } + + if done { // server connection closed return } } @@ -141,12 +171,32 @@ func (c *clientConn) RemoteAddr() net.Addr { } func (c *clientConn) Close() error { + c.mu.Lock() + select { case <-c.closed: + c.mu.Unlock() + return nil default: close(c.closed) } - return nil + + c.mu.Unlock() + + _, err := c.write(nil) + + return err +} + +func (c *clientConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + select { + case <-c.closed: + return true + default: + } + return false } func (c *clientConn) SetReadDeadline(t time.Time) error { diff --git a/internal/util/pht/server.go b/internal/util/pht/server.go index edde7ce..3394a69 100644 --- a/internal/util/pht/server.go +++ b/internal/util/pht/server.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "errors" "fmt" + "io" "net" "net/http" "net/http/httputil" @@ -23,17 +24,21 @@ import ( ) const ( - defaultBacklog = 128 + defaultBacklog = 128 + defaultReadBufferSize = 32 * 1024 + defaultReadTimeout = 10 * time.Second ) type serverOptions struct { - authorizePath string - pushPath string - pullPath string - backlog int - tlsEnabled bool - tlsConfig *tls.Config - logger logger.Logger + authorizePath string + pushPath string + pullPath string + backlog int + tlsEnabled bool + tlsConfig *tls.Config + readBufferSize int + readTimeout time.Duration + logger logger.Logger } type ServerOption func(opts *serverOptions) @@ -64,6 +69,18 @@ func EnableTLSServerOption(enable bool) ServerOption { } } +func ReadBufferSizeServerOption(n int) ServerOption { + return func(opts *serverOptions) { + opts.readBufferSize = n + } +} + +func ReadTimeoutServerOption(timeout time.Duration) ServerOption { + return func(opts *serverOptions) { + opts.readTimeout = timeout + } +} + func LoggerServerOption(logger logger.Logger) ServerOption { return func(opts *serverOptions) { opts.logger = logger @@ -90,6 +107,12 @@ func NewServer(addr string, opts ...ServerOption) *Server { if options.backlog <= 0 { options.backlog = defaultBacklog } + if options.readBufferSize <= 0 { + options.readBufferSize = defaultReadBufferSize + } + if options.readTimeout <= 0 { + options.readTimeout = defaultReadTimeout + } s := &Server{ httpServer: &http.Server{ @@ -118,6 +141,12 @@ func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption) if options.backlog <= 0 { options.backlog = defaultBacklog } + if options.readBufferSize <= 0 { + options.readBufferSize = defaultReadBufferSize + } + if options.readTimeout <= 0 { + options.readTimeout = defaultReadTimeout + } s := &Server{ http3Server: &http3.Server{ @@ -200,6 +229,8 @@ func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(r, false) s.options.logger.Trace(string(dump)) + } else if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + s.options.logger.Debugf("%s %s", r.Method, r.RequestURI) } raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) @@ -234,6 +265,8 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(r, false) s.options.logger.Trace(string(dump)) + } else if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + s.options.logger.Debugf("%s %s", r.Method, r.RequestURI) } if r.Method != http.MethodPost { @@ -257,10 +290,12 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) { br := bufio.NewReader(r.Body) data, err := br.ReadString('\n') if err != nil { - s.options.logger.Error(err) + if err != io.EOF { + s.options.logger.Error(err) + w.WriteHeader(http.StatusPartialContent) + } conn.Close() s.conns.Delete(cid) - w.WriteHeader(http.StatusBadRequest) return } @@ -293,6 +328,8 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(r, false) s.options.logger.Trace(string(dump)) + } else if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + s.options.logger.Debugf("%s %s", r.Method, r.RequestURI) } if r.Method != http.MethodGet { @@ -319,32 +356,37 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) { fw.Flush() } - b := bufpool.Get(4096) + b := bufpool.Get(s.options.readBufferSize) defer bufpool.Put(b) for { - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + conn.SetReadDeadline(time.Now().Add(s.options.readTimeout)) n, err := conn.Read(*b) + if n > 0 { + bw := bufio.NewWriter(w) + bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n])) + bw.WriteString("\n") + if err := bw.Flush(); err != nil { + return + } + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + } if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - s.options.logger.Error(err) + if errors.Is(err, os.ErrDeadlineExceeded) { + (*b)[0] = '\n' // no data + w.Write((*b)[:1]) + } else if errors.Is(err, io.EOF) { + // server connection closed + } else { + if !errors.Is(err, io.ErrClosedPipe) { + s.options.logger.Error(err) + } s.conns.Delete(cid) conn.Close() - } else { - (*b)[0] = '\n' - w.Write((*b)[:1]) } return } - - bw := bufio.NewWriter(w) - bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n])) - bw.WriteString("\n") - if err := bw.Flush(); err != nil { - return - } - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } } }