diff --git a/handler/tun/handler.go b/handler/tun/handler.go index 77e971f..9ff2110 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -102,7 +102,9 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. }) log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) - h.handleClient(ctx, conn, raddr, config, log) + if err := h.handleClient(ctx, conn, raddr, config, log); err != nil { + log.Error(err) + } return nil } diff --git a/handler/tun/server.go b/handler/tun/server.go index e319bae..5f5dd07 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -173,8 +173,6 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * return nil } - // h.updateRoute(src, addr, log) - if addr := h.findRouteFor(dst, config.Routes...); addr != nil { log.Debugf("find route: %s -> %s", dst, addr) diff --git a/internal/util/tap/config.go b/internal/util/tap/config.go index 77895ce..42fafa5 100644 --- a/internal/util/tap/config.go +++ b/internal/util/tap/config.go @@ -1,9 +1,16 @@ package tap +import "net" + +// Route is an IP routing entry +type Route struct { + Net net.IPNet + Gateway net.IP +} type Config struct { Name string Net string MTU int - Routes []string Gateway string + Routes []Route } diff --git a/listener/tap/conn.go b/listener/tap/conn.go index 7faeb31..0874d65 100644 --- a/listener/tap/conn.go +++ b/listener/tap/conn.go @@ -1,18 +1,20 @@ package tap import ( + "context" "errors" + "io" "net" "time" mdata "github.com/go-gost/core/metadata" - "github.com/songgao/water" ) type conn struct { - ifce *water.Interface - laddr net.Addr - raddr net.Addr + ifce io.ReadWriteCloser + laddr net.Addr + raddr net.Addr + cancel context.CancelFunc } func (c *conn) Read(b []byte) (n int, err error) { @@ -44,6 +46,9 @@ func (c *conn) SetWriteDeadline(t time.Time) error { } func (c *conn) Close() (err error) { + if c.cancel != nil { + c.cancel() + } return c.ifce.Close() } diff --git a/listener/tap/listener.go b/listener/tap/listener.go index 311162b..46f453b 100644 --- a/listener/tap/listener.go +++ b/listener/tap/listener.go @@ -1,7 +1,9 @@ package tap import ( + "context" "net" + "time" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" @@ -18,7 +20,6 @@ func init() { } type tapListener struct { - saddr string addr net.Addr cqueue chan net.Conn closed chan struct{} @@ -33,7 +34,6 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(&options) } return &tapListener{ - saddr: options.Addr, logger: options.Logger, options: options, } @@ -48,48 +48,72 @@ func (l *tapListener) Init(md mdata.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "udp4" } - l.addr, err = net.ResolveUDPAddr(network, l.saddr) + l.addr, err = net.ResolveUDPAddr(network, l.options.Addr) if err != nil { return } - ifce, ip, err := l.createTap() - if err != nil { - if ifce != nil { - ifce.Close() - } - return - } - - itf, err := net.InterfaceByName(ifce.Name()) - if err != nil { - return - } - - addrs, _ := itf.Addrs() - l.logger.Infof("name: %s, mac: %s, mtu: %d, addrs: %s", - itf.Name, itf.HardwareAddr, itf.MTU, addrs) - l.cqueue = make(chan net.Conn, 1) l.closed = make(chan struct{}) - var c net.Conn - c = &conn{ - ifce: ifce, - laddr: l.addr, - raddr: &net.IPAddr{IP: ip}, - } - c = metrics.WrapConn(l.options.Service, c) - c = limiter.WrapConn(l.options.TrafficLimiter, c) - c = withMetadata(mdx.NewMetadata(map[string]any{ - "config": l.md.config, - }), c) - - l.cqueue <- c + go l.listenLoop() return } +func (l *tapListener) listenLoop() { + for { + ctx, cancel := context.WithCancel(context.Background()) + err := func() error { + ifce, name, ip, err := l.createTap() + if err != nil { + if ifce != nil { + ifce.Close() + } + return err + } + + itf, err := net.InterfaceByName(name) + if err != nil { + return err + } + + addrs, _ := itf.Addrs() + l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s", + itf.Name, ip, itf.MTU, addrs) + + var c net.Conn + c = &conn{ + ifce: ifce, + laddr: l.addr, + raddr: &net.IPAddr{IP: ip}, + cancel: cancel, + } + c = metrics.WrapConn(l.options.Service, c) + c = limiter.WrapConn(l.options.TrafficLimiter, c) + c = withMetadata(mdx.NewMetadata(map[string]any{ + "config": l.md.config, + }), c) + + l.cqueue <- c + + return nil + }() + if err != nil { + l.logger.Error(err) + cancel() + } + + select { + case <-ctx.Done(): + case <-l.closed: + return + } + + time.Sleep(time.Second) + } +} + func (l *tapListener) Accept() (net.Conn, error) { select { case conn := <-l.cqueue: diff --git a/listener/tap/metadata.go b/listener/tap/metadata.go index 2730c8d..d69c6ab 100644 --- a/listener/tap/metadata.go +++ b/listener/tap/metadata.go @@ -1,6 +1,9 @@ package tap import ( + "net" + "strings" + mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" tap_util "github.com/go-gost/x/internal/util/tap" @@ -19,6 +22,7 @@ func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) { name = "name" netKey = "net" mtu = "mtu" + route = "route" routes = "routes" gateway = "gw" ) @@ -33,9 +37,35 @@ func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) { config.MTU = DefaultMTU } + gw := net.ParseIP(config.Gateway) + + for _, s := range strings.Split(mdutil.GetString(md, route), ",") { + var route tap_util.Route + _, ipNet, _ := net.ParseCIDR(strings.TrimSpace(s)) + if ipNet == nil { + continue + } + route.Net = *ipNet + route.Gateway = gw + + config.Routes = append(config.Routes, route) + } + for _, s := range mdutil.GetStrings(md, routes) { - if s != "" { - config.Routes = append(config.Routes, s) + ss := strings.SplitN(s, " ", 2) + if len(ss) == 2 { + var route tap_util.Route + _, ipNet, _ := net.ParseCIDR(strings.TrimSpace(ss[0])) + if ipNet == nil { + continue + } + route.Net = *ipNet + route.Gateway = net.ParseIP(ss[1]) + if route.Gateway == nil { + route.Gateway = gw + } + + config.Routes = append(config.Routes, route) } } diff --git a/listener/tap/tap_linux.go b/listener/tap/tap_linux.go index c10055a..094ab32 100644 --- a/listener/tap/tap_linux.go +++ b/listener/tap/tap_linux.go @@ -2,15 +2,16 @@ package tap import ( "fmt" + "io" "net" - "os/exec" - "strings" + tap_util "github.com/go-gost/x/internal/util/tap" "github.com/songgao/water" + "github.com/vishvananda/netlink" ) -func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { - ifce, err = water.New(water.Config{ +func (l *tapListener) createTap() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { + tap, err := water.New(water.Config{ DeviceType: water.TAP, PlatformSpecificParams: water.PlatformSpecificParams{ Name: l.md.config.Name, @@ -20,46 +21,61 @@ func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) return } - if err = l.exeCmd(fmt.Sprintf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU)); err != nil { - l.logger.Warn(err) + dev = tap + name = tap.Name() + + ifce, err := net.InterfaceByName(name) + if err != nil { + return + } + + link, err := netlink.LinkByName(name) + if err != nil { + return + } + + if err = netlink.LinkSetMTU(link, l.md.config.MTU); err != nil { + return } if l.md.config.Net != "" { - if err = l.exeCmd(fmt.Sprintf("ip address add %s dev %s", l.md.config.Net, ifce.Name())); err != nil { - l.logger.Warn(err) + var ipNet *net.IPNet + ip, ipNet, err = net.ParseCIDR(l.md.config.Net) + if err != nil { + return + } + + if err = netlink.AddrAdd(link, &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + }); err != nil { + return } } - - if err = l.exeCmd(fmt.Sprintf("ip link set dev %s up", ifce.Name())); err != nil { - l.logger.Warn(err) + if err = netlink.LinkSetUp(link); err != nil { + return } - if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + if err = l.addRoutes(ifce, l.md.config.Routes...); err != nil { return } return } -func (l *tapListener) exeCmd(cmd string) error { - l.logger.Debug(cmd) - - args := strings.Split(cmd, " ") - if err := exec.Command(args[0], args[1:]...).Run(); err != nil { - return fmt.Errorf("%s: %v", cmd, err) - } - - return nil -} - -func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { +func (l *tapListener) addRoutes(ifce *net.Interface, routes ...tap_util.Route) error { for _, route := range routes { - cmd := fmt.Sprintf("ip route add %s via %s dev %s", route, gw, ifName) - l.logger.Debug(cmd) - - args := strings.Split(cmd, " ") - if er := exec.Command(args[0], args[1:]...).Run(); er != nil { - l.logger.Warnf("%s: %v", cmd, er) + r := netlink.Route{ + Dst: &route.Net, + Gw: route.Gateway, + } + if r.Gw == nil { + r.LinkIndex = ifce.Index + } + if err := netlink.RouteReplace(&r); err != nil { + return fmt.Errorf("add route %v %v: %v", r.Dst, r.Gw, err) } } return nil diff --git a/listener/tap/tap_unix.go b/listener/tap/tap_unix.go index cebd27d..4025495 100644 --- a/listener/tap/tap_unix.go +++ b/listener/tap/tap_unix.go @@ -4,23 +4,28 @@ package tap import ( "fmt" + "io" "net" "os/exec" "strings" + tap_util "github.com/go-gost/x/internal/util/tap" "github.com/songgao/water" ) -func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { +func (l *tapListener) createTap() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { ip, _, _ = net.ParseCIDR(l.md.config.Net) - ifce, err = water.New(water.Config{ + ifce, err := water.New(water.Config{ DeviceType: water.TAP, }) if err != nil { return } + dev = ifce + name = ifce.Name() + var cmd string if l.md.config.Net != "" { cmd = fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), l.md.config.Net, l.md.config.MTU) @@ -35,21 +40,18 @@ func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) return } - if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil { return } return } -func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { +func (l *tapListener) addRoutes(ifName string, routes ...tap_util.Route) error { for _, route := range routes { - if route == "" { - continue - } - cmd := fmt.Sprintf("route add -net %s dev %s", route, ifName) - if gw != "" { - cmd += " gw " + gw + cmd := fmt.Sprintf("route add -net %s dev %s", route.Net.String(), ifName) + if route.Gateway != nil { + cmd += " gw " + route.Gateway.String() } l.logger.Debug(cmd) args := strings.Split(cmd, " ") diff --git a/listener/tap/tap_windows.go b/listener/tap/tap_windows.go index e3a8936..779e544 100644 --- a/listener/tap/tap_windows.go +++ b/listener/tap/tap_windows.go @@ -2,17 +2,19 @@ package tap import ( "fmt" + "io" "net" "os/exec" "strings" + tap_util "github.com/go-gost/x/internal/util/tap" "github.com/songgao/water" ) -func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { +func (l *tapListener) createTap() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { ip, ipNet, _ := net.ParseCIDR(l.md.config.Net) - ifce, err = water.New(water.Config{ + ifce, err := water.New(water.Config{ DeviceType: water.TAP, PlatformSpecificParams: water.PlatformSpecificParams{ ComponentID: "tap0901", @@ -24,6 +26,9 @@ func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) return } + dev = ifce + name = ifce.Name() + if ip != nil && ipNet != nil { cmd := fmt.Sprintf("netsh interface ip set address name=%s "+ "source=static addr=%s mask=%s gateway=none", @@ -37,21 +42,21 @@ func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) } } - if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil { return } return } -func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { +func (l *tapListener) addRoutes(ifName string, routes ...tap_util.Route) error { for _, route := range routes { l.deleteRoute(ifName, route) cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", - route, ifName) - if gw != "" { - cmd += " nexthop=" + gw + route.Net.String(), ifName) + if route.Gateway != nil { + cmd += " nexthop=" + route.Gateway.String() } l.logger.Debug(cmd) args := strings.Split(cmd, " ") @@ -62,9 +67,9 @@ func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) erro return nil } -func (l *tapListener) deleteRoute(ifName string, route string) error { +func (l *tapListener) deleteRoute(ifName string, route tap_util.Route) error { cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active", - route, ifName) + route.Net.String(), ifName) l.logger.Debug(cmd) args := strings.Split(cmd, " ") return exec.Command(args[0], args[1:]...).Run()