add chain hop
This commit is contained in:
235
chain/route.go
235
chain/route.go
@ -9,240 +9,49 @@ import (
|
||||
|
||||
"github.com/go-gost/core/common/net/dialer"
|
||||
"github.com/go-gost/core/common/net/udp"
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyRoute = errors.New("empty route")
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultRoute Route = &route{}
|
||||
)
|
||||
|
||||
type Route interface {
|
||||
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
|
||||
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
|
||||
Len() int
|
||||
Path() []*Node
|
||||
Nodes() []*Node
|
||||
}
|
||||
|
||||
type route struct {
|
||||
chain *Chain
|
||||
nodes []*Node
|
||||
}
|
||||
type route struct{}
|
||||
|
||||
func newRoute() *route {
|
||||
return &route{}
|
||||
}
|
||||
|
||||
func (r *route) addNode(node *Node) {
|
||||
r.nodes = append(r.nodes, node)
|
||||
}
|
||||
|
||||
func (r *route) WithChain(chain *Chain) *route {
|
||||
r.chain = chain
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
|
||||
func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
|
||||
var options DialOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
if r.Len() == 0 {
|
||||
netd := dialer.NetDialer{
|
||||
Timeout: options.Timeout,
|
||||
Interface: options.Interface,
|
||||
}
|
||||
if options.SockOpts != nil {
|
||||
netd.Mark = options.SockOpts.Mark
|
||||
}
|
||||
if r != nil {
|
||||
netd.Logger = options.Logger
|
||||
}
|
||||
|
||||
return netd.Dial(ctx, network, address)
|
||||
netd := dialer.NetDialer{
|
||||
Timeout: options.Timeout,
|
||||
Interface: options.Interface,
|
||||
Logger: options.Logger,
|
||||
}
|
||||
if options.SockOpts != nil {
|
||||
netd.Mark = options.SockOpts.Mark
|
||||
}
|
||||
|
||||
conn, err := r.connect(ctx, options.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address)
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return cc, nil
|
||||
return netd.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
|
||||
func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
|
||||
var options BindOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
if r.Len() == 0 {
|
||||
return r.bindLocal(ctx, network, address, &options)
|
||||
}
|
||||
|
||||
conn, err := r.connect(ctx, options.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx,
|
||||
conn, network, address,
|
||||
connector.BacklogBindOption(options.Backlog),
|
||||
connector.MuxBindOption(options.Mux),
|
||||
connector.UDPConnTTLBindOption(options.UDPConnTTL),
|
||||
connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize),
|
||||
connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize),
|
||||
)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) {
|
||||
if r.Len() == 0 {
|
||||
return nil, ErrEmptyRoute
|
||||
}
|
||||
|
||||
network := "ip"
|
||||
node := r.nodes[0]
|
||||
|
||||
defer func() {
|
||||
if r.chain != nil {
|
||||
marker := r.chain.Marker()
|
||||
// chain error
|
||||
if err != nil {
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
if v := metrics.GetCounter(metrics.MetricChainErrorsCounter,
|
||||
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil {
|
||||
v.Inc()
|
||||
}
|
||||
} else {
|
||||
if marker != nil {
|
||||
marker.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
|
||||
marker := node.Marker()
|
||||
if err != nil {
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
cc, err := node.transport.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
cn, err := node.transport.Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
if marker != nil {
|
||||
marker.Reset()
|
||||
}
|
||||
|
||||
if r.chain != nil {
|
||||
if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver,
|
||||
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil {
|
||||
v.Observe(time.Since(start).Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
preNode := node
|
||||
for _, node := range r.nodes[1:] {
|
||||
marker := node.Marker()
|
||||
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
cc, err = preNode.transport.Connect(ctx, cn, "tcp", addr)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
cc, err = node.transport.Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
if marker != nil {
|
||||
marker.Mark()
|
||||
}
|
||||
return
|
||||
}
|
||||
if marker != nil {
|
||||
marker.Reset()
|
||||
}
|
||||
|
||||
cn = cc
|
||||
preNode = node
|
||||
}
|
||||
|
||||
conn = cn
|
||||
return
|
||||
}
|
||||
|
||||
func (r *route) Len() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return len(r.nodes)
|
||||
}
|
||||
|
||||
func (r *route) GetNode(index int) *Node {
|
||||
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
|
||||
return nil
|
||||
}
|
||||
return r.nodes[index]
|
||||
}
|
||||
|
||||
func (r *route) Path() (path []*Node) {
|
||||
if r == nil || len(r.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if node.transport != nil && node.transport.route != nil {
|
||||
path = append(path, node.transport.route.Path()...)
|
||||
}
|
||||
path = append(path, node)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
@ -264,10 +73,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi
|
||||
"address": address,
|
||||
})
|
||||
ln := udp.NewListener(conn, &udp.ListenConfig{
|
||||
Backlog: opts.Backlog,
|
||||
ReadQueueSize: opts.UDPDataQueueSize,
|
||||
ReadBufferSize: opts.UDPDataBufferSize,
|
||||
TTL: opts.UDPConnTTL,
|
||||
Backlog: options.Backlog,
|
||||
ReadQueueSize: options.UDPDataQueueSize,
|
||||
ReadBufferSize: options.UDPDataBufferSize,
|
||||
TTL: options.UDPConnTTL,
|
||||
KeepAlive: true,
|
||||
Logger: logger,
|
||||
})
|
||||
@ -278,6 +87,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi
|
||||
}
|
||||
}
|
||||
|
||||
func (r *route) Nodes() []*Node {
|
||||
return nil
|
||||
}
|
||||
|
||||
type DialOptions struct {
|
||||
Timeout time.Duration
|
||||
Interface string
|
||||
|
Reference in New Issue
Block a user