fix deadlock in websocket client conn

This commit is contained in:
ginuerzh 2024-01-12 23:46:22 +08:00
parent c04c28e1fd
commit 01168e9846
7 changed files with 122 additions and 77 deletions

View File

@ -54,7 +54,7 @@ func getConfig(ctx *gin.Context) {
if ok && ss != nil { if ok && ss != nil {
status := ss.Status() status := ss.Status()
svc.Status = &config.ServiceStatus{ svc.Status = &config.ServiceStatus{
CreateTime: status.CreateTime().Unix(), CreateTime: status.CreateTime().UnixNano(),
State: string(status.State()), State: string(status.State()),
} }
if st := status.Stats(); st != nil { if st := status.Stats(); st != nil {
@ -69,7 +69,7 @@ func getConfig(ctx *gin.Context) {
for _, ev := range status.Events() { for _, ev := range status.Events() {
if !ev.Time.IsZero() { if !ev.Time.IsZero() {
svc.Status.Events = append(svc.Status.Events, config.ServiceEvent{ svc.Status.Events = append(svc.Status.Events, config.ServiceEvent{
Time: ev.Time.Unix(), Time: ev.Time.UnixNano(),
Msg: ev.Message, Msg: ev.Message,
}) })
} }

View File

@ -360,6 +360,7 @@ type ForwardNodeConfig struct {
HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"`
TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"`
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
} }
type HTTPNodeConfig struct { type HTTPNodeConfig struct {
@ -482,10 +483,10 @@ type NodeConfig struct {
Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"` Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"`
Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"` Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"`
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"`
TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"`
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
} }
type Config struct { type Config struct {

View File

@ -234,7 +234,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
} }
if forwarder, ok := h.(handler.Forwarder); ok { if forwarder, ok := h.(handler.Forwarder); ok {
hop, err := parseForwarder(cfg.Forwarder) hop, err := parseForwarder(cfg.Forwarder, log)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -266,7 +266,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
return s, nil return s, nil
} }
func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { func parseForwarder(cfg *config.ForwarderConfig, log logger.Logger) (hop.Hop, error) {
if cfg == nil { if cfg == nil {
return nil, nil return nil, nil
} }
@ -298,12 +298,13 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
HTTP: node.HTTP, HTTP: node.HTTP,
TLS: node.TLS, TLS: node.TLS,
Auth: node.Auth, Auth: node.Auth,
Metadata: node.Metadata,
}) })
} }
} }
} }
if len(hc.Nodes) > 0 { if len(hc.Nodes) > 0 {
return hop_parser.ParseHop(&hc, logger.Default()) return hop_parser.ParseHop(&hc, log)
} }
return registry.HopRegistry().Get(hc.Name), nil return registry.HopRegistry().Get(hc.Name), nil
} }

View File

@ -11,6 +11,7 @@ import (
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
"github.com/go-gost/gosocks4" "github.com/go-gost/gosocks4"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
ctxvalue "github.com/go-gost/x/internal/ctx"
netpkg "github.com/go-gost/x/internal/net" netpkg "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
@ -79,6 +80,7 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler
log := h.options.Logger.WithFields(map[string]any{ log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(), "remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(), "local": conn.LocalAddr().String(),
"sid": ctxvalue.SidFromContext(ctx),
}) })
if log.IsLevelEnabled(logger.DebugLevel) { if log.IsLevelEnabled(logger.DebugLevel) {

View File

@ -84,6 +84,7 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler
log := h.options.Logger.WithFields(map[string]any{ log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(), "remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(), "local": conn.LocalAddr().String(),
"sid": ctxvalue.SidFromContext(ctx),
}) })
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() { defer func() {

View File

@ -133,80 +133,15 @@ func (p *chainHop) Select(ctx context.Context, opts ...hop.SelectOption) *chain.
opt(&options) opt(&options)
} }
ns := p.Nodes()
if len(ns) == 0 {
return nil
}
// hop level bypass // hop level bypass
if p.options.bypass != nil && if p.options.bypass != nil &&
p.options.bypass.Contains(ctx, options.Network, options.Addr, bypass.WithHostOpton(options.Host)) { p.options.bypass.Contains(ctx, options.Network, options.Addr, bypass.WithHostOpton(options.Host)) {
return nil return nil
} }
filters := ns filters := p.filterByHost(options.Host, p.Nodes()...)
if host := options.Host; host != "" { filters = p.filterByProtocol(options.Protocol, filters...)
filters = nil filters = p.filterByPath(options.Path, filters...)
if v, _, _ := net.SplitHostPort(host); v != "" {
host = v
}
var nodes []*chain.Node
for _, node := range ns {
if node == nil {
continue
}
vhost := node.Options().Host
if vhost == "" {
nodes = append(nodes, node)
continue
}
if vhost == host ||
vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) {
filters = append(filters, node)
}
}
if len(filters) == 0 {
filters = nodes
}
}
if protocol := options.Protocol; protocol != "" {
p.options.logger.Debugf("filter by protocol: %s", protocol)
var nodes []*chain.Node
for _, node := range filters {
if node == nil {
continue
}
if node.Options().Protocol == "" {
nodes = append(nodes, node)
continue
}
if node.Options().Protocol == protocol {
nodes = append(nodes, node)
}
}
filters = nodes
}
// filter by path
if path := options.Path; path != "" {
p.options.logger.Debugf("filter by path: %s", path)
sort.SliceStable(filters, func(i, j int) bool {
return len(filters[i].Options().Path) > len(filters[j].Options().Path)
})
var nodes []*chain.Node
for _, node := range filters {
if node.Options().Path == "" {
nodes = append(nodes, node)
continue
}
if strings.HasPrefix(path, node.Options().Path) {
nodes = append(nodes, node)
break
}
}
filters = nodes
}
var nodes []*chain.Node var nodes []*chain.Node
for _, node := range filters { for _, node := range filters {
@ -231,6 +166,108 @@ func (p *chainHop) Select(ctx context.Context, opts ...hop.SelectOption) *chain.
return nodes[0] return nodes[0]
} }
func (p *chainHop) filterByHost(host string, nodes ...*chain.Node) (filters []*chain.Node) {
if host == "" || len(nodes) == 0 {
return nodes
}
if v, _, _ := net.SplitHostPort(host); v != "" {
host = v
}
p.options.logger.Debugf("filter by host: %s", host)
found := false
for _, node := range nodes {
if node == nil {
continue
}
vhost := node.Options().Host
if vhost == "" { // backup node
if !found {
filters = append(filters, node)
}
continue
}
if vhost == host ||
vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) {
if !found { // clear all backup nodes when matched node found
filters = nil
}
filters = append(filters, node)
found = true
continue
}
}
return
}
func (p *chainHop) filterByProtocol(protocol string, nodes ...*chain.Node) (filters []*chain.Node) {
if protocol == "" || len(nodes) == 0 {
return nodes
}
p.options.logger.Debugf("filter by protocol: %s", protocol)
found := false
for _, node := range nodes {
if node == nil {
continue
}
if node.Options().Protocol == "" {
if !found {
filters = append(filters, node)
}
continue
}
if node.Options().Protocol == protocol {
if !found {
filters = nil
}
filters = append(filters, node)
found = true
continue
}
}
return
}
func (p *chainHop) filterByPath(path string, nodes ...*chain.Node) (filters []*chain.Node) {
if path == "" || len(nodes) == 0 {
return nodes
}
p.options.logger.Debugf("filter by path: %s", path)
sort.SliceStable(nodes, func(i, j int) bool {
return len(nodes[i].Options().Path) > len(nodes[j].Options().Path)
})
found := false
for _, node := range nodes {
if node.Options().Path == "" {
if !found {
filters = append(filters, node)
}
continue
}
if strings.HasPrefix(path, node.Options().Path) {
if !found {
filters = nil
}
filters = append(filters, node)
break
}
}
return
}
func (p *chainHop) periodReload(ctx context.Context) error { func (p *chainHop) periodReload(ctx context.Context) error {
period := p.options.period period := p.options.period
if period < time.Second { if period < time.Second {

View File

@ -49,15 +49,18 @@ func (c *websocketConn) WriteMessage(messageType int, data []byte) error {
} }
func (c *websocketConn) SetDeadline(t time.Time) error { func (c *websocketConn) SetDeadline(t time.Time) error {
c.mux.Lock()
defer c.mux.Unlock()
if err := c.SetReadDeadline(t); err != nil { if err := c.SetReadDeadline(t); err != nil {
return err return err
} }
return c.SetWriteDeadline(t) return c.SetWriteDeadline(t)
} }
func (c *websocketConn) SetReadDeadline(t time.Time) error {
c.mux.Lock()
defer c.mux.Unlock()
return c.Conn.SetReadDeadline(t)
}
func (c *websocketConn) SetWriteDeadline(t time.Time) error { func (c *websocketConn) SetWriteDeadline(t time.Time) error {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()