diff --git a/cmd/gost/config.go b/cmd/gost/config.go index e617f35..d20d8a2 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -4,6 +4,7 @@ import ( "io" "os" + "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/components/connector" "github.com/go-gost/gost/pkg/components/dialer" @@ -16,12 +17,23 @@ import ( "github.com/go-gost/gost/pkg/service" ) +var ( + chains = make(map[string]*chain.Chain) + bypasses = make(map[string]bypass.Bypass) +) + func buildService(cfg *config.Config) (services []*service.Service) { if cfg == nil || len(cfg.Services) == 0 { return } - chains := buildChain(cfg) + for _, bypassCfg := range cfg.Bypasses { + bypasses[bypassCfg.Name] = bypassFromConfig(&bypassCfg) + } + + for _, chainCfg := range cfg.Chains { + chains[chainCfg.Name] = chainFromConfig(&chainCfg) + } for _, svc := range cfg.Services { listenerLogger := log.WithFields(map[string]interface{}{ @@ -37,21 +49,15 @@ func buildService(cfg *config.Config) (services []*service.Service) { listenerLogger.Fatal("init:", err) } - var chain *chain.Chain - for _, ch := range chains { - if svc.Chain == ch.Name { - chain = ch - break - } - } - handlerLogger := log.WithFields(map[string]interface{}{ "kind": "handler", "type": svc.Handler.Type, "service": svc.Name, }) + h := registry.GetHandler(svc.Handler.Type)( - handler.ChainOption(chain), + handler.ChainOption(chains[svc.Chain]), + handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), ) if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { @@ -67,68 +73,63 @@ func buildService(cfg *config.Config) (services []*service.Service) { return } -func buildChain(cfg *config.Config) (chains []*chain.Chain) { - if cfg == nil || len(cfg.Chains) == 0 { +func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { + if cfg == nil { return nil } - for _, ch := range cfg.Chains { - c := &chain.Chain{ - Name: ch.Name, - } + c := &chain.Chain{} - selector := selectorFromConfig(ch.LB) - for _, hop := range ch.Hops { - group := &chain.NodeGroup{} - for _, v := range hop.Nodes { - node := chain.NewNode(v.Name, v.Addr) + selector := selectorFromConfig(cfg.LB) + for _, hop := range cfg.Hops { + group := &chain.NodeGroup{} + for _, v := range hop.Nodes { - connectorLogger := log.WithFields(map[string]interface{}{ - "kind": "connector", - "type": v.Connector.Type, - "hop": hop.Name, - "node": node.Name(), - }) - cr := registry.GetConnector(v.Connector.Type)( - connector.LoggerOption(connectorLogger), - ) - if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { - connectorLogger.Fatal("init:", err) - } - - dialerLogger := log.WithFields(map[string]interface{}{ - "kind": "dialer", - "type": v.Dialer.Type, - "hop": hop.Name, - "node": node.Name(), - }) - d := registry.GetDialer(v.Dialer.Type)( - dialer.LoggerOption(dialerLogger), - ) - if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { - dialerLogger.Fatal("init:", err) - } - - tr := (&chain.Transport{}). - WithConnector(cr). - WithDialer(d) - - node.WithTransport(tr) - group.AddNode(node) + connectorLogger := log.WithFields(map[string]interface{}{ + "kind": "connector", + "type": v.Connector.Type, + "hop": hop.Name, + "node": v.Name, + }) + cr := registry.GetConnector(v.Connector.Type)( + connector.LoggerOption(connectorLogger), + ) + if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { + connectorLogger.Fatal("init:", err) } - sel := selector - if s := selectorFromConfig(hop.LB); s != nil { - sel = s + dialerLogger := log.WithFields(map[string]interface{}{ + "kind": "dialer", + "type": v.Dialer.Type, + "hop": hop.Name, + "node": v.Name, + }) + d := registry.GetDialer(v.Dialer.Type)( + dialer.LoggerOption(dialerLogger), + ) + if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { + dialerLogger.Fatal("init:", err) } - group.WithSelector(sel) - c.AddNodeGroup(group) + + tr := (&chain.Transport{}). + WithConnector(cr). + WithDialer(d) + + node := chain.NewNode(v.Name, v.Addr). + WithTransport(tr). + WithBypass(bypasses[v.Bypass]) + group.AddNode(node) } - chains = append(chains, c) + sel := selector + if s := selectorFromConfig(hop.LB); s != nil { + sel = s + } + group.WithSelector(sel) + c.AddNodeGroup(group) } - return + return c } func logFromConfig(cfg *config.LogConfig) logger.Logger { @@ -182,3 +183,11 @@ func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector { }, ) } + +func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { + if cfg == nil { + return nil + } + + return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) +} diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index b30cab3..f245e4f 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -22,6 +22,7 @@ services: metadata: keepAlive: 15s chain: chain01 + # bypass: bypass01 chains: - name: chain01 @@ -41,6 +42,7 @@ chains: - name: node01 addr: ":8081" url: "http://gost:gost@:8081" + # bypass: bypass01 connector: type: http metadata: @@ -52,11 +54,12 @@ chains: - name: node02 addr: ":8082" url: "http://gost:gost@:8082" + # bypass: bypass01 connector: type: http metadata: userAgent: "gost/3.0" - auth: user1:pass1 + auth: user2:pass2 dialer: type: tcp metadata: {} @@ -70,11 +73,42 @@ chains: - name: node03 addr: ":8083" url: "http://gost:gost@:8083" + # bypass: bypass01 connector: type: http metadata: userAgent: "gost/3.0" - auth: user1:pass1 + auth: user3:pass3 dialer: type: tcp - metadata: {} \ No newline at end of file + metadata: {} + +bypasses: +- name: bypass01 + reverse: false + matchers: + - .baidu.com + - "*.example.com" # domain wildcard + - .example.org # will match example.org and *.example.org + + # From IANA IPv4 Special-Purpose Address Registry + # http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml + - 0.0.0.0/8 # RFC1122: "This host on this network" + - 10.0.0.0/8 # RFC1918: Private-Use + - 100.64.0.0/10 # RFC6598: Shared Address Space + - 127.0.0.0/8 # RFC1122: Loopback + - 169.254.0.0/16 # RFC3927: Link Local + - 172.16.0.0/12 # RFC1918: Private-Use + - 192.0.0.0/24 # RFC6890: IETF Protocol Assignments + - 192.0.2.0/24 # RFC5737: Documentation (TEST-NET-1) + - 192.88.99.0/24 # RFC3068: 6to4 Relay Anycast + - 192.168.0.0/16 # RFC1918: Private-Use + - 198.18.0.0/15 # RFC2544: Benchmarking + - 198.51.100.0/24 # RFC5737: Documentation (TEST-NET-2) + - 203.0.113.0/24 # RFC5737: Documentation (TEST-NET-3) + - 240.0.0.0/4 # RFC1112: Reserved + - 255.255.255.255/32 # RFC0919: Limited Broadcast + + # From IANA Multicast Address Space Registry + # http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + - 224.0.0.0/4 # RFC5771: Multicast/Reserved \ No newline at end of file diff --git a/go.mod b/go.mod index d4e07a1..0b2a31c 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/coreos/go-iptables v0.5.0 // indirect github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/go-gost/gosocks5 v0.3.0 + github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 github.com/google/gopacket v1.1.19 // indirect github.com/gorilla/websocket v1.4.2 diff --git a/go.sum b/go.sum index 7df51d9..cd51b87 100644 --- a/go.sum +++ b/go.sum @@ -113,6 +113,8 @@ github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJ github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= diff --git a/pkg/bypass/bypass.go b/pkg/bypass/bypass.go new file mode 100644 index 0000000..719c85d --- /dev/null +++ b/pkg/bypass/bypass.go @@ -0,0 +1,158 @@ +package bypass + +import ( + "net" + "strconv" + "strings" + + glob "github.com/gobwas/glob" +) + +// Matcher is a generic pattern matcher, +// it gives the match result of the given pattern for specific v. +type Matcher interface { + Match(v string) bool +} + +// NewMatcher creates a Matcher for the given pattern. +// The acutal Matcher depends on the pattern: +// IP Matcher if pattern is a valid IP address. +// CIDR Matcher if pattern is a valid CIDR address. +// Domain Matcher if both of the above are not. +func NewMatcher(pattern string) Matcher { + if pattern == "" { + return nil + } + if ip := net.ParseIP(pattern); ip != nil { + return IPMatcher(ip) + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + return CIDRMatcher(inet) + } + return DomainMatcher(pattern) +} + +type ipMatcher struct { + ip net.IP +} + +// IPMatcher creates a Matcher for a specific IP address. +func IPMatcher(ip net.IP) Matcher { + return &ipMatcher{ + ip: ip, + } +} + +func (m *ipMatcher) Match(ip string) bool { + if m == nil { + return false + } + return m.ip.Equal(net.ParseIP(ip)) +} + +type cidrMatcher struct { + ipNet *net.IPNet +} + +// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. +func CIDRMatcher(inet *net.IPNet) Matcher { + return &cidrMatcher{ + ipNet: inet, + } +} + +func (m *cidrMatcher) Match(ip string) bool { + if m == nil || m.ipNet == nil { + return false + } + return m.ipNet.Contains(net.ParseIP(ip)) +} + +type domainMatcher struct { + pattern string + glob glob.Glob +} + +// DomainMatcher creates a Matcher for a specific domain pattern, +// the pattern can be a plain domain such as 'example.com', +// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. +func DomainMatcher(pattern string) Matcher { + p := pattern + if strings.HasPrefix(pattern, ".") { + p = pattern[1:] // trim the prefix '.' + pattern = "*" + p + } + return &domainMatcher{ + pattern: p, + glob: glob.MustCompile(pattern), + } +} + +func (m *domainMatcher) Match(domain string) bool { + if m == nil || m.glob == nil { + return false + } + + if domain == m.pattern { + return true + } + return m.glob.Match(domain) +} + +// Bypass is a filter of address (IP or domain). +type Bypass interface { + // Contains reports whether the bypass includes addr. + Contains(addr string) bool +} + +type bypass struct { + matchers []Matcher + reversed bool +} + +// NewBypass creates and initializes a new Bypass using matchers as its match rules. +// The rules will be reversed if the reversed is true. +func NewBypass(reversed bool, matchers ...Matcher) Bypass { + return &bypass{ + matchers: matchers, + reversed: reversed, + } +} + +// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. +// The rules will be reversed if the reverse is true. +func NewBypassPatterns(reversed bool, patterns ...string) Bypass { + var matchers []Matcher + for _, pattern := range patterns { + if m := NewMatcher(pattern); m != nil { + matchers = append(matchers, m) + } + } + return NewBypass(reversed, matchers...) +} + +func (bp *bypass) Contains(addr string) bool { + if addr == "" || bp == nil || len(bp.matchers) == 0 { + return false + } + + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + + var matched bool + for _, matcher := range bp.matchers { + if matcher == nil { + continue + } + if matcher.Match(addr) { + matched = true + break + } + } + return !bp.reversed && matched || + bp.reversed && !matched +} diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index 28ed459..b6a07c3 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -1,7 +1,6 @@ package chain type Chain struct { - Name string groups []*NodeGroup } @@ -9,7 +8,7 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) { c.groups = append(c.groups, group) } -func (c *Chain) GetRoute() (r *Route) { +func (c *Chain) GetRouteFor(addr string) (r *Route) { if c == nil || len(c.groups) == 0 { return } @@ -20,11 +19,13 @@ func (c *Chain) GetRoute() (r *Route) { if node == nil { return } - // TODO: bypass + if node.bypass != nil && node.bypass.Contains(addr) { + break + } - if node.Transport().IsMultiplex() { - tr := node.Transport().WithRoute(r) - node = node.WithTransport(tr) + if node.transport.IsMultiplex() { + tr := node.transport.Copy().WithRoute(r) + node = node.Copy().WithTransport(tr) r = &Route{} } diff --git a/pkg/chain/node.go b/pkg/chain/node.go index c6ae7aa..ecb5002 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -3,12 +3,15 @@ package chain import ( "sync" "time" + + "github.com/go-gost/gost/pkg/bypass" ) type Node struct { name string addr string transport *Transport + bypass bypass.Bypass marker *failMarker } @@ -28,15 +31,22 @@ func (node *Node) Addr() string { return node.addr } -func (node *Node) Transport() *Transport { - return node.transport -} - func (node *Node) WithTransport(tr *Transport) *Node { node.transport = tr return node } +func (node *Node) WithBypass(bp bypass.Bypass) *Node { + node.bypass = bp + return node +} + +func (node *Node) Copy() *Node { + n := &Node{} + *n = *node + return n +} + type NodeGroup struct { nodes []*Node selector Selector diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 1fad1cc..dbaa3d0 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -20,13 +20,13 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { } node := r.nodes[0] - cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr()) + cc, err := node.transport.Dial(ctx, r.nodes[0].Addr()) if err != nil { node.marker.Mark() return } - cn, err := node.Transport().Handshake(ctx, cc) + cn, err := node.transport.Handshake(ctx, cc) if err != nil { cc.Close() node.marker.Mark() @@ -36,7 +36,7 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { preNode := node for _, node := range r.nodes[1:] { - cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr()) + cc, err = preNode.transport.Connect(ctx, cn, "tcp", node.Addr()) if err != nil { cn.Close() node.marker.Mark() @@ -68,7 +68,7 @@ func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, er return nil, err } - cc, err := r.Last().Transport().Connect(ctx, conn, network, address) + cc, err := r.Last().transport.Connect(ctx, conn, network, address) if err != nil { conn.Close() return nil, err @@ -93,9 +93,19 @@ func (r *Route) IsEmpty() bool { return r == nil || len(r.nodes) == 0 } -func (r Route) Last() *Node { +func (r *Route) Last() *Node { if r.IsEmpty() { return nil } return r.nodes[len(r.nodes)-1] } + +func (r *Route) Path() (path []*Node) { + for _, node := range r.nodes { + if node.transport != nil && node.transport.route != nil { + path = append(path, node.transport.route.Path()...) + } + path = append(path, node) + } + return +} diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index 83c4081..305d516 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -14,6 +14,12 @@ type Transport struct { connector connector.Connector } +func (tr *Transport) Copy() *Transport { + tr2 := &Transport{} + *tr2 = *tr + return tr +} + func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport { tr.dialer = dialer return tr diff --git a/pkg/components/connector/http/connector.go b/pkg/components/connector/http/connector.go index a26d4a7..3eff3c8 100644 --- a/pkg/components/connector/http/connector.go +++ b/pkg/components/connector/http/connector.go @@ -21,7 +21,7 @@ func init() { registry.RegiserConnector("http", NewConnector) } -type Connector struct { +type httpConnector struct { md metadata logger logger.Logger } @@ -32,16 +32,16 @@ func NewConnector(opts ...connector.Option) connector.Connector { opt(options) } - return &Connector{ + return &httpConnector{ logger: options.Logger, } } -func (c *Connector) Init(md md.Metadata) (err error) { +func (c *httpConnector) Init(md md.Metadata) (err error) { return c.parseMetadata(md) } -func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { +func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: address}, @@ -95,7 +95,7 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address return conn, nil } -func (c *Connector) parseMetadata(md md.Metadata) (err error) { +func (c *httpConnector) parseMetadata(md md.Metadata) (err error) { c.md.UserAgent, _ = md.Get(userAgent).(string) if c.md.UserAgent == "" { c.md.UserAgent = defaultUserAgent diff --git a/pkg/components/dialer/tcp/dialer.go b/pkg/components/dialer/tcp/dialer.go index 9d0dc02..cdc146b 100644 --- a/pkg/components/dialer/tcp/dialer.go +++ b/pkg/components/dialer/tcp/dialer.go @@ -14,7 +14,7 @@ func init() { registry.RegisterDialer("tcp", NewDialer) } -type Dialer struct { +type tcpDialer struct { md metadata logger logger.Logger } @@ -25,16 +25,16 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { opt(options) } - return &Dialer{ + return &tcpDialer{ logger: options.Logger, } } -func (d *Dialer) Init(md md.Metadata) (err error) { +func (d *tcpDialer) Init(md md.Metadata) (err error) { return d.parseMetadata(md) } -func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { +func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { var options dialer.DialOptions for _, opt := range opts { opt(&options) @@ -71,6 +71,6 @@ func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOptio return conn, err } -func (d *Dialer) parseMetadata(md md.Metadata) (err error) { +func (d *tcpDialer) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/handler/http/handler.go b/pkg/components/handler/http/handler.go index da40661..85e50b4 100644 --- a/pkg/components/handler/http/handler.go +++ b/pkg/components/handler/http/handler.go @@ -2,10 +2,12 @@ package http import ( "bufio" + "bytes" "context" "encoding/base64" "encoding/binary" "errors" + "fmt" "hash/crc32" "net" "net/http" @@ -16,6 +18,7 @@ import ( "time" "github.com/go-gost/gost/pkg/auth" + "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/components/handler" md "github.com/go-gost/gost/pkg/components/metadata" @@ -27,8 +30,9 @@ func init() { registry.RegisterHandler("http", NewHandler) } -type Handler struct { +type httpHandler struct { chain *chain.Chain + bypass bypass.Bypass logger logger.Logger md metadata } @@ -39,17 +43,18 @@ func NewHandler(opts ...handler.Option) handler.Handler { opt(options) } - return &Handler{ + return &httpHandler{ chain: options.Chain, + bypass: options.Bypass, logger: options.Logger, } } -func (h *Handler) Init(md md.Metadata) error { +func (h *httpHandler) Init(md md.Metadata) error { return h.parseMetadata(md) } -func (h *Handler) parseMetadata(md md.Metadata) error { +func (h *httpHandler) parseMetadata(md md.Metadata) error { h.md.proxyAgent = md.GetString(proxyAgentKey) if v, _ := md.Get(authsKey).([]interface{}); len(v) > 0 { @@ -81,7 +86,7 @@ func (h *Handler) parseMetadata(md md.Metadata) error { return nil } -func (h *Handler) Handle(ctx context.Context, conn net.Conn) { +func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() h.logger = h.logger.WithFields(map[string]interface{}{ @@ -99,7 +104,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { h.handleRequest(ctx, conn, req) } -func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) { +func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) { if req == nil { return } @@ -156,21 +161,18 @@ func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Re } */ - /* - if h.options.Bypass.Contains(host) { - resp.StatusCode = http.StatusForbidden + if h.bypass != nil && h.bypass.Contains(host) { + resp.StatusCode = http.StatusForbidden - log.Logf("[http] %s - %s bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) } - */ + h.logger.Info("bypass: ", host) + + resp.Write(conn) + return + } if !h.authenticate(conn, req, resp) { return @@ -200,6 +202,7 @@ func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Re dump, _ := httputil.DumpResponse(resp, false) h.logger.Debug(string(dump)) } + h.logger.Error(err) return } defer cc.Close() @@ -227,25 +230,21 @@ func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Re handler.Transport(conn, cc) } -func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err error) { +func (h *httpHandler) dial(ctx context.Context, addr string) (conn net.Conn, err error) { count := h.md.retryCount + 1 if count <= 0 { count = 1 } for i := 0; i < count; i++ { - route := h.chain.GetRoute() + route := h.chain.GetRouteFor(addr) - /* - buf := bytes.Buffer{} - fmt.Fprintf(&buf, "%s -> %s -> ", - conn.RemoteAddr(), h.options.Node.String()) - for _, nd := range route.route { - fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) - } - fmt.Fprintf(&buf, "%s", host) - log.Log("[route]", buf.String()) - */ + buf := bytes.Buffer{} + for _, node := range route.Path() { + fmt.Fprintf(&buf, "%s@%s -> ", node.Name(), node.Addr()) + } + fmt.Fprintf(&buf, "%s", addr) + h.logger.Infof("route(retry=%d): %s", i, buf.String()) /* // forward http request @@ -270,7 +269,7 @@ func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err err return } -func (h *Handler) decodeServerName(s string) (string, error) { +func (h *httpHandler) decodeServerName(s string) (string, error) { b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { return "", err @@ -288,7 +287,7 @@ func (h *Handler) decodeServerName(s string) (string, error) { return string(v), nil } -func (h *Handler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { +func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { if proxyAuth == "" { return } @@ -309,7 +308,7 @@ func (h *Handler) basicProxyAuth(proxyAuth string) (username, password string, o return cs[:s], cs[s+1:], true } -func (h *Handler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { +func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) if h.md.authenticator == nil || h.md.authenticator.Authenticate(u, p) { return true diff --git a/pkg/components/handler/http/middleware.go b/pkg/components/handler/http/middleware.go deleted file mode 100644 index d02cfda..0000000 --- a/pkg/components/handler/http/middleware.go +++ /dev/null @@ -1 +0,0 @@ -package http diff --git a/pkg/components/handler/option.go b/pkg/components/handler/option.go index 42478c4..5333767 100644 --- a/pkg/components/handler/option.go +++ b/pkg/components/handler/option.go @@ -1,12 +1,14 @@ package handler import ( + "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" ) type Options struct { Chain *chain.Chain + Bypass bypass.Bypass Logger logger.Logger } @@ -23,3 +25,9 @@ func ChainOption(chain *chain.Chain) Option { opts.Chain = chain } } + +func BypassOption(bypass bypass.Bypass) Option { + return func(opts *Options) { + opts.Bypass = bypass + } +} diff --git a/pkg/components/handler/ss/handler.go b/pkg/components/handler/ss/handler.go index bc4cd28..b30d7bd 100644 --- a/pkg/components/handler/ss/handler.go +++ b/pkg/components/handler/ss/handler.go @@ -7,6 +7,8 @@ import ( "time" "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/components/handler" md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" @@ -19,7 +21,9 @@ func init() { registry.RegisterHandler("ss", NewHandler) } -type Handler struct { +type ssHandler struct { + chain *chain.Chain + bypass bypass.Bypass logger logger.Logger md metadata } @@ -30,18 +34,25 @@ func NewHandler(opts ...handler.Option) handler.Handler { opt(options) } - return &Handler{ + return &ssHandler{ + chain: options.Chain, + bypass: options.Bypass, logger: options.Logger, } } -func (h *Handler) Init(md md.Metadata) (err error) { +func (h *ssHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } -func (h *Handler) Handle(ctx context.Context, conn net.Conn) { +func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() + h.logger = h.logger.WithFields(map[string]interface{}{ + "src": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + if h.md.cipher != nil { conn = &shadowConn{ Conn: h.md.cipher.StreamConn(conn), @@ -61,9 +72,18 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { conn.SetReadDeadline(time.Time{}) - host := addr.String() - cc, err := net.Dial("tcp", host) + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": addr.String(), + }) + + if h.bypass != nil && h.bypass.Contains(addr.String()) { + h.logger.Info("bypass: ", addr.String()) + return + } + + cc, err := h.dial(ctx, addr.String()) if err != nil { + h.logger.Error(err) return } defer cc.Close() @@ -71,7 +91,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { handler.Transport(conn, cc) } -func (h *Handler) parseMetadata(md md.Metadata) (err error) { +func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { h.md.cipher, err = h.initCipher( md.GetString(method), md.GetString(password), @@ -82,10 +102,41 @@ func (h *Handler) parseMetadata(md md.Metadata) (err error) { } h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) return } -func (h *Handler) initCipher(method, password string, key string) (core.Cipher, error) { +func (h *ssHandler) dial(ctx context.Context, addr string) (conn net.Conn, err error) { + count := h.md.retryCount + 1 + if count <= 0 { + count = 1 + } + + for i := 0; i < count; i++ { + route := h.chain.GetRouteFor(addr) + + /* + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + */ + + conn, err = route.Dial(ctx, "tcp", addr) + if err == nil { + break + } + h.logger.Errorf("route(retry=%d): %s", i, err) + } + + return +} + +func (h *ssHandler) initCipher(method, password string, key string) (core.Cipher, error) { if method == "" && password == "" { return nil, nil } diff --git a/pkg/components/handler/ss/metadata.go b/pkg/components/handler/ss/metadata.go index 76d49f0..a3e8a1e 100644 --- a/pkg/components/handler/ss/metadata.go +++ b/pkg/components/handler/ss/metadata.go @@ -11,9 +11,11 @@ const ( password = "password" key = "key" readTimeout = "readTimeout" + retryCount = "retry" ) type metadata struct { cipher core.Cipher readTimeout time.Duration + retryCount int } diff --git a/pkg/config/config.go b/pkg/config/config.go index aef01b4..8a70a45 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,6 +30,11 @@ type LoadbalancingConfig struct { FailTimeout time.Duration } +type BypassConfig struct { + Name string + Reverse bool + Matchers []string +} type ListenerConfig struct { Type string Metadata map[string]interface{} @@ -57,6 +62,7 @@ type ServiceConfig struct { Listener *ListenerConfig Handler *HandlerConfig Chain string + Bypass string } type ChainConfig struct { @@ -77,12 +83,14 @@ type NodeConfig struct { Addr string Dialer *DialerConfig Connector *ConnectorConfig + Bypass string } type Config struct { Log *LogConfig Services []ServiceConfig Chains []ChainConfig + Bypasses []BypassConfig } func (c *Config) Load() error { diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 3a44f2b..d63294a 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -85,7 +85,10 @@ func NewLogger(opts ...LoggerOption) Logger { switch options.Format { case JSONFormat: - log.SetFormatter(&logrus.JSONFormatter{}) + log.SetFormatter(&logrus.JSONFormatter{ + DisableHTMLEscape: true, + // PrettyPrint: true, + }) default: log.SetFormatter(&logrus.TextFormatter{ FullTimestamp: true,