fix host parsing

This commit is contained in:
ginuerzh 2024-07-10 22:57:49 +08:00
parent f2e32080e4
commit 4a4c64cc66
11 changed files with 76 additions and 56 deletions

View File

@ -110,7 +110,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
} }
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "0") host = net.JoinHostPort(strings.Trim(host, "[]"), "0")
} }
var target *chain.Node var target *chain.Node
@ -202,7 +202,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
host := req.Host host := req.Host
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "80") host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
} }
if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) {
log.Debugf("bypass: %s %s", host, req.RequestURI) log.Debugf("bypass: %s %s", host, req.RequestURI)

View File

@ -203,7 +203,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
host := req.Host host := req.Host
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "80") host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
} }
if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) {
log.Debugf("bypass: %s %s", host, req.RequestURI) log.Debugf("bypass: %s %s", host, req.RequestURI)

View File

@ -136,7 +136,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
addr := req.Host addr := req.Host
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "80") addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80")
} }
fields := map[string]any{ fields := map[string]any{

View File

@ -132,7 +132,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
addr := req.Host addr := req.Host
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "80") addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80")
} }
fields := map[string]any{ fields := map[string]any{

View File

@ -7,6 +7,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strings"
"time" "time"
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
@ -88,7 +89,7 @@ func (h *http3Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle
func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error { func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error {
addr := req.Host addr := req.Host
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "80") addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80")
} }
if log.IsLevelEnabled(logger.TraceLevel) { if log.IsLevelEnabled(logger.TraceLevel) {

View File

@ -152,7 +152,7 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd
host := req.Host host := req.Host
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "80") host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"host": host, "host": host,
@ -227,7 +227,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
if port == "" { if port == "" {
port = "443" port = "443"
} }
host = net.JoinHostPort(host, port) host = net.JoinHostPort(strings.Trim(host, "[]"), port)
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"host": host, "host": host,
@ -263,7 +263,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
return nil return nil
} }
func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host string, err error) { func (h *redirectHandler) getServerName(_ context.Context, r io.Reader) (host string, err error) {
record, err := dissector.ReadRecord(r) record, err := dissector.ReadRecord(r)
if err != nil { if err != nil {
return return

View File

@ -13,6 +13,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strings"
"time" "time"
"github.com/go-gost/core/bypass" "github.com/go-gost/core/bypass"
@ -105,7 +106,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net
host := req.Host host := req.Host
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "80") host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"host": host, "host": host,
@ -171,7 +172,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne
} }
if _, _, err := net.SplitHostPort(host); err != nil { if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "443") host = net.JoinHostPort(strings.Trim(host, "[]"), "443")
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{

View File

@ -132,7 +132,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
host := req.Host host := req.Host
if h, _, _ := net.SplitHostPort(host); h == "" { if h, _, _ := net.SplitHostPort(host); h == "" {
host = net.JoinHostPort(host, "80") host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
} }
if node == ep.node { if node == ep.node {

View File

@ -3,6 +3,7 @@ package quic
import ( import (
"context" "context"
"net" "net"
"strings"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
@ -48,7 +49,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
addr := l.options.Addr addr := l.options.Addr
if _, _, err := net.SplitHostPort(addr); err != nil { if _, _, err := net.SplitHostPort(addr); err != nil {
addr = net.JoinHostPort(addr, "0") addr = net.JoinHostPort(strings.Trim(addr, "[]"), "0")
} }
network := "udp" network := "udp"

View File

@ -16,16 +16,16 @@ type redirConn struct {
} }
func (c *redirConn) Read(b []byte) (n int, err error) { func (c *redirConn) Read(b []byte) (n int, err error) {
if c.ttl > 0 {
c.SetReadDeadline(time.Now().Add(c.ttl))
defer c.SetReadDeadline(time.Time{})
}
c.once.Do(func() { c.once.Do(func() {
n = copy(b, c.buf) n = copy(b, c.buf)
bufpool.Put(c.buf) bufpool.Put(c.buf)
}) })
if c.ttl > 0 {
c.SetReadDeadline(time.Now().Add(c.ttl))
defer c.SetReadDeadline(time.Time{})
}
if n == 0 { if n == 0 {
n, err = c.Conn.Read(b) n, err = c.Conn.Read(b)
} }

View File

@ -19,15 +19,30 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// https://github.com/KatelynHaworth/go-tproxy
func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) { func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
lc := net.ListenConfig{ lc := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { if laddr.IP.To4() != nil {
l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
} l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err)
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { }
l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil {
l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err)
}
} else {
if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil {
l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_TRANSPARENT, 1): %v", err)
}
if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1); err != nil {
l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1): %v", err)
}
} }
}) })
}, },
@ -83,11 +98,7 @@ func (l *redirectListener) accept() (conn net.Conn, err error) {
} }
} }
network := "udp" c, err := dialUDP("udp", dstAddr, raddr)
if xnet.IsIPv4(l.options.Addr) {
network = "udp4"
}
c, err := dialUDP(network, dstAddr, raddr)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
return return
@ -128,28 +139,24 @@ func readFromUDP(conn *net.UDPConn, b []byte) (n int, remoteAddr *net.UDPAddr, d
return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err)
} }
switch originalDstRaw.Family { pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw))
case unix.AF_INET: p := (*[2]byte)(unsafe.Pointer(&pp.Port))
pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) dstAddr = &net.UDPAddr{
p := (*[2]byte)(unsafe.Pointer(&pp.Port)) IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
dstAddr = &net.UDPAddr{ Port: int(p[0])<<8 + int(p[1]),
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), }
Port: int(p[0])<<8 + int(p[1]), } else if msg.Header.Level == unix.SOL_IPV6 && msg.Header.Type == unix.IPV6_RECVORIGDSTADDR {
} inet6 := &unix.RawSockaddrInet6{}
if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, inet6); err != nil {
case unix.AF_INET6: return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err)
pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw)) }
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
dstAddr = &net.UDPAddr{ p := (*[2]byte)(unsafe.Pointer(&inet6.Port))
IP: net.IP(pp.Addr[:]), dstAddr = &net.UDPAddr{
Port: int(p[0])<<8 + int(p[1]), IP: net.IP(inet6.Addr[:]),
Zone: strconv.Itoa(int(pp.Scope_id)), Port: int(p[0])<<8 + int(p[1]),
} Zone: strconv.Itoa(int(inet6.Scope_id)),
default:
return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family")
} }
break
} }
} }
@ -179,9 +186,16 @@ func dialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net.Conn,
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %v", err)} return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %v", err)}
} }
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { if laddr.IP.To4() != nil {
unix.Close(fileDescriptor) if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)} unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)}
}
} else {
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IPV6_TRANSPARENT: %v", err)}
}
} }
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
@ -230,9 +244,12 @@ func udpAddrToSocketAddr(addr *net.UDPAddr) (unix.Sockaddr, error) {
ip := [16]byte{} ip := [16]byte{}
copy(ip[:], addr.IP.To16()) copy(ip[:], addr.IP.To16())
zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) var zoneID uint64
if err != nil { if addr.Zone != "" {
return nil, err zoneID, _ = strconv.ParseUint(addr.Zone, 10, 32)
if zoneID == 0 {
zoneID = 2
}
} }
return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil
@ -251,7 +268,7 @@ func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int {
} }
if (laddr == nil || laddr.IP.To4() != nil) && if (laddr == nil || laddr.IP.To4() != nil) &&
(raddr == nil || laddr.IP.To4() != nil) { (raddr == nil || raddr.IP.To4() != nil) {
return unix.AF_INET return unix.AF_INET
} }
return unix.AF_INET6 return unix.AF_INET6