From 03988fee0b9a91eb5a081eabb8377946fb6f5863 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 11 Apr 2022 22:53:02 +0800 Subject: [PATCH] add recorder --- chain/router.go | 48 ++++++++++++++++++++++++++++++++++++-------- recorder/recorder.go | 17 ++++++++++++++++ service/service.go | 1 + 3 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 recorder/recorder.go diff --git a/chain/router.go b/chain/router.go index 32b238e..ba1c81c 100644 --- a/chain/router.go +++ b/chain/router.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/connector" "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" + "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" ) @@ -18,14 +19,15 @@ type SockOpts struct { } type Router struct { - ifceName string - sockOpts *SockOpts - timeout time.Duration - retries int - chain Chainer - resolver resolver.Resolver - hosts hosts.HostMapper - logger logger.Logger + ifceName string + sockOpts *SockOpts + timeout time.Duration + retries int + chain Chainer + resolver resolver.Resolver + hosts hosts.HostMapper + recorders []recorder.RecorderObject + logger logger.Logger } func (r *Router) WithTimeout(timeout time.Duration) *Router { @@ -70,16 +72,29 @@ func (r *Router) Hosts() hosts.HostMapper { return nil } +func (r *Router) WithRecorder(recorders ...recorder.RecorderObject) *Router { + r.recorders = recorders + return r +} + func (r *Router) WithLogger(logger logger.Logger) *Router { r.logger = logger return r } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + host := address + if h, _, _ := net.SplitHostPort(address); h != "" { + host = h + } + r.record(ctx, recorder.RecorderServiceRouterDialAddress, []byte(host)) + conn, err = r.dial(ctx, network, address) if err != nil { + r.record(ctx, recorder.RecorderServiceRouterDialAddressError, []byte(host)) return } + if network == "udp" || network == "udp4" || network == "udp6" { if _, ok := conn.(net.PacketConn); !ok { return &packetConn{conn}, nil @@ -88,6 +103,23 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co return } +func (r *Router) record(ctx context.Context, name string, data []byte) error { + if len(data) == 0 { + return nil + } + + for _, rec := range r.recorders { + if rec.Record == name { + err := rec.Recorder.Record(ctx, data) + if err != nil { + r.logger.Errorf("record %s: %v", name, err) + } + return err + } + } + return nil +} + func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { count := r.retries + 1 if count <= 0 { diff --git a/recorder/recorder.go b/recorder/recorder.go new file mode 100644 index 0000000..70b6099 --- /dev/null +++ b/recorder/recorder.go @@ -0,0 +1,17 @@ +package recorder + +import "context" + +type Recorder interface { + Record(ctx context.Context, b []byte) error +} + +type RecorderObject struct { + Recorder Recorder + Record string +} + +const ( + RecorderServiceRouterDialAddress = "recorder.service.router.dial.address" + RecorderServiceRouterDialAddressError = "recorder.service.router.dial.address.error" +) diff --git a/service/service.go b/service/service.go index 3df74ad..b79e25a 100644 --- a/service/service.go +++ b/service/service.go @@ -98,6 +98,7 @@ func (s *service) Serve() error { if s.options.admission != nil && !s.options.admission.Admit(conn.RemoteAddr().String()) { conn.Close() + s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr()) continue }