diff --git a/cmd/gost/config.go b/cmd/gost/config.go new file mode 100644 index 0000000..f5f5190 --- /dev/null +++ b/cmd/gost/config.go @@ -0,0 +1,222 @@ +package main + +import ( + "io" + "os" + "strings" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/gost/pkg/service" +) + +var ( + chains = make(map[string]*chain.Chain) + bypasses = make(map[string]bypass.Bypass) +) + +func buildService(cfg *config.Config) (services []*service.Service) { + if cfg == nil || len(cfg.Services) == 0 { + return + } + + for _, bypassCfg := range cfg.Bypasses { + bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) + } + + for _, chainCfg := range cfg.Chains { + chains[chainCfg.Name] = chainFromConfig(chainCfg) + } + + for _, svc := range cfg.Services { + serviceLogger := log.WithFields(map[string]interface{}{ + "service": svc.Name, + }) + + listenerLogger := serviceLogger.WithFields(map[string]interface{}{ + "kind": "listener", + "type": svc.Listener.Type, + }) + ln := registry.GetListener(svc.Listener.Type)( + listener.AddrOption(svc.Addr), + listener.LoggerOption(listenerLogger), + ) + + if chainable, ok := ln.(listener.Chainable); ok { + chainable.Chain(chains[svc.Chain]) + } + + if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { + listenerLogger.Fatal("init: ", err) + } + + handlerLogger := serviceLogger.WithFields(map[string]interface{}{ + "kind": "handler", + "type": svc.Handler.Type, + }) + + h := registry.GetHandler(svc.Handler.Type)( + handler.ChainOption(chains[svc.Chain]), + handler.BypassOption(bypasses[svc.Bypass]), + handler.LoggerOption(handlerLogger), + ) + + if forwarder, ok := h.(handler.Forwarder); ok { + forwarder.Forward(forwarderFromConfig(svc.Forwarder)) + } + + if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { + handlerLogger.Fatal("init: ", err) + } + + s := (&service.Service{}). + WithListener(ln). + WithHandler(h). + WithLogger(serviceLogger) + services = append(services, s) + + serviceLogger.Infof("listening on: %s/%s", s.Addr().String(), s.Addr().Network()) + } + + return +} + +func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { + if cfg == nil { + return nil + } + + c := &chain.Chain{} + + selector := selectorFromConfig(cfg.Selector) + for _, hop := range cfg.Hops { + group := &chain.NodeGroup{} + for _, v := range hop.Nodes { + + connectorLogger := log.WithFields(map[string]interface{}{ + "kind": "connector", + "type": v.Connector.Type, + "hop": hop.Name, + "node": v.Name, + }) + cr := registry.GetConnector(v.Connector.Type)( + connector.LoggerOption(connectorLogger), + ) + if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { + connectorLogger.Fatal("init: ", err) + } + + dialerLogger := log.WithFields(map[string]interface{}{ + "kind": "dialer", + "type": v.Dialer.Type, + "hop": hop.Name, + "node": v.Name, + }) + d := registry.GetDialer(v.Dialer.Type)( + dialer.LoggerOption(dialerLogger), + ) + if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { + dialerLogger.Fatal("init: ", err) + } + + tr := (&chain.Transport{}). + WithConnector(cr). + WithDialer(d) + + node := chain.NewNode(v.Name, v.Addr). + WithTransport(tr). + WithBypass(bypasses[v.Bypass]) + group.AddNode(node) + } + + sel := selector + if s := selectorFromConfig(hop.Selector); s != nil { + sel = s + } + group.WithSelector(sel) + c.AddNodeGroup(group) + } + + return c +} + +func logFromConfig(cfg *config.LogConfig) logger.Logger { + if cfg == nil { + cfg = &config.LogConfig{} + } + opts := []logger.LoggerOption{ + logger.FormatLoggerOption(logger.LogFormat(cfg.Format)), + logger.LevelLoggerOption(logger.LogLevel(cfg.Level)), + } + + var out io.Writer = os.Stderr + switch cfg.Output { + case "stdout", "": + out = os.Stdout + case "stderr": + out = os.Stderr + default: + f, err := os.OpenFile(cfg.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + log.Warnf("log", err) + } else { + out = f + } + } + opts = append(opts, logger.OutputLoggerOption(out)) + + return logger.NewLogger(opts...) +} + +func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector { + if cfg == nil { + return nil + } + + var strategy chain.Strategy + switch cfg.Strategy { + case "round": + strategy = chain.RoundRobinStrategy() + case "random": + strategy = chain.RandomStrategy() + case "fifo": + strategy = chain.FIFOStrategy() + default: + strategy = chain.RoundRobinStrategy() + } + + return chain.NewSelector( + strategy, + chain.InvalidFilter(), + chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), + ) +} + +func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { + if cfg == nil { + return nil + } + return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) +} + +func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { + if cfg == nil || len(cfg.Targets) == 0 { + return nil + } + + group := &chain.NodeGroup{} + for _, target := range cfg.Targets { + if v := strings.TrimSpace(target); v != "" { + group.AddNode(chain.NewNode(target, target)) + } + } + return group.WithSelector(selectorFromConfig(cfg.Selector)) +} diff --git a/cmd/gost/out.yml b/cmd/gost/out.yml new file mode 100644 index 0000000..dc1e6df --- /dev/null +++ b/cmd/gost/out.yml @@ -0,0 +1,47 @@ +services: +- name: service-0 + url: tcp://:8080/:8081?abc=def&true=true&n=123 + addr: :8080 + chain: chain-0 + listener: + type: tcp + metadata: + abc: def + "n": "123" + "true": "true" + handler: + type: tcp + metadata: + abc: def + "n": "123" + "true": "true" + forwarder: + targets: + - :8081 +chains: +- name: chain-0 + hops: + - name: hop-0 + nodes: + - name: node-0 + url: auto://:1081?n=123t + addr: :1081 + dialer: + type: auto + metadata: + "n": 123t + connector: + type: auto + metadata: + "n": 123t + - name: hop-1 + nodes: + - name: node-0 + url: auto://:1082 + addr: :1082 + dialer: + type: auto + metadata: {} + connector: + type: auto + metadata: {} diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 4ecde3c..13c6bd0 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -2,6 +2,7 @@ package main import ( // Register connectors + _ "github.com/go-gost/gost/pkg/connector/forward" _ "github.com/go-gost/gost/pkg/connector/http" _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v5" @@ -13,7 +14,8 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/udp" // Register handlers - _ "github.com/go-gost/gost/pkg/handler/forward" + _ "github.com/go-gost/gost/pkg/handler/forward/local" + _ "github.com/go-gost/gost/pkg/handler/forward/remote" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/relay" _ "github.com/go-gost/gost/pkg/handler/socks/v4" diff --git a/cmd/gost/relay.yml b/cmd/gost/relay.yml new file mode 100644 index 0000000..15d54e2 --- /dev/null +++ b/cmd/gost/relay.yml @@ -0,0 +1,115 @@ +log: + output: stderr # stderr, stdout, /path/to/file + level: debug # debug, info, warn, error, fatal + format: json # text, json + +profiling: + addr: ":6060" + enabled: true + +services: +- name: socks5 + addr: ":21080" + handler: + type: socks5 + metadata: + readTimeout: 5s + retry: 3 + udp: true + bufferSize: 4096 + listener: + type: tcp + metadata: + keepAlive: 15s +- name: ss + addr: ":28338" + handler: + type: ss + metadata: + method: chacha20-ietf + password: gost + readTimeout: 5s + retry: 3 + udp: true + bufferSize: 4096 + listener: + type: udp + metadata: + keepAlive: 15s +- name: relay-proxy + addr: ":28080" + chain: chain-socks5 + handler: + type: relay + metadata: + readTimeout: 5s + listener: + type: tcp + metadata: + keepAlive: 15s + +chains: +- name: chain-socks5 + hops: + - name: hop01 + nodes: + - name: node01 + addr: ":21080" + connector: + type: socks5 + metadata: + readTimeout: 5s + bufferSize: 4096 + notls: true + dialer: + type: tcp + metadata: {} + +- name: chain-ss + hops: + - name: hop01 + nodes: + - name: node01 + addr: ":28338" + connector: + type: ss + metadata: + method: chacha20-ietf + password: gost + readTimeout: 5s + nodelay: true + udp: true + bufferSize: 4096 + dialer: + type: udp + metadata: {} + +bypasses: +- name: bypass01 + reverse: false + matchers: + - .baidu.com + - "*.example.com" # domain wildcard + - .example.org # will match example.org and *.example.org + + # From IANA IPv4 Special-Purpose Address Registry + # http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml + - 0.0.0.0/8 # RFC1122: "This host on this network" + - 10.0.0.0/8 # RFC1918: Private-Use + - 100.64.0.0/10 # RFC6598: Shared Address Space + - 127.0.0.0/8 # RFC1122: Loopback + - 169.254.0.0/16 # RFC3927: Link Local + - 172.16.0.0/12 # RFC1918: Private-Use + - 192.0.0.0/24 # RFC6890: IETF Protocol Assignments + - 192.0.2.0/24 # RFC5737: Documentation (TEST-NET-1) + - 192.88.99.0/24 # RFC3068: 6to4 Relay Anycast + - 192.168.0.0/16 # RFC1918: Private-Use + - 198.18.0.0/15 # RFC2544: Benchmarking + - 198.51.100.0/24 # RFC5737: Documentation (TEST-NET-2) + - 203.0.113.0/24 # RFC5737: Documentation (TEST-NET-3) + - 240.0.0.0/4 # RFC1112: Reserved + - 255.255.255.255/32 # RFC0919: Limited Broadcast + + # From IANA Multicast Address Space Registry + # http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + - 224.0.0.0/4 # RFC5771: Multicast/Reserved \ No newline at end of file diff --git a/pkg/chain/route.go b/pkg/chain/route.go index db4cc19..70d8951 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -3,7 +3,10 @@ package chain import ( "context" "errors" + "fmt" "net" + + "github.com/go-gost/gost/pkg/connector" ) var ( @@ -93,6 +96,41 @@ func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Co return d.DialContext(ctx, network, address) } +func (r *Route) Bind(ctx context.Context, network, address string) (connector.Accepter, error) { + if r.IsEmpty() { + return r.bindLocal(ctx, network, address) + } + + conn, err := r.Connect(ctx) + if err != nil { + return nil, err + } + + accepter, err := r.Last().transport.Bind(ctx, conn, network, address) + if err != nil { + conn.Close() + return nil, err + } + + return accepter, nil +} + +func (r *Route) bindLocal(ctx context.Context, network, address string) (connector.Accepter, error) { + switch network { + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return net.ListenTCP(network, addr) + case "udp", "udp4", "udp6": + return nil, nil + default: + err := fmt.Errorf("network %s unsupported", network) + return nil, err + } +} + func (r *Route) IsEmpty() bool { return r == nil || len(r.nodes) == 0 } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 49dfc55..acbff74 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -6,6 +6,7 @@ import ( "fmt" "net" + "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/logger" ) @@ -59,6 +60,35 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co return } +func (r *Router) Bind(ctx context.Context, network, address string) (accepter connector.Accepter, err error) { + count := r.retries + 1 + if count <= 0 { + count = 1 + } + r.logger.Debugf("bind: %s/%s", address, network) + + for i := 0; i < count; i++ { + route := r.chain.GetRouteFor(network, address) + + if r.logger.IsLevelEnabled(logger.DebugLevel) { + buf := bytes.Buffer{} + for _, node := range route.Path() { + fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + } + fmt.Fprintf(&buf, "%s", address) + r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + } + + accepter, err = route.Bind(ctx, network, address) + if err == nil { + break + } + r.logger.Errorf("route(retry=%d): %s", i, err) + } + + return +} + func (r *Router) Connect(ctx context.Context) (conn net.Conn, err error) { count := r.retries + 1 if count <= 0 { diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index 3397bdf..14b6920 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -62,6 +62,13 @@ func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, addres return tr.connector.Connect(ctx, conn, network, address) } +func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { + if binder, ok := tr.connector.(connector.Binder); ok { + return binder.Bind(ctx, conn, network, address, connector.MuxBindOption(true)) + } + return nil, connector.ErrBindUnsupported +} + func (tr *Transport) IsMultiplex() bool { if mux, ok := tr.dialer.(dialer.Multiplexer); ok { return mux.IsMultiplex() diff --git a/pkg/connector/bind.go b/pkg/connector/bind.go new file mode 100644 index 0000000..4888b18 --- /dev/null +++ b/pkg/connector/bind.go @@ -0,0 +1,45 @@ +package connector + +import ( + "context" + "errors" + "net" +) + +var ( + ErrBindUnsupported = errors.New("bind unsupported") +) + +type Accepter interface { + Accept() (net.Conn, error) + Addr() net.Addr + Close() error +} + +type Binder interface { + Bind(ctx context.Context, conn net.Conn, network, address string, opts ...BindOption) (Accepter, error) +} + +type AcceptError struct { + err error +} + +func NewAcceptError(err error) error { + return &AcceptError{err: err} +} + +func (e *AcceptError) Error() string { + return e.err.Error() +} + +func (e *AcceptError) Timeout() bool { + return false +} + +func (e *AcceptError) Temporary() bool { + return true +} + +func (e *AcceptError) Unwrap() error { + return e.err +} diff --git a/pkg/connector/forward/connector.go b/pkg/connector/forward/connector.go new file mode 100644 index 0000000..6b2f3b5 --- /dev/null +++ b/pkg/connector/forward/connector.go @@ -0,0 +1,46 @@ +package forward + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("forward", NewConnector) +} + +type forwardConnector struct { + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &forwardConnector{ + logger: options.Logger, + } +} + +func (c *forwardConnector) Init(md md.Metadata) (err error) { + return nil +} + +func (c *forwardConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + c.logger.Infof("connect: %s/%s", address, network) + + return conn, nil +} diff --git a/pkg/connector/option.go b/pkg/connector/option.go index ca2889c..88f634e 100644 --- a/pkg/connector/option.go +++ b/pkg/connector/option.go @@ -20,3 +20,15 @@ type ConnectOptions struct { } type ConnectOption func(opts *ConnectOptions) + +type BindOptions struct { + Mux bool +} + +type BindOption func(opts *BindOptions) + +func MuxBindOption(mux bool) BindOption { + return func(opts *BindOptions) { + opts.Mux = mux + } +} diff --git a/pkg/connector/relay/conn.go b/pkg/connector/relay/conn.go new file mode 100644 index 0000000..6388171 --- /dev/null +++ b/pkg/connector/relay/conn.go @@ -0,0 +1,117 @@ +package relay + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "sync" + + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/relay" +) + +type conn struct { + net.Conn + udp bool + wbuf bytes.Buffer + once sync.Once + headerSent bool + logger logger.Logger +} + +func (c *conn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + resp := relay.Response{} + _, err = resp.ReadFrom(c.Conn) + if err != nil { + return + } + if resp.Version != relay.Version1 { + err = relay.ErrBadVersion + return + } + if resp.Status != relay.StatusOK { + err = fmt.Errorf("status %d", resp.Status) + return + } + }) + + if err != nil { + return + } + + if !c.udp { + return c.Conn.Read(b) + } + + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + buf := make([]byte, dlen) + _, err = io.ReadFull(c.Conn, buf) + n = copy(b, buf) + return +} + +func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *conn) Write(b []byte) (n int, err error) { + if len(b) > 0xFFFF { + err = errors.New("write: data maximum exceeded") + return + } + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + if c.udp { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) + c.wbuf.Write(bb[:]) + c.headerSent = true + } + c.wbuf.Write(b) // append the data to the cached header + // _, err = c.Conn.Write(c.wbuf.Bytes()) + // c.wbuf.Reset() + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + if !c.udp { + return c.Conn.Write(b) + } + if !c.headerSent { + c.headerSent = true + b2 := make([]byte, len(b)+2) + copy(b2, b) + _, err = c.Conn.Write(b2) + return + } + nsize := 2 + len(b) + var buf []byte + if nsize <= mediumBufferSize { + buf = mPool.Get().([]byte) + defer mPool.Put(buf) + } else { + buf = make([]byte, nsize) + } + binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) + n = copy(buf[2:], b) + _, err = c.Conn.Write(buf[:nsize]) + return +} + +func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} diff --git a/pkg/connector/relay/connector.go b/pkg/connector/relay/connector.go new file mode 100644 index 0000000..bbdeae2 --- /dev/null +++ b/pkg/connector/relay/connector.go @@ -0,0 +1,101 @@ +package relay + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/relay" +) + +func init() { + registry.RegiserConnector("relay", NewConnector) +} + +type relayConnector struct { + logger logger.Logger + md metadata +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &relayConnector{ + logger: options.Logger, + } +} + +func (c *relayConnector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + c.logger.Infof("connect: %s/%s", address, network) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + var udpMode bool + if network == "udp" || network == "udp4" || network == "udp6" { + udpMode = true + } + + req := relay.Request{ + Version: relay.Version1, + } + if udpMode { + req.Flags |= relay.FUDP + } + + if c.md.user != nil { + pwd, _ := c.md.user.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.md.user.Username(), + Password: pwd, + }) + } + + if address != "" { + host, port, _ := net.SplitHostPort(address) + nport, _ := strconv.ParseUint(port, 10, 16) + if host == "" { + host = net.IPv4zero.String() + } + + if nport > 0 { + var atype uint8 + ip := net.ParseIP(host) + if ip == nil { + atype = relay.AddrDomain + } else if ip.To4() == nil { + atype = relay.AddrIPv6 + } else { + atype = relay.AddrIPv4 + } + + req.Features = append(req.Features, &relay.TargetAddrFeature{ + AType: atype, + Host: host, + Port: uint16(nport), + }) + } + } + + return conn, nil +} diff --git a/pkg/connector/relay/metadata.go b/pkg/connector/relay/metadata.go new file mode 100644 index 0000000..0fecfee --- /dev/null +++ b/pkg/connector/relay/metadata.go @@ -0,0 +1,36 @@ +package relay + +import ( + "net/url" + "strings" + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + connectTimeout time.Duration + user *url.Userinfo + nodelay bool +} + +func (c *relayConnector) parseMetadata(md md.Metadata) (err error) { + const ( + auth = "auth" + connectTimeout = "connectTimeout" + nodelay = "nodelay" + ) + + if v := md.GetString(auth); v != "" { + ss := strings.SplitN(v, ":", 2) + if len(ss) == 1 { + c.md.user = url.User(ss[0]) + } else { + c.md.user = url.UserPassword(ss[0], ss[1]) + } + } + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.nodelay = md.GetBool(nodelay) + + return +} diff --git a/pkg/connector/socks/v5/accepter.go b/pkg/connector/socks/v5/accepter.go new file mode 100644 index 0000000..7a23de5 --- /dev/null +++ b/pkg/connector/socks/v5/accepter.go @@ -0,0 +1,191 @@ +package v5 + +import ( + "fmt" + "io" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/common/util/udp" + "github.com/go-gost/gost/pkg/logger" +) + +type tcpAccepter struct { + addr net.Addr + conn net.Conn + logger logger.Logger + done chan struct{} +} + +func (p *tcpAccepter) Accept() (net.Conn, error) { + select { + case <-p.done: + return nil, io.EOF + default: + close(p.done) + } + + // second reply, peer connected + rep, err := gosocks5.ReadReply(p.conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("peer connect failed") + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: p.conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpAccepter) Addr() net.Addr { + return p.addr +} + +func (p *tcpAccepter) Close() error { + return p.conn.Close() +} + +type tcpMuxAccepter struct { + addr net.Addr + session *mux.Session + logger logger.Logger +} + +func (p *tcpMuxAccepter) Accept() (net.Conn, error) { + cc, err := p.session.Accept() + if err != nil { + return nil, err + } + + conn, err := p.getPeerConn(cc) + if err != nil { + cc.Close() + return nil, err + } + + return conn, nil +} + +func (p *tcpMuxAccepter) getPeerConn(conn net.Conn) (net.Conn, error) { + // second reply, peer connected + rep, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + err = fmt.Errorf("peer connect failed") + return nil, err + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpMuxAccepter) Addr() net.Addr { + return p.addr +} + +func (p *tcpMuxAccepter) Close() error { + return p.session.Close() +} + +type udpAccepter struct { + addr net.Addr + conn net.PacketConn + cqueue chan net.Conn + connPool *udp.ConnPool + readQueueSize int + readBufferSize int + closed chan struct{} + logger logger.Logger +} + +func (p *udpAccepter) Accept() (conn net.Conn, err error) { + select { + case conn = <-p.cqueue: + return + case <-p.closed: + return nil, net.ErrClosed + } +} + +func (p *udpAccepter) acceptLoop() { + for { + select { + case <-p.closed: + return + default: + } + + b := bufpool.Get(p.readBufferSize) + + n, raddr, err := p.conn.ReadFrom(b) + if err != nil { + return + } + + c := p.getConn(raddr) + if c == nil { + bufpool.Put(b) + continue + } + + if err := c.WriteQueue(b[:n]); err != nil { + p.logger.Warn("data discarded: ", err) + } + } +} + +func (p *udpAccepter) Addr() net.Addr { + return p.addr +} + +func (p *udpAccepter) Close() error { + select { + case <-p.closed: + default: + close(p.closed) + p.connPool.Close() + } + + return nil +} + +func (p *udpAccepter) getConn(raddr net.Addr) *udp.Conn { + c, ok := p.connPool.Get(raddr.String()) + if !ok { + c = udp.NewConn(p.conn, p.addr, raddr, p.readQueueSize) + select { + case p.cqueue <- c: + p.connPool.Set(raddr.String(), c) + default: + c.Close() + p.logger.Warnf("connection queue is full, client %s discarded", raddr) + return nil + } + } + return c +} diff --git a/pkg/connector/socks/v5/bind.go b/pkg/connector/socks/v5/bind.go new file mode 100644 index 0000000..6d945f8 --- /dev/null +++ b/pkg/connector/socks/v5/bind.go @@ -0,0 +1,137 @@ +package v5 + +import ( + "context" + "fmt" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/common/util/udp" + "github.com/go-gost/gost/pkg/connector" +) + +// Bind implements connector.Binder. +func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (connector.Accepter, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "network": network, + "address": address, + }) + c.logger.Infof("bind: %s/%s", address, network) + + options := connector.BindOptions{} + for _, opt := range opts { + opt(&options) + } + + switch network { + case "tcp", "tcp4", "tcp6": + if options.Mux { + return c.muxBindTCP(ctx, conn, network, address) + } + return c.bindTCP(ctx, conn, network, address) + case "udp", "udp4", "udp6": + return c.bindUDP(ctx, conn, network, address) + default: + err := fmt.Errorf("network %s is unsupported", network) + c.logger.Error(err) + return nil, err + } +} + +func (c *socks5Connector) bindTCP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { + laddr, err := c.bind(conn, gosocks5.CmdBind, network, address) + if err != nil { + return nil, err + } + + return &tcpAccepter{ + addr: laddr, + conn: conn, + logger: c.logger, + done: make(chan struct{}), + }, nil +} + +func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { + laddr, err := c.bind(conn, socks.CmdMuxBind, network, address) + if err != nil { + return nil, err + } + + session, err := mux.ServerSession(conn) + if err != nil { + return nil, err + } + + return &tcpMuxAccepter{ + addr: laddr, + session: session, + logger: c.logger, + }, nil +} + +func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { + laddr, err := c.bind(conn, socks.CmdUDPTun, network, address) + if err != nil { + return nil, err + } + + accepter := &udpAccepter{ + addr: laddr, + conn: socks.UDPTunClientPacketConn(conn), + cqueue: make(chan net.Conn, c.md.backlog), + connPool: udp.NewConnPool(c.md.ttl).WithLogger(c.logger), + readQueueSize: c.md.readQueueSize, + readBufferSize: c.md.readBufferSize, + closed: make(chan struct{}), + logger: c.logger, + } + go accepter.acceptLoop() + + return accepter, nil +} + +func (l *socks5Connector) bind(conn net.Conn, cmd uint8, network, address string) (net.Addr, error) { + laddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + addr := gosocks5.Addr{} + addr.ParseFrom(laddr.String()) + req := gosocks5.NewRequest(cmd, &addr) + if err := req.Write(conn); err != nil { + return nil, err + } + l.logger.Debug(req) + + // first reply, bind status + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + l.logger.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("bind on %s/%s failed", laddr, laddr.Network()) + } + + var baddr net.Addr + switch network { + case "tcp", "tcp4", "tcp6": + baddr, err = net.ResolveTCPAddr(network, reply.Addr.String()) + case "udp", "udp4", "udp6": + baddr, err = net.ResolveUDPAddr(network, reply.Addr.String()) + default: + err = fmt.Errorf("unknown network %s", network) + } + if err != nil { + return nil, err + } + l.logger.Debugf("bind on %s/%s OK", baddr, baddr.Network()) + + return laddr, nil +} diff --git a/pkg/connector/socks/v5/conn.go b/pkg/connector/socks/v5/conn.go new file mode 100644 index 0000000..d11f3b5 --- /dev/null +++ b/pkg/connector/socks/v5/conn.go @@ -0,0 +1,17 @@ +package v5 + +import "net" + +type bindConn struct { + net.Conn + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *bindConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *bindConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go index 336130f..e2919b3 100644 --- a/pkg/connector/socks/v5/metadata.go +++ b/pkg/connector/socks/v5/metadata.go @@ -9,11 +9,23 @@ import ( md "github.com/go-gost/gost/pkg/metadata" ) +const ( + defaultTTL = 60 * time.Second + defaultReadBufferSize = 4096 + defaultReadQueueSize = 128 + defaultBacklog = 128 +) + type metadata struct { connectTimeout time.Duration User *url.Userinfo tlsConfig *tls.Config noTLS bool + + ttl time.Duration + readBufferSize int + readQueueSize int + backlog int } func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { @@ -21,6 +33,11 @@ func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { connectTimeout = "timeout" auth = "auth" noTLS = "notls" + + ttl = "ttl" + readBufferSize = "readBufferSize" + readQueueSize = "readQueueSize" + backlog = "backlog" ) if v := md.GetString(auth); v != "" { @@ -35,5 +52,23 @@ func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { c.md.connectTimeout = md.GetDuration(connectTimeout) c.md.noTLS = md.GetBool(noTLS) + c.md.ttl = md.GetDuration(ttl) + if c.md.ttl <= 0 { + c.md.ttl = defaultTTL + } + c.md.readBufferSize = md.GetInt(readBufferSize) + if c.md.readBufferSize <= 0 { + c.md.readBufferSize = defaultReadBufferSize + } + + c.md.readQueueSize = md.GetInt(readQueueSize) + if c.md.readQueueSize <= 0 { + c.md.readQueueSize = defaultReadQueueSize + } + + c.md.backlog = md.GetInt(backlog) + if c.md.backlog <= 0 { + c.md.backlog = defaultBacklog + } return } diff --git a/pkg/handler/forward/handler.go b/pkg/handler/forward/local/handler.go similarity index 93% rename from pkg/handler/forward/handler.go rename to pkg/handler/forward/local/handler.go index 4f4bc52..614095f 100644 --- a/pkg/handler/forward/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -14,7 +14,8 @@ import ( ) func init() { - registry.RegisterHandler("forward", NewHandler) + registry.RegisterHandler("tcp", NewHandler) + registry.RegisterHandler("udp", NewHandler) } type forwardHandler struct { @@ -43,9 +44,8 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(group *chain.NodeGroup, chain *chain.Chain) { +func (h *forwardHandler) Forward(group *chain.NodeGroup) { h.group = group - h.chain = chain } func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { diff --git a/pkg/handler/forward/metadata.go b/pkg/handler/forward/local/metadata.go similarity index 100% rename from pkg/handler/forward/metadata.go rename to pkg/handler/forward/local/metadata.go diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go new file mode 100644 index 0000000..f8a07b1 --- /dev/null +++ b/pkg/handler/forward/remote/handler.go @@ -0,0 +1,106 @@ +package forward + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterHandler("rtcp", NewHandler) + registry.RegisterHandler("rudp", NewHandler) +} + +type forwardHandler struct { + group *chain.NodeGroup + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &forwardHandler{ + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *forwardHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +// Forward implements handler.Forwarder. +func (h *forwardHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + target := h.group.Next() + if target == nil { + h.logger.Error("no target available") + return + } + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": target.Addr(), + }) + + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + + // without chain + r := (&chain.Router{}). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + + cc, err := r.Dial(ctx, network, target.Addr()) + if err != nil { + h.logger.Error(err) + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + target.Marker().Mark() + return + } + defer cc.Close() + target.Marker().Reset() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) +} diff --git a/pkg/handler/forward/remote/metadata.go b/pkg/handler/forward/remote/metadata.go new file mode 100644 index 0000000..9bf7df0 --- /dev/null +++ b/pkg/handler/forward/remote/metadata.go @@ -0,0 +1,23 @@ +package forward + +import ( + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + readTimeout time.Duration + retryCount int +} + +func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + retryCount = "retry" + ) + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 80c4d10..1016d91 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -14,5 +14,5 @@ type Handler interface { } type Forwarder interface { - Forward(*chain.NodeGroup, *chain.Chain) + Forward(*chain.NodeGroup) } diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 7f44508..9bc4af4 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -45,9 +45,8 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *relayHandler) Forward(group *chain.NodeGroup, chain *chain.Chain) { +func (h *relayHandler) Forward(group *chain.NodeGroup) { h.group = group - h.chain = chain } func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { @@ -134,7 +133,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { if target == "" { resp.Status = relay.StatusBadRequest resp.WriteTo(conn) - h.logger.Error("bad request") + h.logger.Error("target not specified") return } diff --git a/pkg/listener/rtcp/conn.go b/pkg/listener/rtcp/conn.go deleted file mode 100644 index 1bc27c4..0000000 --- a/pkg/listener/rtcp/conn.go +++ /dev/null @@ -1,17 +0,0 @@ -package rtcp - -import "net" - -type peerConn struct { - net.Conn - localAddr net.Addr - remoteAddr net.Addr -} - -func (c *peerConn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *peerConn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 2e030f3..ec31849 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -2,14 +2,10 @@ package rtcp import ( "context" - "fmt" "net" - "sync" - "time" - "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -21,16 +17,13 @@ func init() { } type rtcpListener struct { - addr string - laddr net.Addr - chain *chain.Chain - md metadata - ln net.Listener - cqueue chan net.Conn - session *mux.Session - sessionMux sync.Mutex - logger logger.Logger - closed chan struct{} + addr string + laddr net.Addr + chain *chain.Chain + accepter connector.Accepter + md metadata + logger logger.Logger + closed chan struct{} } func NewListener(opts ...listener.Option) listener.Listener { @@ -61,14 +54,6 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.cqueue = make(chan net.Conn, l.md.backlog) - - if l.chain.IsEmpty() { - l.ln, err = net.ListenTCP("tcp", laddr) - return err - } - - go l.listenLoop() return } @@ -78,10 +63,6 @@ func (l *rtcpListener) Addr() net.Addr { } func (l *rtcpListener) Close() error { - if l.ln != nil { - return l.ln.Close() - } - select { case <-l.closed: default: @@ -92,120 +73,21 @@ func (l *rtcpListener) Close() error { } func (l *rtcpListener) Accept() (conn net.Conn, err error) { - if l.ln != nil { - return l.ln.Accept() + if l.accepter == nil { + r := (&chain.Router{}). + WithChain(l.chain). + WithRetry(l.md.retryCount). + WithLogger(l.logger) + l.accepter, err = r.Bind(context.Background(), "tcp", l.laddr.String()) + if err != nil { + return nil, connector.NewAcceptError(err) + } } - - select { - case conn = <-l.cqueue: - case <-l.closed: - err = net.ErrClosed + conn, err = l.accepter.Accept() + if err != nil { + l.accepter.Close() + l.accepter = nil + return nil, connector.NewAcceptError(err) } - return } - -func (l *rtcpListener) listenLoop() { - var tempDelay time.Duration - - for { - select { - case <-l.closed: - return - default: - } - - conn, err := l.accept() - - if err != nil { - if tempDelay == 0 { - tempDelay = 1000 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 6 * time.Second; tempDelay > max { - tempDelay = max - } - l.logger.Warnf("accept: %v, retrying in %v", err, tempDelay) - time.Sleep(tempDelay) - continue - } - - tempDelay = 0 - - select { - case l.cqueue <- conn: - default: - conn.Close() - l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr().String()) - } - } -} - -func (l *rtcpListener) accept() (net.Conn, error) { - if l.md.enableMux { - return l.muxAccept() - } - - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - cc, err := r.Connect(context.Background()) - if err != nil { - return nil, err - } - - conn, err := l.waitPeer(cc) - if err != nil { - l.logger.Error(err) - cc.Close() - return nil, err - } - - l.logger.Debugf("peer %s accepted", conn.RemoteAddr()) - - return conn, nil -} - -func (l *rtcpListener) waitPeer(conn net.Conn) (net.Conn, error) { - addr := gosocks5.Addr{} - addr.ParseFrom(l.addr) - req := gosocks5.NewRequest(gosocks5.CmdBind, &addr) - if err := req.Write(conn); err != nil { - return nil, err - } - - // first reply, bind status - rep, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - - l.logger.Debug(rep) - - if rep.Rep != gosocks5.Succeeded { - return nil, fmt.Errorf("bind on %s failed", l.addr) - } - l.logger.Debugf("bind on %s OK", rep.Addr) - - // second reply, peer connected - rep, err = gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - if rep.Rep != gosocks5.Succeeded { - return nil, fmt.Errorf("peer connect failed") - } - - raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) - if err != nil { - return nil, err - } - - return &peerConn{ - Conn: conn, - localAddr: l.laddr, - remoteAddr: raddr, - }, nil -} diff --git a/pkg/listener/rtcp/mux.go b/pkg/listener/rtcp/mux.go deleted file mode 100644 index 8040b15..0000000 --- a/pkg/listener/rtcp/mux.go +++ /dev/null @@ -1,109 +0,0 @@ -package rtcp - -import ( - "context" - "fmt" - "net" - - "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/common/util/mux" - "github.com/go-gost/gost/pkg/common/util/socks" -) - -func (l *rtcpListener) muxAccept() (net.Conn, error) { - session, err := l.getSession() - if err != nil { - l.logger.Error(err) - return nil, err - } - - cc, err := session.Accept() - if err != nil { - session.Close() - return nil, err - } - - conn, err := l.getPeerConn(cc) - if err != nil { - l.logger.Error(err) - cc.Close() - return nil, err - } - - l.logger.Debugf("peer %s accepted", conn.RemoteAddr()) - - return conn, nil -} - -func (l *rtcpListener) getPeerConn(conn net.Conn) (net.Conn, error) { - // second reply, peer connected - rep, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - if rep.Rep != gosocks5.Succeeded { - err = fmt.Errorf("peer connect failed") - return nil, err - } - - raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) - if err != nil { - return nil, err - } - - return &peerConn{ - Conn: conn, - localAddr: l.laddr, - remoteAddr: raddr, - }, nil -} - -func (l *rtcpListener) getSession() (s *mux.Session, err error) { - l.sessionMux.Lock() - defer l.sessionMux.Unlock() - - if l.session != nil && !l.session.IsClosed() { - return l.session, nil - } - - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - conn, err := r.Connect(context.Background()) - if err != nil { - return nil, err - } - - l.session, err = l.initSession(conn) - if err != nil { - conn.Close() - return - } - - return l.session, nil -} - -func (l *rtcpListener) initSession(conn net.Conn) (*mux.Session, error) { - addr := gosocks5.Addr{} - addr.ParseFrom(l.addr) - req := gosocks5.NewRequest(socks.CmdMuxBind, &addr) - if err := req.Write(conn); err != nil { - return nil, err - } - - // first reply, bind status - rep, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - - if rep.Rep != gosocks5.Succeeded { - err = fmt.Errorf("bind on %s failed", l.addr) - return nil, err - } - l.logger.Debugf("bind on %s OK", rep.Addr) - - return mux.ServerSession(conn) -} diff --git a/pkg/listener/rudp/metadata.go b/pkg/listener/rudp/metadata.go index 6986813..986b051 100644 --- a/pkg/listener/rudp/metadata.go +++ b/pkg/listener/rudp/metadata.go @@ -14,8 +14,7 @@ const ( ) type metadata struct { - ttl time.Duration - + ttl time.Duration readBufferSize int readQueueSize int backlog int diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index 698f36f..300abae 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -1,8 +1,12 @@ package metadata -import "time" +import ( + "strconv" + "time" +) type Metadata interface { + IsExists(key string) bool Get(key string) interface{} GetBool(key string) bool GetInt(key string) int @@ -13,6 +17,11 @@ type Metadata interface { type MapMetadata map[string]interface{} +func (m MapMetadata) IsExists(key string) bool { + _, ok := m[key] + return ok +} + func (m MapMetadata) Get(key string) interface{} { if m != nil { return m[key] @@ -21,22 +30,43 @@ func (m MapMetadata) Get(key string) interface{} { } func (m MapMetadata) GetBool(key string) (v bool) { - if m != nil { - v, _ = m[key].(bool) + if m == nil || !m.IsExists(key) { + return + } + switch vv := m[key].(type) { + case bool: + return vv + case int: + return vv != 0 + case string: + v, _ = strconv.ParseBool(vv) + return } return } func (m MapMetadata) GetInt(key string) (v int) { - if m != nil { - v, _ = m[key].(int) + switch vv := m[key].(type) { + case bool: + if vv { + v = 1 + } + case int: + return vv + case string: + v, _ = strconv.Atoi(vv) + return } return } func (m MapMetadata) GetFloat(key string) (v float64) { - if m != nil { - v, _ = m[key].(float64) + switch vv := m[key].(type) { + case int: + return float64(vv) + case string: + v, _ = strconv.ParseFloat(vv, 64) + return } return } diff --git a/pkg/service/service.go b/pkg/service/service.go index 2e0ef00..76c72e0 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -50,7 +50,7 @@ func (s *Service) serve() error { if e != nil { if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { - tempDelay = 5 * time.Millisecond + tempDelay = 100 * time.Millisecond } else { tempDelay *= 2 } @@ -61,6 +61,7 @@ func (s *Service) serve() error { time.Sleep(tempDelay) continue } + s.logger.Errorf("accept: %v", e) return e } tempDelay = 0