tun: support multiple IPs

This commit is contained in:
ginuerzh
2022-12-22 17:44:30 +08:00
parent 67bbdbf5a3
commit fb29d5c80e
9 changed files with 128 additions and 99 deletions

View File

@ -16,8 +16,8 @@ import (
)
const (
// 4-byte magic header followed by 16-byte IP address followed by 16-byte key.
keepAliveDataLength = 36
// 4-byte magic header followed by 16-byte key.
keepAliveHeaderLength = 20
)
var (
@ -25,9 +25,12 @@ var (
)
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr string, config *tun_util.Config, log logger.Logger) error {
ip, _, err := net.ParseCIDR(config.Net)
if err != nil {
return err
var ips []net.IP
for _, net := range config.Net {
ips = append(ips, net.IP)
}
if len(ips) == 0 {
return ErrInvalidNet
}
for {
@ -41,9 +44,9 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go h.keepAlive(ctx, cc, ip)
go h.keepAlive(ctx, cc, ips)
return h.transportClient(conn, cc, config, log)
return h.transportClient(conn, cc, log)
}()
if err == ErrTun {
return err
@ -54,13 +57,19 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, raddr stri
}
}
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ips []net.IP) {
// handshake
var keepAliveData [keepAliveDataLength]byte
copy(keepAliveData[:4], magicHeader) // magic header
copy(keepAliveData[4:20], ip.To16())
copy(keepAliveData[20:36], []byte(h.md.passphrase))
if _, err := conn.Write(keepAliveData[:]); err != nil {
keepAliveData := bufpool.Get(keepAliveHeaderLength + len(ips)*net.IPv6len)
defer bufpool.Put(keepAliveData)
copy((*keepAliveData)[:4], magicHeader) // magic header
copy((*keepAliveData)[4:20], []byte(h.md.passphrase))
pos := 20
for _, ip := range ips {
copy((*keepAliveData)[pos:pos+net.IPv6len], ip.To16())
pos += net.IPv6len
}
if _, err := conn.Write((*keepAliveData)); err != nil {
return
}
@ -75,7 +84,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
for {
select {
case <-ticker.C:
if _, err := conn.Write(keepAliveData[:]); err != nil {
if _, err := conn.Write((*keepAliveData)); err != nil {
return
}
h.options.Logger.Debugf("keepalive sended")
@ -85,7 +94,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
}
}
func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *tun_util.Config, log logger.Logger) error {
func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logger.Logger) error {
errc := make(chan error, 1)
go func() {
@ -147,7 +156,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *t
return err
}
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
if n == keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
ip := net.IP((*b)[4:20])
log.Debugf("keepalive received at %v", ip)

View File

@ -17,7 +17,8 @@ import (
)
var (
ErrTun = errors.New("tun device error")
ErrTun = errors.New("tun device error")
ErrInvalidNet = errors.New("invalid net IP")
)
func init() {

View File

@ -37,11 +37,6 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu
}
func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
tunIP, _, err := net.ParseCIDR(config.Net)
if err != nil {
return err
}
errc := make(chan error, 1)
go func() {
@ -115,33 +110,49 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
if err != nil {
return err
}
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
peerIP := net.IP((*b)[4:20])
key := bytes.TrimRight((*b)[20:36], "\x00")
if peerIP.Equal(tunIP.To16()) {
if n > keepAliveHeaderLength && bytes.Equal((*b)[:4], magicHeader) {
var peerIPs []net.IP
data := (*b)[keepAliveHeaderLength:n]
if len(data)%net.IPv6len == 0 {
for len(data) > 0 {
peerIPs = append(peerIPs, net.IP(data[:net.IPv6len]))
data = data[net.IPv6len:]
}
}
if len(peerIPs) == 0 {
return nil
}
if auther := h.options.Auther; auther != nil {
ip := peerIP
if v := peerIP.To4(); ip != nil {
ip = v
for _, net := range config.Net {
for _, ip := range peerIPs {
if ip.Equal(net.IP.To16()) {
return nil
}
}
if !auther.Authenticate(ip.String(), string(key)) {
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP)
}
if auther := h.options.Auther; auther != nil {
ok := true
key := bytes.TrimRight((*b)[4:20], "\x00")
for _, ip := range peerIPs {
if ok = auther.Authenticate(ip.String(), string(key)); !ok {
break
}
}
if !ok {
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIPs)
return nil
}
}
log.Debugf("keepalive from %v => %v", addr, peerIP)
log.Debugf("keepalive from %v => %v", addr, peerIPs)
addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil {
log.Warnf("keepalive from %v: %v", addr, err)
return nil
}
var keepAliveData [keepAliveDataLength]byte
var keepAliveData [keepAliveHeaderLength]byte
copy(keepAliveData[:4], magicHeader) // magic header
a16 := addrPort.Addr().As16()
copy(keepAliveData[4:], a16[:])
@ -151,7 +162,9 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
return nil
}
h.updateRoute(peerIP, addr, log)
for _, ip := range peerIPs {
h.updateRoute(ip, addr, log)
}
return nil
}
@ -204,7 +217,7 @@ func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, con
}
}()
err = <-errc
err := <-errc
if err != nil && err == io.EOF {
err = nil
}