support multiple network interfaces

This commit is contained in:
ginuerzh
2022-03-01 21:48:50 +08:00
parent 07132d8de7
commit ffdf11e89e
44 changed files with 431 additions and 474 deletions

View File

@ -75,7 +75,11 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
grpcOpts := []grpc.DialOption{
// grpc.WithBlock(),
grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) {
return d.dial(ctx, "tcp", s, &options)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
return netd.Dial(c, "tcp", s)
}),
grpc.WithAuthority(host),
grpc.WithConnectParams(grpc.ConnectParams{
@ -111,31 +115,3 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
closed: make(chan struct{}),
}, nil
}
func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.options.Logger.Error(err)
} else {
d.options.Logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.options.Logger.Error(err)
} else {
d.options.Logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}

View File

@ -73,7 +73,11 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
Transport: &http.Transport{
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, &options)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
return netd.Dial(ctx, network, addr)
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
@ -94,31 +98,3 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
delete(d.clients, address)
}), nil
}
func (d *http2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}

View File

@ -93,14 +93,22 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial
if d.h2c {
client.Transport = &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
return netd.Dial(ctx, network, addr)
},
}
} else {
client.Transport = &http.Transport{
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
return netd.Dial(ctx, network, addr)
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
@ -163,31 +171,3 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial
}
return conn, nil
}
func (d *h2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}

View File

@ -56,11 +56,6 @@ func (d *kcpDialer) Multiplex() bool {
}
func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
@ -70,12 +65,17 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
ok = false
}
if !ok {
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
if d.md.config.TCP {
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
pc, err := tcpraw.Dial("tcp", addr)
if err != nil {
return nil, err
@ -85,7 +85,11 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
PacketConn: pc,
}
} else {
conn, err = net.ListenUDP("udp", nil)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "udp", addr)
if err != nil {
return nil, err
}

84
pkg/dialer/net.go Normal file
View File

@ -0,0 +1,84 @@
package dialer
import (
"context"
"fmt"
"net"
"time"
"github.com/go-gost/gost/pkg/logger"
)
var (
DefaultNetDialer = &NetDialer{
Timeout: 30 * time.Second,
}
)
type NetDialer struct {
Interface string
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
ifAddr, err := parseInterfaceAddr(d.Interface, network)
if err != nil {
return nil, err
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
logger.Default().Infof("interface: %s %s %v", d.Interface, network, ifAddr)
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
return net.ListenUDP(network, laddr)
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
Timeout: d.Timeout,
LocalAddr: ifAddr,
}
return netd.DialContext(ctx, network, addr)
}
func parseInterfaceAddr(ifceName, network string) (net.Addr, error) {
if ifceName == "" {
return nil, nil
}
ip := net.ParseIP(ifceName)
if ip == nil {
ifce, err := net.InterfaceByName(ifceName)
if err != nil {
return nil, err
}
addrs, err := ifce.Addrs()
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("addr not found for interface %s", ifceName)
}
ip = addrs[0].(*net.IPNet).IP
}
switch network {
case "tcp", "tcp4", "tcp6":
return &net.TCPAddr{IP: ip}, nil
case "udp", "udp4", "udp6":
return &net.UDPAddr{IP: ip}, nil
default:
return &net.IPAddr{IP: ip}, nil
}
}

View File

@ -35,8 +35,16 @@ func (d *obfsHTTPDialer) Init(md md.Metadata) (err error) {
}
func (d *obfsHTTPDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
options := &dialer.DialOptions{}
for _, opt := range opts {
opt(options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err := netd.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}

View File

@ -35,8 +35,16 @@ func (d *obfsTLSDialer) Init(md md.Metadata) (err error) {
}
func (d *obfsTLSDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
options := &dialer.DialOptions{}
for _, opt := range opts {
opt(options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err := netd.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}

View File

@ -1,9 +1,7 @@
package dialer
import (
"context"
"crypto/tls"
"net"
"net/url"
"github.com/go-gost/gost/pkg/logger"
@ -36,8 +34,8 @@ func LoggerOption(logger logger.Logger) Option {
}
type DialOptions struct {
Host string
DialFunc func(ctx context.Context, addr string) (net.Conn, error)
Host string
NetDialer *NetDialer
}
type DialOption func(opts *DialOptions)
@ -48,9 +46,9 @@ func HostDialOption(host string) DialOption {
}
}
func DialFuncDialOption(dialf func(ctx context.Context, addr string) (net.Conn, error)) DialOption {
func NetDialerDialOption(netd *NetDialer) DialOption {
return func(opts *DialOptions) {
opts.DialFunc = dialf
opts.NetDialer = netd
}
}

View File

@ -54,25 +54,27 @@ func (d *quicDialer) Multiplex() bool {
}
func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
session, ok := d.sessions[addr]
if !ok {
var cc *net.UDPConn
cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return
options := &dialer.DialOptions{}
for _, opt := range opts {
opt(options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "udp", "")
if err != nil {
return nil, err
}
conn = cc
if d.md.cipherKey != nil {
conn = quic_util.CipherConn(cc, d.md.cipherKey)
conn = quic_util.CipherConn(conn.(*net.UDPConn), d.md.cipherKey)
}
session = &quicSession{conn: conn}

View File

@ -52,11 +52,6 @@ func (d *sshDialer) Multiplex() bool {
}
func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
@ -66,7 +61,16 @@ func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "tcp", addr)
if err != nil {
return
}
@ -134,34 +138,6 @@ func (d *sshDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
return ssh_util.NewConn(conn, channel), nil
}
func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) {
config := ssh.ClientConfig{
Timeout: 30 * time.Second,

View File

@ -51,11 +51,6 @@ func (d *sshdDialer) Multiplex() bool {
}
func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
@ -65,7 +60,16 @@ func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "tcp", addr)
if err != nil {
return
}
@ -129,36 +133,6 @@ func (d *sshdDialer) Handshake(ctx context.Context, conn net.Conn, options ...di
return ssh_util.NewClientConn(session.conn, session.client), nil
}
func (d *sshdDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
log := d.options.Logger
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
log.Error(err)
} else {
log.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
log.Error(err)
} else {
log.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) {
config := ssh.ClientConfig{
// Timeout: timeout,

View File

@ -40,8 +40,11 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
opt(&options)
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err := netd.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}

View File

@ -44,8 +44,11 @@ func (d *tlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
opt(&options)
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err := netd.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}

View File

@ -54,11 +54,6 @@ func (d *mtlsDialer) Multiplex() bool {
}
func (d *mtlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
@ -68,7 +63,16 @@ func (d *mtlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "tcp", addr)
if err != nil {
return
}
@ -122,34 +126,6 @@ func (d *mtlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...di
return cc, nil
}
func (d *mtlsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) {
tlsConn := tls.Client(conn, d.options.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {

View File

@ -35,16 +35,20 @@ func (d *udpDialer) Init(md md.Metadata) (err error) {
}
func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
c, err := net.DialUDP("udp", nil, taddr)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
c, err := netd.Dial(ctx, "udp", addr)
if err != nil {
return nil, err
}
return &conn{
UDPConn: c,
UDPConn: c.(*net.UDPConn),
}, nil
}

View File

@ -61,8 +61,11 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt
opt(&options)
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err := netd.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}

View File

@ -71,11 +71,6 @@ func (d *mwsDialer) Multiplex() bool {
}
func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
@ -85,7 +80,16 @@ func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
netd := options.NetDialer
if netd == nil {
netd = dialer.DefaultNetDialer
}
conn, err = netd.Dial(ctx, "tcp", addr)
if err != nil {
return
}
@ -143,34 +147,6 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
return cc, nil
}
func (d *mwsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]any{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,