initial commit
This commit is contained in:
41
chain/chain.go
Normal file
41
chain/chain.go
Normal file
@ -0,0 +1,41 @@
|
||||
package chain
|
||||
|
||||
type Chainer interface {
|
||||
Route(network, address string) *Route
|
||||
}
|
||||
|
||||
type Chain struct {
|
||||
groups []*NodeGroup
|
||||
}
|
||||
|
||||
func (c *Chain) AddNodeGroup(group *NodeGroup) {
|
||||
c.groups = append(c.groups, group)
|
||||
}
|
||||
|
||||
func (c *Chain) Route(network, address string) (r *Route) {
|
||||
if c == nil || len(c.groups) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r = &Route{}
|
||||
for _, group := range c.groups {
|
||||
node := group.Next()
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
if node.Bypass != nil && node.Bypass.Contains(address) {
|
||||
break
|
||||
}
|
||||
|
||||
if node.Transport.Multiplex() {
|
||||
tr := node.Transport.Copy().
|
||||
WithRoute(r)
|
||||
node = node.Copy()
|
||||
node.Transport = tr
|
||||
r = &Route{}
|
||||
}
|
||||
|
||||
r.addNode(node)
|
||||
}
|
||||
return r
|
||||
}
|
97
chain/node.go
Normal file
97
chain/node.go
Normal file
@ -0,0 +1,97 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/bypass"
|
||||
"github.com/go-gost/core/hosts"
|
||||
"github.com/go-gost/core/resolver"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
Name string
|
||||
Addr string
|
||||
Transport *Transport
|
||||
Bypass bypass.Bypass
|
||||
Resolver resolver.Resolver
|
||||
Hosts hosts.HostMapper
|
||||
Marker *FailMarker
|
||||
}
|
||||
|
||||
func (node *Node) Copy() *Node {
|
||||
n := &Node{}
|
||||
*n = *node
|
||||
return n
|
||||
}
|
||||
|
||||
type NodeGroup struct {
|
||||
nodes []*Node
|
||||
selector Selector
|
||||
}
|
||||
|
||||
func NewNodeGroup(nodes ...*Node) *NodeGroup {
|
||||
return &NodeGroup{
|
||||
nodes: nodes,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *NodeGroup) AddNode(node *Node) {
|
||||
g.nodes = append(g.nodes, node)
|
||||
}
|
||||
|
||||
func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup {
|
||||
g.selector = selector
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *NodeGroup) Next() *Node {
|
||||
if g == nil || len(g.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := g.selector
|
||||
if s == nil {
|
||||
s = DefaultSelector
|
||||
}
|
||||
|
||||
return s.Select(g.nodes...)
|
||||
}
|
||||
|
||||
type FailMarker struct {
|
||||
failTime int64
|
||||
failCount int64
|
||||
}
|
||||
|
||||
func (m *FailMarker) FailTime() int64 {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return atomic.LoadInt64(&m.failTime)
|
||||
}
|
||||
|
||||
func (m *FailMarker) FailCount() int64 {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return atomic.LoadInt64(&m.failCount)
|
||||
}
|
||||
|
||||
func (m *FailMarker) Mark() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt64(&m.failCount, 1)
|
||||
atomic.StoreInt64(&m.failTime, time.Now().Unix())
|
||||
}
|
||||
|
||||
func (m *FailMarker) Reset() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&m.failCount, 0)
|
||||
}
|
47
chain/resovle.go
Normal file
47
chain/resovle.go
Normal file
@ -0,0 +1,47 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/hosts"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/resolver"
|
||||
)
|
||||
|
||||
func resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
|
||||
if addr == "" {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if host == "" {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
if hosts != nil {
|
||||
if ips, _ := hosts.Lookup(network, host); len(ips) > 0 {
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
return net.JoinHostPort(ips[0].String(), port), nil
|
||||
}
|
||||
}
|
||||
|
||||
if r != nil {
|
||||
ips, err := r.Resolve(ctx, network, host)
|
||||
if err != nil {
|
||||
if err == resolver.ErrInvalid {
|
||||
return addr, nil
|
||||
}
|
||||
log.Error(err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("resolver: domain %s does not exists", host)
|
||||
}
|
||||
return net.JoinHostPort(ips[0].String(), port), nil
|
||||
}
|
||||
return addr, nil
|
||||
}
|
192
chain/route.go
Normal file
192
chain/route.go
Normal file
@ -0,0 +1,192 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/common/net/dialer"
|
||||
"github.com/go-gost/core/common/util/udp"
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyRoute = errors.New("empty route")
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
nodes []*Node
|
||||
ifceName string
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (r *Route) addNode(node *Node) {
|
||||
r.nodes = append(r.nodes, node)
|
||||
}
|
||||
|
||||
func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if r.Len() == 0 {
|
||||
netd := dialer.NetDialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
if r != nil {
|
||||
netd.Interface = r.ifceName
|
||||
}
|
||||
return netd.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
conn, err := r.connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
if r.Len() == 0 {
|
||||
return r.bindLocal(ctx, network, address, opts...)
|
||||
}
|
||||
|
||||
conn, err := r.connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
if r.Len() == 0 {
|
||||
return nil, ErrEmptyRoute
|
||||
}
|
||||
|
||||
network := "ip"
|
||||
node := r.nodes[0]
|
||||
|
||||
addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
if err != nil {
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
cc, err := node.Transport.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
|
||||
cn, err := node.Transport.Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
node.Marker.Reset()
|
||||
|
||||
preNode := node
|
||||
for _, node := range r.nodes[1:] {
|
||||
addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
cc, err = preNode.Transport.Connect(ctx, cn, "tcp", addr)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
cc, err = node.Transport.Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.Marker.Mark()
|
||||
return
|
||||
}
|
||||
node.Marker.Reset()
|
||||
|
||||
cn = cc
|
||||
preNode = node
|
||||
}
|
||||
|
||||
conn = cn
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Route) Len() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return len(r.nodes)
|
||||
}
|
||||
|
||||
func (r *Route) GetNode(index int) *Node {
|
||||
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
|
||||
return nil
|
||||
}
|
||||
return r.nodes[index]
|
||||
}
|
||||
|
||||
func (r *Route) Path() (path []*Node) {
|
||||
if r == nil || len(r.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, node := range r.nodes {
|
||||
if node.Transport != nil && node.Transport.route != nil {
|
||||
path = append(path, node.Transport.route.Path()...)
|
||||
}
|
||||
path = append(path, node)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
options := connector.BindOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
addr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenTCP(network, addr)
|
||||
case "udp", "udp4", "udp6":
|
||||
addr, err := net.ResolveUDPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger := logger.Default().WithFields(map[string]any{
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
ln := udp.NewListener(conn, addr,
|
||||
options.Backlog, options.UDPDataQueueSize, options.UDPDataBufferSize,
|
||||
options.UDPConnTTL, logger)
|
||||
return ln, err
|
||||
default:
|
||||
err := fmt.Errorf("network %s unsupported", network)
|
||||
return nil, err
|
||||
}
|
||||
}
|
169
chain/router.go
Normal file
169
chain/router.go
Normal file
@ -0,0 +1,169 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/hosts"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/resolver"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
timeout time.Duration
|
||||
retries int
|
||||
ifceName string
|
||||
chain Chainer
|
||||
resolver resolver.Resolver
|
||||
hosts hosts.HostMapper
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (r *Router) WithTimeout(timeout time.Duration) *Router {
|
||||
r.timeout = timeout
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) WithRetries(retries int) *Router {
|
||||
r.retries = retries
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) WithInterface(ifceName string) *Router {
|
||||
r.ifceName = ifceName
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) WithChain(chain Chainer) *Router {
|
||||
r.chain = chain
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) WithResolver(resolver resolver.Resolver) *Router {
|
||||
r.resolver = resolver
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) WithHosts(hosts hosts.HostMapper) *Router {
|
||||
r.hosts = hosts
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) Hosts() hosts.HostMapper {
|
||||
if r != nil {
|
||||
return r.hosts
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) WithLogger(logger logger.Logger) *Router {
|
||||
r.logger = logger
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
conn, err = r.dial(ctx, network, address)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if network == "udp" || network == "udp4" || network == "udp6" {
|
||||
if _, ok := conn.(net.PacketConn); !ok {
|
||||
return &packetConn{conn}, nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
count := r.retries + 1
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
r.logger.Debugf("dial %s/%s", address, network)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
var route *Route
|
||||
if r.chain != nil {
|
||||
route = r.chain.Route(network, address)
|
||||
}
|
||||
|
||||
if r.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
buf := bytes.Buffer{}
|
||||
for _, node := range route.Path() {
|
||||
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", address)
|
||||
r.logger.Debugf("route(retry=%d) %s", i, buf.String())
|
||||
}
|
||||
|
||||
address, err = resolve(ctx, "ip", address, r.resolver, r.hosts, r.logger)
|
||||
if err != nil {
|
||||
r.logger.Error(err)
|
||||
break
|
||||
}
|
||||
|
||||
if route == nil {
|
||||
route = &Route{}
|
||||
}
|
||||
route.ifceName = r.ifceName
|
||||
route.logger = r.logger
|
||||
|
||||
conn, err = route.Dial(ctx, network, address)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
r.logger.Errorf("route(retry=%d) %s", i, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) {
|
||||
count := r.retries + 1
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
r.logger.Debugf("bind on %s/%s", address, network)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
var route *Route
|
||||
if r.chain != nil {
|
||||
route = r.chain.Route(network, address)
|
||||
}
|
||||
|
||||
if r.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
buf := bytes.Buffer{}
|
||||
for _, node := range route.Path() {
|
||||
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", address)
|
||||
r.logger.Debugf("route(retry=%d) %s", i, buf.String())
|
||||
}
|
||||
|
||||
ln, err = route.Bind(ctx, network, address, opts...)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
r.logger.Errorf("route(retry=%d) %s", i, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.Read(b)
|
||||
addr = c.Conn.RemoteAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.Write(b)
|
||||
}
|
170
chain/selector.go
Normal file
170
chain/selector.go
Normal file
@ -0,0 +1,170 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// default options for FailFilter
|
||||
const (
|
||||
DefaultFailTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultSelector = NewSelector(RoundRobinStrategy())
|
||||
)
|
||||
|
||||
type Selector interface {
|
||||
Select(nodes ...*Node) *Node
|
||||
}
|
||||
|
||||
type selector struct {
|
||||
strategy Strategy
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
func NewSelector(strategy Strategy, filters ...Filter) Selector {
|
||||
return &selector{
|
||||
filters: filters,
|
||||
strategy: strategy,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *selector) Select(nodes ...*Node) *Node {
|
||||
for _, filter := range s.filters {
|
||||
nodes = filter.Filter(nodes...)
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.strategy.Apply(nodes...)
|
||||
}
|
||||
|
||||
type Strategy interface {
|
||||
Apply(nodes ...*Node) *Node
|
||||
}
|
||||
|
||||
type roundRobinStrategy struct {
|
||||
counter uint64
|
||||
}
|
||||
|
||||
// RoundRobinStrategy is a strategy for node selector.
|
||||
// The node will be selected by round-robin algorithm.
|
||||
func RoundRobinStrategy() Strategy {
|
||||
return &roundRobinStrategy{}
|
||||
}
|
||||
|
||||
func (s *roundRobinStrategy) Apply(nodes ...*Node) *Node {
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := atomic.AddUint64(&s.counter, 1) - 1
|
||||
return nodes[int(n%uint64(len(nodes)))]
|
||||
}
|
||||
|
||||
type randomStrategy struct {
|
||||
rand *rand.Rand
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// RandomStrategy is a strategy for node selector.
|
||||
// The node will be selected randomly.
|
||||
func RandomStrategy() Strategy {
|
||||
return &randomStrategy{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *randomStrategy) Apply(nodes ...*Node) *Node {
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
r := s.rand.Int()
|
||||
|
||||
return nodes[r%len(nodes)]
|
||||
}
|
||||
|
||||
type fifoStrategy struct{}
|
||||
|
||||
// FIFOStrategy is a strategy for node selector.
|
||||
// The node will be selected from first to last,
|
||||
// and will stick to the selected node until it is failed.
|
||||
func FIFOStrategy() Strategy {
|
||||
return &fifoStrategy{}
|
||||
}
|
||||
|
||||
// Apply applies the fifo strategy for the nodes.
|
||||
func (s *fifoStrategy) Apply(nodes ...*Node) *Node {
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
return nodes[0]
|
||||
}
|
||||
|
||||
type Filter interface {
|
||||
Filter(nodes ...*Node) []*Node
|
||||
}
|
||||
|
||||
type failFilter struct {
|
||||
maxFails int
|
||||
failTimeout time.Duration
|
||||
}
|
||||
|
||||
// FailFilter filters the dead node.
|
||||
// A node is marked as dead if its failed count is greater than MaxFails.
|
||||
func FailFilter(maxFails int, timeout time.Duration) Filter {
|
||||
return &failFilter{
|
||||
maxFails: maxFails,
|
||||
failTimeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Filter filters dead nodes.
|
||||
func (f *failFilter) Filter(nodes ...*Node) []*Node {
|
||||
maxFails := f.maxFails
|
||||
failTimeout := f.failTimeout
|
||||
if failTimeout == 0 {
|
||||
failTimeout = DefaultFailTimeout
|
||||
}
|
||||
|
||||
if len(nodes) <= 1 || maxFails <= 0 {
|
||||
return nodes
|
||||
}
|
||||
var nl []*Node
|
||||
for _, node := range nodes {
|
||||
if node.Marker.FailCount() < int64(maxFails) ||
|
||||
time.Since(time.Unix(node.Marker.FailTime(), 0)) >= failTimeout {
|
||||
nl = append(nl, node)
|
||||
}
|
||||
}
|
||||
return nl
|
||||
}
|
||||
|
||||
type invalidFilter struct{}
|
||||
|
||||
// InvalidFilter filters the invalid node.
|
||||
// A node is invalid if its port is invalid (negative or zero value).
|
||||
func InvalidFilter() Filter {
|
||||
return &invalidFilter{}
|
||||
}
|
||||
|
||||
// Filter filters invalid nodes.
|
||||
func (f *invalidFilter) Filter(nodes ...*Node) []*Node {
|
||||
var nl []*Node
|
||||
for _, node := range nodes {
|
||||
_, sport, _ := net.SplitHostPort(node.Addr)
|
||||
if port, _ := strconv.Atoi(sport); port > 0 {
|
||||
nl = append(nl, node)
|
||||
}
|
||||
}
|
||||
return nl
|
||||
}
|
100
chain/transport.go
Normal file
100
chain/transport.go
Normal file
@ -0,0 +1,100 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
net_dialer "github.com/go-gost/core/common/net/dialer"
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/dialer"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
addr string
|
||||
ifceName string
|
||||
route *Route
|
||||
dialer dialer.Dialer
|
||||
connector connector.Connector
|
||||
}
|
||||
|
||||
func (tr *Transport) Copy() *Transport {
|
||||
tr2 := &Transport{}
|
||||
*tr2 = *tr
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) WithInterface(ifceName string) *Transport {
|
||||
tr.ifceName = ifceName
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport {
|
||||
tr.dialer = dialer
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) WithConnector(connector connector.Connector) *Transport {
|
||||
tr.connector = connector
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||
netd := &net_dialer.NetDialer{
|
||||
Interface: tr.ifceName,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
if tr.route.Len() > 0 {
|
||||
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return tr.route.Dial(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
opts := []dialer.DialOption{
|
||||
dialer.HostDialOption(tr.addr),
|
||||
dialer.NetDialerDialOption(netd),
|
||||
}
|
||||
return tr.dialer.Dial(ctx, addr, opts...)
|
||||
}
|
||||
|
||||
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
|
||||
var err error
|
||||
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
|
||||
conn, err = hs.Handshake(ctx, conn,
|
||||
dialer.AddrHandshakeOption(tr.addr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if hs, ok := tr.connector.(connector.Handshaker); ok {
|
||||
return hs.Handshake(ctx, conn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
|
||||
return tr.connector.Connect(ctx, conn, network, address)
|
||||
}
|
||||
|
||||
func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
if binder, ok := tr.connector.(connector.Binder); ok {
|
||||
return binder.Bind(ctx, conn, network, address, opts...)
|
||||
}
|
||||
return nil, connector.ErrBindUnsupported
|
||||
}
|
||||
|
||||
func (tr *Transport) Multiplex() bool {
|
||||
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
|
||||
return mux.Multiplex()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (tr *Transport) WithRoute(r *Route) *Transport {
|
||||
tr.route = r
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) WithAddr(addr string) *Transport {
|
||||
tr.addr = addr
|
||||
return tr
|
||||
}
|
Reference in New Issue
Block a user