add udp relay support for http handler
This commit is contained in:
parent
f3411832a8
commit
15f9aa091b
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,6 +9,7 @@ _test
|
||||
release
|
||||
debian
|
||||
bin
|
||||
.vscode
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/connector"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
@ -50,19 +51,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
||||
})
|
||||
c.logger.Infof("connect %s/%s", address, network)
|
||||
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
err := fmt.Errorf("tcp over udp is unsupported")
|
||||
c.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("network %s is unsupported", network)
|
||||
c.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: address},
|
||||
@ -83,6 +71,21 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
||||
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
|
||||
}
|
||||
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
err := fmt.Errorf("tcp over udp is unsupported")
|
||||
c.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
case "udp", "udp4", "udp6":
|
||||
req.Header.Set("X-Gost-Protocol", "udp")
|
||||
default:
|
||||
err := fmt.Errorf("network %s is unsupported", network)
|
||||
c.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
c.logger.Debug(string(dump))
|
||||
@ -113,5 +116,10 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
||||
return nil, fmt.Errorf("%s", resp.Status)
|
||||
}
|
||||
|
||||
if network == "udp" {
|
||||
addr, _ := net.ResolveUDPAddr(network, address)
|
||||
return socks.UDPTunClientConn(conn, addr), nil
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
@ -167,11 +167,5 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network
|
||||
return nil, errors.New("get socks5 UDP tunnel failure")
|
||||
}
|
||||
|
||||
baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Debugf("associate on %s OK", baddr)
|
||||
|
||||
return socks.UDPTunClientConn(conn, addr), nil
|
||||
}
|
||||
|
@ -90,6 +90,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
req.URL.Scheme = "http"
|
||||
}
|
||||
|
||||
network := req.Header.Get("X-Gost-Protocol")
|
||||
if network != "udp" {
|
||||
network = "tcp"
|
||||
}
|
||||
|
||||
// Try to get the actual host.
|
||||
// Compatible with GOST 2.x.
|
||||
if v := req.Header.Get("Gost-Target"); v != "" {
|
||||
@ -168,6 +173,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
return
|
||||
}
|
||||
|
||||
if network == "udp" {
|
||||
h.handleUDP(ctx, conn, network, req.Host)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == "PRI" ||
|
||||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
||||
resp.StatusCode = http.StatusBadRequest
|
||||
@ -187,7 +197,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
WithChain(h.chain).
|
||||
WithRetry(h.md.retryCount).
|
||||
WithLogger(h.logger)
|
||||
cc, err := r.Dial(ctx, "tcp", addr)
|
||||
cc, err := r.Dial(ctx, network, addr)
|
||||
if err != nil {
|
||||
resp.StatusCode = http.StatusServiceUnavailable
|
||||
resp.Write(conn)
|
||||
@ -209,13 +219,13 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
h.logger.Debug(string(dump))
|
||||
}
|
||||
if err = resp.Write(conn); err != nil {
|
||||
h.logger.Warn(err)
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
req.Header.Del("Proxy-Connection")
|
||||
if err = req.Write(cc); err != nil {
|
||||
h.logger.Warn(err)
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ type metadata struct {
|
||||
retryCount int
|
||||
probeResist *probeResist
|
||||
sni bool
|
||||
enableUDP bool
|
||||
}
|
||||
|
||||
func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
||||
@ -23,6 +24,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
||||
knock = "knock"
|
||||
retryCount = "retry"
|
||||
sni = "sni"
|
||||
enableUDP = "udp"
|
||||
)
|
||||
|
||||
h.md.proxyAgent = md.GetString(proxyAgent)
|
||||
@ -53,6 +55,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
||||
}
|
||||
h.md.retryCount = md.GetInt(retryCount)
|
||||
h.md.sni = md.GetBool(sni)
|
||||
h.md.enableUDP = md.GetBool(enableUDP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
83
pkg/handler/http/udp.go
Normal file
83
pkg/handler/http/udp.go
Normal file
@ -0,0 +1,83 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
"cmd": "udp",
|
||||
})
|
||||
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
}
|
||||
if h.md.proxyAgent != "" {
|
||||
resp.Header.Add("Proxy-Agent", h.md.proxyAgent)
|
||||
}
|
||||
|
||||
if !h.md.enableUDP {
|
||||
resp.StatusCode = http.StatusForbidden
|
||||
resp.Write(conn)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
}
|
||||
h.logger.Error("UDP relay is diabled")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.StatusCode = http.StatusOK
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
}
|
||||
if err := resp.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// obtain a udp connection
|
||||
r := (&chain.Router{}).
|
||||
WithChain(h.chain).
|
||||
WithRetry(h.md.retryCount).
|
||||
WithLogger(h.logger)
|
||||
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||
WithBypass(h.bypass).
|
||||
WithLogger(h.logger)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
relay.Run()
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}
|
126
pkg/handler/relay.go
Normal file
126
pkg/handler/relay.go
Normal file
@ -0,0 +1,126 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/bypass"
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type UDPRelay struct {
|
||||
pc1 net.PacketConn
|
||||
pc2 net.PacketConn
|
||||
|
||||
bypass bypass.Bypass
|
||||
bufferSize int
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay {
|
||||
return &UDPRelay{
|
||||
pc1: pc1,
|
||||
pc2: pc2,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay {
|
||||
r.bypass = bp
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay {
|
||||
r.logger = logger
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *UDPRelay) SetBufferSize(n int) {
|
||||
r.bufferSize = n
|
||||
}
|
||||
|
||||
func (r *UDPRelay) Run() (err error) {
|
||||
bufSize := r.bufferSize
|
||||
if bufSize <= 0 {
|
||||
bufSize = 1024
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := r.pc1.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := r.pc2.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.logger != nil {
|
||||
r.logger.Debugf("%s >>> %s data: %d",
|
||||
r.pc2.LocalAddr(), raddr, n)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := r.pc2.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := r.pc1.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.logger != nil {
|
||||
r.logger.Debugf("%s <<< %s data: %d",
|
||||
r.pc2.LocalAddr(), raddr, n)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
}
|
@ -136,6 +136,5 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
h.handleConnect(ctx, conn, network, address)
|
||||
case relay.BIND:
|
||||
h.handleBind(ctx, conn, network, address)
|
||||
case relay.ASSOCIATE:
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/auth"
|
||||
util_tls "github.com/go-gost/gost/pkg/common/util/tls"
|
||||
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
@ -23,7 +23,7 @@ type metadata struct {
|
||||
compatibilityMode bool
|
||||
}
|
||||
|
||||
func (h *socks5Handler) parseMetadata(md md.Metadata) error {
|
||||
func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) {
|
||||
const (
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
@ -39,14 +39,19 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error {
|
||||
compatibilityMode = "comp"
|
||||
)
|
||||
|
||||
var err error
|
||||
h.md.tlsConfig, err = util_tls.LoadTLSConfig(
|
||||
md.GetString(certFile),
|
||||
md.GetString(keyFile),
|
||||
md.GetString(caFile),
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("parse tls config: ", err)
|
||||
if md.GetString(certFile) != "" ||
|
||||
md.GetString(keyFile) != "" ||
|
||||
md.GetString(caFile) != "" {
|
||||
h.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
||||
md.GetString(certFile),
|
||||
md.GetString(keyFile),
|
||||
md.GetString(caFile),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
h.md.tlsConfig = tls_util.DefaultConfig
|
||||
}
|
||||
|
||||
if v, _ := md.Get(users).([]interface{}); len(v) > 0 {
|
||||
|
@ -9,8 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
@ -26,7 +27,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
relay, err := net.ListenUDP("udp", nil)
|
||||
cc, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||
@ -34,10 +35,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
h.logger.Debug(reply)
|
||||
return
|
||||
}
|
||||
defer relay.Close()
|
||||
defer cc.Close()
|
||||
|
||||
saddr := gosocks5.Addr{}
|
||||
saddr.ParseFrom(relay.LocalAddr().String())
|
||||
saddr.ParseFrom(cc.LocalAddr().String())
|
||||
saddr.Type = 0
|
||||
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
|
||||
@ -48,99 +49,39 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
h.logger.Debug(reply)
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", relay.LocalAddr(), relay.LocalAddr().Network()),
|
||||
"bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()),
|
||||
})
|
||||
h.logger.Debugf("bind on %s OK", relay.LocalAddr())
|
||||
h.logger.Debugf("bind on %s OK", cc.LocalAddr())
|
||||
|
||||
peer, err := net.ListenUDP("udp", nil)
|
||||
// obtain a udp connection
|
||||
r := (&chain.Router{}).
|
||||
WithChain(h.chain).
|
||||
WithRetry(h.md.retryCount).
|
||||
WithLogger(h.logger)
|
||||
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer peer.Close()
|
||||
defer c.Close()
|
||||
|
||||
go h.relayUDP(
|
||||
socks.UDPConn(relay, h.md.udpBufferSize),
|
||||
peer,
|
||||
)
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
|
||||
WithBypass(h.bypass).
|
||||
WithLogger(h.logger)
|
||||
relay.SetBufferSize(h.md.udpBufferSize)
|
||||
|
||||
go relay.Run()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), relay.LocalAddr())
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
io.Copy(ioutil.Discard, conn)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), relay.LocalAddr())
|
||||
}
|
||||
|
||||
func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) {
|
||||
bufSize := h.md.udpBufferSize
|
||||
errc := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := c.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := peer.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s >>> %s data: %d",
|
||||
peer.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := peer.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := c.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s <<< %s data: %d",
|
||||
peer.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
}
|
||||
|
@ -2,13 +2,13 @@ package v5
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
|
||||
@ -24,111 +24,43 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
|
||||
return
|
||||
}
|
||||
|
||||
bindAddr, _ := net.ResolveUDPAddr(network, address)
|
||||
pc, err := net.ListenUDP(network, bindAddr)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
saddr, _ := gosocks5.NewAddr(pc.LocalAddr().String())
|
||||
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
|
||||
saddr.Type = 0
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, saddr)
|
||||
// dummy bind
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, nil)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", pc.LocalAddr(), pc.LocalAddr().Network()),
|
||||
})
|
||||
// obtain a udp connection
|
||||
r := (&chain.Router{}).
|
||||
WithChain(h.chain).
|
||||
WithRetry(h.md.retryCount).
|
||||
WithLogger(h.logger)
|
||||
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
h.logger.Debugf("bind on %s OK", pc.LocalAddr())
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||
WithBypass(h.bypass).
|
||||
WithLogger(h.logger)
|
||||
relay.SetBufferSize(h.md.udpBufferSize)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
h.tunnelServerUDP(
|
||||
socks.UDPTunServerConn(conn),
|
||||
pc,
|
||||
)
|
||||
relay.Run()
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}
|
||||
|
||||
func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
||||
bufSize := h.md.udpBufferSize
|
||||
errc := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := tunnel.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := c.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s >>> %s data: %d",
|
||||
c.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := c.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("%s <<< %s data: %d",
|
||||
c.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
||||
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
|
||||
"github.com/go-gost/gost/pkg/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
@ -16,18 +15,19 @@ import (
|
||||
|
||||
func init() {
|
||||
registry.RegisterListener("ws", NewListener)
|
||||
registry.RegisterListener("wss", NewListener)
|
||||
registry.RegisterListener("wss", NewTLSListener)
|
||||
}
|
||||
|
||||
type wsListener struct {
|
||||
saddr string
|
||||
md metadata
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
saddr string
|
||||
md metadata
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
tlsEnabled bool
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
@ -41,6 +41,18 @@ func NewListener(opts ...listener.Option) listener.Listener {
|
||||
}
|
||||
}
|
||||
|
||||
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &wsListener{
|
||||
saddr: options.Addr,
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *wsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
@ -115,19 +127,6 @@ func (l *wsListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
|
||||
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
||||
md.GetString(certFile),
|
||||
md.GetString(keyFile),
|
||||
md.GetString(caFile),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
|
||||
if err != nil {
|
||||
|
@ -4,20 +4,9 @@ import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
responseHeader = "responseHeader"
|
||||
connQueueSize = "connQueueSize"
|
||||
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -36,3 +25,49 @@ type metadata struct {
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
}
|
||||
|
||||
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
|
||||
const (
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
responseHeader = "responseHeader"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
if l.tlsEnabled {
|
||||
if md.GetString(certFile) != "" ||
|
||||
md.GetString(keyFile) != "" ||
|
||||
md.GetString(caFile) != "" {
|
||||
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
||||
md.GetString(certFile),
|
||||
md.GetString(keyFile),
|
||||
md.GetString(caFile),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
l.md.tlsConfig = tls_util.DefaultConfig
|
||||
}
|
||||
}
|
||||
|
||||
l.md.path = md.GetString(path)
|
||||
l.md.connQueueSize = md.GetInt(connQueueSize)
|
||||
if l.md.connQueueSize <= 0 {
|
||||
l.md.connQueueSize = defaultQueueSize
|
||||
}
|
||||
l.md.enableCompression = md.GetBool(enableCompression)
|
||||
l.md.readBufferSize = md.GetInt(readBufferSize)
|
||||
l.md.writeBufferSize = md.GetInt(writeBufferSize)
|
||||
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
|
||||
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user