fix transparent proxy
This commit is contained in:
@ -1,40 +0,0 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
|
||||
defer conn.Close()
|
||||
|
||||
tc, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
err = errors.New("wrong connection type, must be TCP")
|
||||
return
|
||||
}
|
||||
|
||||
fc, err := tc.File()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer fc.Close()
|
||||
|
||||
mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// only ipv4 support
|
||||
ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7])
|
||||
port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3])
|
||||
addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c, err = net.FileConn(fc)
|
||||
return
|
||||
}
|
8
handler/redirect/tcp/conn.go
Normal file
8
handler/redirect/tcp/conn.go
Normal file
@ -0,0 +1,8 @@
|
||||
package redirect
|
||||
|
||||
import "io"
|
||||
|
||||
type readWriter struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
261
handler/redirect/tcp/handler.go
Normal file
261
handler/redirect/tcp/handler.go
Normal file
@ -0,0 +1,261 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
netpkg "github.com/go-gost/core/common/net"
|
||||
"github.com/go-gost/core/handler"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
"github.com/go-gost/core/registry"
|
||||
dissector "github.com/go-gost/tls-dissector"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("red", NewHandler)
|
||||
registry.HandlerRegistry().Register("redir", NewHandler)
|
||||
registry.HandlerRegistry().Register("redirect", NewHandler)
|
||||
}
|
||||
|
||||
type redirectHandler struct {
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &redirectHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *redirectHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
network := "tcp"
|
||||
|
||||
dstAddr, err := h.getOriginalDstAddr(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
|
||||
})
|
||||
|
||||
var rw io.ReadWriter = conn
|
||||
if h.md.sniffing {
|
||||
// try to sniff TLS traffic
|
||||
var hdr [dissector.RecordHeaderLen]byte
|
||||
_, err := io.ReadFull(rw, hdr[:])
|
||||
rw = &readWriter{
|
||||
Reader: io.MultiReader(bytes.NewReader(hdr[:]), rw),
|
||||
Writer: rw,
|
||||
}
|
||||
if err == nil &&
|
||||
hdr[0] == dissector.Handshake &&
|
||||
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
|
||||
return h.handleHTTPS(ctx, rw, conn.RemoteAddr(), dstAddr, log)
|
||||
}
|
||||
|
||||
// try to sniff HTTP traffic
|
||||
buf := new(bytes.Buffer)
|
||||
_, err = http.ReadRequest(bufio.NewReader(io.TeeReader(rw, buf)))
|
||||
rw = &readWriter{
|
||||
Reader: io.MultiReader(buf, rw),
|
||||
Writer: rw,
|
||||
}
|
||||
if err == nil {
|
||||
return h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
|
||||
log.Info("bypass: ", dstAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, dstAddr.String())
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
|
||||
netpkg.Transport(rw, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error {
|
||||
req, err := http.ReadRequest(bufio.NewReader(rw))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
host := req.Host
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
host = net.JoinHostPort(host, "80")
|
||||
}
|
||||
log = log.WithFields(map[string]any{
|
||||
"host": host,
|
||||
})
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(host) {
|
||||
log.Info("bypass: ", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", raddr, host)
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", raddr, host)
|
||||
}()
|
||||
|
||||
if err := req.Write(cc); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(cc), req)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
return resp.Write(rw)
|
||||
}
|
||||
|
||||
func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr, dstAddr net.Addr, log logger.Logger) error {
|
||||
buf := new(bytes.Buffer)
|
||||
host, err := h.getServerName(ctx, io.TeeReader(rw, buf))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
if host == "" {
|
||||
host = dstAddr.String()
|
||||
} else {
|
||||
host = net.JoinHostPort(host, "443")
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"host": host,
|
||||
})
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(host) {
|
||||
log.Info("bypass: ", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", raddr, host)
|
||||
netpkg.Transport(&readWriter{
|
||||
Reader: io.MultiReader(buf, rw),
|
||||
Writer: rw,
|
||||
}, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", raddr, host)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host string, err error) {
|
||||
record, err := dissector.ReadRecord(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientHello := dissector.ClientHelloMsg{}
|
||||
if err = clientHello.Decode(record.Opaque); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ext := range clientHello.Extensions {
|
||||
if ext.Type() == dissector.ExtServerName {
|
||||
snExtension := ext.(*dissector.ServerNameExtension)
|
||||
host = snExtension.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
62
handler/redirect/tcp/handler_linux.go
Normal file
62
handler/redirect/tcp/handler_linux.go
Normal file
@ -0,0 +1,62 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, err error) {
|
||||
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
err = errors.New("wrong connection type, must be TCP Conn")
|
||||
return
|
||||
}
|
||||
|
||||
sc, ok := conn.(syscall.Conn)
|
||||
if !ok {
|
||||
err = errors.New("wrong connection type, must be syscall.Conn")
|
||||
return
|
||||
}
|
||||
|
||||
rc, err := sc.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cerr error
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
if tcpAddr.IP.To4() != nil {
|
||||
mreq, err := unix.GetsockoptIPv6Mreq(int(fd), unix.IPPROTO_IP, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
cerr = err
|
||||
return
|
||||
}
|
||||
|
||||
addr = &net.TCPAddr{
|
||||
IP: net.IP(mreq.Multiaddr[4:8]),
|
||||
Port: int(mreq.Multiaddr[2])<<8 + int(mreq.Multiaddr[3]),
|
||||
}
|
||||
} else {
|
||||
info, err := unix.GetsockoptIPv6MTUInfo(int(fd), unix.IPPROTO_IPV6, unix.SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
cerr = err
|
||||
return
|
||||
}
|
||||
addr = &net.TCPAddr{
|
||||
IP: net.IP(info.Addr.Addr[:]),
|
||||
Port: int(info.Addr.Port),
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cerr != nil {
|
||||
return nil, cerr
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -7,9 +7,7 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
|
||||
defer conn.Close()
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, err error) {
|
||||
err = errors.New("TCP redirect is not available on non-linux platform")
|
||||
return
|
||||
}
|
17
handler/redirect/tcp/metadata.go
Normal file
17
handler/redirect/tcp/metadata.go
Normal file
@ -0,0 +1,17 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/core/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
sniffing bool
|
||||
}
|
||||
|
||||
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
sniffing = "sniffing"
|
||||
)
|
||||
h.md.sniffing = mdata.GetBool(md, sniffing)
|
||||
return
|
||||
}
|
@ -14,10 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("red", NewHandler)
|
||||
registry.HandlerRegistry().Register("redu", NewHandler)
|
||||
registry.HandlerRegistry().Register("redir", NewHandler)
|
||||
registry.HandlerRegistry().Register("redirect", NewHandler)
|
||||
}
|
||||
|
||||
type redirectHandler struct {
|
||||
@ -66,22 +63,8 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
network := "tcp"
|
||||
var dstAddr net.Addr
|
||||
var err error
|
||||
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
dstAddr = conn.LocalAddr()
|
||||
}
|
||||
|
||||
if network == "tcp" {
|
||||
dstAddr, conn, err = h.getOriginalDstAddr(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
network := "udp"
|
||||
dstAddr := conn.LocalAddr()
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
|
@ -29,7 +29,7 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
if bs := mdata.GetInt(md, udpBufferSize); bs > 0 {
|
||||
h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
|
||||
} else {
|
||||
h.md.udpBufferSize = 1024
|
||||
h.md.udpBufferSize = 1500
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
if bs := mdata.GetInt(md, bufferSize); bs > 0 {
|
||||
h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
|
||||
} else {
|
||||
h.md.bufferSize = 1024
|
||||
h.md.bufferSize = 1500
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user