From 15f9aa091b827ad561f7bd8b82a79cf8c96d1c5a Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 1 Dec 2021 21:23:19 +0800 Subject: [PATCH] add udp relay support for http handler --- .gitignore | 1 + pkg/connector/http/connector.go | 34 +++++--- pkg/connector/socks/v5/connector.go | 6 -- pkg/handler/http/handler.go | 16 +++- pkg/handler/http/metadata.go | 3 + pkg/handler/http/udp.go | 83 ++++++++++++++++++ pkg/handler/relay.go | 126 ++++++++++++++++++++++++++++ pkg/handler/relay/handler.go | 1 - pkg/handler/socks/v5/metadata.go | 25 +++--- pkg/handler/socks/v5/udp.go | 115 +++++++------------------ pkg/handler/socks/v5/udp_tun.go | 120 ++++++-------------------- pkg/listener/ws/listener.go | 45 +++++----- pkg/listener/ws/metadata.go | 61 +++++++++++--- 13 files changed, 386 insertions(+), 250 deletions(-) create mode 100644 pkg/handler/http/udp.go create mode 100644 pkg/handler/relay.go diff --git a/.gitignore b/.gitignore index 0cd5cca..2072667 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ _test release debian bin +.vscode # Architecture specific extensions/prefixes *.[568vq] diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 14e645b..7b6492e 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -11,6 +11,7 @@ import ( "net/url" "time" + "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -50,19 +51,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add }) c.logger.Infof("connect %s/%s", address, network) - switch network { - case "tcp", "tcp4", "tcp6": - if _, ok := conn.(net.PacketConn); ok { - err := fmt.Errorf("tcp over udp is unsupported") - c.logger.Error(err) - return nil, err - } - default: - err := fmt.Errorf("network %s is unsupported", network) - c.logger.Error(err) - return nil, err - } - req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: address}, @@ -83,6 +71,21 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) } + switch network { + case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + c.logger.Error(err) + return nil, err + } + case "udp", "udp4", "udp6": + req.Header.Set("X-Gost-Protocol", "udp") + default: + err := fmt.Errorf("network %s is unsupported", network) + c.logger.Error(err) + return nil, err + } + if c.logger.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpRequest(req, false) c.logger.Debug(string(dump)) @@ -113,5 +116,10 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add return nil, fmt.Errorf("%s", resp.Status) } + if network == "udp" { + addr, _ := net.ResolveUDPAddr(network, address) + return socks.UDPTunClientConn(conn, addr), nil + } + return conn, nil } diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index b12e651..de5b5b2 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -167,11 +167,5 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network return nil, errors.New("get socks5 UDP tunnel failure") } - baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) - if err != nil { - return nil, err - } - c.logger.Debugf("associate on %s OK", baddr) - return socks.UDPTunClientConn(conn, addr), nil } diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index affaf81..2565b61 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -90,6 +90,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt req.URL.Scheme = "http" } + network := req.Header.Get("X-Gost-Protocol") + if network != "udp" { + network = "tcp" + } + // Try to get the actual host. // Compatible with GOST 2.x. if v := req.Header.Get("Gost-Target"); v != "" { @@ -168,6 +173,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt return } + if network == "udp" { + h.handleUDP(ctx, conn, network, req.Host) + return + } + if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { resp.StatusCode = http.StatusBadRequest @@ -187,7 +197,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt WithChain(h.chain). WithRetry(h.md.retryCount). WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", addr) + cc, err := r.Dial(ctx, network, addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable resp.Write(conn) @@ -209,13 +219,13 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt h.logger.Debug(string(dump)) } if err = resp.Write(conn); err != nil { - h.logger.Warn(err) + h.logger.Error(err) return } } else { req.Header.Del("Proxy-Connection") if err = req.Write(cc); err != nil { - h.logger.Warn(err) + h.logger.Error(err) return } } diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index b8d8cfa..7fa7944 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -13,6 +13,7 @@ type metadata struct { retryCount int probeResist *probeResist sni bool + enableUDP bool } func (h *httpHandler) parseMetadata(md md.Metadata) error { @@ -23,6 +24,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error { knock = "knock" retryCount = "retry" sni = "sni" + enableUDP = "udp" ) h.md.proxyAgent = md.GetString(proxyAgent) @@ -53,6 +55,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error { } h.md.retryCount = md.GetInt(retryCount) h.md.sni = md.GetBool(sni) + h.md.enableUDP = md.GetBool(enableUDP) return nil } diff --git a/pkg/handler/http/udp.go b/pkg/handler/http/udp.go new file mode 100644 index 0000000..3dbeb3c --- /dev/null +++ b/pkg/handler/http/udp.go @@ -0,0 +1,83 @@ +package http + +import ( + "context" + "net" + "net/http" + "net/http/httputil" + "time" + + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" +) + +func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) { + h.logger = h.logger.WithFields(map[string]interface{}{ + "cmd": "udp", + }) + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + } + if h.md.proxyAgent != "" { + resp.Header.Add("Proxy-Agent", h.md.proxyAgent) + } + + if !h.md.enableUDP { + resp.StatusCode = http.StatusForbidden + resp.Write(conn) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + h.logger.Error("UDP relay is diabled") + + return + } + + resp.StatusCode = http.StatusOK + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + if err := resp.Write(conn); err != nil { + h.logger.Error(err) + return + } + + // obtain a udp connection + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + c, err := r.Dial(ctx, "udp", "") // UDP association + if err != nil { + h.logger.Error(err) + return + } + defer c.Close() + + pc, ok := c.(net.PacketConn) + if !ok { + h.logger.Errorf("wrong connection type") + return + } + + relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + WithBypass(h.bypass). + WithLogger(h.logger) + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + relay.Run() + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) +} diff --git a/pkg/handler/relay.go b/pkg/handler/relay.go new file mode 100644 index 0000000..9cb82c4 --- /dev/null +++ b/pkg/handler/relay.go @@ -0,0 +1,126 @@ +package handler + +import ( + "net" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/logger" +) + +type UDPRelay struct { + pc1 net.PacketConn + pc2 net.PacketConn + + bypass bypass.Bypass + bufferSize int + logger logger.Logger +} + +func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay { + return &UDPRelay{ + pc1: pc1, + pc2: pc2, + } +} + +func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay { + r.bypass = bp + return r +} + +func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay { + r.logger = logger + return r +} + +func (r *UDPRelay) SetBufferSize(n int) { + r.bufferSize = n +} + +func (r *UDPRelay) Run() (err error) { + bufSize := r.bufferSize + if bufSize <= 0 { + bufSize = 1024 + } + + errc := make(chan error, 2) + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := r.pc1.ReadFrom(b) + if err != nil { + return err + } + + if r.bypass != nil && r.bypass.Contains(raddr.String()) { + if r.logger != nil { + r.logger.Warn("bypass: ", raddr) + } + return nil + } + + if _, err := r.pc2.WriteTo(b[:n], raddr); err != nil { + return err + } + + if r.logger != nil { + r.logger.Debugf("%s >>> %s data: %d", + r.pc2.LocalAddr(), raddr, n) + + } + + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := r.pc2.ReadFrom(b) + if err != nil { + return err + } + + if r.bypass != nil && r.bypass.Contains(raddr.String()) { + if r.logger != nil { + r.logger.Warn("bypass: ", raddr) + } + return nil + } + + if _, err := r.pc1.WriteTo(b[:n], raddr); err != nil { + return err + } + + if r.logger != nil { + r.logger.Debugf("%s <<< %s data: %d", + r.pc2.LocalAddr(), raddr, n) + + } + + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + return <-errc +} diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 1ad1152..03dac36 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -136,6 +136,5 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { h.handleConnect(ctx, conn, network, address) case relay.BIND: h.handleBind(ctx, conn, network, address) - case relay.ASSOCIATE: } } diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 82343d0..6d10a7a 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -6,7 +6,7 @@ import ( "time" "github.com/go-gost/gost/pkg/auth" - util_tls "github.com/go-gost/gost/pkg/common/util/tls" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" md "github.com/go-gost/gost/pkg/metadata" ) @@ -23,7 +23,7 @@ type metadata struct { compatibilityMode bool } -func (h *socks5Handler) parseMetadata(md md.Metadata) error { +func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -39,14 +39,19 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error { compatibilityMode = "comp" ) - var err error - h.md.tlsConfig, err = util_tls.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - h.logger.Warn("parse tls config: ", err) + if md.GetString(certFile) != "" || + md.GetString(keyFile) != "" || + md.GetString(caFile) != "" { + h.md.tlsConfig, err = tls_util.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + } else { + h.md.tlsConfig = tls_util.DefaultConfig } if v, _ := md.Get(users).([]interface{}); len(v) > 0 { diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 409fee3..3a155c6 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -9,8 +9,9 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/handler" ) func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { @@ -26,7 +27,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { return } - relay, err := net.ListenUDP("udp", nil) + cc, err := net.ListenUDP("udp", nil) if err != nil { h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -34,10 +35,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { h.logger.Debug(reply) return } - defer relay.Close() + defer cc.Close() saddr := gosocks5.Addr{} - saddr.ParseFrom(relay.LocalAddr().String()) + saddr.ParseFrom(cc.LocalAddr().String()) saddr.Type = 0 saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) @@ -48,99 +49,39 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": fmt.Sprintf("%s/%s", relay.LocalAddr(), relay.LocalAddr().Network()), + "bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()), }) - h.logger.Debugf("bind on %s OK", relay.LocalAddr()) + h.logger.Debugf("bind on %s OK", cc.LocalAddr()) - peer, err := net.ListenUDP("udp", nil) + // obtain a udp connection + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + c, err := r.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return } - defer peer.Close() + defer c.Close() - go h.relayUDP( - socks.UDPConn(relay, h.md.udpBufferSize), - peer, - ) + pc, ok := c.(net.PacketConn) + if !ok { + h.logger.Errorf("wrong connection type") + return + } + + relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). + WithBypass(h.bypass). + WithLogger(h.logger) + relay.SetBufferSize(h.md.udpBufferSize) + + go relay.Run() t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), relay.LocalAddr()) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) io.Copy(ioutil.Discard, conn) h.logger. WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), relay.LocalAddr()) -} - -func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) { - bufSize := h.md.udpBufferSize - errc := make(chan error, 2) - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := c.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := peer.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s >>> %s data: %d", - peer.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := peer.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := c.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s <<< %s data: %d", - peer.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc + Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) } diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index ce0a72c..01ba438 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -2,13 +2,13 @@ package v5 import ( "context" - "fmt" "net" "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/handler" ) func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) { @@ -24,111 +24,43 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network return } - bindAddr, _ := net.ResolveUDPAddr(network, address) - pc, err := net.ListenUDP(network, bindAddr) - if err != nil { - h.logger.Error(err) - return - } - defer pc.Close() - - saddr, _ := gosocks5.NewAddr(pc.LocalAddr().String()) - saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) - saddr.Type = 0 - reply := gosocks5.NewReply(gosocks5.Succeeded, saddr) + // dummy bind + reply := gosocks5.NewReply(gosocks5.Succeeded, nil) if err := reply.Write(conn); err != nil { h.logger.Error(err) return } h.logger.Debug(reply) - h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": fmt.Sprintf("%s/%s", pc.LocalAddr(), pc.LocalAddr().Network()), - }) + // obtain a udp connection + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + c, err := r.Dial(ctx, "udp", "") // UDP association + if err != nil { + h.logger.Error(err) + return + } + defer c.Close() - h.logger.Debugf("bind on %s OK", pc.LocalAddr()) + pc, ok := c.(net.PacketConn) + if !ok { + h.logger.Errorf("wrong connection type") + return + } + + relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + WithBypass(h.bypass). + WithLogger(h.logger) + relay.SetBufferSize(h.md.udpBufferSize) t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - h.tunnelServerUDP( - socks.UDPTunServerConn(conn), - pc, - ) + relay.Run() h.logger. WithFields(map[string]interface{}{ "duration": time.Since(t), }). Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) } - -func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { - bufSize := h.md.udpBufferSize - errc := make(chan error, 2) - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := tunnel.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := c.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s >>> %s data: %d", - c.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := c.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := tunnel.WriteTo(b[:n], raddr); err != nil { - return err - } - h.logger.Debugf("%s <<< %s data: %d", - c.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc -} diff --git a/pkg/listener/ws/listener.go b/pkg/listener/ws/listener.go index d8270f7..8d43250 100644 --- a/pkg/listener/ws/listener.go +++ b/pkg/listener/ws/listener.go @@ -5,7 +5,6 @@ import ( "net" "net/http" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" ws_util "github.com/go-gost/gost/pkg/common/util/ws" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -16,18 +15,19 @@ import ( func init() { registry.RegisterListener("ws", NewListener) - registry.RegisterListener("wss", NewListener) + registry.RegisterListener("wss", NewTLSListener) } type wsListener struct { - saddr string - md metadata - addr net.Addr - upgrader *websocket.Upgrader - srv *http.Server - connChan chan net.Conn - errChan chan error - logger logger.Logger + saddr string + md metadata + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + tlsEnabled bool + connChan chan net.Conn + errChan chan error + logger logger.Logger } func NewListener(opts ...listener.Option) listener.Listener { @@ -41,6 +41,18 @@ func NewListener(opts ...listener.Option) listener.Listener { } } +func NewTLSListener(opts ...listener.Option) listener.Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &wsListener{ + saddr: options.Addr, + tlsEnabled: true, + logger: options.Logger, + } +} + func (l *wsListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return @@ -115,19 +127,6 @@ func (l *wsListener) Addr() net.Addr { return l.addr } -func (l *wsListener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - return -} - func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) if err != nil { diff --git a/pkg/listener/ws/metadata.go b/pkg/listener/ws/metadata.go index 95f81e1..a2165cb 100644 --- a/pkg/listener/ws/metadata.go +++ b/pkg/listener/ws/metadata.go @@ -4,20 +4,9 @@ import ( "crypto/tls" "net/http" "time" -) -const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - responseHeader = "responseHeader" - connQueueSize = "connQueueSize" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -36,3 +25,49 @@ type metadata struct { responseHeader http.Header connQueueSize int } + +func (l *wsListener) parseMetadata(md md.Metadata) (err error) { + const ( + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + enableCompression = "enableCompression" + responseHeader = "responseHeader" + connQueueSize = "connQueueSize" + ) + + if l.tlsEnabled { + if md.GetString(certFile) != "" || + md.GetString(keyFile) != "" || + md.GetString(caFile) != "" { + l.md.tlsConfig, err = tls_util.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + } else { + l.md.tlsConfig = tls_util.DefaultConfig + } + } + + l.md.path = md.GetString(path) + l.md.connQueueSize = md.GetInt(connQueueSize) + if l.md.connQueueSize <= 0 { + l.md.connQueueSize = defaultQueueSize + } + l.md.enableCompression = md.GetBool(enableCompression) + l.md.readBufferSize = md.GetInt(readBufferSize) + l.md.writeBufferSize = md.GetInt(writeBufferSize) + l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) + + return +}