diff --git a/chain/route.go b/chain/route.go index 8e1de36..2ac3876 100644 --- a/chain/route.go +++ b/chain/route.go @@ -78,7 +78,7 @@ func (*defaultRoute) Bind(ctx context.Context, network, address string, opts ... ReadQueueSize: options.UDPDataQueueSize, ReadBufferSize: options.UDPDataBufferSize, TTL: options.UDPConnTTL, - KeepAlive: true, + Keepalive: true, Logger: logger, }) return ln, err diff --git a/chain/router.go b/chain/router.go index 1c166f3..b6b5aaf 100644 --- a/chain/router.go +++ b/chain/router.go @@ -42,12 +42,6 @@ func (r *Router) Options() *chain.RouterOptions { } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - if r.options.Timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) - defer cancel() - } - host := address if h, _, _ := net.SplitHostPort(address); h != "" { host = h @@ -93,6 +87,13 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.options.Logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { + ctx := ctx + if r.options.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) + defer cancel() + } + var ipAddr string ipAddr, err = xnet.Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger) if err != nil { @@ -133,12 +134,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co } func (r *Router) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (ln net.Listener, err error) { - if r.options.Timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) - defer cancel() - } - count := r.options.Retries + 1 if count <= 0 { count = 1 @@ -146,6 +141,13 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...chai r.options.Logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { + ctx := ctx + if r.options.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) + defer cancel() + } + var route chain.Route if r.options.Chain != nil { route = r.options.Chain.Route(ctx, network, address) diff --git a/connector/relay/bind.go b/connector/relay/bind.go index 09a7513..248a0ab 100644 --- a/connector/relay/bind.go +++ b/connector/relay/bind.go @@ -73,7 +73,7 @@ func (c *relayConnector) bindUDP(ctx context.Context, conn net.Conn, network, ad ReadQueueSize: opts.UDPDataQueueSize, ReadBufferSize: opts.UDPDataBufferSize, TTL: opts.UDPConnTTL, - KeepAlive: true, + Keepalive: true, Logger: log, }) diff --git a/connector/socks/v5/bind.go b/connector/socks/v5/bind.go index b2e6d39..3c15898 100644 --- a/connector/socks/v5/bind.go +++ b/connector/socks/v5/bind.go @@ -87,7 +87,7 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a ReadQueueSize: opts.UDPDataQueueSize, ReadBufferSize: opts.UDPDataBufferSize, TTL: opts.UDPConnTTL, - KeepAlive: true, + Keepalive: true, Logger: log, }) diff --git a/handler/http/handler.go b/handler/http/handler.go index acc6c85..9117ef3 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -24,6 +24,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/core/observer/stats" ctxvalue "github.com/go-gost/x/ctx" + xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" stats_util "github.com/go-gost/x/internal/util/stats" traffic_wrapper "github.com/go-gost/x/limiter/traffic/wrapper" @@ -236,7 +237,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } if req.Method != http.MethodConnect { - return h.handleProxy(rw, cc, req, log) + return h.handleProxy(xio.NewReadWriteCloser(rw, rw, conn), cc, req, log) } resp.StatusCode = http.StatusOK @@ -261,50 +262,92 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt return nil } -func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) { - req.Header.Del("Proxy-Connection") +func (h *httpHandler) handleProxy(rw io.ReadWriteCloser, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) { + roundTrip := func(req *http.Request) error { + if req == nil { + return nil + } - if err = req.Write(cc); err != nil { - log.Error(err) - return err - } + resp := &http.Response{ + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: http.Header{}, + StatusCode: http.StatusServiceUnavailable, + } - ch := make(chan error, 1) + // HTTP/1.0 + if req.ProtoMajor == 1 && req.ProtoMinor == 0 { + if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" { + req.Header.Del("Connection") + } else { + req.Header.Set("Connection", "close") + } + } - go func() { - ch <- netpkg.CopyBuffer(rw, cc, 32*1024) - }() + req.Header.Del("Proxy-Connection") - for { - err := func() error { - req, err := http.ReadRequest(bufio.NewReader(rw)) + if err = req.Write(cc); err != nil { + resp.Write(rw) + return err + } + + go func() { + res, err := http.ReadResponse(bufio.NewReader(cc), req) if err != nil { - if err == io.EOF { - return nil - } - return err + h.options.Logger.Errorf("read response: %v", err) + resp.Write(rw) + return } if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) + dump, _ := httputil.DumpResponse(res, false) log.Trace(string(dump)) } - req.Header.Del("Proxy-Connection") - - if err = req.Write(cc); err != nil { - return err + if res.Close { + defer rw.Close() } - return nil - }() - ch <- err - if err != nil { - break - } + // HTTP/1.0 + if req.ProtoMajor == 1 && req.ProtoMinor == 0 { + if !res.Close { + res.Header.Set("Connection", "keep-alive") + } + res.ProtoMajor = req.ProtoMajor + res.ProtoMinor = req.ProtoMinor + } + + if err = res.Write(rw); err != nil { + rw.Close() + log.Errorf("write response: %v", err) + } + }() + + return nil } - return <-ch + if err = roundTrip(req); err != nil { + return err + } + + for { + req, err := http.ReadRequest(bufio.NewReader(rw)) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + + if err = roundTrip(req); err != nil { + return err + } + } } func (h *httpHandler) decodeServerName(s string) (string, error) { diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index 75cd9ed..17aec06 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -83,6 +83,9 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { log.Trace(string(dump)) } + resp.ProtoMajor = req.ProtoMajor + resp.ProtoMinor = req.ProtoMinor + var tunnelID relay.TunnelID if ep.ingress != nil { if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil { diff --git a/internal/net/udp/listener.go b/internal/net/udp/listener.go index affeaa4..b7c34f4 100644 --- a/internal/net/udp/listener.go +++ b/internal/net/udp/listener.go @@ -17,17 +17,16 @@ type ListenConfig struct { ReadQueueSize int ReadBufferSize int TTL time.Duration - KeepAlive bool + Keepalive bool Logger logger.Logger } type listener struct { conn net.PacketConn cqueue chan net.Conn connPool *connPool - // mux sync.Mutex - closed chan struct{} - errChan chan error - config *ListenConfig + closed chan struct{} + errChan chan error + config *ListenConfig } func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener { @@ -42,9 +41,7 @@ func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener { errChan: make(chan error, 1), config: cfg, } - if cfg.KeepAlive { - ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger) - } + ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger) go ln.listenLoop() return ln @@ -113,15 +110,12 @@ func (ln *listener) Close() error { } func (ln *listener) getConn(raddr net.Addr) *conn { - // ln.mux.Lock() - // defer ln.mux.Unlock() - c, ok := ln.connPool.Get(raddr.String()) - if ok { + if ok && !c.isClosed() { return c } - c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive) + c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.Keepalive) select { case ln.cqueue <- c: ln.connPool.Set(raddr.String(), c) @@ -142,17 +136,17 @@ type conn struct { idle int32 // indicate the connection is idle closed chan struct{} closeMutex sync.Mutex - keepAlive bool + keepalive bool } -func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn { +func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepalive bool) *conn { return &conn{ PacketConn: c, localAddr: laddr, remoteAddr: remoteAddr, rc: make(chan []byte, queueSize), closed: make(chan struct{}), - keepAlive: keepAlive, + keepalive: keepalive, } } @@ -179,7 +173,7 @@ func (c *conn) Read(b []byte) (n int, err error) { } func (c *conn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - if !c.keepAlive { + if !c.keepalive { defer c.Close() } return c.PacketConn.WriteTo(b, addr) @@ -201,6 +195,15 @@ func (c *conn) Close() error { return nil } +func (c *conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + func (c *conn) LocalAddr() net.Addr { return c.localAddr } diff --git a/listener/ftcp/listener.go b/listener/ftcp/listener.go index 58dc3c4..6e60e2c 100644 --- a/listener/ftcp/listener.go +++ b/listener/ftcp/listener.go @@ -64,7 +64,7 @@ func (l *ftcpListener) Init(md md.Metadata) (err error) { ReadQueueSize: l.md.readQueueSize, ReadBufferSize: l.md.readBufferSize, TTL: l.md.ttl, - KeepAlive: true, + Keepalive: true, Logger: l.logger, }) return diff --git a/listener/rudp/metadata.go b/listener/rudp/metadata.go index b4b7af0..bdc2f83 100644 --- a/listener/rudp/metadata.go +++ b/listener/rudp/metadata.go @@ -9,7 +9,7 @@ import ( const ( defaultTTL = 5 * time.Second - defaultReadBufferSize = 1024 + defaultReadBufferSize = 8192 defaultReadQueueSize = 1024 defaultBacklog = 128 ) diff --git a/listener/udp/listener.go b/listener/udp/listener.go index d74c7a1..e69a3b6 100644 --- a/listener/udp/listener.go +++ b/listener/udp/listener.go @@ -65,7 +65,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) { Backlog: l.md.backlog, ReadQueueSize: l.md.readQueueSize, ReadBufferSize: l.md.readBufferSize, - KeepAlive: l.md.keepalive, + Keepalive: l.md.keepalive, TTL: l.md.ttl, Logger: l.logger, }) diff --git a/listener/udp/metadata.go b/listener/udp/metadata.go index 6cb7b30..494c9ec 100644 --- a/listener/udp/metadata.go +++ b/listener/udp/metadata.go @@ -9,7 +9,7 @@ import ( const ( defaultTTL = 5 * time.Second - defaultReadBufferSize = 1024 + defaultReadBufferSize = 8192 defaultReadQueueSize = 128 defaultBacklog = 128 )