From 64736585ee415c31006e53250241a0903e3f6056 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 31 Oct 2021 12:41:53 +0800 Subject: [PATCH] improve http handler --- cmd/gost/config.go | 226 ++++++++++++--------- cmd/gost/gost.yml | 57 ++++-- cmd/gost/main.go | 3 + pkg/chain/node.go | 72 ++++++- pkg/chain/route.go | 8 + pkg/chain/selector.go | 137 ++++++++++++- pkg/components/connector/http/connector.go | 18 +- pkg/components/dialer/tcp/dialer.go | 26 ++- pkg/components/handler/http/handler.go | 7 +- pkg/components/handler/http/metadata.go | 1 + pkg/config/config.go | 7 +- 11 files changed, 435 insertions(+), 127 deletions(-) diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 04545b7..e617f35 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -16,6 +16,121 @@ import ( "github.com/go-gost/gost/pkg/service" ) +func buildService(cfg *config.Config) (services []*service.Service) { + if cfg == nil || len(cfg.Services) == 0 { + return + } + + chains := buildChain(cfg) + + for _, svc := range cfg.Services { + listenerLogger := log.WithFields(map[string]interface{}{ + "kind": "listener", + "type": svc.Listener.Type, + "service": svc.Name, + }) + ln := registry.GetListener(svc.Listener.Type)( + listener.AddrOption(svc.Addr), + listener.LoggerOption(listenerLogger), + ) + if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { + listenerLogger.Fatal("init:", err) + } + + var chain *chain.Chain + for _, ch := range chains { + if svc.Chain == ch.Name { + chain = ch + break + } + } + + handlerLogger := log.WithFields(map[string]interface{}{ + "kind": "handler", + "type": svc.Handler.Type, + "service": svc.Name, + }) + h := registry.GetHandler(svc.Handler.Type)( + handler.ChainOption(chain), + handler.LoggerOption(handlerLogger), + ) + if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { + handlerLogger.Fatal("init:", err) + } + + s := (&service.Service{}). + WithListener(ln). + WithHandler(h) + services = append(services, s) + } + + return +} + +func buildChain(cfg *config.Config) (chains []*chain.Chain) { + if cfg == nil || len(cfg.Chains) == 0 { + return nil + } + + for _, ch := range cfg.Chains { + c := &chain.Chain{ + Name: ch.Name, + } + + selector := selectorFromConfig(ch.LB) + for _, hop := range ch.Hops { + group := &chain.NodeGroup{} + for _, v := range hop.Nodes { + node := chain.NewNode(v.Name, v.Addr) + + connectorLogger := log.WithFields(map[string]interface{}{ + "kind": "connector", + "type": v.Connector.Type, + "hop": hop.Name, + "node": node.Name(), + }) + cr := registry.GetConnector(v.Connector.Type)( + connector.LoggerOption(connectorLogger), + ) + if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { + connectorLogger.Fatal("init:", err) + } + + dialerLogger := log.WithFields(map[string]interface{}{ + "kind": "dialer", + "type": v.Dialer.Type, + "hop": hop.Name, + "node": node.Name(), + }) + d := registry.GetDialer(v.Dialer.Type)( + dialer.LoggerOption(dialerLogger), + ) + if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { + dialerLogger.Fatal("init:", err) + } + + tr := (&chain.Transport{}). + WithConnector(cr). + WithDialer(d) + + node.WithTransport(tr) + group.AddNode(node) + } + + sel := selector + if s := selectorFromConfig(hop.LB); s != nil { + sel = s + } + group.WithSelector(sel) + c.AddNodeGroup(group) + } + + chains = append(chains, c) + } + + return +} + func logFromConfig(cfg *config.LogConfig) logger.Logger { opts := []logger.LoggerOption{ logger.FormatLoggerOption(logger.LogFormat(cfg.Format)), @@ -41,100 +156,29 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { return logger.NewLogger(opts...) } -func buildService(cfg *config.Config) (services []*service.Service) { - if cfg == nil || len(cfg.Services) == 0 { - return - } - - chains := buildChain(cfg) - - for _, svc := range cfg.Services { - s := &service.Service{} - - ln := registry.GetListener(svc.Listener.Type)( - listener.AddrOption(svc.Addr), - listener.LoggerOption( - log.WithFields(map[string]interface{}{ - "kind": "listener", - "type": svc.Listener.Type, - }), - ), - ) - ln.Init(metadata.MapMetadata(svc.Listener.Metadata)) - s.WithListener(ln) - - var chain *chain.Chain - for _, ch := range chains { - if svc.Chain == ch.Name { - chain = ch - break - } - } - h := registry.GetHandler(svc.Handler.Type)( - handler.ChainOption(chain), - handler.LoggerOption( - log.WithFields(map[string]interface{}{ - "kind": "handler", - "type": svc.Handler.Type, - }), - ), - ) - h.Init(metadata.MapMetadata(svc.Handler.Metadata)) - s.WithHandler(h) - - services = append(services, s) - } - - return -} - -func buildChain(cfg *config.Config) (chains []*chain.Chain) { - if cfg == nil || len(cfg.Chains) == 0 { +func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector { + if cfg == nil { return nil } - for _, ch := range cfg.Chains { - c := &chain.Chain{ - Name: ch.Name, - } - for _, hop := range ch.Hops { - group := &chain.NodeGroup{} - for _, v := range hop.Nodes { - node := chain.NewNode(v.Name, v.Addr) - - tr := &chain.Transport{} - - cr := registry.GetConnector(v.Connector.Type)( - connector.LoggerOption( - log.WithFields(map[string]interface{}{ - "kind": "connector", - "type": v.Connector.Type, - }), - ), - ) - cr.Init(metadata.MapMetadata(v.Connector.Metadata)) - tr.WithConnector(cr) - - d := registry.GetDialer(v.Dialer.Type)( - dialer.LoggerOption( - log.WithFields(map[string]interface{}{ - "kind": "dialer", - "type": v.Dialer.Type, - }), - ), - ) - d.Init(metadata.MapMetadata(v.Dialer.Metadata)) - tr.WithDialer(d) - - node.WithTransport(tr) - - group.AddNode(node) - } - c.AddNodeGroup(group) - } - - chains = append(chains, c) + var strategy chain.Strategy + switch cfg.Strategy { + case "round": + strategy = &chain.RoundRobinStrategy{} + case "random": + strategy = &chain.RandomStrategy{} + case "fifio": + strategy = &chain.FIFOStrategy{} + default: + strategy = &chain.RoundRobinStrategy{} } - return + return chain.NewSelector( + strategy, + &chain.InvalidFilter{}, + &chain.FailFilter{ + MaxFails: cfg.MaxFails, + FailTimeout: cfg.FailTimeout, + }, + ) } diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index 12f7bac..b30cab3 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -4,46 +4,77 @@ log: format: json # text, json services: -- url: "http://gost:gost@:8000" +- name: http+tcp + url: "http://gost:gost@:8000" addr: ":8000" handler: type: http metadata: proxyAgent: "gost/3.0" + retry: 3 auths: - user1:pass1 - user2:pass2 + # probeResist: code:404 # code, web, host, file + # knock: example.com listener: type: tcp metadata: keepAlive: 15s - # chain: chain01 + chain: chain01 chains: - name: chain01 # chain level load balancing lb: strategy: round - filters: - - filter1 + maxFails: 1 + failTimeout: 30s hops: - - name: level01 + - name: hop01 # hop level load balancing lb: - strategy: rand - filters: - - filter1 + strategy: round + maxFails: 1 + failTimeout: 30s nodes: - name: node01 - addr: ":8080" - url: "http://gost:gost@:8080" + addr: ":8081" + url: "http://gost:gost@:8081" connector: type: http metadata: userAgent: "gost/3.0" - auth: username:password + auth: user1:pass1 dialer: type: tcp metadata: {} - - + - name: node02 + addr: ":8082" + url: "http://gost:gost@:8082" + connector: + type: http + metadata: + userAgent: "gost/3.0" + auth: user1:pass1 + dialer: + type: tcp + metadata: {} + - name: hop02 + # hop level load balancing + lb: + strategy: round + maxFails: 1 + failTimeout: 30s + nodes: + - name: node03 + addr: ":8083" + url: "http://gost:gost@:8083" + connector: + type: http + metadata: + userAgent: "gost/3.0" + auth: user1:pass1 + dialer: + type: tcp + metadata: {} \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 0716122..12f5dfe 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,6 +1,8 @@ package main import ( + stdlog "log" + "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/logger" ) @@ -10,6 +12,7 @@ var ( ) func main() { + stdlog.SetFlags(stdlog.LstdFlags | stdlog.Lshortfile) cfg := &config.Config{} if err := cfg.Load(); err != nil { log.Fatal(err) diff --git a/pkg/chain/node.go b/pkg/chain/node.go index e47dcc0..c6ae7aa 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -1,15 +1,22 @@ package chain +import ( + "sync" + "time" +) + type Node struct { name string addr string transport *Transport + marker *failMarker } func NewNode(name, addr string) *Node { return &Node{ - name: name, - addr: addr, + name: name, + addr: addr, + marker: &failMarker{}, } } @@ -45,15 +52,72 @@ func (g *NodeGroup) AddNode(node *Node) { g.nodes = append(g.nodes, node) } -func (g *NodeGroup) WithSelector(selector Selector) { +func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup { g.selector = selector + return g } func (g *NodeGroup) Next() *Node { + if g == nil || len(g.nodes) == 0 { + return nil + } + selector := g.selector if selector == nil { - // selector = defaultSelector return g.nodes[0] } + return selector.Select(g.nodes...) } + +type failMarker struct { + failTime int64 + failCount uint32 + mux sync.RWMutex +} + +func (m *failMarker) FailTime() int64 { + if m == nil { + return 0 + } + + m.mux.RLock() + defer m.mux.RUnlock() + + return m.failTime +} + +func (m *failMarker) FailCount() uint32 { + if m == nil { + return 0 + } + + m.mux.RLock() + defer m.mux.RUnlock() + + return m.failCount +} + +func (m *failMarker) Mark() { + if m == nil { + return + } + + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = time.Now().Unix() + m.failCount++ +} + +func (m *failMarker) Reset() { + if m == nil { + return + } + + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = 0 + m.failCount = 0 +} diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 8536918..1fad1cc 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -22,26 +22,34 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { node := r.nodes[0] cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr()) if err != nil { + node.marker.Mark() return } cn, err := node.Transport().Handshake(ctx, cc) if err != nil { cc.Close() + node.marker.Mark() return } + node.marker.Reset() preNode := node for _, node := range r.nodes[1:] { cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr()) if err != nil { cn.Close() + node.marker.Mark() return } cc, err = node.transport.Handshake(ctx, cc) if err != nil { cn.Close() + node.marker.Mark() + return } + node.marker.Reset() + cn = cc preNode = node } diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index cc70e75..bb39b9e 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -1,19 +1,24 @@ package chain +import ( + "math/rand" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +// default options for FailFilter +const ( + DefaultMaxFails = 1 + DefaultFailTimeout = 30 * time.Second +) + var ( defaultSelector Selector = NewSelector(nil) ) -type Filter interface { - Filter(nodes ...*Node) []*Node - String() string -} - -type Strategy interface { - Apply(nodes ...*Node) *Node - String() string -} - type Selector interface { Select(nodes ...*Node) *Node } @@ -39,3 +44,115 @@ func (s *selector) Select(nodes ...*Node) *Node { } return s.strategy.Apply(nodes...) } + +type Strategy interface { + Apply(nodes ...*Node) *Node +} + +// RoundStrategy is a strategy for node selector. +// The node will be selected by round-robin algorithm. +type RoundRobinStrategy struct { + counter uint64 +} + +func (s *RoundRobinStrategy) Apply(nodes ...*Node) *Node { + if len(nodes) == 0 { + return nil + } + + n := atomic.AddUint64(&s.counter, 1) - 1 + return nodes[int(n%uint64(len(nodes)))] +} + +// RandomStrategy is a strategy for node selector. +// The node will be selected randomly. +type RandomStrategy struct { + Seed int64 + rand *rand.Rand + once sync.Once + mux sync.Mutex +} + +func (s *RandomStrategy) Apply(nodes ...*Node) *Node { + s.once.Do(func() { + seed := s.Seed + if seed == 0 { + seed = time.Now().UnixNano() + } + s.rand = rand.New(rand.NewSource(seed)) + }) + if len(nodes) == 0 { + return nil + } + + s.mux.Lock() + defer s.mux.Unlock() + + r := s.rand.Int() + + return nodes[r%len(nodes)] +} + +// FIFOStrategy is a strategy for node selector. +// The node will be selected from first to last, +// and will stick to the selected node until it is failed. +type FIFOStrategy struct{} + +// Apply applies the fifo strategy for the nodes. +func (s *FIFOStrategy) Apply(nodes ...*Node) *Node { + if len(nodes) == 0 { + return nil + } + return nodes[0] +} + +type Filter interface { + Filter(nodes ...*Node) []*Node +} + +// FailFilter filters the dead node. +// A node is marked as dead if its failed count is greater than MaxFails. +type FailFilter struct { + MaxFails int + FailTimeout time.Duration +} + +// Filter filters dead nodes. +func (f *FailFilter) Filter(nodes ...*Node) []*Node { + maxFails := f.MaxFails + if maxFails == 0 { + maxFails = DefaultMaxFails + } + failTimeout := f.FailTimeout + if failTimeout == 0 { + failTimeout = DefaultFailTimeout + } + + if len(nodes) <= 1 || maxFails < 0 { + return nodes + } + var nl []*Node + for _, node := range nodes { + if node.marker.FailCount() < uint32(maxFails) || + time.Since(time.Unix(node.marker.FailTime(), 0)) >= failTimeout { + nl = append(nl, node) + } + } + return nl +} + +// InvalidFilter filters the invalid node. +// A node is invalid if its port is invalid (negative or zero value). +type InvalidFilter struct{} + +// Filter filters invalid nodes. +func (f *InvalidFilter) Filter(nodes ...*Node) []*Node { + var nl []*Node + for _, node := range nodes { + _, sport, _ := net.SplitHostPort(node.Addr()) + if port, _ := strconv.Atoi(sport); port > 0 { + nl = append(nl, node) + } + } + return nl +} diff --git a/pkg/components/connector/http/connector.go b/pkg/components/connector/http/connector.go index 040a8eb..a26d4a7 100644 --- a/pkg/components/connector/http/connector.go +++ b/pkg/components/connector/http/connector.go @@ -5,9 +5,9 @@ import ( "context" "encoding/base64" "fmt" - "log" "net" "net/http" + "net/http/httputil" "net/url" "strings" @@ -51,11 +51,15 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address Header: make(http.Header), } if c.md.UserAgent != "" { - log.Println(c.md.UserAgent) req.Header.Set("User-Agent", c.md.UserAgent) } req.Header.Set("Proxy-Connection", "keep-alive") + c.logger = c.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": conn.RemoteAddr().String(), + }) + if user := c.md.User; user != nil { u := user.Username() p, _ := user.Password() @@ -63,6 +67,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) } + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + c.logger.Debug(string(dump)) + } + req = req.WithContext(ctx) if err := req.Write(conn); err != nil { return nil, err @@ -74,6 +83,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address } defer resp.Body.Close() + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + c.logger.Debug(string(dump)) + } + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("%s", resp.Status) } diff --git a/pkg/components/dialer/tcp/dialer.go b/pkg/components/dialer/tcp/dialer.go index 6845319..9d0dc02 100644 --- a/pkg/components/dialer/tcp/dialer.go +++ b/pkg/components/dialer/tcp/dialer.go @@ -42,11 +42,33 @@ func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOptio dial := options.DialFunc if dial != nil { - return dial(ctx, addr) + conn, err := dial(ctx, addr) + if err != nil { + d.logger.Error(err) + } else { + if d.logger.IsLevelEnabled(logger.DebugLevel) { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") + } + } + return conn, err } var netd net.Dialer - return netd.DialContext(ctx, "tcp", addr) + conn, err := netd.DialContext(ctx, "tcp", addr) + if err != nil { + d.logger.Error(err) + } else { + if d.logger.IsLevelEnabled(logger.DebugLevel) { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial direct") + } + } + return conn, err } func (d *Dialer) parseMetadata(md md.Metadata) (err error) { diff --git a/pkg/components/handler/http/handler.go b/pkg/components/handler/http/handler.go index 94912fe..da40661 100644 --- a/pkg/components/handler/http/handler.go +++ b/pkg/components/handler/http/handler.go @@ -76,6 +76,7 @@ func (h *Handler) parseMetadata(md md.Metadata) error { } } } + h.md.retryCount = md.GetInt(retryCount) return nil } @@ -260,10 +261,10 @@ func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err err */ conn, err = route.Dial(ctx, "tcp", addr) - if err != nil { - h.logger.Warn("retry:", err) - continue + if err == nil { + break } + h.logger.Errorf("route(retry=%d): %s", i, err) } return diff --git a/pkg/components/handler/http/metadata.go b/pkg/components/handler/http/metadata.go index 92e8ada..d28806e 100644 --- a/pkg/components/handler/http/metadata.go +++ b/pkg/components/handler/http/metadata.go @@ -8,6 +8,7 @@ const ( authsKey = "auths" probeResistKey = "probeResist" knockKey = "knock" + retryCount = "retry" ) type metadata struct { diff --git a/pkg/config/config.go b/pkg/config/config.go index 6602f83..aef01b4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,6 +2,7 @@ package config import ( "io" + "time" "github.com/spf13/viper" ) @@ -24,8 +25,9 @@ type LogConfig struct { } type LoadbalancingConfig struct { - Strategy string - Filters []string + Strategy string + MaxFails int + FailTimeout time.Duration } type ListenerConfig struct { @@ -49,6 +51,7 @@ type ConnectorConfig struct { } type ServiceConfig struct { + Name string URL string Addr string Listener *ListenerConfig