add recorder

This commit is contained in:
ginuerzh 2022-04-11 22:53:02 +08:00
parent a117222cde
commit 03988fee0b
3 changed files with 58 additions and 8 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/go-gost/core/connector" "github.com/go-gost/core/connector"
"github.com/go-gost/core/hosts" "github.com/go-gost/core/hosts"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/resolver" "github.com/go-gost/core/resolver"
) )
@ -18,14 +19,15 @@ type SockOpts struct {
} }
type Router struct { type Router struct {
ifceName string ifceName string
sockOpts *SockOpts sockOpts *SockOpts
timeout time.Duration timeout time.Duration
retries int retries int
chain Chainer chain Chainer
resolver resolver.Resolver resolver resolver.Resolver
hosts hosts.HostMapper hosts hosts.HostMapper
logger logger.Logger recorders []recorder.RecorderObject
logger logger.Logger
} }
func (r *Router) WithTimeout(timeout time.Duration) *Router { func (r *Router) WithTimeout(timeout time.Duration) *Router {
@ -70,16 +72,29 @@ func (r *Router) Hosts() hosts.HostMapper {
return nil return nil
} }
func (r *Router) WithRecorder(recorders ...recorder.RecorderObject) *Router {
r.recorders = recorders
return r
}
func (r *Router) WithLogger(logger logger.Logger) *Router { func (r *Router) WithLogger(logger logger.Logger) *Router {
r.logger = logger r.logger = logger
return r return r
} }
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
host := address
if h, _, _ := net.SplitHostPort(address); h != "" {
host = h
}
r.record(ctx, recorder.RecorderServiceRouterDialAddress, []byte(host))
conn, err = r.dial(ctx, network, address) conn, err = r.dial(ctx, network, address)
if err != nil { if err != nil {
r.record(ctx, recorder.RecorderServiceRouterDialAddressError, []byte(host))
return return
} }
if network == "udp" || network == "udp4" || network == "udp6" { if network == "udp" || network == "udp4" || network == "udp6" {
if _, ok := conn.(net.PacketConn); !ok { if _, ok := conn.(net.PacketConn); !ok {
return &packetConn{conn}, nil return &packetConn{conn}, nil
@ -88,6 +103,23 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co
return return
} }
func (r *Router) record(ctx context.Context, name string, data []byte) error {
if len(data) == 0 {
return nil
}
for _, rec := range r.recorders {
if rec.Record == name {
err := rec.Recorder.Record(ctx, data)
if err != nil {
r.logger.Errorf("record %s: %v", name, err)
}
return err
}
}
return nil
}
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.retries + 1 count := r.retries + 1
if count <= 0 { if count <= 0 {

17
recorder/recorder.go Normal file
View File

@ -0,0 +1,17 @@
package recorder
import "context"
type Recorder interface {
Record(ctx context.Context, b []byte) error
}
type RecorderObject struct {
Recorder Recorder
Record string
}
const (
RecorderServiceRouterDialAddress = "recorder.service.router.dial.address"
RecorderServiceRouterDialAddressError = "recorder.service.router.dial.address.error"
)

View File

@ -98,6 +98,7 @@ func (s *service) Serve() error {
if s.options.admission != nil && if s.options.admission != nil &&
!s.options.admission.Admit(conn.RemoteAddr().String()) { !s.options.admission.Admit(conn.RemoteAddr().String()) {
conn.Close() conn.Close()
s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr())
continue continue
} }