diff --git a/chain/route.go b/chain/route.go index b20cb47..6c4c309 100644 --- a/chain/route.go +++ b/chain/route.go @@ -19,24 +19,33 @@ var ( ) type Route struct { - chain *Chain - ifceName string - nodes []*Node - logger logger.Logger + chain *Chain + nodes []*Node + logger logger.Logger } func (r *Route) addNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) { +func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) { + var options DialOptions + for _, opt := range opts { + opt(&options) + } + if r.Len() == 0 { netd := dialer.NetDialer{ - Timeout: 30 * time.Second, + Timeout: options.Timeout, + Interface: options.Interface, + } + if options.SockOpts != nil { + netd.Mark = options.SockOpts.Mark } if r != nil { - netd.Interface = r.ifceName + netd.Logger = r.logger } + return netd.Dial(ctx, network, address) } @@ -198,3 +207,29 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ... return nil, err } } + +type DialOptions struct { + Timeout time.Duration + Interface string + SockOpts *SockOpts +} + +type DialOption func(opts *DialOptions) + +func TimeoutDialOption(d time.Duration) DialOption { + return func(opts *DialOptions) { + opts.Timeout = d + } +} + +func InterfaceDialOption(ifName string) DialOption { + return func(opts *DialOptions) { + opts.Interface = ifName + } +} + +func SockOptsDialOption(so *SockOpts) DialOption { + return func(opts *DialOptions) { + opts.SockOpts = so + } +} diff --git a/chain/router.go b/chain/router.go index 9a39f2f..32b238e 100644 --- a/chain/router.go +++ b/chain/router.go @@ -13,10 +13,15 @@ import ( "github.com/go-gost/core/resolver" ) +type SockOpts struct { + Mark int +} + type Router struct { + ifceName string + sockOpts *SockOpts timeout time.Duration retries int - ifceName string chain Chainer resolver resolver.Resolver hosts hosts.HostMapper @@ -38,6 +43,11 @@ func (r *Router) WithInterface(ifceName string) *Router { return r } +func (r *Router) WithSockOpts(so *SockOpts) *Router { + r.sockOpts = so + return r +} + func (r *Router) WithChain(chain Chainer) *Router { r.chain = chain return r @@ -109,10 +119,11 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co if route == nil { route = &Route{} } - route.ifceName = r.ifceName route.logger = r.logger - - conn, err = route.Dial(ctx, network, address) + conn, err = route.Dial(ctx, network, address, + InterfaceDialOption(r.ifceName), + SockOptsDialOption(r.sockOpts), + ) if err == nil { break } diff --git a/chain/transport.go b/chain/transport.go index d70827a..0efe5cf 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -13,6 +13,7 @@ import ( type Transport struct { addr string ifceName string + sockOpts *SockOpts route *Route dialer dialer.Dialer connector connector.Connector @@ -29,6 +30,11 @@ func (tr *Transport) WithInterface(ifceName string) *Transport { return tr } +func (tr *Transport) WithSockOpts(so *SockOpts) *Transport { + tr.sockOpts = so + return tr +} + func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport { tr.dialer = dialer return tr @@ -42,7 +48,10 @@ func (tr *Transport) WithConnector(connector connector.Connector) *Transport { func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { netd := &net_dialer.NetDialer{ Interface: tr.ifceName, - Timeout: 30 * time.Second, + Timeout: 15 * time.Second, + } + if tr.sockOpts != nil { + netd.Mark = tr.sockOpts.Mark } if tr.route.Len() > 0 { netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index 32a227b..c4db350 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -10,14 +10,17 @@ import ( "github.com/go-gost/core/logger" ) +const ( + DefaultTimeout = 15 * time.Second +) + var ( - DefaultNetDialer = &NetDialer{ - Timeout: 30 * time.Second, - } + DefaultNetDialer = &NetDialer{} ) type NetDialer struct { Interface string + Mark int Timeout time.Duration DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) Logger logger.Logger @@ -27,6 +30,10 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e if d == nil { d = DefaultNetDialer } + if d.Timeout <= 0 { + d.Timeout = DefaultTimeout + } + log := d.Logger if log == nil { log = logger.Default() @@ -39,7 +46,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e if d.DialFunc != nil { return d.DialFunc(ctx, network, addr) } - logger.Default().Infof("interface: %s %v/%s", ifceName, ifAddr, network) + log.Infof("interface: %s %v/%s", ifceName, ifAddr, network) switch network { case "udp", "udp4", "udp6": @@ -59,17 +66,18 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e Timeout: d.Timeout, LocalAddr: ifAddr, Control: func(network, address string, c syscall.RawConn) error { - var cerr error - err := c.Control(func(fd uintptr) { - cerr = bindDevice(fd, ifceName) + return c.Control(func(fd uintptr) { + if ifceName != "" { + if err := bindDevice(fd, ifceName); err != nil { + log.Warnf("bind device: %v", err) + } + } + if d.Mark != 0 { + if err := setMark(fd, d.Mark); err != nil { + log.Warnf("set mark: %v", err) + } + } }) - if err != nil { - return err - } - if cerr != nil { - return cerr - } - return nil }, } return netd.DialContext(ctx, network, addr) diff --git a/common/net/dialer/dialer_linux.go b/common/net/dialer/dialer_linux.go index cf7f0aa..7cb1cde 100644 --- a/common/net/dialer/dialer_linux.go +++ b/common/net/dialer/dialer_linux.go @@ -12,3 +12,10 @@ func bindDevice(fd uintptr, ifceName string) error { } return unix.BindToDevice(int(fd), ifceName) } + +func setMark(fd uintptr, mark int) error { + if mark == 0 { + return nil + } + return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark) +} diff --git a/common/net/dialer/dialer_other.go b/common/net/dialer/dialer_other.go index 6ad6b21..fff35b4 100644 --- a/common/net/dialer/dialer_other.go +++ b/common/net/dialer/dialer_other.go @@ -5,3 +5,7 @@ package dialer func bindDevice(fd uintptr, ifceName string) error { return nil } + +func setMark(fd uintptr, mark int) error { + return nil +} diff --git a/common/net/relay/relay.go b/common/net/relay/relay.go index c5d8493..af4cebe 100644 --- a/common/net/relay/relay.go +++ b/common/net/relay/relay.go @@ -41,7 +41,7 @@ func (r *UDPRelay) SetBufferSize(n int) { func (r *UDPRelay) Run() (err error) { bufSize := r.bufferSize if bufSize <= 0 { - bufSize = 1024 + bufSize = 1500 } errc := make(chan error, 2) diff --git a/handler/http/handler.go b/handler/http/handler.go index 665e3db..a45fd72 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -154,7 +154,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } if network == "udp" { - return h.handleUDP(ctx, conn, network, req.Host, log) + return h.handleUDP(ctx, conn, log) } if req.Method == "PRI" || diff --git a/handler/http/udp.go b/handler/http/udp.go index 8f72260..dba9db5 100644 --- a/handler/http/udp.go +++ b/handler/http/udp.go @@ -13,7 +13,7 @@ import ( "github.com/go-gost/core/logger" ) -func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { +func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) error { log = log.WithFields(map[string]any{ "cmd": "udp", }) diff --git a/handler/socks/v5/metadata.go b/handler/socks/v5/metadata.go index 5479919..283133e 100644 --- a/handler/socks/v5/metadata.go +++ b/handler/socks/v5/metadata.go @@ -34,7 +34,7 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { if bs := mdata.GetInt(md, udpBufferSize); bs > 0 { h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { - h.md.udpBufferSize = 1024 + h.md.udpBufferSize = 1500 } h.md.compatibilityMode = mdata.GetBool(md, compatibilityMode) diff --git a/metrics/wrapper/conn.go b/metrics/wrapper/conn.go index ad7ae2f..ab036ad 100644 --- a/metrics/wrapper/conn.go +++ b/metrics/wrapper/conn.go @@ -38,6 +38,15 @@ func (c *serverConn) Write(b []byte) (n int, err error) { return } +func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) { + if sc, ok := c.Conn.(syscall.Conn); ok { + rc, err = sc.SyscallConn() + return + } + err = errUnsupport + return +} + type packetConn struct { net.PacketConn service string @@ -168,7 +177,7 @@ func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, er } func (c *udpConn) SyscallConn() (rc syscall.RawConn, err error) { - if nc, ok := c.PacketConn.(syscallConn); ok { + if nc, ok := c.PacketConn.(syscall.Conn); ok { return nc.SyscallConn() } err = errUnsupport @@ -189,8 +198,8 @@ type UDPConn interface { readUDP writeUDP setBuffer - syscallConn remoteAddr + syscall.Conn } type setBuffer interface { @@ -208,10 +217,6 @@ type writeUDP interface { WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) } -type syscallConn interface { - SyscallConn() (syscall.RawConn, error) -} - type remoteAddr interface { RemoteAddr() net.Addr }