improve http handler

This commit is contained in:
ginuerzh 2021-10-31 12:41:53 +08:00
parent 248f7e4318
commit 64736585ee
11 changed files with 435 additions and 127 deletions

View File

@ -16,6 +16,121 @@ import (
"github.com/go-gost/gost/pkg/service" "github.com/go-gost/gost/pkg/service"
) )
func buildService(cfg *config.Config) (services []*service.Service) {
if cfg == nil || len(cfg.Services) == 0 {
return
}
chains := buildChain(cfg)
for _, svc := range cfg.Services {
listenerLogger := log.WithFields(map[string]interface{}{
"kind": "listener",
"type": svc.Listener.Type,
"service": svc.Name,
})
ln := registry.GetListener(svc.Listener.Type)(
listener.AddrOption(svc.Addr),
listener.LoggerOption(listenerLogger),
)
if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil {
listenerLogger.Fatal("init:", err)
}
var chain *chain.Chain
for _, ch := range chains {
if svc.Chain == ch.Name {
chain = ch
break
}
}
handlerLogger := log.WithFields(map[string]interface{}{
"kind": "handler",
"type": svc.Handler.Type,
"service": svc.Name,
})
h := registry.GetHandler(svc.Handler.Type)(
handler.ChainOption(chain),
handler.LoggerOption(handlerLogger),
)
if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil {
handlerLogger.Fatal("init:", err)
}
s := (&service.Service{}).
WithListener(ln).
WithHandler(h)
services = append(services, s)
}
return
}
func buildChain(cfg *config.Config) (chains []*chain.Chain) {
if cfg == nil || len(cfg.Chains) == 0 {
return nil
}
for _, ch := range cfg.Chains {
c := &chain.Chain{
Name: ch.Name,
}
selector := selectorFromConfig(ch.LB)
for _, hop := range ch.Hops {
group := &chain.NodeGroup{}
for _, v := range hop.Nodes {
node := chain.NewNode(v.Name, v.Addr)
connectorLogger := log.WithFields(map[string]interface{}{
"kind": "connector",
"type": v.Connector.Type,
"hop": hop.Name,
"node": node.Name(),
})
cr := registry.GetConnector(v.Connector.Type)(
connector.LoggerOption(connectorLogger),
)
if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil {
connectorLogger.Fatal("init:", err)
}
dialerLogger := log.WithFields(map[string]interface{}{
"kind": "dialer",
"type": v.Dialer.Type,
"hop": hop.Name,
"node": node.Name(),
})
d := registry.GetDialer(v.Dialer.Type)(
dialer.LoggerOption(dialerLogger),
)
if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil {
dialerLogger.Fatal("init:", err)
}
tr := (&chain.Transport{}).
WithConnector(cr).
WithDialer(d)
node.WithTransport(tr)
group.AddNode(node)
}
sel := selector
if s := selectorFromConfig(hop.LB); s != nil {
sel = s
}
group.WithSelector(sel)
c.AddNodeGroup(group)
}
chains = append(chains, c)
}
return
}
func logFromConfig(cfg *config.LogConfig) logger.Logger { func logFromConfig(cfg *config.LogConfig) logger.Logger {
opts := []logger.LoggerOption{ opts := []logger.LoggerOption{
logger.FormatLoggerOption(logger.LogFormat(cfg.Format)), logger.FormatLoggerOption(logger.LogFormat(cfg.Format)),
@ -41,100 +156,29 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
return logger.NewLogger(opts...) return logger.NewLogger(opts...)
} }
func buildService(cfg *config.Config) (services []*service.Service) { func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector {
if cfg == nil || len(cfg.Services) == 0 { if cfg == nil {
return
}
chains := buildChain(cfg)
for _, svc := range cfg.Services {
s := &service.Service{}
ln := registry.GetListener(svc.Listener.Type)(
listener.AddrOption(svc.Addr),
listener.LoggerOption(
log.WithFields(map[string]interface{}{
"kind": "listener",
"type": svc.Listener.Type,
}),
),
)
ln.Init(metadata.MapMetadata(svc.Listener.Metadata))
s.WithListener(ln)
var chain *chain.Chain
for _, ch := range chains {
if svc.Chain == ch.Name {
chain = ch
break
}
}
h := registry.GetHandler(svc.Handler.Type)(
handler.ChainOption(chain),
handler.LoggerOption(
log.WithFields(map[string]interface{}{
"kind": "handler",
"type": svc.Handler.Type,
}),
),
)
h.Init(metadata.MapMetadata(svc.Handler.Metadata))
s.WithHandler(h)
services = append(services, s)
}
return
}
func buildChain(cfg *config.Config) (chains []*chain.Chain) {
if cfg == nil || len(cfg.Chains) == 0 {
return nil return nil
} }
for _, ch := range cfg.Chains { var strategy chain.Strategy
c := &chain.Chain{ switch cfg.Strategy {
Name: ch.Name, case "round":
} strategy = &chain.RoundRobinStrategy{}
for _, hop := range ch.Hops { case "random":
group := &chain.NodeGroup{} strategy = &chain.RandomStrategy{}
for _, v := range hop.Nodes { case "fifio":
node := chain.NewNode(v.Name, v.Addr) strategy = &chain.FIFOStrategy{}
default:
tr := &chain.Transport{} strategy = &chain.RoundRobinStrategy{}
cr := registry.GetConnector(v.Connector.Type)(
connector.LoggerOption(
log.WithFields(map[string]interface{}{
"kind": "connector",
"type": v.Connector.Type,
}),
),
)
cr.Init(metadata.MapMetadata(v.Connector.Metadata))
tr.WithConnector(cr)
d := registry.GetDialer(v.Dialer.Type)(
dialer.LoggerOption(
log.WithFields(map[string]interface{}{
"kind": "dialer",
"type": v.Dialer.Type,
}),
),
)
d.Init(metadata.MapMetadata(v.Dialer.Metadata))
tr.WithDialer(d)
node.WithTransport(tr)
group.AddNode(node)
}
c.AddNodeGroup(group)
}
chains = append(chains, c)
} }
return return chain.NewSelector(
strategy,
&chain.InvalidFilter{},
&chain.FailFilter{
MaxFails: cfg.MaxFails,
FailTimeout: cfg.FailTimeout,
},
)
} }

View File

@ -4,46 +4,77 @@ log:
format: json # text, json format: json # text, json
services: services:
- url: "http://gost:gost@:8000" - name: http+tcp
url: "http://gost:gost@:8000"
addr: ":8000" addr: ":8000"
handler: handler:
type: http type: http
metadata: metadata:
proxyAgent: "gost/3.0" proxyAgent: "gost/3.0"
retry: 3
auths: auths:
- user1:pass1 - user1:pass1
- user2:pass2 - user2:pass2
# probeResist: code:404 # code, web, host, file
# knock: example.com
listener: listener:
type: tcp type: tcp
metadata: metadata:
keepAlive: 15s keepAlive: 15s
# chain: chain01 chain: chain01
chains: chains:
- name: chain01 - name: chain01
# chain level load balancing # chain level load balancing
lb: lb:
strategy: round strategy: round
filters: maxFails: 1
- filter1 failTimeout: 30s
hops: hops:
- name: level01 - name: hop01
# hop level load balancing # hop level load balancing
lb: lb:
strategy: rand strategy: round
filters: maxFails: 1
- filter1 failTimeout: 30s
nodes: nodes:
- name: node01 - name: node01
addr: ":8080" addr: ":8081"
url: "http://gost:gost@:8080" url: "http://gost:gost@:8081"
connector: connector:
type: http type: http
metadata: metadata:
userAgent: "gost/3.0" userAgent: "gost/3.0"
auth: username:password auth: user1:pass1
dialer:
type: tcp
metadata: {}
- name: node02
addr: ":8082"
url: "http://gost:gost@:8082"
connector:
type: http
metadata:
userAgent: "gost/3.0"
auth: user1:pass1
dialer:
type: tcp
metadata: {}
- name: hop02
# hop level load balancing
lb:
strategy: round
maxFails: 1
failTimeout: 30s
nodes:
- name: node03
addr: ":8083"
url: "http://gost:gost@:8083"
connector:
type: http
metadata:
userAgent: "gost/3.0"
auth: user1:pass1
dialer: dialer:
type: tcp type: tcp
metadata: {} metadata: {}

View File

@ -1,6 +1,8 @@
package main package main
import ( import (
stdlog "log"
"github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
@ -10,6 +12,7 @@ var (
) )
func main() { func main() {
stdlog.SetFlags(stdlog.LstdFlags | stdlog.Lshortfile)
cfg := &config.Config{} cfg := &config.Config{}
if err := cfg.Load(); err != nil { if err := cfg.Load(); err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -1,15 +1,22 @@
package chain package chain
import (
"sync"
"time"
)
type Node struct { type Node struct {
name string name string
addr string addr string
transport *Transport transport *Transport
marker *failMarker
} }
func NewNode(name, addr string) *Node { func NewNode(name, addr string) *Node {
return &Node{ return &Node{
name: name, name: name,
addr: addr, addr: addr,
marker: &failMarker{},
} }
} }
@ -45,15 +52,72 @@ func (g *NodeGroup) AddNode(node *Node) {
g.nodes = append(g.nodes, node) g.nodes = append(g.nodes, node)
} }
func (g *NodeGroup) WithSelector(selector Selector) { func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup {
g.selector = selector g.selector = selector
return g
} }
func (g *NodeGroup) Next() *Node { func (g *NodeGroup) Next() *Node {
if g == nil || len(g.nodes) == 0 {
return nil
}
selector := g.selector selector := g.selector
if selector == nil { if selector == nil {
// selector = defaultSelector
return g.nodes[0] return g.nodes[0]
} }
return selector.Select(g.nodes...) return selector.Select(g.nodes...)
} }
type failMarker struct {
failTime int64
failCount uint32
mux sync.RWMutex
}
func (m *failMarker) FailTime() int64 {
if m == nil {
return 0
}
m.mux.RLock()
defer m.mux.RUnlock()
return m.failTime
}
func (m *failMarker) FailCount() uint32 {
if m == nil {
return 0
}
m.mux.RLock()
defer m.mux.RUnlock()
return m.failCount
}
func (m *failMarker) Mark() {
if m == nil {
return
}
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = time.Now().Unix()
m.failCount++
}
func (m *failMarker) Reset() {
if m == nil {
return
}
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = 0
m.failCount = 0
}

View File

@ -22,26 +22,34 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) {
node := r.nodes[0] node := r.nodes[0]
cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr()) cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr())
if err != nil { if err != nil {
node.marker.Mark()
return return
} }
cn, err := node.Transport().Handshake(ctx, cc) cn, err := node.Transport().Handshake(ctx, cc)
if err != nil { if err != nil {
cc.Close() cc.Close()
node.marker.Mark()
return return
} }
node.marker.Reset()
preNode := node preNode := node
for _, node := range r.nodes[1:] { for _, node := range r.nodes[1:] {
cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr()) cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr())
if err != nil { if err != nil {
cn.Close() cn.Close()
node.marker.Mark()
return return
} }
cc, err = node.transport.Handshake(ctx, cc) cc, err = node.transport.Handshake(ctx, cc)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.marker.Mark()
return
} }
node.marker.Reset()
cn = cc cn = cc
preNode = node preNode = node
} }

View File

@ -1,19 +1,24 @@
package chain package chain
import (
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
)
// default options for FailFilter
const (
DefaultMaxFails = 1
DefaultFailTimeout = 30 * time.Second
)
var ( var (
defaultSelector Selector = NewSelector(nil) defaultSelector Selector = NewSelector(nil)
) )
type Filter interface {
Filter(nodes ...*Node) []*Node
String() string
}
type Strategy interface {
Apply(nodes ...*Node) *Node
String() string
}
type Selector interface { type Selector interface {
Select(nodes ...*Node) *Node Select(nodes ...*Node) *Node
} }
@ -39,3 +44,115 @@ func (s *selector) Select(nodes ...*Node) *Node {
} }
return s.strategy.Apply(nodes...) return s.strategy.Apply(nodes...)
} }
type Strategy interface {
Apply(nodes ...*Node) *Node
}
// RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
type RoundRobinStrategy struct {
counter uint64
}
func (s *RoundRobinStrategy) Apply(nodes ...*Node) *Node {
if len(nodes) == 0 {
return nil
}
n := atomic.AddUint64(&s.counter, 1) - 1
return nodes[int(n%uint64(len(nodes)))]
}
// RandomStrategy is a strategy for node selector.
// The node will be selected randomly.
type RandomStrategy struct {
Seed int64
rand *rand.Rand
once sync.Once
mux sync.Mutex
}
func (s *RandomStrategy) Apply(nodes ...*Node) *Node {
s.once.Do(func() {
seed := s.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
s.rand = rand.New(rand.NewSource(seed))
})
if len(nodes) == 0 {
return nil
}
s.mux.Lock()
defer s.mux.Unlock()
r := s.rand.Int()
return nodes[r%len(nodes)]
}
// FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last,
// and will stick to the selected node until it is failed.
type FIFOStrategy struct{}
// Apply applies the fifo strategy for the nodes.
func (s *FIFOStrategy) Apply(nodes ...*Node) *Node {
if len(nodes) == 0 {
return nil
}
return nodes[0]
}
type Filter interface {
Filter(nodes ...*Node) []*Node
}
// FailFilter filters the dead node.
// A node is marked as dead if its failed count is greater than MaxFails.
type FailFilter struct {
MaxFails int
FailTimeout time.Duration
}
// Filter filters dead nodes.
func (f *FailFilter) Filter(nodes ...*Node) []*Node {
maxFails := f.MaxFails
if maxFails == 0 {
maxFails = DefaultMaxFails
}
failTimeout := f.FailTimeout
if failTimeout == 0 {
failTimeout = DefaultFailTimeout
}
if len(nodes) <= 1 || maxFails < 0 {
return nodes
}
var nl []*Node
for _, node := range nodes {
if node.marker.FailCount() < uint32(maxFails) ||
time.Since(time.Unix(node.marker.FailTime(), 0)) >= failTimeout {
nl = append(nl, node)
}
}
return nl
}
// InvalidFilter filters the invalid node.
// A node is invalid if its port is invalid (negative or zero value).
type InvalidFilter struct{}
// Filter filters invalid nodes.
func (f *InvalidFilter) Filter(nodes ...*Node) []*Node {
var nl []*Node
for _, node := range nodes {
_, sport, _ := net.SplitHostPort(node.Addr())
if port, _ := strconv.Atoi(sport); port > 0 {
nl = append(nl, node)
}
}
return nl
}

View File

@ -5,9 +5,9 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"strings" "strings"
@ -51,11 +51,15 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
Header: make(http.Header), Header: make(http.Header),
} }
if c.md.UserAgent != "" { if c.md.UserAgent != "" {
log.Println(c.md.UserAgent)
req.Header.Set("User-Agent", c.md.UserAgent) req.Header.Set("User-Agent", c.md.UserAgent)
} }
req.Header.Set("Proxy-Connection", "keep-alive") req.Header.Set("Proxy-Connection", "keep-alive")
c.logger = c.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": conn.RemoteAddr().String(),
})
if user := c.md.User; user != nil { if user := c.md.User; user != nil {
u := user.Username() u := user.Username()
p, _ := user.Password() p, _ := user.Password()
@ -63,6 +67,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
} }
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
c.logger.Debug(string(dump))
}
req = req.WithContext(ctx) req = req.WithContext(ctx)
if err := req.Write(conn); err != nil { if err := req.Write(conn); err != nil {
return nil, err return nil, err
@ -74,6 +83,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
} }
defer resp.Body.Close() defer resp.Body.Close()
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
c.logger.Debug(string(dump))
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s", resp.Status) return nil, fmt.Errorf("%s", resp.Status)
} }

View File

@ -42,11 +42,33 @@ func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOptio
dial := options.DialFunc dial := options.DialFunc
if dial != nil { if dial != nil {
return dial(ctx, addr) conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
}
return conn, err
} }
var netd net.Dialer var netd net.Dialer
return netd.DialContext(ctx, "tcp", addr) conn, err := netd.DialContext(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial direct")
}
}
return conn, err
} }
func (d *Dialer) parseMetadata(md md.Metadata) (err error) { func (d *Dialer) parseMetadata(md md.Metadata) (err error) {

View File

@ -76,6 +76,7 @@ func (h *Handler) parseMetadata(md md.Metadata) error {
} }
} }
} }
h.md.retryCount = md.GetInt(retryCount)
return nil return nil
} }
@ -260,10 +261,10 @@ func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err err
*/ */
conn, err = route.Dial(ctx, "tcp", addr) conn, err = route.Dial(ctx, "tcp", addr)
if err != nil { if err == nil {
h.logger.Warn("retry:", err) break
continue
} }
h.logger.Errorf("route(retry=%d): %s", i, err)
} }
return return

View File

@ -8,6 +8,7 @@ const (
authsKey = "auths" authsKey = "auths"
probeResistKey = "probeResist" probeResistKey = "probeResist"
knockKey = "knock" knockKey = "knock"
retryCount = "retry"
) )
type metadata struct { type metadata struct {

View File

@ -2,6 +2,7 @@ package config
import ( import (
"io" "io"
"time"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -24,8 +25,9 @@ type LogConfig struct {
} }
type LoadbalancingConfig struct { type LoadbalancingConfig struct {
Strategy string Strategy string
Filters []string MaxFails int
FailTimeout time.Duration
} }
type ListenerConfig struct { type ListenerConfig struct {
@ -49,6 +51,7 @@ type ConnectorConfig struct {
} }
type ServiceConfig struct { type ServiceConfig struct {
Name string
URL string URL string
Addr string Addr string
Listener *ListenerConfig Listener *ListenerConfig