fix pht connection

This commit is contained in:
ginuerzh 2022-11-02 18:11:50 +08:00
parent 669e80b780
commit dc2fe32a2a
3 changed files with 132 additions and 38 deletions

View File

@ -77,6 +77,8 @@ func (c *Client) authorize(ctx context.Context, addr string) (token string, err
if c.Logger.IsLevelEnabled(logger.TraceLevel) { if c.Logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
c.Logger.Trace(string(dump)) c.Logger.Trace(string(dump))
} else if c.Logger.IsLevelEnabled(logger.DebugLevel) {
c.Logger.Debugf("%s %s", r.Method, r.URL)
} }
resp, err := c.Client.Do(r) resp, err := c.Client.Do(r)

View File

@ -5,9 +5,11 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"sync"
"time" "time"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
@ -20,6 +22,7 @@ type clientConn struct {
buf []byte buf []byte
rxc chan []byte rxc chan []byte
closed chan struct{} closed chan struct{}
mu sync.Mutex
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
logger logger.Logger logger logger.Logger
@ -30,7 +33,7 @@ func (c *clientConn) Read(b []byte) (n int, err error) {
select { select {
case c.buf = <-c.rxc: case c.buf = <-c.rxc:
case <-c.closed: case <-c.closed:
err = net.ErrClosed err = io.ErrClosedPipe
return return
} }
} }
@ -45,21 +48,35 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
if len(b) == 0 { if len(b) == 0 {
return return
} }
return c.write(b)
}
func (c *clientConn) write(b []byte) (n int, err error) {
if c.isClosed() {
err = io.ErrClosedPipe
return
}
var r io.Reader
if len(b) > 0 {
buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b)) buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b))
buf.WriteByte('\n') buf.WriteByte('\n')
r = buf
}
r, err := http.NewRequest(http.MethodPost, c.pushURL, buf) req, err := http.NewRequest(http.MethodPost, c.pushURL, r)
if err != nil { if err != nil {
return return
} }
if c.logger.IsLevelEnabled(logger.TraceLevel) { if c.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(req, false)
c.logger.Trace(string(dump)) c.logger.Trace(string(dump))
} else if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debugf("%s %s", req.Method, req.URL)
} }
resp, err := c.client.Do(r) resp, err := c.client.Do(req)
if err != nil { if err != nil {
return return
} }
@ -80,9 +97,12 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
} }
func (c *clientConn) readLoop() { func (c *clientConn) readLoop() {
defer c.Close()
for { for {
if c.isClosed() {
return
}
done := true
err := func() error { err := func() error {
r, err := http.NewRequest(http.MethodGet, c.pullURL, nil) r, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
if err != nil { if err != nil {
@ -92,6 +112,8 @@ func (c *clientConn) readLoop() {
if c.logger.IsLevelEnabled(logger.TraceLevel) { if c.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
c.logger.Trace(string(dump)) c.logger.Trace(string(dump))
} else if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debugf("%s %s", r.Method, r.URL)
} }
resp, err := c.client.Do(r) resp, err := c.client.Do(r)
@ -111,6 +133,11 @@ func (c *clientConn) readLoop() {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
done = false
if scanner.Text() == "" {
continue
}
b, err := base64.StdEncoding.DecodeString(scanner.Text()) b, err := base64.StdEncoding.DecodeString(scanner.Text())
if err != nil { if err != nil {
return err return err
@ -121,12 +148,15 @@ func (c *clientConn) readLoop() {
return net.ErrClosed return net.ErrClosed
} }
} }
return scanner.Err() return scanner.Err()
}() }()
if err != nil { if err != nil {
c.logger.Error(err) c.Close()
return
}
if done { // server connection closed
return return
} }
} }
@ -141,12 +171,32 @@ func (c *clientConn) RemoteAddr() net.Addr {
} }
func (c *clientConn) Close() error { func (c *clientConn) Close() error {
c.mu.Lock()
select { select {
case <-c.closed: case <-c.closed:
c.mu.Unlock()
return nil
default: default:
close(c.closed) close(c.closed)
} }
return nil
c.mu.Unlock()
_, err := c.write(nil)
return err
}
func (c *clientConn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
select {
case <-c.closed:
return true
default:
}
return false
} }
func (c *clientConn) SetReadDeadline(t time.Time) error { func (c *clientConn) SetReadDeadline(t time.Time) error {

View File

@ -6,6 +6,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@ -24,6 +25,8 @@ import (
const ( const (
defaultBacklog = 128 defaultBacklog = 128
defaultReadBufferSize = 32 * 1024
defaultReadTimeout = 10 * time.Second
) )
type serverOptions struct { type serverOptions struct {
@ -33,6 +36,8 @@ type serverOptions struct {
backlog int backlog int
tlsEnabled bool tlsEnabled bool
tlsConfig *tls.Config tlsConfig *tls.Config
readBufferSize int
readTimeout time.Duration
logger logger.Logger logger logger.Logger
} }
@ -64,6 +69,18 @@ func EnableTLSServerOption(enable bool) ServerOption {
} }
} }
func ReadBufferSizeServerOption(n int) ServerOption {
return func(opts *serverOptions) {
opts.readBufferSize = n
}
}
func ReadTimeoutServerOption(timeout time.Duration) ServerOption {
return func(opts *serverOptions) {
opts.readTimeout = timeout
}
}
func LoggerServerOption(logger logger.Logger) ServerOption { func LoggerServerOption(logger logger.Logger) ServerOption {
return func(opts *serverOptions) { return func(opts *serverOptions) {
opts.logger = logger opts.logger = logger
@ -90,6 +107,12 @@ func NewServer(addr string, opts ...ServerOption) *Server {
if options.backlog <= 0 { if options.backlog <= 0 {
options.backlog = defaultBacklog options.backlog = defaultBacklog
} }
if options.readBufferSize <= 0 {
options.readBufferSize = defaultReadBufferSize
}
if options.readTimeout <= 0 {
options.readTimeout = defaultReadTimeout
}
s := &Server{ s := &Server{
httpServer: &http.Server{ httpServer: &http.Server{
@ -118,6 +141,12 @@ func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption)
if options.backlog <= 0 { if options.backlog <= 0 {
options.backlog = defaultBacklog options.backlog = defaultBacklog
} }
if options.readBufferSize <= 0 {
options.readBufferSize = defaultReadBufferSize
}
if options.readTimeout <= 0 {
options.readTimeout = defaultReadTimeout
}
s := &Server{ s := &Server{
http3Server: &http3.Server{ http3Server: &http3.Server{
@ -200,6 +229,8 @@ func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.TraceLevel) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Trace(string(dump)) s.options.logger.Trace(string(dump))
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
} }
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
@ -234,6 +265,8 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.TraceLevel) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Trace(string(dump)) s.options.logger.Trace(string(dump))
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
} }
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -257,10 +290,12 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
br := bufio.NewReader(r.Body) br := bufio.NewReader(r.Body)
data, err := br.ReadString('\n') data, err := br.ReadString('\n')
if err != nil { if err != nil {
if err != io.EOF {
s.options.logger.Error(err) s.options.logger.Error(err)
w.WriteHeader(http.StatusPartialContent)
}
conn.Close() conn.Close()
s.conns.Delete(cid) s.conns.Delete(cid)
w.WriteHeader(http.StatusBadRequest)
return return
} }
@ -293,6 +328,8 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.TraceLevel) { if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Trace(string(dump)) s.options.logger.Trace(string(dump))
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
} }
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
@ -319,24 +356,13 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
fw.Flush() fw.Flush()
} }
b := bufpool.Get(4096) b := bufpool.Get(s.options.readBufferSize)
defer bufpool.Put(b) defer bufpool.Put(b)
for { for {
conn.SetReadDeadline(time.Now().Add(10 * time.Second)) conn.SetReadDeadline(time.Now().Add(s.options.readTimeout))
n, err := conn.Read(*b) n, err := conn.Read(*b)
if err != nil { if n > 0 {
if !errors.Is(err, os.ErrDeadlineExceeded) {
s.options.logger.Error(err)
s.conns.Delete(cid)
conn.Close()
} else {
(*b)[0] = '\n'
w.Write((*b)[:1])
}
return
}
bw := bufio.NewWriter(w) bw := bufio.NewWriter(w)
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n])) bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
bw.WriteString("\n") bw.WriteString("\n")
@ -347,4 +373,20 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
fw.Flush() fw.Flush()
} }
} }
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
(*b)[0] = '\n' // no data
w.Write((*b)[:1])
} else if errors.Is(err, io.EOF) {
// server connection closed
} else {
if !errors.Is(err, io.ErrClosedPipe) {
s.options.logger.Error(err)
}
s.conns.Delete(cid)
conn.Close()
}
return
}
}
} }