diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index cfc4016..01523cb 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -274,6 +274,13 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { Metadata: md, } + if svc.Handler.Type == "sshd" { + svc.Handler.Auths = nil + } + if svc.Listener.Type == "sshd" { + svc.Listener.Auths = auths + } + return svc, nil } @@ -354,6 +361,13 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { Metadata: md, } + if node.Connector.Type == "sshd" { + node.Connector.Auth = nil + } + if node.Dialer.Type == "sshd" { + node.Dialer.Auth = auth + } + return node, nil } diff --git a/cmd/gost/config.go b/cmd/gost/config.go index d8fdaa6..15e5559 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -91,6 +91,7 @@ func buildService(cfg *config.Config) (services []*service.Service) { ln := registry.GetListener(svc.Listener.Type)( listener.AddrOption(svc.Addr), + listener.ChainOption(chains[svc.Listener.Chain]), listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...), listener.TLSConfigOption(tlsConfig), listener.LoggerOption(listenerLogger), diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 0e8c6ed..c473470 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -37,7 +37,6 @@ import ( _ "github.com/go-gost/gost/pkg/handler/dns" _ "github.com/go-gost/gost/pkg/handler/forward/local" _ "github.com/go-gost/gost/pkg/handler/forward/remote" - _ "github.com/go-gost/gost/pkg/handler/forward/ssh" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/http2" _ "github.com/go-gost/gost/pkg/handler/redirect" @@ -47,6 +46,7 @@ import ( _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" _ "github.com/go-gost/gost/pkg/handler/ss/udp" + _ "github.com/go-gost/gost/pkg/handler/sshd" _ "github.com/go-gost/gost/pkg/handler/tap" _ "github.com/go-gost/gost/pkg/handler/tun" @@ -65,6 +65,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/rtcp" _ "github.com/go-gost/gost/pkg/listener/rudp" _ "github.com/go-gost/gost/pkg/listener/ssh" + _ "github.com/go-gost/gost/pkg/listener/sshd" _ "github.com/go-gost/gost/pkg/listener/tap" _ "github.com/go-gost/gost/pkg/listener/tcp" _ "github.com/go-gost/gost/pkg/listener/tls" diff --git a/pkg/dialer/http3/dialer.go b/pkg/dialer/http3/dialer.go index ffa18bd..0694e23 100644 --- a/pkg/dialer/http3/dialer.go +++ b/pkg/dialer/http3/dialer.go @@ -15,6 +15,7 @@ import ( func init() { registry.RegisterDialer("http3", NewDialer) + registry.RegisterDialer("h3", NewDialer) } type http3Dialer struct { diff --git a/pkg/dialer/ssh/dialer.go b/pkg/dialer/ssh/dialer.go index 23d21b0..847764d 100644 --- a/pkg/dialer/ssh/dialer.go +++ b/pkg/dialer/ssh/dialer.go @@ -164,7 +164,7 @@ func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { config := ssh.ClientConfig{ - // Timeout: timeout, + Timeout: 30 * time.Second, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } if d.md.user != nil { diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 5b5b6a0..ac995ec 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -9,7 +9,6 @@ import ( "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "github.com/go-gost/relay" @@ -24,42 +23,37 @@ type autoHandler struct { socks4Handler handler.Handler socks5Handler handler.Handler relayHandler handler.Handler - log logger.Logger + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { - options := &handler.Options{} + options := handler.Options{} for _, opt := range opts { - opt(options) - } - - log := options.Logger - if log == nil { - log = logger.Default() + opt(&options) } h := &autoHandler{ - log: log, + options: options, } if f := registry.GetHandler("http"); f != nil { v := append(opts, - handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"}))) + handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "http"}))) h.httpHandler = f(v...) } if f := registry.GetHandler("socks4"); f != nil { v := append(opts, - handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks4"}))) + handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks4"}))) h.socks4Handler = f(v...) } if f := registry.GetHandler("socks5"); f != nil { v := append(opts, - handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks5"}))) + handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks5"}))) h.socks5Handler = f(v...) } if f := registry.GetHandler("relay"); f != nil { v := append(opts, - handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "relay"}))) + handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "relay"}))) h.relayHandler = f(v...) } @@ -92,15 +86,15 @@ func (h *autoHandler) Init(md md.Metadata) error { } func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { - h.log = h.log.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) start := time.Now() - h.log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.log.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -108,7 +102,7 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { br := bufio.NewReader(conn) b, err := br.Peek(1) if err != nil { - h.log.Error(err) + log.Error(err) conn.Close() return } @@ -132,5 +126,4 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { h.httpHandler.Handle(ctx, conn) } } - } diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index 8d3e09b..ddc4fb3 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -33,7 +33,6 @@ type dnsHandler struct { exchangers []exchanger.Exchanger cache *resolver_util.Cache router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -50,19 +49,18 @@ func NewHandler(opts ...handler.Option) handler.Handler { } func (h *dnsHandler) Init(md md.Metadata) (err error) { - h.logger = h.options.Logger - if err = h.parseMetadata(md); err != nil { return } + log := h.options.Logger - h.cache = resolver_util.NewCache().WithLogger(h.options.Logger) + h.cache = resolver_util.NewCache().WithLogger(log) h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, Resolver: h.options.Resolver, // Hosts: h.options.Hosts, - Logger: h.options.Logger, + Logger: log, } for _, server := range h.md.dns { @@ -74,10 +72,10 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { server, exchanger.RouterOption(h.router), exchanger.TimeoutOption(h.md.timeout), - exchanger.LoggerOption(h.logger), + exchanger.LoggerOption(log), ) if err != nil { - h.logger.Warnf("parse %s: %v", server, err) + log.Warnf("parse %s: %v", server, err) continue } h.exchangers = append(h.exchangers, ex) @@ -87,9 +85,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { defaultNameserver, exchanger.RouterOption(h.router), exchanger.TimeoutOption(h.md.timeout), - exchanger.LoggerOption(h.logger), + exchanger.LoggerOption(log), ) - h.logger.Warnf("resolver not found, default to %s", defaultNameserver) + log.Warnf("resolver not found, default to %s", defaultNameserver) if err != nil { return err } @@ -103,14 +101,14 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -120,26 +118,25 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { n, err := conn.Read(*b) if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Info("read data: ", n) - reply, err := h.exchange(ctx, (*b)[:n]) + reply, err := h.exchange(ctx, (*b)[:n], log) if err != nil { return } defer bufpool.Put(&reply) if _, err = conn.Write(reply); err != nil { - h.logger.Error(err) + log.Error(err) } } -func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { +func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { mq := dns.Msg{} if err := mq.Unpack(msg); err != nil { - h.logger.Error(err) + log.Error(err) return nil, err } @@ -149,23 +146,23 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { resolver_util.AddSubnetOpt(&mq, h.md.clientIP) - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(mq.String()) + if log.IsLevelEnabled(logger.DebugLevel) { + log.Debug(mq.String()) } else { - h.logger.Info(h.dumpMsgHeader(&mq)) + log.Info(h.dumpMsgHeader(&mq)) } var mr *dns.Msg - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { defer func() { if mr != nil { - h.logger.Debug(mr.String()) + log.Debug(mr.String()) } }() } - mr = h.lookupHosts(&mq) + mr = h.lookupHosts(&mq, log) if mr != nil { b := bufpool.Get(4096) return mr.PackBuffer(*b) @@ -176,7 +173,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { key := resolver_util.NewCacheKey(&mq.Question[0]) mr = h.cache.Load(key) if mr != nil { - h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) + log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) mr.Id = mq.Id b := bufpool.Get(4096) @@ -195,18 +192,18 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { query, err := mq.PackBuffer(*b) if err != nil { - h.logger.Error(err) + log.Error(err) return nil, err } var reply []byte for _, ex := range h.exchangers { - h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String()) + log.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String()) reply, err = ex.Exchange(ctx, query) if err == nil { break } - h.logger.Error(err) + log.Error(err) } if err != nil { return nil, err @@ -214,21 +211,21 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { mr = &dns.Msg{} if err = mr.Unpack(reply); err != nil { - h.logger.Error(err) + log.Error(err) return nil, err } - if h.logger.IsLevelEnabled(logger.DebugLevel) { - h.logger.Debug(mr.String()) + if log.IsLevelEnabled(logger.DebugLevel) { + log.Debug(mr.String()) } else { - h.logger.Info(h.dumpMsgHeader(mr)) + log.Info(h.dumpMsgHeader(mr)) } return reply, nil } // lookup host mapper -func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) { +func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { if h.options.Hosts == nil || r.Question[0].Qclass != dns.ClassINET || (r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) { @@ -246,12 +243,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) { if len(ips) == 0 { return nil } - h.logger.Debugf("hit host mapper: %s -> %s", host, ips) + log.Debugf("hit host mapper: %s -> %s", host, ips) for _, ip := range ips { rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String())) if err != nil { - h.logger.Error(err) + log.Error(err) return nil } m.Answer = append(m.Answer, rr) @@ -262,12 +259,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) { if len(ips) == 0 { return nil } - h.logger.Debugf("hit host mapper: %s -> %s", host, ips) + log.Debugf("hit host mapper: %s -> %s", host, ips) for _, ip := range ips { rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String())) if err != nil { - h.logger.Error(err) + log.Error(err) return nil } m.Answer = append(m.Answer, rr) diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index 05295e4..657df1a 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -8,7 +8,6 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) @@ -22,7 +21,6 @@ func init() { type forwardHandler struct { group *chain.NodeGroup router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -55,7 +53,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -69,21 +66,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() target := h.group.Next() if target == nil { - h.logger.Error("no target available") + log.Error("no target available") return } @@ -92,15 +89,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { network = "udp" } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", target.Addr(), network), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { - h.logger.Error(err) + log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. target.Marker().Mark() @@ -110,11 +107,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { target.Marker().Reset() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) } diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index 2fc0699..80c22c4 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -8,7 +8,6 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) @@ -21,7 +20,6 @@ func init() { type forwardHandler struct { group *chain.NodeGroup router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -49,7 +47,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -63,21 +60,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() target := h.group.Next() if target == nil { - h.logger.Error("no target available") + log.Error("no target available") return } @@ -86,15 +83,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { network = "udp" } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", target.Addr(), network), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { - h.logger.Error(err) + log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. target.Marker().Mark() @@ -104,11 +101,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { target.Marker().Reset() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) } diff --git a/pkg/handler/forward/ssh/handler.go b/pkg/handler/forward/ssh/handler.go deleted file mode 100644 index 0791912..0000000 --- a/pkg/handler/forward/ssh/handler.go +++ /dev/null @@ -1,300 +0,0 @@ -package ssh - -import ( - "context" - "encoding/binary" - "fmt" - "net" - "strconv" - "time" - - "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" - "github.com/go-gost/gost/pkg/handler" - ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" - "github.com/go-gost/gost/pkg/logger" - md "github.com/go-gost/gost/pkg/metadata" - "github.com/go-gost/gost/pkg/registry" - "golang.org/x/crypto/ssh" -) - -// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X -const ( - DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 - RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 - ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 - CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 -) - -func init() { - registry.RegisterHandler("sshd", NewHandler) -} - -type forwardHandler struct { - config *ssh.ServerConfig - router *chain.Router - logger logger.Logger - md metadata - options handler.Options -} - -func NewHandler(opts ...handler.Option) handler.Handler { - options := handler.Options{} - for _, opt := range opts { - opt(&options) - } - - return &forwardHandler{ - options: options, - } -} - -func (h *forwardHandler) Init(md md.Metadata) (err error) { - if err = h.parseMetadata(md); err != nil { - return - } - - authenticator := auth_util.AuthFromUsers(h.options.Auths...) - - config := &ssh.ServerConfig{ - PasswordCallback: ssh_util.PasswordCallback(authenticator), - PublicKeyCallback: ssh_util.PublicKeyCallback(h.md.authorizedKeys), - } - - config.AddHostKey(h.md.signer) - - if authenticator == nil && len(h.md.authorizedKeys) == 0 { - config.NoClientAuth = true - } - - h.config = config - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, - } - h.logger = h.options.Logger - - return nil -} - -func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { - defer conn.Close() - - start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - }) - - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) - defer func() { - h.logger.WithFields(map[string]interface{}{ - "duration": time.Since(start), - }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) - }() - - sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) - if err != nil { - h.logger.Error(err) - return - } - - h.handleForward(ctx, sshConn, chans, reqs) -} - -func (h *forwardHandler) handleForward(ctx context.Context, conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - quit := make(chan struct{}) - defer close(quit) // quit signal - - go func() { - for req := range reqs { - switch req.Type { - case RemoteForwardRequest: - go h.tcpipForwardRequest(conn, req, quit) - default: - h.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply) - if req.WantReply { - req.Reply(false, nil) - } - } - } - }() - - go func() { - for newChannel := range chans { - // Check the type of channel - t := newChannel.ChannelType() - switch t { - case DirectForwardRequest: - channel, requests, err := newChannel.Accept() - if err != nil { - h.logger.Warnf("could not accept channel: %s", err.Error()) - continue - } - p := directForward{} - ssh.Unmarshal(newChannel.ExtraData(), &p) - - h.logger.Debug(p.String()) - - if p.Host1 == "" { - p.Host1 = "" - } - - go ssh.DiscardRequests(requests) - go h.directPortForwardChannel(ctx, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1)))) - default: - h.logger.Warnf("unsupported channel type: %s", t) - newChannel.Reject(ssh.Prohibited, fmt.Sprintf("unsupported channel type: %s", t)) - } - } - }() - - conn.Wait() -} - -func (h *forwardHandler) directPortForwardChannel(ctx context.Context, channel ssh.Channel, raddr string) { - defer channel.Close() - - // log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr) - - /* - if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) - return - } - */ - - if h.options.Bypass != nil && h.options.Bypass.Contains(raddr) { - h.logger.Infof("bypass %s", raddr) - return - } - - conn, err := h.router.Dial(ctx, "tcp", raddr) - if err != nil { - return - } - defer conn.Close() - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr()) - handler.Transport(conn, channel) - h.logger.WithFields(map[string]interface{}{ - "duration": time.Since(t), - }).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr()) -} - -// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" -type directForward struct { - Host1 string - Port1 uint32 - Host2 string - Port2 uint32 -} - -func (p directForward) String() string { - return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) -} - -func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { - host, portString, err := net.SplitHostPort(addr.String()) - if err != nil { - return - } - port, err = strconv.Atoi(portString) - return -} - -// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request -type tcpipForward struct { - Host string - Port uint32 -} - -func (h *forwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { - t := tcpipForward{} - ssh.Unmarshal(req.Payload, &t) - - addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port))) - - /* - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) - req.Reply(false, nil) - return - } - */ - - // tie to the client connection - ln, err := net.Listen("tcp", addr) - if err != nil { - h.logger.Error(err) - req.Reply(false, nil) - return - } - defer ln.Close() - - h.logger.Debugf("bind on %s OK", ln.Addr()) - - err = func() error { - if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used - _, port, err := getHostPortFromAddr(ln.Addr()) - if err != nil { - return err - } - var b [4]byte - binary.BigEndian.PutUint32(b[:], uint32(port)) - t.Port = uint32(port) - return req.Reply(true, b[:]) - } - return req.Reply(true, nil) - }() - if err != nil { - h.logger.Error(err) - return - } - - go func() { - for { - conn, err := ln.Accept() - if err != nil { // Unable to accept new connection - listener is likely closed - return - } - - go func(conn net.Conn) { - defer conn.Close() - - p := directForward{} - var err error - - var portnum int - p.Host1 = t.Host - p.Port1 = t.Port - p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) - if err != nil { - return - } - - p.Port2 = uint32(portnum) - ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) - if err != nil { - h.logger.Error("open forwarded channel: ", err) - return - } - defer ch.Close() - go ssh.DiscardRequests(reqs) - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - handler.Transport(ch, conn) - h.logger.WithFields(map[string]interface{}{ - "duration": time.Since(t), - }).Infof("%s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) - }(conn) - } - }() - - <-quit -} diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index 6d216a9..c3bbb49 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -32,7 +32,6 @@ func init() { type httpHandler struct { router *chain.Router authenticator auth.Authenticator - logger logger.Logger md metadata options handler.Options } @@ -61,7 +60,6 @@ func (h *httpHandler) Init(md md.Metadata) error { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return nil } @@ -70,28 +68,28 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { - h.logger.Error(err) + log.Error(err) return } defer req.Body.Close() - h.handleRequest(ctx, conn, req) + h.handleRequest(ctx, conn, req, log) } -func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) { +func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) { if req == nil { return } @@ -129,16 +127,16 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt fields := map[string]interface{}{ "dst": addr, } - if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { + if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" { fields["user"] = u } - h.logger = h.logger.WithFields(fields) + log = log.WithFields(fields) - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpRequest(req, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + log.Infof("%s >> %s", conn.RemoteAddr(), addr) resp := &http.Response{ ProtoMajor: 1, @@ -152,22 +150,22 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { resp.StatusCode = http.StatusForbidden - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } - h.logger.Info("bypass: ", addr) + log.Info("bypass: ", addr) resp.Write(conn) return } - if !h.authenticate(conn, req, resp) { + if !h.authenticate(conn, req, resp, log) { return } if network == "udp" { - h.handleUDP(ctx, conn, network, req.Host) + h.handleUDP(ctx, conn, network, req.Host, log) return } @@ -176,9 +174,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.StatusCode = http.StatusBadRequest resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } return @@ -191,9 +189,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.StatusCode = http.StatusServiceUnavailable resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } return } @@ -203,30 +201,28 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.StatusCode = http.StatusOK resp.Status = "200 Connection established" - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } if err = resp.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } } else { req.Header.Del("Proxy-Connection") if err = req.Write(cc); err != nil { - h.logger.Error(err) + log.Error(err) return } } start := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(start), - }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) } func (h *httpHandler) decodeServerName(s string) (string, error) { @@ -247,7 +243,7 @@ func (h *httpHandler) decodeServerName(s string) (string, error) { return string(v), nil } -func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { +func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) { if proxyAuth == "" { return } @@ -268,8 +264,8 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password strin return cs[:s], cs[s+1:], true } -func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { - u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) +func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) { + u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) if h.authenticator == nil || h.authenticator.Authenticate(u, p) { return true } @@ -289,7 +285,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. } r, err := http.Get(url) if err != nil { - h.logger.Error(err) + log.Error(err) break } resp = r @@ -297,7 +293,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. case "host": cc, err := net.Dial("tcp", pr.Value) if err != nil { - h.logger.Error(err) + log.Error(err) break } defer cc.Close() @@ -333,7 +329,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. resp.Header.Add("Proxy-Connection", "close") } - h.logger.Info("proxy authentication required") + log.Info("proxy authentication required") } else { resp.Header.Set("Server", "nginx/1.20.1") resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) @@ -342,9 +338,9 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. } } - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } resp.Write(conn) diff --git a/pkg/handler/http/udp.go b/pkg/handler/http/udp.go index c35991d..11f910d 100644 --- a/pkg/handler/http/udp.go +++ b/pkg/handler/http/udp.go @@ -12,8 +12,8 @@ import ( "github.com/go-gost/gost/pkg/logger" ) -func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "cmd": "udp", }) @@ -30,49 +30,47 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add resp.StatusCode = http.StatusForbidden resp.Write(conn) - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } - h.logger.Error("UDP relay is diabled") + log.Error("UDP relay is diabled") return } resp.StatusCode = http.StatusOK - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } if err := resp.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } // obtain a udp connection c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { - h.logger.Error(err) + log.Error(err) return } defer c.Close() pc, ok := c.(net.PacketConn) if !ok { - h.logger.Errorf("wrong connection type") + log.Errorf("wrong connection type") return } relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). - WithLogger(h.logger) + WithLogger(log) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) relay.Run() - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) } diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 0e78bcb..6600272 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -35,7 +35,6 @@ func init() { type http2Handler struct { router *chain.Router authenticator auth.Authenticator - logger logger.Logger md metadata options handler.Options } @@ -64,7 +63,6 @@ func (h *http2Handler) Init(md md.Metadata) error { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return nil } @@ -72,29 +70,29 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() cc, ok := conn.(*http2_util.ServerConn) if !ok { - h.logger.Error("wrong connection type") + log.Error("wrong connection type") return } - h.roundTrip(ctx, cc.Writer(), cc.Request()) + h.roundTrip(ctx, cc.Writer(), cc.Request(), log) } // NOTE: there is an issue (golang/go#43989) will cause the client hangs // when server returns an non-200 status code, // May be fixed in go1.18. -func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request) { +func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) { // Try to get the actual host. // Compatible with GOST 2.x. if v := req.Header.Get("Gost-Target"); v != "" { @@ -122,21 +120,21 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { fields["user"] = u } - h.logger = h.logger.WithFields(fields) + log = log.WithFields(fields) - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpRequest(req, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } - h.logger.Infof("%s >> %s", req.RemoteAddr, addr) + log.Infof("%s >> %s", req.RemoteAddr, addr) - if h.md.proxyAgent != "" { - w.Header().Set("Proxy-Agent", h.md.proxyAgent) + for k := range h.md.header { + w.Header().Set(k, h.md.header.Get(k)) } if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { w.WriteHeader(http.StatusForbidden) - h.logger.Info("bypass: ", addr) + log.Info("bypass: ", addr) return } @@ -147,7 +145,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } - if !h.authenticate(w, req, resp) { + if !h.authenticate(w, req, resp, log) { return } @@ -157,7 +155,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req cc, err := h.router.Dial(ctx, "tcp", addr) if err != nil { - h.logger.Error(err) + log.Error(err) w.WriteHeader(http.StatusServiceUnavailable) return } @@ -174,30 +172,28 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req // we take over the underly connection conn, _, err := hj.Hijack() if err != nil { - h.logger.Error(err) + log.Error(err) w.WriteHeader(http.StatusInternalServerError) return } defer conn.Close() start := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(start), - }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return } start := time.Now() - h.logger.Infof("%s <-> %s", req.RemoteAddr, addr) + log.Infof("%s <-> %s", req.RemoteAddr, addr) handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(start), - }). - Infof("%s >-< %s", req.RemoteAddr, addr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >-< %s", req.RemoteAddr, addr) return } } @@ -241,7 +237,7 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri return cs[:s], cs[s+1:], true } -func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) { +func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) { u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization")) if h.authenticator == nil || h.authenticator.Authenticate(u, p) { return true @@ -261,7 +257,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp } r, err := http.Get(url) if err != nil { - h.logger.Error(err) + log.Error(err) break } resp = r @@ -269,13 +265,13 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp case "host": cc, err := net.Dial("tcp", pr.Value) if err != nil { - h.logger.Error(err) + log.Error(err) break } defer cc.Close() if err := h.forwardRequest(w, r, cc); err != nil { - h.logger.Error(err) + log.Error(err) } return case "file": @@ -303,7 +299,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp resp.Header.Add("Proxy-Connection", "close") } - h.logger.Info("proxy authentication required") + log.Info("proxy authentication required") } else { resp.Header = http.Header{} resp.Header.Set("Server", "nginx/1.20.1") @@ -313,9 +309,9 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp } } - if h.logger.IsLevelEnabled(logger.DebugLevel) { + if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) - h.logger.Debug(string(dump)) + log.Debug(string(dump)) } h.writeResponse(w, resp) diff --git a/pkg/handler/http2/metadata.go b/pkg/handler/http2/metadata.go index 7610b79..b5305ec 100644 --- a/pkg/handler/http2/metadata.go +++ b/pkg/handler/http2/metadata.go @@ -1,28 +1,31 @@ package http2 import ( + "net/http" "strings" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - proxyAgent string probeResistance *probeResistance - sni bool - enableUDP bool + header http.Header } func (h *http2Handler) parseMetadata(md mdata.Metadata) error { const ( - proxyAgent = "proxyAgent" + header = "header" probeResistKey = "probeResistance" knock = "knock" - sni = "sni" - enableUDP = "udp" ) - h.md.proxyAgent = mdata.GetString(md, proxyAgent) + if m := mdata.GetStringMapString(md, header); len(m) > 0 { + hd := http.Header{} + for k, v := range m { + hd.Add(k, v) + } + h.md.header = hd + } if v := mdata.GetString(md, probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { @@ -33,8 +36,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { } } } - h.md.sni = mdata.GetBool(md, sni) - h.md.enableUDP = mdata.GetBool(md, enableUDP) return nil } diff --git a/pkg/handler/redirect/handler.go b/pkg/handler/redirect/handler.go index 73878e4..cd1b880 100644 --- a/pkg/handler/redirect/handler.go +++ b/pkg/handler/redirect/handler.go @@ -8,7 +8,6 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) @@ -22,7 +21,6 @@ func init() { type redirectHandler struct { router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -50,7 +48,6 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -59,14 +56,14 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -83,35 +80,33 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { if network == "tcp" { dstAddr, conn, err = h.getOriginalDstAddr(conn) if err != nil { - h.logger.Error(err) + log.Error(err) return } } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", dstAddr, network), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), dstAddr) + log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr) if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) { - h.logger.Info("bypass: ", dstAddr) + log.Info("bypass: ", dstAddr) return } cc, err := h.router.Dial(ctx, network, dstAddr.String()) if err != nil { - h.logger.Error(err) + log.Error(err) return } defer cc.Close() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) + log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) } diff --git a/pkg/handler/redirect/handler_linux.go b/pkg/handler/redirect/handler_linux.go index 925bc80..fb9ee42 100644 --- a/pkg/handler/redirect/handler_linux.go +++ b/pkg/handler/redirect/handler_linux.go @@ -13,7 +13,6 @@ func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c ne tc, ok := conn.(*net.TCPConn) if !ok { err = errors.New("wrong connection type, must be TCP") - h.logger.Error(err) return } diff --git a/pkg/handler/relay/bind.go b/pkg/handler/relay/bind.go index 4df850a..0f97e78 100644 --- a/pkg/handler/relay/bind.go +++ b/pkg/handler/relay/bind.go @@ -10,16 +10,17 @@ import ( "github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "bind", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + log.Infof("%s >> %s", conn.RemoteAddr(), address) resp := relay.Response{ Version: relay.Version1, @@ -29,18 +30,18 @@ func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, a if !h.md.enableBind { resp.Status = relay.StatusForbidden resp.WriteTo(conn) - h.logger.Error("BIND is diabled") + log.Error("BIND is diabled") return } if network == "tcp" { - h.bindTCP(ctx, conn, network, address) + h.bindTCP(ctx, conn, network, address, log) } else { - h.bindUDP(ctx, conn, network, address) + h.bindUDP(ctx, conn, network, address, log) } } -func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string) { +func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -48,7 +49,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { - h.logger.Error(err) + log.Error(err) resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) return @@ -57,7 +58,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr af := &relay.AddrFeature{} err = af.ParseFrom(ln.Addr().String()) if err != nil { - h.logger.Warn(err) + log.Warn(err) } // Issue: may not reachable when host has multi-interface @@ -65,20 +66,20 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr af.AType = relay.AddrIPv4 resp.Features = append(resp.Features, af) if _, err := resp.WriteTo(conn); err != nil { - h.logger.Error(err) + log.Error(err) ln.Close() return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), }) - h.logger.Debugf("bind on %s OK", ln.Addr()) + log.Debugf("bind on %s OK", ln.Addr()) - h.serveTCPBind(ctx, conn, ln) + h.serveTCPBind(ctx, conn, ln, log) } -func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string) { +func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -87,7 +88,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr bindAddr, _ := net.ResolveUDPAddr(network, address) pc, err := net.ListenUDP(network, bindAddr) if err != nil { - h.logger.Error(err) + log.Error(err) return } defer pc.Close() @@ -95,7 +96,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr af := &relay.AddrFeature{} err = af.ParseFrom(pc.LocalAddr().String()) if err != nil { - h.logger.Warn(err) + log.Warn(err) } // Issue: may not reachable when host has multi-interface @@ -103,33 +104,32 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr af.AType = relay.AddrIPv4 resp.Features = append(resp.Features, af) if _, err := resp.WriteTo(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "bind": pc.LocalAddr().String(), }) - h.logger.Debugf("bind on %s OK", pc.LocalAddr()) + log.Debugf("bind on %s OK", pc.LocalAddr()) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) h.tunnelServerUDP( socks.UDPTunServerConn(conn), pc, + log, ) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) } -func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener) { +func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { // Upgrade connection to multiplex stream. session, err := mux.ClientSession(conn) if err != nil { - h.logger.Error(err) + log.Error(err) return } defer session.Close() @@ -139,7 +139,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L for { conn, err := session.Accept() if err != nil { - h.logger.Error(err) + log.Error(err) return } conn.Close() // we do not handle incoming connections. @@ -149,17 +149,22 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L for { rc, err := ln.Accept() if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debugf("peer %s accepted", rc.RemoteAddr()) + log.Debugf("peer %s accepted", rc.RemoteAddr()) go func(c net.Conn) { defer c.Close() + log = log.WithFields(map[string]interface{}{ + "local": ln.Addr().String(), + "remote": c.RemoteAddr().String(), + }) + sc, err := session.GetConn() if err != nil { - h.logger.Error(err) + log.Error(err) return } defer sc.Close() @@ -172,21 +177,20 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L Features: []relay.Feature{af}, } if _, err := resp.WriteTo(sc); err != nil { - h.logger.Error(err) + log.Error(err) return } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String()) + log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) handler.Transport(sc, c) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String()) + log.WithFields(map[string]interface{}{"duration": time.Since(t)}). + Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) }(rc) } } -func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { +func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn, log logger.Logger) (err error) { bufSize := h.md.udpBufferSize errc := make(chan error, 2) @@ -202,7 +206,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { } if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) + log.Warn("bypass: ", raddr) return nil } @@ -210,7 +214,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { return err } - h.logger.Debugf("%s >>> %s data: %d", + log.Debugf("%s >>> %s data: %d", c.LocalAddr(), raddr, n) return nil @@ -235,14 +239,14 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { } if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) + log.Warn("bypass: ", raddr) return nil } if _, err := tunnel.WriteTo((*b)[:n], raddr); err != nil { return err } - h.logger.Debugf("%s <<< %s data: %d", + log.Debugf("%s <<< %s data: %d", c.LocalAddr(), raddr, n) return nil diff --git a/pkg/handler/relay/connect.go b/pkg/handler/relay/connect.go index 5ce5e71..ca116d3 100644 --- a/pkg/handler/relay/connect.go +++ b/pkg/handler/relay/connect.go @@ -7,16 +7,17 @@ import ( "time" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "connect", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + log.Infof("%s >> %s", conn.RemoteAddr(), address) resp := relay.Response{ Version: relay.Version1, @@ -26,12 +27,12 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network if address == "" { resp.Status = relay.StatusBadRequest resp.WriteTo(conn) - h.logger.Error("target not specified") + log.Error("target not specified") return } if h.options.Bypass != nil && h.options.Bypass.Contains(address) { - h.logger.Info("bypass: ", address) + log.Info("bypass: ", address) resp.Status = relay.StatusForbidden resp.WriteTo(conn) return @@ -47,7 +48,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } } @@ -78,11 +79,9 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address) + log.Infof("%s <-> %s", conn.RemoteAddr(), address) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), address) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), address) } diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index aa44e84..11f0b3d 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -7,10 +7,11 @@ import ( "time" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) { +func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -19,15 +20,16 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if target == nil { resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) - h.logger.Error("no target available") + log.Error("no target available") return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", target.Addr(), network), + "cmd": "forward", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { @@ -37,7 +39,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network resp.Status = relay.StatusHostUnreachable resp.WriteTo(conn) - h.logger.Error(err) + log.Error(err) return } @@ -46,7 +48,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } } @@ -77,11 +79,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) } diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 9bb943f..7c8a158 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -10,7 +10,6 @@ import ( "github.com/go-gost/gost/pkg/chain" auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "github.com/go-gost/relay" @@ -24,7 +23,6 @@ type relayHandler struct { group *chain.NodeGroup router *chain.Router authenticator auth.Authenticator - logger logger.Logger md metadata options handler.Options } @@ -53,7 +51,6 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return nil } @@ -66,14 +63,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -84,14 +81,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { req := relay.Request{} if _, err := req.ReadFrom(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } conn.SetReadDeadline(time.Time{}) if req.Version != relay.Version1 { - h.logger.Error("bad version") + log.Error("bad version") return } @@ -109,7 +106,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } if user != "" { - h.logger = h.logger.WithFields(map[string]interface{}{"user": user}) + log = log.WithFields(map[string]interface{}{"user": user}) } resp := relay.Response{ @@ -119,7 +116,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { if h.authenticator != nil && !h.authenticator.Authenticate(user, pass) { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) - h.logger.Error("unauthorized") + log.Error("unauthorized") return } @@ -132,18 +129,18 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { if address != "" { resp.Status = relay.StatusForbidden resp.WriteTo(conn) - h.logger.Error("forward mode, connect is forbidden") + log.Error("forward mode, connect is forbidden") return } // forward mode - h.handleForward(ctx, conn, network) + h.handleForward(ctx, conn, network, log) return } switch req.Flags & relay.CmdMask { case 0, relay.CONNECT: - h.handleConnect(ctx, conn, network, address) + h.handleConnect(ctx, conn, network, address, log) case relay.BIND: - h.handleBind(ctx, conn, network, address) + h.handleBind(ctx, conn, network, address, log) } } diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go index 5181446..1d3754a 100644 --- a/pkg/handler/sni/handler.go +++ b/pkg/handler/sni/handler.go @@ -14,7 +14,6 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" dissector "github.com/go-gost/tls-dissector" @@ -27,7 +26,6 @@ func init() { type sniHandler struct { httpHandler handler.Handler router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -38,19 +36,13 @@ func NewHandler(opts ...handler.Option) handler.Handler { opt(&options) } - log := options.Logger - if log == nil { - log = logger.Default() - } - h := &sniHandler{ options: options, - logger: log, } if f := registry.GetHandler("http"); f != nil { v := append(opts, - handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"}))) + handler.LoggerOption(h.options.Logger.WithFields(map[string]interface{}{"type": "http"}))) h.httpHandler = f(v...) } @@ -85,21 +77,21 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() var hdr [dissector.RecordHeaderLen]byte if _, err := io.ReadFull(conn, hdr[:]); err != nil { - h.logger.Error(err) + log.Error(err) return } @@ -121,25 +113,25 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { buf := bufpool.Get(int(length) + dissector.RecordHeaderLen) defer bufpool.Put(buf) if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil { - h.logger.Error(err) + log.Error(err) return } copy(*buf, hdr[:]) opaque, host, err := h.decodeHost(bytes.NewReader(*buf)) if err != nil { - h.logger.Error(err) + log.Error(err) return } target := net.JoinHostPort(host, "443") - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": target, }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target) + log.Infof("%s >> %s", conn.RemoteAddr(), target) if h.options.Bypass != nil && h.options.Bypass.Contains(target) { - h.logger.Info("bypass: ", target) + log.Info("bypass: ", target) return } @@ -150,18 +142,16 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { defer cc.Close() if _, err := cc.Write(opaque); err != nil { - h.logger.Error(err) + log.Error(err) return } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target) + log.Infof("%s <-> %s", conn.RemoteAddr(), target) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), target) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target) } func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) { diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index d6eda4e..6bd6e61 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -23,7 +23,6 @@ func init() { type socks4Handler struct { router *chain.Router authenticator auth.Authenticator - logger logger.Logger md metadata options handler.Options } @@ -52,7 +51,6 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return nil } @@ -62,14 +60,14 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -80,10 +78,10 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { req, err := gosocks4.ReadRequest(conn) if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(req) + log.Debug(req) conn.SetReadDeadline(time.Time{}) @@ -91,33 +89,33 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { !h.authenticator.Authenticate(string(req.Userid), "") { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp.Write(conn) - h.logger.Debug(resp) + log.Debug(resp) return } switch req.Cmd { case gosocks4.CmdConnect: - h.handleConnect(ctx, conn, req) + h.handleConnect(ctx, conn, req, log) case gosocks4.CmdBind: h.handleBind(ctx, conn, req) default: - h.logger.Errorf("unknown cmd: %d", req.Cmd) + log.Errorf("unknown cmd: %d", req.Cmd) } } -func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request) { +func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) { addr := req.Addr.String() - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": addr, }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + log.Infof("%s >> %s", conn.RemoteAddr(), addr) if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { resp := gosocks4.NewReply(gosocks4.Rejected, nil) resp.Write(conn) - h.logger.Debug(resp) - h.logger.Info("bypass: ", addr) + log.Debug(resp) + log.Info("bypass: ", addr) return } @@ -125,7 +123,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g if err != nil { resp := gosocks4.NewReply(gosocks4.Failed, nil) resp.Write(conn) - h.logger.Debug(resp) + log.Debug(resp) return } @@ -133,19 +131,17 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g resp := gosocks4.NewReply(gosocks4.Granted, nil) if err := resp.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(resp) + log.Debug(resp) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) } func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) { diff --git a/pkg/handler/socks/v5/bind.go b/pkg/handler/socks/v5/bind.go index 3eed1f6..b5a6bf9 100644 --- a/pkg/handler/socks/v5/bind.go +++ b/pkg/handler/socks/v5/bind.go @@ -8,43 +8,44 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "bind", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + log.Infof("%s >> %s", conn.RemoteAddr(), address) if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply.Write(conn) - h.logger.Debug(reply) - h.logger.Error("BIND is diabled") + log.Debug(reply) + log.Error("BIND is diabled") return } // BIND does not support chain. - h.bindLocal(ctx, conn, network, address) + h.bindLocal(ctx, conn, network, address, log) } -func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string) { +func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { - h.logger.Error(err) + log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) } - h.logger.Debug(reply) + log.Debug(reply) return } socksAddr := gosocks5.Addr{} if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil { - h.logger.Warn(err) + log.Warn(err) } // Issue: may not reachable when host has multi-interface @@ -52,22 +53,22 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a socksAddr.Type = 0 reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) ln.Close() return } - h.logger.Debug(reply) + log.Debug(reply) - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), }) - h.logger.Debugf("bind on %s OK", ln.Addr()) + log.Debugf("bind on %s OK", ln.Addr()) - h.serveBind(ctx, conn, ln) + h.serveBind(ctx, conn, ln, log) } -func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener) { +func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { var rc net.Conn accept := func() <-chan error { errc := make(chan error, 1) @@ -105,38 +106,42 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis select { case err := <-accept(): if err != nil { - h.logger.Error(err) + log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) if err := reply.Write(pc2); err != nil { - h.logger.Error(err) + log.Error(err) } - h.logger.Debug(reply) + log.Debug(reply) return } defer rc.Close() - h.logger.Debugf("peer %s accepted", rc.RemoteAddr()) + log.Debugf("peer %s accepted", rc.RemoteAddr()) + + log = log.WithFields(map[string]interface{}{ + "local": rc.LocalAddr().String(), + "remote": rc.RemoteAddr().String(), + }) raddr := gosocks5.Addr{} raddr.ParseFrom(rc.RemoteAddr().String()) reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr) if err := reply.Write(pc2); err != nil { - h.logger.Error(err) + log.Error(err) } - h.logger.Debug(reply) + log.Debug(reply) start := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), raddr.String()) + log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr()) handler.Transport(pc2, rc) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(start)}). - Infof("%s >-< %s", conn.RemoteAddr(), raddr.String()) + log.WithFields(map[string]interface{}{"duration": time.Since(start)}). + Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr()) case err := <-pipe(): if err != nil { - h.logger.Error(err) + log.Error(err) } ln.Close() return diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go index 672abc7..8caf1a7 100644 --- a/pkg/handler/socks/v5/connect.go +++ b/pkg/handler/socks/v5/connect.go @@ -8,20 +8,21 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "connect", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + log.Infof("%s >> %s", conn.RemoteAddr(), address) if h.options.Bypass != nil && h.options.Bypass.Contains(address) { resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) resp.Write(conn) - h.logger.Debug(resp) - h.logger.Info("bypass: ", address) + log.Debug(resp) + log.Info("bypass: ", address) return } @@ -29,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) - h.logger.Debug(resp) + log.Debug(resp) return } @@ -37,17 +38,15 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ resp := gosocks5.NewReply(gosocks5.Succeeded, nil) if err := resp.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(resp) + log.Debug(resp) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address) + log.Infof("%s <-> %s", conn.RemoteAddr(), address) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), address) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), address) } diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index fb3e78e..34d2f08 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -10,7 +10,6 @@ import ( auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) @@ -23,7 +22,6 @@ func init() { type socks5Handler struct { selector gosocks5.Selector router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -44,7 +42,6 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { return } - h.logger = h.options.Logger h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, @@ -56,7 +53,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { h.selector = &serverSelector{ Authenticator: auth_util.AuthFromUsers(h.options.Auths...), TLSConfig: h.options.TLSConfig, - logger: h.logger, + logger: h.options.Logger, noTLS: h.md.noTLS, } @@ -68,14 +65,14 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -87,30 +84,30 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { conn = gosocks5.ServerConn(conn, h.selector) req, err := gosocks5.ReadRequest(conn) if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(req) + log.Debug(req) conn.SetReadDeadline(time.Time{}) address := req.Addr.String() switch req.Cmd { case gosocks5.CmdConnect: - h.handleConnect(ctx, conn, "tcp", address) + h.handleConnect(ctx, conn, "tcp", address, log) case gosocks5.CmdBind: - h.handleBind(ctx, conn, "tcp", address) + h.handleBind(ctx, conn, "tcp", address, log) case socks.CmdMuxBind: - h.handleMuxBind(ctx, conn, "tcp", address) + h.handleMuxBind(ctx, conn, "tcp", address, log) case gosocks5.CmdUdp: - h.handleUDP(ctx, conn) + h.handleUDP(ctx, conn, log) case socks.CmdUDPTun: - h.handleUDPTun(ctx, conn, "udp", address) + h.handleUDPTun(ctx, conn, "udp", address, log) default: - h.logger.Errorf("unknown cmd: %d", req.Cmd) + log.Errorf("unknown cmd: %d", req.Cmd) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) resp.Write(conn) - h.logger.Debug(resp) + log.Debug(resp) return } } diff --git a/pkg/handler/socks/v5/mbind.go b/pkg/handler/socks/v5/mbind.go index bb1fd98..109474a 100644 --- a/pkg/handler/socks/v5/mbind.go +++ b/pkg/handler/socks/v5/mbind.go @@ -9,43 +9,44 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "mbind", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + log.Infof("%s >> %s", conn.RemoteAddr(), address) if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply.Write(conn) - h.logger.Debug(reply) - h.logger.Error("BIND is diabled") + log.Debug(reply) + log.Error("BIND is diabled") return } - h.muxBindLocal(ctx, conn, network, address) + h.muxBindLocal(ctx, conn, network, address, log) } -func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string) { +func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { - h.logger.Error(err) + log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) } - h.logger.Debug(reply) + log.Debug(reply) return } socksAddr := gosocks5.Addr{} err = socksAddr.ParseFrom(ln.Addr().String()) if err != nil { - h.logger.Warn(err) + log.Warn(err) } // Issue: may not reachable when host has multi-interface @@ -53,26 +54,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network socksAddr.Type = 0 reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) ln.Close() return } - h.logger.Debug(reply) + log.Debug(reply) - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), }) - h.logger.Debugf("bind on %s OK", ln.Addr()) + log.Debugf("bind on %s OK", ln.Addr()) - h.serveMuxBind(ctx, conn, ln) + h.serveMuxBind(ctx, conn, ln, log) } -func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener) { +func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { // Upgrade connection to multiplex stream. session, err := mux.ClientSession(conn) if err != nil { - h.logger.Error(err) + log.Error(err) return } defer session.Close() @@ -82,7 +83,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. for { conn, err := session.Accept() if err != nil { - h.logger.Error(err) + log.Error(err) return } conn.Close() // we do not handle incoming connections. @@ -92,17 +93,21 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. for { rc, err := ln.Accept() if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debugf("peer %s accepted", rc.RemoteAddr()) + log.Debugf("peer %s accepted", rc.RemoteAddr()) go func(c net.Conn) { defer c.Close() + log = log.WithFields(map[string]interface{}{ + "local": rc.LocalAddr().String(), + "remote": rc.RemoteAddr().String(), + }) sc, err := session.GetConn() if err != nil { - h.logger.Error(err) + log.Error(err) return } defer sc.Close() @@ -113,18 +118,17 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. addr.ParseFrom(c.RemoteAddr().String()) reply := gosocks5.NewReply(gosocks5.Succeeded, &addr) if err := reply.Write(sc); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(reply) + log.Debug(reply) } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String()) + log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) handler.Transport(sc, c) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String()) + log.WithFields(map[string]interface{}{"duration": time.Since(t)}). + Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) }(rc) } } diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 07a77ff..619a92f 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -11,27 +11,28 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "cmd": "udp", }) if !h.md.enableUDP { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply.Write(conn) - h.logger.Debug(reply) - h.logger.Error("UDP relay is diabled") + log.Debug(reply) + log.Error("UDP relay is diabled") return } cc, err := net.ListenUDP("udp", nil) if err != nil { - h.logger.Error(err) + log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - h.logger.Debug(reply) + log.Debug(reply) return } defer cc.Close() @@ -42,41 +43,40 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(reply) + log.Debug(reply) - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()), }) - h.logger.Debugf("bind on %s OK", cc.LocalAddr()) + log.Debugf("bind on %s OK", cc.LocalAddr()) // obtain a udp connection c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { - h.logger.Error(err) + log.Error(err) return } defer c.Close() pc, ok := c.(net.PacketConn) if !ok { - h.logger.Errorf("wrong connection type") + log.Errorf("wrong connection type") return } relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). WithBypass(h.options.Bypass). - WithLogger(h.logger) + WithLogger(log) relay.SetBufferSize(h.md.udpBufferSize) go relay.Run() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) io.Copy(ioutil.Discard, conn) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). + log.WithFields(map[string]interface{}{"duration": time.Since(t)}). Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) } diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 339f368..2fcb816 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -8,54 +8,53 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) { - h.logger = h.logger.WithFields(map[string]interface{}{ +func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { + log = log.WithFields(map[string]interface{}{ "cmd": "udp-tun", }) if !h.md.enableUDP { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply.Write(conn) - h.logger.Debug(reply) - h.logger.Error("UDP relay is diabled") + log.Debug(reply) + log.Error("UDP relay is diabled") return } // dummy bind reply := gosocks5.NewReply(gosocks5.Succeeded, nil) if err := reply.Write(conn); err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger.Debug(reply) + log.Debug(reply) // obtain a udp connection c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { - h.logger.Error(err) + log.Error(err) return } defer c.Close() pc, ok := c.(net.PacketConn) if !ok { - h.logger.Errorf("wrong connection type") + log.Errorf("wrong connection type") return } relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). - WithLogger(h.logger) + WithLogger(log) relay.SetBufferSize(h.md.udpBufferSize) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) relay.Run() - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) } diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 2b8ad77..337d920 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -11,7 +11,6 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/ss" "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -24,7 +23,6 @@ func init() { type ssHandler struct { cipher core.Cipher router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -60,7 +58,6 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -69,14 +66,14 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -91,19 +88,19 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { addr := &gosocks5.Addr{} if _, err := addr.ReadFrom(conn); err != nil { - h.logger.Error(err) + log.Error(err) io.Copy(ioutil.Discard, conn) return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": addr.String(), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + log.Infof("%s >> %s", conn.RemoteAddr(), addr) if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) { - h.logger.Info("bypass: ", addr.String()) + log.Info("bypass: ", addr.String()) return } @@ -114,11 +111,9 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { defer cc.Close() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) } diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index 0ecc956..bed446d 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -23,7 +23,6 @@ func init() { type ssuHandler struct { cipher core.Cipher router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -60,7 +59,6 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -69,14 +67,14 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log := h.options.Logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -99,26 +97,25 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { // obtain a udp connection c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { - h.logger.Error(err) + log.Error(err) return } defer c.Close() cc, ok := c.(net.PacketConn) if !ok { - h.logger.Errorf("wrong connection type") + log.Errorf("wrong connection type") return } t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) - h.relayPacket(pc, cc) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). + log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) + h.relayPacket(pc, cc, log) + log.WithFields(map[string]interface{}{"duration": time.Since(t)}). Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) } -func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { +func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) { bufSize := h.md.bufferSize errc := make(chan error, 2) @@ -134,7 +131,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { } if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) { - h.logger.Warn("bypass: ", addr) + log.Warn("bypass: ", addr) return nil } @@ -142,7 +139,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { return err } - h.logger.Debugf("%s >>> %s data: %d", + log.Debugf("%s >>> %s data: %d", pc2.LocalAddr(), addr, n) return nil }() @@ -166,7 +163,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { } if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) + log.Warn("bypass: ", raddr) return nil } @@ -174,7 +171,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { return err } - h.logger.Debugf("%s <<< %s data: %d", + log.Debugf("%s <<< %s data: %d", pc2.LocalAddr(), raddr, n) return nil }() diff --git a/pkg/handler/sshd/handler.go b/pkg/handler/sshd/handler.go new file mode 100644 index 0000000..ee7d4ce --- /dev/null +++ b/pkg/handler/sshd/handler.go @@ -0,0 +1,238 @@ +package ssh + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "time" + + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X +const ( + ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 +) + +func init() { + registry.RegisterHandler("sshd", NewHandler) +} + +type forwardHandler struct { + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &forwardHandler{ + options: options, + } +} + +func (h *forwardHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + h.router = &chain.Router{ + Retries: h.options.Retries, + Chain: h.options.Chain, + Resolver: h.options.Resolver, + Hosts: h.options.Hosts, + Logger: h.options.Logger, + } + + return nil +} + +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + log := h.options.Logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + switch cc := conn.(type) { + case *sshd_util.DirectForwardConn: + h.handleDirectForward(ctx, cc, log) + case *sshd_util.RemoteForwardConn: + h.handleRemoteForward(ctx, cc, log) + default: + log.Error("wrong connection type") + return + } +} + +func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) { + targetAddr := conn.DstAddr() + + log = log.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", targetAddr, "tcp"), + "cmd": "connect", + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), targetAddr) + + if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) { + log.Infof("bypass %s", targetAddr) + return + } + + cc, err := h.router.Dial(ctx, "tcp", targetAddr) + if err != nil { + return + } + defer cc.Close() + + t := time.Now() + log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr) + handler.Transport(conn, cc) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", cc.LocalAddr(), targetAddr) +} + +func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) { + req := conn.Request() + + t := tcpipForward{} + ssh.Unmarshal(req.Payload, &t) + + network := "tcp" + addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port))) + + log = log.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", addr, network), + "cmd": "bind", + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), addr) + + // tie to the client connection + ln, err := net.Listen(network, addr) + if err != nil { + log.Error(err) + req.Reply(false, nil) + return + } + defer ln.Close() + + log = log.WithFields(map[string]interface{}{ + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), + }) + log.Debugf("bind on %s OK", ln.Addr()) + + err = func() error { + if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used + _, port, err := getHostPortFromAddr(ln.Addr()) + if err != nil { + return err + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(port)) + t.Port = uint32(port) + return req.Reply(true, b[:]) + } + return req.Reply(true, nil) + }() + if err != nil { + log.Error(err) + return + } + + sshConn := conn.Conn() + + go func() { + for { + cc, err := ln.Accept() + if err != nil { // Unable to accept new connection - listener is likely closed + return + } + + go func(conn net.Conn) { + defer conn.Close() + + log := log.WithFields(map[string]interface{}{ + "local": conn.LocalAddr().String(), + "remote": conn.RemoteAddr().String(), + }) + + p := directForward{} + var err error + + var portnum int + p.Host1 = t.Host + p.Port1 = t.Port + p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) + if err != nil { + return + } + + p.Port2 = uint32(portnum) + ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) + if err != nil { + log.Error("open forwarded channel: ", err) + return + } + defer ch.Close() + go ssh.DiscardRequests(reqs) + + t := time.Now() + log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr()) + handler.Transport(ch, conn) + log.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr()) + }(cc) + } + }() + + tm := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) + <-conn.Done() + log.WithFields(map[string]interface{}{ + "duration": time.Since(tm), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) +} + +func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { + host, portString, err := net.SplitHostPort(addr.String()) + if err != nil { + return + } + port, err = strconv.Atoi(portString) + return +} + +// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" +type directForward struct { + Host1 string + Port1 uint32 + Host2 string + Port2 uint32 +} + +func (p directForward) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) +} + +// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request +type tcpipForward struct { + Host string + Port uint32 +} diff --git a/pkg/handler/sshd/metadata.go b/pkg/handler/sshd/metadata.go new file mode 100644 index 0000000..725282f --- /dev/null +++ b/pkg/handler/sshd/metadata.go @@ -0,0 +1,12 @@ +package ssh + +import ( + mdata "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { +} + +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { + return +} diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index 8a5d223..44f9068 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -33,7 +33,6 @@ type tapHandler struct { exit chan struct{} cipher core.Cipher router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -71,7 +70,6 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -85,21 +83,22 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { defer os.Exit(0) defer conn.Close() + log := h.options.Logger cc, ok := conn.(*tap_util.Conn) if !ok || cc.Config() == nil { - h.logger.Error("invalid connection") + log.Error("invalid connection") return } start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -112,19 +111,19 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr()) if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) } - h.handleLoop(ctx, conn, raddr, cc.Config()) + h.handleLoop(ctx, conn, raddr, cc.Config(), log) } -func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config) { +func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) { var tempDelay time.Duration for { err := func() error { @@ -154,10 +153,10 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add pc = h.cipher.PacketConn(pc) } - return h.transport(conn, pc, addr) + return h.transport(conn, pc, addr, log) }() if err != nil { - h.logger.Error(err) + log.Error(err) } select { @@ -183,7 +182,7 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add } -func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr) error { +func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error { errc := make(chan error, 1) go func() { @@ -205,7 +204,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr dst := waterutil.MACDestination((*b)[:n]) eType := etherType(waterutil.MACEthertype((*b)[:n])) - h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n) + log.Debugf("%s >> %s %s %d", src, dst, eType, n) // client side, deliver frame directly. if raddr != nil { @@ -227,7 +226,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr addr = v.(net.Addr) } if addr == nil { - h.logger.Warnf("no route for %s -> %s %s %d", src, dst, eType, n) + log.Warnf("no route for %s -> %s %s %d", src, dst, eType, n) return nil } @@ -261,7 +260,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr dst := waterutil.MACDestination((*b)[:n]) eType := etherType(waterutil.MACEthertype((*b)[:n])) - h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n) + log.Debugf("%s >> %s %s %d", src, dst, eType, n) // client side, deliver frame to tap device. if raddr != nil { @@ -273,12 +272,12 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr rkey := hwAddrToTapRouteKey(src) if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { if actual.(net.Addr).String() != addr.String() { - h.logger.Debugf("update route: %s -> %s (old %s)", + log.Debugf("update route: %s -> %s (old %s)", src, addr, actual.(net.Addr)) h.routes.Store(rkey, addr) } } else { - h.logger.Debugf("new route: %s -> %s", src, addr) + log.Debugf("new route: %s -> %s", src, addr) } if waterutil.IsBroadcast(dst) { @@ -291,7 +290,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr } if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok { - h.logger.Debugf("find route: %s -> %s", dst, v) + log.Debugf("find route: %s -> %s", dst, v) _, err := conn.WriteTo((*b)[:n], v.(net.Addr)) return err } diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index eb0af13..c87e0b0 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -35,7 +35,6 @@ type tunHandler struct { exit chan struct{} cipher core.Cipher router *chain.Router - logger logger.Logger md metadata options handler.Options } @@ -73,7 +72,6 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { Hosts: h.options.Hosts, Logger: h.options.Logger, } - h.logger = h.options.Logger return } @@ -87,21 +85,23 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { defer os.Exit(0) defer conn.Close() + log := h.options.Logger + cc, ok := conn.(*tun_util.Conn) if !ok || cc.Config() == nil { - h.logger.Error("invalid connection") + log.Error("invalid connection") return } start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { - h.logger.WithFields(map[string]interface{}{ + log.WithFields(map[string]interface{}{ "duration": time.Since(start), }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() @@ -114,19 +114,19 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr()) if err != nil { - h.logger.Error(err) + log.Error(err) return } - h.logger = h.logger.WithFields(map[string]interface{}{ + log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) } - h.handleLoop(ctx, conn, raddr, cc.Config()) + h.handleLoop(ctx, conn, raddr, cc.Config(), log) } -func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config) { +func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) { var tempDelay time.Duration for { err := func() error { @@ -155,10 +155,10 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add pc = h.cipher.PacketConn(pc) } - return h.transport(conn, pc, addr) + return h.transport(conn, pc, addr, log) }() if err != nil { - h.logger.Error(err) + log.Error(err) } select { @@ -184,7 +184,7 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add } -func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { +func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error { errc := make(chan error, 1) go func() { @@ -206,10 +206,10 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr if waterutil.IsIPv4((*b)[:n]) { header, err := ipv4.ParseHeader((*b)[:n]) if err != nil { - h.logger.Error(err) + log.Error(err) return nil } - h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", + log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), header.Len, header.TotalLen, header.ID, header.Flags) @@ -217,17 +217,17 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) if err != nil { - h.logger.Warn(err) + log.Warn(err) return nil } - h.logger.Debugf("%s >> %s %s %d %d", + log.Debugf("%s >> %s %s %d %d", header.Src, header.Dst, ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) src, dst = header.Src, header.Dst } else { - h.logger.Warn("unknown packet, discarded") + log.Warn("unknown packet, discarded") return nil } @@ -239,11 +239,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr addr := h.findRouteFor(dst) if addr == nil { - h.logger.Warnf("no route for %s -> %s", src, dst) + log.Warnf("no route for %s -> %s", src, dst) return nil } - h.logger.Debugf("find route: %s -> %s", dst, addr) + log.Debugf("find route: %s -> %s", dst, addr) if _, err := conn.WriteTo((*b)[:n], addr); err != nil { return err @@ -274,11 +274,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr if waterutil.IsIPv4((*b)[:n]) { header, err := ipv4.ParseHeader((*b)[:n]) if err != nil { - h.logger.Warn(err) + log.Warn(err) return nil } - h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", + log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), header.Len, header.TotalLen, header.ID, header.Flags) @@ -286,18 +286,18 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) if err != nil { - h.logger.Warn(err) + log.Warn(err) return nil } - h.logger.Debugf("%s > %s %s %d %d", + log.Debugf("%s > %s %s %d %d", header.Src, header.Dst, ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) src, dst = header.Src, header.Dst } else { - h.logger.Warn("unknown packet, discarded") + log.Warn("unknown packet, discarded") return nil } @@ -310,16 +310,16 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr rkey := ipToTunRouteKey(src) if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { if actual.(net.Addr).String() != addr.String() { - h.logger.Debugf("update route: %s -> %s (old %s)", + log.Debugf("update route: %s -> %s (old %s)", src, addr, actual.(net.Addr)) h.routes.Store(rkey, addr) } } else { - h.logger.Warnf("no route for %s -> %s", src, addr) + log.Warnf("no route for %s -> %s", src, addr) } if addr := h.findRouteFor(dst); addr != nil { - h.logger.Debugf("find route: %s -> %s", dst, addr) + log.Debugf("find route: %s -> %s", dst, addr) _, err := conn.WriteTo((*b)[:n], addr) return err diff --git a/pkg/internal/util/sshd/conn.go b/pkg/internal/util/sshd/conn.go new file mode 100644 index 0000000..592a961 --- /dev/null +++ b/pkg/internal/util/sshd/conn.go @@ -0,0 +1,118 @@ +package sshd + +import ( + "context" + "errors" + "net" + "time" + + "golang.org/x/crypto/ssh" +) + +type DirectForwardConn struct { + conn ssh.Conn + channel ssh.Channel + dstAddr string +} + +func NewDirectForwardConn(conn ssh.Conn, channel ssh.Channel, dstAddr string) net.Conn { + return &DirectForwardConn{ + conn: conn, + channel: channel, + dstAddr: dstAddr, + } +} + +func (c *DirectForwardConn) Read(b []byte) (n int, err error) { + return c.channel.Read(b) +} + +func (c *DirectForwardConn) Write(b []byte) (n int, err error) { + return c.channel.Write(b) +} + +func (c *DirectForwardConn) Close() error { + return c.channel.Close() +} + +func (c *DirectForwardConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *DirectForwardConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *DirectForwardConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *DirectForwardConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *DirectForwardConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *DirectForwardConn) DstAddr() string { + return c.dstAddr +} + +type RemoteForwardConn struct { + ctx context.Context + conn ssh.Conn + req *ssh.Request +} + +func NewRemoteForwardConn(ctx context.Context, conn ssh.Conn, req *ssh.Request) net.Conn { + return &RemoteForwardConn{ + ctx: ctx, + conn: conn, + req: req, + } +} + +func (c *RemoteForwardConn) Conn() ssh.Conn { + return c.conn +} + +func (c *RemoteForwardConn) Request() *ssh.Request { + return c.req +} + +func (c *RemoteForwardConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *RemoteForwardConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *RemoteForwardConn) Close() error { + return &net.OpError{Op: "close", Net: "nop", Source: nil, Addr: nil, Err: errors.New("close not supported")} +} + +func (c *RemoteForwardConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *RemoteForwardConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *RemoteForwardConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *RemoteForwardConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *RemoteForwardConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *RemoteForwardConn) Done() <-chan struct{} { + return c.ctx.Done() +} diff --git a/pkg/listener/http3/listener.go b/pkg/listener/http3/listener.go index b8636ba..3156bd6 100644 --- a/pkg/listener/http3/listener.go +++ b/pkg/listener/http3/listener.go @@ -15,6 +15,7 @@ import ( func init() { registry.RegisterListener("http3", NewListener) + registry.RegisterListener("h3", NewListener) } type phtListener struct { diff --git a/pkg/listener/ssh/listener.go b/pkg/listener/ssh/listener.go index 88403ab..2ddc352 100644 --- a/pkg/listener/ssh/listener.go +++ b/pkg/listener/ssh/listener.go @@ -3,6 +3,7 @@ package ssh import ( "fmt" "net" + "time" auth_util "github.com/go-gost/gost/pkg/common/util/auth" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" @@ -29,13 +30,14 @@ type sshListener struct { } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &sshListener{ - addr: options.Addr, - logger: options.Logger, + addr: options.Addr, + logger: options.Logger, + options: options, } } @@ -96,6 +98,14 @@ func (l *sshListener) listenLoop() { } func (l *sshListener) serveConn(conn net.Conn) { + start := time.Now() + l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + l.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) if err != nil { l.logger.Error(err) @@ -122,8 +132,9 @@ func (l *sshListener) serveConn(conn net.Conn) { select { case l.cqueue <- cc: default: - cc.Close() l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr()) + newChannel.Reject(ssh.ResourceShortage, "connection queue is full") + cc.Close() } default: diff --git a/pkg/listener/sshd/listener.go b/pkg/listener/sshd/listener.go new file mode 100644 index 0000000..90f9aba --- /dev/null +++ b/pkg/listener/sshd/listener.go @@ -0,0 +1,199 @@ +package ssh + +import ( + "context" + "fmt" + "net" + "strconv" + "time" + + auth_util "github.com/go-gost/gost/pkg/common/util/auth" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X +const ( + DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 + RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 +) + +func init() { + registry.RegisterListener("sshd", NewListener) +} + +type sshdListener struct { + addr string + net.Listener + config *ssh.ServerConfig + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &sshdListener{ + addr: options.Addr, + logger: options.Logger, + options: options, + } +} + +func (l *sshdListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + ln, err := net.Listen("tcp", l.addr) + if err != nil { + return err + } + + l.Listener = ln + + authenticator := auth_util.AuthFromUsers(l.options.Auths...) + config := &ssh.ServerConfig{ + PasswordCallback: ssh_util.PasswordCallback(authenticator), + PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys), + } + config.AddHostKey(l.md.signer) + if authenticator == nil && len(l.md.authorizedKeys) == 0 { + config.NoClientAuth = true + } + + l.config = config + l.cqueue = make(chan net.Conn, l.md.backlog) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *sshdListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.cqueue: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *sshdListener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + l.logger.Error("accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.serveConn(conn) + } +} + +func (l *sshdListener) serveConn(conn net.Conn) { + start := time.Now() + l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + l.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) + if err != nil { + l.logger.Error(err) + conn.Close() + return + } + defer sc.Close() + + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case DirectForwardRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + l.logger.Warnf("could not accept channel: %s", err.Error()) + continue + } + p := directForward{} + ssh.Unmarshal(newChannel.ExtraData(), &p) + + l.logger.Debug(p.String()) + + if p.Host1 == "" { + p.Host1 = "" + } + + go ssh.DiscardRequests(requests) + cc := sshd_util.NewDirectForwardConn(sc, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1)))) + + select { + case l.cqueue <- cc: + default: + l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr()) + newChannel.Reject(ssh.ResourceShortage, "connection queue is full") + cc.Close() + } + + default: + l.logger.Warnf("unsupported channel type: %s", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t)) + } + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for req := range reqs { + switch req.Type { + case RemoteForwardRequest: + cc := sshd_util.NewRemoteForwardConn(ctx, sc, req) + + select { + case l.cqueue <- cc: + default: + l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr()) + req.Reply(false, []byte("connection queue is full")) + cc.Close() + } + default: + l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply) + req.Reply(false, nil) + } + } + }() + sc.Wait() +} + +// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" +type directForward struct { + Host1 string + Port1 uint32 + Host2 string + Port2 uint32 +} + +func (p directForward) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) +} diff --git a/pkg/handler/forward/ssh/metadata.go b/pkg/listener/sshd/metadata.go similarity index 69% rename from pkg/handler/forward/ssh/metadata.go rename to pkg/listener/sshd/metadata.go index 9f8a4bd..312e4e6 100644 --- a/pkg/handler/forward/ssh/metadata.go +++ b/pkg/listener/sshd/metadata.go @@ -9,16 +9,22 @@ import ( "golang.org/x/crypto/ssh" ) +const ( + defaultBacklog = 128 +) + type metadata struct { signer ssh.Signer authorizedKeys map[string]bool + backlog int } -func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { +func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { const ( authorizedKeys = "authorizedKeys" privateKeyFile = "privateKeyFile" passphrase = "passphrase" + backlog = "backlog" ) if key := mdata.GetString(md, privateKeyFile); key != "" { @@ -29,20 +35,20 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { pp := mdata.GetString(md, passphrase) if pp == "" { - h.md.signer, err = ssh.ParsePrivateKey(data) + l.md.signer, err = ssh.ParsePrivateKey(data) } else { - h.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) } if err != nil { return err } } - if h.md.signer == nil { + if l.md.signer == nil { signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey) if err != nil { return err } - h.md.signer = signer + l.md.signer = signer } if name := mdata.GetString(md, authorizedKeys); name != "" { @@ -50,7 +56,12 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { if err != nil { return err } - h.md.authorizedKeys = m + l.md.authorizedKeys = m + } + + l.md.backlog = mdata.GetInt(md, backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog } return