add handler error metrics

This commit is contained in:
ginuerzh 2022-03-05 16:37:45 +08:00
parent e587b4df7c
commit ee72cea036
31 changed files with 404 additions and 293 deletions

View File

@ -28,13 +28,13 @@ func WrapConn(service string, c net.Conn) net.Conn {
func (c *serverConn) Read(b []byte) (n int, err error) { func (c *serverConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b) n, err = c.Conn.Read(b)
metrics.RequestInputBytes(c.service).Add(float64(n)) metrics.InputBytes(c.service).Add(float64(n))
return return
} }
func (c *serverConn) Write(b []byte) (n int, err error) { func (c *serverConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b) n, err = c.Conn.Write(b)
metrics.RequestOutputBytes(c.service).Add(float64(n)) metrics.OutputBytes(c.service).Add(float64(n))
return return
} }
@ -52,13 +52,13 @@ func WrapPacketConn(service string, pc net.PacketConn) net.PacketConn {
func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p) n, addr, err = c.PacketConn.ReadFrom(p)
metrics.RequestInputBytes(c.service).Add(float64(n)) metrics.InputBytes(c.service).Add(float64(n))
return return
} }
func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
n, err = c.PacketConn.WriteTo(p, addr) n, err = c.PacketConn.WriteTo(p, addr)
metrics.RequestOutputBytes(c.service).Add(float64(n)) metrics.OutputBytes(c.service).Add(float64(n))
return return
} }
@ -98,7 +98,7 @@ func (c *udpConn) SetWriteBuffer(n int) error {
func (c *udpConn) Read(b []byte) (n int, err error) { func (c *udpConn) Read(b []byte) (n int, err error) {
if nc, ok := c.PacketConn.(io.Reader); ok { if nc, ok := c.PacketConn.(io.Reader); ok {
n, err = nc.Read(b) n, err = nc.Read(b)
metrics.RequestInputBytes(c.service).Add(float64(n)) metrics.InputBytes(c.service).Add(float64(n))
return return
} }
err = errUnsupport err = errUnsupport
@ -107,14 +107,14 @@ func (c *udpConn) Read(b []byte) (n int, err error) {
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p) n, addr, err = c.PacketConn.ReadFrom(p)
metrics.RequestInputBytes(c.service).Add(float64(n)) metrics.InputBytes(c.service).Add(float64(n))
return return
} }
func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
if nc, ok := c.PacketConn.(readUDP); ok { if nc, ok := c.PacketConn.(readUDP); ok {
n, addr, err = nc.ReadFromUDP(b) n, addr, err = nc.ReadFromUDP(b)
metrics.RequestInputBytes(c.service).Add(float64(n)) metrics.InputBytes(c.service).Add(float64(n))
return return
} }
err = errUnsupport err = errUnsupport
@ -124,7 +124,7 @@ func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
if nc, ok := c.PacketConn.(readUDP); ok { if nc, ok := c.PacketConn.(readUDP); ok {
n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob) n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob)
metrics.RequestInputBytes(c.service).Add(float64(n + oobn)) metrics.InputBytes(c.service).Add(float64(n + oobn))
return return
} }
err = errUnsupport err = errUnsupport
@ -134,7 +134,7 @@ func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAd
func (c *udpConn) Write(b []byte) (n int, err error) { func (c *udpConn) Write(b []byte) (n int, err error) {
if nc, ok := c.PacketConn.(io.Writer); ok { if nc, ok := c.PacketConn.(io.Writer); ok {
n, err = nc.Write(b) n, err = nc.Write(b)
metrics.RequestOutputBytes(c.service).Add(float64(n)) metrics.OutputBytes(c.service).Add(float64(n))
return return
} }
err = errUnsupport err = errUnsupport
@ -143,14 +143,14 @@ func (c *udpConn) Write(b []byte) (n int, err error) {
func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
n, err = c.PacketConn.WriteTo(p, addr) n, err = c.PacketConn.WriteTo(p, addr)
metrics.RequestOutputBytes(c.service).Add(float64(n)) metrics.OutputBytes(c.service).Add(float64(n))
return return
} }
func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) {
if nc, ok := c.PacketConn.(writeUDP); ok { if nc, ok := c.PacketConn.(writeUDP); ok {
n, err = nc.WriteToUDP(b, addr) n, err = nc.WriteToUDP(b, addr)
metrics.RequestOutputBytes(c.service).Add(float64(n)) metrics.OutputBytes(c.service).Add(float64(n))
return return
} }
err = errUnsupport err = errUnsupport
@ -160,7 +160,7 @@ func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) {
func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) {
if nc, ok := c.PacketConn.(writeUDP); ok { if nc, ok := c.PacketConn.(writeUDP); ok {
n, oobn, err = nc.WriteMsgUDP(b, oob, addr) n, oobn, err = nc.WriteMsgUDP(b, oob, addr)
metrics.RequestOutputBytes(c.service).Add(float64(n + oobn)) metrics.OutputBytes(c.service).Add(float64(n + oobn))
return return
} }
err = errUnsupport err = errUnsupport

View File

@ -1,4 +1,4 @@
package handler package relay
import ( import (
"net" "net"

View File

@ -1,4 +1,4 @@
package handler package net
import ( import (
"bufio" "bufio"

View File

@ -8,6 +8,7 @@ import (
"github.com/go-gost/gosocks4" "github.com/go-gost/gosocks4"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
@ -85,7 +86,7 @@ func (h *autoHandler) Init(md md.Metadata) error {
return nil return nil
} }
func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) error {
log := h.options.Logger.WithFields(map[string]any{ log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(), "remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(), "local": conn.LocalAddr().String(),
@ -104,26 +105,27 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
if err != nil { if err != nil {
log.Error(err) log.Error(err)
conn.Close() conn.Close()
return return err
} }
conn = handler.NewBufferReaderConn(conn, br) conn = netpkg.NewBufferReaderConn(conn, br)
switch b[0] { switch b[0] {
case gosocks4.Ver4: // socks4 case gosocks4.Ver4: // socks4
if h.socks4Handler != nil { if h.socks4Handler != nil {
h.socks4Handler.Handle(ctx, conn) return h.socks4Handler.Handle(ctx, conn)
} }
case gosocks5.Ver5: // socks5 case gosocks5.Ver5: // socks5
if h.socks5Handler != nil { if h.socks5Handler != nil {
h.socks5Handler.Handle(ctx, conn) return h.socks5Handler.Handle(ctx, conn)
} }
case relay.Version1: // relay case relay.Version1: // relay
if h.relayHandler != nil { if h.relayHandler != nil {
h.relayHandler.Handle(ctx, conn) return h.relayHandler.Handle(ctx, conn)
} }
default: // http default: // http
if h.httpHandler != nil { if h.httpHandler != nil {
h.httpHandler.Handle(ctx, conn) return h.httpHandler.Handle(ctx, conn)
} }
} }
return nil
} }

View File

@ -98,7 +98,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
return return
} }
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -120,18 +120,20 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
n, err := conn.Read(*b) n, err := conn.Read(*b)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
reply, err := h.exchange(ctx, (*b)[:n], log) reply, err := h.exchange(ctx, (*b)[:n], log)
if err != nil { if err != nil {
return return err
} }
defer bufpool.Put(&reply) defer bufpool.Put(&reply)
if _, err = conn.Write(reply); err != nil { if _, err = conn.Write(reply); err != nil {
log.Error(err) log.Error(err)
return err
} }
return nil
} }
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {

View File

@ -2,11 +2,13 @@ package local
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
@ -59,7 +61,7 @@ func (h *forwardHandler) Forward(group *chain.NodeGroup) {
h.group = group h.group = group
} }
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -77,8 +79,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
target := h.group.Next() target := h.group.Next()
if target == nil { if target == nil {
log.Error("no target available") err := errors.New("target not available")
return log.Error(err)
return err
} }
network := "tcp" network := "tcp"
@ -98,15 +101,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
// TODO: the router itself may be failed due to the failed node in the router, // TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation. // the dead marker may be a wrong operation.
target.Marker.Mark() target.Marker.Mark()
return return err
} }
defer cc.Close() defer cc.Close()
target.Marker.Reset() target.Marker.Reset()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
} }

View File

@ -2,11 +2,13 @@ package remote
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
@ -53,7 +55,7 @@ func (h *forwardHandler) Forward(group *chain.NodeGroup) {
h.group = group h.group = group
} }
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -71,8 +73,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
target := h.group.Next() target := h.group.Next()
if target == nil { if target == nil {
log.Error("no target available") err := errors.New("target not available")
return log.Error(err)
return err
} }
network := "tcp" network := "tcp"
@ -92,15 +95,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
// TODO: the router itself may be failed due to the failed node in the router, // TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation. // the dead marker may be a wrong operation.
target.Marker.Mark() target.Marker.Mark()
return return err
} }
defer cc.Close() defer cc.Close()
target.Marker.Reset() target.Marker.Reset()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
} }

View File

@ -10,7 +10,7 @@ import (
type Handler interface { type Handler interface {
Init(metadata.Metadata) error Init(metadata.Metadata) error
Handle(context.Context, net.Conn) Handle(context.Context, net.Conn) error
} }
type Forwarder interface { type Forwarder interface {

View File

@ -17,6 +17,7 @@ import (
"github.com/asaskevich/govalidator" "github.com/asaskevich/govalidator"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -57,7 +58,7 @@ func (h *httpHandler) Init(md md.Metadata) error {
return nil return nil
} }
func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -75,18 +76,14 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) {
req, err := http.ReadRequest(bufio.NewReader(conn)) req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer req.Body.Close() defer req.Body.Close()
h.handleRequest(ctx, conn, req, log) return h.handleRequest(ctx, conn, req, log)
}
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) {
if req == nil {
return
} }
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) error {
if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) {
req.URL.Scheme = "http" req.URL.Scheme = "http"
} }
@ -149,30 +146,27 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
} }
log.Info("bypass: ", addr) log.Info("bypass: ", addr)
resp.Write(conn) return resp.Write(conn)
return
} }
if !h.authenticate(conn, req, resp, log) { if !h.authenticate(conn, req, resp, log) {
return return nil
} }
if network == "udp" { if network == "udp" {
h.handleUDP(ctx, conn, network, req.Host, log) return h.handleUDP(ctx, conn, network, req.Host, log)
return
} }
if req.Method == "PRI" || if req.Method == "PRI" ||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") { (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp.StatusCode = http.StatusBadRequest resp.StatusCode = http.StatusBadRequest
resp.Write(conn)
if log.IsLevelEnabled(logger.DebugLevel) { if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false) dump, _ := httputil.DumpResponse(resp, false)
log.Debug(string(dump)) log.Debug(string(dump))
} }
return return resp.Write(conn)
} }
req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Authorization")
@ -180,13 +174,12 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
cc, err := h.router.Dial(ctx, network, addr) cc, err := h.router.Dial(ctx, network, addr)
if err != nil { if err != nil {
resp.StatusCode = http.StatusServiceUnavailable resp.StatusCode = http.StatusServiceUnavailable
resp.Write(conn)
if log.IsLevelEnabled(logger.DebugLevel) { if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false) dump, _ := httputil.DumpResponse(resp, false)
log.Debug(string(dump)) log.Debug(string(dump))
} }
return return resp.Write(conn)
} }
defer cc.Close() defer cc.Close()
@ -200,22 +193,24 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
} }
if err = resp.Write(conn); err != nil { if err = resp.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
} else { } else {
req.Header.Del("Proxy-Connection") req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil { if err = req.Write(cc); err != nil {
log.Error(err) log.Error(err)
return return err
} }
} }
start := time.Now() start := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr) log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(start), "duration": time.Since(start),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr) }).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
} }
func (h *httpHandler) decodeServerName(s string) (string, error) { func (h *httpHandler) decodeServerName(s string) (string, error) {
@ -292,7 +287,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
defer cc.Close() defer cc.Close()
req.Write(cc) req.Write(cc)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
return return
case "file": case "file":
f, _ := os.Open(pr.Value) f, _ := os.Open(pr.Value)

View File

@ -2,17 +2,18 @@ package http
import ( import (
"context" "context"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"time" "time"
"github.com/go-gost/gost/pkg/common/net/relay"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"cmd": "udp", "cmd": "udp",
}) })
@ -28,15 +29,15 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
if !h.md.enableUDP { if !h.md.enableUDP {
resp.StatusCode = http.StatusForbidden resp.StatusCode = http.StatusForbidden
resp.Write(conn)
if log.IsLevelEnabled(logger.DebugLevel) { if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false) dump, _ := httputil.DumpResponse(resp, false)
log.Debug(string(dump)) log.Debug(string(dump))
} }
log.Error("UDP relay is diabled")
return log.Error("http: UDP relay is disabled")
return resp.Write(conn)
} }
resp.StatusCode = http.StatusOK resp.StatusCode = http.StatusOK
@ -46,24 +47,25 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
} }
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
// obtain a udp connection // obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer c.Close() defer c.Close()
pc, ok := c.(net.PacketConn) pc, ok := c.(net.PacketConn)
if !ok { if !ok {
log.Errorf("wrong connection type") err = errors.New("wrong connection type")
return log.Error(err)
return err
} }
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). relay := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass). WithBypass(h.options.Bypass).
WithLogger(log) WithLogger(log)
@ -73,4 +75,6 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
return nil
} }

View File

@ -19,6 +19,7 @@ import (
"time" "time"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
http2_util "github.com/go-gost/gost/pkg/internal/util/http2" http2_util "github.com/go-gost/gost/pkg/internal/util/http2"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
@ -60,7 +61,7 @@ func (h *http2Handler) Init(md md.Metadata) error {
return nil return nil
} }
func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -77,16 +78,17 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) {
cc, ok := conn.(*http2_util.ServerConn) cc, ok := conn.(*http2_util.ServerConn)
if !ok { if !ok {
log.Error("wrong connection type") err := errors.New("wrong connection type")
return log.Error(err)
return err
} }
h.roundTrip(ctx, cc.Writer(), cc.Request(), log) return h.roundTrip(ctx, cc.Writer(), cc.Request(), log)
} }
// NOTE: there is an issue (golang/go#43989) will cause the client hangs // NOTE: there is an issue (golang/go#43989) will cause the client hangs
// when server returns an non-200 status code, // when server returns an non-200 status code,
// May be fixed in go1.18. // May be fixed in go1.18.
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) { func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error {
// Try to get the actual host. // Try to get the actual host.
// Compatible with GOST 2.x. // Compatible with GOST 2.x.
if v := req.Header.Get("Gost-Target"); v != "" { if v := req.Header.Get("Gost-Target"); v != "" {
@ -129,7 +131,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
log.Info("bypass: ", addr) log.Info("bypass: ", addr)
return return nil
} }
resp := &http.Response{ resp := &http.Response{
@ -140,7 +142,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
} }
if !h.authenticate(w, req, resp, log) { if !h.authenticate(w, req, resp, log) {
return return nil
} }
// delete the proxy related headers. // delete the proxy related headers.
@ -151,7 +153,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
if err != nil { if err != nil {
log.Error(err) log.Error(err)
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return err
} }
defer cc.Close() defer cc.Close()
@ -168,28 +170,31 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
if err != nil { if err != nil {
log.Error(err) log.Error(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return err
} }
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr) log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(start), "duration": time.Since(start),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr) }).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return return nil
} }
start := time.Now() start := time.Now()
log.Infof("%s <-> %s", req.RemoteAddr, addr) log.Infof("%s <-> %s", req.RemoteAddr, addr)
handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc) netpkg.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(start), "duration": time.Since(start),
}).Infof("%s >-< %s", req.RemoteAddr, addr) }).Infof("%s >-< %s", req.RemoteAddr, addr)
return return nil
} }
// TODO: forward request
return nil
} }
func (h *http2Handler) decodeServerName(s string) (string, error) { func (h *http2Handler) decodeServerName(s string) (string, error) {

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
@ -49,7 +50,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) {
return return
} }
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -78,7 +79,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
dstAddr, conn, err = h.getOriginalDstAddr(conn) dstAddr, conn, err = h.getOriginalDstAddr(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
} }
@ -90,20 +91,22 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) { if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
log.Info("bypass: ", dstAddr) log.Info("bypass: ", dstAddr)
return return nil
} }
cc, err := h.router.Dial(ctx, network, dstAddr.String()) cc, err := h.router.Dial(ctx, network, dstAddr.String())
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer cc.Close() defer cc.Close()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) }).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
} }

View File

@ -6,14 +6,15 @@ import (
"net" "net"
"time" "time"
netpkg "github.com/go-gost/gost/pkg/common/net"
net_relay "github.com/go-gost/gost/pkg/common/net/relay"
"github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/common/util/mux"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay" "github.com/go-gost/relay"
) )
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network), "dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "bind", "cmd": "bind",
@ -28,19 +29,19 @@ func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, a
if !h.md.enableBind { if !h.md.enableBind {
resp.Status = relay.StatusForbidden resp.Status = relay.StatusForbidden
resp.WriteTo(conn) log.Error("relay: BIND is disabled")
log.Error("BIND is diabled") _, err := resp.WriteTo(conn)
return return err
} }
if network == "tcp" { if network == "tcp" {
h.bindTCP(ctx, conn, network, address, log) return h.bindTCP(ctx, conn, network, address, log)
} else { } else {
h.bindUDP(ctx, conn, network, address, log) return h.bindUDP(ctx, conn, network, address, log)
} }
} }
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
resp := relay.Response{ resp := relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
@ -51,7 +52,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
log.Error(err) log.Error(err)
resp.Status = relay.StatusServiceUnavailable resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn) resp.WriteTo(conn)
return return err
} }
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
@ -67,7 +68,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
if _, err := resp.WriteTo(conn); err != nil { if _, err := resp.WriteTo(conn); err != nil {
log.Error(err) log.Error(err)
ln.Close() ln.Close()
return return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -75,10 +76,10 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
}) })
log.Debugf("bind on %s OK", ln.Addr()) log.Debugf("bind on %s OK", ln.Addr())
h.serveTCPBind(ctx, conn, ln, log) return h.serveTCPBind(ctx, conn, ln, log)
} }
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
resp := relay.Response{ resp := relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
@ -88,7 +89,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
pc, err := net.ListenUDP(network, bindAddr) pc, err := net.ListenUDP(network, bindAddr)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer pc.Close() defer pc.Close()
@ -104,7 +105,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
resp.Features = append(resp.Features, af) resp.Features = append(resp.Features, af)
if _, err := resp.WriteTo(conn); err != nil { if _, err := resp.WriteTo(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -112,25 +113,26 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
}) })
log.Debugf("bind on %s OK", pc.LocalAddr()) log.Debugf("bind on %s OK", pc.LocalAddr())
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). r := net_relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass). WithBypass(h.options.Bypass).
WithLogger(log) WithLogger(log)
relay.SetBufferSize(h.md.udpBufferSize) r.SetBufferSize(h.md.udpBufferSize)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run() r.Run()
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
return nil
} }
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error {
// Upgrade connection to multiplex stream. // Upgrade connection to multiplex stream.
session, err := mux.ClientSession(conn) session, err := mux.ClientSession(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer session.Close() defer session.Close()
@ -150,7 +152,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
rc, err := ln.Accept() rc, err := ln.Accept()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debugf("peer %s accepted", rc.RemoteAddr()) log.Debugf("peer %s accepted", rc.RemoteAddr())
@ -183,7 +185,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
handler.Transport(sc, c) netpkg.Transport(sc, c)
log.WithFields(map[string]any{"duration": time.Since(t)}). log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
}(rc) }(rc)

View File

@ -2,16 +2,17 @@ package relay
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/go-gost/gost/pkg/handler" netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay" "github.com/go-gost/relay"
) )
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network), "dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect", "cmd": "connect",
@ -27,29 +28,30 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if address == "" { if address == "" {
resp.Status = relay.StatusBadRequest resp.Status = relay.StatusBadRequest
resp.WriteTo(conn) resp.WriteTo(conn)
log.Error("target not specified") err := errors.New("target not specified")
return log.Error(err)
return err
} }
if h.options.Bypass != nil && h.options.Bypass.Contains(address) { if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
log.Info("bypass: ", address) log.Info("bypass: ", address)
resp.Status = relay.StatusForbidden resp.Status = relay.StatusForbidden
resp.WriteTo(conn) _, err := resp.WriteTo(conn)
return return err
} }
cc, err := h.router.Dial(ctx, network, address) cc, err := h.router.Dial(ctx, network, address)
if err != nil { if err != nil {
resp.Status = relay.StatusNetworkUnreachable resp.Status = relay.StatusNetworkUnreachable
resp.WriteTo(conn) resp.WriteTo(conn)
return return err
} }
defer cc.Close() defer cc.Close()
if h.md.noDelay { if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil { if _, err := resp.WriteTo(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
} }
@ -61,7 +63,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if !h.md.noDelay { if !h.md.noDelay {
// cache the header // cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil { if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return return err
} }
} }
conn = rc conn = rc
@ -72,7 +74,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if !h.md.noDelay { if !h.md.noDelay {
// cache the header // cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil { if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return return err
} }
} }
conn = rc conn = rc
@ -80,8 +82,10 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), address) log.Infof("%s <-> %s", conn.RemoteAddr(), address)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), address) }).Infof("%s >-< %s", conn.RemoteAddr(), address)
return nil
} }

View File

@ -2,16 +2,17 @@ package relay
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/go-gost/gost/pkg/handler" netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay" "github.com/go-gost/relay"
) )
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) { func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error {
resp := relay.Response{ resp := relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
@ -20,8 +21,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if target == nil { if target == nil {
resp.Status = relay.StatusServiceUnavailable resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn) resp.WriteTo(conn)
log.Error("no target available") err := errors.New("target not available")
return log.Error(err)
return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -41,7 +43,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
resp.WriteTo(conn) resp.WriteTo(conn)
log.Error(err) log.Error(err)
return return err
} }
defer cc.Close() defer cc.Close()
target.Marker.Reset() target.Marker.Reset()
@ -49,7 +51,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if h.md.noDelay { if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil { if _, err := resp.WriteTo(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
} }
@ -61,7 +63,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if !h.md.noDelay { if !h.md.noDelay {
// cache the header // cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil { if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return return err
} }
} }
conn = rc conn = rc
@ -72,7 +74,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if !h.md.noDelay { if !h.md.noDelay {
// cache the header // cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil { if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return return err
} }
} }
conn = rc conn = rc
@ -80,8 +82,10 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
} }

View File

@ -2,6 +2,7 @@ package relay
import ( import (
"context" "context"
"errors"
"net" "net"
"strconv" "strconv"
"time" "time"
@ -13,6 +14,11 @@ import (
"github.com/go-gost/relay" "github.com/go-gost/relay"
) )
var (
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
)
func init() { func init() {
registry.HandlerRegistry().Register("relay", NewHandler) registry.HandlerRegistry().Register("relay", NewHandler)
} }
@ -53,7 +59,7 @@ func (h *relayHandler) Forward(group *chain.NodeGroup) {
h.group = group h.group = group
} }
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -76,14 +82,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
req := relay.Request{} req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil { if _, err := req.ReadFrom(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if req.Version != relay.Version1 { if req.Version != relay.Version1 {
log.Error("bad version") err := ErrBadVersion
return log.Error(err)
return err
} }
var user, pass string var user, pass string
@ -109,9 +116,9 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
} }
if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) { if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) {
resp.Status = relay.StatusUnauthorized resp.Status = relay.StatusUnauthorized
resp.WriteTo(conn)
log.Error("unauthorized") log.Error("unauthorized")
return _, err := resp.WriteTo(conn)
return err
} }
network := "tcp" network := "tcp"
@ -122,19 +129,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
if h.group != nil { if h.group != nil {
if address != "" { if address != "" {
resp.Status = relay.StatusForbidden resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
log.Error("forward mode, connect is forbidden") log.Error("forward mode, connect is forbidden")
return _, err := resp.WriteTo(conn)
return err
} }
// forward mode // forward mode
h.handleForward(ctx, conn, network, log) return h.handleForward(ctx, conn, network, log)
return
} }
switch req.Flags & relay.CmdMask { switch req.Flags & relay.CmdMask {
case 0, relay.CONNECT: case 0, relay.CONNECT:
h.handleConnect(ctx, conn, network, address, log) return h.handleConnect(ctx, conn, network, address, log)
case relay.BIND: case relay.BIND:
h.handleBind(ctx, conn, network, address, log) return h.handleBind(ctx, conn, network, address, log)
} }
return ErrUnknownCmd
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/bufpool"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
@ -70,7 +71,7 @@ func (h *sniHandler) Init(md md.Metadata) (err error) {
return nil return nil
} }
func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -89,7 +90,7 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
var hdr [dissector.RecordHeaderLen]byte var hdr [dissector.RecordHeaderLen]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil { if _, err := io.ReadFull(conn, hdr[:]); err != nil {
log.Error(err) log.Error(err)
return return err
} }
if hdr[0] != dissector.Handshake { if hdr[0] != dissector.Handshake {
@ -100,9 +101,9 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
} }
if h.httpHandler != nil { if h.httpHandler != nil {
h.httpHandler.Handle(ctx, conn) return h.httpHandler.Handle(ctx, conn)
} }
return return nil
} }
length := binary.BigEndian.Uint16(hdr[3:5]) length := binary.BigEndian.Uint16(hdr[3:5])
@ -111,14 +112,14 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
defer bufpool.Put(buf) defer bufpool.Put(buf)
if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil { if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil {
log.Error(err) log.Error(err)
return return err
} }
copy(*buf, hdr[:]) copy(*buf, hdr[:])
opaque, host, err := h.decodeHost(bytes.NewReader(*buf)) opaque, host, err := h.decodeHost(bytes.NewReader(*buf))
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
target := net.JoinHostPort(host, "443") target := net.JoinHostPort(host, "443")
@ -129,26 +130,29 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
if h.options.Bypass != nil && h.options.Bypass.Contains(target) { if h.options.Bypass != nil && h.options.Bypass.Contains(target) {
log.Info("bypass: ", target) log.Info("bypass: ", target)
return return nil
} }
cc, err := h.router.Dial(ctx, "tcp", target) cc, err := h.router.Dial(ctx, "tcp", target)
if err != nil { if err != nil {
return log.Error(err)
return err
} }
defer cc.Close() defer cc.Close()
if _, err := cc.Write(opaque); err != nil { if _, err := cc.Write(opaque); err != nil {
log.Error(err) log.Error(err)
return return err
} }
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target) log.Infof("%s <-> %s", conn.RemoteAddr(), target)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target) }).Infof("%s >-< %s", conn.RemoteAddr(), target)
return nil
} }
func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) { func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) {

View File

@ -2,17 +2,24 @@ package v4
import ( import (
"context" "context"
"errors"
"net" "net"
"time" "time"
"github.com/go-gost/gosocks4" "github.com/go-gost/gosocks4"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
) )
var (
ErrUnknownCmd = errors.New("socks4: unknown command")
ErrUnimplemented = errors.New("socks4: unimplemented")
)
func init() { func init() {
registry.HandlerRegistry().Register("socks4", NewHandler) registry.HandlerRegistry().Register("socks4", NewHandler)
registry.HandlerRegistry().Register("socks4a", NewHandler) registry.HandlerRegistry().Register("socks4a", NewHandler)
@ -48,7 +55,7 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) {
return nil return nil
} }
func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -72,7 +79,7 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
req, err := gosocks4.ReadRequest(conn) req, err := gosocks4.ReadRequest(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(req) log.Debug(req)
@ -81,22 +88,23 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
if h.options.Auther != nil && if h.options.Auther != nil &&
!h.options.Auther.Authenticate(string(req.Userid), "") { !h.options.Auther.Authenticate(string(req.Userid), "") {
resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil)
resp.Write(conn)
log.Debug(resp) log.Debug(resp)
return return resp.Write(conn)
} }
switch req.Cmd { switch req.Cmd {
case gosocks4.CmdConnect: case gosocks4.CmdConnect:
h.handleConnect(ctx, conn, req, log) return h.handleConnect(ctx, conn, req, log)
case gosocks4.CmdBind: case gosocks4.CmdBind:
h.handleBind(ctx, conn, req) return h.handleBind(ctx, conn, req)
default: default:
log.Errorf("unknown cmd: %d", req.Cmd) err = ErrUnknownCmd
log.Error(err)
return err
} }
} }
func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) { func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) error {
addr := req.Addr.String() addr := req.Addr.String()
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -106,10 +114,9 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
resp := gosocks4.NewReply(gosocks4.Rejected, nil) resp := gosocks4.NewReply(gosocks4.Rejected, nil)
resp.Write(conn)
log.Debug(resp) log.Debug(resp)
log.Info("bypass: ", addr) log.Info("bypass: ", addr)
return return resp.Write(conn)
} }
cc, err := h.router.Dial(ctx, "tcp", addr) cc, err := h.router.Dial(ctx, "tcp", addr)
@ -117,7 +124,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
resp := gosocks4.NewReply(gosocks4.Failed, nil) resp := gosocks4.NewReply(gosocks4.Failed, nil)
resp.Write(conn) resp.Write(conn)
log.Debug(resp) log.Debug(resp)
return return err
} }
defer cc.Close() defer cc.Close()
@ -125,18 +132,21 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
resp := gosocks4.NewReply(gosocks4.Granted, nil) resp := gosocks4.NewReply(gosocks4.Granted, nil)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(resp) log.Debug(resp)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr) log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr) }).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
} }
func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) { func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) error {
// TODO: bind // TODO: bind
return ErrUnimplemented
} }

View File

@ -7,11 +7,11 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network), "dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "bind", "cmd": "bind",
@ -21,17 +21,16 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network,
if !h.md.enableBind { if !h.md.enableBind {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
log.Debug(reply) log.Debug(reply)
log.Error("BIND is diabled") log.Error("socks5: BIND is disabled")
return return reply.Write(conn)
} }
// BIND does not support chain. // BIND does not support chain.
h.bindLocal(ctx, conn, network, address, log) return h.bindLocal(ctx, conn, network, address, log)
} }
func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -40,7 +39,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a
log.Error(err) log.Error(err)
} }
log.Debug(reply) log.Debug(reply)
return return err
} }
socksAddr := gosocks5.Addr{} socksAddr := gosocks5.Addr{}
@ -55,7 +54,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
log.Error(err) log.Error(err)
ln.Close() ln.Close()
return return err
} }
log.Debug(reply) log.Debug(reply)
@ -66,6 +65,7 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a
log.Debugf("bind on %s OK", ln.Addr()) log.Debugf("bind on %s OK", ln.Addr())
h.serveBind(ctx, conn, ln, log) h.serveBind(ctx, conn, ln, log)
return nil
} }
func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
@ -95,7 +95,7 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
defer close(errc) defer close(errc)
defer pc1.Close() defer pc1.Close()
errc <- handler.Transport(conn, pc1) errc <- netpkg.Transport(conn, pc1)
}() }()
return errc return errc
@ -135,7 +135,7 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
start := time.Now() start := time.Now()
log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr()) log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr())
handler.Transport(pc2, rc) netpkg.Transport(pc2, rc)
log.WithFields(map[string]any{"duration": time.Since(start)}). log.WithFields(map[string]any{"duration": time.Since(start)}).
Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr()) Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr())

View File

@ -7,11 +7,11 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network), "dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect", "cmd": "connect",
@ -20,18 +20,17 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
if h.options.Bypass != nil && h.options.Bypass.Contains(address) { if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
resp.Write(conn)
log.Debug(resp) log.Debug(resp)
log.Info("bypass: ", address) log.Info("bypass: ", address)
return return resp.Write(conn)
} }
cc, err := h.router.Dial(ctx, network, address) cc, err := h.router.Dial(ctx, network, address)
if err != nil { if err != nil {
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn)
log.Debug(resp) log.Debug(resp)
return resp.Write(conn)
return err
} }
defer cc.Close() defer cc.Close()
@ -39,14 +38,16 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
resp := gosocks5.NewReply(gosocks5.Succeeded, nil) resp := gosocks5.NewReply(gosocks5.Succeeded, nil)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(resp) log.Debug(resp)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), address) log.Infof("%s <-> %s", conn.RemoteAddr(), address)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), address) }).Infof("%s >-< %s", conn.RemoteAddr(), address)
return nil
} }

View File

@ -2,6 +2,7 @@ package v5
import ( import (
"context" "context"
"errors"
"net" "net"
"time" "time"
@ -13,6 +14,10 @@ import (
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
) )
var (
ErrUnknownCmd = errors.New("socks5: unknown command")
)
func init() { func init() {
registry.HandlerRegistry().Register("socks5", NewHandler) registry.HandlerRegistry().Register("socks5", NewHandler)
registry.HandlerRegistry().Register("socks", NewHandler) registry.HandlerRegistry().Register("socks", NewHandler)
@ -56,7 +61,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
return return
} }
func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -81,7 +86,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
req, err := gosocks5.ReadRequest(conn) req, err := gosocks5.ReadRequest(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(req) log.Debug(req)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
@ -90,20 +95,21 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
switch req.Cmd { switch req.Cmd {
case gosocks5.CmdConnect: case gosocks5.CmdConnect:
h.handleConnect(ctx, conn, "tcp", address, log) return h.handleConnect(ctx, conn, "tcp", address, log)
case gosocks5.CmdBind: case gosocks5.CmdBind:
h.handleBind(ctx, conn, "tcp", address, log) return h.handleBind(ctx, conn, "tcp", address, log)
case socks.CmdMuxBind: case socks.CmdMuxBind:
h.handleMuxBind(ctx, conn, "tcp", address, log) return h.handleMuxBind(ctx, conn, "tcp", address, log)
case gosocks5.CmdUdp: case gosocks5.CmdUdp:
h.handleUDP(ctx, conn, log) return h.handleUDP(ctx, conn, log)
case socks.CmdUDPTun: case socks.CmdUDPTun:
h.handleUDPTun(ctx, conn, "udp", address, log) return h.handleUDPTun(ctx, conn, "udp", address, log)
default: default:
log.Errorf("unknown cmd: %d", req.Cmd) err = ErrUnknownCmd
log.Error(err)
resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil)
resp.Write(conn) resp.Write(conn)
log.Debug(resp) log.Debug(resp)
return return err
} }
} }

View File

@ -7,12 +7,12 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/common/util/mux"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network), "dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "mbind", "cmd": "mbind",
@ -22,16 +22,15 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, networ
if !h.md.enableBind { if !h.md.enableBind {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
log.Debug(reply) log.Debug(reply)
log.Error("BIND is diabled") log.Error("socks5: BIND is disabled")
return return reply.Write(conn)
} }
h.muxBindLocal(ctx, conn, network, address, log) return h.muxBindLocal(ctx, conn, network, address, log)
} }
func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -40,7 +39,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network
log.Error(err) log.Error(err)
} }
log.Debug(reply) log.Debug(reply)
return return err
} }
socksAddr := gosocks5.Addr{} socksAddr := gosocks5.Addr{}
@ -56,7 +55,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
log.Error(err) log.Error(err)
ln.Close() ln.Close()
return return err
} }
log.Debug(reply) log.Debug(reply)
@ -66,15 +65,15 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network
log.Debugf("bind on %s OK", ln.Addr()) log.Debugf("bind on %s OK", ln.Addr())
h.serveMuxBind(ctx, conn, ln, log) return h.serveMuxBind(ctx, conn, ln, log)
} }
func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error {
// Upgrade connection to multiplex stream. // Upgrade connection to multiplex stream.
session, err := mux.ClientSession(conn) session, err := mux.ClientSession(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer session.Close() defer session.Close()
@ -94,7 +93,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
rc, err := ln.Accept() rc, err := ln.Accept()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debugf("peer %s accepted", rc.RemoteAddr()) log.Debugf("peer %s accepted", rc.RemoteAddr())
@ -126,7 +125,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
handler.Transport(sc, c) netpkg.Transport(sc, c)
log.WithFields(map[string]any{"duration": time.Since(t)}). log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
}(rc) }(rc)

View File

@ -2,6 +2,7 @@ package v5
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -9,22 +10,21 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/net/relay"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) { func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"cmd": "udp", "cmd": "udp",
}) })
if !h.md.enableUDP { if !h.md.enableUDP {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
log.Debug(reply) log.Debug(reply)
log.Error("UDP relay is diabled") log.Error("socks5: UDP relay is disabled")
return return reply.Write(conn)
} }
cc, err := net.ListenUDP("udp", nil) cc, err := net.ListenUDP("udp", nil)
@ -33,7 +33,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger
reply := gosocks5.NewReply(gosocks5.Failure, nil) reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn) reply.Write(conn)
log.Debug(reply) log.Debug(reply)
return return err
} }
defer cc.Close() defer cc.Close()
@ -44,7 +44,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(reply) log.Debug(reply)
@ -57,26 +57,29 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger
c, err := h.router.Dial(ctx, "udp", "") // UDP association c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer c.Close() defer c.Close()
pc, ok := c.(net.PacketConn) pc, ok := c.(net.PacketConn)
if !ok { if !ok {
log.Errorf("wrong connection type") err := errors.New("socks5: wrong connection type")
return log.Error(err)
return err
} }
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). r := relay.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
WithBypass(h.options.Bypass). WithBypass(h.options.Bypass).
WithLogger(log) WithLogger(log)
relay.SetBufferSize(h.md.udpBufferSize) r.SetBufferSize(h.md.udpBufferSize)
go relay.Run() go r.Run()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
io.Copy(ioutil.Discard, conn) io.Copy(ioutil.Discard, conn)
log.WithFields(map[string]any{"duration": time.Since(t)}). log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
return nil
} }

View File

@ -6,12 +6,12 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/net/relay"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) { func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"cmd": "udp-tun", "cmd": "udp-tun",
}) })
@ -25,26 +25,24 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
// relay mode // relay mode
if !h.md.enableUDP { if !h.md.enableUDP {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
log.Debug(reply) log.Debug(reply)
log.Error("UDP relay is diabled") log.Error("socks5: UDP relay is disabled")
return return reply.Write(conn)
} }
} else { } else {
// BIND mode // BIND mode
if !h.md.enableBind { if !h.md.enableBind {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
log.Debug(reply) log.Debug(reply)
log.Error("BIND is diabled") log.Error("socks5: BIND is disabled")
return return reply.Write(conn)
} }
} }
pc, err := net.ListenUDP(network, bindAddr) pc, err := net.ListenUDP(network, bindAddr)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer pc.Close() defer pc.Close()
@ -53,20 +51,22 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
log.Error(err) log.Error(err)
return return err
} }
log.Debug(reply) log.Debug(reply)
log.Debugf("bind on %s OK", pc.LocalAddr()) log.Debugf("bind on %s OK", pc.LocalAddr())
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). r := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass). WithBypass(h.options.Bypass).
WithLogger(log) WithLogger(log)
relay.SetBufferSize(h.md.udpBufferSize) r.SetBufferSize(h.md.udpBufferSize)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run() r.Run()
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
return nil
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/common/util/ss" "github.com/go-gost/gost/pkg/common/util/ss"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -59,7 +60,7 @@ func (h *ssHandler) Init(md md.Metadata) (err error) {
return return
} }
func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -87,7 +88,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
if _, err := addr.ReadFrom(conn); err != nil { if _, err := addr.ReadFrom(conn); err != nil {
log.Error(err) log.Error(err)
io.Copy(ioutil.Discard, conn) io.Copy(ioutil.Discard, conn)
return return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -98,19 +99,21 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) { if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
log.Info("bypass: ", addr.String()) log.Info("bypass: ", addr.String())
return return nil
} }
cc, err := h.router.Dial(ctx, "tcp", addr.String()) cc, err := h.router.Dial(ctx, "tcp", addr.String())
if err != nil { if err != nil {
return return err
} }
defer cc.Close() defer cc.Close()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr) log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr) }).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
} }

View File

@ -2,6 +2,7 @@ package ss
import ( import (
"context" "context"
"errors"
"net" "net"
"time" "time"
@ -60,7 +61,7 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) {
return return
} }
func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -95,14 +96,15 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
c, err := h.router.Dial(ctx, "udp", "") // UDP association c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
defer c.Close() defer c.Close()
cc, ok := c.(net.PacketConn) cc, ok := c.(net.PacketConn)
if !ok { if !ok {
log.Errorf("wrong connection type") err := errors.New("ss: wrong connection type")
return log.Error(err)
return err
} }
t := time.Now() t := time.Now()
@ -110,6 +112,8 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
h.relayPacket(pc, cc, log) h.relayPacket(pc, cc, log)
log.WithFields(map[string]any{"duration": time.Since(t)}). log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr()) Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr())
return nil
} }
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) { func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) {

View File

@ -3,12 +3,14 @@ package ssh
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"time" "time"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
netpkg "github.com/go-gost/gost/pkg/common/net"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
@ -56,7 +58,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) {
return nil return nil
} }
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) error {
defer conn.Close() defer conn.Close()
log := h.options.Logger.WithFields(map[string]any{ log := h.options.Logger.WithFields(map[string]any{
@ -66,16 +68,17 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
switch cc := conn.(type) { switch cc := conn.(type) {
case *sshd_util.DirectForwardConn: case *sshd_util.DirectForwardConn:
h.handleDirectForward(ctx, cc, log) return h.handleDirectForward(ctx, cc, log)
case *sshd_util.RemoteForwardConn: case *sshd_util.RemoteForwardConn:
h.handleRemoteForward(ctx, cc, log) return h.handleRemoteForward(ctx, cc, log)
default: default:
log.Error("wrong connection type") err := errors.New("sshd: wrong connection type")
return log.Error(err)
return err
} }
} }
func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) { func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) error {
targetAddr := conn.DstAddr() targetAddr := conn.DstAddr()
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
@ -87,28 +90,33 @@ func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_uti
if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) { if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) {
log.Infof("bypass %s", targetAddr) log.Infof("bypass %s", targetAddr)
return return nil
} }
cc, err := h.router.Dial(ctx, "tcp", targetAddr) cc, err := h.router.Dial(ctx, "tcp", targetAddr)
if err != nil { if err != nil {
return return err
} }
defer cc.Close() defer cc.Close()
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr) log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
handler.Transport(conn, cc) netpkg.Transport(conn, cc)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr) }).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
return nil
} }
func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) { func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) error {
req := conn.Request() req := conn.Request()
t := tcpipForward{} t := tcpipForward{}
ssh.Unmarshal(req.Payload, &t) if err := ssh.Unmarshal(req.Payload, &t); err != nil {
log.Error(err)
return err
}
network := "tcp" network := "tcp"
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port))) addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
@ -125,7 +133,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
if err != nil { if err != nil {
log.Error(err) log.Error(err)
req.Reply(false, nil) req.Reply(false, nil)
return return err
} }
defer ln.Close() defer ln.Close()
@ -149,7 +157,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
}() }()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
sshConn := conn.Conn() sshConn := conn.Conn()
@ -191,7 +199,7 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr()) log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
handler.Transport(ch, conn) netpkg.Transport(ch, conn)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr()) }).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
@ -205,6 +213,8 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(tm), "duration": time.Since(tm),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr) }).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
} }
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {

View File

@ -76,15 +76,16 @@ func (h *tapHandler) Forward(group *chain.NodeGroup) {
h.group = group h.group = group
} }
func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) error {
defer os.Exit(0) defer os.Exit(0)
defer conn.Close() defer conn.Close()
log := h.options.Logger log := h.options.Logger
cc, ok := conn.(*tap_util.Conn) cc, ok := conn.(*tap_util.Conn)
if !ok || cc.Config() == nil { if !ok || cc.Config() == nil {
log.Error("invalid connection") err := errors.New("tap: wrong connection type")
return log.Error(err)
return err
} }
start := time.Now() start := time.Now()
@ -109,7 +110,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
raddr, err = net.ResolveUDPAddr(network, target.Addr) raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
@ -118,6 +119,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
} }
h.handleLoop(ctx, conn, raddr, cc.Config(), log) h.handleLoop(ctx, conn, raddr, cc.Config(), log)
return nil
} }
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) { func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) {

View File

@ -78,7 +78,7 @@ func (h *tunHandler) Forward(group *chain.NodeGroup) {
h.group = group h.group = group
} }
func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) error {
defer os.Exit(0) defer os.Exit(0)
defer conn.Close() defer conn.Close()
@ -86,8 +86,9 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
cc, ok := conn.(*tun_util.Conn) cc, ok := conn.(*tun_util.Conn)
if !ok || cc.Config() == nil { if !ok || cc.Config() == nil {
log.Error("invalid connection") err := errors.New("tun: wrong connection type")
return log.Error(err)
return err
} }
start := time.Now() start := time.Now()
@ -112,7 +113,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
raddr, err = net.ResolveUDPAddr(network, target.Addr) raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return err
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
@ -121,6 +122,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
} }
h.handleLoop(ctx, conn, raddr, cc.Config(), log) h.handleLoop(ctx, conn, raddr, cc.Config(), log)
return nil
} }
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) { func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) {

View File

@ -5,16 +5,33 @@ import (
) )
var ( var (
global = newMetrics() metrics = newMetrics()
) )
type Gauge interface {
Inc()
Dec()
Add(float64)
Set(float64)
}
type Counter interface {
Inc()
Add(float64)
}
type Observer interface {
Observe(float64)
}
type Metrics struct { type Metrics struct {
services prometheus.Gauge services prometheus.Gauge
requests *prometheus.CounterVec requests *prometheus.CounterVec
requestsInFlight *prometheus.GaugeVec requestsInFlight *prometheus.GaugeVec
requestSeconds *prometheus.HistogramVec requestSeconds *prometheus.HistogramVec
requestInputBytes *prometheus.CounterVec inputBytes *prometheus.CounterVec
requestOutputBytes *prometheus.CounterVec outputBytes *prometheus.CounterVec
handlerErrors *prometheus.CounterVec
} }
func newMetrics() *Metrics { func newMetrics() *Metrics {
@ -44,20 +61,26 @@ func newMetrics() *Metrics {
Name: "gost_service_request_duration_seconds", Name: "gost_service_request_duration_seconds",
Help: "Distribution of request latencies", Help: "Distribution of request latencies",
Buckets: []float64{ Buckets: []float64{
.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 20, 30, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 60,
}, },
}, },
[]string{"service"}), []string{"service"}),
requestInputBytes: prometheus.NewCounterVec( inputBytes: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "gost_service_request_transfer_input_bytes_total", Name: "gost_service_transfer_input_bytes_total",
Help: "Total request input data transfer size in bytes", Help: "Total service input data transfer size in bytes",
}, },
[]string{"service"}), []string{"service"}),
requestOutputBytes: prometheus.NewCounterVec( outputBytes: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "gost_service_request_transfer_output_bytes_total", Name: "gost_service_transfer_output_bytes_total",
Help: "Total request output data transfer size in bytes", Help: "Total service output data transfer size in bytes",
},
[]string{"service"}),
handlerErrors: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "gost_service_handler_errors_total",
Help: "Total service handler errors",
}, },
[]string{"service"}), []string{"service"}),
} }
@ -65,31 +88,35 @@ func newMetrics() *Metrics {
prometheus.MustRegister(m.requests) prometheus.MustRegister(m.requests)
prometheus.MustRegister(m.requestsInFlight) prometheus.MustRegister(m.requestsInFlight)
prometheus.MustRegister(m.requestSeconds) prometheus.MustRegister(m.requestSeconds)
prometheus.MustRegister(m.requestInputBytes) prometheus.MustRegister(m.inputBytes)
prometheus.MustRegister(m.requestOutputBytes) prometheus.MustRegister(m.outputBytes)
return m return m
} }
func Services() prometheus.Gauge { func Services() Gauge {
return global.services return metrics.services
} }
func Requests(service string) prometheus.Counter { func Requests(service string) Counter {
return global.requests.With(prometheus.Labels{"service": service}) return metrics.requests.With(prometheus.Labels{"service": service})
} }
func RequestsInFlight(service string) prometheus.Gauge { func RequestsInFlight(service string) Gauge {
return global.requestsInFlight.With(prometheus.Labels{"service": service}) return metrics.requestsInFlight.With(prometheus.Labels{"service": service})
} }
func RequestSeconds(service string) prometheus.Observer { func RequestSeconds(service string) Observer {
return global.requestSeconds.With(prometheus.Labels{"service": service}) return metrics.requestSeconds.With(prometheus.Labels{"service": service})
} }
func RequestInputBytes(service string) prometheus.Counter { func InputBytes(service string) Counter {
return global.requestInputBytes.With(prometheus.Labels{"service": service}) return metrics.inputBytes.With(prometheus.Labels{"service": service})
} }
func RequestOutputBytes(service string) prometheus.Counter { func OutputBytes(service string) Counter {
return global.requestOutputBytes.With(prometheus.Labels{"service": service}) return metrics.outputBytes.With(prometheus.Labels{"service": service})
}
func HandlerErrors(service string) Counter {
return metrics.handlerErrors.With(prometheus.Labels{"service": service})
} }

View File

@ -10,7 +10,6 @@ import (
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/metrics" "github.com/go-gost/gost/pkg/metrics"
"github.com/prometheus/client_golang/prometheus"
) )
type options struct { type options struct {
@ -105,11 +104,14 @@ func (s *service) Serve() error {
metrics.RequestsInFlight(s.name).Inc() metrics.RequestsInFlight(s.name).Inc()
defer metrics.RequestsInFlight(s.name).Dec() defer metrics.RequestsInFlight(s.name).Dec()
timer := prometheus.NewTimer( start := time.Now()
metrics.RequestSeconds(s.name)) defer func() {
defer timer.ObserveDuration() metrics.RequestSeconds(s.name).Observe(time.Since(start).Seconds())
}()
s.handler.Handle(context.Background(), conn) if err := s.handler.Handle(context.Background(), conn); err != nil {
metrics.HandlerErrors(s.name).Inc()
}
}() }()
} }
} }