diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 957d55a..f076d2b 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -110,7 +110,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "0") + host = net.JoinHostPort(strings.Trim(host, "[]"), "0") } var target *chain.Node @@ -202,7 +202,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot host := req.Host if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "80") + host = net.JoinHostPort(strings.Trim(host, "[]"), "80") } if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { log.Debugf("bypass: %s %s", host, req.RequestURI) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index dfd2ba2..82fad11 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -203,7 +203,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot host := req.Host if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "80") + host = net.JoinHostPort(strings.Trim(host, "[]"), "80") } if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { log.Debugf("bypass: %s %s", host, req.RequestURI) diff --git a/handler/http/handler.go b/handler/http/handler.go index 392544f..0829274 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -136,7 +136,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt addr := req.Host if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "80") + addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") } fields := map[string]any{ diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 5aa6e46..def7ad1 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -132,7 +132,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req addr := req.Host if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "80") + addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") } fields := map[string]any{ diff --git a/handler/http3/handler.go b/handler/http3/handler.go index 933b12a..8be8c93 100644 --- a/handler/http3/handler.go +++ b/handler/http3/handler.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/httputil" + "strings" "time" "github.com/go-gost/core/chain" @@ -88,7 +89,7 @@ func (h *http3Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error { addr := req.Host if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "80") + addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") } if log.IsLevelEnabled(logger.TraceLevel) { diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index 56d5e9b..bd17b64 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -152,7 +152,7 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd host := req.Host if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "80") + host = net.JoinHostPort(strings.Trim(host, "[]"), "80") } log = log.WithFields(map[string]any{ "host": host, @@ -227,7 +227,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad if port == "" { port = "443" } - host = net.JoinHostPort(host, port) + host = net.JoinHostPort(strings.Trim(host, "[]"), port) } log = log.WithFields(map[string]any{ "host": host, @@ -263,7 +263,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad return nil } -func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host string, err error) { +func (h *redirectHandler) getServerName(_ context.Context, r io.Reader) (host string, err error) { record, err := dissector.ReadRecord(r) if err != nil { return diff --git a/handler/sni/handler.go b/handler/sni/handler.go index d70bf2d..8875ec4 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/http/httputil" + "strings" "time" "github.com/go-gost/core/bypass" @@ -105,7 +106,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net host := req.Host if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "80") + host = net.JoinHostPort(strings.Trim(host, "[]"), "80") } log = log.WithFields(map[string]any{ "host": host, @@ -171,7 +172,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne } if _, _, err := net.SplitHostPort(host); err != nil { - host = net.JoinHostPort(host, "443") + host = net.JoinHostPort(strings.Trim(host, "[]"), "443") } log = log.WithFields(map[string]any{ diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index a8f3126..b582431 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -132,7 +132,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { host := req.Host if h, _, _ := net.SplitHostPort(host); h == "" { - host = net.JoinHostPort(host, "80") + host = net.JoinHostPort(strings.Trim(host, "[]"), "80") } if node == ep.node { diff --git a/listener/quic/listener.go b/listener/quic/listener.go index 1c8a070..dff8bfa 100644 --- a/listener/quic/listener.go +++ b/listener/quic/listener.go @@ -3,6 +3,7 @@ package quic import ( "context" "net" + "strings" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" @@ -48,7 +49,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) { addr := l.options.Addr if _, _, err := net.SplitHostPort(addr); err != nil { - addr = net.JoinHostPort(addr, "0") + addr = net.JoinHostPort(strings.Trim(addr, "[]"), "0") } network := "udp" diff --git a/listener/redirect/udp/conn.go b/listener/redirect/udp/conn.go index cae9dad..1111bad 100644 --- a/listener/redirect/udp/conn.go +++ b/listener/redirect/udp/conn.go @@ -16,16 +16,16 @@ type redirConn struct { } func (c *redirConn) Read(b []byte) (n int, err error) { - if c.ttl > 0 { - c.SetReadDeadline(time.Now().Add(c.ttl)) - defer c.SetReadDeadline(time.Time{}) - } - c.once.Do(func() { n = copy(b, c.buf) bufpool.Put(c.buf) }) + if c.ttl > 0 { + c.SetReadDeadline(time.Now().Add(c.ttl)) + defer c.SetReadDeadline(time.Time{}) + } + if n == 0 { n, err = c.Conn.Read(b) } diff --git a/listener/redirect/udp/listener_linux.go b/listener/redirect/udp/listener_linux.go index 3b0cfce..db2504a 100644 --- a/listener/redirect/udp/listener_linux.go +++ b/listener/redirect/udp/listener_linux.go @@ -19,15 +19,30 @@ import ( "golang.org/x/sys/unix" ) +// https://github.com/KatelynHaworth/go-tproxy func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + lc := net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { - if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { - l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) - } - if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { - l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) + if laddr.IP.To4() != nil { + if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) + } + if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) + } + } else { + if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_TRANSPARENT, 1): %v", err) + } + if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1): %v", err) + } } }) }, @@ -83,11 +98,7 @@ func (l *redirectListener) accept() (conn net.Conn, err error) { } } - network := "udp" - if xnet.IsIPv4(l.options.Addr) { - network = "udp4" - } - c, err := dialUDP(network, dstAddr, raddr) + c, err := dialUDP("udp", dstAddr, raddr) if err != nil { l.logger.Error(err) return @@ -128,28 +139,24 @@ func readFromUDP(conn *net.UDPConn, b []byte) (n int, remoteAddr *net.UDPAddr, d return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) } - switch originalDstRaw.Family { - case unix.AF_INET: - pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) - p := (*[2]byte)(unsafe.Pointer(&pp.Port)) - dstAddr = &net.UDPAddr{ - IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), - Port: int(p[0])<<8 + int(p[1]), - } - - case unix.AF_INET6: - pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw)) - p := (*[2]byte)(unsafe.Pointer(&pp.Port)) - dstAddr = &net.UDPAddr{ - IP: net.IP(pp.Addr[:]), - Port: int(p[0])<<8 + int(p[1]), - Zone: strconv.Itoa(int(pp.Scope_id)), - } - - default: - return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family") + pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + dstAddr = &net.UDPAddr{ + IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), + Port: int(p[0])<<8 + int(p[1]), + } + } else if msg.Header.Level == unix.SOL_IPV6 && msg.Header.Type == unix.IPV6_RECVORIGDSTADDR { + inet6 := &unix.RawSockaddrInet6{} + if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, inet6); err != nil { + return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) + } + + p := (*[2]byte)(unsafe.Pointer(&inet6.Port)) + dstAddr = &net.UDPAddr{ + IP: net.IP(inet6.Addr[:]), + Port: int(p[0])<<8 + int(p[1]), + Zone: strconv.Itoa(int(inet6.Scope_id)), } - break } } @@ -179,9 +186,16 @@ func dialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net.Conn, return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %v", err)} } - if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { - unix.Close(fileDescriptor) - return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)} + if laddr.IP.To4() != nil { + if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)} + } + } else { + if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IPV6_TRANSPARENT: %v", err)} + } } if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { @@ -230,9 +244,12 @@ func udpAddrToSocketAddr(addr *net.UDPAddr) (unix.Sockaddr, error) { ip := [16]byte{} copy(ip[:], addr.IP.To16()) - zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) - if err != nil { - return nil, err + var zoneID uint64 + if addr.Zone != "" { + zoneID, _ = strconv.ParseUint(addr.Zone, 10, 32) + if zoneID == 0 { + zoneID = 2 + } } return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil @@ -251,7 +268,7 @@ func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int { } if (laddr == nil || laddr.IP.To4() != nil) && - (raddr == nil || laddr.IP.To4() != nil) { + (raddr == nil || raddr.IP.To4() != nil) { return unix.AF_INET } return unix.AF_INET6