fix host parsing
This commit is contained in:
		| @ -110,7 +110,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, _, err := net.SplitHostPort(host); err != nil { | 	if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 		host = net.JoinHostPort(host, "0") | 		host = net.JoinHostPort(strings.Trim(host, "[]"), "0") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var target *chain.Node | 	var target *chain.Node | ||||||
| @ -202,7 +202,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot | |||||||
|  |  | ||||||
| 			host := req.Host | 			host := req.Host | ||||||
| 			if _, _, err := net.SplitHostPort(host); err != nil { | 			if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 				host = net.JoinHostPort(host, "80") | 				host = net.JoinHostPort(strings.Trim(host, "[]"), "80") | ||||||
| 			} | 			} | ||||||
| 			if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { | 			if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { | ||||||
| 				log.Debugf("bypass: %s %s", host, req.RequestURI) | 				log.Debugf("bypass: %s %s", host, req.RequestURI) | ||||||
|  | |||||||
| @ -203,7 +203,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot | |||||||
|  |  | ||||||
| 			host := req.Host | 			host := req.Host | ||||||
| 			if _, _, err := net.SplitHostPort(host); err != nil { | 			if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 				host = net.JoinHostPort(host, "80") | 				host = net.JoinHostPort(strings.Trim(host, "[]"), "80") | ||||||
| 			} | 			} | ||||||
| 			if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { | 			if bp := h.options.Bypass; bp != nil && bp.Contains(ctx, "tcp", host, bypass.WithPathOption(req.RequestURI)) { | ||||||
| 				log.Debugf("bypass: %s %s", host, req.RequestURI) | 				log.Debugf("bypass: %s %s", host, req.RequestURI) | ||||||
|  | |||||||
| @ -136,7 +136,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt | |||||||
|  |  | ||||||
| 	addr := req.Host | 	addr := req.Host | ||||||
| 	if _, port, _ := net.SplitHostPort(addr); port == "" { | 	if _, port, _ := net.SplitHostPort(addr); port == "" { | ||||||
| 		addr = net.JoinHostPort(addr, "80") | 		addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	fields := map[string]any{ | 	fields := map[string]any{ | ||||||
|  | |||||||
| @ -132,7 +132,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req | |||||||
|  |  | ||||||
| 	addr := req.Host | 	addr := req.Host | ||||||
| 	if _, port, _ := net.SplitHostPort(addr); port == "" { | 	if _, port, _ := net.SplitHostPort(addr); port == "" { | ||||||
| 		addr = net.JoinHostPort(addr, "80") | 		addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	fields := map[string]any{ | 	fields := map[string]any{ | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httputil" | 	"net/http/httputil" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/go-gost/core/chain" | 	"github.com/go-gost/core/chain" | ||||||
| @ -88,7 +89,7 @@ func (h *http3Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle | |||||||
| func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error { | func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error { | ||||||
| 	addr := req.Host | 	addr := req.Host | ||||||
| 	if _, port, _ := net.SplitHostPort(addr); port == "" { | 	if _, port, _ := net.SplitHostPort(addr); port == "" { | ||||||
| 		addr = net.JoinHostPort(addr, "80") | 		addr = net.JoinHostPort(strings.Trim(addr, "[]"), "80") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if log.IsLevelEnabled(logger.TraceLevel) { | 	if log.IsLevelEnabled(logger.TraceLevel) { | ||||||
|  | |||||||
| @ -152,7 +152,7 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd | |||||||
|  |  | ||||||
| 	host := req.Host | 	host := req.Host | ||||||
| 	if _, _, err := net.SplitHostPort(host); err != nil { | 	if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 		host = net.JoinHostPort(host, "80") | 		host = net.JoinHostPort(strings.Trim(host, "[]"), "80") | ||||||
| 	} | 	} | ||||||
| 	log = log.WithFields(map[string]any{ | 	log = log.WithFields(map[string]any{ | ||||||
| 		"host": host, | 		"host": host, | ||||||
| @ -227,7 +227,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad | |||||||
| 			if port == "" { | 			if port == "" { | ||||||
| 				port = "443" | 				port = "443" | ||||||
| 			} | 			} | ||||||
| 			host = net.JoinHostPort(host, port) | 			host = net.JoinHostPort(strings.Trim(host, "[]"), port) | ||||||
| 		} | 		} | ||||||
| 		log = log.WithFields(map[string]any{ | 		log = log.WithFields(map[string]any{ | ||||||
| 			"host": host, | 			"host": host, | ||||||
| @ -263,7 +263,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host string, err error) { | func (h *redirectHandler) getServerName(_ context.Context, r io.Reader) (host string, err error) { | ||||||
| 	record, err := dissector.ReadRecord(r) | 	record, err := dissector.ReadRecord(r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httputil" | 	"net/http/httputil" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/go-gost/core/bypass" | 	"github.com/go-gost/core/bypass" | ||||||
| @ -105,7 +106,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net | |||||||
|  |  | ||||||
| 	host := req.Host | 	host := req.Host | ||||||
| 	if _, _, err := net.SplitHostPort(host); err != nil { | 	if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 		host = net.JoinHostPort(host, "80") | 		host = net.JoinHostPort(strings.Trim(host, "[]"), "80") | ||||||
| 	} | 	} | ||||||
| 	log = log.WithFields(map[string]any{ | 	log = log.WithFields(map[string]any{ | ||||||
| 		"host": host, | 		"host": host, | ||||||
| @ -171,7 +172,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, _, err := net.SplitHostPort(host); err != nil { | 	if _, _, err := net.SplitHostPort(host); err != nil { | ||||||
| 		host = net.JoinHostPort(host, "443") | 		host = net.JoinHostPort(strings.Trim(host, "[]"), "443") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log = log.WithFields(map[string]any{ | 	log = log.WithFields(map[string]any{ | ||||||
|  | |||||||
| @ -132,7 +132,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { | |||||||
|  |  | ||||||
| 			host := req.Host | 			host := req.Host | ||||||
| 			if h, _, _ := net.SplitHostPort(host); h == "" { | 			if h, _, _ := net.SplitHostPort(host); h == "" { | ||||||
| 				host = net.JoinHostPort(host, "80") | 				host = net.JoinHostPort(strings.Trim(host, "[]"), "80") | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if node == ep.node { | 			if node == ep.node { | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ package quic | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/go-gost/core/listener" | 	"github.com/go-gost/core/listener" | ||||||
| 	"github.com/go-gost/core/logger" | 	"github.com/go-gost/core/logger" | ||||||
| @ -48,7 +49,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) { | |||||||
|  |  | ||||||
| 	addr := l.options.Addr | 	addr := l.options.Addr | ||||||
| 	if _, _, err := net.SplitHostPort(addr); err != nil { | 	if _, _, err := net.SplitHostPort(addr); err != nil { | ||||||
| 		addr = net.JoinHostPort(addr, "0") | 		addr = net.JoinHostPort(strings.Trim(addr, "[]"), "0") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	network := "udp" | 	network := "udp" | ||||||
|  | |||||||
| @ -16,16 +16,16 @@ type redirConn struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *redirConn) Read(b []byte) (n int, err error) { | func (c *redirConn) Read(b []byte) (n int, err error) { | ||||||
| 	if c.ttl > 0 { |  | ||||||
| 		c.SetReadDeadline(time.Now().Add(c.ttl)) |  | ||||||
| 		defer c.SetReadDeadline(time.Time{}) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	c.once.Do(func() { | 	c.once.Do(func() { | ||||||
| 		n = copy(b, c.buf) | 		n = copy(b, c.buf) | ||||||
| 		bufpool.Put(c.buf) | 		bufpool.Put(c.buf) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	if c.ttl > 0 { | ||||||
|  | 		c.SetReadDeadline(time.Now().Add(c.ttl)) | ||||||
|  | 		defer c.SetReadDeadline(time.Time{}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if n == 0 { | 	if n == 0 { | ||||||
| 		n, err = c.Conn.Read(b) | 		n, err = c.Conn.Read(b) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -19,15 +19,30 @@ import ( | |||||||
| 	"golang.org/x/sys/unix" | 	"golang.org/x/sys/unix" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // https://github.com/KatelynHaworth/go-tproxy | ||||||
| func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) { | func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) { | ||||||
|  | 	laddr, err := net.ResolveUDPAddr("udp", addr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	lc := net.ListenConfig{ | 	lc := net.ListenConfig{ | ||||||
| 		Control: func(network, address string, c syscall.RawConn) error { | 		Control: func(network, address string, c syscall.RawConn) error { | ||||||
| 			return c.Control(func(fd uintptr) { | 			return c.Control(func(fd uintptr) { | ||||||
| 				if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { | 				if laddr.IP.To4() != nil { | ||||||
| 					l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) | 					if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { | ||||||
| 				} | 						l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) | ||||||
| 				if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { | 					} | ||||||
| 					l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) | 					if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { | ||||||
|  | 						l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil { | ||||||
|  | 						l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_TRANSPARENT, 1): %v", err) | ||||||
|  | 					} | ||||||
|  | 					if err := unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1); err != nil { | ||||||
|  | 						l.logger.Errorf("SetsockoptInt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1): %v", err) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 			}) | 			}) | ||||||
| 		}, | 		}, | ||||||
| @ -83,11 +98,7 @@ func (l *redirectListener) accept() (conn net.Conn, err error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	network := "udp" | 	c, err := dialUDP("udp", dstAddr, raddr) | ||||||
| 	if xnet.IsIPv4(l.options.Addr) { |  | ||||||
| 		network = "udp4" |  | ||||||
| 	} |  | ||||||
| 	c, err := dialUDP(network, dstAddr, raddr) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		l.logger.Error(err) | 		l.logger.Error(err) | ||||||
| 		return | 		return | ||||||
| @ -128,28 +139,24 @@ func readFromUDP(conn *net.UDPConn, b []byte) (n int, remoteAddr *net.UDPAddr, d | |||||||
| 				return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) | 				return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			switch originalDstRaw.Family { | 			pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) | ||||||
| 			case unix.AF_INET: | 			p := (*[2]byte)(unsafe.Pointer(&pp.Port)) | ||||||
| 				pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) | 			dstAddr = &net.UDPAddr{ | ||||||
| 				p := (*[2]byte)(unsafe.Pointer(&pp.Port)) | 				IP:   net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), | ||||||
| 				dstAddr = &net.UDPAddr{ | 				Port: int(p[0])<<8 + int(p[1]), | ||||||
| 					IP:   net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), | 			} | ||||||
| 					Port: int(p[0])<<8 + int(p[1]), | 		} else if msg.Header.Level == unix.SOL_IPV6 && msg.Header.Type == unix.IPV6_RECVORIGDSTADDR { | ||||||
| 				} | 			inet6 := &unix.RawSockaddrInet6{} | ||||||
|  | 			if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, inet6); err != nil { | ||||||
| 			case unix.AF_INET6: | 				return 0, nil, nil, fmt.Errorf("reading original destination address: %v", err) | ||||||
| 				pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw)) | 			} | ||||||
| 				p := (*[2]byte)(unsafe.Pointer(&pp.Port)) |  | ||||||
| 				dstAddr = &net.UDPAddr{ | 			p := (*[2]byte)(unsafe.Pointer(&inet6.Port)) | ||||||
| 					IP:   net.IP(pp.Addr[:]), | 			dstAddr = &net.UDPAddr{ | ||||||
| 					Port: int(p[0])<<8 + int(p[1]), | 				IP:   net.IP(inet6.Addr[:]), | ||||||
| 					Zone: strconv.Itoa(int(pp.Scope_id)), | 				Port: int(p[0])<<8 + int(p[1]), | ||||||
| 				} | 				Zone: strconv.Itoa(int(inet6.Scope_id)), | ||||||
|  |  | ||||||
| 			default: |  | ||||||
| 				return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family") |  | ||||||
| 			} | 			} | ||||||
| 			break |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @ -179,9 +186,16 @@ func dialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net.Conn, | |||||||
| 		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %v", err)} | 		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %v", err)} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { | 	if laddr.IP.To4() != nil { | ||||||
| 		unix.Close(fileDescriptor) | 		if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { | ||||||
| 		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)} | 			unix.Close(fileDescriptor) | ||||||
|  | 			return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %v", err)} | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1); err != nil { | ||||||
|  | 			unix.Close(fileDescriptor) | ||||||
|  | 			return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IPV6_TRANSPARENT: %v", err)} | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { | 	if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { | ||||||
| @ -230,9 +244,12 @@ func udpAddrToSocketAddr(addr *net.UDPAddr) (unix.Sockaddr, error) { | |||||||
| 		ip := [16]byte{} | 		ip := [16]byte{} | ||||||
| 		copy(ip[:], addr.IP.To16()) | 		copy(ip[:], addr.IP.To16()) | ||||||
|  |  | ||||||
| 		zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) | 		var zoneID uint64 | ||||||
| 		if err != nil { | 		if addr.Zone != "" { | ||||||
| 			return nil, err | 			zoneID, _ = strconv.ParseUint(addr.Zone, 10, 32) | ||||||
|  | 			if zoneID == 0 { | ||||||
|  | 				zoneID = 2 | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil | 		return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil | ||||||
| @ -251,7 +268,7 @@ func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if (laddr == nil || laddr.IP.To4() != nil) && | 	if (laddr == nil || laddr.IP.To4() != nil) && | ||||||
| 		(raddr == nil || laddr.IP.To4() != nil) { | 		(raddr == nil || raddr.IP.To4() != nil) { | ||||||
| 		return unix.AF_INET | 		return unix.AF_INET | ||||||
| 	} | 	} | ||||||
| 	return unix.AF_INET6 | 	return unix.AF_INET6 | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user