add router interface

This commit is contained in:
ginuerzh 2024-07-08 22:28:21 +08:00
parent 30cc928705
commit 48d070d345
20 changed files with 37 additions and 1112 deletions

View File

@ -41,7 +41,7 @@ type TLSNodeSettings struct {
type NodeOptions struct {
Network string
Transport *Transport
Transport Transporter
Bypass bypass.Bypass
Resolver resolver.Resolver
HostMapper hosts.HostMapper
@ -53,7 +53,7 @@ type NodeOptions struct {
type NodeOption func(*NodeOptions)
func TransportNodeOption(tr *Transport) NodeOption {
func TransportNodeOption(tr Transporter) NodeOption {
return func(o *NodeOptions) {
o.Transport = tr
}

View File

@ -1,44 +0,0 @@
package chain
import (
"context"
"fmt"
"net"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/resolver"
)
func Resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
if addr == "" {
return addr, nil
}
host, port, _ := net.SplitHostPort(addr)
if host == "" {
return addr, nil
}
if hosts != nil {
if ips, _ := hosts.Lookup(ctx, network, host); len(ips) > 0 {
log.Debugf("hit host mapper: %s -> %s", host, ips)
return net.JoinHostPort(ips[0].String(), port), nil
}
}
if r != nil {
ips, err := r.Resolve(ctx, network, host)
if err != nil {
if err == resolver.ErrInvalid {
return addr, nil
}
log.Error(err)
}
if len(ips) == 0 {
return "", fmt.Errorf("resolver: domain %s does not exist", host)
}
return net.JoinHostPort(ips[0].String(), port), nil
}
return addr, nil
}

View File

@ -2,96 +2,18 @@ package chain
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/logger"
)
var (
ErrEmptyRoute = errors.New("empty route")
)
var (
DefaultRoute Route = &route{}
)
type Route interface {
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
Nodes() []*Node
}
// route is a Route without nodes.
type route struct{}
func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
var options DialOptions
for _, opt := range opts {
opt(&options)
}
netd := dialer.NetDialer{
Interface: options.Interface,
Netns: options.Netns,
Logger: options.Logger,
}
if options.SockOpts != nil {
netd.Mark = options.SockOpts.Mark
}
return netd.Dial(ctx, network, address)
}
func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
var options BindOptions
for _, opt := range opts {
opt(&options)
}
switch network {
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
return net.ListenTCP(network, addr)
case "udp", "udp4", "udp6":
addr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
logger := logger.Default().WithFields(map[string]any{
"network": network,
"address": address,
})
ln := udp.NewListener(conn, &udp.ListenConfig{
Backlog: options.Backlog,
ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: options.UDPDataBufferSize,
TTL: options.UDPConnTTL,
KeepAlive: true,
Logger: logger,
})
return ln, err
default:
err := fmt.Errorf("network %s unsupported", network)
return nil, err
}
}
func (r *route) Nodes() []*Node {
return nil
}
type DialOptions struct {
Interface string
Netns string

View File

@ -1,9 +1,7 @@
package chain
import (
"bytes"
"context"
"fmt"
"net"
"time"
@ -92,193 +90,8 @@ func LoggerRouterOption(logger logger.Logger) RouterOption {
}
}
type Router struct {
options RouterOptions
}
func NewRouter(opts ...RouterOption) *Router {
r := &Router{}
for _, opt := range opts {
if opt != nil {
opt(&r.options)
}
}
if r.options.Timeout == 0 {
r.options.Timeout = 15 * time.Second
}
if r.options.Logger == nil {
r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"})
}
return r
}
func (r *Router) Options() *RouterOptions {
if r == nil {
return nil
}
return &r.options
}
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
host := address
if h, _, _ := net.SplitHostPort(address); h != "" {
host = h
}
r.record(ctx, recorder.RecorderServiceRouterDialAddress, []byte(host))
conn, err = r.dial(ctx, network, address)
if err != nil {
r.record(ctx, recorder.RecorderServiceRouterDialAddressError, []byte(host))
return
}
if network == "udp" || network == "udp4" || network == "udp6" {
if _, ok := conn.(net.PacketConn); !ok {
return &packetConn{conn}, nil
}
}
return
}
func (r *Router) record(ctx context.Context, name string, data []byte) error {
if len(data) == 0 {
return nil
}
for _, rec := range r.options.Recorders {
if rec.Record == name {
err := rec.Recorder.Record(ctx, data)
if err != nil {
r.options.Logger.Errorf("record %s: %v", name, err)
}
return err
}
}
return nil
}
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ {
var ipAddr string
ipAddr, err = Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
if err != nil {
r.options.Logger.Error(err)
break
}
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, ipAddr, WithHostRouteOption(address))
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", ipAddr)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
conn, err = route.Dial(ctx, network, ipAddr,
InterfaceDialOption(r.options.IfceName),
NetnsDialOption(r.options.Netns),
SockOptsDialOption(r.options.SockOpts),
LoggerDialOption(r.options.Logger),
)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ {
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, address)
if route == nil || len(route.Nodes()) == 0 {
err = ErrEmptyRoute
return
}
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", address)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
ln, err = route.Bind(ctx, network, address, opts...)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func routePath(route Route) (path []*Node) {
if route == nil {
return
}
for _, node := range route.Nodes() {
if tr := node.Options().Transport; tr != nil {
path = append(path, routePath(tr.Options().Route)...)
}
path = append(path, node)
}
return
}
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
type Router interface {
Options() *RouterOptions
Dial(ctx context.Context, network, address string) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
}

View File

@ -4,9 +4,7 @@ import (
"context"
"net"
net_dialer "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/dialer"
)
type TransportOptions struct {
@ -49,97 +47,12 @@ func RouteTransportOption(route Route) TransportOption {
}
}
type Transport struct {
dialer dialer.Dialer
connector connector.Connector
options TransportOptions
}
func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOption) *Transport {
tr := &Transport{
dialer: d,
connector: c,
}
for _, opt := range opts {
if opt != nil {
opt(&tr.options)
}
}
return tr
}
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Netns: tr.options.Netns,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
if tr.options.Route != nil && len(tr.options.Route.Nodes()) > 0 {
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tr.options.Route.Dial(ctx, network, addr)
}
}
opts := []dialer.DialOption{
dialer.HostDialOption(tr.options.Addr),
dialer.NetDialerDialOption(netd),
}
return tr.dialer.Dial(ctx, addr, opts...)
}
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
var err error
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
conn, err = hs.Handshake(ctx, conn,
dialer.AddrHandshakeOption(tr.options.Addr))
if err != nil {
return nil, err
}
}
if hs, ok := tr.connector.(connector.Handshaker); ok {
return hs.Handshake(ctx, conn)
}
return conn, nil
}
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Netns: tr.options.Netns,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
return tr.connector.Connect(ctx, conn, network, address,
connector.NetDialerConnectOption(netd),
)
}
func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if binder, ok := tr.connector.(connector.Binder); ok {
return binder.Bind(ctx, conn, network, address, opts...)
}
return nil, connector.ErrBindUnsupported
}
func (tr *Transport) Multiplex() bool {
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
return mux.Multiplex()
}
return false
}
func (tr *Transport) Options() *TransportOptions {
if tr != nil {
return &tr.options
}
return nil
}
func (tr *Transport) Copy() *Transport {
tr2 := &Transport{}
*tr2 = *tr
return tr
type Transporter interface {
Dial(ctx context.Context, addr string) (net.Conn, error)
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error)
Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error)
Multiplex() bool
Options() *TransportOptions
Copy() Transporter
}

View File

@ -1,84 +0,0 @@
package net
import (
"fmt"
"net"
)
func ParseInterfaceAddr(ifceName, network string) (ifce string, addr []net.Addr, err error) {
if ifceName == "" {
addr = append(addr, nil)
return
}
ip := net.ParseIP(ifceName)
if ip == nil {
var ife *net.Interface
ife, err = net.InterfaceByName(ifceName)
if err != nil {
return
}
var addrs []net.Addr
addrs, err = ife.Addrs()
if err != nil {
return
}
if len(addrs) == 0 {
err = fmt.Errorf("addr not found for interface %s", ifceName)
return
}
ifce = ifceName
for _, addr_ := range addrs {
if ipNet, ok := addr_.(*net.IPNet); ok {
addr = append(addr, ipToAddr(ipNet.IP, network))
}
}
} else {
ifce, err = findInterfaceByIP(ip)
if err != nil {
return
}
addr = []net.Addr{ipToAddr(ip, network)}
}
return
}
func ipToAddr(ip net.IP, network string) (addr net.Addr) {
port := 0
switch network {
case "tcp", "tcp4", "tcp6":
addr = &net.TCPAddr{IP: ip, Port: port}
return
case "udp", "udp4", "udp6":
addr = &net.UDPAddr{IP: ip, Port: port}
return
default:
addr = &net.IPAddr{IP: ip}
return
}
}
func findInterfaceByIP(ip net.IP) (string, error) {
ifces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, ifce := range ifces {
addrs, _ := ifce.Addrs()
if len(addrs) == 0 {
continue
}
for _, addr := range addrs {
ipAddr, _ := addr.(*net.IPNet)
if ipAddr == nil {
continue
}
// logger.Default().Infof("%s-%s", ipAddr, ip)
if ipAddr.IP.Equal(ip) {
return ifce.Name, nil
}
}
}
return "", nil
}

10
common/net/dialer.go Normal file
View File

@ -0,0 +1,10 @@
package net
import (
"context"
"net"
)
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}

View File

@ -1,171 +0,0 @@
package dialer
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"syscall"
"time"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
"github.com/vishvananda/netns"
)
const (
DefaultTimeout = 10 * time.Second
)
var (
DefaultNetDialer = &NetDialer{}
)
type NetDialer struct {
Interface string
Netns string
Mark int
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
Logger logger.Logger
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if d == nil {
d = DefaultNetDialer
}
log := d.Logger
if log == nil {
log = logger.Default()
}
if d.Netns != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originNs, err := netns.Get()
if err != nil {
return nil, fmt.Errorf("netns.Get(): %v", err)
}
defer netns.Set(originNs)
ns, err := netns.GetFromName(d.Netns)
if err != nil {
return nil, fmt.Errorf("netns.GetFromName(%s): %v", d.Netns, err)
}
defer ns.Close()
if err := netns.Set(ns); err != nil {
return nil, fmt.Errorf("netns.Set(%s): %v", d.Netns, err)
}
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
switch network {
case "unix":
netd := net.Dialer{}
return netd.DialContext(ctx, network, addr)
default:
}
ifces := strings.Split(d.Interface, ",")
for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!")
ifce = strings.TrimSuffix(ifce, "!")
var ifceName string
var ifAddrs []net.Addr
ifceName, ifAddrs, err = xnet.ParseInterfaceAddr(ifce, network)
if err != nil && strict {
return
}
for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log)
if err == nil {
return
}
log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err)
if strict &&
!strings.Contains(err.Error(), "no suitable address found") &&
!strings.Contains(err.Error(), "mismatched local address type") {
return
}
}
}
return
}
func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) {
if ifceName != "" {
log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network)
}
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
c, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
sc, err := c.SyscallConn()
if err != nil {
log.Error(err)
return nil, err
}
err = sc.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
if err != nil {
log.Error(err)
}
return c, nil
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
LocalAddr: ifAddr,
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
},
}
if d.Netns != "" {
// https://github.com/golang/go/issues/44922#issuecomment-796645858
netd.FallbackDelay = -1
}
return netd.DialContext(ctx, network, addr)
}

View File

@ -1,19 +0,0 @@
package dialer
import (
"golang.org/x/sys/unix"
)
func bindDevice(fd uintptr, ifceName string) error {
if ifceName == "" {
return nil
}
return unix.BindToDevice(int(fd), ifceName)
}
func setMark(fd uintptr, mark int) error {
if mark == 0 {
return nil
}
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
}

View File

@ -1,11 +0,0 @@
//go:build !linux
package dialer
func bindDevice(fd uintptr, ifceName string) error {
return nil
}
func setMark(fd uintptr, mark int) error {
return nil
}

View File

@ -1,108 +0,0 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
)
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type conn struct {
net.PacketConn
localAddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
idle int32 // indicate the connection is idle
closed chan struct{}
closeMutex sync.Mutex
keepAlive bool
}
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
return &conn{
PacketConn: c,
localAddr: laddr,
remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize),
closed: make(chan struct{}),
keepAlive: keepAlive,
}
}
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *conn) Write(b []byte) (n int, err error) {
n, err = c.WriteTo(b, c.remoteAddr)
if !c.keepAlive {
c.Close()
}
return
}
func (c *conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) WriteQueue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}

View File

@ -1,131 +0,0 @@
package udp
import (
"net"
"time"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type ListenConfig struct {
Addr net.Addr
Backlog int
ReadQueueSize int
ReadBufferSize int
TTL time.Duration
KeepAlive bool
Logger logger.Logger
}
type listener struct {
conn net.PacketConn
cqueue chan net.Conn
connPool *connPool
// mux sync.Mutex
closed chan struct{}
errChan chan error
config *ListenConfig
}
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
if cfg == nil {
cfg = &ListenConfig{}
}
ln := &listener{
conn: conn,
cqueue: make(chan net.Conn, cfg.Backlog),
closed: make(chan struct{}),
errChan: make(chan error, 1),
config: cfg,
}
if cfg.KeepAlive {
ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger)
}
go ln.listenLoop()
return ln
}
func (ln *listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-ln.cqueue:
return
case <-ln.closed:
return nil, net.ErrClosed
case err = <-ln.errChan:
if err == nil {
err = net.ErrClosed
}
return
}
}
func (ln *listener) listenLoop() {
for {
select {
case <-ln.closed:
return
default:
}
b := bufpool.Get(ln.config.ReadBufferSize)
n, raddr, err := ln.conn.ReadFrom(b)
if err != nil {
ln.errChan <- err
close(ln.errChan)
return
}
c := ln.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := c.WriteQueue(b[:n]); err != nil {
ln.config.Logger.Warn("data discarded: ", err)
}
}
}
func (ln *listener) Addr() net.Addr {
if ln.config.Addr != nil {
return ln.config.Addr
}
return ln.conn.LocalAddr()
}
func (ln *listener) Close() error {
select {
case <-ln.closed:
default:
close(ln.closed)
ln.conn.Close()
ln.connPool.Close()
}
return nil
}
func (ln *listener) getConn(raddr net.Addr) *conn {
// ln.mux.Lock()
// defer ln.mux.Unlock()
c, ok := ln.connPool.Get(raddr.String())
if ok {
return c
}
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
select {
case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c)
return c
default:
c.Close()
ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil
}
}

View File

@ -1,115 +0,0 @@
package udp
import (
"sync"
"time"
"github.com/go-gost/core/logger"
)
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key any) (c *conn, ok bool) {
if p == nil {
return
}
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key any, c *conn) {
if p == nil {
return
}
p.m.Store(key, c)
}
func (p *connPool) Delete(key any) {
if p == nil {
return
}
p.m.Delete(key)
}
func (p *connPool) Close() {
if p == nil {
return
}
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v any) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size := 0
idles := 0
p.m.Range(func(key, value any) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-p.closed:
return
}
}
}

View File

@ -1,41 +0,0 @@
package udp
import (
"io"
"net"
"syscall"
)
type Conn interface {
net.PacketConn
io.Reader
io.Writer
readUDP
writeUDP
setBuffer
syscallConn
remoteAddr
}
type setBuffer interface {
SetReadBuffer(bytes int) error
SetWriteBuffer(bytes int) error
}
type readUDP interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
type writeUDP interface {
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
}
type syscallConn interface {
SyscallConn() (syscall.RawConn, error)
}
type remoteAddr interface {
RemoteAddr() net.Addr
}

View File

@ -5,7 +5,7 @@ import (
"net/url"
"time"
"github.com/go-gost/core/common/net/dialer"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
@ -36,14 +36,14 @@ func LoggerOption(logger logger.Logger) Option {
}
type ConnectOptions struct {
NetDialer *dialer.NetDialer
Dialer xnet.Dialer
}
type ConnectOption func(opts *ConnectOptions)
func NetDialerConnectOption(netd *dialer.NetDialer) ConnectOption {
func DialerConnectOption(dialer xnet.Dialer) ConnectOption {
return func(opts *ConnectOptions) {
opts.NetDialer = netd
opts.Dialer = dialer
}
}

View File

@ -4,7 +4,7 @@ import (
"crypto/tls"
"net/url"
"github.com/go-gost/core/common/net/dialer"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
@ -42,8 +42,8 @@ func ProxyProtocolOption(ppv int) Option {
}
type DialOptions struct {
Host string
NetDialer *dialer.NetDialer
Host string
Dialer xnet.Dialer
}
type DialOption func(opts *DialOptions)
@ -54,9 +54,9 @@ func HostDialOption(host string) DialOption {
}
}
func NetDialerDialOption(netd *dialer.NetDialer) DialOption {
func NetDialerDialOption(dialer xnet.Dialer) DialOption {
return func(opts *DialOptions) {
opts.NetDialer = netd
opts.Dialer = dialer
}
}

5
go.mod
View File

@ -3,8 +3,3 @@ module github.com/go-gost/core
go 1.22
toolchain go1.22.2
require (
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.21.0
)

4
go.sum
View File

@ -1,4 +0,0 @@
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@ -16,7 +16,7 @@ import (
type Options struct {
Bypass bypass.Bypass
Router *chain.Router
Router chain.Router
Auth *url.Userinfo
Auther auth.Authenticator
RateLimiter rate.RateLimiter
@ -36,7 +36,7 @@ func BypassOption(bypass bypass.Bypass) Option {
}
}
func RouterOption(router *chain.Router) Option {
func RouterOption(router chain.Router) Option {
return func(opts *Options) {
opts.Router = router
}

View File

@ -27,7 +27,7 @@ type Options struct {
Service string
ProxyProtocol int
Netns string
Router *chain.Router
Router chain.Router
}
type Option func(opts *Options)
@ -104,7 +104,7 @@ func NetnsOption(netns string) Option {
}
}
func RouterOption(router *chain.Router) Option {
func RouterOption(router chain.Router) Option {
return func(opts *Options) {
opts.Router = router
}