add udp relay support for http handler
This commit is contained in:
parent
f3411832a8
commit
15f9aa091b
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,6 +9,7 @@ _test
|
|||||||
release
|
release
|
||||||
debian
|
debian
|
||||||
bin
|
bin
|
||||||
|
.vscode
|
||||||
|
|
||||||
# Architecture specific extensions/prefixes
|
# Architecture specific extensions/prefixes
|
||||||
*.[568vq]
|
*.[568vq]
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||||
"github.com/go-gost/gost/pkg/connector"
|
"github.com/go-gost/gost/pkg/connector"
|
||||||
"github.com/go-gost/gost/pkg/logger"
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
md "github.com/go-gost/gost/pkg/metadata"
|
md "github.com/go-gost/gost/pkg/metadata"
|
||||||
@ -50,19 +51,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
|||||||
})
|
})
|
||||||
c.logger.Infof("connect %s/%s", address, network)
|
c.logger.Infof("connect %s/%s", address, network)
|
||||||
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp4", "tcp6":
|
|
||||||
if _, ok := conn.(net.PacketConn); ok {
|
|
||||||
err := fmt.Errorf("tcp over udp is unsupported")
|
|
||||||
c.logger.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err := fmt.Errorf("network %s is unsupported", network)
|
|
||||||
c.logger.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &http.Request{
|
req := &http.Request{
|
||||||
Method: http.MethodConnect,
|
Method: http.MethodConnect,
|
||||||
URL: &url.URL{Host: address},
|
URL: &url.URL{Host: address},
|
||||||
@ -83,6 +71,21 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
|||||||
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
|
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
if _, ok := conn.(net.PacketConn); ok {
|
||||||
|
err := fmt.Errorf("tcp over udp is unsupported")
|
||||||
|
c.logger.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
req.Header.Set("X-Gost-Protocol", "udp")
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("network %s is unsupported", network)
|
||||||
|
c.logger.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
dump, _ := httputil.DumpRequest(req, false)
|
dump, _ := httputil.DumpRequest(req, false)
|
||||||
c.logger.Debug(string(dump))
|
c.logger.Debug(string(dump))
|
||||||
@ -113,5 +116,10 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
|
|||||||
return nil, fmt.Errorf("%s", resp.Status)
|
return nil, fmt.Errorf("%s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if network == "udp" {
|
||||||
|
addr, _ := net.ResolveUDPAddr(network, address)
|
||||||
|
return socks.UDPTunClientConn(conn, addr), nil
|
||||||
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
@ -167,11 +167,5 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network
|
|||||||
return nil, errors.New("get socks5 UDP tunnel failure")
|
return nil, errors.New("get socks5 UDP tunnel failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.logger.Debugf("associate on %s OK", baddr)
|
|
||||||
|
|
||||||
return socks.UDPTunClientConn(conn, addr), nil
|
return socks.UDPTunClientConn(conn, addr), nil
|
||||||
}
|
}
|
||||||
|
@ -90,6 +90,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
network := req.Header.Get("X-Gost-Protocol")
|
||||||
|
if network != "udp" {
|
||||||
|
network = "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
// Try to get the actual host.
|
// Try to get the actual host.
|
||||||
// Compatible with GOST 2.x.
|
// Compatible with GOST 2.x.
|
||||||
if v := req.Header.Get("Gost-Target"); v != "" {
|
if v := req.Header.Get("Gost-Target"); v != "" {
|
||||||
@ -168,6 +173,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if network == "udp" {
|
||||||
|
h.handleUDP(ctx, conn, network, req.Host)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if req.Method == "PRI" ||
|
if req.Method == "PRI" ||
|
||||||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
(req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
||||||
resp.StatusCode = http.StatusBadRequest
|
resp.StatusCode = http.StatusBadRequest
|
||||||
@ -187,7 +197,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
WithChain(h.chain).
|
WithChain(h.chain).
|
||||||
WithRetry(h.md.retryCount).
|
WithRetry(h.md.retryCount).
|
||||||
WithLogger(h.logger)
|
WithLogger(h.logger)
|
||||||
cc, err := r.Dial(ctx, "tcp", addr)
|
cc, err := r.Dial(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.StatusCode = http.StatusServiceUnavailable
|
resp.StatusCode = http.StatusServiceUnavailable
|
||||||
resp.Write(conn)
|
resp.Write(conn)
|
||||||
@ -209,13 +219,13 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
h.logger.Debug(string(dump))
|
h.logger.Debug(string(dump))
|
||||||
}
|
}
|
||||||
if err = resp.Write(conn); err != nil {
|
if err = resp.Write(conn); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
req.Header.Del("Proxy-Connection")
|
req.Header.Del("Proxy-Connection")
|
||||||
if err = req.Write(cc); err != nil {
|
if err = req.Write(cc); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ type metadata struct {
|
|||||||
retryCount int
|
retryCount int
|
||||||
probeResist *probeResist
|
probeResist *probeResist
|
||||||
sni bool
|
sni bool
|
||||||
|
enableUDP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
||||||
@ -23,6 +24,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
|||||||
knock = "knock"
|
knock = "knock"
|
||||||
retryCount = "retry"
|
retryCount = "retry"
|
||||||
sni = "sni"
|
sni = "sni"
|
||||||
|
enableUDP = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
h.md.proxyAgent = md.GetString(proxyAgent)
|
h.md.proxyAgent = md.GetString(proxyAgent)
|
||||||
@ -53,6 +55,7 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
|
|||||||
}
|
}
|
||||||
h.md.retryCount = md.GetInt(retryCount)
|
h.md.retryCount = md.GetInt(retryCount)
|
||||||
h.md.sni = md.GetBool(sni)
|
h.md.sni = md.GetBool(sni)
|
||||||
|
h.md.enableUDP = md.GetBool(enableUDP)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
83
pkg/handler/http/udp.go
Normal file
83
pkg/handler/http/udp.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
|
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||||
|
"github.com/go-gost/gost/pkg/handler"
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) {
|
||||||
|
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||||
|
"cmd": "udp",
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
if h.md.proxyAgent != "" {
|
||||||
|
resp.Header.Add("Proxy-Agent", h.md.proxyAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.md.enableUDP {
|
||||||
|
resp.StatusCode = http.StatusForbidden
|
||||||
|
resp.Write(conn)
|
||||||
|
|
||||||
|
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
dump, _ := httputil.DumpResponse(resp, false)
|
||||||
|
h.logger.Debug(string(dump))
|
||||||
|
}
|
||||||
|
h.logger.Error("UDP relay is diabled")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.StatusCode = http.StatusOK
|
||||||
|
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
dump, _ := httputil.DumpResponse(resp, false)
|
||||||
|
h.logger.Debug(string(dump))
|
||||||
|
}
|
||||||
|
if err := resp.Write(conn); err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// obtain a udp connection
|
||||||
|
r := (&chain.Router{}).
|
||||||
|
WithChain(h.chain).
|
||||||
|
WithRetry(h.md.retryCount).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
pc, ok := c.(net.PacketConn)
|
||||||
|
if !ok {
|
||||||
|
h.logger.Errorf("wrong connection type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||||
|
WithBypass(h.bypass).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
|
||||||
|
t := time.Now()
|
||||||
|
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
|
relay.Run()
|
||||||
|
h.logger.
|
||||||
|
WithFields(map[string]interface{}{
|
||||||
|
"duration": time.Since(t),
|
||||||
|
}).
|
||||||
|
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
|
}
|
126
pkg/handler/relay.go
Normal file
126
pkg/handler/relay.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/bypass"
|
||||||
|
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UDPRelay struct {
|
||||||
|
pc1 net.PacketConn
|
||||||
|
pc2 net.PacketConn
|
||||||
|
|
||||||
|
bypass bypass.Bypass
|
||||||
|
bufferSize int
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay {
|
||||||
|
return &UDPRelay{
|
||||||
|
pc1: pc1,
|
||||||
|
pc2: pc2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay {
|
||||||
|
r.bypass = bp
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay {
|
||||||
|
r.logger = logger
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UDPRelay) SetBufferSize(n int) {
|
||||||
|
r.bufferSize = n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *UDPRelay) Run() (err error) {
|
||||||
|
bufSize := r.bufferSize
|
||||||
|
if bufSize <= 0 {
|
||||||
|
bufSize = 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
err := func() error {
|
||||||
|
b := bufpool.Get(bufSize)
|
||||||
|
defer bufpool.Put(b)
|
||||||
|
|
||||||
|
n, raddr, err := r.pc1.ReadFrom(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Warn("bypass: ", raddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := r.pc2.WriteTo(b[:n], raddr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Debugf("%s >>> %s data: %d",
|
||||||
|
r.pc2.LocalAddr(), raddr, n)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
err := func() error {
|
||||||
|
b := bufpool.Get(bufSize)
|
||||||
|
defer bufpool.Put(b)
|
||||||
|
|
||||||
|
n, raddr, err := r.pc2.ReadFrom(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Warn("bypass: ", raddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := r.pc1.WriteTo(b[:n], raddr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Debugf("%s <<< %s data: %d",
|
||||||
|
r.pc2.LocalAddr(), raddr, n)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return <-errc
|
||||||
|
}
|
@ -136,6 +136,5 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
|||||||
h.handleConnect(ctx, conn, network, address)
|
h.handleConnect(ctx, conn, network, address)
|
||||||
case relay.BIND:
|
case relay.BIND:
|
||||||
h.handleBind(ctx, conn, network, address)
|
h.handleBind(ctx, conn, network, address)
|
||||||
case relay.ASSOCIATE:
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/gost/pkg/auth"
|
"github.com/go-gost/gost/pkg/auth"
|
||||||
util_tls "github.com/go-gost/gost/pkg/common/util/tls"
|
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
||||||
md "github.com/go-gost/gost/pkg/metadata"
|
md "github.com/go-gost/gost/pkg/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ type metadata struct {
|
|||||||
compatibilityMode bool
|
compatibilityMode bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *socks5Handler) parseMetadata(md md.Metadata) error {
|
func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) {
|
||||||
const (
|
const (
|
||||||
certFile = "certFile"
|
certFile = "certFile"
|
||||||
keyFile = "keyFile"
|
keyFile = "keyFile"
|
||||||
@ -39,14 +39,19 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error {
|
|||||||
compatibilityMode = "comp"
|
compatibilityMode = "comp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var err error
|
if md.GetString(certFile) != "" ||
|
||||||
h.md.tlsConfig, err = util_tls.LoadTLSConfig(
|
md.GetString(keyFile) != "" ||
|
||||||
md.GetString(certFile),
|
md.GetString(caFile) != "" {
|
||||||
md.GetString(keyFile),
|
h.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
||||||
md.GetString(caFile),
|
md.GetString(certFile),
|
||||||
)
|
md.GetString(keyFile),
|
||||||
if err != nil {
|
md.GetString(caFile),
|
||||||
h.logger.Warn("parse tls config: ", err)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.md.tlsConfig = tls_util.DefaultConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, _ := md.Get(users).([]interface{}); len(v) > 0 {
|
if v, _ := md.Get(users).([]interface{}); len(v) > 0 {
|
||||||
|
@ -9,8 +9,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/gosocks5"
|
"github.com/go-gost/gosocks5"
|
||||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||||
|
"github.com/go-gost/gost/pkg/handler"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||||
@ -26,7 +27,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
relay, err := net.ListenUDP("udp", nil)
|
cc, err := net.ListenUDP("udp", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error(err)
|
h.logger.Error(err)
|
||||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||||
@ -34,10 +35,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
|||||||
h.logger.Debug(reply)
|
h.logger.Debug(reply)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer relay.Close()
|
defer cc.Close()
|
||||||
|
|
||||||
saddr := gosocks5.Addr{}
|
saddr := gosocks5.Addr{}
|
||||||
saddr.ParseFrom(relay.LocalAddr().String())
|
saddr.ParseFrom(cc.LocalAddr().String())
|
||||||
saddr.Type = 0
|
saddr.Type = 0
|
||||||
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
|
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
|
||||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
|
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
|
||||||
@ -48,99 +49,39 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
|||||||
h.logger.Debug(reply)
|
h.logger.Debug(reply)
|
||||||
|
|
||||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||||
"bind": fmt.Sprintf("%s/%s", relay.LocalAddr(), relay.LocalAddr().Network()),
|
"bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()),
|
||||||
})
|
})
|
||||||
h.logger.Debugf("bind on %s OK", relay.LocalAddr())
|
h.logger.Debugf("bind on %s OK", cc.LocalAddr())
|
||||||
|
|
||||||
peer, err := net.ListenUDP("udp", nil)
|
// obtain a udp connection
|
||||||
|
r := (&chain.Router{}).
|
||||||
|
WithChain(h.chain).
|
||||||
|
WithRetry(h.md.retryCount).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error(err)
|
h.logger.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer peer.Close()
|
defer c.Close()
|
||||||
|
|
||||||
go h.relayUDP(
|
pc, ok := c.(net.PacketConn)
|
||||||
socks.UDPConn(relay, h.md.udpBufferSize),
|
if !ok {
|
||||||
peer,
|
h.logger.Errorf("wrong connection type")
|
||||||
)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
|
||||||
|
WithBypass(h.bypass).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
relay.SetBufferSize(h.md.udpBufferSize)
|
||||||
|
|
||||||
|
go relay.Run()
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), relay.LocalAddr())
|
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||||
io.Copy(ioutil.Discard, conn)
|
io.Copy(ioutil.Discard, conn)
|
||||||
h.logger.
|
h.logger.
|
||||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||||
Infof("%s >-< %s", conn.RemoteAddr(), relay.LocalAddr())
|
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||||
}
|
|
||||||
|
|
||||||
func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) {
|
|
||||||
bufSize := h.md.udpBufferSize
|
|
||||||
errc := make(chan error, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
err := func() error {
|
|
||||||
b := bufpool.Get(bufSize)
|
|
||||||
defer bufpool.Put(b)
|
|
||||||
|
|
||||||
n, raddr, err := c.ReadFrom(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
|
||||||
h.logger.Warn("bypass: ", raddr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := peer.WriteTo(b[:n], raddr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("%s >>> %s data: %d",
|
|
||||||
peer.LocalAddr(), raddr, n)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
errc <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
err := func() error {
|
|
||||||
b := bufpool.Get(bufSize)
|
|
||||||
defer bufpool.Put(b)
|
|
||||||
|
|
||||||
n, raddr, err := peer.ReadFrom(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
|
||||||
h.logger.Warn("bypass: ", raddr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := c.WriteTo(b[:n], raddr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("%s <<< %s data: %d",
|
|
||||||
peer.LocalAddr(), raddr, n)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
errc <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return <-errc
|
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,13 @@ package v5
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/gosocks5"
|
"github.com/go-gost/gosocks5"
|
||||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||||
|
"github.com/go-gost/gost/pkg/handler"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
|
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
|
||||||
@ -24,111 +24,43 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bindAddr, _ := net.ResolveUDPAddr(network, address)
|
// dummy bind
|
||||||
pc, err := net.ListenUDP(network, bindAddr)
|
reply := gosocks5.NewReply(gosocks5.Succeeded, nil)
|
||||||
if err != nil {
|
|
||||||
h.logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer pc.Close()
|
|
||||||
|
|
||||||
saddr, _ := gosocks5.NewAddr(pc.LocalAddr().String())
|
|
||||||
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
|
|
||||||
saddr.Type = 0
|
|
||||||
reply := gosocks5.NewReply(gosocks5.Succeeded, saddr)
|
|
||||||
if err := reply.Write(conn); err != nil {
|
if err := reply.Write(conn); err != nil {
|
||||||
h.logger.Error(err)
|
h.logger.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.logger.Debug(reply)
|
h.logger.Debug(reply)
|
||||||
|
|
||||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
// obtain a udp connection
|
||||||
"bind": fmt.Sprintf("%s/%s", pc.LocalAddr(), pc.LocalAddr().Network()),
|
r := (&chain.Router{}).
|
||||||
})
|
WithChain(h.chain).
|
||||||
|
WithRetry(h.md.retryCount).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
c, err := r.Dial(ctx, "udp", "") // UDP association
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
h.logger.Debugf("bind on %s OK", pc.LocalAddr())
|
pc, ok := c.(net.PacketConn)
|
||||||
|
if !ok {
|
||||||
|
h.logger.Errorf("wrong connection type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||||
|
WithBypass(h.bypass).
|
||||||
|
WithLogger(h.logger)
|
||||||
|
relay.SetBufferSize(h.md.udpBufferSize)
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
h.tunnelServerUDP(
|
relay.Run()
|
||||||
socks.UDPTunServerConn(conn),
|
|
||||||
pc,
|
|
||||||
)
|
|
||||||
h.logger.
|
h.logger.
|
||||||
WithFields(map[string]interface{}{
|
WithFields(map[string]interface{}{
|
||||||
"duration": time.Since(t),
|
"duration": time.Since(t),
|
||||||
}).
|
}).
|
||||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
|
||||||
bufSize := h.md.udpBufferSize
|
|
||||||
errc := make(chan error, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
err := func() error {
|
|
||||||
b := bufpool.Get(bufSize)
|
|
||||||
defer bufpool.Put(b)
|
|
||||||
|
|
||||||
n, raddr, err := tunnel.ReadFrom(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
|
||||||
h.logger.Warn("bypass: ", raddr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := c.WriteTo(b[:n], raddr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("%s >>> %s data: %d",
|
|
||||||
c.LocalAddr(), raddr, n)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
errc <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
err := func() error {
|
|
||||||
b := bufpool.Get(bufSize)
|
|
||||||
defer bufpool.Put(b)
|
|
||||||
|
|
||||||
n, raddr, err := c.ReadFrom(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
|
|
||||||
h.logger.Warn("bypass: ", raddr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.logger.Debugf("%s <<< %s data: %d",
|
|
||||||
c.LocalAddr(), raddr, n)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
errc <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return <-errc
|
|
||||||
}
|
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
|
||||||
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
|
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
|
||||||
"github.com/go-gost/gost/pkg/listener"
|
"github.com/go-gost/gost/pkg/listener"
|
||||||
"github.com/go-gost/gost/pkg/logger"
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
@ -16,18 +15,19 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
registry.RegisterListener("ws", NewListener)
|
registry.RegisterListener("ws", NewListener)
|
||||||
registry.RegisterListener("wss", NewListener)
|
registry.RegisterListener("wss", NewTLSListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
type wsListener struct {
|
type wsListener struct {
|
||||||
saddr string
|
saddr string
|
||||||
md metadata
|
md metadata
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
upgrader *websocket.Upgrader
|
upgrader *websocket.Upgrader
|
||||||
srv *http.Server
|
srv *http.Server
|
||||||
connChan chan net.Conn
|
tlsEnabled bool
|
||||||
errChan chan error
|
connChan chan net.Conn
|
||||||
logger logger.Logger
|
errChan chan error
|
||||||
|
logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(opts ...listener.Option) listener.Listener {
|
func NewListener(opts ...listener.Option) listener.Listener {
|
||||||
@ -41,6 +41,18 @@ func NewListener(opts ...listener.Option) listener.Listener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||||
|
options := &listener.Options{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
return &wsListener{
|
||||||
|
saddr: options.Addr,
|
||||||
|
tlsEnabled: true,
|
||||||
|
logger: options.Logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (l *wsListener) Init(md md.Metadata) (err error) {
|
func (l *wsListener) Init(md md.Metadata) (err error) {
|
||||||
if err = l.parseMetadata(md); err != nil {
|
if err = l.parseMetadata(md); err != nil {
|
||||||
return
|
return
|
||||||
@ -115,19 +127,6 @@ func (l *wsListener) Addr() net.Addr {
|
|||||||
return l.addr
|
return l.addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
|
|
||||||
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
|
||||||
md.GetString(certFile),
|
|
||||||
md.GetString(keyFile),
|
|
||||||
md.GetString(caFile),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
|
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,20 +4,9 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
|
||||||
path = "path"
|
md "github.com/go-gost/gost/pkg/metadata"
|
||||||
certFile = "certFile"
|
|
||||||
keyFile = "keyFile"
|
|
||||||
caFile = "caFile"
|
|
||||||
handshakeTimeout = "handshakeTimeout"
|
|
||||||
readHeaderTimeout = "readHeaderTimeout"
|
|
||||||
readBufferSize = "readBufferSize"
|
|
||||||
writeBufferSize = "writeBufferSize"
|
|
||||||
enableCompression = "enableCompression"
|
|
||||||
responseHeader = "responseHeader"
|
|
||||||
connQueueSize = "connQueueSize"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -36,3 +25,49 @@ type metadata struct {
|
|||||||
responseHeader http.Header
|
responseHeader http.Header
|
||||||
connQueueSize int
|
connQueueSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
|
||||||
|
const (
|
||||||
|
path = "path"
|
||||||
|
certFile = "certFile"
|
||||||
|
keyFile = "keyFile"
|
||||||
|
caFile = "caFile"
|
||||||
|
handshakeTimeout = "handshakeTimeout"
|
||||||
|
readHeaderTimeout = "readHeaderTimeout"
|
||||||
|
readBufferSize = "readBufferSize"
|
||||||
|
writeBufferSize = "writeBufferSize"
|
||||||
|
enableCompression = "enableCompression"
|
||||||
|
responseHeader = "responseHeader"
|
||||||
|
connQueueSize = "connQueueSize"
|
||||||
|
)
|
||||||
|
|
||||||
|
if l.tlsEnabled {
|
||||||
|
if md.GetString(certFile) != "" ||
|
||||||
|
md.GetString(keyFile) != "" ||
|
||||||
|
md.GetString(caFile) != "" {
|
||||||
|
l.md.tlsConfig, err = tls_util.LoadTLSConfig(
|
||||||
|
md.GetString(certFile),
|
||||||
|
md.GetString(keyFile),
|
||||||
|
md.GetString(caFile),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
l.md.tlsConfig = tls_util.DefaultConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.md.path = md.GetString(path)
|
||||||
|
l.md.connQueueSize = md.GetInt(connQueueSize)
|
||||||
|
if l.md.connQueueSize <= 0 {
|
||||||
|
l.md.connQueueSize = defaultQueueSize
|
||||||
|
}
|
||||||
|
l.md.enableCompression = md.GetBool(enableCompression)
|
||||||
|
l.md.readBufferSize = md.GetInt(readBufferSize)
|
||||||
|
l.md.writeBufferSize = md.GetInt(writeBufferSize)
|
||||||
|
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
|
||||||
|
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user