add udp relay support for http handler

This commit is contained in:
ginuerzh 2021-12-01 21:23:19 +08:00
parent f3411832a8
commit 15f9aa091b
13 changed files with 386 additions and 250 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ _test
release release
debian debian
bin bin
.vscode
# Architecture specific extensions/prefixes # Architecture specific extensions/prefixes
*.[568vq] *.[568vq]

View File

@ -11,6 +11,7 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -50,19 +51,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
}) })
c.logger.Infof("connect %s/%s", address, network) c.logger.Infof("connect %s/%s", address, network)
switch network {
case "tcp", "tcp4", "tcp6":
if _, ok := conn.(net.PacketConn); ok {
err := fmt.Errorf("tcp over udp is unsupported")
c.logger.Error(err)
return nil, err
}
default:
err := fmt.Errorf("network %s is unsupported", network)
c.logger.Error(err)
return nil, err
}
req := &http.Request{ req := &http.Request{
Method: http.MethodConnect, Method: http.MethodConnect,
URL: &url.URL{Host: address}, URL: &url.URL{Host: address},
@ -83,6 +71,21 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
} }
switch network {
case "tcp", "tcp4", "tcp6":
if _, ok := conn.(net.PacketConn); ok {
err := fmt.Errorf("tcp over udp is unsupported")
c.logger.Error(err)
return nil, err
}
case "udp", "udp4", "udp6":
req.Header.Set("X-Gost-Protocol", "udp")
default:
err := fmt.Errorf("network %s is unsupported", network)
c.logger.Error(err)
return nil, err
}
if c.logger.IsLevelEnabled(logger.DebugLevel) { if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false) dump, _ := httputil.DumpRequest(req, false)
c.logger.Debug(string(dump)) c.logger.Debug(string(dump))
@ -113,5 +116,10 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
return nil, fmt.Errorf("%s", resp.Status) return nil, fmt.Errorf("%s", resp.Status)
} }
if network == "udp" {
addr, _ := net.ResolveUDPAddr(network, address)
return socks.UDPTunClientConn(conn, addr), nil
}
return conn, nil return conn, nil
} }

View File

@ -167,11 +167,5 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network
return nil, errors.New("get socks5 UDP tunnel failure") return nil, errors.New("get socks5 UDP tunnel failure")
} }
baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String())
if err != nil {
return nil, err
}
c.logger.Debugf("associate on %s OK", baddr)
return socks.UDPTunClientConn(conn, addr), nil return socks.UDPTunClientConn(conn, addr), nil
} }

View File

@ -90,6 +90,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
req.URL.Scheme = "http" req.URL.Scheme = "http"
} }
network := req.Header.Get("X-Gost-Protocol")
if network != "udp" {
network = "tcp"
}
// Try to get the actual host. // Try to get the actual host.
// Compatible with GOST 2.x. // Compatible with GOST 2.x.
if v := req.Header.Get("Gost-Target"); v != "" { if v := req.Header.Get("Gost-Target"); v != "" {
@ -168,6 +173,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
return return
} }
if network == "udp" {
h.handleUDP(ctx, conn, network, req.Host)
return
}
if req.Method == "PRI" || if req.Method == "PRI" ||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") { (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp.StatusCode = http.StatusBadRequest resp.StatusCode = http.StatusBadRequest
@ -187,7 +197,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
WithChain(h.chain). WithChain(h.chain).
WithRetry(h.md.retryCount). WithRetry(h.md.retryCount).
WithLogger(h.logger) WithLogger(h.logger)
cc, err := r.Dial(ctx, "tcp", addr) cc, err := r.Dial(ctx, network, addr)
if err != nil { if err != nil {
resp.StatusCode = http.StatusServiceUnavailable resp.StatusCode = http.StatusServiceUnavailable
resp.Write(conn) resp.Write(conn)
@ -209,13 +219,13 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
h.logger.Debug(string(dump)) h.logger.Debug(string(dump))
} }
if err = resp.Write(conn); err != nil { if err = resp.Write(conn); err != nil {
h.logger.Warn(err) h.logger.Error(err)
return return
} }
} else { } else {
req.Header.Del("Proxy-Connection") req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil { if err = req.Write(cc); err != nil {
h.logger.Warn(err) h.logger.Error(err)
return return
} }
} }

View File

@ -13,6 +13,7 @@ type metadata struct {
retryCount int retryCount int
probeResist *probeResist probeResist *probeResist
sni bool sni bool
enableUDP bool
} }
func (h *httpHandler) parseMetadata(md md.Metadata) error { func (h *httpHandler) parseMetadata(md md.Metadata) error {
@ -23,6 +24,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
knock = "knock" knock = "knock"
retryCount = "retry" retryCount = "retry"
sni = "sni" sni = "sni"
enableUDP = "udp"
) )
h.md.proxyAgent = md.GetString(proxyAgent) h.md.proxyAgent = md.GetString(proxyAgent)
@ -53,6 +55,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
} }
h.md.retryCount = md.GetInt(retryCount) h.md.retryCount = md.GetInt(retryCount)
h.md.sni = md.GetBool(sni) h.md.sni = md.GetBool(sni)
h.md.enableUDP = md.GetBool(enableUDP)
return nil return nil
} }

83
pkg/handler/http/udp.go Normal file
View File

@ -0,0 +1,83 @@
package http
import (
"context"
"net"
"net/http"
"net/http/httputil"
"time"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
"cmd": "udp",
})
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
}
if h.md.proxyAgent != "" {
resp.Header.Add("Proxy-Agent", h.md.proxyAgent)
}
if !h.md.enableUDP {
resp.StatusCode = http.StatusForbidden
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
}
h.logger.Error("UDP relay is diabled")
return
}
resp.StatusCode = http.StatusOK
if h.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
}
if err := resp.Write(conn); err != nil {
h.logger.Error(err)
return
}
// obtain a udp connection
r := (&chain.Router{}).
WithChain(h.chain).
WithRetry(h.md.retryCount).
WithLogger(h.logger)
c, err := r.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
return
}
defer c.Close()
pc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
return
}
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.bypass).
WithLogger(h.logger)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run()
h.logger.
WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}

126
pkg/handler/relay.go Normal file
View File

@ -0,0 +1,126 @@
package handler
import (
"net"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/logger"
)
type UDPRelay struct {
pc1 net.PacketConn
pc2 net.PacketConn
bypass bypass.Bypass
bufferSize int
logger logger.Logger
}
func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay {
return &UDPRelay{
pc1: pc1,
pc2: pc2,
}
}
func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay {
r.bypass = bp
return r
}
func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay {
r.logger = logger
return r
}
func (r *UDPRelay) SetBufferSize(n int) {
r.bufferSize = n
}
func (r *UDPRelay) Run() (err error) {
bufSize := r.bufferSize
if bufSize <= 0 {
bufSize = 1024
}
errc := make(chan error, 2)
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := r.pc1.ReadFrom(b)
if err != nil {
return err
}
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
if r.logger != nil {
r.logger.Warn("bypass: ", raddr)
}
return nil
}
if _, err := r.pc2.WriteTo(b[:n], raddr); err != nil {
return err
}
if r.logger != nil {
r.logger.Debugf("%s >>> %s data: %d",
r.pc2.LocalAddr(), raddr, n)
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := r.pc2.ReadFrom(b)
if err != nil {
return err
}
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
if r.logger != nil {
r.logger.Warn("bypass: ", raddr)
}
return nil
}
if _, err := r.pc1.WriteTo(b[:n], raddr); err != nil {
return err
}
if r.logger != nil {
r.logger.Debugf("%s <<< %s data: %d",
r.pc2.LocalAddr(), raddr, n)
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
return <-errc
}

View File

@ -136,6 +136,5 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
h.handleConnect(ctx, conn, network, address) h.handleConnect(ctx, conn, network, address)
case relay.BIND: case relay.BIND:
h.handleBind(ctx, conn, network, address) h.handleBind(ctx, conn, network, address)
case relay.ASSOCIATE:
} }
} }

View File

@ -6,7 +6,7 @@ import (
"time" "time"
"github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/auth"
util_tls "github.com/go-gost/gost/pkg/common/util/tls" tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
) )
@ -23,7 +23,7 @@ type metadata struct {
compatibilityMode bool compatibilityMode bool
} }
func (h *socks5Handler) parseMetadata(md md.Metadata) error { func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) {
const ( const (
certFile = "certFile" certFile = "certFile"
keyFile = "keyFile" keyFile = "keyFile"
@ -39,14 +39,19 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error {
compatibilityMode = "comp" compatibilityMode = "comp"
) )
var err error if md.GetString(certFile) != "" ||
h.md.tlsConfig, err = util_tls.LoadTLSConfig( md.GetString(keyFile) != "" ||
md.GetString(caFile) != "" {
h.md.tlsConfig, err = tls_util.LoadTLSConfig(
md.GetString(certFile), md.GetString(certFile),
md.GetString(keyFile), md.GetString(keyFile),
md.GetString(caFile), md.GetString(caFile),
) )
if err != nil { if err != nil {
h.logger.Warn("parse tls config: ", err) return
}
} else {
h.md.tlsConfig = tls_util.DefaultConfig
} }
if v, _ := md.Get(users).([]interface{}); len(v) > 0 { if v, _ := md.Get(users).([]interface{}); len(v) > 0 {

View File

@ -9,8 +9,9 @@ import (
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
) )
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
@ -26,7 +27,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
return return
} }
relay, err := net.ListenUDP("udp", nil) cc, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil) reply := gosocks5.NewReply(gosocks5.Failure, nil)
@ -34,10 +35,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
h.logger.Debug(reply) h.logger.Debug(reply)
return return
} }
defer relay.Close() defer cc.Close()
saddr := gosocks5.Addr{} saddr := gosocks5.Addr{}
saddr.ParseFrom(relay.LocalAddr().String()) saddr.ParseFrom(cc.LocalAddr().String())
saddr.Type = 0 saddr.Type = 0
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
@ -48,99 +49,39 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
h.logger.Debug(reply) h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{ h.logger = h.logger.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", relay.LocalAddr(), relay.LocalAddr().Network()), "bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()),
}) })
h.logger.Debugf("bind on %s OK", relay.LocalAddr()) h.logger.Debugf("bind on %s OK", cc.LocalAddr())
peer, err := net.ListenUDP("udp", nil) // obtain a udp connection
r := (&chain.Router{}).
WithChain(h.chain).
WithRetry(h.md.retryCount).
WithLogger(h.logger)
c, err := r.Dial(ctx, "udp", "") // UDP association
if err != nil { if err != nil {
h.logger.Error(err) h.logger.Error(err)
return return
} }
defer peer.Close() defer c.Close()
go h.relayUDP( pc, ok := c.(net.PacketConn)
socks.UDPConn(relay, h.md.udpBufferSize), if !ok {
peer, h.logger.Errorf("wrong connection type")
) return
}
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
WithBypass(h.bypass).
WithLogger(h.logger)
relay.SetBufferSize(h.md.udpBufferSize)
go relay.Run()
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), relay.LocalAddr()) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
io.Copy(ioutil.Discard, conn) io.Copy(ioutil.Discard, conn)
h.logger. h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}). WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), relay.LocalAddr()) Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
}
func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := peer.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := peer.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
return <-errc
} }

View File

@ -2,13 +2,13 @@ package v5
import ( import (
"context" "context"
"fmt"
"net" "net"
"time" "time"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
) )
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) { func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
@ -24,111 +24,43 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
return return
} }
bindAddr, _ := net.ResolveUDPAddr(network, address) // dummy bind
pc, err := net.ListenUDP(network, bindAddr) reply := gosocks5.NewReply(gosocks5.Succeeded, nil)
if err != nil {
h.logger.Error(err)
return
}
defer pc.Close()
saddr, _ := gosocks5.NewAddr(pc.LocalAddr().String())
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
saddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, saddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
return return
} }
h.logger.Debug(reply) h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{ // obtain a udp connection
"bind": fmt.Sprintf("%s/%s", pc.LocalAddr(), pc.LocalAddr().Network()), r := (&chain.Router{}).
}) WithChain(h.chain).
WithRetry(h.md.retryCount).
WithLogger(h.logger)
c, err := r.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
return
}
defer c.Close()
h.logger.Debugf("bind on %s OK", pc.LocalAddr()) pc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
return
}
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.bypass).
WithLogger(h.logger)
relay.SetBufferSize(h.md.udpBufferSize)
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
h.tunnelServerUDP( relay.Run()
socks.UDPTunServerConn(conn),
pc,
)
h.logger. h.logger.
WithFields(map[string]interface{}{ WithFields(map[string]interface{}{
"duration": time.Since(t), "duration": time.Since(t),
}). }).
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
} }
func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := tunnel.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
return <-errc
}

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"net/http" "net/http"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
ws_util "github.com/go-gost/gost/pkg/common/util/ws" ws_util "github.com/go-gost/gost/pkg/common/util/ws"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
@ -16,7 +15,7 @@ import (
func init() { func init() {
registry.RegisterListener("ws", NewListener) registry.RegisterListener("ws", NewListener)
registry.RegisterListener("wss", NewListener) registry.RegisterListener("wss", NewTLSListener)
} }
type wsListener struct { type wsListener struct {
@ -25,6 +24,7 @@ type wsListener struct {
addr net.Addr addr net.Addr
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
srv *http.Server srv *http.Server
tlsEnabled bool
connChan chan net.Conn connChan chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
@ -41,6 +41,18 @@ func NewListener(opts ...listener.Option) listener.Listener {
} }
} }
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &wsListener{
saddr: options.Addr,
tlsEnabled: true,
logger: options.Logger,
}
}
func (l *wsListener) Init(md md.Metadata) (err error) { func (l *wsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil { if err = l.parseMetadata(md); err != nil {
return return
@ -115,19 +127,6 @@ func (l *wsListener) Addr() net.Addr {
return l.addr return l.addr
} }
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
)
if err != nil {
return
}
return
}
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil { if err != nil {

View File

@ -4,20 +4,9 @@ import (
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"time" "time"
)
const ( tls_util "github.com/go-gost/gost/pkg/common/util/tls"
path = "path" md "github.com/go-gost/gost/pkg/metadata"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
) )
const ( const (
@ -36,3 +25,49 @@ type metadata struct {
responseHeader http.Header responseHeader http.Header
connQueueSize int connQueueSize int
} }
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
)
if l.tlsEnabled {
if md.GetString(certFile) != "" ||
md.GetString(keyFile) != "" ||
md.GetString(caFile) != "" {
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
)
if err != nil {
return
}
} else {
l.md.tlsConfig = tls_util.DefaultConfig
}
}
l.md.path = md.GetString(path)
l.md.connQueueSize = md.GetInt(connQueueSize)
if l.md.connQueueSize <= 0 {
l.md.connQueueSize = defaultQueueSize
}
l.md.enableCompression = md.GetBool(enableCompression)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
return
}