fix transparent proxy

This commit is contained in:
ginuerzh
2022-03-29 23:02:32 +08:00
parent 6a6367b8d1
commit 303f46f843
20 changed files with 493 additions and 68 deletions

View File

@ -176,10 +176,15 @@ type ConnectorConfig struct {
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
}
type SockOptsConfig struct {
Mark int `yaml:",omitempty" json:"mark,omitempty"`
}
type ServiceConfig struct {
Name string `json:"name"`
Addr string `yaml:",omitempty" json:"addr,omitempty"`
Interface string `yaml:",omitempty" json:"interface,omitempty"`
SockOpts *SockOptsConfig `yaml:"sockopts,omitempty" json:"sockopts,omitempty"`
Admission string `yaml:",omitempty" json:"admission,omitempty"`
Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
@ -198,6 +203,7 @@ type ChainConfig struct {
type HopConfig struct {
Name string `json:"name"`
Interface string `yaml:",omitempty" json:"interface,omitempty"`
SockOpts *SockOptsConfig `yaml:"sockopts,omitempty" json:"sockopts,omitempty"`
Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"`
Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
@ -209,6 +215,7 @@ type NodeConfig struct {
Name string `json:"name"`
Addr string `yaml:",omitempty" json:"addr,omitempty"`
Interface string `yaml:",omitempty" json:"interface,omitempty"`
SockOpts *SockOptsConfig `yaml:"sockopts,omitempty" json:"sockopts,omitempty"`
Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty" json:"hosts,omitempty"`

View File

@ -105,12 +105,23 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if v.Interface == "" {
v.Interface = hop.Interface
}
if v.SockOpts == nil {
v.SockOpts = hop.SockOpts
}
var sockOpts *chain.SockOpts
if v.SockOpts != nil {
sockOpts = &chain.SockOpts{
Mark: v.SockOpts.Mark,
}
}
tr := (&chain.Transport{}).
WithConnector(cr).
WithDialer(d).
WithAddr(v.Addr).
WithInterface(v.Interface)
WithInterface(v.Interface).
WithSockOpts(sockOpts)
node := &chain.Node{
Name: v.Name,

View File

@ -91,10 +91,18 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
auther = registry.AutherRegistry().Get(cfg.Handler.Auther)
}
var sockOpts *chain.SockOpts
if cfg.SockOpts != nil {
sockOpts = &chain.SockOpts{
Mark: cfg.SockOpts.Mark,
}
}
router := (&chain.Router{}).
WithRetries(cfg.Handler.Retries).
// WithTimeout(timeout time.Duration).
WithInterface(cfg.Interface).
WithSockOpts(sockOpts).
WithChain(registry.ChainRegistry().Get(cfg.Handler.Chain)).
WithResolver(registry.ResolverRegistry().Get(cfg.Resolver)).
WithHosts(registry.HostsRegistry().Get(cfg.Hosts)).

View File

@ -26,7 +26,7 @@ func (c *ssuConnector) parseMetadata(md mdata.Metadata) (err error) {
if bs := mdata.GetInt(md, bufferSize); bs > 0 {
c.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
} else {
c.md.bufferSize = 1024
c.md.bufferSize = 1500
}
return

View File

@ -1,40 +0,0 @@
package redirect
import (
"errors"
"fmt"
"net"
"syscall"
)
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
defer conn.Close()
tc, ok := conn.(*net.TCPConn)
if !ok {
err = errors.New("wrong connection type, must be TCP")
return
}
fc, err := tc.File()
if err != nil {
return
}
defer fc.Close()
mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80)
if err != nil {
return
}
// only ipv4 support
ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7])
port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3])
addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port))
if err != nil {
return
}
c, err = net.FileConn(fc)
return
}

View File

@ -0,0 +1,8 @@
package redirect
import "io"
type readWriter struct {
io.Reader
io.Writer
}

View File

@ -0,0 +1,261 @@
package redirect
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"time"
"github.com/go-gost/core/chain"
netpkg "github.com/go-gost/core/common/net"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/core/registry"
dissector "github.com/go-gost/tls-dissector"
)
func init() {
registry.HandlerRegistry().Register("red", NewHandler)
registry.HandlerRegistry().Register("redir", NewHandler)
registry.HandlerRegistry().Register("redirect", NewHandler)
}
type redirectHandler struct {
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &redirectHandler{
options: options,
}
}
func (h *redirectHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
network := "tcp"
dstAddr, err := h.getOriginalDstAddr(conn)
if err != nil {
log.Error(err)
return err
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
})
var rw io.ReadWriter = conn
if h.md.sniffing {
// try to sniff TLS traffic
var hdr [dissector.RecordHeaderLen]byte
_, err := io.ReadFull(rw, hdr[:])
rw = &readWriter{
Reader: io.MultiReader(bytes.NewReader(hdr[:]), rw),
Writer: rw,
}
if err == nil &&
hdr[0] == dissector.Handshake &&
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
return h.handleHTTPS(ctx, rw, conn.RemoteAddr(), dstAddr, log)
}
// try to sniff HTTP traffic
buf := new(bytes.Buffer)
_, err = http.ReadRequest(bufio.NewReader(io.TeeReader(rw, buf)))
rw = &readWriter{
Reader: io.MultiReader(buf, rw),
Writer: rw,
}
if err == nil {
return h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
}
}
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
log.Info("bypass: ", dstAddr)
return nil
}
cc, err := h.router.Dial(ctx, network, dstAddr.String())
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
netpkg.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
}
func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error {
req, err := http.ReadRequest(bufio.NewReader(rw))
if err != nil {
return err
}
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Debug(string(dump))
}
host := req.Host
if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, "80")
}
log = log.WithFields(map[string]any{
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(host) {
log.Info("bypass: ", host)
return nil
}
cc, err := h.router.Dial(ctx, "tcp", host)
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", raddr, host)
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", raddr, host)
}()
if err := req.Write(cc); err != nil {
log.Error(err)
return err
}
resp, err := http.ReadResponse(bufio.NewReader(cc), req)
if err != nil {
log.Error(err)
return err
}
defer resp.Body.Close()
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
log.Debug(string(dump))
}
return resp.Write(rw)
}
func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr, dstAddr net.Addr, log logger.Logger) error {
buf := new(bytes.Buffer)
host, err := h.getServerName(ctx, io.TeeReader(rw, buf))
if err != nil {
log.Error(err)
return err
}
if host == "" {
host = dstAddr.String()
} else {
host = net.JoinHostPort(host, "443")
}
log = log.WithFields(map[string]any{
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(host) {
log.Info("bypass: ", host)
return nil
}
cc, err := h.router.Dial(ctx, "tcp", host)
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", raddr, host)
netpkg.Transport(&readWriter{
Reader: io.MultiReader(buf, rw),
Writer: rw,
}, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", raddr, host)
return nil
}
func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host string, err error) {
record, err := dissector.ReadRecord(r)
if err != nil {
return
}
clientHello := dissector.ClientHelloMsg{}
if err = clientHello.Decode(record.Opaque); err != nil {
return
}
for _, ext := range clientHello.Extensions {
if ext.Type() == dissector.ExtServerName {
snExtension := ext.(*dissector.ServerNameExtension)
host = snExtension.Name
break
}
}
return
}

View File

@ -0,0 +1,62 @@
package redirect
import (
"errors"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, err error) {
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
err = errors.New("wrong connection type, must be TCP Conn")
return
}
sc, ok := conn.(syscall.Conn)
if !ok {
err = errors.New("wrong connection type, must be syscall.Conn")
return
}
rc, err := sc.SyscallConn()
if err != nil {
return
}
var cerr error
err = rc.Control(func(fd uintptr) {
if tcpAddr.IP.To4() != nil {
mreq, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
if err != nil {
cerr = err
return
}
addr = &net.TCPAddr{
IP: net.IP(mreq.Multiaddr[4:8]),
Port: int(mreq.Multiaddr[2])<<8 + int(mreq.Multiaddr[3]),
}
} else {
info, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
if err != nil {
cerr = err
return
}
addr = &net.TCPAddr{
IP: net.IP(info.Addr.Addr[:]),
Port: int(info.Addr.Port),
}
}
})
if err != nil {
return
}
if cerr != nil {
return nil, cerr
}
return
}

View File

@ -7,9 +7,7 @@ import (
"net"
)
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
defer conn.Close()
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, err error) {
err = errors.New("TCP redirect is not available on non-linux platform")
return
}

View File

@ -0,0 +1,17 @@
package redirect
import (
mdata "github.com/go-gost/core/metadata"
)
type metadata struct {
sniffing bool
}
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
sniffing = "sniffing"
)
h.md.sniffing = mdata.GetBool(md, sniffing)
return
}

View File

@ -14,10 +14,7 @@ import (
)
func init() {
registry.HandlerRegistry().Register("red", NewHandler)
registry.HandlerRegistry().Register("redu", NewHandler)
registry.HandlerRegistry().Register("redir", NewHandler)
registry.HandlerRegistry().Register("redirect", NewHandler)
}
type redirectHandler struct {
@ -66,22 +63,8 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
network := "tcp"
var dstAddr net.Addr
var err error
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
dstAddr = conn.LocalAddr()
}
if network == "tcp" {
dstAddr, conn, err = h.getOriginalDstAddr(conn)
if err != nil {
log.Error(err)
return err
}
}
network := "udp"
dstAddr := conn.LocalAddr()
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", dstAddr, network),

View File

@ -29,7 +29,7 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
if bs := mdata.GetInt(md, udpBufferSize); bs > 0 {
h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
} else {
h.md.udpBufferSize = 1024
h.md.udpBufferSize = 1500
}
return
}

View File

@ -26,7 +26,7 @@ func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) {
if bs := mdata.GetInt(md, bufferSize); bs > 0 {
h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
} else {
h.md.bufferSize = 1024
h.md.bufferSize = 1500
}
return
}

View File

@ -8,7 +8,7 @@ import (
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadBufferSize = 1500
defaultReadQueueSize = 128
defaultBacklog = 128
)

View File

@ -0,0 +1,66 @@
package tcp
import (
"context"
"net"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
metrics "github.com/go-gost/core/metrics/wrapper"
"github.com/go-gost/core/registry"
)
func init() {
registry.ListenerRegistry().Register("red", NewListener)
registry.ListenerRegistry().Register("redir", NewListener)
registry.ListenerRegistry().Register("redirect", NewListener)
}
type redirectListener struct {
ln net.Listener
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &redirectListener{
logger: options.Logger,
options: options,
}
}
func (l *redirectListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
lc := net.ListenConfig{}
if l.md.tproxy {
lc.Control = l.control
}
ln, err := lc.Listen(context.Background(), "tcp", l.options.Addr)
if err != nil {
return err
}
l.ln = metrics.WrapListener(l.options.Service, ln)
return
}
func (l *redirectListener) Accept() (conn net.Conn, err error) {
return l.ln.Accept()
}
func (l *redirectListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *redirectListener) Close() error {
return l.ln.Close()
}

View File

@ -0,0 +1,15 @@
package tcp
import (
"syscall"
"golang.org/x/sys/unix"
)
func (l *redirectListener) control(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
l.logger.Errorf("set sockopt: %v", err)
}
})
}

View File

@ -0,0 +1,12 @@
//go:build !linux
package tcp
import (
"errors"
"syscall"
)
func (l *redirectListener) control(network, address string, c syscall.RawConn) error {
return errors.New("TProxy is not available on non-linux platform")
}

View File

@ -0,0 +1,17 @@
package tcp
import (
mdata "github.com/go-gost/core/metadata"
)
type metadata struct {
tproxy bool
}
func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) {
const (
tproxy = "tproxy"
)
l.md.tproxy = mdata.GetBool(md, tproxy)
return
}

View File

@ -8,7 +8,7 @@ import (
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadBufferSize = 1500
)
type metadata struct {