add Route and Selector interfaces

This commit is contained in:
ginuerzh
2022-09-02 14:59:34 +08:00
parent 201edf2de5
commit 2835a5d44a
7 changed files with 224 additions and 319 deletions

View File

@ -2,21 +2,22 @@ package chain
import (
"github.com/go-gost/core/metadata"
"github.com/go-gost/core/selector"
)
type Chainer interface {
Route(network, address string) *Route
Route(network, address string) Route
}
type SelectableChainer interface {
Chainer
Selectable
selector.Selectable
}
type Chain struct {
name string
groups []*NodeGroup
marker Marker
marker selector.Marker
metadata metadata.Metadata
}
@ -24,7 +25,7 @@ func NewChain(name string, groups ...*NodeGroup) *Chain {
return &Chain{
name: name,
groups: groups,
marker: NewFailMarker(),
marker: selector.NewFailMarker(),
}
}
@ -40,18 +41,16 @@ func (c *Chain) Metadata() metadata.Metadata {
return c.metadata
}
func (c *Chain) Marker() Marker {
func (c *Chain) Marker() selector.Marker {
return c.marker
}
func (c *Chain) Route(network, address string) (r *Route) {
func (c *Chain) Route(network, address string) Route {
if c == nil || len(c.groups) == 0 {
return
return nil
}
r = &Route{
chain: c,
}
rt := newRoute().WithChain(c)
for _, group := range c.groups {
// hop level bypass test
if group.bypass != nil && group.bypass.Contains(address) {
@ -60,27 +59,37 @@ func (c *Chain) Route(network, address string) (r *Route) {
node := group.FilterAddr(address).Next()
if node == nil {
return
return rt
}
if node.transport.Multiplex() {
tr := node.transport.Copy().
WithRoute(r)
tr := node.transport.
Copy().
WithRoute(rt)
node = node.Copy()
node.transport = tr
r = &Route{}
rt = newRoute()
}
r.addNode(node)
rt.addNode(node)
}
return r
return rt
}
type ChainGroup struct {
Chains []SelectableChainer
Selector Selector[SelectableChainer]
chains []SelectableChainer
selector selector.Selector[SelectableChainer]
}
func (p *ChainGroup) Route(network, address string) *Route {
func NewChainGroup(chains ...SelectableChainer) *ChainGroup {
return &ChainGroup{chains: chains}
}
func (p *ChainGroup) WithSelector(s selector.Selector[SelectableChainer]) *ChainGroup {
p.selector = s
return p
}
func (p *ChainGroup) Route(network, address string) Route {
if chain := p.next(); chain != nil {
return chain.Route(network, address)
}
@ -88,13 +97,9 @@ func (p *ChainGroup) Route(network, address string) *Route {
}
func (p *ChainGroup) next() Chainer {
if p == nil || len(p.Chains) == 0 {
if p == nil || len(p.chains) == 0 {
return nil
}
s := p.Selector
if s == nil {
s = DefaultChainSelector
}
return s.Select(p.Chains...)
return p.selector.Select(p.chains...)
}

View File

@ -5,6 +5,7 @@ import (
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/metadata"
"github.com/go-gost/core/resolver"
"github.com/go-gost/core/selector"
)
type Node struct {
@ -14,7 +15,7 @@ type Node struct {
bypass bypass.Bypass
resolver resolver.Resolver
hostMapper hosts.HostMapper
marker Marker
marker selector.Marker
metadata metadata.Metadata
}
@ -22,7 +23,7 @@ func NewNode(name, addr string) *Node {
return &Node{
Name: name,
Addr: addr,
marker: NewFailMarker(),
marker: selector.NewFailMarker(),
}
}
@ -51,7 +52,7 @@ func (node *Node) WithMetadata(md metadata.Metadata) *Node {
return node
}
func (node *Node) Marker() Marker {
func (node *Node) Marker() selector.Marker {
return node.marker
}
@ -67,7 +68,7 @@ func (node *Node) Copy() *Node {
type NodeGroup struct {
nodes []*Node
selector Selector[*Node]
selector selector.Selector[*Node]
bypass bypass.Bypass
}
@ -85,7 +86,7 @@ func (g *NodeGroup) Nodes() []*Node {
return g.nodes
}
func (g *NodeGroup) WithSelector(selector Selector[*Node]) *NodeGroup {
func (g *NodeGroup) WithSelector(selector selector.Selector[*Node]) *NodeGroup {
g.selector = selector
return g
}
@ -114,10 +115,5 @@ func (g *NodeGroup) Next() *Node {
return nil
}
s := g.selector
if s == nil {
s = DefaultNodeSelector
}
return s.Select(g.nodes...)
return g.selector.Select(g.nodes...)
}

View File

@ -18,17 +18,32 @@ var (
ErrEmptyRoute = errors.New("empty route")
)
type Route struct {
chain *Chain
nodes []*Node
logger logger.Logger
type Route interface {
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
Len() int
Path() []*Node
}
func (r *Route) addNode(node *Node) {
type route struct {
chain *Chain
nodes []*Node
}
func newRoute() *route {
return &route{}
}
func (r *route) addNode(node *Node) {
r.nodes = append(r.nodes, node)
}
func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
func (r *route) WithChain(chain *Chain) *route {
r.chain = chain
return r
}
func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
var options DialOptions
for _, opt := range opts {
opt(&options)
@ -43,13 +58,13 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO
netd.Mark = options.SockOpts.Mark
}
if r != nil {
netd.Logger = r.logger
netd.Logger = options.Logger
}
return netd.Dial(ctx, network, address)
}
conn, err := r.connect(ctx)
conn, err := r.connect(ctx, options.Logger)
if err != nil {
return nil, err
}
@ -62,17 +77,29 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO
return cc, nil
}
func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if r.Len() == 0 {
return r.bindLocal(ctx, network, address, opts...)
func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
var options BindOptions
for _, opt := range opts {
opt(&options)
}
conn, err := r.connect(ctx)
if r.Len() == 0 {
return r.bindLocal(ctx, network, address, &options)
}
conn, err := r.connect(ctx, options.Logger)
if err != nil {
return nil, err
}
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, conn, network, address, opts...)
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx,
conn, network, address,
connector.BacklogBindOption(options.Backlog),
connector.MuxBindOption(options.Mux),
connector.UDPConnTTLBindOption(options.UDPConnTTL),
connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize),
connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize),
)
if err != nil {
conn.Close()
return nil, err
@ -81,7 +108,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne
return ln, nil
}
func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) {
if r.Len() == 0 {
return nil, ErrEmptyRoute
}
@ -109,7 +136,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
}
}()
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
marker := node.Marker()
if err != nil {
if marker != nil {
@ -149,7 +176,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
preNode := node
for _, node := range r.nodes[1:] {
marker := node.Marker()
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
if err != nil {
cn.Close()
if marker != nil {
@ -185,21 +212,21 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
return
}
func (r *Route) Len() int {
func (r *route) Len() int {
if r == nil {
return 0
}
return len(r.nodes)
}
func (r *Route) GetNode(index int) *Node {
func (r *route) GetNode(index int) *Node {
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
return nil
}
return r.nodes[index]
}
func (r *Route) Path() (path []*Node) {
func (r *route) Path() (path []*Node) {
if r == nil || len(r.nodes) == 0 {
return nil
}
@ -213,12 +240,7 @@ func (r *Route) Path() (path []*Node) {
return
}
func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
options := connector.BindOptions{}
for _, opt := range opts {
opt(&options)
}
func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) {
switch network {
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
@ -240,10 +262,10 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...
"address": address,
})
ln := udp.NewListener(conn, &udp.ListenConfig{
Backlog: options.Backlog,
ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: options.UDPDataBufferSize,
TTL: options.UDPConnTTL,
Backlog: opts.Backlog,
ReadQueueSize: opts.UDPDataQueueSize,
ReadBufferSize: opts.UDPDataBufferSize,
TTL: opts.UDPConnTTL,
KeepAlive: true,
Logger: logger,
})
@ -258,6 +280,7 @@ type DialOptions struct {
Timeout time.Duration
Interface string
SockOpts *SockOpts
Logger logger.Logger
}
type DialOption func(opts *DialOptions)
@ -279,3 +302,56 @@ func SockOptsDialOption(so *SockOpts) DialOption {
opts.SockOpts = so
}
}
func LoggerDialOption(logger logger.Logger) DialOption {
return func(opts *DialOptions) {
opts.Logger = logger
}
}
type BindOptions struct {
Mux bool
Backlog int
UDPDataQueueSize int
UDPDataBufferSize int
UDPConnTTL time.Duration
Logger logger.Logger
}
type BindOption func(opts *BindOptions)
func MuxBindOption(mux bool) BindOption {
return func(opts *BindOptions) {
opts.Mux = mux
}
}
func BacklogBindOption(backlog int) BindOption {
return func(opts *BindOptions) {
opts.Backlog = backlog
}
}
func UDPDataQueueSizeBindOption(size int) BindOption {
return func(opts *BindOptions) {
opts.UDPDataQueueSize = size
}
}
func UDPDataBufferSizeBindOption(size int) BindOption {
return func(opts *BindOptions) {
opts.UDPDataBufferSize = size
}
}
func UDPConnTTLBindOption(ttl time.Duration) BindOption {
return func(opts *BindOptions) {
opts.UDPConnTTL = ttl
}
}
func LoggerBindOption(logger logger.Logger) BindOption {
return func(opts *BindOptions) {
opts.Logger = logger
}
}

View File

@ -7,7 +7,6 @@ import (
"net"
"time"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder"
@ -128,7 +127,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
r.logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ {
var route *Route
var route Route
if r.chain != nil {
route = r.chain.Route(network, address)
}
@ -149,12 +148,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
}
if route == nil {
route = &Route{}
route = newRoute()
}
route.logger = r.logger
conn, err = route.Dial(ctx, network, address,
InterfaceDialOption(r.ifceName),
SockOptsDialOption(r.sockOpts),
LoggerDialOption(r.logger),
)
if err == nil {
break
@ -165,7 +164,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
return
}
func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) {
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
count := r.retries + 1
if count <= 0 {
count = 1
@ -173,7 +172,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
r.logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ {
var route *Route
var route Route
if r.chain != nil {
route = r.chain.Route(network, address)
if route.Len() == 0 {

View File

@ -1,245 +0,0 @@
package chain
import (
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
)
// default options for FailFilter
const (
DefaultFailTimeout = 30 * time.Second
)
var (
DefaultNodeSelector = NewSelector(
RoundRobinStrategy[*Node](),
// FailFilter[*Node](1, DefaultFailTimeout),
)
DefaultChainSelector = NewSelector(
RoundRobinStrategy[SelectableChainer](),
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
)
)
type Selectable interface {
Marker() Marker
Metadata() metadata.Metadata
}
type Selector[T any] interface {
Select(...T) T
}
type selector[T Selectable] struct {
strategy Strategy[T]
filters []Filter[T]
}
func NewSelector[T Selectable](strategy Strategy[T], filters ...Filter[T]) Selector[T] {
return &selector[T]{
filters: filters,
strategy: strategy,
}
}
func (s *selector[T]) Select(vs ...T) (v T) {
for _, filter := range s.filters {
vs = filter.Filter(vs...)
}
if len(vs) == 0 {
return
}
return s.strategy.Apply(vs...)
}
type Strategy[T Selectable] interface {
Apply(...T) T
}
type roundRobinStrategy[T Selectable] struct {
counter uint64
}
// RoundRobinStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
func RoundRobinStrategy[T Selectable]() Strategy[T] {
return &roundRobinStrategy[T]{}
}
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
n := atomic.AddUint64(&s.counter, 1) - 1
return vs[int(n%uint64(len(vs)))]
}
type randomStrategy[T Selectable] struct {
rand *rand.Rand
mux sync.Mutex
}
// RandomStrategy is a strategy for node selector.
// The node will be selected randomly.
func RandomStrategy[T Selectable]() Strategy[T] {
return &randomStrategy[T]{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
s.mux.Lock()
defer s.mux.Unlock()
r := s.rand.Int()
return vs[r%len(vs)]
}
type fifoStrategy[T Selectable] struct{}
// FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last,
// and will stick to the selected node until it is failed.
func FIFOStrategy[T Selectable]() Strategy[T] {
return &fifoStrategy[T]{}
}
// Apply applies the fifo strategy for the nodes.
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
return vs[0]
}
type Filter[T Selectable] interface {
Filter(...T) []T
}
type failFilter[T Selectable] struct {
maxFails int
failTimeout time.Duration
}
// FailFilter filters the dead objects.
// An object is marked as dead if its failed count is greater than MaxFails.
func FailFilter[T Selectable](maxFails int, timeout time.Duration) Filter[T] {
return &failFilter[T]{
maxFails: maxFails,
failTimeout: timeout,
}
}
// Filter filters dead objects.
func (f *failFilter[T]) Filter(vs ...T) []T {
maxFails := f.maxFails
failTimeout := f.failTimeout
if failTimeout == 0 {
failTimeout = DefaultFailTimeout
}
if len(vs) <= 1 || maxFails <= 0 {
return vs
}
var l []T
for _, v := range vs {
if marker := v.Marker(); marker != nil {
if marker.Count() < int64(maxFails) ||
time.Since(marker.Time()) >= failTimeout {
l = append(l, v)
}
} else {
l = append(l, v)
}
}
return l
}
type backupFilter[T Selectable] struct{}
// BackupFilter filters the backup objects.
// An object is marked as backup if its metadata has backup flag.
func BackupFilter[T Selectable]() Filter[T] {
return &backupFilter[T]{}
}
// Filter filters backup objects.
func (f *backupFilter[T]) Filter(vs ...T) []T {
if len(vs) <= 1 {
return vs
}
var l, backups []T
for _, v := range vs {
if mdutil.GetBool(v.Metadata(), "backup") {
backups = append(backups, v)
} else {
l = append(l, v)
}
}
if len(l) == 0 {
return backups
}
return l
}
type Marker interface {
Time() time.Time
Count() int64
Mark()
Reset()
}
type failMarker struct {
failTime int64
failCount int64
}
func NewFailMarker() Marker {
return &failMarker{}
}
func (m *failMarker) Time() time.Time {
if m == nil {
return time.Time{}
}
return time.Unix(atomic.LoadInt64(&m.failTime), 0)
}
func (m *failMarker) Count() int64 {
if m == nil {
return 0
}
return atomic.LoadInt64(&m.failCount)
}
func (m *failMarker) Mark() {
if m == nil {
return
}
atomic.AddInt64(&m.failCount, 1)
atomic.StoreInt64(&m.failTime, time.Now().Unix())
}
func (m *failMarker) Reset() {
if m == nil {
return
}
atomic.StoreInt64(&m.failCount, 0)
}

View File

@ -14,7 +14,7 @@ type Transport struct {
addr string
ifceName string
sockOpts *SockOpts
route *Route
route Route
dialer dialer.Dialer
connector connector.Connector
}
@ -53,7 +53,7 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
if tr.sockOpts != nil {
netd.Mark = tr.sockOpts.Mark
}
if tr.route.Len() > 0 {
if tr.route != nil && tr.route.Len() > 0 {
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tr.route.Dial(ctx, network, addr)
}
@ -98,7 +98,7 @@ func (tr *Transport) Multiplex() bool {
return false
}
func (tr *Transport) WithRoute(r *Route) *Transport {
func (tr *Transport) WithRoute(r Route) *Transport {
tr.route = r
return tr
}