diff --git a/handler/tun/client.go b/handler/tun/client.go index eb29bec..3eab336 100644 --- a/handler/tun/client.go +++ b/handler/tun/client.go @@ -16,8 +16,8 @@ import ( ) const ( - // 4-byte magic header followed by 16-byte IP address followed by 16-byte key. - keepAliveDataLength = 36 + // 4-byte magic header followed by 16-byte key. + keepAliveHeaderLength = 20 ) var ( @@ -25,9 +25,12 @@ var ( ) func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr string, config *tun_util.Config, log logger.Logger) error { - ip, _, err := net.ParseCIDR(config.Net) - if err != nil { - return err + var ips []net.IP + for _, net := range config.Net { + ips = append(ips, net.IP) + } + if len(ips) == 0 { + return ErrInvalidNet } for { @@ -41,9 +44,9 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri ctx, cancel := context.WithCancel(ctx) defer cancel() - go h.keepAlive(ctx, cc, ip) + go h.keepAlive(ctx, cc, ips) - return h.transportClient(conn, cc, config, log) + return h.transportClient(conn, cc, log) }() if err == ErrTun { return err @@ -54,13 +57,19 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri } } -func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { +func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ips []net.IP) { // handshake - var keepAliveData [keepAliveDataLength]byte - copy(keepAliveData[:4], magicHeader) // magic header - copy(keepAliveData[4:20], ip.To16()) - copy(keepAliveData[20:36], []byte(h.md.passphrase)) - if _, err := conn.Write(keepAliveData[:]); err != nil { + keepAliveData := bufpool.Get(keepAliveHeaderLength + len(ips)*net.IPv6len) + defer bufpool.Put(keepAliveData) + + copy((*keepAliveData)[:4], magicHeader) // magic header + copy((*keepAliveData)[4:20], []byte(h.md.passphrase)) + pos := 20 + for _, ip := range ips { + copy((*keepAliveData)[pos:pos+net.IPv6len], ip.To16()) + pos += net.IPv6len + } + if _, err := conn.Write((*keepAliveData)); err != nil { return } @@ -75,7 +84,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { for { select { case <-ticker.C: - if _, err := conn.Write(keepAliveData[:]); err != nil { + if _, err := conn.Write((*keepAliveData)); err != nil { return } h.options.Logger.Debugf("keepalive sended") @@ -85,7 +94,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { } } -func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *tun_util.Config, log logger.Logger) error { +func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logger.Logger) error { errc := make(chan error, 1) go func() { @@ -147,7 +156,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *t return err } - if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { + if n == keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) { ip := net.IP((*b)[4:20]) log.Debugf("keepalive received at %v", ip) diff --git a/handler/tun/handler.go b/handler/tun/handler.go index abfc333..cc67b09 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -17,7 +17,8 @@ import ( ) var ( - ErrTun = errors.New("tun device error") + ErrTun = errors.New("tun device error") + ErrInvalidNet = errors.New("invalid net IP") ) func init() { diff --git a/handler/tun/server.go b/handler/tun/server.go index a7bca79..e158064 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -37,11 +37,6 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu } func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { - tunIP, _, err := net.ParseCIDR(config.Net) - if err != nil { - return err - } - errc := make(chan error, 1) go func() { @@ -115,33 +110,49 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con if err != nil { return err } - if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { - peerIP := net.IP((*b)[4:20]) - key := bytes.TrimRight((*b)[20:36], "\x00") - - if peerIP.Equal(tunIP.To16()) { + if n > keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) { + var peerIPs []net.IP + data := (*b)[keepAliveHeaderLength:n] + if len(data)%net.IPv6len == 0 { + for len(data) > 0 { + peerIPs = append(peerIPs, net.IP(data[:net.IPv6len])) + data = data[net.IPv6len:] + } + } + if len(peerIPs) == 0 { return nil } - if auther := h.options.Auther; auther != nil { - ip := peerIP - if v := peerIP.To4(); ip != nil { - ip = v + for _, net := range config.Net { + for _, ip := range peerIPs { + if ip.Equal(net.IP.To16()) { + return nil + } } - if !auther.Authenticate(ip.String(), string(key)) { - log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP) + } + + if auther := h.options.Auther; auther != nil { + ok := true + key := bytes.TrimRight((*b)[4:20], "\x00") + for _, ip := range peerIPs { + if ok = auther.Authenticate(ip.String(), string(key)); !ok { + break + } + } + if !ok { + log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIPs) return nil } } - log.Debugf("keepalive from %v => %v", addr, peerIP) + log.Debugf("keepalive from %v => %v", addr, peerIPs) addrPort, err := netip.ParseAddrPort(addr.String()) if err != nil { log.Warnf("keepalive from %v: %v", addr, err) return nil } - var keepAliveData [keepAliveDataLength]byte + var keepAliveData [keepAliveHeaderLength]byte copy(keepAliveData[:4], magicHeader) // magic header a16 := addrPort.Addr().As16() copy(keepAliveData[4:], a16[:]) @@ -151,7 +162,9 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con return nil } - h.updateRoute(peerIP, addr, log) + for _, ip := range peerIPs { + h.updateRoute(ip, addr, log) + } return nil } @@ -204,7 +217,7 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con } }() - err = <-errc + err := <-errc if err != nil && err == io.EOF { err = nil } diff --git a/internal/util/tun/config.go b/internal/util/tun/config.go index 6448cae..899deb5 100644 --- a/internal/util/tun/config.go +++ b/internal/util/tun/config.go @@ -10,10 +10,10 @@ type Route struct { type Config struct { Name string - Net string + Net []net.IPNet // peer addr of point-to-point on MacOS Peer string MTU int - Gateway string + Gateway net.IP Routes []Route } diff --git a/listener/tun/metadata.go b/listener/tun/metadata.go index 75172e0..ec3d45d 100644 --- a/listener/tun/metadata.go +++ b/listener/tun/metadata.go @@ -29,17 +29,30 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { ) config := &tun_util.Config{ - Name: mdutil.GetString(md, name), - Net: mdutil.GetString(md, netKey), - Peer: mdutil.GetString(md, peer), - MTU: mdutil.GetInt(md, mtu), - Gateway: mdutil.GetString(md, gateway), + Name: mdutil.GetString(md, name), + Peer: mdutil.GetString(md, peer), + MTU: mdutil.GetInt(md, mtu), } if config.MTU <= 0 { config.MTU = DefaultMTU } + if gw := mdutil.GetString(md, gateway); gw != "" { + config.Gateway = net.ParseIP(gw) + } - gw := net.ParseIP(config.Gateway) + for _, s := range strings.Split(mdutil.GetString(md, netKey), ",") { + if s = strings.TrimSpace(s); s == "" { + continue + } + ip, ipNet, err := net.ParseCIDR(s) + if err != nil { + continue + } + config.Net = append(config.Net, net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }) + } for _, s := range strings.Split(mdutil.GetString(md, route), ",") { var route tun_util.Route @@ -48,7 +61,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { continue } route.Net = *ipNet - route.Gateway = gw + route.Gateway = config.Gateway config.Routes = append(config.Routes, route) } @@ -64,7 +77,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { route.Net = *ipNet route.Gateway = net.ParseIP(ss[1]) if route.Gateway == nil { - route.Gateway = gw + route.Gateway = config.Gateway } config.Routes = append(config.Routes, route) diff --git a/listener/tun/tun_darwin.go b/listener/tun/tun_darwin.go index 8c22d9e..9cb4eba 100644 --- a/listener/tun/tun_darwin.go +++ b/listener/tun/tun_darwin.go @@ -15,11 +15,6 @@ const ( ) func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { - ip, _, err = net.ParseCIDR(l.md.config.Net) - if err != nil { - return - } - if l.md.config.Name == "" { l.md.config.Name = defaultTunName } @@ -32,13 +27,15 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. if peer == "" { peer = ip.String() } - cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up", - name, l.md.config.Net, l.md.config.Peer, l.md.config.MTU) - l.logger.Debug(cmd) - - args := strings.Split(cmd, " ") - if err = exec.Command(args[0], args[1:]...).Run(); err != nil { - return + if len(l.md.config.Net) > 0 { + cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up", + name, l.md.config.Net[0].String(), l.md.config.Peer, l.md.config.MTU) + l.logger.Debug(cmd) + args := strings.Split(cmd, " ") + if err = exec.Command(args[0], args[1:]...).Run(); err != nil { + return + } + ip = l.md.config.Net[0].IP } if err = l.addRoutes(name, l.md.config.Routes...); err != nil { diff --git a/listener/tun/tun_linux.go b/listener/tun/tun_linux.go index d79dfe6..be66f2c 100644 --- a/listener/tun/tun_linux.go +++ b/listener/tun/tun_linux.go @@ -11,11 +11,6 @@ import ( ) func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { - ip, ipNet, err := net.ParseCIDR(l.md.config.Net) - if err != nil { - return - } - dev, name, err = l.createTunDevice() if err != nil { return @@ -31,14 +26,18 @@ func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.I return } - if err = netlink.AddrAdd(link, &netlink.Addr{ - IPNet: &net.IPNet{ - IP: ip, - Mask: ipNet.Mask, - }, - }); err != nil { - return + for _, net := range l.md.config.Net { + if err = netlink.AddrAdd(link, &netlink.Addr{ + IPNet: &net, + }); err != nil { + l.logger.Error(err) + continue + } } + if len(l.md.config.Net) > 0 { + ip = l.md.config.Net[0].IP + } + if err = netlink.LinkSetUp(link); err != nil { return } diff --git a/listener/tun/tun_unix.go b/listener/tun/tun_unix.go index 83d434d..191607b 100644 --- a/listener/tun/tun_unix.go +++ b/listener/tun/tun_unix.go @@ -17,11 +17,6 @@ const ( ) func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { - ip, _, err = net.ParseCIDR(l.md.config.Net) - if err != nil { - return - } - if l.md.config.Name == "" { l.md.config.Name = defaultTunName } @@ -30,14 +25,17 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. return } - cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up", - name, l.md.config.Net, l.md.config.MTU) - l.logger.Debug(cmd) + if len(l.md.config.Net) > 0 { + cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up", + name, l.md.config.Net[0].String(), l.md.config.MTU) + l.logger.Debug(cmd) - args := strings.Split(cmd, " ") - if er := exec.Command(args[0], args[1:]...).Run(); er != nil { - err = fmt.Errorf("%s: %v", cmd, er) - return + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + ip = l.md.config.Net[0].IP } if err = l.addRoutes(name, l.md.config.Routes...); err != nil { diff --git a/listener/tun/tun_windows.go b/listener/tun/tun_windows.go index 874f010..971dc2c 100644 --- a/listener/tun/tun_windows.go +++ b/listener/tun/tun_windows.go @@ -15,11 +15,6 @@ const ( ) func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { - ip, ipNet, err := net.ParseCIDR(l.md.config.Net) - if err != nil { - return - } - if l.md.config.Name == "" { l.md.config.Name = defaultTunName } @@ -28,15 +23,19 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. return } - cmd := fmt.Sprintf("netsh interface ip set address name=%s "+ - "source=static addr=%s mask=%s gateway=none", - name, ip.String(), ipMask(ipNet.Mask)) - l.logger.Debug(cmd) + if len(l.md.config.Net) > 0 { + ipNet := l.md.config.Net[0] + cmd := fmt.Sprintf("netsh interface ip set address name=%s "+ + "source=static addr=%s mask=%s gateway=none", + name, ipNet.IP.String(), ipMask(ipNet.Mask)) + l.logger.Debug(cmd) - args := strings.Split(cmd, " ") - if er := exec.Command(args[0], args[1:]...).Run(); er != nil { - err = fmt.Errorf("%s: %v", cmd, er) - return + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + ip = ipNet.IP } if err = l.addRoutes(name, l.md.config.Gateway, l.md.config.Routes...); err != nil { @@ -46,14 +45,14 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. return } -func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error { +func (l *tunListener) addRoutes(ifName string, gw net.IP, routes ...tun_util.Route) error { for _, route := range routes { l.deleteRoute(ifName, route.Net.String()) cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", route.Net.String(), ifName) - if gw != "" { - cmd += " nexthop=" + gw + if gw != nil { + cmd += " nexthop=" + gw.String() } l.logger.Debug(cmd) args := strings.Split(cmd, " ")