Merge branch 'refs/heads/master' into dev
This commit is contained in:
		| @ -78,7 +78,7 @@ func (*defaultRoute) Bind(ctx context.Context, network, address string, opts ... | |||||||
| 			ReadQueueSize:  options.UDPDataQueueSize, | 			ReadQueueSize:  options.UDPDataQueueSize, | ||||||
| 			ReadBufferSize: options.UDPDataBufferSize, | 			ReadBufferSize: options.UDPDataBufferSize, | ||||||
| 			TTL:            options.UDPConnTTL, | 			TTL:            options.UDPConnTTL, | ||||||
| 			KeepAlive:      true, | 			Keepalive:      true, | ||||||
| 			Logger:         logger, | 			Logger:         logger, | ||||||
| 		}) | 		}) | ||||||
| 		return ln, err | 		return ln, err | ||||||
|  | |||||||
| @ -42,12 +42,6 @@ func (r *Router) Options() *chain.RouterOptions { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { | func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { | ||||||
| 	if r.options.Timeout > 0 { |  | ||||||
| 		var cancel context.CancelFunc |  | ||||||
| 		ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) |  | ||||||
| 		defer cancel() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	host := address | 	host := address | ||||||
| 	if h, _, _ := net.SplitHostPort(address); h != "" { | 	if h, _, _ := net.SplitHostPort(address); h != "" { | ||||||
| 		host = h | 		host = h | ||||||
| @ -93,6 +87,13 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co | |||||||
| 	r.options.Logger.Debugf("dial %s/%s", address, network) | 	r.options.Logger.Debugf("dial %s/%s", address, network) | ||||||
|  |  | ||||||
| 	for i := 0; i < count; i++ { | 	for i := 0; i < count; i++ { | ||||||
|  | 		ctx := ctx | ||||||
|  | 		if r.options.Timeout > 0 { | ||||||
|  | 			var cancel context.CancelFunc | ||||||
|  | 			ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) | ||||||
|  | 			defer cancel() | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		var ipAddr string | 		var ipAddr string | ||||||
| 		ipAddr, err = xnet.Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger) | 		ipAddr, err = xnet.Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -133,12 +134,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *Router) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (ln net.Listener, err error) { | func (r *Router) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (ln net.Listener, err error) { | ||||||
| 	if r.options.Timeout > 0 { |  | ||||||
| 		var cancel context.CancelFunc |  | ||||||
| 		ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) |  | ||||||
| 		defer cancel() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	count := r.options.Retries + 1 | 	count := r.options.Retries + 1 | ||||||
| 	if count <= 0 { | 	if count <= 0 { | ||||||
| 		count = 1 | 		count = 1 | ||||||
| @ -146,6 +141,13 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...chai | |||||||
| 	r.options.Logger.Debugf("bind on %s/%s", address, network) | 	r.options.Logger.Debugf("bind on %s/%s", address, network) | ||||||
|  |  | ||||||
| 	for i := 0; i < count; i++ { | 	for i := 0; i < count; i++ { | ||||||
|  | 		ctx := ctx | ||||||
|  | 		if r.options.Timeout > 0 { | ||||||
|  | 			var cancel context.CancelFunc | ||||||
|  | 			ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) | ||||||
|  | 			defer cancel() | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		var route chain.Route | 		var route chain.Route | ||||||
| 		if r.options.Chain != nil { | 		if r.options.Chain != nil { | ||||||
| 			route = r.options.Chain.Route(ctx, network, address) | 			route = r.options.Chain.Route(ctx, network, address) | ||||||
|  | |||||||
| @ -73,7 +73,7 @@ func (c *relayConnector) bindUDP(ctx context.Context, conn net.Conn, network, ad | |||||||
| 			ReadQueueSize:  opts.UDPDataQueueSize, | 			ReadQueueSize:  opts.UDPDataQueueSize, | ||||||
| 			ReadBufferSize: opts.UDPDataBufferSize, | 			ReadBufferSize: opts.UDPDataBufferSize, | ||||||
| 			TTL:            opts.UDPConnTTL, | 			TTL:            opts.UDPConnTTL, | ||||||
| 			KeepAlive:      true, | 			Keepalive:      true, | ||||||
| 			Logger:         log, | 			Logger:         log, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | |||||||
| @ -87,7 +87,7 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a | |||||||
| 			ReadQueueSize:  opts.UDPDataQueueSize, | 			ReadQueueSize:  opts.UDPDataQueueSize, | ||||||
| 			ReadBufferSize: opts.UDPDataBufferSize, | 			ReadBufferSize: opts.UDPDataBufferSize, | ||||||
| 			TTL:            opts.UDPConnTTL, | 			TTL:            opts.UDPConnTTL, | ||||||
| 			KeepAlive:      true, | 			Keepalive:      true, | ||||||
| 			Logger:         log, | 			Logger:         log, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ import ( | |||||||
| 	md "github.com/go-gost/core/metadata" | 	md "github.com/go-gost/core/metadata" | ||||||
| 	"github.com/go-gost/core/observer/stats" | 	"github.com/go-gost/core/observer/stats" | ||||||
| 	ctxvalue "github.com/go-gost/x/ctx" | 	ctxvalue "github.com/go-gost/x/ctx" | ||||||
|  | 	xio "github.com/go-gost/x/internal/io" | ||||||
| 	netpkg "github.com/go-gost/x/internal/net" | 	netpkg "github.com/go-gost/x/internal/net" | ||||||
| 	stats_util "github.com/go-gost/x/internal/util/stats" | 	stats_util "github.com/go-gost/x/internal/util/stats" | ||||||
| 	traffic_wrapper "github.com/go-gost/x/limiter/traffic/wrapper" | 	traffic_wrapper "github.com/go-gost/x/limiter/traffic/wrapper" | ||||||
| @ -236,7 +237,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if req.Method != http.MethodConnect { | 	if req.Method != http.MethodConnect { | ||||||
| 		return h.handleProxy(rw, cc, req, log) | 		return h.handleProxy(xio.NewReadWriteCloser(rw, rw, conn), cc, req, log) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp.StatusCode = http.StatusOK | 	resp.StatusCode = http.StatusOK | ||||||
| @ -261,25 +262,78 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) { | func (h *httpHandler) handleProxy(rw io.ReadWriteCloser, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) { | ||||||
|  | 	roundTrip := func(req *http.Request) error { | ||||||
|  | 		if req == nil { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		resp := &http.Response{ | ||||||
|  | 			ProtoMajor: req.ProtoMajor, | ||||||
|  | 			ProtoMinor: req.ProtoMinor, | ||||||
|  | 			Header:     http.Header{}, | ||||||
|  | 			StatusCode: http.StatusServiceUnavailable, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// HTTP/1.0 | ||||||
|  | 		if req.ProtoMajor == 1 && req.ProtoMinor == 0 { | ||||||
|  | 			if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" { | ||||||
|  | 				req.Header.Del("Connection") | ||||||
|  | 			} else { | ||||||
|  | 				req.Header.Set("Connection", "close") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		req.Header.Del("Proxy-Connection") | 		req.Header.Del("Proxy-Connection") | ||||||
|  |  | ||||||
| 		if err = req.Write(cc); err != nil { | 		if err = req.Write(cc); err != nil { | ||||||
| 		log.Error(err) | 			resp.Write(rw) | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	ch := make(chan error, 1) |  | ||||||
|  |  | ||||||
| 		go func() { | 		go func() { | ||||||
| 		ch <- netpkg.CopyBuffer(rw, cc, 32*1024) | 			res, err := http.ReadResponse(bufio.NewReader(cc), req) | ||||||
|  | 			if err != nil { | ||||||
|  | 				h.options.Logger.Errorf("read response: %v", err) | ||||||
|  | 				resp.Write(rw) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if log.IsLevelEnabled(logger.TraceLevel) { | ||||||
|  | 				dump, _ := httputil.DumpResponse(res, false) | ||||||
|  | 				log.Trace(string(dump)) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if res.Close { | ||||||
|  | 				defer rw.Close() | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// HTTP/1.0 | ||||||
|  | 			if req.ProtoMajor == 1 && req.ProtoMinor == 0 { | ||||||
|  | 				if !res.Close { | ||||||
|  | 					res.Header.Set("Connection", "keep-alive") | ||||||
|  | 				} | ||||||
|  | 				res.ProtoMajor = req.ProtoMajor | ||||||
|  | 				res.ProtoMinor = req.ProtoMinor | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err = res.Write(rw); err != nil { | ||||||
|  | 				rw.Close() | ||||||
|  | 				log.Errorf("write response: %v", err) | ||||||
|  | 			} | ||||||
| 		}() | 		}() | ||||||
|  |  | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err = roundTrip(req); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		err := func() error { |  | ||||||
| 		req, err := http.ReadRequest(bufio.NewReader(rw)) | 		req, err := http.ReadRequest(bufio.NewReader(rw)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 				if err == io.EOF { | 			if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { | ||||||
| 				return nil | 				return nil | ||||||
| 			} | 			} | ||||||
| 			return err | 			return err | ||||||
| @ -290,23 +344,12 @@ func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log l | |||||||
| 			log.Trace(string(dump)) | 			log.Trace(string(dump)) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 			req.Header.Del("Proxy-Connection") | 		if err = roundTrip(req); err != nil { | ||||||
|  |  | ||||||
| 			if err = req.Write(cc); err != nil { |  | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 			return nil |  | ||||||
| 		}() |  | ||||||
| 		ch <- err |  | ||||||
|  |  | ||||||
| 		if err != nil { |  | ||||||
| 			break |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| 	return <-ch |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *httpHandler) decodeServerName(s string) (string, error) { | func (h *httpHandler) decodeServerName(s string) (string, error) { | ||||||
| 	b, err := base64.RawURLEncoding.DecodeString(s) | 	b, err := base64.RawURLEncoding.DecodeString(s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -83,6 +83,9 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { | |||||||
| 				log.Trace(string(dump)) | 				log.Trace(string(dump)) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			resp.ProtoMajor = req.ProtoMajor | ||||||
|  | 			resp.ProtoMinor = req.ProtoMinor | ||||||
|  |  | ||||||
| 			var tunnelID relay.TunnelID | 			var tunnelID relay.TunnelID | ||||||
| 			if ep.ingress != nil { | 			if ep.ingress != nil { | ||||||
| 				if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil { | 				if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil { | ||||||
|  | |||||||
| @ -17,14 +17,13 @@ type ListenConfig struct { | |||||||
| 	ReadQueueSize  int | 	ReadQueueSize  int | ||||||
| 	ReadBufferSize int | 	ReadBufferSize int | ||||||
| 	TTL            time.Duration | 	TTL            time.Duration | ||||||
| 	KeepAlive      bool | 	Keepalive      bool | ||||||
| 	Logger         logger.Logger | 	Logger         logger.Logger | ||||||
| } | } | ||||||
| type listener struct { | type listener struct { | ||||||
| 	conn     net.PacketConn | 	conn     net.PacketConn | ||||||
| 	cqueue   chan net.Conn | 	cqueue   chan net.Conn | ||||||
| 	connPool *connPool | 	connPool *connPool | ||||||
| 	// mux      sync.Mutex |  | ||||||
| 	closed   chan struct{} | 	closed   chan struct{} | ||||||
| 	errChan  chan error | 	errChan  chan error | ||||||
| 	config   *ListenConfig | 	config   *ListenConfig | ||||||
| @ -42,9 +41,7 @@ func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener { | |||||||
| 		errChan: make(chan error, 1), | 		errChan: make(chan error, 1), | ||||||
| 		config:  cfg, | 		config:  cfg, | ||||||
| 	} | 	} | ||||||
| 	if cfg.KeepAlive { |  | ||||||
| 	ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger) | 	ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger) | ||||||
| 	} |  | ||||||
| 	go ln.listenLoop() | 	go ln.listenLoop() | ||||||
|  |  | ||||||
| 	return ln | 	return ln | ||||||
| @ -113,15 +110,12 @@ func (ln *listener) Close() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (ln *listener) getConn(raddr net.Addr) *conn { | func (ln *listener) getConn(raddr net.Addr) *conn { | ||||||
| 	// ln.mux.Lock() |  | ||||||
| 	// defer ln.mux.Unlock() |  | ||||||
|  |  | ||||||
| 	c, ok := ln.connPool.Get(raddr.String()) | 	c, ok := ln.connPool.Get(raddr.String()) | ||||||
| 	if ok { | 	if ok && !c.isClosed() { | ||||||
| 		return c | 		return c | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive) | 	c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.Keepalive) | ||||||
| 	select { | 	select { | ||||||
| 	case ln.cqueue <- c: | 	case ln.cqueue <- c: | ||||||
| 		ln.connPool.Set(raddr.String(), c) | 		ln.connPool.Set(raddr.String(), c) | ||||||
| @ -142,17 +136,17 @@ type conn struct { | |||||||
| 	idle       int32       // indicate the connection is idle | 	idle       int32       // indicate the connection is idle | ||||||
| 	closed     chan struct{} | 	closed     chan struct{} | ||||||
| 	closeMutex sync.Mutex | 	closeMutex sync.Mutex | ||||||
| 	keepAlive  bool | 	keepalive  bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn { | func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepalive bool) *conn { | ||||||
| 	return &conn{ | 	return &conn{ | ||||||
| 		PacketConn: c, | 		PacketConn: c, | ||||||
| 		localAddr:  laddr, | 		localAddr:  laddr, | ||||||
| 		remoteAddr: remoteAddr, | 		remoteAddr: remoteAddr, | ||||||
| 		rc:         make(chan []byte, queueSize), | 		rc:         make(chan []byte, queueSize), | ||||||
| 		closed:     make(chan struct{}), | 		closed:     make(chan struct{}), | ||||||
| 		keepAlive:  keepAlive, | 		keepalive:  keepalive, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -179,7 +173,7 @@ func (c *conn) Read(b []byte) (n int, err error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *conn) WriteTo(b []byte, addr net.Addr) (n int, err error) { | func (c *conn) WriteTo(b []byte, addr net.Addr) (n int, err error) { | ||||||
| 	if !c.keepAlive { | 	if !c.keepalive { | ||||||
| 		defer c.Close() | 		defer c.Close() | ||||||
| 	} | 	} | ||||||
| 	return c.PacketConn.WriteTo(b, addr) | 	return c.PacketConn.WriteTo(b, addr) | ||||||
| @ -201,6 +195,15 @@ func (c *conn) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (c *conn) isClosed() bool { | ||||||
|  | 	select { | ||||||
|  | 	case <-c.closed: | ||||||
|  | 		return true | ||||||
|  | 	default: | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (c *conn) LocalAddr() net.Addr { | func (c *conn) LocalAddr() net.Addr { | ||||||
| 	return c.localAddr | 	return c.localAddr | ||||||
| } | } | ||||||
|  | |||||||
| @ -64,7 +64,7 @@ func (l *ftcpListener) Init(md md.Metadata) (err error) { | |||||||
| 			ReadQueueSize:  l.md.readQueueSize, | 			ReadQueueSize:  l.md.readQueueSize, | ||||||
| 			ReadBufferSize: l.md.readBufferSize, | 			ReadBufferSize: l.md.readBufferSize, | ||||||
| 			TTL:            l.md.ttl, | 			TTL:            l.md.ttl, | ||||||
| 			KeepAlive:      true, | 			Keepalive:      true, | ||||||
| 			Logger:         l.logger, | 			Logger:         l.logger, | ||||||
| 		}) | 		}) | ||||||
| 	return | 	return | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	defaultTTL            = 5 * time.Second | 	defaultTTL            = 5 * time.Second | ||||||
| 	defaultReadBufferSize = 1024 | 	defaultReadBufferSize = 8192 | ||||||
| 	defaultReadQueueSize  = 1024 | 	defaultReadQueueSize  = 1024 | ||||||
| 	defaultBacklog        = 128 | 	defaultBacklog        = 128 | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -65,7 +65,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) { | |||||||
| 		Backlog:        l.md.backlog, | 		Backlog:        l.md.backlog, | ||||||
| 		ReadQueueSize:  l.md.readQueueSize, | 		ReadQueueSize:  l.md.readQueueSize, | ||||||
| 		ReadBufferSize: l.md.readBufferSize, | 		ReadBufferSize: l.md.readBufferSize, | ||||||
| 		KeepAlive:      l.md.keepalive, | 		Keepalive:      l.md.keepalive, | ||||||
| 		TTL:            l.md.ttl, | 		TTL:            l.md.ttl, | ||||||
| 		Logger:         l.logger, | 		Logger:         l.logger, | ||||||
| 	}) | 	}) | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	defaultTTL            = 5 * time.Second | 	defaultTTL            = 5 * time.Second | ||||||
| 	defaultReadBufferSize = 1024 | 	defaultReadBufferSize = 8192 | ||||||
| 	defaultReadQueueSize  = 128 | 	defaultReadQueueSize  = 128 | ||||||
| 	defaultBacklog        = 128 | 	defaultBacklog        = 128 | ||||||
| ) | ) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user