diff --git a/pkg/common/metrics/conn.go b/pkg/common/metrics/conn.go index b831b8a..d3fcee7 100644 --- a/pkg/common/metrics/conn.go +++ b/pkg/common/metrics/conn.go @@ -28,13 +28,13 @@ func WrapConn(service string, c net.Conn) net.Conn { func (c *serverConn) Read(b []byte) (n int, err error) { n, err = c.Conn.Read(b) - metrics.RequestInputBytes(c.service).Add(float64(n)) + metrics.InputBytes(c.service).Add(float64(n)) return } func (c *serverConn) Write(b []byte) (n int, err error) { n, err = c.Conn.Write(b) - metrics.RequestOutputBytes(c.service).Add(float64(n)) + metrics.OutputBytes(c.service).Add(float64(n)) return } @@ -52,13 +52,13 @@ func WrapPacketConn(service string, pc net.PacketConn) net.PacketConn { func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) - metrics.RequestInputBytes(c.service).Add(float64(n)) + metrics.InputBytes(c.service).Add(float64(n)) return } func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { n, err = c.PacketConn.WriteTo(p, addr) - metrics.RequestOutputBytes(c.service).Add(float64(n)) + metrics.OutputBytes(c.service).Add(float64(n)) return } @@ -98,7 +98,7 @@ func (c *udpConn) SetWriteBuffer(n int) error { func (c *udpConn) Read(b []byte) (n int, err error) { if nc, ok := c.PacketConn.(io.Reader); ok { n, err = nc.Read(b) - metrics.RequestInputBytes(c.service).Add(float64(n)) + metrics.InputBytes(c.service).Add(float64(n)) return } err = errUnsupport @@ -107,14 +107,14 @@ func (c *udpConn) Read(b []byte) (n int, err error) { func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) - metrics.RequestInputBytes(c.service).Add(float64(n)) + metrics.InputBytes(c.service).Add(float64(n)) return } func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { if nc, ok := c.PacketConn.(readUDP); ok { n, addr, err = nc.ReadFromUDP(b) - metrics.RequestInputBytes(c.service).Add(float64(n)) + metrics.InputBytes(c.service).Add(float64(n)) return } err = errUnsupport @@ -124,7 +124,7 @@ func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { if nc, ok := c.PacketConn.(readUDP); ok { n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob) - metrics.RequestInputBytes(c.service).Add(float64(n + oobn)) + metrics.InputBytes(c.service).Add(float64(n + oobn)) return } err = errUnsupport @@ -134,7 +134,7 @@ func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAd func (c *udpConn) Write(b []byte) (n int, err error) { if nc, ok := c.PacketConn.(io.Writer); ok { n, err = nc.Write(b) - metrics.RequestOutputBytes(c.service).Add(float64(n)) + metrics.OutputBytes(c.service).Add(float64(n)) return } err = errUnsupport @@ -143,14 +143,14 @@ func (c *udpConn) Write(b []byte) (n int, err error) { func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { n, err = c.PacketConn.WriteTo(p, addr) - metrics.RequestOutputBytes(c.service).Add(float64(n)) + metrics.OutputBytes(c.service).Add(float64(n)) return } func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { if nc, ok := c.PacketConn.(writeUDP); ok { n, err = nc.WriteToUDP(b, addr) - metrics.RequestOutputBytes(c.service).Add(float64(n)) + metrics.OutputBytes(c.service).Add(float64(n)) return } err = errUnsupport @@ -160,7 +160,7 @@ func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { if nc, ok := c.PacketConn.(writeUDP); ok { n, oobn, err = nc.WriteMsgUDP(b, oob, addr) - metrics.RequestOutputBytes(c.service).Add(float64(n + oobn)) + metrics.OutputBytes(c.service).Add(float64(n + oobn)) return } err = errUnsupport diff --git a/pkg/handler/relay.go b/pkg/common/net/relay/relay.go similarity index 99% rename from pkg/handler/relay.go rename to pkg/common/net/relay/relay.go index 5dd0af5..b2ea340 100644 --- a/pkg/handler/relay.go +++ b/pkg/common/net/relay/relay.go @@ -1,4 +1,4 @@ -package handler +package relay import ( "net" diff --git a/pkg/handler/transport.go b/pkg/common/net/transport.go similarity index 97% rename from pkg/handler/transport.go rename to pkg/common/net/transport.go index 9ab42f5..6abc26b 100644 --- a/pkg/handler/transport.go +++ b/pkg/common/net/transport.go @@ -1,4 +1,4 @@ -package handler +package net import ( "bufio" diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 145ec4a..9fb4d72 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -8,6 +8,7 @@ import ( "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -85,7 +86,7 @@ func (h *autoHandler) Init(md md.Metadata) error { return nil } -func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) error { log := h.options.Logger.WithFields(map[string]any{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), @@ -104,26 +105,27 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { if err != nil { log.Error(err) conn.Close() - return + return err } - conn = handler.NewBufferReaderConn(conn, br) + conn = netpkg.NewBufferReaderConn(conn, br) switch b[0] { case gosocks4.Ver4: // socks4 if h.socks4Handler != nil { - h.socks4Handler.Handle(ctx, conn) + return h.socks4Handler.Handle(ctx, conn) } case gosocks5.Ver5: // socks5 if h.socks5Handler != nil { - h.socks5Handler.Handle(ctx, conn) + return h.socks5Handler.Handle(ctx, conn) } case relay.Version1: // relay if h.relayHandler != nil { - h.relayHandler.Handle(ctx, conn) + return h.relayHandler.Handle(ctx, conn) } default: // http if h.httpHandler != nil { - h.httpHandler.Handle(ctx, conn) + return h.httpHandler.Handle(ctx, conn) } } + return nil } diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index c474f9a..222f571 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -98,7 +98,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { return } -func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -120,18 +120,20 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { n, err := conn.Read(*b) if err != nil { log.Error(err) - return + return err } reply, err := h.exchange(ctx, (*b)[:n], log) if err != nil { - return + return err } defer bufpool.Put(&reply) if _, err = conn.Write(reply); err != nil { log.Error(err) + return err } + return nil } func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index edacab9..1e17a13 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -2,11 +2,13 @@ package local import ( "context" + "errors" "fmt" "net" "time" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -59,7 +61,7 @@ func (h *forwardHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -77,8 +79,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { target := h.group.Next() if target == nil { - log.Error("no target available") - return + err := errors.New("target not available") + log.Error(err) + return err } network := "tcp" @@ -98,15 +101,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. target.Marker.Mark() - return + return err } defer cc.Close() target.Marker.Reset() t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) + + return nil } diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index fead5c8..8847757 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -2,11 +2,13 @@ package remote import ( "context" + "errors" "fmt" "net" "time" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -53,7 +55,7 @@ func (h *forwardHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -71,8 +73,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { target := h.group.Next() if target == nil { - log.Error("no target available") - return + err := errors.New("target not available") + log.Error(err) + return err } network := "tcp" @@ -92,15 +95,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. target.Marker.Mark() - return + return err } defer cc.Close() target.Marker.Reset() t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) + + return nil } diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1016d91..61a5d73 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -10,7 +10,7 @@ import ( type Handler interface { Init(metadata.Metadata) error - Handle(context.Context, net.Conn) + Handle(context.Context, net.Conn) error } type Forwarder interface { diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index daf3528..01a095d 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -17,6 +17,7 @@ import ( "github.com/asaskevich/govalidator" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -57,7 +58,7 @@ func (h *httpHandler) Init(md md.Metadata) error { return nil } -func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -75,18 +76,14 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { log.Error(err) - return + return err } defer req.Body.Close() - h.handleRequest(ctx, conn, req, log) + return h.handleRequest(ctx, conn, req, log) } -func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) { - if req == nil { - return - } - +func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) error { if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { req.URL.Scheme = "http" } @@ -149,30 +146,27 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } log.Info("bypass: ", addr) - resp.Write(conn) - return + return resp.Write(conn) } if !h.authenticate(conn, req, resp, log) { - return + return nil } if network == "udp" { - h.handleUDP(ctx, conn, network, req.Host, log) - return + return h.handleUDP(ctx, conn, network, req.Host, log) } if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { resp.StatusCode = http.StatusBadRequest - resp.Write(conn) if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) log.Debug(string(dump)) } - return + return resp.Write(conn) } req.Header.Del("Proxy-Authorization") @@ -180,13 +174,12 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt cc, err := h.router.Dial(ctx, network, addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable - resp.Write(conn) if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) log.Debug(string(dump)) } - return + return resp.Write(conn) } defer cc.Close() @@ -200,22 +193,24 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } if err = resp.Write(conn); err != nil { log.Error(err) - return + return err } } else { req.Header.Del("Proxy-Connection") if err = req.Write(cc); err != nil { log.Error(err) - return + return err } } start := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil } func (h *httpHandler) decodeServerName(s string) (string, error) { @@ -292,7 +287,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. defer cc.Close() req.Write(cc) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) return case "file": f, _ := os.Open(pr.Value) diff --git a/pkg/handler/http/udp.go b/pkg/handler/http/udp.go index 3343af9..b9e4fc0 100644 --- a/pkg/handler/http/udp.go +++ b/pkg/handler/http/udp.go @@ -2,17 +2,18 @@ package http import ( "context" + "errors" "net" "net/http" "net/http/httputil" "time" + "github.com/go-gost/gost/pkg/common/net/relay" "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" ) -func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "cmd": "udp", }) @@ -28,15 +29,15 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add if !h.md.enableUDP { resp.StatusCode = http.StatusForbidden - resp.Write(conn) if log.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) log.Debug(string(dump)) } - log.Error("UDP relay is diabled") - return + log.Error("http: UDP relay is disabled") + + return resp.Write(conn) } resp.StatusCode = http.StatusOK @@ -46,24 +47,25 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add } if err := resp.Write(conn); err != nil { log.Error(err) - return + return err } // obtain a udp connection c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { log.Error(err) - return + return err } defer c.Close() pc, ok := c.(net.PacketConn) if !ok { - log.Errorf("wrong connection type") - return + err = errors.New("wrong connection type") + log.Error(err) + return err } - relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + relay := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). WithLogger(log) @@ -73,4 +75,6 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + + return nil } diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index c6bc572..cbecebb 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -19,6 +19,7 @@ import ( "time" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" @@ -60,7 +61,7 @@ func (h *http2Handler) Init(md md.Metadata) error { return nil } -func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { +func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -77,16 +78,17 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { cc, ok := conn.(*http2_util.ServerConn) if !ok { - log.Error("wrong connection type") - return + err := errors.New("wrong connection type") + log.Error(err) + return err } - h.roundTrip(ctx, cc.Writer(), cc.Request(), log) + return h.roundTrip(ctx, cc.Writer(), cc.Request(), log) } // NOTE: there is an issue (golang/go#43989) will cause the client hangs // when server returns an non-200 status code, // May be fixed in go1.18. -func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) { +func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error { // Try to get the actual host. // Compatible with GOST 2.x. if v := req.Header.Get("Gost-Target"); v != "" { @@ -129,7 +131,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { w.WriteHeader(http.StatusForbidden) log.Info("bypass: ", addr) - return + return nil } resp := &http.Response{ @@ -140,7 +142,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req } if !h.authenticate(w, req, resp, log) { - return + return nil } // delete the proxy related headers. @@ -151,7 +153,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req if err != nil { log.Error(err) w.WriteHeader(http.StatusServiceUnavailable) - return + return err } defer cc.Close() @@ -168,28 +170,31 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) - return + return err } defer conn.Close() start := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) - return + return nil } start := time.Now() log.Infof("%s <-> %s", req.RemoteAddr, addr) - handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc) + netpkg.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", req.RemoteAddr, addr) - return + return nil } + + // TODO: forward request + return nil } func (h *http2Handler) decodeServerName(s string) (string, error) { diff --git a/pkg/handler/redirect/handler.go b/pkg/handler/redirect/handler.go index f4fd7b1..0e8db62 100644 --- a/pkg/handler/redirect/handler.go +++ b/pkg/handler/redirect/handler.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -49,7 +50,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { return } -func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -78,7 +79,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { dstAddr, conn, err = h.getOriginalDstAddr(conn) if err != nil { log.Error(err) - return + return err } } @@ -90,20 +91,22 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) { log.Info("bypass: ", dstAddr) - return + return nil } cc, err := h.router.Dial(ctx, network, dstAddr.String()) if err != nil { log.Error(err) - return + return err } defer cc.Close() t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) + + return nil } diff --git a/pkg/handler/relay/bind.go b/pkg/handler/relay/bind.go index 726cdfa..d8c0a02 100644 --- a/pkg/handler/relay/bind.go +++ b/pkg/handler/relay/bind.go @@ -6,14 +6,15 @@ import ( "net" "time" + netpkg "github.com/go-gost/gost/pkg/common/net" + net_relay "github.com/go-gost/gost/pkg/common/net/relay" "github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "bind", @@ -28,19 +29,19 @@ func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, a if !h.md.enableBind { resp.Status = relay.StatusForbidden - resp.WriteTo(conn) - log.Error("BIND is diabled") - return + log.Error("relay: BIND is disabled") + _, err := resp.WriteTo(conn) + return err } if network == "tcp" { - h.bindTCP(ctx, conn, network, address, log) + return h.bindTCP(ctx, conn, network, address, log) } else { - h.bindUDP(ctx, conn, network, address, log) + return h.bindUDP(ctx, conn, network, address, log) } } -func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -51,7 +52,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr log.Error(err) resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) - return + return err } af := &relay.AddrFeature{} @@ -67,7 +68,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr if _, err := resp.WriteTo(conn); err != nil { log.Error(err) ln.Close() - return + return err } log = log.WithFields(map[string]any{ @@ -75,10 +76,10 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr }) log.Debugf("bind on %s OK", ln.Addr()) - h.serveTCPBind(ctx, conn, ln, log) + return h.serveTCPBind(ctx, conn, ln, log) } -func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -88,7 +89,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr pc, err := net.ListenUDP(network, bindAddr) if err != nil { log.Error(err) - return + return err } defer pc.Close() @@ -104,7 +105,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr resp.Features = append(resp.Features, af) if _, err := resp.WriteTo(conn); err != nil { log.Error(err) - return + return err } log = log.WithFields(map[string]any{ @@ -112,25 +113,26 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr }) log.Debugf("bind on %s OK", pc.LocalAddr()) - relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + r := net_relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). WithLogger(log) - relay.SetBufferSize(h.md.udpBufferSize) + r.SetBufferSize(h.md.udpBufferSize) t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - relay.Run() + r.Run() log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + return nil } -func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { +func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error { // Upgrade connection to multiplex stream. session, err := mux.ClientSession(conn) if err != nil { log.Error(err) - return + return err } defer session.Close() @@ -150,7 +152,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L rc, err := ln.Accept() if err != nil { log.Error(err) - return + return err } log.Debugf("peer %s accepted", rc.RemoteAddr()) @@ -183,7 +185,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L t := time.Now() log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) - handler.Transport(sc, c) + netpkg.Transport(sc, c) log.WithFields(map[string]any{"duration": time.Since(t)}). Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) }(rc) diff --git a/pkg/handler/relay/connect.go b/pkg/handler/relay/connect.go index 495e275..9705826 100644 --- a/pkg/handler/relay/connect.go +++ b/pkg/handler/relay/connect.go @@ -2,16 +2,17 @@ package relay import ( "context" + "errors" "fmt" "net" "time" - "github.com/go-gost/gost/pkg/handler" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "connect", @@ -27,29 +28,30 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network if address == "" { resp.Status = relay.StatusBadRequest resp.WriteTo(conn) - log.Error("target not specified") - return + err := errors.New("target not specified") + log.Error(err) + return err } if h.options.Bypass != nil && h.options.Bypass.Contains(address) { log.Info("bypass: ", address) resp.Status = relay.StatusForbidden - resp.WriteTo(conn) - return + _, err := resp.WriteTo(conn) + return err } cc, err := h.router.Dial(ctx, network, address) if err != nil { resp.Status = relay.StatusNetworkUnreachable resp.WriteTo(conn) - return + return err } defer cc.Close() if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { log.Error(err) - return + return err } } @@ -61,7 +63,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network if !h.md.noDelay { // cache the header if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return + return err } } conn = rc @@ -72,7 +74,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network if !h.md.noDelay { // cache the header if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return + return err } } conn = rc @@ -80,8 +82,10 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) + + return nil } diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index 8ddbf9a..d7ad8e0 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -2,16 +2,17 @@ package relay import ( "context" + "errors" "fmt" "net" "time" - "github.com/go-gost/gost/pkg/handler" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) { +func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -20,8 +21,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if target == nil { resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) - log.Error("no target available") - return + err := errors.New("target not available") + log.Error(err) + return err } log = log.WithFields(map[string]any{ @@ -41,7 +43,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network resp.WriteTo(conn) log.Error(err) - return + return err } defer cc.Close() target.Marker.Reset() @@ -49,7 +51,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { log.Error(err) - return + return err } } @@ -61,7 +63,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if !h.md.noDelay { // cache the header if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return + return err } } conn = rc @@ -72,7 +74,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if !h.md.noDelay { // cache the header if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return + return err } } conn = rc @@ -80,8 +82,10 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) + + return nil } diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 96a75f9..52acd01 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -2,6 +2,7 @@ package relay import ( "context" + "errors" "net" "strconv" "time" @@ -13,6 +14,11 @@ import ( "github.com/go-gost/relay" ) +var ( + ErrBadVersion = errors.New("relay: bad version") + ErrUnknownCmd = errors.New("relay: unknown command") +) + func init() { registry.HandlerRegistry().Register("relay", NewHandler) } @@ -53,7 +59,7 @@ func (h *relayHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -76,14 +82,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { req := relay.Request{} if _, err := req.ReadFrom(conn); err != nil { log.Error(err) - return + return err } conn.SetReadDeadline(time.Time{}) if req.Version != relay.Version1 { - log.Error("bad version") - return + err := ErrBadVersion + log.Error(err) + return err } var user, pass string @@ -109,9 +116,9 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) { resp.Status = relay.StatusUnauthorized - resp.WriteTo(conn) log.Error("unauthorized") - return + _, err := resp.WriteTo(conn) + return err } network := "tcp" @@ -122,19 +129,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { if h.group != nil { if address != "" { resp.Status = relay.StatusForbidden - resp.WriteTo(conn) log.Error("forward mode, connect is forbidden") - return + _, err := resp.WriteTo(conn) + return err } // forward mode - h.handleForward(ctx, conn, network, log) - return + return h.handleForward(ctx, conn, network, log) } switch req.Flags & relay.CmdMask { case 0, relay.CONNECT: - h.handleConnect(ctx, conn, network, address, log) + return h.handleConnect(ctx, conn, network, address, log) case relay.BIND: - h.handleBind(ctx, conn, network, address, log) + return h.handleBind(ctx, conn, network, address, log) } + return ErrUnknownCmd } diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go index d40be0c..a082bfe 100644 --- a/pkg/handler/sni/handler.go +++ b/pkg/handler/sni/handler.go @@ -13,6 +13,7 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -70,7 +71,7 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { return nil } -func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -89,7 +90,7 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { var hdr [dissector.RecordHeaderLen]byte if _, err := io.ReadFull(conn, hdr[:]); err != nil { log.Error(err) - return + return err } if hdr[0] != dissector.Handshake { @@ -100,9 +101,9 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { } if h.httpHandler != nil { - h.httpHandler.Handle(ctx, conn) + return h.httpHandler.Handle(ctx, conn) } - return + return nil } length := binary.BigEndian.Uint16(hdr[3:5]) @@ -111,14 +112,14 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { defer bufpool.Put(buf) if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil { log.Error(err) - return + return err } copy(*buf, hdr[:]) opaque, host, err := h.decodeHost(bytes.NewReader(*buf)) if err != nil { log.Error(err) - return + return err } target := net.JoinHostPort(host, "443") @@ -129,26 +130,29 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { if h.options.Bypass != nil && h.options.Bypass.Contains(target) { log.Info("bypass: ", target) - return + return nil } cc, err := h.router.Dial(ctx, "tcp", target) if err != nil { - return + log.Error(err) + return err } defer cc.Close() if _, err := cc.Write(opaque); err != nil { log.Error(err) - return + return err } t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), target) + + return nil } func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) { diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index b370c26..0dcad44 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -2,17 +2,24 @@ package v4 import ( "context" + "errors" "net" "time" "github.com/go-gost/gosocks4" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) +var ( + ErrUnknownCmd = errors.New("socks4: unknown command") + ErrUnimplemented = errors.New("socks4: unimplemented") +) + func init() { registry.HandlerRegistry().Register("socks4", NewHandler) registry.HandlerRegistry().Register("socks4a", NewHandler) @@ -48,7 +55,7 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { return nil } -func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { +func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -72,7 +79,7 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { req, err := gosocks4.ReadRequest(conn) if err != nil { log.Error(err) - return + return err } log.Debug(req) @@ -81,22 +88,23 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { if h.options.Auther != nil && !h.options.Auther.Authenticate(string(req.Userid), "") { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) - resp.Write(conn) log.Debug(resp) - return + return resp.Write(conn) } switch req.Cmd { case gosocks4.CmdConnect: - h.handleConnect(ctx, conn, req, log) + return h.handleConnect(ctx, conn, req, log) case gosocks4.CmdBind: - h.handleBind(ctx, conn, req) + return h.handleBind(ctx, conn, req) default: - log.Errorf("unknown cmd: %d", req.Cmd) + err = ErrUnknownCmd + log.Error(err) + return err } } -func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) { +func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) error { addr := req.Addr.String() log = log.WithFields(map[string]any{ @@ -106,10 +114,9 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { resp := gosocks4.NewReply(gosocks4.Rejected, nil) - resp.Write(conn) log.Debug(resp) log.Info("bypass: ", addr) - return + return resp.Write(conn) } cc, err := h.router.Dial(ctx, "tcp", addr) @@ -117,7 +124,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g resp := gosocks4.NewReply(gosocks4.Failed, nil) resp.Write(conn) log.Debug(resp) - return + return err } defer cc.Close() @@ -125,18 +132,21 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g resp := gosocks4.NewReply(gosocks4.Granted, nil) if err := resp.Write(conn); err != nil { log.Error(err) - return + return err } log.Debug(resp) t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil } -func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) { +func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) error { // TODO: bind + return ErrUnimplemented } diff --git a/pkg/handler/socks/v5/bind.go b/pkg/handler/socks/v5/bind.go index 7b30635..56c9dd7 100644 --- a/pkg/handler/socks/v5/bind.go +++ b/pkg/handler/socks/v5/bind.go @@ -7,11 +7,11 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/handler" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "bind", @@ -21,17 +21,16 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) log.Debug(reply) - log.Error("BIND is diabled") - return + log.Error("socks5: BIND is disabled") + return reply.Write(conn) } // BIND does not support chain. - h.bindLocal(ctx, conn, network, address, log) + return h.bindLocal(ctx, conn, network, address, log) } -func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { log.Error(err) @@ -40,7 +39,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a log.Error(err) } log.Debug(reply) - return + return err } socksAddr := gosocks5.Addr{} @@ -55,7 +54,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a if err := reply.Write(conn); err != nil { log.Error(err) ln.Close() - return + return err } log.Debug(reply) @@ -66,6 +65,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a log.Debugf("bind on %s OK", ln.Addr()) h.serveBind(ctx, conn, ln, log) + return nil } func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { @@ -95,7 +95,7 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis defer close(errc) defer pc1.Close() - errc <- handler.Transport(conn, pc1) + errc <- netpkg.Transport(conn, pc1) }() return errc @@ -135,7 +135,7 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis start := time.Now() log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr()) - handler.Transport(pc2, rc) + netpkg.Transport(pc2, rc) log.WithFields(map[string]any{"duration": time.Since(start)}). Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr()) diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go index 51c702b..766e722 100644 --- a/pkg/handler/socks/v5/connect.go +++ b/pkg/handler/socks/v5/connect.go @@ -7,11 +7,11 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/handler" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "connect", @@ -20,18 +20,17 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ if h.options.Bypass != nil && h.options.Bypass.Contains(address) { resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) - resp.Write(conn) log.Debug(resp) log.Info("bypass: ", address) - return + return resp.Write(conn) } cc, err := h.router.Dial(ctx, network, address) if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) - resp.Write(conn) log.Debug(resp) - return + resp.Write(conn) + return err } defer cc.Close() @@ -39,14 +38,16 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ resp := gosocks5.NewReply(gosocks5.Succeeded, nil) if err := resp.Write(conn); err != nil { log.Error(err) - return + return err } log.Debug(resp) t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) + + return nil } diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index 05918c8..2bde241 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -2,6 +2,7 @@ package v5 import ( "context" + "errors" "net" "time" @@ -13,6 +14,10 @@ import ( "github.com/go-gost/gost/pkg/registry" ) +var ( + ErrUnknownCmd = errors.New("socks5: unknown command") +) + func init() { registry.HandlerRegistry().Register("socks5", NewHandler) registry.HandlerRegistry().Register("socks", NewHandler) @@ -56,7 +61,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { return } -func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { +func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -81,7 +86,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { req, err := gosocks5.ReadRequest(conn) if err != nil { log.Error(err) - return + return err } log.Debug(req) conn.SetReadDeadline(time.Time{}) @@ -90,20 +95,21 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { switch req.Cmd { case gosocks5.CmdConnect: - h.handleConnect(ctx, conn, "tcp", address, log) + return h.handleConnect(ctx, conn, "tcp", address, log) case gosocks5.CmdBind: - h.handleBind(ctx, conn, "tcp", address, log) + return h.handleBind(ctx, conn, "tcp", address, log) case socks.CmdMuxBind: - h.handleMuxBind(ctx, conn, "tcp", address, log) + return h.handleMuxBind(ctx, conn, "tcp", address, log) case gosocks5.CmdUdp: - h.handleUDP(ctx, conn, log) + return h.handleUDP(ctx, conn, log) case socks.CmdUDPTun: - h.handleUDPTun(ctx, conn, "udp", address, log) + return h.handleUDPTun(ctx, conn, "udp", address, log) default: - log.Errorf("unknown cmd: %d", req.Cmd) + err = ErrUnknownCmd + log.Error(err) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) resp.Write(conn) log.Debug(resp) - return + return err } } diff --git a/pkg/handler/socks/v5/mbind.go b/pkg/handler/socks/v5/mbind.go index 6f32803..ccf138e 100644 --- a/pkg/handler/socks/v5/mbind.go +++ b/pkg/handler/socks/v5/mbind.go @@ -7,12 +7,12 @@ import ( "time" "github.com/go-gost/gosocks5" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/common/util/mux" - "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "mbind", @@ -22,16 +22,15 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, networ if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) log.Debug(reply) - log.Error("BIND is diabled") - return + log.Error("socks5: BIND is disabled") + return reply.Write(conn) } - h.muxBindLocal(ctx, conn, network, address, log) + return h.muxBindLocal(ctx, conn, network, address, log) } -func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { log.Error(err) @@ -40,7 +39,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network log.Error(err) } log.Debug(reply) - return + return err } socksAddr := gosocks5.Addr{} @@ -56,7 +55,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network if err := reply.Write(conn); err != nil { log.Error(err) ln.Close() - return + return err } log.Debug(reply) @@ -66,15 +65,15 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network log.Debugf("bind on %s OK", ln.Addr()) - h.serveMuxBind(ctx, conn, ln, log) + return h.serveMuxBind(ctx, conn, ln, log) } -func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { +func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error { // Upgrade connection to multiplex stream. session, err := mux.ClientSession(conn) if err != nil { log.Error(err) - return + return err } defer session.Close() @@ -94,7 +93,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. rc, err := ln.Accept() if err != nil { log.Error(err) - return + return err } log.Debugf("peer %s accepted", rc.RemoteAddr()) @@ -126,7 +125,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. t := time.Now() log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) - handler.Transport(sc, c) + netpkg.Transport(sc, c) log.WithFields(map[string]any{"duration": time.Since(t)}). Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) }(rc) diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 12f4c6a..923d4cd 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -2,6 +2,7 @@ package v5 import ( "context" + "errors" "fmt" "io" "io/ioutil" @@ -9,22 +10,21 @@ import ( "time" "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/common/net/relay" "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) { +func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) error { log = log.WithFields(map[string]any{ "cmd": "udp", }) if !h.md.enableUDP { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) log.Debug(reply) - log.Error("UDP relay is diabled") - return + log.Error("socks5: UDP relay is disabled") + return reply.Write(conn) } cc, err := net.ListenUDP("udp", nil) @@ -33,7 +33,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) log.Debug(reply) - return + return err } defer cc.Close() @@ -44,7 +44,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) if err := reply.Write(conn); err != nil { log.Error(err) - return + return err } log.Debug(reply) @@ -57,26 +57,29 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { log.Error(err) - return + return err } defer c.Close() pc, ok := c.(net.PacketConn) if !ok { - log.Errorf("wrong connection type") - return + err := errors.New("socks5: wrong connection type") + log.Error(err) + return err } - relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). + r := relay.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). WithBypass(h.options.Bypass). WithLogger(log) - relay.SetBufferSize(h.md.udpBufferSize) + r.SetBufferSize(h.md.udpBufferSize) - go relay.Run() + go r.Run() t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) io.Copy(ioutil.Discard, conn) log.WithFields(map[string]any{"duration": time.Since(t)}). Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) + + return nil } diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 68c6a03..76693e0 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -6,12 +6,12 @@ import ( "time" "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/common/net/relay" "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" ) -func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { +func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { log = log.WithFields(map[string]any{ "cmd": "udp-tun", }) @@ -25,26 +25,24 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network // relay mode if !h.md.enableUDP { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) log.Debug(reply) - log.Error("UDP relay is diabled") - return + log.Error("socks5: UDP relay is disabled") + return reply.Write(conn) } } else { // BIND mode if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) - reply.Write(conn) log.Debug(reply) - log.Error("BIND is diabled") - return + log.Error("socks5: BIND is disabled") + return reply.Write(conn) } } pc, err := net.ListenUDP(network, bindAddr) if err != nil { log.Error(err) - return + return err } defer pc.Close() @@ -53,20 +51,22 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) if err := reply.Write(conn); err != nil { log.Error(err) - return + return err } log.Debug(reply) log.Debugf("bind on %s OK", pc.LocalAddr()) - relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + r := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc). WithBypass(h.options.Bypass). WithLogger(log) - relay.SetBufferSize(h.md.udpBufferSize) + r.SetBufferSize(h.md.udpBufferSize) t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - relay.Run() + r.Run() log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + + return nil } diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 5880f1e..e046b82 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/common/util/ss" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" @@ -59,7 +60,7 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { return } -func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -87,7 +88,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { if _, err := addr.ReadFrom(conn); err != nil { log.Error(err) io.Copy(ioutil.Discard, conn) - return + return err } log = log.WithFields(map[string]any{ @@ -98,19 +99,21 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) { log.Info("bypass: ", addr.String()) - return + return nil } cc, err := h.router.Dial(ctx, "tcp", addr.String()) if err != nil { - return + return err } defer cc.Close() t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil } diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index dcedacc..30b5647 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -2,6 +2,7 @@ package ss import ( "context" + "errors" "net" "time" @@ -60,7 +61,7 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { return } -func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() start := time.Now() @@ -95,14 +96,15 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { log.Error(err) - return + return err } defer c.Close() cc, ok := c.(net.PacketConn) if !ok { - log.Errorf("wrong connection type") - return + err := errors.New("ss: wrong connection type") + log.Error(err) + return err } t := time.Now() @@ -110,6 +112,8 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { h.relayPacket(pc, cc, log) log.WithFields(map[string]any{"duration": time.Since(t)}). Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr()) + + return nil } func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) { diff --git a/pkg/handler/sshd/handler.go b/pkg/handler/sshd/handler.go index 2f72db7..ec64dec 100644 --- a/pkg/handler/sshd/handler.go +++ b/pkg/handler/sshd/handler.go @@ -3,12 +3,14 @@ package ssh import ( "context" "encoding/binary" + "errors" "fmt" "net" "strconv" "time" "github.com/go-gost/gost/pkg/chain" + netpkg "github.com/go-gost/gost/pkg/common/net" "github.com/go-gost/gost/pkg/handler" sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" "github.com/go-gost/gost/pkg/logger" @@ -56,7 +58,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return nil } -func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error { defer conn.Close() log := h.options.Logger.WithFields(map[string]any{ @@ -66,16 +68,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { switch cc := conn.(type) { case *sshd_util.DirectForwardConn: - h.handleDirectForward(ctx, cc, log) + return h.handleDirectForward(ctx, cc, log) case *sshd_util.RemoteForwardConn: - h.handleRemoteForward(ctx, cc, log) + return h.handleRemoteForward(ctx, cc, log) default: - log.Error("wrong connection type") - return + err := errors.New("sshd: wrong connection type") + log.Error(err) + return err } } -func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) { +func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) error { targetAddr := conn.DstAddr() log = log.WithFields(map[string]any{ @@ -87,28 +90,33 @@ func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_uti if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) { log.Infof("bypass %s", targetAddr) - return + return nil } cc, err := h.router.Dial(ctx, "tcp", targetAddr) if err != nil { - return + return err } defer cc.Close() t := time.Now() log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr) - handler.Transport(conn, cc) + netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", cc.LocalAddr(), targetAddr) + + return nil } -func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) { +func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) error { req := conn.Request() t := tcpipForward{} - ssh.Unmarshal(req.Payload, &t) + if err := ssh.Unmarshal(req.Payload, &t); err != nil { + log.Error(err) + return err + } network := "tcp" addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port))) @@ -125,7 +133,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti if err != nil { log.Error(err) req.Reply(false, nil) - return + return err } defer ln.Close() @@ -149,7 +157,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti }() if err != nil { log.Error(err) - return + return err } sshConn := conn.Conn() @@ -191,7 +199,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti t := time.Now() log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr()) - handler.Transport(ch, conn) + netpkg.Transport(ch, conn) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr()) @@ -205,6 +213,8 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti log.WithFields(map[string]any{ "duration": time.Since(tm), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil } func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index 26ea1fd..25c2fa8 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -76,15 +76,16 @@ func (h *tapHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) error { defer os.Exit(0) defer conn.Close() log := h.options.Logger cc, ok := conn.(*tap_util.Conn) if !ok || cc.Config() == nil { - log.Error("invalid connection") - return + err := errors.New("tap: wrong connection type") + log.Error(err) + return err } start := time.Now() @@ -109,7 +110,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { log.Error(err) - return + return err } log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), @@ -118,6 +119,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { } h.handleLoop(ctx, conn, raddr, cc.Config(), log) + return nil } func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) { diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index 3dac462..88e5374 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -78,7 +78,7 @@ func (h *tunHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) error { defer os.Exit(0) defer conn.Close() @@ -86,8 +86,9 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { cc, ok := conn.(*tun_util.Conn) if !ok || cc.Config() == nil { - log.Error("invalid connection") - return + err := errors.New("tun: wrong connection type") + log.Error(err) + return err } start := time.Now() @@ -112,7 +113,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { log.Error(err) - return + return err } log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), @@ -121,6 +122,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { } h.handleLoop(ctx, conn, raddr, cc.Config(), log) + return nil } func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) { diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 4cea365..bc7d5a8 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -5,16 +5,33 @@ import ( ) var ( - global = newMetrics() + metrics = newMetrics() ) +type Gauge interface { + Inc() + Dec() + Add(float64) + Set(float64) +} + +type Counter interface { + Inc() + Add(float64) +} + +type Observer interface { + Observe(float64) +} + type Metrics struct { - services prometheus.Gauge - requests *prometheus.CounterVec - requestsInFlight *prometheus.GaugeVec - requestSeconds *prometheus.HistogramVec - requestInputBytes *prometheus.CounterVec - requestOutputBytes *prometheus.CounterVec + services prometheus.Gauge + requests *prometheus.CounterVec + requestsInFlight *prometheus.GaugeVec + requestSeconds *prometheus.HistogramVec + inputBytes *prometheus.CounterVec + outputBytes *prometheus.CounterVec + handlerErrors *prometheus.CounterVec } func newMetrics() *Metrics { @@ -44,20 +61,26 @@ func newMetrics() *Metrics { Name: "gost_service_request_duration_seconds", Help: "Distribution of request latencies", Buckets: []float64{ - .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 20, 30, + .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 60, }, }, []string{"service"}), - requestInputBytes: prometheus.NewCounterVec( + inputBytes: prometheus.NewCounterVec( prometheus.CounterOpts{ - Name: "gost_service_request_transfer_input_bytes_total", - Help: "Total request input data transfer size in bytes", + Name: "gost_service_transfer_input_bytes_total", + Help: "Total service input data transfer size in bytes", }, []string{"service"}), - requestOutputBytes: prometheus.NewCounterVec( + outputBytes: prometheus.NewCounterVec( prometheus.CounterOpts{ - Name: "gost_service_request_transfer_output_bytes_total", - Help: "Total request output data transfer size in bytes", + Name: "gost_service_transfer_output_bytes_total", + Help: "Total service output data transfer size in bytes", + }, + []string{"service"}), + handlerErrors: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gost_service_handler_errors_total", + Help: "Total service handler errors", }, []string{"service"}), } @@ -65,31 +88,35 @@ func newMetrics() *Metrics { prometheus.MustRegister(m.requests) prometheus.MustRegister(m.requestsInFlight) prometheus.MustRegister(m.requestSeconds) - prometheus.MustRegister(m.requestInputBytes) - prometheus.MustRegister(m.requestOutputBytes) + prometheus.MustRegister(m.inputBytes) + prometheus.MustRegister(m.outputBytes) return m } -func Services() prometheus.Gauge { - return global.services +func Services() Gauge { + return metrics.services } -func Requests(service string) prometheus.Counter { - return global.requests.With(prometheus.Labels{"service": service}) +func Requests(service string) Counter { + return metrics.requests.With(prometheus.Labels{"service": service}) } -func RequestsInFlight(service string) prometheus.Gauge { - return global.requestsInFlight.With(prometheus.Labels{"service": service}) +func RequestsInFlight(service string) Gauge { + return metrics.requestsInFlight.With(prometheus.Labels{"service": service}) } -func RequestSeconds(service string) prometheus.Observer { - return global.requestSeconds.With(prometheus.Labels{"service": service}) +func RequestSeconds(service string) Observer { + return metrics.requestSeconds.With(prometheus.Labels{"service": service}) } -func RequestInputBytes(service string) prometheus.Counter { - return global.requestInputBytes.With(prometheus.Labels{"service": service}) +func InputBytes(service string) Counter { + return metrics.inputBytes.With(prometheus.Labels{"service": service}) } -func RequestOutputBytes(service string) prometheus.Counter { - return global.requestOutputBytes.With(prometheus.Labels{"service": service}) +func OutputBytes(service string) Counter { + return metrics.outputBytes.With(prometheus.Labels{"service": service}) +} + +func HandlerErrors(service string) Counter { + return metrics.handlerErrors.With(prometheus.Labels{"service": service}) } diff --git a/pkg/service/service.go b/pkg/service/service.go index 19cd7c0..f36b6cf 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -10,7 +10,6 @@ import ( "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/metrics" - "github.com/prometheus/client_golang/prometheus" ) type options struct { @@ -105,11 +104,14 @@ func (s *service) Serve() error { metrics.RequestsInFlight(s.name).Inc() defer metrics.RequestsInFlight(s.name).Dec() - timer := prometheus.NewTimer( - metrics.RequestSeconds(s.name)) - defer timer.ObserveDuration() + start := time.Now() + defer func() { + metrics.RequestSeconds(s.name).Observe(time.Since(start).Seconds()) + }() - s.handler.Handle(context.Background(), conn) + if err := s.handler.Handle(context.Background(), conn); err != nil { + metrics.HandlerErrors(s.name).Inc() + } }() } }