gost/pkg/dialer/net.go
2022-03-01 21:48:50 +08:00

85 lines
1.7 KiB
Go

package dialer
import (
"context"
"fmt"
"net"
"time"
"github.com/go-gost/gost/pkg/logger"
)
var (
DefaultNetDialer = &NetDialer{
Timeout: 30 * time.Second,
}
)
type NetDialer struct {
Interface string
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
ifAddr, err := parseInterfaceAddr(d.Interface, network)
if err != nil {
return nil, err
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
logger.Default().Infof("interface: %s %s %v", d.Interface, network, ifAddr)
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
return net.ListenUDP(network, laddr)
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
Timeout: d.Timeout,
LocalAddr: ifAddr,
}
return netd.DialContext(ctx, network, addr)
}
func parseInterfaceAddr(ifceName, network string) (net.Addr, error) {
if ifceName == "" {
return nil, nil
}
ip := net.ParseIP(ifceName)
if ip == nil {
ifce, err := net.InterfaceByName(ifceName)
if err != nil {
return nil, err
}
addrs, err := ifce.Addrs()
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("addr not found for interface %s", ifceName)
}
ip = addrs[0].(*net.IPNet).IP
}
switch network {
case "tcp", "tcp4", "tcp6":
return &net.TCPAddr{IP: ip}, nil
case "udp", "udp4", "udp6":
return &net.UDPAddr{IP: ip}, nil
default:
return &net.IPAddr{IP: ip}, nil
}
}