tun: support multiple IPs

This commit is contained in:
ginuerzh 2022-12-22 17:44:30 +08:00
parent 67bbdbf5a3
commit fb29d5c80e
9 changed files with 128 additions and 99 deletions

View File

@ -16,8 +16,8 @@ import (
) )
const ( const (
// 4-byte magic header followed by 16-byte IP address followed by 16-byte key. // 4-byte magic header followed by 16-byte key.
keepAliveDataLength = 36 keepAliveHeaderLength = 20
) )
var ( var (
@ -25,9 +25,12 @@ var (
) )
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr string, config *tun_util.Config, log logger.Logger) error { func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr string, config *tun_util.Config, log logger.Logger) error {
ip, _, err := net.ParseCIDR(config.Net) var ips []net.IP
if err != nil { for _, net := range config.Net {
return err ips = append(ips, net.IP)
}
if len(ips) == 0 {
return ErrInvalidNet
} }
for { for {
@ -41,9 +44,9 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
go h.keepAlive(ctx, cc, ip) go h.keepAlive(ctx, cc, ips)
return h.transportClient(conn, cc, config, log) return h.transportClient(conn, cc, log)
}() }()
if err == ErrTun { if err == ErrTun {
return err return err
@ -54,13 +57,19 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
} }
} }
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ips []net.IP) {
// handshake // handshake
var keepAliveData [keepAliveDataLength]byte keepAliveData := bufpool.Get(keepAliveHeaderLength + len(ips)*net.IPv6len)
copy(keepAliveData[:4], magicHeader) // magic header defer bufpool.Put(keepAliveData)
copy(keepAliveData[4:20], ip.To16())
copy(keepAliveData[20:36], []byte(h.md.passphrase)) copy((*keepAliveData)[:4], magicHeader) // magic header
if _, err := conn.Write(keepAliveData[:]); err != nil { copy((*keepAliveData)[4:20], []byte(h.md.passphrase))
pos := 20
for _, ip := range ips {
copy((*keepAliveData)[pos:pos+net.IPv6len], ip.To16())
pos += net.IPv6len
}
if _, err := conn.Write((*keepAliveData)); err != nil {
return return
} }
@ -75,7 +84,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if _, err := conn.Write(keepAliveData[:]); err != nil { if _, err := conn.Write((*keepAliveData)); err != nil {
return return
} }
h.options.Logger.Debugf("keepalive sended") h.options.Logger.Debugf("keepalive sended")
@ -85,7 +94,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
} }
} }
func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *tun_util.Config, log logger.Logger) error { func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logger.Logger) error {
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -147,7 +156,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *t
return err return err
} }
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { if n == keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
ip := net.IP((*b)[4:20]) ip := net.IP((*b)[4:20])
log.Debugf("keepalive received at %v", ip) log.Debugf("keepalive received at %v", ip)

View File

@ -17,7 +17,8 @@ import (
) )
var ( var (
ErrTun = errors.New("tun device error") ErrTun = errors.New("tun device error")
ErrInvalidNet = errors.New("invalid net IP")
) )
func init() { func init() {

View File

@ -37,11 +37,6 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu
} }
func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
tunIP, _, err := net.ParseCIDR(config.Net)
if err != nil {
return err
}
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -115,33 +110,49 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
if err != nil { if err != nil {
return err return err
} }
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { if n > keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
peerIP := net.IP((*b)[4:20]) var peerIPs []net.IP
key := bytes.TrimRight((*b)[20:36], "\x00") data := (*b)[keepAliveHeaderLength:n]
if len(data)%net.IPv6len == 0 {
if peerIP.Equal(tunIP.To16()) { for len(data) > 0 {
peerIPs = append(peerIPs, net.IP(data[:net.IPv6len]))
data = data[net.IPv6len:]
}
}
if len(peerIPs) == 0 {
return nil return nil
} }
if auther := h.options.Auther; auther != nil { for _, net := range config.Net {
ip := peerIP for _, ip := range peerIPs {
if v := peerIP.To4(); ip != nil { if ip.Equal(net.IP.To16()) {
ip = v return nil
}
} }
if !auther.Authenticate(ip.String(), string(key)) { }
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP)
if auther := h.options.Auther; auther != nil {
ok := true
key := bytes.TrimRight((*b)[4:20], "\x00")
for _, ip := range peerIPs {
if ok = auther.Authenticate(ip.String(), string(key)); !ok {
break
}
}
if !ok {
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIPs)
return nil return nil
} }
} }
log.Debugf("keepalive from %v => %v", addr, peerIP) log.Debugf("keepalive from %v => %v", addr, peerIPs)
addrPort, err := netip.ParseAddrPort(addr.String()) addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil { if err != nil {
log.Warnf("keepalive from %v: %v", addr, err) log.Warnf("keepalive from %v: %v", addr, err)
return nil return nil
} }
var keepAliveData [keepAliveDataLength]byte var keepAliveData [keepAliveHeaderLength]byte
copy(keepAliveData[:4], magicHeader) // magic header copy(keepAliveData[:4], magicHeader) // magic header
a16 := addrPort.Addr().As16() a16 := addrPort.Addr().As16()
copy(keepAliveData[4:], a16[:]) copy(keepAliveData[4:], a16[:])
@ -151,7 +162,9 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
return nil return nil
} }
h.updateRoute(peerIP, addr, log) for _, ip := range peerIPs {
h.updateRoute(ip, addr, log)
}
return nil return nil
} }
@ -204,7 +217,7 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
} }
}() }()
err = <-errc err := <-errc
if err != nil && err == io.EOF { if err != nil && err == io.EOF {
err = nil err = nil
} }

View File

@ -10,10 +10,10 @@ type Route struct {
type Config struct { type Config struct {
Name string Name string
Net string Net []net.IPNet
// peer addr of point-to-point on MacOS // peer addr of point-to-point on MacOS
Peer string Peer string
MTU int MTU int
Gateway string Gateway net.IP
Routes []Route Routes []Route
} }

View File

@ -29,17 +29,30 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
) )
config := &tun_util.Config{ config := &tun_util.Config{
Name: mdutil.GetString(md, name), Name: mdutil.GetString(md, name),
Net: mdutil.GetString(md, netKey), Peer: mdutil.GetString(md, peer),
Peer: mdutil.GetString(md, peer), MTU: mdutil.GetInt(md, mtu),
MTU: mdutil.GetInt(md, mtu),
Gateway: mdutil.GetString(md, gateway),
} }
if config.MTU <= 0 { if config.MTU <= 0 {
config.MTU = DefaultMTU config.MTU = DefaultMTU
} }
if gw := mdutil.GetString(md, gateway); gw != "" {
config.Gateway = net.ParseIP(gw)
}
gw := net.ParseIP(config.Gateway) for _, s := range strings.Split(mdutil.GetString(md, netKey), ",") {
if s = strings.TrimSpace(s); s == "" {
continue
}
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
continue
}
config.Net = append(config.Net, net.IPNet{
IP: ip,
Mask: ipNet.Mask,
})
}
for _, s := range strings.Split(mdutil.GetString(md, route), ",") { for _, s := range strings.Split(mdutil.GetString(md, route), ",") {
var route tun_util.Route var route tun_util.Route
@ -48,7 +61,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
continue continue
} }
route.Net = *ipNet route.Net = *ipNet
route.Gateway = gw route.Gateway = config.Gateway
config.Routes = append(config.Routes, route) config.Routes = append(config.Routes, route)
} }
@ -64,7 +77,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
route.Net = *ipNet route.Net = *ipNet
route.Gateway = net.ParseIP(ss[1]) route.Gateway = net.ParseIP(ss[1])
if route.Gateway == nil { if route.Gateway == nil {
route.Gateway = gw route.Gateway = config.Gateway
} }
config.Routes = append(config.Routes, route) config.Routes = append(config.Routes, route)

View File

@ -15,11 +15,6 @@ const (
) )
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
ip, _, err = net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
if l.md.config.Name == "" { if l.md.config.Name == "" {
l.md.config.Name = defaultTunName l.md.config.Name = defaultTunName
} }
@ -32,13 +27,15 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
if peer == "" { if peer == "" {
peer = ip.String() peer = ip.String()
} }
cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up", if len(l.md.config.Net) > 0 {
name, l.md.config.Net, l.md.config.Peer, l.md.config.MTU) cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up",
l.logger.Debug(cmd) name, l.md.config.Net[0].String(), l.md.config.Peer, l.md.config.MTU)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ") args := strings.Split(cmd, " ")
if err = exec.Command(args[0], args[1:]...).Run(); err != nil { if err = exec.Command(args[0], args[1:]...).Run(); err != nil {
return return
}
ip = l.md.config.Net[0].IP
} }
if err = l.addRoutes(name, l.md.config.Routes...); err != nil { if err = l.addRoutes(name, l.md.config.Routes...); err != nil {

View File

@ -11,11 +11,6 @@ import (
) )
func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.IP, err error) {
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
dev, name, err = l.createTunDevice() dev, name, err = l.createTunDevice()
if err != nil { if err != nil {
return return
@ -31,14 +26,18 @@ func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.I
return return
} }
if err = netlink.AddrAdd(link, &netlink.Addr{ for _, net := range l.md.config.Net {
IPNet: &net.IPNet{ if err = netlink.AddrAdd(link, &netlink.Addr{
IP: ip, IPNet: &net,
Mask: ipNet.Mask, }); err != nil {
}, l.logger.Error(err)
}); err != nil { continue
return }
} }
if len(l.md.config.Net) > 0 {
ip = l.md.config.Net[0].IP
}
if err = netlink.LinkSetUp(link); err != nil { if err = netlink.LinkSetUp(link); err != nil {
return return
} }

View File

@ -17,11 +17,6 @@ const (
) )
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
ip, _, err = net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
if l.md.config.Name == "" { if l.md.config.Name == "" {
l.md.config.Name = defaultTunName l.md.config.Name = defaultTunName
} }
@ -30,14 +25,17 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
return return
} }
cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up", if len(l.md.config.Net) > 0 {
name, l.md.config.Net, l.md.config.MTU) cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up",
l.logger.Debug(cmd) name, l.md.config.Net[0].String(), l.md.config.MTU)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ") args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil { if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er) err = fmt.Errorf("%s: %v", cmd, er)
return return
}
ip = l.md.config.Net[0].IP
} }
if err = l.addRoutes(name, l.md.config.Routes...); err != nil { if err = l.addRoutes(name, l.md.config.Routes...); err != nil {

View File

@ -15,11 +15,6 @@ const (
) )
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) { func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
if l.md.config.Name == "" { if l.md.config.Name == "" {
l.md.config.Name = defaultTunName l.md.config.Name = defaultTunName
} }
@ -28,15 +23,19 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
return return
} }
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+ if len(l.md.config.Net) > 0 {
"source=static addr=%s mask=%s gateway=none", ipNet := l.md.config.Net[0]
name, ip.String(), ipMask(ipNet.Mask)) cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
l.logger.Debug(cmd) "source=static addr=%s mask=%s gateway=none",
name, ipNet.IP.String(), ipMask(ipNet.Mask))
l.logger.Debug(cmd)
args := strings.Split(cmd, " ") args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil { if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er) err = fmt.Errorf("%s: %v", cmd, er)
return return
}
ip = ipNet.IP
} }
if err = l.addRoutes(name, l.md.config.Gateway, l.md.config.Routes...); err != nil { if err = l.addRoutes(name, l.md.config.Gateway, l.md.config.Routes...); err != nil {
@ -46,14 +45,14 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
return return
} }
func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error { func (l *tunListener) addRoutes(ifName string, gw net.IP, routes ...tun_util.Route) error {
for _, route := range routes { for _, route := range routes {
l.deleteRoute(ifName, route.Net.String()) l.deleteRoute(ifName, route.Net.String())
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
route.Net.String(), ifName) route.Net.String(), ifName)
if gw != "" { if gw != nil {
cmd += " nexthop=" + gw cmd += " nexthop=" + gw.String()
} }
l.logger.Debug(cmd) l.logger.Debug(cmd)
args := strings.Split(cmd, " ") args := strings.Split(cmd, " ")