improve http handler

This commit is contained in:
ginuerzh
2021-10-31 12:41:53 +08:00
parent 248f7e4318
commit 64736585ee
11 changed files with 435 additions and 127 deletions

View File

@ -1,15 +1,22 @@
package chain
import (
"sync"
"time"
)
type Node struct {
name string
addr string
transport *Transport
marker *failMarker
}
func NewNode(name, addr string) *Node {
return &Node{
name: name,
addr: addr,
name: name,
addr: addr,
marker: &failMarker{},
}
}
@ -45,15 +52,72 @@ func (g *NodeGroup) AddNode(node *Node) {
g.nodes = append(g.nodes, node)
}
func (g *NodeGroup) WithSelector(selector Selector) {
func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup {
g.selector = selector
return g
}
func (g *NodeGroup) Next() *Node {
if g == nil || len(g.nodes) == 0 {
return nil
}
selector := g.selector
if selector == nil {
// selector = defaultSelector
return g.nodes[0]
}
return selector.Select(g.nodes...)
}
type failMarker struct {
failTime int64
failCount uint32
mux sync.RWMutex
}
func (m *failMarker) FailTime() int64 {
if m == nil {
return 0
}
m.mux.RLock()
defer m.mux.RUnlock()
return m.failTime
}
func (m *failMarker) FailCount() uint32 {
if m == nil {
return 0
}
m.mux.RLock()
defer m.mux.RUnlock()
return m.failCount
}
func (m *failMarker) Mark() {
if m == nil {
return
}
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = time.Now().Unix()
m.failCount++
}
func (m *failMarker) Reset() {
if m == nil {
return
}
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = 0
m.failCount = 0
}

View File

@ -22,26 +22,34 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) {
node := r.nodes[0]
cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr())
if err != nil {
node.marker.Mark()
return
}
cn, err := node.Transport().Handshake(ctx, cc)
if err != nil {
cc.Close()
node.marker.Mark()
return
}
node.marker.Reset()
preNode := node
for _, node := range r.nodes[1:] {
cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr())
if err != nil {
cn.Close()
node.marker.Mark()
return
}
cc, err = node.transport.Handshake(ctx, cc)
if err != nil {
cn.Close()
node.marker.Mark()
return
}
node.marker.Reset()
cn = cc
preNode = node
}

View File

@ -1,19 +1,24 @@
package chain
import (
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
)
// default options for FailFilter
const (
DefaultMaxFails = 1
DefaultFailTimeout = 30 * time.Second
)
var (
defaultSelector Selector = NewSelector(nil)
)
type Filter interface {
Filter(nodes ...*Node) []*Node
String() string
}
type Strategy interface {
Apply(nodes ...*Node) *Node
String() string
}
type Selector interface {
Select(nodes ...*Node) *Node
}
@ -39,3 +44,115 @@ func (s *selector) Select(nodes ...*Node) *Node {
}
return s.strategy.Apply(nodes...)
}
type Strategy interface {
Apply(nodes ...*Node) *Node
}
// RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
type RoundRobinStrategy struct {
counter uint64
}
func (s *RoundRobinStrategy) Apply(nodes ...*Node) *Node {
if len(nodes) == 0 {
return nil
}
n := atomic.AddUint64(&s.counter, 1) - 1
return nodes[int(n%uint64(len(nodes)))]
}
// RandomStrategy is a strategy for node selector.
// The node will be selected randomly.
type RandomStrategy struct {
Seed int64
rand *rand.Rand
once sync.Once
mux sync.Mutex
}
func (s *RandomStrategy) Apply(nodes ...*Node) *Node {
s.once.Do(func() {
seed := s.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
s.rand = rand.New(rand.NewSource(seed))
})
if len(nodes) == 0 {
return nil
}
s.mux.Lock()
defer s.mux.Unlock()
r := s.rand.Int()
return nodes[r%len(nodes)]
}
// FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last,
// and will stick to the selected node until it is failed.
type FIFOStrategy struct{}
// Apply applies the fifo strategy for the nodes.
func (s *FIFOStrategy) Apply(nodes ...*Node) *Node {
if len(nodes) == 0 {
return nil
}
return nodes[0]
}
type Filter interface {
Filter(nodes ...*Node) []*Node
}
// FailFilter filters the dead node.
// A node is marked as dead if its failed count is greater than MaxFails.
type FailFilter struct {
MaxFails int
FailTimeout time.Duration
}
// Filter filters dead nodes.
func (f *FailFilter) Filter(nodes ...*Node) []*Node {
maxFails := f.MaxFails
if maxFails == 0 {
maxFails = DefaultMaxFails
}
failTimeout := f.FailTimeout
if failTimeout == 0 {
failTimeout = DefaultFailTimeout
}
if len(nodes) <= 1 || maxFails < 0 {
return nodes
}
var nl []*Node
for _, node := range nodes {
if node.marker.FailCount() < uint32(maxFails) ||
time.Since(time.Unix(node.marker.FailTime(), 0)) >= failTimeout {
nl = append(nl, node)
}
}
return nl
}
// InvalidFilter filters the invalid node.
// A node is invalid if its port is invalid (negative or zero value).
type InvalidFilter struct{}
// Filter filters invalid nodes.
func (f *InvalidFilter) Filter(nodes ...*Node) []*Node {
var nl []*Node
for _, node := range nodes {
_, sport, _ := net.SplitHostPort(node.Addr())
if port, _ := strconv.Atoi(sport); port > 0 {
nl = append(nl, node)
}
}
return nl
}

View File

@ -5,9 +5,9 @@ import (
"context"
"encoding/base64"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
@ -51,11 +51,15 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
Header: make(http.Header),
}
if c.md.UserAgent != "" {
log.Println(c.md.UserAgent)
req.Header.Set("User-Agent", c.md.UserAgent)
}
req.Header.Set("Proxy-Connection", "keep-alive")
c.logger = c.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": conn.RemoteAddr().String(),
})
if user := c.md.User; user != nil {
u := user.Username()
p, _ := user.Password()
@ -63,6 +67,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
c.logger.Debug(string(dump))
}
req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return nil, err
@ -74,6 +83,11 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address
}
defer resp.Body.Close()
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
c.logger.Debug(string(dump))
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s", resp.Status)
}

View File

@ -42,11 +42,33 @@ func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOptio
dial := options.DialFunc
if dial != nil {
return dial(ctx, addr)
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
}
return conn, err
}
var netd net.Dialer
return netd.DialContext(ctx, "tcp", addr)
conn, err := netd.DialContext(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial direct")
}
}
return conn, err
}
func (d *Dialer) parseMetadata(md md.Metadata) (err error) {

View File

@ -76,6 +76,7 @@ func (h *Handler) parseMetadata(md md.Metadata) error {
}
}
}
h.md.retryCount = md.GetInt(retryCount)
return nil
}
@ -260,10 +261,10 @@ func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err err
*/
conn, err = route.Dial(ctx, "tcp", addr)
if err != nil {
h.logger.Warn("retry:", err)
continue
if err == nil {
break
}
h.logger.Errorf("route(retry=%d): %s", i, err)
}
return

View File

@ -8,6 +8,7 @@ const (
authsKey = "auths"
probeResistKey = "probeResist"
knockKey = "knock"
retryCount = "retry"
)
type metadata struct {

View File

@ -2,6 +2,7 @@ package config
import (
"io"
"time"
"github.com/spf13/viper"
)
@ -24,8 +25,9 @@ type LogConfig struct {
}
type LoadbalancingConfig struct {
Strategy string
Filters []string
Strategy string
MaxFails int
FailTimeout time.Duration
}
type ListenerConfig struct {
@ -49,6 +51,7 @@ type ConnectorConfig struct {
}
type ServiceConfig struct {
Name string
URL string
Addr string
Listener *ListenerConfig