tun: support multiple IPs
This commit is contained in:
parent
67bbdbf5a3
commit
fb29d5c80e
@ -16,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// 4-byte magic header followed by 16-byte IP address followed by 16-byte key.
|
||||
keepAliveDataLength = 36
|
||||
// 4-byte magic header followed by 16-byte key.
|
||||
keepAliveHeaderLength = 20
|
||||
)
|
||||
|
||||
var (
|
||||
@ -25,9 +25,12 @@ var (
|
||||
)
|
||||
|
||||
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr string, config *tun_util.Config, log logger.Logger) error {
|
||||
ip, _, err := net.ParseCIDR(config.Net)
|
||||
if err != nil {
|
||||
return err
|
||||
var ips []net.IP
|
||||
for _, net := range config.Net {
|
||||
ips = append(ips, net.IP)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return ErrInvalidNet
|
||||
}
|
||||
|
||||
for {
|
||||
@ -41,9 +44,9 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go h.keepAlive(ctx, cc, ip)
|
||||
go h.keepAlive(ctx, cc, ips)
|
||||
|
||||
return h.transportClient(conn, cc, config, log)
|
||||
return h.transportClient(conn, cc, log)
|
||||
}()
|
||||
if err == ErrTun {
|
||||
return err
|
||||
@ -54,13 +57,19 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
||||
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ips []net.IP) {
|
||||
// handshake
|
||||
var keepAliveData [keepAliveDataLength]byte
|
||||
copy(keepAliveData[:4], magicHeader) // magic header
|
||||
copy(keepAliveData[4:20], ip.To16())
|
||||
copy(keepAliveData[20:36], []byte(h.md.passphrase))
|
||||
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
||||
keepAliveData := bufpool.Get(keepAliveHeaderLength + len(ips)*net.IPv6len)
|
||||
defer bufpool.Put(keepAliveData)
|
||||
|
||||
copy((*keepAliveData)[:4], magicHeader) // magic header
|
||||
copy((*keepAliveData)[4:20], []byte(h.md.passphrase))
|
||||
pos := 20
|
||||
for _, ip := range ips {
|
||||
copy((*keepAliveData)[pos:pos+net.IPv6len], ip.To16())
|
||||
pos += net.IPv6len
|
||||
}
|
||||
if _, err := conn.Write((*keepAliveData)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -75,7 +84,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
||||
if _, err := conn.Write((*keepAliveData)); err != nil {
|
||||
return
|
||||
}
|
||||
h.options.Logger.Debugf("keepalive sended")
|
||||
@ -85,7 +94,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *tun_util.Config, log logger.Logger) error {
|
||||
func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logger.Logger) error {
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
@ -147,7 +156,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *t
|
||||
return err
|
||||
}
|
||||
|
||||
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||
if n == keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||
ip := net.IP((*b)[4:20])
|
||||
log.Debugf("keepalive received at %v", ip)
|
||||
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrTun = errors.New("tun device error")
|
||||
ErrInvalidNet = errors.New("invalid net IP")
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -37,11 +37,6 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu
|
||||
}
|
||||
|
||||
func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
|
||||
tunIP, _, err := net.ParseCIDR(config.Net)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
@ -115,33 +110,49 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||
peerIP := net.IP((*b)[4:20])
|
||||
key := bytes.TrimRight((*b)[20:36], "\x00")
|
||||
|
||||
if peerIP.Equal(tunIP.To16()) {
|
||||
if n > keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||
var peerIPs []net.IP
|
||||
data := (*b)[keepAliveHeaderLength:n]
|
||||
if len(data)%net.IPv6len == 0 {
|
||||
for len(data) > 0 {
|
||||
peerIPs = append(peerIPs, net.IP(data[:net.IPv6len]))
|
||||
data = data[net.IPv6len:]
|
||||
}
|
||||
}
|
||||
if len(peerIPs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, net := range config.Net {
|
||||
for _, ip := range peerIPs {
|
||||
if ip.Equal(net.IP.To16()) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if auther := h.options.Auther; auther != nil {
|
||||
ip := peerIP
|
||||
if v := peerIP.To4(); ip != nil {
|
||||
ip = v
|
||||
ok := true
|
||||
key := bytes.TrimRight((*b)[4:20], "\x00")
|
||||
for _, ip := range peerIPs {
|
||||
if ok = auther.Authenticate(ip.String(), string(key)); !ok {
|
||||
break
|
||||
}
|
||||
if !auther.Authenticate(ip.String(), string(key)) {
|
||||
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP)
|
||||
}
|
||||
if !ok {
|
||||
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIPs)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("keepalive from %v => %v", addr, peerIP)
|
||||
log.Debugf("keepalive from %v => %v", addr, peerIPs)
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(addr.String())
|
||||
if err != nil {
|
||||
log.Warnf("keepalive from %v: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
var keepAliveData [keepAliveDataLength]byte
|
||||
var keepAliveData [keepAliveHeaderLength]byte
|
||||
copy(keepAliveData[:4], magicHeader) // magic header
|
||||
a16 := addrPort.Addr().As16()
|
||||
copy(keepAliveData[4:], a16[:])
|
||||
@ -151,7 +162,9 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
|
||||
return nil
|
||||
}
|
||||
|
||||
h.updateRoute(peerIP, addr, log)
|
||||
for _, ip := range peerIPs {
|
||||
h.updateRoute(ip, addr, log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -204,7 +217,7 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
|
||||
}
|
||||
}()
|
||||
|
||||
err = <-errc
|
||||
err := <-errc
|
||||
if err != nil && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
|
@ -10,10 +10,10 @@ type Route struct {
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Net string
|
||||
Net []net.IPNet
|
||||
// peer addr of point-to-point on MacOS
|
||||
Peer string
|
||||
MTU int
|
||||
Gateway string
|
||||
Gateway net.IP
|
||||
Routes []Route
|
||||
}
|
||||
|
@ -30,16 +30,29 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
|
||||
config := &tun_util.Config{
|
||||
Name: mdutil.GetString(md, name),
|
||||
Net: mdutil.GetString(md, netKey),
|
||||
Peer: mdutil.GetString(md, peer),
|
||||
MTU: mdutil.GetInt(md, mtu),
|
||||
Gateway: mdutil.GetString(md, gateway),
|
||||
}
|
||||
if config.MTU <= 0 {
|
||||
config.MTU = DefaultMTU
|
||||
}
|
||||
if gw := mdutil.GetString(md, gateway); gw != "" {
|
||||
config.Gateway = net.ParseIP(gw)
|
||||
}
|
||||
|
||||
gw := net.ParseIP(config.Gateway)
|
||||
for _, s := range strings.Split(mdutil.GetString(md, netKey), ",") {
|
||||
if s = strings.TrimSpace(s); s == "" {
|
||||
continue
|
||||
}
|
||||
ip, ipNet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
config.Net = append(config.Net, net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ipNet.Mask,
|
||||
})
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(mdutil.GetString(md, route), ",") {
|
||||
var route tun_util.Route
|
||||
@ -48,7 +61,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
continue
|
||||
}
|
||||
route.Net = *ipNet
|
||||
route.Gateway = gw
|
||||
route.Gateway = config.Gateway
|
||||
|
||||
config.Routes = append(config.Routes, route)
|
||||
}
|
||||
@ -64,7 +77,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
route.Net = *ipNet
|
||||
route.Gateway = net.ParseIP(ss[1])
|
||||
if route.Gateway == nil {
|
||||
route.Gateway = gw
|
||||
route.Gateway = config.Gateway
|
||||
}
|
||||
|
||||
config.Routes = append(config.Routes, route)
|
||||
|
@ -15,11 +15,6 @@ const (
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
|
||||
ip, _, err = net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.config.Name == "" {
|
||||
l.md.config.Name = defaultTunName
|
||||
}
|
||||
@ -32,14 +27,16 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
if peer == "" {
|
||||
peer = ip.String()
|
||||
}
|
||||
if len(l.md.config.Net) > 0 {
|
||||
cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up",
|
||||
name, l.md.config.Net, l.md.config.Peer, l.md.config.MTU)
|
||||
name, l.md.config.Net[0].String(), l.md.config.Peer, l.md.config.MTU)
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if err = exec.Command(args[0], args[1:]...).Run(); err != nil {
|
||||
return
|
||||
}
|
||||
ip = l.md.config.Net[0].IP
|
||||
}
|
||||
|
||||
if err = l.addRoutes(name, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
|
@ -11,11 +11,6 @@ import (
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.IP, err error) {
|
||||
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, name, err = l.createTunDevice()
|
||||
if err != nil {
|
||||
return
|
||||
@ -31,14 +26,18 @@ func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.I
|
||||
return
|
||||
}
|
||||
|
||||
for _, net := range l.md.config.Net {
|
||||
if err = netlink.AddrAdd(link, &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ipNet.Mask,
|
||||
},
|
||||
IPNet: &net,
|
||||
}); err != nil {
|
||||
return
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(l.md.config.Net) > 0 {
|
||||
ip = l.md.config.Net[0].IP
|
||||
}
|
||||
|
||||
if err = netlink.LinkSetUp(link); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -17,11 +17,6 @@ const (
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
|
||||
ip, _, err = net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.config.Name == "" {
|
||||
l.md.config.Name = defaultTunName
|
||||
}
|
||||
@ -30,8 +25,9 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
return
|
||||
}
|
||||
|
||||
if len(l.md.config.Net) > 0 {
|
||||
cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up",
|
||||
name, l.md.config.Net, l.md.config.MTU)
|
||||
name, l.md.config.Net[0].String(), l.md.config.MTU)
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
@ -39,6 +35,8 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
ip = l.md.config.Net[0].IP
|
||||
}
|
||||
|
||||
if err = l.addRoutes(name, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
|
@ -15,11 +15,6 @@ const (
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.IP, err error) {
|
||||
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.config.Name == "" {
|
||||
l.md.config.Name = defaultTunName
|
||||
}
|
||||
@ -28,9 +23,11 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
return
|
||||
}
|
||||
|
||||
if len(l.md.config.Net) > 0 {
|
||||
ipNet := l.md.config.Net[0]
|
||||
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
|
||||
"source=static addr=%s mask=%s gateway=none",
|
||||
name, ip.String(), ipMask(ipNet.Mask))
|
||||
name, ipNet.IP.String(), ipMask(ipNet.Mask))
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
@ -38,6 +35,8 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
ip = ipNet.IP
|
||||
}
|
||||
|
||||
if err = l.addRoutes(name, l.md.config.Gateway, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
@ -46,14 +45,14 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net.
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error {
|
||||
func (l *tunListener) addRoutes(ifName string, gw net.IP, routes ...tun_util.Route) error {
|
||||
for _, route := range routes {
|
||||
l.deleteRoute(ifName, route.Net.String())
|
||||
|
||||
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
|
||||
route.Net.String(), ifName)
|
||||
if gw != "" {
|
||||
cmd += " nexthop=" + gw
|
||||
if gw != nil {
|
||||
cmd += " nexthop=" + gw.String()
|
||||
}
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
|
Loading…
Reference in New Issue
Block a user