add chain hop

This commit is contained in:
ginuerzh 2022-09-20 11:48:30 +08:00
parent 50d443049f
commit 41ff9835a6
8 changed files with 293 additions and 527 deletions

View File

@ -2,103 +2,8 @@ package chain
import ( import (
"context" "context"
"github.com/go-gost/core/metadata"
"github.com/go-gost/core/selector"
) )
type Chainer interface { type Chainer interface {
Route(ctx context.Context, network, address string) Route Route(ctx context.Context, network, address string) Route
} }
type Chain struct {
name string
groups []*NodeGroup
marker selector.Marker
metadata metadata.Metadata
}
func NewChain(name string, groups ...*NodeGroup) *Chain {
return &Chain{
name: name,
groups: groups,
marker: selector.NewFailMarker(),
}
}
func (c *Chain) AddNodeGroup(group *NodeGroup) {
c.groups = append(c.groups, group)
}
func (c *Chain) WithMetadata(md metadata.Metadata) {
c.metadata = md
}
// Metadata implements metadata.Metadatable interface.
func (c *Chain) Metadata() metadata.Metadata {
return c.metadata
}
// Marker implements selector.Markable interface.
func (c *Chain) Marker() selector.Marker {
return c.marker
}
func (c *Chain) Route(ctx context.Context, network, address string) Route {
if c == nil || len(c.groups) == 0 {
return nil
}
rt := newRoute().WithChain(c)
for _, group := range c.groups {
// hop level bypass test
if group.bypass != nil && group.bypass.Contains(address) {
break
}
node := group.FilterAddr(address).Next(ctx)
if node == nil {
return rt
}
if node.transport.Multiplex() {
tr := node.transport.
Copy().
WithRoute(rt)
node = node.Copy()
node.transport = tr
rt = newRoute()
}
rt.addNode(node)
}
return rt
}
type ChainGroup struct {
chains []Chainer
selector selector.Selector[Chainer]
}
func NewChainGroup(chains ...Chainer) *ChainGroup {
return &ChainGroup{chains: chains}
}
func (p *ChainGroup) WithSelector(s selector.Selector[Chainer]) *ChainGroup {
p.selector = s
return p
}
func (p *ChainGroup) Route(ctx context.Context, network, address string) Route {
if chain := p.next(ctx); chain != nil {
return chain.Route(ctx, network, address)
}
return nil
}
func (p *ChainGroup) next(ctx context.Context) Chainer {
if p == nil || len(p.chains) == 0 {
return nil
}
return p.selector.Select(ctx, p.chains...)
}

20
chain/hop.go Normal file
View File

@ -0,0 +1,20 @@
package chain
import "context"
type SelectOptions struct {
Addr string
}
type SelectOption func(*SelectOptions)
func AddrSelectOption(addr string) SelectOption {
return func(o *SelectOptions) {
o.Addr = addr
}
}
type Hop interface {
Nodes() []*Node
Select(ctx context.Context, opts ...SelectOption) *Node
}

View File

@ -1,8 +1,6 @@
package chain package chain
import ( import (
"context"
"github.com/go-gost/core/bypass" "github.com/go-gost/core/bypass"
"github.com/go-gost/core/hosts" "github.com/go-gost/core/hosts"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
@ -10,48 +8,76 @@ import (
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
) )
type NodeOptions struct {
Transport *Transport
Bypass bypass.Bypass
Resolver resolver.Resolver
HostMapper hosts.HostMapper
Metadata metadata.Metadata
}
type NodeOption func(*NodeOptions)
func TransportNodeOption(tr *Transport) NodeOption {
return func(o *NodeOptions) {
o.Transport = tr
}
}
func BypassNodeOption(bp bypass.Bypass) NodeOption {
return func(o *NodeOptions) {
o.Bypass = bp
}
}
func ResoloverNodeOption(resolver resolver.Resolver) NodeOption {
return func(o *NodeOptions) {
o.Resolver = resolver
}
}
func HostMapperNodeOption(m hosts.HostMapper) NodeOption {
return func(o *NodeOptions) {
o.HostMapper = m
}
}
func MetadataNodeOption(md metadata.Metadata) NodeOption {
return func(o *NodeOptions) {
o.Metadata = md
}
}
type Node struct { type Node struct {
Name string Name string
Addr string Addr string
transport *Transport
bypass bypass.Bypass
resolver resolver.Resolver
hostMapper hosts.HostMapper
marker selector.Marker marker selector.Marker
metadata metadata.Metadata options NodeOptions
} }
func NewNode(name, addr string) *Node { func NewNode(name string, addr string, opts ...NodeOption) *Node {
var options NodeOptions
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return &Node{ return &Node{
Name: name, Name: name,
Addr: addr, Addr: addr,
marker: selector.NewFailMarker(), marker: selector.NewFailMarker(),
options: options,
} }
} }
func (node *Node) WithTransport(tr *Transport) *Node { func (node *Node) Options() *NodeOptions {
node.transport = tr return &node.options
return node
} }
func (node *Node) WithBypass(bypass bypass.Bypass) *Node { // Metadata implements metadadta.Metadatable interface.
node.bypass = bypass func (node *Node) Metadata() metadata.Metadata {
return node return node.options.Metadata
}
func (node *Node) WithResolver(reslv resolver.Resolver) *Node {
node.resolver = reslv
return node
}
func (node *Node) WithHostMapper(m hosts.HostMapper) *Node {
node.hostMapper = m
return node
}
func (node *Node) WithMetadata(md metadata.Metadata) *Node {
node.metadata = md
return node
} }
// Marker implements selector.Markable interface. // Marker implements selector.Markable interface.
@ -59,65 +85,8 @@ func (node *Node) Marker() selector.Marker {
return node.marker return node.marker
} }
// Metadata implements metadadta.Metadatable interface.
func (node *Node) Metadata() metadata.Metadata {
return node.metadata
}
func (node *Node) Copy() *Node { func (node *Node) Copy() *Node {
n := &Node{} n := &Node{}
*n = *node *n = *node
return n return n
} }
type NodeGroup struct {
nodes []*Node
selector selector.Selector[*Node]
bypass bypass.Bypass
}
func NewNodeGroup(nodes ...*Node) *NodeGroup {
return &NodeGroup{
nodes: nodes,
}
}
func (g *NodeGroup) AddNode(node *Node) {
g.nodes = append(g.nodes, node)
}
func (g *NodeGroup) Nodes() []*Node {
return g.nodes
}
func (g *NodeGroup) WithSelector(selector selector.Selector[*Node]) *NodeGroup {
g.selector = selector
return g
}
func (g *NodeGroup) WithBypass(bypass bypass.Bypass) *NodeGroup {
g.bypass = bypass
return g
}
func (g *NodeGroup) FilterAddr(addr string) *NodeGroup {
var nodes []*Node
for _, node := range g.nodes {
if node.bypass == nil || !node.bypass.Contains(addr) {
nodes = append(nodes, node)
}
}
return &NodeGroup{
nodes: nodes,
selector: g.selector,
bypass: g.bypass,
}
}
func (g *NodeGroup) Next(ctx context.Context) *Node {
if g == nil || len(g.nodes) == 0 {
return nil
}
return g.selector.Select(ctx, g.nodes...)
}

View File

@ -10,7 +10,7 @@ import (
"github.com/go-gost/core/resolver" "github.com/go-gost/core/resolver"
) )
func resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { func Resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
if addr == "" { if addr == "" {
return addr, nil return addr, nil
} }

View File

@ -9,240 +9,49 @@ import (
"github.com/go-gost/core/common/net/dialer" "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/common/net/udp" "github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/metrics"
) )
var ( var (
ErrEmptyRoute = errors.New("empty route") ErrEmptyRoute = errors.New("empty route")
) )
var (
DefaultRoute Route = &route{}
)
type Route interface { type Route interface {
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
Len() int Nodes() []*Node
Path() []*Node
} }
type route struct { type route struct{}
chain *Chain
nodes []*Node
}
func newRoute() *route { func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
return &route{}
}
func (r *route) addNode(node *Node) {
r.nodes = append(r.nodes, node)
}
func (r *route) WithChain(chain *Chain) *route {
r.chain = chain
return r
}
func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
var options DialOptions var options DialOptions
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
} }
if r.Len() == 0 {
netd := dialer.NetDialer{ netd := dialer.NetDialer{
Timeout: options.Timeout, Timeout: options.Timeout,
Interface: options.Interface, Interface: options.Interface,
Logger: options.Logger,
} }
if options.SockOpts != nil { if options.SockOpts != nil {
netd.Mark = options.SockOpts.Mark netd.Mark = options.SockOpts.Mark
} }
if r != nil {
netd.Logger = options.Logger
}
return netd.Dial(ctx, network, address) return netd.Dial(ctx, network, address)
}
conn, err := r.connect(ctx, options.Logger)
if err != nil {
return nil, err
}
cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address)
if err != nil {
if conn != nil {
conn.Close()
}
return nil, err
}
return cc, nil
} }
func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) { func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
var options BindOptions var options BindOptions
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
} }
if r.Len() == 0 {
return r.bindLocal(ctx, network, address, &options)
}
conn, err := r.connect(ctx, options.Logger)
if err != nil {
return nil, err
}
ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx,
conn, network, address,
connector.BacklogBindOption(options.Backlog),
connector.MuxBindOption(options.Mux),
connector.UDPConnTTLBindOption(options.UDPConnTTL),
connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize),
connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize),
)
if err != nil {
conn.Close()
return nil, err
}
return ln, nil
}
func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) {
if r.Len() == 0 {
return nil, ErrEmptyRoute
}
network := "ip"
node := r.nodes[0]
defer func() {
if r.chain != nil {
marker := r.chain.Marker()
// chain error
if err != nil {
if marker != nil {
marker.Mark()
}
if v := metrics.GetCounter(metrics.MetricChainErrorsCounter,
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil {
v.Inc()
}
} else {
if marker != nil {
marker.Reset()
}
}
}
}()
addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
marker := node.Marker()
if err != nil {
if marker != nil {
marker.Mark()
}
return
}
start := time.Now()
cc, err := node.transport.Dial(ctx, addr)
if err != nil {
if marker != nil {
marker.Mark()
}
return
}
cn, err := node.transport.Handshake(ctx, cc)
if err != nil {
cc.Close()
if marker != nil {
marker.Mark()
}
return
}
if marker != nil {
marker.Reset()
}
if r.chain != nil {
if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver,
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil {
v.Observe(time.Since(start).Seconds())
}
}
preNode := node
for _, node := range r.nodes[1:] {
marker := node.Marker()
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger)
if err != nil {
cn.Close()
if marker != nil {
marker.Mark()
}
return
}
cc, err = preNode.transport.Connect(ctx, cn, "tcp", addr)
if err != nil {
cn.Close()
if marker != nil {
marker.Mark()
}
return
}
cc, err = node.transport.Handshake(ctx, cc)
if err != nil {
cn.Close()
if marker != nil {
marker.Mark()
}
return
}
if marker != nil {
marker.Reset()
}
cn = cc
preNode = node
}
conn = cn
return
}
func (r *route) Len() int {
if r == nil {
return 0
}
return len(r.nodes)
}
func (r *route) GetNode(index int) *Node {
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
return nil
}
return r.nodes[index]
}
func (r *route) Path() (path []*Node) {
if r == nil || len(r.nodes) == 0 {
return nil
}
for _, node := range r.nodes {
if node.transport != nil && node.transport.route != nil {
path = append(path, node.transport.route.Path()...)
}
path = append(path, node)
}
return
}
func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) {
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address) addr, err := net.ResolveTCPAddr(network, address)
@ -264,10 +73,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi
"address": address, "address": address,
}) })
ln := udp.NewListener(conn, &udp.ListenConfig{ ln := udp.NewListener(conn, &udp.ListenConfig{
Backlog: opts.Backlog, Backlog: options.Backlog,
ReadQueueSize: opts.UDPDataQueueSize, ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: opts.UDPDataBufferSize, ReadBufferSize: options.UDPDataBufferSize,
TTL: opts.UDPConnTTL, TTL: options.UDPConnTTL,
KeepAlive: true, KeepAlive: true,
Logger: logger, Logger: logger,
}) })
@ -278,6 +87,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi
} }
} }
func (r *route) Nodes() []*Node {
return nil
}
type DialOptions struct { type DialOptions struct {
Timeout time.Duration Timeout time.Duration
Interface string Interface string

View File

@ -17,68 +17,96 @@ type SockOpts struct {
Mark int Mark int
} }
type Router struct { type RouterOptions struct {
ifceName string IfceName string
sockOpts *SockOpts SockOpts *SockOpts
timeout time.Duration Timeout time.Duration
retries int Retries int
chain Chainer Chain Chainer
resolver resolver.Resolver Resolver resolver.Resolver
hosts hosts.HostMapper HostMapper hosts.HostMapper
recorders []recorder.RecorderObject Recorders []recorder.RecorderObject
logger logger.Logger Logger logger.Logger
} }
func (r *Router) WithTimeout(timeout time.Duration) *Router { type RouterOption func(*RouterOptions)
r.timeout = timeout
return r
}
func (r *Router) WithRetries(retries int) *Router { func InterfaceRouterOption(ifceName string) RouterOption {
r.retries = retries return func(o *RouterOptions) {
return r o.IfceName = ifceName
}
func (r *Router) WithInterface(ifceName string) *Router {
r.ifceName = ifceName
return r
}
func (r *Router) WithSockOpts(so *SockOpts) *Router {
r.sockOpts = so
return r
}
func (r *Router) WithChain(chain Chainer) *Router {
r.chain = chain
return r
}
func (r *Router) WithResolver(resolver resolver.Resolver) *Router {
r.resolver = resolver
return r
}
func (r *Router) WithHosts(hosts hosts.HostMapper) *Router {
r.hosts = hosts
return r
}
func (r *Router) Hosts() hosts.HostMapper {
if r != nil {
return r.hosts
} }
}
func SockOptsRouterOption(so *SockOpts) RouterOption {
return func(o *RouterOptions) {
o.SockOpts = so
}
}
func TimeoutRouterOption(timeout time.Duration) RouterOption {
return func(o *RouterOptions) {
o.Timeout = timeout
}
}
func RetriesRouterOption(retries int) RouterOption {
return func(o *RouterOptions) {
o.Retries = retries
}
}
func ChainRouterOption(chain Chainer) RouterOption {
return func(o *RouterOptions) {
o.Chain = chain
}
}
func ResolverRouterOption(resolver resolver.Resolver) RouterOption {
return func(o *RouterOptions) {
o.Resolver = resolver
}
}
func HostMapperRouterOption(m hosts.HostMapper) RouterOption {
return func(o *RouterOptions) {
o.HostMapper = m
}
}
func RecordersRouterOption(recorders ...recorder.RecorderObject) RouterOption {
return func(o *RouterOptions) {
o.Recorders = recorders
}
}
func LoggerRouterOption(logger logger.Logger) RouterOption {
return func(o *RouterOptions) {
o.Logger = logger
}
}
type Router struct {
options RouterOptions
}
func NewRouter(opts ...RouterOption) *Router {
r := &Router{}
for _, opt := range opts {
if opt != nil {
opt(&r.options)
}
}
if r.options.Logger == nil {
r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"})
}
return r
}
func (r *Router) Options() *RouterOptions {
if r == nil {
return nil return nil
} }
return &r.options
func (r *Router) WithRecorder(recorders ...recorder.RecorderObject) *Router {
r.recorders = recorders
return r
}
func (r *Router) WithLogger(logger logger.Logger) *Router {
r.logger = logger
return r
} }
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
@ -107,11 +135,11 @@ func (r *Router) record(ctx context.Context, name string, data []byte) error {
return nil return nil
} }
for _, rec := range r.recorders { for _, rec := range r.options.Recorders {
if rec.Record == name { if rec.Record == name {
err := rec.Recorder.Record(ctx, data) err := rec.Recorder.Record(ctx, data)
if err != nil { if err != nil {
r.logger.Errorf("record %s: %v", name, err) r.options.Logger.Errorf("record %s: %v", name, err)
} }
return err return err
} }
@ -120,90 +148,99 @@ func (r *Router) record(ctx context.Context, name string, data []byte) error {
} }
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.retries + 1 count := r.options.Retries + 1
if count <= 0 { if count <= 0 {
count = 1 count = 1
} }
r.logger.Debugf("dial %s/%s", address, network) r.options.Logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
var route Route var route Route
if r.chain != nil { if r.options.Chain != nil {
route = r.chain.Route(ctx, network, address) route = r.options.Chain.Route(ctx, network, address)
} }
if r.logger.IsLevelEnabled(logger.DebugLevel) { if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{} buf := bytes.Buffer{}
var path []*Node for _, node := range routePath(route) {
if route != nil {
path = route.Path()
}
for _, node := range path {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
} }
fmt.Fprintf(&buf, "%s", address) fmt.Fprintf(&buf, "%s", address)
r.logger.Debugf("route(retry=%d) %s", i, buf.String()) r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
} }
address, err = resolve(ctx, "ip", address, r.resolver, r.hosts, r.logger) address, err = Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
if err != nil { if err != nil {
r.logger.Error(err) r.options.Logger.Error(err)
break break
} }
if route == nil { if route == nil {
route = newRoute() route = DefaultRoute
} }
conn, err = route.Dial(ctx, network, address, conn, err = route.Dial(ctx, network, address,
InterfaceDialOption(r.ifceName), InterfaceDialOption(r.options.IfceName),
SockOptsDialOption(r.sockOpts), SockOptsDialOption(r.options.SockOpts),
LoggerDialOption(r.logger), LoggerDialOption(r.options.Logger),
) )
if err == nil { if err == nil {
break break
} }
r.logger.Errorf("route(retry=%d) %s", i, err) r.options.Logger.Errorf("route(retry=%d) %s", i, err)
} }
return return
} }
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) { func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
count := r.retries + 1 count := r.options.Retries + 1
if count <= 0 { if count <= 0 {
count = 1 count = 1
} }
r.logger.Debugf("bind on %s/%s", address, network) r.options.Logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
var route Route var route Route
if r.chain != nil { if r.options.Chain != nil {
route = r.chain.Route(ctx, network, address) route = r.options.Chain.Route(ctx, network, address)
if route.Len() == 0 { if len(route.Nodes()) == 0 {
err = ErrEmptyRoute err = ErrEmptyRoute
return return
} }
} }
if r.logger.IsLevelEnabled(logger.DebugLevel) { if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{} buf := bytes.Buffer{}
for _, node := range route.Path() { for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
} }
fmt.Fprintf(&buf, "%s", address) fmt.Fprintf(&buf, "%s", address)
r.logger.Debugf("route(retry=%d) %s", i, buf.String()) r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
} }
ln, err = route.Bind(ctx, network, address, opts...) ln, err = route.Bind(ctx, network, address, opts...)
if err == nil { if err == nil {
break break
} }
r.logger.Errorf("route(retry=%d) %s", i, err) r.options.Logger.Errorf("route(retry=%d) %s", i, err)
} }
return return
} }
func routePath(route Route) (path []*Node) {
if route == nil {
return
}
for _, node := range route.Nodes() {
if tr := node.Options().Transport; tr != nil {
path = append(path, routePath(tr.Options().Route)...)
}
path = append(path, node)
}
return
}
type packetConn struct { type packetConn struct {
net.Conn net.Conn
} }

View File

@ -10,62 +10,81 @@ import (
"github.com/go-gost/core/dialer" "github.com/go-gost/core/dialer"
) )
type TransportOptions struct {
Addr string
IfceName string
SockOpts *SockOpts
Route Route
Timeout time.Duration
}
type TransportOption func(*TransportOptions)
func AddrTransportOption(addr string) TransportOption {
return func(o *TransportOptions) {
o.Addr = addr
}
}
func InterfaceTransportOption(ifceName string) TransportOption {
return func(o *TransportOptions) {
o.IfceName = ifceName
}
}
func SockOptsTransportOption(so *SockOpts) TransportOption {
return func(o *TransportOptions) {
o.SockOpts = so
}
}
func RouteTransportOption(route Route) TransportOption {
return func(o *TransportOptions) {
o.Route = route
}
}
func TimeoutTransportOption(timeout time.Duration) TransportOption {
return func(o *TransportOptions) {
o.Timeout = timeout
}
}
type Transport struct { type Transport struct {
addr string
ifceName string
sockOpts *SockOpts
route Route
dialer dialer.Dialer dialer dialer.Dialer
connector connector.Connector connector connector.Connector
timeout time.Duration options TransportOptions
} }
func (tr *Transport) Copy() *Transport { func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOption) *Transport {
tr2 := &Transport{} tr := &Transport{
*tr2 = *tr dialer: d,
return tr connector: c,
} }
for _, opt := range opts {
if opt != nil {
opt(&tr.options)
}
}
func (tr *Transport) WithInterface(ifceName string) *Transport {
tr.ifceName = ifceName
return tr
}
func (tr *Transport) WithSockOpts(so *SockOpts) *Transport {
tr.sockOpts = so
return tr
}
func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport {
tr.dialer = dialer
return tr
}
func (tr *Transport) WithConnector(connector connector.Connector) *Transport {
tr.connector = connector
return tr
}
func (tr *Transport) WithTimeout(d time.Duration) *Transport {
tr.timeout = d
return tr return tr
} }
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
netd := &net_dialer.NetDialer{ netd := &net_dialer.NetDialer{
Interface: tr.ifceName, Interface: tr.options.IfceName,
Timeout: tr.timeout, Timeout: tr.options.Timeout,
} }
if tr.sockOpts != nil { if tr.options.SockOpts != nil {
netd.Mark = tr.sockOpts.Mark netd.Mark = tr.options.SockOpts.Mark
} }
if tr.route != nil && tr.route.Len() > 0 { if tr.options.Route != nil && len(tr.options.Route.Nodes()) > 0 {
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tr.route.Dial(ctx, network, addr) return tr.options.Route.Dial(ctx, network, addr)
} }
} }
opts := []dialer.DialOption{ opts := []dialer.DialOption{
dialer.HostDialOption(tr.addr), dialer.HostDialOption(tr.options.Addr),
dialer.NetDialerDialOption(netd), dialer.NetDialerDialOption(netd),
} }
return tr.dialer.Dial(ctx, addr, opts...) return tr.dialer.Dial(ctx, addr, opts...)
@ -75,7 +94,7 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er
var err error var err error
if hs, ok := tr.dialer.(dialer.Handshaker); ok { if hs, ok := tr.dialer.(dialer.Handshaker); ok {
conn, err = hs.Handshake(ctx, conn, conn, err = hs.Handshake(ctx, conn,
dialer.AddrHandshakeOption(tr.addr)) dialer.AddrHandshakeOption(tr.options.Addr))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,11 +107,11 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
netd := &net_dialer.NetDialer{ netd := &net_dialer.NetDialer{
Interface: tr.ifceName, Interface: tr.options.IfceName,
Timeout: tr.timeout, Timeout: tr.options.Timeout,
} }
if tr.sockOpts != nil { if tr.options.SockOpts != nil {
netd.Mark = tr.sockOpts.Mark netd.Mark = tr.options.SockOpts.Mark
} }
return tr.connector.Connect(ctx, conn, network, address, return tr.connector.Connect(ctx, conn, network, address,
connector.NetDialerConnectOption(netd), connector.NetDialerConnectOption(netd),
@ -113,12 +132,15 @@ func (tr *Transport) Multiplex() bool {
return false return false
} }
func (tr *Transport) WithRoute(r Route) *Transport { func (tr *Transport) Options() *TransportOptions {
tr.route = r if tr != nil {
return tr return &tr.options
}
return nil
} }
func (tr *Transport) WithAddr(addr string) *Transport { func (tr *Transport) Copy() *Transport {
tr.addr = addr tr2 := &Transport{}
*tr2 = *tr
return tr return tr
} }

View File

@ -14,5 +14,5 @@ type Handler interface {
} }
type Forwarder interface { type Forwarder interface {
Forward(*chain.NodeGroup) Forward(chain.Hop)
} }