add router interface

This commit is contained in:
ginuerzh
2024-07-08 22:28:21 +08:00
parent 30cc928705
commit 48d070d345
20 changed files with 37 additions and 1112 deletions

View File

@ -41,7 +41,7 @@ type TLSNodeSettings struct {
type NodeOptions struct {
Network string
Transport *Transport
Transport Transporter
Bypass bypass.Bypass
Resolver resolver.Resolver
HostMapper hosts.HostMapper
@ -53,7 +53,7 @@ type NodeOptions struct {
type NodeOption func(*NodeOptions)
func TransportNodeOption(tr *Transport) NodeOption {
func TransportNodeOption(tr Transporter) NodeOption {
return func(o *NodeOptions) {
o.Transport = tr
}

View File

@ -1,44 +0,0 @@
package chain
import (
"context"
"fmt"
"net"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/resolver"
)
func Resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
if addr == "" {
return addr, nil
}
host, port, _ := net.SplitHostPort(addr)
if host == "" {
return addr, nil
}
if hosts != nil {
if ips, _ := hosts.Lookup(ctx, network, host); len(ips) > 0 {
log.Debugf("hit host mapper: %s -> %s", host, ips)
return net.JoinHostPort(ips[0].String(), port), nil
}
}
if r != nil {
ips, err := r.Resolve(ctx, network, host)
if err != nil {
if err == resolver.ErrInvalid {
return addr, nil
}
log.Error(err)
}
if len(ips) == 0 {
return "", fmt.Errorf("resolver: domain %s does not exist", host)
}
return net.JoinHostPort(ips[0].String(), port), nil
}
return addr, nil
}

View File

@ -2,96 +2,18 @@ package chain
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/logger"
)
var (
ErrEmptyRoute = errors.New("empty route")
)
var (
DefaultRoute Route = &route{}
)
type Route interface {
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
Nodes() []*Node
}
// route is a Route without nodes.
type route struct{}
func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
var options DialOptions
for _, opt := range opts {
opt(&options)
}
netd := dialer.NetDialer{
Interface: options.Interface,
Netns: options.Netns,
Logger: options.Logger,
}
if options.SockOpts != nil {
netd.Mark = options.SockOpts.Mark
}
return netd.Dial(ctx, network, address)
}
func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
var options BindOptions
for _, opt := range opts {
opt(&options)
}
switch network {
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
return net.ListenTCP(network, addr)
case "udp", "udp4", "udp6":
addr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
logger := logger.Default().WithFields(map[string]any{
"network": network,
"address": address,
})
ln := udp.NewListener(conn, &udp.ListenConfig{
Backlog: options.Backlog,
ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: options.UDPDataBufferSize,
TTL: options.UDPConnTTL,
KeepAlive: true,
Logger: logger,
})
return ln, err
default:
err := fmt.Errorf("network %s unsupported", network)
return nil, err
}
}
func (r *route) Nodes() []*Node {
return nil
}
type DialOptions struct {
Interface string
Netns string

View File

@ -1,9 +1,7 @@
package chain
import (
"bytes"
"context"
"fmt"
"net"
"time"
@ -92,193 +90,8 @@ func LoggerRouterOption(logger logger.Logger) RouterOption {
}
}
type Router struct {
options RouterOptions
}
func NewRouter(opts ...RouterOption) *Router {
r := &Router{}
for _, opt := range opts {
if opt != nil {
opt(&r.options)
}
}
if r.options.Timeout == 0 {
r.options.Timeout = 15 * time.Second
}
if r.options.Logger == nil {
r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"})
}
return r
}
func (r *Router) Options() *RouterOptions {
if r == nil {
return nil
}
return &r.options
}
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
host := address
if h, _, _ := net.SplitHostPort(address); h != "" {
host = h
}
r.record(ctx, recorder.RecorderServiceRouterDialAddress, []byte(host))
conn, err = r.dial(ctx, network, address)
if err != nil {
r.record(ctx, recorder.RecorderServiceRouterDialAddressError, []byte(host))
return
}
if network == "udp" || network == "udp4" || network == "udp6" {
if _, ok := conn.(net.PacketConn); !ok {
return &packetConn{conn}, nil
}
}
return
}
func (r *Router) record(ctx context.Context, name string, data []byte) error {
if len(data) == 0 {
return nil
}
for _, rec := range r.options.Recorders {
if rec.Record == name {
err := rec.Recorder.Record(ctx, data)
if err != nil {
r.options.Logger.Errorf("record %s: %v", name, err)
}
return err
}
}
return nil
}
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ {
var ipAddr string
ipAddr, err = Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
if err != nil {
r.options.Logger.Error(err)
break
}
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, ipAddr, WithHostRouteOption(address))
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", ipAddr)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
conn, err = route.Dial(ctx, network, ipAddr,
InterfaceDialOption(r.options.IfceName),
NetnsDialOption(r.options.Netns),
SockOptsDialOption(r.options.SockOpts),
LoggerDialOption(r.options.Logger),
)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ {
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, address)
if route == nil || len(route.Nodes()) == 0 {
err = ErrEmptyRoute
return
}
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", address)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
ln, err = route.Bind(ctx, network, address, opts...)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func routePath(route Route) (path []*Node) {
if route == nil {
return
}
for _, node := range route.Nodes() {
if tr := node.Options().Transport; tr != nil {
path = append(path, routePath(tr.Options().Route)...)
}
path = append(path, node)
}
return
}
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
type Router interface {
Options() *RouterOptions
Dial(ctx context.Context, network, address string) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
}

View File

@ -4,9 +4,7 @@ import (
"context"
"net"
net_dialer "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/dialer"
)
type TransportOptions struct {
@ -49,97 +47,12 @@ func RouteTransportOption(route Route) TransportOption {
}
}
type Transport struct {
dialer dialer.Dialer
connector connector.Connector
options TransportOptions
}
func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOption) *Transport {
tr := &Transport{
dialer: d,
connector: c,
}
for _, opt := range opts {
if opt != nil {
opt(&tr.options)
}
}
return tr
}
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Netns: tr.options.Netns,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
if tr.options.Route != nil && len(tr.options.Route.Nodes()) > 0 {
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tr.options.Route.Dial(ctx, network, addr)
}
}
opts := []dialer.DialOption{
dialer.HostDialOption(tr.options.Addr),
dialer.NetDialerDialOption(netd),
}
return tr.dialer.Dial(ctx, addr, opts...)
}
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
var err error
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
conn, err = hs.Handshake(ctx, conn,
dialer.AddrHandshakeOption(tr.options.Addr))
if err != nil {
return nil, err
}
}
if hs, ok := tr.connector.(connector.Handshaker); ok {
return hs.Handshake(ctx, conn)
}
return conn, nil
}
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Netns: tr.options.Netns,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
return tr.connector.Connect(ctx, conn, network, address,
connector.NetDialerConnectOption(netd),
)
}
func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if binder, ok := tr.connector.(connector.Binder); ok {
return binder.Bind(ctx, conn, network, address, opts...)
}
return nil, connector.ErrBindUnsupported
}
func (tr *Transport) Multiplex() bool {
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
return mux.Multiplex()
}
return false
}
func (tr *Transport) Options() *TransportOptions {
if tr != nil {
return &tr.options
}
return nil
}
func (tr *Transport) Copy() *Transport {
tr2 := &Transport{}
*tr2 = *tr
return tr
type Transporter interface {
Dial(ctx context.Context, addr string) (net.Conn, error)
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error)
Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error)
Multiplex() bool
Options() *TransportOptions
Copy() Transporter
}