diff --git a/api/config.go b/api/config.go index 3cb09ea..e5261a7 100644 --- a/api/config.go +++ b/api/config.go @@ -54,7 +54,7 @@ func getConfig(ctx *gin.Context) { if ok && ss != nil { status := ss.Status() svc.Status = &config.ServiceStatus{ - CreateTime: status.CreateTime().Unix(), + CreateTime: status.CreateTime().UnixNano(), State: string(status.State()), } if st := status.Stats(); st != nil { @@ -69,7 +69,7 @@ func getConfig(ctx *gin.Context) { for _, ev := range status.Events() { if !ev.Time.IsZero() { svc.Status.Events = append(svc.Status.Events, config.ServiceEvent{ - Time: ev.Time.Unix(), + Time: ev.Time.UnixNano(), Msg: ev.Message, }) } diff --git a/config/config.go b/config/config.go index 7c6333f..c511738 100644 --- a/config/config.go +++ b/config/config.go @@ -360,6 +360,7 @@ type ForwardNodeConfig struct { HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } type HTTPNodeConfig struct { @@ -482,10 +483,10 @@ type NodeConfig struct { Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"` Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"` - Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } type Config struct { diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 09cf5eb..6dfa36e 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -234,7 +234,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { } if forwarder, ok := h.(handler.Forwarder); ok { - hop, err := parseForwarder(cfg.Forwarder) + hop, err := parseForwarder(cfg.Forwarder, log) if err != nil { return nil, err } @@ -266,7 +266,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return s, nil } -func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { +func parseForwarder(cfg *config.ForwarderConfig, log logger.Logger) (hop.Hop, error) { if cfg == nil { return nil, nil } @@ -298,12 +298,13 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { HTTP: node.HTTP, TLS: node.TLS, Auth: node.Auth, + Metadata: node.Metadata, }) } } } if len(hc.Nodes) > 0 { - return hop_parser.ParseHop(&hc, logger.Default()) + return hop_parser.ParseHop(&hc, log) } return registry.HopRegistry().Get(hc.Name), nil } diff --git a/handler/auto/handler.go b/handler/auto/handler.go index 80c9063..3b251ce 100644 --- a/handler/auto/handler.go +++ b/handler/auto/handler.go @@ -11,6 +11,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" "github.com/go-gost/x/registry" ) @@ -79,6 +80,7 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler log := h.options.Logger.WithFields(map[string]any{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), + "sid": ctxvalue.SidFromContext(ctx), }) if log.IsLevelEnabled(logger.DebugLevel) { diff --git a/handler/http/handler.go b/handler/http/handler.go index 4ddb4f9..89a190e 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -84,6 +84,7 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler log := h.options.Logger.WithFields(map[string]any{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), + "sid": ctxvalue.SidFromContext(ctx), }) log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) defer func() { diff --git a/hop/hop.go b/hop/hop.go index b528f2e..689bb30 100644 --- a/hop/hop.go +++ b/hop/hop.go @@ -133,80 +133,15 @@ func (p *chainHop) Select(ctx context.Context, opts ...hop.SelectOption) *chain. opt(&options) } - ns := p.Nodes() - if len(ns) == 0 { - return nil - } - // hop level bypass if p.options.bypass != nil && p.options.bypass.Contains(ctx, options.Network, options.Addr, bypass.WithHostOpton(options.Host)) { return nil } - filters := ns - if host := options.Host; host != "" { - filters = nil - if v, _, _ := net.SplitHostPort(host); v != "" { - host = v - } - var nodes []*chain.Node - for _, node := range ns { - if node == nil { - continue - } - vhost := node.Options().Host - if vhost == "" { - nodes = append(nodes, node) - continue - } - if vhost == host || - vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) { - filters = append(filters, node) - } - } - if len(filters) == 0 { - filters = nodes - } - } - - if protocol := options.Protocol; protocol != "" { - p.options.logger.Debugf("filter by protocol: %s", protocol) - var nodes []*chain.Node - for _, node := range filters { - if node == nil { - continue - } - if node.Options().Protocol == "" { - nodes = append(nodes, node) - continue - } - if node.Options().Protocol == protocol { - nodes = append(nodes, node) - } - } - filters = nodes - } - - // filter by path - if path := options.Path; path != "" { - p.options.logger.Debugf("filter by path: %s", path) - sort.SliceStable(filters, func(i, j int) bool { - return len(filters[i].Options().Path) > len(filters[j].Options().Path) - }) - var nodes []*chain.Node - for _, node := range filters { - if node.Options().Path == "" { - nodes = append(nodes, node) - continue - } - if strings.HasPrefix(path, node.Options().Path) { - nodes = append(nodes, node) - break - } - } - filters = nodes - } + filters := p.filterByHost(options.Host, p.Nodes()...) + filters = p.filterByProtocol(options.Protocol, filters...) + filters = p.filterByPath(options.Path, filters...) var nodes []*chain.Node for _, node := range filters { @@ -231,6 +166,108 @@ func (p *chainHop) Select(ctx context.Context, opts ...hop.SelectOption) *chain. return nodes[0] } +func (p *chainHop) filterByHost(host string, nodes ...*chain.Node) (filters []*chain.Node) { + if host == "" || len(nodes) == 0 { + return nodes + } + + if v, _, _ := net.SplitHostPort(host); v != "" { + host = v + } + p.options.logger.Debugf("filter by host: %s", host) + + found := false + for _, node := range nodes { + if node == nil { + continue + } + vhost := node.Options().Host + if vhost == "" { // backup node + if !found { + filters = append(filters, node) + } + continue + } + + if vhost == host || + vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) { + if !found { // clear all backup nodes when matched node found + filters = nil + } + filters = append(filters, node) + found = true + continue + } + + } + + return +} + +func (p *chainHop) filterByProtocol(protocol string, nodes ...*chain.Node) (filters []*chain.Node) { + if protocol == "" || len(nodes) == 0 { + return nodes + } + + p.options.logger.Debugf("filter by protocol: %s", protocol) + found := false + for _, node := range nodes { + if node == nil { + continue + } + + if node.Options().Protocol == "" { + if !found { + filters = append(filters, node) + } + continue + } + + if node.Options().Protocol == protocol { + if !found { + filters = nil + } + filters = append(filters, node) + found = true + continue + } + } + + return +} + +func (p *chainHop) filterByPath(path string, nodes ...*chain.Node) (filters []*chain.Node) { + if path == "" || len(nodes) == 0 { + return nodes + } + + p.options.logger.Debugf("filter by path: %s", path) + + sort.SliceStable(nodes, func(i, j int) bool { + return len(nodes[i].Options().Path) > len(nodes[j].Options().Path) + }) + + found := false + for _, node := range nodes { + if node.Options().Path == "" { + if !found { + filters = append(filters, node) + } + continue + } + + if strings.HasPrefix(path, node.Options().Path) { + if !found { + filters = nil + } + filters = append(filters, node) + break + } + } + + return +} + func (p *chainHop) periodReload(ctx context.Context) error { period := p.options.period if period < time.Second { diff --git a/internal/util/ws/ws.go b/internal/util/ws/ws.go index 0dd6612..a6c1512 100644 --- a/internal/util/ws/ws.go +++ b/internal/util/ws/ws.go @@ -49,15 +49,18 @@ func (c *websocketConn) WriteMessage(messageType int, data []byte) error { } func (c *websocketConn) SetDeadline(t time.Time) error { - c.mux.Lock() - defer c.mux.Unlock() - if err := c.SetReadDeadline(t); err != nil { return err } return c.SetWriteDeadline(t) } +func (c *websocketConn) SetReadDeadline(t time.Time) error { + c.mux.Lock() + defer c.mux.Unlock() + return c.Conn.SetReadDeadline(t) +} + func (c *websocketConn) SetWriteDeadline(t time.Time) error { c.mux.Lock() defer c.mux.Unlock()