From ca414f655dbbd557b40725af7c589a73dd5d726c Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 21 Aug 2022 13:54:52 +0800 Subject: [PATCH] update tun handler --- handler/tun/client.go | 130 ++++++++++++++++++++ handler/tun/handler.go | 253 +-------------------------------------- handler/tun/metadata.go | 3 - handler/tun/server.go | 166 +++++++++++++++++++++++++ listener/tun/conn.go | 11 +- listener/tun/listener.go | 88 +++++++++----- 6 files changed, 366 insertions(+), 285 deletions(-) create mode 100644 handler/tun/client.go create mode 100644 handler/tun/server.go diff --git a/handler/tun/client.go b/handler/tun/client.go new file mode 100644 index 0000000..782d614 --- /dev/null +++ b/handler/tun/client.go @@ -0,0 +1,130 @@ +package tun + +import ( + "context" + "io" + "net" + + "github.com/go-gost/core/common/bufpool" + "github.com/go-gost/core/logger" + tun_util "github.com/go-gost/x/internal/util/tun" + "github.com/songgao/water/waterutil" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error { + cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + return err + } + defer cc.Close() + + return h.transportClient(conn, cc, config, log) +} + +func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_util.Config, log logger.Logger) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, err := tun.Read(*b) + if err != nil { + return err + } + + if waterutil.IsIPv4((*b)[:n]) { + header, err := ipv4.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + + log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } else if waterutil.IsIPv6((*b)[:n]) { + header, err := ipv6.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + + log.Tracef("%s >> %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } else { + log.Warn("unknown packet, discarded") + return nil + } + + _, err = conn.Write((*b)[:n]) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, err := conn.Read(*b) + if err != nil { + return err + } + + if waterutil.IsIPv4((*b)[:n]) { + header, err := ipv4.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + + log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } else if waterutil.IsIPv6((*b)[:n]) { + header, err := ipv6.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + + log.Tracef("%s > %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } else { + log.Warn("unknown packet, discarded") + return nil + } + + _, err = tun.Write((*b)[:n]) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} diff --git a/handler/tun/handler.go b/handler/tun/handler.go index d9efe52..b3b344a 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -4,25 +4,16 @@ import ( "context" "errors" "fmt" - "io" "net" - "os" "sync" "time" "github.com/go-gost/core/chain" - "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/handler" - "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/x/internal/util/ss" tun_util "github.com/go-gost/x/internal/util/tun" "github.com/go-gost/x/registry" - "github.com/shadowsocks/go-shadowsocks2/core" - "github.com/shadowsocks/go-shadowsocks2/shadowaead" "github.com/songgao/water/waterutil" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) func init() { @@ -33,7 +24,6 @@ type tunHandler struct { group *chain.NodeGroup routes sync.Map exit chan struct{} - cipher core.Cipher router *chain.Router md metadata options handler.Options @@ -56,15 +46,6 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { return } - if h.options.Auth != nil { - method := h.options.Auth.Username() - password, _ := h.options.Auth.Password() - h.cipher, err = ss.ShadowCipher(method, password, h.md.key) - if err != nil { - return - } - } - h.router = h.options.Router if h.router == nil { h.router = (&chain.Router{}).WithLogger(h.options.Logger) @@ -79,7 +60,6 @@ func (h *tunHandler) Forward(group *chain.NodeGroup) { } func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { - defer os.Exit(0) defer conn.Close() log := h.options.Logger @@ -90,6 +70,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. log.Error(err) return err } + config := v.GetMetadata().Get("config").(*tun_util.Config) start := time.Now() log = log.WithFields(map[string]any{ @@ -119,236 +100,12 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), }) log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) + + h.handleClient(ctx, conn, raddr, config, log) + return nil } - config := v.GetMetadata().Get("config").(*tun_util.Config) - h.handleLoop(ctx, conn, raddr, config, log) - return nil -} - -func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) { - var tempDelay time.Duration - for { - err := func() error { - var err error - var pc net.PacketConn - if addr != nil { - cc, err := h.router.Dial(ctx, addr.Network(), "") - if err != nil { - return err - } - - var ok bool - pc, ok = cc.(net.PacketConn) - if !ok { - cc.Close() - return errors.New("wrong connection type") - } - } else { - laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) - pc, err = net.ListenUDP("udp", laddr) - } - if err != nil { - return err - } - - if h.cipher != nil { - pc = h.cipher.PacketConn(pc) - } - defer pc.Close() - - return h.transport(conn, pc, addr, config, log) - }() - if err != nil { - log.Error(err) - } - - select { - case <-h.exit: - return - default: - } - - if err != nil { - if tempDelay == 0 { - tempDelay = 1000 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 6 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - tempDelay = 0 - } - -} - -func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, config *tun_util.Config, log logger.Logger) error { - errc := make(chan error, 1) - - go func() { - for { - err := func() error { - b := bufpool.Get(h.md.bufferSize) - defer bufpool.Put(b) - - n, err := tun.Read(*b) - if err != nil { - select { - case h.exit <- struct{}{}: - default: - } - return err - } - - var src, dst net.IP - if waterutil.IsIPv4((*b)[:n]) { - header, err := ipv4.ParseHeader((*b)[:n]) - if err != nil { - log.Error(err) - return nil - } - log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", - header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), - header.Len, header.TotalLen, header.ID, header.Flags) - - src, dst = header.Src, header.Dst - } else if waterutil.IsIPv6((*b)[:n]) { - header, err := ipv6.ParseHeader((*b)[:n]) - if err != nil { - log.Warn(err) - return nil - } - log.Tracef("%s >> %s %s %d %d", - header.Src, header.Dst, - ipProtocol(waterutil.IPProtocol(header.NextHeader)), - header.PayloadLen, header.TrafficClass) - - src, dst = header.Src, header.Dst - } else { - log.Warn("unknown packet, discarded") - return nil - } - - // client side, deliver packet directly. - if raddr != nil { - _, err := conn.WriteTo((*b)[:n], raddr) - return err - } - - addr := h.findRouteFor(dst, config.Routes...) - if addr == nil { - log.Debugf("no route for %s -> %s", src, dst) - return nil - } - - log.Debugf("find route: %s -> %s", dst, addr) - - if _, err := conn.WriteTo((*b)[:n], addr); err != nil { - return err - } - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(h.md.bufferSize) - defer bufpool.Put(b) - - n, addr, err := conn.ReadFrom(*b) - if err != nil && - err != shadowaead.ErrShortPacket { - return err - } - - var src, dst net.IP - if waterutil.IsIPv4((*b)[:n]) { - header, err := ipv4.ParseHeader((*b)[:n]) - if err != nil { - log.Warn(err) - return nil - } - - log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", - header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), - header.Len, header.TotalLen, header.ID, header.Flags) - - src, dst = header.Src, header.Dst - } else if waterutil.IsIPv6((*b)[:n]) { - header, err := ipv6.ParseHeader((*b)[:n]) - if err != nil { - log.Warn(err) - return nil - } - - log.Tracef("%s > %s %s %d %d", - header.Src, header.Dst, - ipProtocol(waterutil.IPProtocol(header.NextHeader)), - header.PayloadLen, header.TrafficClass) - - src, dst = header.Src, header.Dst - } else { - log.Warn("unknown packet, discarded") - return nil - } - - // client side, deliver packet to tun device. - if raddr != nil { - _, err := tun.Write((*b)[:n]) - return err - } - - rkey := ipToTunRouteKey(src) - if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { - if actual.(net.Addr).String() != addr.String() { - log.Debugf("update route: %s -> %s (old %s)", - src, addr, actual.(net.Addr)) - h.routes.Store(rkey, addr) - } - } else { - log.Debugf("no route for %s -> %s", src, addr) - } - - if addr := h.findRouteFor(dst, config.Routes...); addr != nil { - log.Debugf("find route: %s -> %s", dst, addr) - - _, err := conn.WriteTo((*b)[:n], addr) - return err - } - - if _, err := tun.Write((*b)[:n]); err != nil { - select { - case h.exit <- struct{}{}: - default: - } - return err - } - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - err := <-errc - if err != nil && err == io.EOF { - err = nil - } - return err + return h.handleServer(ctx, conn, config, log) } func (h *tunHandler) findRouteFor(dst net.IP, routes ...tun_util.Route) net.Addr { diff --git a/handler/tun/metadata.go b/handler/tun/metadata.go index 386603c..72eeff2 100644 --- a/handler/tun/metadata.go +++ b/handler/tun/metadata.go @@ -6,17 +6,14 @@ import ( ) type metadata struct { - key string bufferSize int } func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - key = "key" bufferSize = "bufferSize" ) - h.md.key = mdx.GetString(md, key) h.md.bufferSize = mdx.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { h.md.bufferSize = 1500 diff --git a/handler/tun/server.go b/handler/tun/server.go new file mode 100644 index 0000000..ee21ab4 --- /dev/null +++ b/handler/tun/server.go @@ -0,0 +1,166 @@ +package tun + +import ( + "context" + "io" + "net" + + "github.com/go-gost/core/common/bufpool" + "github.com/go-gost/core/logger" + tun_util "github.com/go-gost/x/internal/util/tun" + "github.com/songgao/water/waterutil" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tun_util.Config, log logger.Logger) error { + pc, err := net.ListenPacket(conn.LocalAddr().Network(), conn.LocalAddr().String()) + if err != nil { + return err + } + defer pc.Close() + + return h.transportServer(conn, pc, config, log) +} + +func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, err := tun.Read(*b) + if err != nil { + return err + } + + var src, dst net.IP + if waterutil.IsIPv4((*b)[:n]) { + header, err := ipv4.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + src, dst = header.Src, header.Dst + + log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } else if waterutil.IsIPv6((*b)[:n]) { + header, err := ipv6.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + src, dst = header.Src, header.Dst + + log.Tracef("%s >> %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } else { + log.Warn("unknown packet, discarded") + return nil + } + + addr := h.findRouteFor(dst, config.Routes...) + if addr == nil { + log.Debugf("no route for %s -> %s, packet discarded", src, dst) + return nil + } + + log.Debugf("find route: %s -> %s", dst, addr) + + if _, err := conn.WriteTo((*b)[:n], addr); err != nil { + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, addr, err := conn.ReadFrom(*b) + if err != nil { + return err + } + + var src, dst net.IP + if waterutil.IsIPv4((*b)[:n]) { + header, err := ipv4.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + src, dst = header.Src, header.Dst + + log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } else if waterutil.IsIPv6((*b)[:n]) { + header, err := ipv6.ParseHeader((*b)[:n]) + if err != nil { + log.Warn(err) + return nil + } + src, dst = header.Src, header.Dst + + log.Tracef("%s > %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } else { + log.Warn("unknown packet, discarded") + return nil + } + + rkey := ipToTunRouteKey(src) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + h.routes.Store(rkey, addr) + log.Debugf("update route: %s -> %s (old %s)", + src, addr, actual.(net.Addr)) + } + } else { + log.Debugf("new route: %s -> %s", src, addr) + } + + if addr := h.findRouteFor(dst, config.Routes...); addr != nil { + log.Debugf("find route: %s -> %s", dst, addr) + + _, err := conn.WriteTo((*b)[:n], addr) + return err + } + + if _, err := tun.Write((*b)[:n]); err != nil { + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} diff --git a/listener/tun/conn.go b/listener/tun/conn.go index 3471121..7bed81b 100644 --- a/listener/tun/conn.go +++ b/listener/tun/conn.go @@ -1,6 +1,7 @@ package tun import ( + "context" "errors" "io" "net" @@ -10,9 +11,10 @@ import ( ) type conn struct { - ifce io.ReadWriteCloser - laddr net.Addr - raddr net.Addr + ifce io.ReadWriteCloser + laddr net.Addr + raddr net.Addr + cancel context.CancelFunc } func (c *conn) Read(b []byte) (n int, err error) { @@ -44,6 +46,9 @@ func (c *conn) SetWriteDeadline(t time.Time) error { } func (c *conn) Close() (err error) { + if c.cancel != nil { + c.cancel() + } return c.ifce.Close() } diff --git a/listener/tun/listener.go b/listener/tun/listener.go index 8e7aef0..3f9cbeb 100644 --- a/listener/tun/listener.go +++ b/listener/tun/listener.go @@ -1,7 +1,9 @@ package tun import ( + "context" "net" + "time" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" @@ -49,43 +51,67 @@ func (l *tunListener) Init(md mdata.Metadata) (err error) { if err != nil { return } - - ifce, name, ip, err := l.createTun() - if err != nil { - if ifce != nil { - ifce.Close() - } - return - } - - itf, err := net.InterfaceByName(name) - if err != nil { - return - } - - addrs, _ := itf.Addrs() - l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s", - itf.Name, ip, itf.MTU, addrs) - - l.cqueue = make(chan net.Conn, 1) + l.cqueue = make(chan net.Conn) l.closed = make(chan struct{}) - var c net.Conn - c = &conn{ - ifce: ifce, - laddr: l.addr, - raddr: &net.IPAddr{IP: ip}, - } - c = metrics.WrapConn(l.options.Service, c) - c = withMetadata(mdx.NewMetadata(map[string]any{ - "config": l.md.config, - }), c) - - l.cqueue <- c + go l.listenLoop() return } +func (l *tunListener) listenLoop() { + for { + ctx, cancel := context.WithCancel(context.Background()) + err := func() error { + ifce, name, ip, err := l.createTun() + if err != nil { + if ifce != nil { + ifce.Close() + } + return err + } + + itf, err := net.InterfaceByName(name) + if err != nil { + return err + } + + addrs, _ := itf.Addrs() + l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s", + itf.Name, ip, itf.MTU, addrs) + + var c net.Conn + c = &conn{ + ifce: ifce, + laddr: l.addr, + raddr: &net.IPAddr{IP: ip}, + cancel: cancel, + } + c = metrics.WrapConn(l.options.Service, c) + c = withMetadata(mdx.NewMetadata(map[string]any{ + "config": l.md.config, + }), c) + + l.cqueue <- c + + return nil + }() + if err != nil { + l.logger.Error(err) + cancel() + continue + } + + select { + case <-ctx.Done(): + case <-l.closed: + return + } + + time.Sleep(time.Second) + } +} + func (l *tunListener) Accept() (net.Conn, error) { select { case conn := <-l.cqueue: