add Route and Selector interfaces
This commit is contained in:
@ -2,21 +2,22 @@ package chain
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-gost/core/metadata"
|
"github.com/go-gost/core/metadata"
|
||||||
|
"github.com/go-gost/core/selector"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Chainer interface {
|
type Chainer interface {
|
||||||
Route(network, address string) *Route
|
Route(network, address string) Route
|
||||||
}
|
}
|
||||||
|
|
||||||
type SelectableChainer interface {
|
type SelectableChainer interface {
|
||||||
Chainer
|
Chainer
|
||||||
Selectable
|
selector.Selectable
|
||||||
}
|
}
|
||||||
|
|
||||||
type Chain struct {
|
type Chain struct {
|
||||||
name string
|
name string
|
||||||
groups []*NodeGroup
|
groups []*NodeGroup
|
||||||
marker Marker
|
marker selector.Marker
|
||||||
metadata metadata.Metadata
|
metadata metadata.Metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ func NewChain(name string, groups ...*NodeGroup) *Chain {
|
|||||||
return &Chain{
|
return &Chain{
|
||||||
name: name,
|
name: name,
|
||||||
groups: groups,
|
groups: groups,
|
||||||
marker: NewFailMarker(),
|
marker: selector.NewFailMarker(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,18 +41,16 @@ func (c *Chain) Metadata() metadata.Metadata {
|
|||||||
return c.metadata
|
return c.metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Chain) Marker() Marker {
|
func (c *Chain) Marker() selector.Marker {
|
||||||
return c.marker
|
return c.marker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Chain) Route(network, address string) (r *Route) {
|
func (c *Chain) Route(network, address string) Route {
|
||||||
if c == nil || len(c.groups) == 0 {
|
if c == nil || len(c.groups) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r = &Route{
|
rt := newRoute().WithChain(c)
|
||||||
chain: c,
|
|
||||||
}
|
|
||||||
for _, group := range c.groups {
|
for _, group := range c.groups {
|
||||||
// hop level bypass test
|
// hop level bypass test
|
||||||
if group.bypass != nil && group.bypass.Contains(address) {
|
if group.bypass != nil && group.bypass.Contains(address) {
|
||||||
@ -60,27 +59,37 @@ func (c *Chain) Route(network, address string) (r *Route) {
|
|||||||
|
|
||||||
node := group.FilterAddr(address).Next()
|
node := group.FilterAddr(address).Next()
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return
|
return rt
|
||||||
}
|
}
|
||||||
if node.transport.Multiplex() {
|
if node.transport.Multiplex() {
|
||||||
tr := node.transport.Copy().
|
tr := node.transport.
|
||||||
WithRoute(r)
|
Copy().
|
||||||
|
WithRoute(rt)
|
||||||
node = node.Copy()
|
node = node.Copy()
|
||||||
node.transport = tr
|
node.transport = tr
|
||||||
r = &Route{}
|
rt = newRoute()
|
||||||
}
|
}
|
||||||
|
|
||||||
r.addNode(node)
|
rt.addNode(node)
|
||||||
}
|
}
|
||||||
return r
|
return rt
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChainGroup struct {
|
type ChainGroup struct {
|
||||||
Chains []SelectableChainer
|
chains []SelectableChainer
|
||||||
Selector Selector[SelectableChainer]
|
selector selector.Selector[SelectableChainer]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ChainGroup) Route(network, address string) *Route {
|
func NewChainGroup(chains ...SelectableChainer) *ChainGroup {
|
||||||
|
return &ChainGroup{chains: chains}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ChainGroup) WithSelector(s selector.Selector[SelectableChainer]) *ChainGroup {
|
||||||
|
p.selector = s
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ChainGroup) Route(network, address string) Route {
|
||||||
if chain := p.next(); chain != nil {
|
if chain := p.next(); chain != nil {
|
||||||
return chain.Route(network, address)
|
return chain.Route(network, address)
|
||||||
}
|
}
|
||||||
@ -88,13 +97,9 @@ func (p *ChainGroup) Route(network, address string) *Route {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ChainGroup) next() Chainer {
|
func (p *ChainGroup) next() Chainer {
|
||||||
if p == nil || len(p.Chains) == 0 {
|
if p == nil || len(p.chains) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s := p.Selector
|
return p.selector.Select(p.chains...)
|
||||||
if s == nil {
|
|
||||||
s = DefaultChainSelector
|
|
||||||
}
|
|
||||||
return s.Select(p.Chains...)
|
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/go-gost/core/hosts"
|
"github.com/go-gost/core/hosts"
|
||||||
"github.com/go-gost/core/metadata"
|
"github.com/go-gost/core/metadata"
|
||||||
"github.com/go-gost/core/resolver"
|
"github.com/go-gost/core/resolver"
|
||||||
|
"github.com/go-gost/core/selector"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
@ -14,7 +15,7 @@ type Node struct {
|
|||||||
bypass bypass.Bypass
|
bypass bypass.Bypass
|
||||||
resolver resolver.Resolver
|
resolver resolver.Resolver
|
||||||
hostMapper hosts.HostMapper
|
hostMapper hosts.HostMapper
|
||||||
marker Marker
|
marker selector.Marker
|
||||||
metadata metadata.Metadata
|
metadata metadata.Metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ func NewNode(name, addr string) *Node {
|
|||||||
return &Node{
|
return &Node{
|
||||||
Name: name,
|
Name: name,
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
marker: NewFailMarker(),
|
marker: selector.NewFailMarker(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ func (node *Node) WithMetadata(md metadata.Metadata) *Node {
|
|||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Node) Marker() Marker {
|
func (node *Node) Marker() selector.Marker {
|
||||||
return node.marker
|
return node.marker
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,7 +68,7 @@ func (node *Node) Copy() *Node {
|
|||||||
|
|
||||||
type NodeGroup struct {
|
type NodeGroup struct {
|
||||||
nodes []*Node
|
nodes []*Node
|
||||||
selector Selector[*Node]
|
selector selector.Selector[*Node]
|
||||||
bypass bypass.Bypass
|
bypass bypass.Bypass
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,7 +86,7 @@ func (g *NodeGroup) Nodes() []*Node {
|
|||||||
return g.nodes
|
return g.nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *NodeGroup) WithSelector(selector Selector[*Node]) *NodeGroup {
|
func (g *NodeGroup) WithSelector(selector selector.Selector[*Node]) *NodeGroup {
|
||||||
g.selector = selector
|
g.selector = selector
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
@ -114,10 +115,5 @@ func (g *NodeGroup) Next() *Node {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s := g.selector
|
return g.selector.Select(g.nodes...)
|
||||||
if s == nil {
|
|
||||||
s = DefaultNodeSelector
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.Select(g.nodes...)
|
|
||||||
}
|
}
|
||||||
|
134
chain/route.go
134
chain/route.go
@ -18,17 +18,32 @@ var (
|
|||||||
ErrEmptyRoute = errors.New("empty route")
|
ErrEmptyRoute = errors.New("empty route")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Route struct {
|
type Route interface {
|
||||||
chain *Chain
|
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
|
||||||
nodes []*Node
|
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
|
||||||
logger logger.Logger
|
Len() int
|
||||||
|
Path() []*Node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) addNode(node *Node) {
|
type route struct {
|
||||||
|
chain *Chain
|
||||||
|
nodes []*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRoute() *route {
|
||||||
|
return &route{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *route) addNode(node *Node) {
|
||||||
r.nodes = append(r.nodes, node)
|
r.nodes = append(r.nodes, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
|
func (r *route) WithChain(chain *Chain) *route {
|
||||||
|
r.chain = chain
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
|
||||||
var options DialOptions
|
var options DialOptions
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&options)
|
opt(&options)
|
||||||
@ -43,13 +58,13 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO
|
|||||||
netd.Mark = options.SockOpts.Mark
|
netd.Mark = options.SockOpts.Mark
|
||||||
}
|
}
|
||||||
if r != nil {
|
if r != nil {
|
||||||
netd.Logger = r.logger
|
netd.Logger = options.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
return netd.Dial(ctx, network, address)
|
return netd.Dial(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := r.connect(ctx)
|
conn, err := r.connect(ctx, options.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -62,17 +77,29 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO
|
|||||||
return cc, nil
|
return cc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
|
||||||
if r.Len() == 0 {
|
var options BindOptions
|
||||||
return r.bindLocal(ctx, network, address, opts...)
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := r.connect(ctx)
|
if r.Len() == 0 {
|
||||||
|
return r.bindLocal(ctx, network, address, &options)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := r.connect(ctx, options.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, conn, network, address, opts...)
|
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx,
|
||||||
|
conn, network, address,
|
||||||
|
connector.BacklogBindOption(options.Backlog),
|
||||||
|
connector.MuxBindOption(options.Mux),
|
||||||
|
connector.UDPConnTTLBindOption(options.UDPConnTTL),
|
||||||
|
connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize),
|
||||||
|
connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -81,7 +108,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne
|
|||||||
return ln, nil
|
return ln, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) {
|
||||||
if r.Len() == 0 {
|
if r.Len() == 0 {
|
||||||
return nil, ErrEmptyRoute
|
return nil, ErrEmptyRoute
|
||||||
}
|
}
|
||||||
@ -109,7 +136,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
|
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
|
||||||
marker := node.Marker()
|
marker := node.Marker()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if marker != nil {
|
if marker != nil {
|
||||||
@ -149,7 +176,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
|||||||
preNode := node
|
preNode := node
|
||||||
for _, node := range r.nodes[1:] {
|
for _, node := range r.nodes[1:] {
|
||||||
marker := node.Marker()
|
marker := node.Marker()
|
||||||
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
|
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
if marker != nil {
|
if marker != nil {
|
||||||
@ -185,21 +212,21 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) Len() int {
|
func (r *route) Len() int {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return len(r.nodes)
|
return len(r.nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) GetNode(index int) *Node {
|
func (r *route) GetNode(index int) *Node {
|
||||||
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
|
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return r.nodes[index]
|
return r.nodes[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) Path() (path []*Node) {
|
func (r *route) Path() (path []*Node) {
|
||||||
if r == nil || len(r.nodes) == 0 {
|
if r == nil || len(r.nodes) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -213,12 +240,7 @@ func (r *Route) Path() (path []*Node) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) {
|
||||||
options := connector.BindOptions{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
addr, err := net.ResolveTCPAddr(network, address)
|
addr, err := net.ResolveTCPAddr(network, address)
|
||||||
@ -240,10 +262,10 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...
|
|||||||
"address": address,
|
"address": address,
|
||||||
})
|
})
|
||||||
ln := udp.NewListener(conn, &udp.ListenConfig{
|
ln := udp.NewListener(conn, &udp.ListenConfig{
|
||||||
Backlog: options.Backlog,
|
Backlog: opts.Backlog,
|
||||||
ReadQueueSize: options.UDPDataQueueSize,
|
ReadQueueSize: opts.UDPDataQueueSize,
|
||||||
ReadBufferSize: options.UDPDataBufferSize,
|
ReadBufferSize: opts.UDPDataBufferSize,
|
||||||
TTL: options.UDPConnTTL,
|
TTL: opts.UDPConnTTL,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
})
|
})
|
||||||
@ -258,6 +280,7 @@ type DialOptions struct {
|
|||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
Interface string
|
Interface string
|
||||||
SockOpts *SockOpts
|
SockOpts *SockOpts
|
||||||
|
Logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type DialOption func(opts *DialOptions)
|
type DialOption func(opts *DialOptions)
|
||||||
@ -279,3 +302,56 @@ func SockOptsDialOption(so *SockOpts) DialOption {
|
|||||||
opts.SockOpts = so
|
opts.SockOpts = so
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoggerDialOption(logger logger.Logger) DialOption {
|
||||||
|
return func(opts *DialOptions) {
|
||||||
|
opts.Logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BindOptions struct {
|
||||||
|
Mux bool
|
||||||
|
Backlog int
|
||||||
|
UDPDataQueueSize int
|
||||||
|
UDPDataBufferSize int
|
||||||
|
UDPConnTTL time.Duration
|
||||||
|
Logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type BindOption func(opts *BindOptions)
|
||||||
|
|
||||||
|
func MuxBindOption(mux bool) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.Mux = mux
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BacklogBindOption(backlog int) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.Backlog = backlog
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UDPDataQueueSizeBindOption(size int) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.UDPDataQueueSize = size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UDPDataBufferSizeBindOption(size int) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.UDPDataBufferSize = size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UDPConnTTLBindOption(ttl time.Duration) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.UDPConnTTL = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoggerBindOption(logger logger.Logger) BindOption {
|
||||||
|
return func(opts *BindOptions) {
|
||||||
|
opts.Logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/core/connector"
|
|
||||||
"github.com/go-gost/core/hosts"
|
"github.com/go-gost/core/hosts"
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
"github.com/go-gost/core/recorder"
|
"github.com/go-gost/core/recorder"
|
||||||
@ -128,7 +127,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
|||||||
r.logger.Debugf("dial %s/%s", address, network)
|
r.logger.Debugf("dial %s/%s", address, network)
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
var route *Route
|
var route Route
|
||||||
if r.chain != nil {
|
if r.chain != nil {
|
||||||
route = r.chain.Route(network, address)
|
route = r.chain.Route(network, address)
|
||||||
}
|
}
|
||||||
@ -149,12 +148,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
if route == nil {
|
if route == nil {
|
||||||
route = &Route{}
|
route = newRoute()
|
||||||
}
|
}
|
||||||
route.logger = r.logger
|
|
||||||
conn, err = route.Dial(ctx, network, address,
|
conn, err = route.Dial(ctx, network, address,
|
||||||
InterfaceDialOption(r.ifceName),
|
InterfaceDialOption(r.ifceName),
|
||||||
SockOptsDialOption(r.sockOpts),
|
SockOptsDialOption(r.sockOpts),
|
||||||
|
LoggerDialOption(r.logger),
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
@ -165,7 +164,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) {
|
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
|
||||||
count := r.retries + 1
|
count := r.retries + 1
|
||||||
if count <= 0 {
|
if count <= 0 {
|
||||||
count = 1
|
count = 1
|
||||||
@ -173,7 +172,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
|
|||||||
r.logger.Debugf("bind on %s/%s", address, network)
|
r.logger.Debugf("bind on %s/%s", address, network)
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
var route *Route
|
var route Route
|
||||||
if r.chain != nil {
|
if r.chain != nil {
|
||||||
route = r.chain.Route(network, address)
|
route = r.chain.Route(network, address)
|
||||||
if route.Len() == 0 {
|
if route.Len() == 0 {
|
||||||
|
@ -1,245 +0,0 @@
|
|||||||
package chain
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-gost/core/metadata"
|
|
||||||
mdutil "github.com/go-gost/core/metadata/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// default options for FailFilter
|
|
||||||
const (
|
|
||||||
DefaultFailTimeout = 30 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
DefaultNodeSelector = NewSelector(
|
|
||||||
RoundRobinStrategy[*Node](),
|
|
||||||
// FailFilter[*Node](1, DefaultFailTimeout),
|
|
||||||
)
|
|
||||||
DefaultChainSelector = NewSelector(
|
|
||||||
RoundRobinStrategy[SelectableChainer](),
|
|
||||||
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
type Selectable interface {
|
|
||||||
Marker() Marker
|
|
||||||
Metadata() metadata.Metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
type Selector[T any] interface {
|
|
||||||
Select(...T) T
|
|
||||||
}
|
|
||||||
|
|
||||||
type selector[T Selectable] struct {
|
|
||||||
strategy Strategy[T]
|
|
||||||
filters []Filter[T]
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSelector[T Selectable](strategy Strategy[T], filters ...Filter[T]) Selector[T] {
|
|
||||||
return &selector[T]{
|
|
||||||
filters: filters,
|
|
||||||
strategy: strategy,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *selector[T]) Select(vs ...T) (v T) {
|
|
||||||
for _, filter := range s.filters {
|
|
||||||
vs = filter.Filter(vs...)
|
|
||||||
}
|
|
||||||
if len(vs) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return s.strategy.Apply(vs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Strategy[T Selectable] interface {
|
|
||||||
Apply(...T) T
|
|
||||||
}
|
|
||||||
|
|
||||||
type roundRobinStrategy[T Selectable] struct {
|
|
||||||
counter uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundRobinStrategy is a strategy for node selector.
|
|
||||||
// The node will be selected by round-robin algorithm.
|
|
||||||
func RoundRobinStrategy[T Selectable]() Strategy[T] {
|
|
||||||
return &roundRobinStrategy[T]{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
|
|
||||||
if len(vs) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n := atomic.AddUint64(&s.counter, 1) - 1
|
|
||||||
return vs[int(n%uint64(len(vs)))]
|
|
||||||
}
|
|
||||||
|
|
||||||
type randomStrategy[T Selectable] struct {
|
|
||||||
rand *rand.Rand
|
|
||||||
mux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomStrategy is a strategy for node selector.
|
|
||||||
// The node will be selected randomly.
|
|
||||||
func RandomStrategy[T Selectable]() Strategy[T] {
|
|
||||||
return &randomStrategy[T]{
|
|
||||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
|
|
||||||
if len(vs) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
r := s.rand.Int()
|
|
||||||
|
|
||||||
return vs[r%len(vs)]
|
|
||||||
}
|
|
||||||
|
|
||||||
type fifoStrategy[T Selectable] struct{}
|
|
||||||
|
|
||||||
// FIFOStrategy is a strategy for node selector.
|
|
||||||
// The node will be selected from first to last,
|
|
||||||
// and will stick to the selected node until it is failed.
|
|
||||||
func FIFOStrategy[T Selectable]() Strategy[T] {
|
|
||||||
return &fifoStrategy[T]{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply applies the fifo strategy for the nodes.
|
|
||||||
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
|
|
||||||
if len(vs) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return vs[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
type Filter[T Selectable] interface {
|
|
||||||
Filter(...T) []T
|
|
||||||
}
|
|
||||||
|
|
||||||
type failFilter[T Selectable] struct {
|
|
||||||
maxFails int
|
|
||||||
failTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// FailFilter filters the dead objects.
|
|
||||||
// An object is marked as dead if its failed count is greater than MaxFails.
|
|
||||||
func FailFilter[T Selectable](maxFails int, timeout time.Duration) Filter[T] {
|
|
||||||
return &failFilter[T]{
|
|
||||||
maxFails: maxFails,
|
|
||||||
failTimeout: timeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter filters dead objects.
|
|
||||||
func (f *failFilter[T]) Filter(vs ...T) []T {
|
|
||||||
maxFails := f.maxFails
|
|
||||||
failTimeout := f.failTimeout
|
|
||||||
if failTimeout == 0 {
|
|
||||||
failTimeout = DefaultFailTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vs) <= 1 || maxFails <= 0 {
|
|
||||||
return vs
|
|
||||||
}
|
|
||||||
var l []T
|
|
||||||
for _, v := range vs {
|
|
||||||
if marker := v.Marker(); marker != nil {
|
|
||||||
if marker.Count() < int64(maxFails) ||
|
|
||||||
time.Since(marker.Time()) >= failTimeout {
|
|
||||||
l = append(l, v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
l = append(l, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
type backupFilter[T Selectable] struct{}
|
|
||||||
|
|
||||||
// BackupFilter filters the backup objects.
|
|
||||||
// An object is marked as backup if its metadata has backup flag.
|
|
||||||
func BackupFilter[T Selectable]() Filter[T] {
|
|
||||||
return &backupFilter[T]{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter filters backup objects.
|
|
||||||
func (f *backupFilter[T]) Filter(vs ...T) []T {
|
|
||||||
if len(vs) <= 1 {
|
|
||||||
return vs
|
|
||||||
}
|
|
||||||
|
|
||||||
var l, backups []T
|
|
||||||
for _, v := range vs {
|
|
||||||
if mdutil.GetBool(v.Metadata(), "backup") {
|
|
||||||
backups = append(backups, v)
|
|
||||||
} else {
|
|
||||||
l = append(l, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(l) == 0 {
|
|
||||||
return backups
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
type Marker interface {
|
|
||||||
Time() time.Time
|
|
||||||
Count() int64
|
|
||||||
Mark()
|
|
||||||
Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
type failMarker struct {
|
|
||||||
failTime int64
|
|
||||||
failCount int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFailMarker() Marker {
|
|
||||||
return &failMarker{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *failMarker) Time() time.Time {
|
|
||||||
if m == nil {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Unix(atomic.LoadInt64(&m.failTime), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *failMarker) Count() int64 {
|
|
||||||
if m == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return atomic.LoadInt64(&m.failCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *failMarker) Mark() {
|
|
||||||
if m == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.AddInt64(&m.failCount, 1)
|
|
||||||
atomic.StoreInt64(&m.failTime, time.Now().Unix())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *failMarker) Reset() {
|
|
||||||
if m == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.StoreInt64(&m.failCount, 0)
|
|
||||||
}
|
|
@ -14,7 +14,7 @@ type Transport struct {
|
|||||||
addr string
|
addr string
|
||||||
ifceName string
|
ifceName string
|
||||||
sockOpts *SockOpts
|
sockOpts *SockOpts
|
||||||
route *Route
|
route Route
|
||||||
dialer dialer.Dialer
|
dialer dialer.Dialer
|
||||||
connector connector.Connector
|
connector connector.Connector
|
||||||
}
|
}
|
||||||
@ -53,7 +53,7 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
|||||||
if tr.sockOpts != nil {
|
if tr.sockOpts != nil {
|
||||||
netd.Mark = tr.sockOpts.Mark
|
netd.Mark = tr.sockOpts.Mark
|
||||||
}
|
}
|
||||||
if tr.route.Len() > 0 {
|
if tr.route != nil && tr.route.Len() > 0 {
|
||||||
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return tr.route.Dial(ctx, network, addr)
|
return tr.route.Dial(ctx, network, addr)
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ func (tr *Transport) Multiplex() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *Transport) WithRoute(r *Route) *Transport {
|
func (tr *Transport) WithRoute(r Route) *Transport {
|
||||||
tr.route = r
|
tr.route = r
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
74
selector/selector.go
Normal file
74
selector/selector.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package selector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/core/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Selectable interface {
|
||||||
|
Marker() Marker
|
||||||
|
Metadata() metadata.Metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
type Selector[T any] interface {
|
||||||
|
Select(...T) T
|
||||||
|
}
|
||||||
|
|
||||||
|
type Strategy[T Selectable] interface {
|
||||||
|
Apply(...T) T
|
||||||
|
}
|
||||||
|
|
||||||
|
type Filter[T Selectable] interface {
|
||||||
|
Filter(...T) []T
|
||||||
|
}
|
||||||
|
|
||||||
|
type Marker interface {
|
||||||
|
Time() time.Time
|
||||||
|
Count() int64
|
||||||
|
Mark()
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
type failMarker struct {
|
||||||
|
failTime int64
|
||||||
|
failCount int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFailMarker() Marker {
|
||||||
|
return &failMarker{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failMarker) Time() time.Time {
|
||||||
|
if m == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Unix(atomic.LoadInt64(&m.failTime), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failMarker) Count() int64 {
|
||||||
|
if m == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return atomic.LoadInt64(&m.failCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failMarker) Mark() {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt64(&m.failCount, 1)
|
||||||
|
atomic.StoreInt64(&m.failTime, time.Now().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failMarker) Reset() {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.StoreInt64(&m.failCount, 0)
|
||||||
|
}
|
Reference in New Issue
Block a user