diff --git a/cmd/gost/norm.go b/cmd/gost/norm.go index 3291ed3..0e16f12 100644 --- a/cmd/gost/norm.go +++ b/cmd/gost/norm.go @@ -64,7 +64,8 @@ func normService(svc *config.ServiceConfig) { } if handler != "relay" { if listener == "tcp" || listener == "udp" || - listener == "rtcp" || listener == "rudp" { + listener == "rtcp" || listener == "rudp" || + listener == "tun" || listener == "tap" { handler = listener } else { handler = "tcp" diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 8beeb69..d0e54c1 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -44,6 +44,8 @@ import ( _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" _ "github.com/go-gost/gost/pkg/handler/ss/udp" + _ "github.com/go-gost/gost/pkg/handler/tap" + _ "github.com/go-gost/gost/pkg/handler/tun" // Register listeners _ "github.com/go-gost/gost/pkg/listener/ftcp" @@ -57,6 +59,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/rtcp" _ "github.com/go-gost/gost/pkg/listener/rudp" _ "github.com/go-gost/gost/pkg/listener/ssh" + _ "github.com/go-gost/gost/pkg/listener/tap" _ "github.com/go-gost/gost/pkg/listener/tcp" _ "github.com/go-gost/gost/pkg/listener/tls" _ "github.com/go-gost/gost/pkg/listener/tls/mux" diff --git a/pkg/connector/http/metadata.go b/pkg/connector/http/metadata.go index d3b3aea..507c4f4 100644 --- a/pkg/connector/http/metadata.go +++ b/pkg/connector/http/metadata.go @@ -22,9 +22,9 @@ func (c *httpConnector) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) - c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { c.md.User = url.User(ss[0]) diff --git a/pkg/connector/http2/metadata.go b/pkg/connector/http2/metadata.go index c675eee..494c233 100644 --- a/pkg/connector/http2/metadata.go +++ b/pkg/connector/http2/metadata.go @@ -5,7 +5,7 @@ import ( "strings" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -18,20 +18,20 @@ type metadata struct { User *url.Userinfo } -func (c *http2Connector) parseMetadata(md md.Metadata) (err error) { +func (c *http2Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" userAgent = "userAgent" user = "user" ) - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.UserAgent, _ = md.Get(userAgent).(string) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.UserAgent = mdata.GetString(md, userAgent) if c.md.UserAgent == "" { c.md.UserAgent = defaultUserAgent } - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { c.md.User = url.User(ss[0]) diff --git a/pkg/connector/relay/metadata.go b/pkg/connector/relay/metadata.go index 76e64d9..fb66030 100644 --- a/pkg/connector/relay/metadata.go +++ b/pkg/connector/relay/metadata.go @@ -5,7 +5,7 @@ import ( "strings" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -14,14 +14,14 @@ type metadata struct { noDelay bool } -func (c *relayConnector) parseMetadata(md md.Metadata) (err error) { +func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { const ( user = "user" connectTimeout = "connectTimeout" noDelay = "nodelay" ) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { c.md.user = url.User(ss[0]) @@ -29,8 +29,8 @@ func (c *relayConnector) parseMetadata(md md.Metadata) (err error) { c.md.user = url.UserPassword(ss[0], ss[1]) } } - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.noDelay = md.GetBool(noDelay) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.noDelay = mdata.GetBool(md, noDelay) return } diff --git a/pkg/connector/sni/metadata.go b/pkg/connector/sni/metadata.go index f383093..1c8c916 100644 --- a/pkg/connector/sni/metadata.go +++ b/pkg/connector/sni/metadata.go @@ -3,7 +3,7 @@ package sni import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -11,14 +11,14 @@ type metadata struct { connectTimeout time.Duration } -func (c *sniConnector) parseMetadata(md md.Metadata) (err error) { +func (c *sniConnector) parseMetadata(md mdata.Metadata) (err error) { const ( host = "host" connectTimeout = "timeout" ) - c.md.host = md.GetString(host) - c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.host = mdata.GetString(md, host) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) return } diff --git a/pkg/connector/socks/v4/metadata.go b/pkg/connector/socks/v4/metadata.go index 08f8184..54fb242 100644 --- a/pkg/connector/socks/v4/metadata.go +++ b/pkg/connector/socks/v4/metadata.go @@ -4,7 +4,7 @@ import ( "net/url" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -13,18 +13,18 @@ type metadata struct { disable4a bool } -func (c *socks4Connector) parseMetadata(md md.Metadata) (err error) { +func (c *socks4Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" user = "user" disable4a = "disable4a" ) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { c.md.User = url.User(v) } - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.disable4a = md.GetBool(disable4a) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.disable4a = mdata.GetBool(md, disable4a) return } diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go index a3ae9cc..259137a 100644 --- a/pkg/connector/socks/v5/metadata.go +++ b/pkg/connector/socks/v5/metadata.go @@ -6,7 +6,7 @@ import ( "strings" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -16,14 +16,14 @@ type metadata struct { noTLS bool } -func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { +func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" user = "user" noTLS = "notls" ) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { c.md.User = url.User(ss[0]) @@ -32,8 +32,8 @@ func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { } } - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.noTLS = md.GetBool(noTLS) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.noTLS = mdata.GetBool(md, noTLS) return } diff --git a/pkg/connector/ss/metadata.go b/pkg/connector/ss/metadata.go index 4e8575e..052f544 100644 --- a/pkg/connector/ss/metadata.go +++ b/pkg/connector/ss/metadata.go @@ -5,7 +5,7 @@ import ( "time" "github.com/go-gost/gost/pkg/common/util/ss" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) @@ -15,7 +15,7 @@ type metadata struct { noDelay bool } -func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { +func (c *ssConnector) parseMetadata(md mdata.Metadata) (err error) { const ( user = "user" key = "key" @@ -24,7 +24,7 @@ func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { ) var method, password string - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { method = ss[0] @@ -32,13 +32,13 @@ func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { method, password = ss[0], ss[1] } } - c.md.cipher, err = ss.ShadowCipher(method, password, md.GetString(key)) + c.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) if err != nil { return } - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.noDelay = md.GetBool(noDelay) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.noDelay = mdata.GetBool(md, noDelay) return } diff --git a/pkg/connector/ss/udp/connector.go b/pkg/connector/ss/udp/connector.go index cadc3f2..5d72ecc 100644 --- a/pkg/connector/ss/udp/connector.go +++ b/pkg/connector/ss/udp/connector.go @@ -72,7 +72,7 @@ func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, addr } // standard UDP relay - return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.udpBufferSize), nil + return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil } if c.md.cipher != nil { diff --git a/pkg/connector/ss/udp/metadata.go b/pkg/connector/ss/udp/metadata.go index 16e70e5..7291552 100644 --- a/pkg/connector/ss/udp/metadata.go +++ b/pkg/connector/ss/udp/metadata.go @@ -1,30 +1,31 @@ package ss import ( + "math" "strings" "time" "github.com/go-gost/gost/pkg/common/util/ss" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) type metadata struct { cipher core.Cipher connectTimeout time.Duration - udpBufferSize int + bufferSize int } -func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) { +func (c *ssuConnector) parseMetadata(md mdata.Metadata) (err error) { const ( user = "user" key = "key" connectTimeout = "timeout" - udpBufferSize = "udpBufferSize" // udp buffer size + bufferSize = "bufferSize" // udp buffer size ) var method, password string - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { method = ss[0] @@ -32,22 +33,17 @@ func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) { method, password = ss[0], ss[1] } } - c.md.cipher, err = ss.ShadowCipher(method, password, md.GetString(key)) + c.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) if err != nil { return } - c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - if c.md.udpBufferSize > 0 { - if c.md.udpBufferSize < 512 { - c.md.udpBufferSize = 512 - } - if c.md.udpBufferSize > 65*1024 { - c.md.udpBufferSize = 65 * 1024 - } + if bs := mdata.GetInt(md, bufferSize); bs > 0 { + c.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { - c.md.udpBufferSize = 4096 + c.md.bufferSize = 1024 } return diff --git a/pkg/dialer/forward/ssh/metadata.go b/pkg/dialer/forward/ssh/metadata.go index d2fcd03..c589f8a 100644 --- a/pkg/dialer/forward/ssh/metadata.go +++ b/pkg/dialer/forward/ssh/metadata.go @@ -6,7 +6,7 @@ import ( "strings" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "golang.org/x/crypto/ssh" ) @@ -16,7 +16,7 @@ type metadata struct { signer ssh.Signer } -func (d *forwardDialer) parseMetadata(md md.Metadata) (err error) { +func (d *forwardDialer) parseMetadata(md mdata.Metadata) (err error) { const ( handshakeTimeout = "handshakeTimeout" user = "user" @@ -24,7 +24,7 @@ func (d *forwardDialer) parseMetadata(md md.Metadata) (err error) { passphrase = "passphrase" ) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { d.md.user = url.User(ss[0]) @@ -33,13 +33,13 @@ func (d *forwardDialer) parseMetadata(md md.Metadata) (err error) { } } - if key := md.GetString(privateKeyFile); key != "" { + if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := md.GetString(passphrase) + pp := mdata.GetString(md, passphrase) if pp == "" { d.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -50,7 +50,7 @@ func (d *forwardDialer) parseMetadata(md md.Metadata) (err error) { } } - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) return } diff --git a/pkg/dialer/http2/h2/metadata.go b/pkg/dialer/http2/h2/metadata.go index bb36c71..9731c4b 100644 --- a/pkg/dialer/http2/h2/metadata.go +++ b/pkg/dialer/http2/h2/metadata.go @@ -5,7 +5,7 @@ import ( "net" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -14,7 +14,7 @@ type metadata struct { tlsConfig *tls.Config } -func (d *h2Dialer) parseMetadata(md md.Metadata) (err error) { +func (d *h2Dialer) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -24,20 +24,20 @@ func (d *h2Dialer) parseMetadata(md md.Metadata) (err error) { path = "path" ) - d.md.host = md.GetString(serverName) + d.md.host = mdata.GetString(md, serverName) sn, _, _ := net.SplitHostPort(d.md.host) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.path = md.GetString(path) + d.md.path = mdata.GetString(md, path) return } diff --git a/pkg/dialer/http2/metadata.go b/pkg/dialer/http2/metadata.go index 4befe1d..a26dfb0 100644 --- a/pkg/dialer/http2/metadata.go +++ b/pkg/dialer/http2/metadata.go @@ -5,14 +5,14 @@ import ( "net" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { tlsConfig *tls.Config } -func (d *http2Dialer) parseMetadata(md md.Metadata) (err error) { +func (d *http2Dialer) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -21,15 +21,15 @@ func (d *http2Dialer) parseMetadata(md md.Metadata) (err error) { serverName = "serverName" ) - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) diff --git a/pkg/dialer/kcp/metadata.go b/pkg/dialer/kcp/metadata.go index 85e3d26..d0cefc9 100644 --- a/pkg/dialer/kcp/metadata.go +++ b/pkg/dialer/kcp/metadata.go @@ -5,7 +5,7 @@ import ( "time" kcp_util "github.com/go-gost/gost/pkg/common/util/kcp" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -13,19 +13,13 @@ type metadata struct { config *kcp_util.Config } -func (d *kcpDialer) parseMetadata(md md.Metadata) (err error) { +func (d *kcpDialer) parseMetadata(md mdata.Metadata) (err error) { const ( config = "config" handshakeTimeout = "handshakeTimeout" ) - if mm, _ := md.Get(config).(map[interface{}]interface{}); len(mm) > 0 { - m := make(map[string]interface{}) - for k, v := range mm { - if sk, ok := k.(string); ok { - m[sk] = v - } - } + if m := mdata.GetStringMap(md, config); len(m) > 0 { b, err := json.Marshal(m) if err != nil { return err @@ -36,11 +30,10 @@ func (d *kcpDialer) parseMetadata(md md.Metadata) (err error) { } d.md.config = cfg } - if d.md.config == nil { d.md.config = kcp_util.DefaultConfig } - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) return } diff --git a/pkg/dialer/obfs/http/metadata.go b/pkg/dialer/obfs/http/metadata.go index 1ef8e98..7993f54 100644 --- a/pkg/dialer/obfs/http/metadata.go +++ b/pkg/dialer/obfs/http/metadata.go @@ -1,10 +1,9 @@ package http import ( - "fmt" "net/http" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -12,19 +11,19 @@ type metadata struct { header http.Header } -func (d *obfsHTTPDialer) parseMetadata(md md.Metadata) (err error) { +func (d *obfsHTTPDialer) parseMetadata(md mdata.Metadata) (err error) { const ( header = "header" host = "host" ) - if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 { + if m := mdata.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} - for k, v := range mm { - h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v)) + for k, v := range m { + h.Add(k, v) } d.md.header = h } - d.md.host = md.GetString(host) + d.md.host = mdata.GetString(md, host) return } diff --git a/pkg/dialer/obfs/tls/metadata.go b/pkg/dialer/obfs/tls/metadata.go index f387a20..23204b2 100644 --- a/pkg/dialer/obfs/tls/metadata.go +++ b/pkg/dialer/obfs/tls/metadata.go @@ -1,18 +1,18 @@ package tls import ( - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { host string } -func (d *obfsTLSDialer) parseMetadata(md md.Metadata) (err error) { +func (d *obfsTLSDialer) parseMetadata(md mdata.Metadata) (err error) { const ( host = "host" ) - d.md.host = md.GetString(host) + d.md.host = mdata.GetString(md, host) return } diff --git a/pkg/dialer/quic/metadata.go b/pkg/dialer/quic/metadata.go index 18f1bfa..141754b 100644 --- a/pkg/dialer/quic/metadata.go +++ b/pkg/dialer/quic/metadata.go @@ -6,7 +6,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -18,7 +18,7 @@ type metadata struct { tlsConfig *tls.Config } -func (d *quicDialer) parseMetadata(md md.Metadata) (err error) { +func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { const ( keepAlive = "keepAlive" handshakeTimeout = "handshakeTimeout" @@ -33,26 +33,26 @@ func (d *quicDialer) parseMetadata(md md.Metadata) (err error) { cipherKey = "cipherKey" ) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - if key := md.GetString(cipherKey); key != "" { + if key := mdata.GetString(md, cipherKey); key != "" { d.md.cipherKey = []byte(key) } - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.keepAlive = md.GetBool(keepAlive) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - d.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout) + d.md.keepAlive = mdata.GetBool(md, keepAlive) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) return } diff --git a/pkg/dialer/ssh/metadata.go b/pkg/dialer/ssh/metadata.go index 986e5d2..598d4d4 100644 --- a/pkg/dialer/ssh/metadata.go +++ b/pkg/dialer/ssh/metadata.go @@ -6,7 +6,7 @@ import ( "strings" "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "golang.org/x/crypto/ssh" ) @@ -16,7 +16,7 @@ type metadata struct { signer ssh.Signer } -func (d *sshDialer) parseMetadata(md md.Metadata) (err error) { +func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { const ( handshakeTimeout = "handshakeTimeout" user = "user" @@ -24,7 +24,7 @@ func (d *sshDialer) parseMetadata(md md.Metadata) (err error) { passphrase = "passphrase" ) - if v := md.GetString(user); v != "" { + if v := mdata.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { d.md.user = url.User(ss[0]) @@ -33,13 +33,13 @@ func (d *sshDialer) parseMetadata(md md.Metadata) (err error) { } } - if key := md.GetString(privateKeyFile); key != "" { + if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := md.GetString(passphrase) + pp := mdata.GetString(md, passphrase) if pp == "" { d.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -50,7 +50,7 @@ func (d *sshDialer) parseMetadata(md md.Metadata) (err error) { } } - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) return } diff --git a/pkg/dialer/tls/metadata.go b/pkg/dialer/tls/metadata.go index b03461a..11ab968 100644 --- a/pkg/dialer/tls/metadata.go +++ b/pkg/dialer/tls/metadata.go @@ -6,7 +6,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -14,7 +14,7 @@ type metadata struct { handshakeTimeout time.Duration } -func (d *tlsDialer) parseMetadata(md md.Metadata) (err error) { +func (d *tlsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -25,19 +25,19 @@ func (d *tlsDialer) parseMetadata(md md.Metadata) (err error) { handshakeTimeout = "handshakeTimeout" ) - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) return } diff --git a/pkg/dialer/tls/mux/metadata.go b/pkg/dialer/tls/mux/metadata.go index d11cff1..75ac50b 100644 --- a/pkg/dialer/tls/mux/metadata.go +++ b/pkg/dialer/tls/mux/metadata.go @@ -6,7 +6,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -21,7 +21,7 @@ type metadata struct { muxMaxStreamBuffer int } -func (d *mtlsDialer) parseMetadata(md md.Metadata) (err error) { +func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -39,25 +39,25 @@ func (d *mtlsDialer) parseMetadata(md md.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - d.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout) - d.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer) + d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) + d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) + d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) + d.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) + d.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) + d.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) return } diff --git a/pkg/dialer/ws/metadata.go b/pkg/dialer/ws/metadata.go index 0601489..e7e0002 100644 --- a/pkg/dialer/ws/metadata.go +++ b/pkg/dialer/ws/metadata.go @@ -2,13 +2,12 @@ package ws import ( "crypto/tls" - "fmt" "net" "net/http" "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -29,7 +28,7 @@ type metadata struct { header http.Header } -func (d *wsDialer) parseMetadata(md md.Metadata) (err error) { +func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( path = "path" host = "host" @@ -49,35 +48,35 @@ func (d *wsDialer) parseMetadata(md md.Metadata) (err error) { header = "header" ) - d.md.path = md.GetString(path) + d.md.path = mdata.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.host = md.GetString(host) + d.md.host = mdata.GetString(md, host) - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) - d.md.readBufferSize = md.GetInt(readBufferSize) - d.md.writeBufferSize = md.GetInt(writeBufferSize) - d.md.enableCompression = md.GetBool(enableCompression) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) + d.md.readBufferSize = mdata.GetInt(md, readBufferSize) + d.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) + d.md.enableCompression = mdata.GetBool(md, enableCompression) - if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 { + if m := mdata.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} - for k, v := range mm { - h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v)) + for k, v := range m { + h.Add(k, v) } d.md.header = h } diff --git a/pkg/dialer/ws/mux/metadata.go b/pkg/dialer/ws/mux/metadata.go index cd50a89..56922ce 100644 --- a/pkg/dialer/ws/mux/metadata.go +++ b/pkg/dialer/ws/mux/metadata.go @@ -2,13 +2,12 @@ package mux import ( "crypto/tls" - "fmt" "net" "net/http" "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -36,7 +35,7 @@ type metadata struct { header http.Header } -func (d *mwsDialer) parseMetadata(md md.Metadata) (err error) { +func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( path = "path" host = "host" @@ -63,42 +62,42 @@ func (d *mwsDialer) parseMetadata(md md.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - d.md.path = md.GetString(path) + d.md.path = mdata.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.host = md.GetString(host) + d.md.host = mdata.GetString(md, host) - sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) if sn == "" { sn = "localhost" } d.md.tlsConfig, err = tls_util.LoadClientConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - md.GetBool(secure), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), + mdata.GetBool(md, secure), sn, ) - d.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout) - d.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer) + d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) + d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) + d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) + d.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) + d.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) + d.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) - d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) - d.md.readBufferSize = md.GetInt(readBufferSize) - d.md.writeBufferSize = md.GetInt(writeBufferSize) - d.md.enableCompression = md.GetBool(enableCompression) + d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) + d.md.readBufferSize = mdata.GetInt(md, readBufferSize) + d.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) + d.md.enableCompression = mdata.GetBool(md, enableCompression) - if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 { + if m := mdata.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} - for k, v := range mm { - h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v)) + for k, v := range m { + h.Add(k, v) } d.md.header = h } diff --git a/pkg/handler/forward/local/metadata.go b/pkg/handler/forward/local/metadata.go index b54f8eb..f66b1ad 100644 --- a/pkg/handler/forward/local/metadata.go +++ b/pkg/handler/forward/local/metadata.go @@ -3,7 +3,7 @@ package local import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -11,13 +11,13 @@ type metadata struct { retryCount int } -func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" retryCount = "retry" ) - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/forward/remote/metadata.go b/pkg/handler/forward/remote/metadata.go index 50210b5..26bd723 100644 --- a/pkg/handler/forward/remote/metadata.go +++ b/pkg/handler/forward/remote/metadata.go @@ -3,7 +3,7 @@ package remote import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -11,13 +11,13 @@ type metadata struct { retryCount int } -func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" retryCount = "retry" ) - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/forward/ssh/metadata.go b/pkg/handler/forward/ssh/metadata.go index c904ae6..bf98f57 100644 --- a/pkg/handler/forward/ssh/metadata.go +++ b/pkg/handler/forward/ssh/metadata.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "golang.org/x/crypto/ssh" ) @@ -17,7 +17,7 @@ type metadata struct { authorizedKeys map[string]bool } -func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" authorizedKeys = "authorizedKeys" @@ -25,28 +25,26 @@ func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { passphrase = "passphrase" ) - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range v { - if s, _ := auth.(string); s != "" { - ss := strings.SplitN(s, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } + for _, auth := range auths { + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) } } h.md.authenticator = authenticator } - if key := md.GetString(privateKeyFile); key != "" { + if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := md.GetString(passphrase) + pp := mdata.GetString(md, passphrase) if pp == "" { h.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -64,7 +62,7 @@ func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { h.md.signer = signer } - if name := md.GetString(authorizedKeys); name != "" { + if name := mdata.GetString(md, authorizedKeys); name != "" { m, err := ssh_util.ParseAuthorizedKeysFile(name) if err != nil { return err diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index 43c8c19..c7d38d0 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -28,7 +28,7 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { enableUDP = "udp" ) - if auths := md.GetStrings(users); len(auths) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) for _, auth := range auths { ss := strings.SplitN(auth, ":", 2) @@ -41,26 +41,26 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { h.md.authenticator = authenticator } - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { + if m := mdata.GetStringMapString(md, header); len(m) > 0 { hd := http.Header{} - for k, v := range mm { + for k, v := range m { hd.Add(k, v) } h.md.header = hd } - if v := md.GetString(probeResistKey); v != "" { + if v := mdata.GetString(md, probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { h.md.probeResist = &probeResist{ Type: ss[0], Value: ss[1], - Knock: md.GetString(knock), + Knock: mdata.GetString(md, knock), } } } - h.md.retryCount = md.GetInt(retryCount) - h.md.sni = md.GetBool(sni) - h.md.enableUDP = md.GetBool(enableUDP) + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.sni = mdata.GetBool(md, sni) + h.md.enableUDP = mdata.GetBool(md, enableUDP) return nil } diff --git a/pkg/handler/http2/metadata.go b/pkg/handler/http2/metadata.go index 2455394..5112784 100644 --- a/pkg/handler/http2/metadata.go +++ b/pkg/handler/http2/metadata.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/go-gost/gost/pkg/auth" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -16,7 +16,7 @@ type metadata struct { enableUDP bool } -func (h *http2Handler) parseMetadata(md md.Metadata) error { +func (h *http2Handler) parseMetadata(md mdata.Metadata) error { const ( proxyAgent = "proxyAgent" users = "users" @@ -27,9 +27,9 @@ func (h *http2Handler) parseMetadata(md md.Metadata) error { enableUDP = "udp" ) - h.md.proxyAgent = md.GetString(proxyAgent) + h.md.proxyAgent = mdata.GetString(md, proxyAgent) - if auths := md.GetStrings(users); len(auths) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) for _, auth := range auths { ss := strings.SplitN(auth, ":", 2) @@ -42,18 +42,18 @@ func (h *http2Handler) parseMetadata(md md.Metadata) error { h.md.authenticator = authenticator } - if v := md.GetString(probeResistKey); v != "" { + if v := mdata.GetString(md, probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { h.md.probeResist = &probeResist{ Type: ss[0], Value: ss[1], - Knock: md.GetString(knock), + Knock: mdata.GetString(md, knock), } } } - h.md.retryCount = md.GetInt(retryCount) - h.md.sni = md.GetBool(sni) - h.md.enableUDP = md.GetBool(enableUDP) + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.sni = mdata.GetBool(md, sni) + h.md.enableUDP = mdata.GetBool(md, enableUDP) return nil } diff --git a/pkg/handler/redirect/metadata.go b/pkg/handler/redirect/metadata.go index 7edf818..43b9163 100644 --- a/pkg/handler/redirect/metadata.go +++ b/pkg/handler/redirect/metadata.go @@ -1,18 +1,18 @@ package redirect import ( - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { retryCount int } -func (h *redirectHandler) parseMetadata(md md.Metadata) (err error) { +func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) { const ( retryCount = "retry" ) - h.md.retryCount = md.GetInt(retryCount) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go index 91fde54..c4f3a0f 100644 --- a/pkg/handler/relay/metadata.go +++ b/pkg/handler/relay/metadata.go @@ -1,11 +1,12 @@ package relay import ( + "math" "strings" "time" "github.com/go-gost/gost/pkg/auth" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -17,7 +18,7 @@ type metadata struct { noDelay bool } -func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { +func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" readTimeout = "readTimeout" @@ -27,7 +28,7 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { noDelay = "nodelay" ) - if auths := md.GetStrings(users); len(auths) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) for _, auth := range auths { ss := strings.SplitN(auth, ":", 2) @@ -40,20 +41,15 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { h.md.authenticator = authenticator } - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) - h.md.enableBind = md.GetBool(enableBind) - h.md.noDelay = md.GetBool(noDelay) - h.md.udpBufferSize = md.GetInt(udpBufferSize) - if h.md.udpBufferSize > 0 { - if h.md.udpBufferSize < 512 { - h.md.udpBufferSize = 512 // min buffer size - } - if h.md.udpBufferSize > 65*1024 { - h.md.udpBufferSize = 65 * 1024 // max buffer size - } + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.enableBind = mdata.GetBool(md, enableBind) + h.md.noDelay = mdata.GetBool(md, noDelay) + + if bs := mdata.GetInt(md, udpBufferSize); bs > 0 { + h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { - h.md.udpBufferSize = 1024 // default buffer size + h.md.udpBufferSize = 1024 } return } diff --git a/pkg/handler/sni/metadata.go b/pkg/handler/sni/metadata.go index c38f0a5..4fdbdd7 100644 --- a/pkg/handler/sni/metadata.go +++ b/pkg/handler/sni/metadata.go @@ -3,7 +3,7 @@ package sni import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -11,13 +11,13 @@ type metadata struct { retryCount int } -func (h *sniHandler) parseMetadata(md md.Metadata) (err error) { +func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" retryCount = "retry" ) - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/socks/v4/metadata.go b/pkg/handler/socks/v4/metadata.go index 8382e70..2842c01 100644 --- a/pkg/handler/socks/v4/metadata.go +++ b/pkg/handler/socks/v4/metadata.go @@ -4,7 +4,7 @@ import ( "time" "github.com/go-gost/gost/pkg/auth" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -13,24 +13,24 @@ type metadata struct { retryCount int } -func (h *socks4Handler) parseMetadata(md md.Metadata) (err error) { +func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" readTimeout = "readTimeout" retryCount = "retry" ) - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range v { - if v, _ := auth.(string); v != "" { - authenticator.Add(v, "") + for _, auth := range auths { + if auth != "" { + authenticator.Add(auth, "") } } h.md.authenticator = authenticator } - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 13aa447..47a13d0 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -2,12 +2,13 @@ package v5 import ( "crypto/tls" + "math" "strings" "time" "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { @@ -23,7 +24,7 @@ type metadata struct { compatibilityMode bool } -func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) { +func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -40,49 +41,41 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) { ) h.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range v { - if s, _ := auth.(string); s != "" { - ss := strings.SplitN(s, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } + for _, auth := range auths { + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) } } h.md.authenticator = authenticator } - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.timeout = md.GetDuration(timeout) - h.md.retryCount = md.GetInt(retryCount) - h.md.noTLS = md.GetBool(noTLS) - h.md.enableBind = md.GetBool(enableBind) - h.md.enableUDP = md.GetBool(enableUDP) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.timeout = mdata.GetDuration(md, timeout) + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.noTLS = mdata.GetBool(md, noTLS) + h.md.enableBind = mdata.GetBool(md, enableBind) + h.md.enableUDP = mdata.GetBool(md, enableUDP) - h.md.udpBufferSize = md.GetInt(udpBufferSize) - if h.md.udpBufferSize > 0 { - if h.md.udpBufferSize < 512 { - h.md.udpBufferSize = 512 // min buffer size - } - if h.md.udpBufferSize > 65*1024 { - h.md.udpBufferSize = 65 * 1024 // max buffer size - } + if bs := mdata.GetInt(md, udpBufferSize); bs > 0 { + h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { - h.md.udpBufferSize = 1024 // default buffer size + h.md.udpBufferSize = 1024 } - h.md.compatibilityMode = md.GetBool(compatibilityMode) + h.md.compatibilityMode = mdata.GetBool(md, compatibilityMode) return nil } diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 44ae2f2..85d688b 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -76,7 +76,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { addr := &gosocks5.Addr{} if _, err := addr.ReadFrom(conn); err != nil { h.logger.Error(err) - h.discard(conn) + io.Copy(ioutil.Discard, conn) return } @@ -110,7 +110,3 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { }). Infof("%s >-< %s", conn.RemoteAddr(), addr) } - -func (h *ssHandler) discard(conn net.Conn) { - io.Copy(ioutil.Discard, conn) -} diff --git a/pkg/handler/ss/metadata.go b/pkg/handler/ss/metadata.go index 4943c3d..f841cb3 100644 --- a/pkg/handler/ss/metadata.go +++ b/pkg/handler/ss/metadata.go @@ -5,7 +5,7 @@ import ( "time" "github.com/go-gost/gost/pkg/common/util/ss" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) @@ -15,7 +15,7 @@ type metadata struct { retryCount int } -func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { +func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" key = "key" @@ -24,26 +24,22 @@ func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { ) var method, password string - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { - h.logger.Info(v) - for _, auth := range v { - if s, _ := auth.(string); s != "" { - ss := strings.SplitN(s, ":", 2) - if len(ss) == 1 { - method = ss[0] - } else { - method, password = ss[0], ss[1] - } - } + if auths := mdata.GetStrings(md, users); len(auths) > 0 { + auth := auths[0] + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + method = ss[0] + } else { + method, password = ss[0], ss[1] } } - h.md.cipher, err = ss.ShadowCipher(method, password, md.GetString(key)) + h.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) if err != nil { return } - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/ss/udp/metadata.go b/pkg/handler/ss/udp/metadata.go index 49c5cf1..24d1908 100644 --- a/pkg/handler/ss/udp/metadata.go +++ b/pkg/handler/ss/udp/metadata.go @@ -1,11 +1,12 @@ package ss import ( + "math" "strings" "time" "github.com/go-gost/gost/pkg/common/util/ss" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) @@ -16,7 +17,7 @@ type metadata struct { bufferSize int } -func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { +func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" key = "key" @@ -26,36 +27,27 @@ func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { ) var method, password string - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { - for _, auth := range v { - if s, _ := auth.(string); s != "" { - ss := strings.SplitN(s, ":", 2) - if len(ss) == 1 { - method = ss[0] - } else { - method, password = ss[0], ss[1] - } - } + if auths := mdata.GetStrings(md, users); len(auths) > 0 { + auth := auths[0] + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + method = ss[0] + } else { + method, password = ss[0], ss[1] } } - h.md.cipher, err = ss.ShadowCipher(method, password, md.GetString(key)) + h.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) if err != nil { return } - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) - h.md.bufferSize = md.GetInt(bufferSize) - if h.md.bufferSize > 0 { - if h.md.bufferSize < 512 { - h.md.bufferSize = 512 // min buffer size - } - if h.md.bufferSize > 65*1024 { - h.md.bufferSize = 65 * 1024 // max buffer size - } + if bs := mdata.GetInt(md, bufferSize); bs > 0 { + h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { - h.md.bufferSize = 1024 // default buffer size + h.md.bufferSize = 1024 } return } diff --git a/pkg/handler/tap/conn.go b/pkg/handler/tap/conn.go new file mode 100644 index 0000000..19ace7d --- /dev/null +++ b/pkg/handler/tap/conn.go @@ -0,0 +1,17 @@ +package tap + +import "net" + +type packetConn struct { + net.Conn +} + +func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go new file mode 100644 index 0000000..d59474c --- /dev/null +++ b/pkg/handler/tap/handler.go @@ -0,0 +1,331 @@ +package tap + +import ( + "context" + "fmt" + "io" + "net" + "os" + "sync" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/handler" + tap_util "github.com/go-gost/gost/pkg/internal/util/tap" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/shadowsocks/go-shadowsocks2/shadowaead" + "github.com/songgao/water/waterutil" + "github.com/xtaci/tcpraw" +) + +func init() { + registry.RegisterHandler("tap", NewHandler) +} + +type tapHandler struct { + group *chain.NodeGroup + chain *chain.Chain + bypass bypass.Bypass + routes sync.Map + exit chan struct{} + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &tapHandler{ + bypass: options.Bypass, + exit: make(chan struct{}, 1), + logger: options.Logger, + } +} + +func (h *tapHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +// implements chain.Chainable interface +func (h *tapHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +// Forward implements handler.Forwarder. +func (h *tapHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { + defer os.Exit(0) + defer conn.Close() + + cc, ok := conn.(*tap_util.Conn) + if !ok || cc.Config() == nil { + h.logger.Error("invalid connection") + return + } + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + network := "udp" + var raddr net.Addr + var err error + + target := h.group.Next() + if target != nil { + raddr, err = net.ResolveUDPAddr(network, target.Addr()) + if err != nil { + h.logger.Error(err) + return + } + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + } + + h.handleLoop(ctx, conn, raddr, cc.Config()) +} + +func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config) { + var tempDelay time.Duration + for { + err := func() error { + var err error + var pc net.PacketConn + // fake tcp mode will be ignored when the client specifies a chain. + if addr != nil && !h.chain.IsEmpty() { + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + return err + } + pc = &packetConn{cc} + } else { + if h.md.tcpMode { + if addr != nil { + pc, err = tcpraw.Dial("tcp", addr.String()) + } else { + pc, err = tcpraw.Listen("tcp", conn.LocalAddr().String()) + } + } else { + laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) + pc, err = net.ListenUDP("udp", laddr) + } + } + if err != nil { + return err + } + + if h.md.cipher != nil { + pc = h.md.cipher.PacketConn(pc) + } + + return h.transport(conn, pc, addr) + }() + if err != nil { + h.logger.Error(err) + } + + select { + case <-h.exit: + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + tempDelay = 0 + } + +} + +func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, err := tap.Read(b) + if err != nil { + select { + case h.exit <- struct{}{}: + default: + } + return err + } + + src := waterutil.MACSource(b[:n]) + dst := waterutil.MACDestination(b[:n]) + eType := etherType(waterutil.MACEthertype(b[:n])) + + h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n) + + // client side, deliver frame directly. + if raddr != nil { + _, err := conn.WriteTo(b[:n], raddr) + return err + } + + // server side, broadcast. + if waterutil.IsBroadcast(dst) { + go h.routes.Range(func(k, v interface{}) bool { + conn.WriteTo(b[:n], v.(net.Addr)) + return true + }) + return nil + } + + var addr net.Addr + if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok { + addr = v.(net.Addr) + } + if addr == nil { + h.logger.Warnf("no route for %s -> %s %s %d", src, dst, eType, n) + return nil + } + + if _, err := conn.WriteTo(b[:n], addr); err != nil { + return err + } + + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil && + err != shadowaead.ErrShortPacket { + return err + } + + src := waterutil.MACSource(b[:n]) + dst := waterutil.MACDestination(b[:n]) + eType := etherType(waterutil.MACEthertype(b[:n])) + + h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n) + + // client side, deliver frame to tap device. + if raddr != nil { + _, err := tap.Write(b[:n]) + return err + } + + // server side, record route. + rkey := hwAddrToTapRouteKey(src) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + h.logger.Debugf("update route: %s -> %s (old %s)", + src, addr, actual.(net.Addr)) + h.routes.Store(rkey, addr) + } + } else { + h.logger.Debugf("new route: %s -> %s", src, addr) + } + + if waterutil.IsBroadcast(dst) { + go h.routes.Range(func(k, v interface{}) bool { + if k.(tapRouteKey) != rkey { + conn.WriteTo(b[:n], v.(net.Addr)) + } + return true + }) + } + + if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok { + h.logger.Debugf("find route: %s -> %s", dst, v) + _, err := conn.WriteTo(b[:n], v.(net.Addr)) + return err + } + + if _, err := tap.Write(b[:n]); err != nil { + select { + case h.exit <- struct{}{}: + default: + } + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +var mEtherTypes = map[waterutil.Ethertype]string{ + waterutil.IPv4: "ip", + waterutil.ARP: "arp", + waterutil.RARP: "rarp", + waterutil.IPv6: "ip6", +} + +func etherType(et waterutil.Ethertype) string { + if s, ok := mEtherTypes[et]; ok { + return s + } + return fmt.Sprintf("unknown(%v)", et) +} + +type tapRouteKey [6]byte + +func hwAddrToTapRouteKey(addr net.HardwareAddr) (key tapRouteKey) { + copy(key[:], addr) + return +} diff --git a/pkg/handler/tap/metadata.go b/pkg/handler/tap/metadata.go new file mode 100644 index 0000000..5a97380 --- /dev/null +++ b/pkg/handler/tap/metadata.go @@ -0,0 +1,50 @@ +package tap + +import ( + "strings" + + "github.com/go-gost/gost/pkg/common/util/ss" + mdata "github.com/go-gost/gost/pkg/metadata" + "github.com/shadowsocks/go-shadowsocks2/core" +) + +type metadata struct { + cipher core.Cipher + retryCount int + tcpMode bool + bufferSize int +} + +func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + users = "users" + key = "key" + readTimeout = "readTimeout" + retryCount = "retry" + tcpMode = "tcp" + bufferSize = "bufferSize" + ) + + var method, password string + if auths := mdata.GetStrings(md, users); len(auths) > 0 { + auth := auths[0] + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + method = ss[0] + } else { + method, password = ss[0], ss[1] + } + } + h.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) + if err != nil { + return + } + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.tcpMode = mdata.GetBool(md, tcpMode) + + h.md.bufferSize = mdata.GetInt(md, bufferSize) + if h.md.bufferSize <= 0 { + h.md.bufferSize = 1024 + } + return +} diff --git a/pkg/handler/tun/conn.go b/pkg/handler/tun/conn.go new file mode 100644 index 0000000..c16bd7f --- /dev/null +++ b/pkg/handler/tun/conn.go @@ -0,0 +1,17 @@ +package tun + +import "net" + +type packetConn struct { + net.Conn +} + +func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go new file mode 100644 index 0000000..f84debb --- /dev/null +++ b/pkg/handler/tun/handler.go @@ -0,0 +1,380 @@ +package tun + +import ( + "context" + "fmt" + "io" + "net" + "os" + "sync" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/handler" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/shadowsocks/go-shadowsocks2/shadowaead" + "github.com/songgao/water/waterutil" + "github.com/xtaci/tcpraw" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func init() { + registry.RegisterHandler("tun", NewHandler) +} + +type tunHandler struct { + group *chain.NodeGroup + chain *chain.Chain + bypass bypass.Bypass + routes sync.Map + exit chan struct{} + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &tunHandler{ + bypass: options.Bypass, + exit: make(chan struct{}, 1), + logger: options.Logger, + } +} + +func (h *tunHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +// implements chain.Chainable interface +func (h *tunHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +// Forward implements handler.Forwarder. +func (h *tunHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { + defer os.Exit(0) + defer conn.Close() + + cc, ok := conn.(*tun_util.Conn) + if !ok || cc.Config() == nil { + h.logger.Error("invalid connection") + return + } + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + network := "udp" + var raddr net.Addr + var err error + + target := h.group.Next() + if target != nil { + raddr, err = net.ResolveUDPAddr(network, target.Addr()) + if err != nil { + h.logger.Error(err) + return + } + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + } + + h.handleLoop(ctx, conn, raddr, cc.Config()) +} + +func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config) { + var tempDelay time.Duration + for { + err := func() error { + var err error + var pc net.PacketConn + // fake tcp mode will be ignored when the client specifies a chain. + if addr != nil && !h.chain.IsEmpty() { + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + return err + } + pc = &packetConn{cc} + } else { + if h.md.tcpMode { + if addr != nil { + pc, err = tcpraw.Dial("tcp", addr.String()) + } else { + pc, err = tcpraw.Listen("tcp", conn.LocalAddr().String()) + } + } else { + laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) + pc, err = net.ListenUDP("udp", laddr) + } + } + if err != nil { + return err + } + + if h.md.cipher != nil { + pc = h.md.cipher.PacketConn(pc) + } + + return h.transport(conn, pc, addr) + }() + if err != nil { + h.logger.Error(err) + } + + select { + case <-h.exit: + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + tempDelay = 0 + } + +} + +func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, err := tun.Read(b) + if err != nil { + select { + case h.exit <- struct{}{}: + default: + } + return err + } + + var src, dst net.IP + if waterutil.IsIPv4(b[:n]) { + header, err := ipv4.ParseHeader(b[:n]) + if err != nil { + h.logger.Error(err) + return nil + } + h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol(b[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + + src, dst = header.Src, header.Dst + } else if waterutil.IsIPv6(b[:n]) { + header, err := ipv6.ParseHeader(b[:n]) + if err != nil { + h.logger.Warn(err) + return nil + } + h.logger.Debugf("%s >> %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + + src, dst = header.Src, header.Dst + } else { + h.logger.Warn("unknown packet, discarded") + return nil + } + + // client side, deliver packet directly. + if raddr != nil { + _, err := conn.WriteTo(b[:n], raddr) + return err + } + + addr := h.findRouteFor(dst) + if addr == nil { + h.logger.Warnf("no route for %s -> %s", src, dst) + return nil + } + + h.logger.Debugf("find route: %s -> %s", dst, addr) + + if _, err := conn.WriteTo(b[:n], addr); err != nil { + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(h.md.bufferSize) + defer bufpool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil && + err != shadowaead.ErrShortPacket { + return err + } + + var src, dst net.IP + if waterutil.IsIPv4(b[:n]) { + header, err := ipv4.ParseHeader(b[:n]) + if err != nil { + h.logger.Warn(err) + return nil + } + + h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol(b[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + + src, dst = header.Src, header.Dst + } else if waterutil.IsIPv6(b[:n]) { + header, err := ipv6.ParseHeader(b[:n]) + if err != nil { + h.logger.Warn(err) + return nil + } + + h.logger.Debugf("%s > %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + + src, dst = header.Src, header.Dst + } else { + h.logger.Warn("unknown packet, discarded") + return nil + } + + // client side, deliver packet to tun device. + if raddr != nil { + _, err := tun.Write(b[:n]) + return err + } + + rkey := ipToTunRouteKey(src) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + h.logger.Debugf("update route: %s -> %s (old %s)", + src, addr, actual.(net.Addr)) + h.routes.Store(rkey, addr) + } + } else { + h.logger.Warnf("no route for %s -> %s", src, addr) + } + + if addr := h.findRouteFor(dst); addr != nil { + h.logger.Debugf("find route: %s -> %s", dst, addr) + + _, err := conn.WriteTo(b[:n], addr) + return err + } + + if _, err := tun.Write(b[:n]); err != nil { + select { + case h.exit <- struct{}{}: + default: + } + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +func (h *tunHandler) findRouteFor(dst net.IP, routes ...tun_util.Route) net.Addr { + if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { + return v.(net.Addr) + } + for _, route := range routes { + if route.Net.Contains(dst) && route.Gateway != nil { + if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok { + return v.(net.Addr) + } + } + } + return nil +} + +var mIPProts = map[waterutil.IPProtocol]string{ + waterutil.HOPOPT: "HOPOPT", + waterutil.ICMP: "ICMP", + waterutil.IGMP: "IGMP", + waterutil.GGP: "GGP", + waterutil.TCP: "TCP", + waterutil.UDP: "UDP", + waterutil.IPv6_Route: "IPv6-Route", + waterutil.IPv6_Frag: "IPv6-Frag", + waterutil.IPv6_ICMP: "IPv6-ICMP", +} + +func ipProtocol(p waterutil.IPProtocol) string { + if v, ok := mIPProts[p]; ok { + return v + } + return fmt.Sprintf("unknown(%d)", p) +} + +type tunRouteKey [16]byte + +func ipToTunRouteKey(ip net.IP) (key tunRouteKey) { + copy(key[:], ip.To16()) + return +} diff --git a/pkg/handler/tun/metadata.go b/pkg/handler/tun/metadata.go new file mode 100644 index 0000000..2d8437c --- /dev/null +++ b/pkg/handler/tun/metadata.go @@ -0,0 +1,50 @@ +package tun + +import ( + "strings" + + "github.com/go-gost/gost/pkg/common/util/ss" + mdata "github.com/go-gost/gost/pkg/metadata" + "github.com/shadowsocks/go-shadowsocks2/core" +) + +type metadata struct { + cipher core.Cipher + retryCount int + tcpMode bool + bufferSize int +} + +func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + users = "users" + key = "key" + readTimeout = "readTimeout" + retryCount = "retry" + tcpMode = "tcp" + bufferSize = "bufferSize" + ) + + var method, password string + if auths := mdata.GetStrings(md, users); len(auths) > 0 { + auth := auths[0] + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + method = ss[0] + } else { + method, password = ss[0], ss[1] + } + } + h.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) + if err != nil { + return + } + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.tcpMode = mdata.GetBool(md, tcpMode) + + h.md.bufferSize = mdata.GetInt(md, bufferSize) + if h.md.bufferSize <= 0 { + h.md.bufferSize = 1024 + } + return +} diff --git a/pkg/internal/util/tap/config.go b/pkg/internal/util/tap/config.go new file mode 100644 index 0000000..77895ce --- /dev/null +++ b/pkg/internal/util/tap/config.go @@ -0,0 +1,9 @@ +package tap + +type Config struct { + Name string + Net string + MTU int + Routes []string + Gateway string +} diff --git a/pkg/internal/util/tap/conn.go b/pkg/internal/util/tap/conn.go new file mode 100644 index 0000000..453444f --- /dev/null +++ b/pkg/internal/util/tap/conn.go @@ -0,0 +1,61 @@ +package tap + +import ( + "errors" + "net" + "time" + + "github.com/songgao/water" +) + +type Conn struct { + config *Config + ifce *water.Interface + laddr net.Addr + raddr net.Addr +} + +func NewConn(config *Config, ifce *water.Interface, laddr, raddr net.Addr) *Conn { + return &Conn{ + config: config, + ifce: ifce, + laddr: laddr, + raddr: raddr, + } +} + +func (c *Conn) Config() *Config { + return c.config +} + +func (c *Conn) Read(b []byte) (n int, err error) { + return c.ifce.Read(b) +} + +func (c *Conn) Write(b []byte) (n int, err error) { + return c.ifce.Write(b) +} + +func (c *Conn) Close() (err error) { + return c.ifce.Close() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *Conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/pkg/internal/util/tun/config.go b/pkg/internal/util/tun/config.go new file mode 100644 index 0000000..6448cae --- /dev/null +++ b/pkg/internal/util/tun/config.go @@ -0,0 +1,19 @@ +package tun + +import "net" + +// Route is an IP routing entry +type Route struct { + Net net.IPNet + Gateway net.IP +} + +type Config struct { + Name string + Net string + // peer addr of point-to-point on MacOS + Peer string + MTU int + Gateway string + Routes []Route +} diff --git a/pkg/internal/util/tun/conn.go b/pkg/internal/util/tun/conn.go new file mode 100644 index 0000000..05477f9 --- /dev/null +++ b/pkg/internal/util/tun/conn.go @@ -0,0 +1,61 @@ +package tun + +import ( + "errors" + "net" + "time" + + "github.com/songgao/water" +) + +type Conn struct { + config *Config + ifce *water.Interface + laddr net.Addr + raddr net.Addr +} + +func NewConn(config *Config, ifce *water.Interface, laddr, raddr net.Addr) *Conn { + return &Conn{ + config: config, + ifce: ifce, + laddr: laddr, + raddr: raddr, + } +} + +func (c *Conn) Config() *Config { + return c.config +} + +func (c *Conn) Read(b []byte) (n int, err error) { + return c.ifce.Read(b) +} + +func (c *Conn) Write(b []byte) (n int, err error) { + return c.ifce.Write(b) +} + +func (c *Conn) Close() (err error) { + return c.ifce.Close() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *Conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/pkg/listener/http2/h2/metadata.go b/pkg/listener/http2/h2/metadata.go index 80ffc2c..3eca516 100644 --- a/pkg/listener/http2/h2/metadata.go +++ b/pkg/listener/http2/h2/metadata.go @@ -4,7 +4,7 @@ import ( "crypto/tls" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -17,7 +17,7 @@ type metadata struct { backlog int } -func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { +func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) { const ( path = "path" certFile = "certFile" @@ -27,19 +27,19 @@ func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.path = md.GetString(path) + l.md.path = mdata.GetString(md, path) return } diff --git a/pkg/listener/http2/metadata.go b/pkg/listener/http2/metadata.go index 50fca3c..ecfede0 100644 --- a/pkg/listener/http2/metadata.go +++ b/pkg/listener/http2/metadata.go @@ -6,7 +6,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -25,7 +25,7 @@ type metadata struct { backlog int } -func (l *http2Listener) parseMetadata(md md.Metadata) (err error) { +func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { const ( path = "path" certFile = "certFile" @@ -39,15 +39,15 @@ func (l *http2Listener) parseMetadata(md md.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/pkg/listener/kcp/metadata.go b/pkg/listener/kcp/metadata.go index 0e8f9a6..127c65e 100644 --- a/pkg/listener/kcp/metadata.go +++ b/pkg/listener/kcp/metadata.go @@ -4,7 +4,7 @@ import ( "encoding/json" kcp_util "github.com/go-gost/gost/pkg/common/util/kcp" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -16,19 +16,13 @@ type metadata struct { backlog int } -func (l *kcpListener) parseMetadata(md md.Metadata) (err error) { +func (l *kcpListener) parseMetadata(md mdata.Metadata) (err error) { const ( backlog = "backlog" config = "config" ) - if mm, _ := md.Get(config).(map[interface{}]interface{}); len(mm) > 0 { - m := make(map[string]interface{}) - for k, v := range mm { - if sk, ok := k.(string); ok { - m[sk] = v - } - } + if m := mdata.GetStringMap(md, config); len(m) > 0 { b, err := json.Marshal(m) if err != nil { return err @@ -44,7 +38,7 @@ func (l *kcpListener) parseMetadata(md md.Metadata) (err error) { l.md.config = kcp_util.DefaultConfig } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/pkg/listener/quic/metadata.go b/pkg/listener/quic/metadata.go index 79fb4ff..3d43ede 100644 --- a/pkg/listener/quic/metadata.go +++ b/pkg/listener/quic/metadata.go @@ -5,7 +5,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -22,7 +22,7 @@ type metadata struct { backlog int } -func (l *quicListener) parseMetadata(md md.Metadata) (err error) { +func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) { const ( keepAlive = "keepAlive" handshakeTimeout = "handshakeTimeout" @@ -37,26 +37,26 @@ func (l *quicListener) parseMetadata(md md.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - if key := md.GetString(cipherKey); key != "" { + if key := mdata.GetString(md, cipherKey); key != "" { l.md.cipherKey = []byte(key) } - l.md.keepAlive = md.GetBool(keepAlive) - l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - l.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout) + l.md.keepAlive = mdata.GetBool(md, keepAlive) + l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) return } diff --git a/pkg/listener/redirect/udp/metadata.go b/pkg/listener/redirect/udp/metadata.go index 2a513b5..1e3151f 100644 --- a/pkg/listener/redirect/udp/metadata.go +++ b/pkg/listener/redirect/udp/metadata.go @@ -3,7 +3,7 @@ package udp import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -16,18 +16,18 @@ type metadata struct { readBufferSize int } -func (l *redirectListener) parseMetadata(md md.Metadata) (err error) { +func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) { const ( ttl = "ttl" readBufferSize = "readBufferSize" ) - l.md.ttl = md.GetDuration(ttl) + l.md.ttl = mdata.GetDuration(md, ttl) if l.md.ttl <= 0 { l.md.ttl = defaultTTL } - l.md.readBufferSize = md.GetInt(readBufferSize) + l.md.readBufferSize = mdata.GetInt(md, readBufferSize) if l.md.readBufferSize <= 0 { l.md.readBufferSize = defaultReadBufferSize } diff --git a/pkg/listener/rtcp/metadata.go b/pkg/listener/rtcp/metadata.go index c25e577..4a723ec 100644 --- a/pkg/listener/rtcp/metadata.go +++ b/pkg/listener/rtcp/metadata.go @@ -3,7 +3,7 @@ package rtcp import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -17,17 +17,17 @@ type metadata struct { retryCount int } -func (l *rtcpListener) parseMetadata(md md.Metadata) (err error) { +func (l *rtcpListener) parseMetadata(md mdata.Metadata) (err error) { const ( enableMux = "mux" backlog = "backlog" retryCount = "retry" ) - l.md.enableMux = md.GetBool(enableMux) - l.md.retryCount = md.GetInt(retryCount) + l.md.enableMux = mdata.GetBool(md, enableMux) + l.md.retryCount = mdata.GetInt(md, retryCount) - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/pkg/listener/rudp/metadata.go b/pkg/listener/rudp/metadata.go index f389955..6c717c6 100644 --- a/pkg/listener/rudp/metadata.go +++ b/pkg/listener/rudp/metadata.go @@ -3,7 +3,7 @@ package rudp import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -21,7 +21,7 @@ type metadata struct { retryCount int } -func (l *rudpListener) parseMetadata(md md.Metadata) (err error) { +func (l *rudpListener) parseMetadata(md mdata.Metadata) (err error) { const ( ttl = "ttl" readBufferSize = "readBufferSize" @@ -30,25 +30,25 @@ func (l *rudpListener) parseMetadata(md md.Metadata) (err error) { retryCount = "retry" ) - l.md.ttl = md.GetDuration(ttl) + l.md.ttl = mdata.GetDuration(md, ttl) if l.md.ttl <= 0 { l.md.ttl = defaultTTL } - l.md.readBufferSize = md.GetInt(readBufferSize) + l.md.readBufferSize = mdata.GetInt(md, readBufferSize) if l.md.readBufferSize <= 0 { l.md.readBufferSize = defaultReadBufferSize } - l.md.readQueueSize = md.GetInt(readQueueSize) + l.md.readQueueSize = mdata.GetInt(md, readQueueSize) if l.md.readQueueSize <= 0 { l.md.readQueueSize = defaultReadQueueSize } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.retryCount = md.GetInt(retryCount) + l.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/listener/ssh/metadata.go b/pkg/listener/ssh/metadata.go index 96e7800..8ecf72f 100644 --- a/pkg/listener/ssh/metadata.go +++ b/pkg/listener/ssh/metadata.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" "golang.org/x/crypto/ssh" ) @@ -22,7 +22,7 @@ type metadata struct { backlog int } -func (l *sshListener) parseMetadata(md md.Metadata) (err error) { +func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" authorizedKeys = "authorizedKeys" @@ -31,28 +31,26 @@ func (l *sshListener) parseMetadata(md md.Metadata) (err error) { backlog = "backlog" ) - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + if auths := mdata.GetStrings(md, users); len(auths) > 0 { authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range v { - if s, _ := auth.(string); s != "" { - ss := strings.SplitN(s, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } + for _, auth := range auths { + ss := strings.SplitN(auth, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) } } l.md.authenticator = authenticator } - if key := md.GetString(privateKeyFile); key != "" { + if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := md.GetString(passphrase) + pp := mdata.GetString(md, passphrase) if pp == "" { l.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -70,7 +68,7 @@ func (l *sshListener) parseMetadata(md md.Metadata) (err error) { l.md.signer = signer } - if name := md.GetString(authorizedKeys); name != "" { + if name := mdata.GetString(md, authorizedKeys); name != "" { m, err := ssh_util.ParseAuthorizedKeysFile(name) if err != nil { return err @@ -78,7 +76,7 @@ func (l *sshListener) parseMetadata(md md.Metadata) (err error) { l.md.authorizedKeys = m } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/pkg/listener/tap/listener.go b/pkg/listener/tap/listener.go new file mode 100644 index 0000000..7333099 --- /dev/null +++ b/pkg/listener/tap/listener.go @@ -0,0 +1,96 @@ +package tap + +import ( + "net" + + tap_util "github.com/go-gost/gost/pkg/internal/util/tap" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterListener("tap", NewListener) +} + +type tapListener struct { + saddr string + addr net.Addr + cqueue chan net.Conn + closed chan struct{} + logger logger.Logger + md metadata +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &tapListener{ + saddr: options.Addr, + logger: options.Logger, + } +} + +func (l *tapListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + l.addr, err = net.ResolveUDPAddr("udp", l.saddr) + if err != nil { + return + } + + ifce, ip, err := l.createTap() + if err != nil { + if ifce != nil { + ifce.Close() + } + return + } + + itf, err := net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + addrs, _ := itf.Addrs() + l.logger.Infof("name: %s, mac: %s, mtu: %d, addrs: %s", + itf.Name, itf.HardwareAddr, itf.MTU, addrs) + + l.cqueue = make(chan net.Conn, 1) + l.closed = make(chan struct{}) + + conn := tap_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip}) + + l.cqueue <- conn + + return +} + +func (l *tapListener) Accept() (net.Conn, error) { + select { + case conn := <-l.cqueue: + return conn, nil + case <-l.closed: + } + + return nil, listener.ErrClosed +} + +func (l *tapListener) Addr() net.Addr { + return l.addr +} + +func (l *tapListener) Close() error { + select { + case <-l.closed: + return net.ErrClosed + default: + close(l.closed) + } + return nil +} diff --git a/pkg/listener/tap/metadata.go b/pkg/listener/tap/metadata.go new file mode 100644 index 0000000..496e4e5 --- /dev/null +++ b/pkg/listener/tap/metadata.go @@ -0,0 +1,44 @@ +package tap + +import ( + tap_util "github.com/go-gost/gost/pkg/internal/util/tap" + mdata "github.com/go-gost/gost/pkg/metadata" +) + +const ( + DefaultMTU = 1350 +) + +type metadata struct { + config *tap_util.Config +} + +func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) { + const ( + name = "name" + netKey = "net" + mtu = "mtu" + routes = "routes" + gateway = "gw" + ) + + config := &tap_util.Config{ + Name: mdata.GetString(md, name), + Net: mdata.GetString(md, netKey), + MTU: mdata.GetInt(md, mtu), + Gateway: mdata.GetString(md, gateway), + } + if config.MTU <= 0 { + config.MTU = DefaultMTU + } + + for _, s := range mdata.GetStrings(md, routes) { + if s != "" { + config.Routes = append(config.Routes, s) + } + } + + l.md.config = config + + return +} diff --git a/pkg/listener/tap/tap_darwin.go b/pkg/listener/tap/tap_darwin.go new file mode 100644 index 0000000..69cfe1a --- /dev/null +++ b/pkg/listener/tap/tap_darwin.go @@ -0,0 +1,13 @@ +package tap + +import ( + "errors" + "net" + + "github.com/songgao/water" +) + +func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { + err = errors.New("tap is not supported on darwin") + return +} diff --git a/pkg/listener/tap/tap_linux.go b/pkg/listener/tap/tap_linux.go new file mode 100644 index 0000000..dac7f10 --- /dev/null +++ b/pkg/listener/tap/tap_linux.go @@ -0,0 +1,69 @@ +package tap + +import ( + "net" + + "github.com/docker/libcontainer/netlink" + "github.com/milosgajdos/tenus" + "github.com/songgao/water" +) + +func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { + var ipNet *net.IPNet + if l.md.config.Net != "" { + ip, ipNet, err = net.ParseCIDR(l.md.config.Net) + if err != nil { + return + } + } + + ifce, err = water.New(water.Config{ + DeviceType: water.TAP, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: l.md.config.Name, + }, + }) + if err != nil { + return + } + + link, err := tenus.NewLinkFrom(ifce.Name()) + if err != nil { + return + } + + l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU) + + if err = link.SetLinkMTU(l.md.config.MTU); err != nil { + return + } + + if l.md.config.Net != "" { + l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name()) + + if err = link.SetLinkIp(ip, ipNet); err != nil { + return + } + } + + l.logger.Debugf("ip link set dev %s up", ifce.Name()) + if err = link.SetLinkUp(); err != nil { + return + } + + if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + return + } + + return +} + +func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + l.logger.Debugf("ip route add %s via %s dev %s", route, gw, ifName) + if err := netlink.AddRoute(route, "", gw, ifName); err != nil { + return err + } + } + return nil +} diff --git a/pkg/listener/tap/tap_unix.go b/pkg/listener/tap/tap_unix.go new file mode 100644 index 0000000..cebd27d --- /dev/null +++ b/pkg/listener/tap/tap_unix.go @@ -0,0 +1,61 @@ +//go:build !linux && !windows && !darwin + +package tap + +import ( + "fmt" + "net" + "os/exec" + "strings" + + "github.com/songgao/water" +) + +func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { + ip, _, _ = net.ParseCIDR(l.md.config.Net) + + ifce, err = water.New(water.Config{ + DeviceType: water.TAP, + }) + if err != nil { + return + } + + var cmd string + if l.md.config.Net != "" { + cmd = fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), l.md.config.Net, l.md.config.MTU) + } else { + cmd = fmt.Sprintf("ifconfig %s mtu %d up", ifce.Name(), l.md.config.MTU) + } + l.logger.Debug(cmd) + + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + return + } + + return +} + +func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + if route == "" { + continue + } + cmd := fmt.Sprintf("route add -net %s dev %s", route, ifName) + if gw != "" { + cmd += " gw " + gw + } + l.logger.Debug(cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} diff --git a/pkg/listener/tap/tap_windows.go b/pkg/listener/tap/tap_windows.go new file mode 100644 index 0000000..e3a8936 --- /dev/null +++ b/pkg/listener/tap/tap_windows.go @@ -0,0 +1,75 @@ +package tap + +import ( + "fmt" + "net" + "os/exec" + "strings" + + "github.com/songgao/water" +) + +func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) { + ip, ipNet, _ := net.ParseCIDR(l.md.config.Net) + + ifce, err = water.New(water.Config{ + DeviceType: water.TAP, + PlatformSpecificParams: water.PlatformSpecificParams{ + ComponentID: "tap0901", + InterfaceName: l.md.config.Name, + Network: l.md.config.Net, + }, + }) + if err != nil { + return + } + + if ip != nil && ipNet != nil { + cmd := fmt.Sprintf("netsh interface ip set address name=%s "+ + "source=static addr=%s mask=%s gateway=none", + ifce.Name(), ip.String(), ipMask(ipNet.Mask)) + l.logger.Debug(cmd) + + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + } + + if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { + return + } + + return +} + +func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + l.deleteRoute(ifName, route) + + cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", + route, ifName) + if gw != "" { + cmd += " nexthop=" + gw + } + l.logger.Debug(cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} + +func (l *tapListener) deleteRoute(ifName string, route string) error { + cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active", + route, ifName) + l.logger.Debug(cmd) + args := strings.Split(cmd, " ") + return exec.Command(args[0], args[1:]...).Run() +} + +func ipMask(mask net.IPMask) string { + return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3]) +} diff --git a/pkg/listener/tls/metadata.go b/pkg/listener/tls/metadata.go index f37733a..d5067c2 100644 --- a/pkg/listener/tls/metadata.go +++ b/pkg/listener/tls/metadata.go @@ -4,14 +4,14 @@ import ( "crypto/tls" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { tlsConfig *tls.Config } -func (l *tlsListener) parseMetadata(md md.Metadata) (err error) { +func (l *tlsListener) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -19,9 +19,9 @@ func (l *tlsListener) parseMetadata(md md.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return diff --git a/pkg/listener/tls/mux/metadata.go b/pkg/listener/tls/mux/metadata.go index c59d429..9a2119a 100644 --- a/pkg/listener/tls/mux/metadata.go +++ b/pkg/listener/tls/mux/metadata.go @@ -5,7 +5,7 @@ import ( "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -25,7 +25,7 @@ type metadata struct { backlog int } -func (l *mtlsListener) parseMetadata(md md.Metadata) (err error) { +func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { const ( certFile = "certFile" keyFile = "keyFile" @@ -42,25 +42,25 @@ func (l *mtlsListener) parseMetadata(md md.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout) - l.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer) + l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) + l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) + l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) + l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) + l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) + l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) return } diff --git a/pkg/listener/tun/conn.go b/pkg/listener/tun/conn.go deleted file mode 100644 index 56f3283..0000000 --- a/pkg/listener/tun/conn.go +++ /dev/null @@ -1,46 +0,0 @@ -package tun - -import ( - "errors" - "net" - "time" - - "github.com/songgao/water" -) - -type tunConn struct { - ifce *water.Interface - addr net.Addr -} - -func (c *tunConn) Read(b []byte) (n int, err error) { - return c.ifce.Read(b) -} - -func (c *tunConn) Write(b []byte) (n int, err error) { - return c.ifce.Write(b) -} - -func (c *tunConn) Close() (err error) { - return c.ifce.Close() -} - -func (c *tunConn) LocalAddr() net.Addr { - return c.addr -} - -func (c *tunConn) RemoteAddr() net.Addr { - return &net.IPAddr{} -} - -func (c *tunConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *tunConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *tunConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} diff --git a/pkg/listener/tun/listener.go b/pkg/listener/tun/listener.go index 3112d3c..683bb82 100644 --- a/pkg/listener/tun/listener.go +++ b/pkg/listener/tun/listener.go @@ -3,18 +3,13 @@ package tun import ( "net" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) -// ipRoute is an IP routing entry -type ipRoute struct { - Dest net.IPNet - Gateway net.IP -} - func init() { registry.RegisterListener("tun", NewListener) } @@ -44,19 +39,33 @@ func (l *tunListener) Init(md md.Metadata) (err error) { return } - conn, ifce, err := l.createTun() + l.addr, err = net.ResolveUDPAddr("udp", l.saddr) if err != nil { return } - addrs, _ := ifce.Addrs() - l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s", - ifce.Name, conn.LocalAddr(), ifce.MTU, addrs) + ifce, ip, err := l.createTun() + if err != nil { + if ifce != nil { + ifce.Close() + } + return + } + + itf, err := net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + addrs, _ := itf.Addrs() + l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s", + itf.Name, ip, itf.MTU, addrs) - l.addr = conn.LocalAddr() l.cqueue = make(chan net.Conn, 1) l.closed = make(chan struct{}) + conn := tun_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip}) + l.cqueue <- conn return diff --git a/pkg/listener/tun/metadata.go b/pkg/listener/tun/metadata.go index 1b269ca..2c4343a 100644 --- a/pkg/listener/tun/metadata.go +++ b/pkg/listener/tun/metadata.go @@ -4,7 +4,8 @@ import ( "net" "strings" - md "github.com/go-gost/gost/pkg/metadata" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -12,18 +13,10 @@ const ( ) type metadata struct { - name string - net string - // peer addr of point-to-point on MacOS - peer string - mtu int - routes []ipRoute - // default gateway - gateway string - tcp bool + config *tun_util.Config } -func (l *tunListener) parseMetadata(md md.Metadata) (err error) { +func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { const ( name = "name" netKey = "net" @@ -31,40 +24,40 @@ func (l *tunListener) parseMetadata(md md.Metadata) (err error) { mtu = "mtu" routes = "routes" gateway = "gw" - tcp = "tcp" ) - l.md.name = md.GetString(name) - l.md.net = md.GetString(netKey) - l.md.peer = md.GetString(peer) - l.md.mtu = md.GetInt(mtu) - - if l.md.mtu <= 0 { - l.md.mtu = DefaultMTU + config := &tun_util.Config{ + Name: mdata.GetString(md, name), + Net: mdata.GetString(md, netKey), + Peer: mdata.GetString(md, peer), + MTU: mdata.GetInt(md, mtu), + Gateway: mdata.GetString(md, gateway), + } + if config.MTU <= 0 { + config.MTU = DefaultMTU } - l.md.gateway = md.GetString(gateway) - l.md.tcp = md.GetBool(tcp) + gw := net.ParseIP(config.Gateway) - gw := net.ParseIP(l.md.gateway) - - for _, s := range md.GetStrings(routes) { + for _, s := range mdata.GetStrings(md, routes) { ss := strings.SplitN(s, " ", 2) if len(ss) == 2 { - var route ipRoute + var route tun_util.Route _, ipNet, _ := net.ParseCIDR(strings.TrimSpace(ss[0])) if ipNet == nil { continue } - route.Dest = *ipNet + route.Net = *ipNet route.Gateway = net.ParseIP(ss[1]) if route.Gateway == nil { route.Gateway = gw } - l.md.routes = append(l.md.routes, route) + config.Routes = append(config.Routes, route) } } + l.md.config = config + return } diff --git a/pkg/listener/tun/tun_darwin.go b/pkg/listener/tun/tun_darwin.go index 30091d9..0e569e9 100644 --- a/pkg/listener/tun/tun_darwin.go +++ b/pkg/listener/tun/tun_darwin.go @@ -6,29 +6,30 @@ import ( "os/exec" "strings" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" "github.com/songgao/water" ) -func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) { - ip, _, err := net.ParseCIDR(l.md.net) +func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) { + ip, _, err = net.ParseCIDR(l.md.config.Net) if err != nil { return } - ifce, err := water.New(water.Config{ + ifce, err = water.New(water.Config{ DeviceType: water.TUN, }) if err != nil { return } - peer := l.md.peer + peer := l.md.config.Peer if peer == "" { peer = ip.String() } cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up", - ifce.Name(), l.md.net, l.md.peer, l.md.mtu) + ifce.Name(), l.md.config.Net, l.md.config.Peer, l.md.config.MTU) l.logger.Debug(cmd) args := strings.Split(cmd, " ") @@ -36,25 +37,16 @@ func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) return } - if err = l.addRoutes(ifce.Name(), l.md.routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil { return } - itf, err = net.InterfaceByName(ifce.Name()) - if err != nil { - return - } - - conn = &tunConn{ - ifce: ifce, - addr: &net.IPAddr{IP: ip}, - } return } -func (l *tunListener) addRoutes(ifName string, routes ...ipRoute) error { +func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error { for _, route := range routes { - cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName) l.logger.Debug(cmd) args := strings.Split(cmd, " ") if err := exec.Command(args[0], args[1:]...).Run(); err != nil { diff --git a/pkg/listener/tun/tun_linux.go b/pkg/listener/tun/tun_linux.go index 891b52f..9f9cacf 100644 --- a/pkg/listener/tun/tun_linux.go +++ b/pkg/listener/tun/tun_linux.go @@ -6,20 +6,21 @@ import ( "syscall" "github.com/docker/libcontainer/netlink" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" "github.com/milosgajdos/tenus" "github.com/songgao/water" ) -func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) { - ip, ipNet, err := net.ParseCIDR(l.md.net) +func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) { + ip, ipNet, err := net.ParseCIDR(l.md.config.Net) if err != nil { return } - ifce, err := water.New(water.Config{ + ifce, err = water.New(water.Config{ DeviceType: water.TUN, PlatformSpecificParams: water.PlatformSpecificParams{ - Name: l.md.name, + Name: l.md.config.Name, }, }) if err != nil { @@ -31,13 +32,13 @@ func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) return } - l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.mtu) + l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU) - if err = link.SetLinkMTU(l.md.mtu); err != nil { + if err = link.SetLinkMTU(l.md.config.MTU); err != nil { return } - l.logger.Debugf("ip address add %s dev %s", l.md.net, ifce.Name()) + l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name()) if err = link.SetLinkIp(ip, ipNet); err != nil { return @@ -48,26 +49,17 @@ func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) return } - if err = l.addRoutes(ifce.Name(), l.md.routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil { return } - itf, err = net.InterfaceByName(ifce.Name()) - if err != nil { - return - } - - conn = &tunConn{ - ifce: ifce, - addr: &net.IPAddr{IP: ip}, - } return } -func (l *tunListener) addRoutes(ifName string, routes ...ipRoute) error { +func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error { for _, route := range routes { - l.logger.Debugf("ip route add %s dev %s", route.Dest.String(), ifName) - if err := netlink.AddRoute(route.Dest.String(), "", "", ifName); err != nil && !errors.Is(err, syscall.EEXIST) { + l.logger.Debugf("ip route add %s dev %s", route.Net.String(), ifName) + if err := netlink.AddRoute(route.Net.String(), "", "", ifName); err != nil && !errors.Is(err, syscall.EEXIST) { return err } } diff --git a/pkg/listener/tun/tun_unix.go b/pkg/listener/tun/tun_unix.go index 24a1a1b..cc8614b 100644 --- a/pkg/listener/tun/tun_unix.go +++ b/pkg/listener/tun/tun_unix.go @@ -8,16 +8,17 @@ import ( "os/exec" "strings" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" "github.com/songgao/water" ) -func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) { - ip, _, err := net.ParseCIDR(l.md.net) +func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) { + ip, _, err = net.ParseCIDR(l.md.config.Net) if err != nil { return } - ifce, err := water.New(water.Config{ + ifce, err = water.New(water.Config{ DeviceType: water.TUN, }) if err != nil { @@ -25,33 +26,25 @@ func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) } cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up", - ifce.Name(), l.md.net, l.md.mtu) + ifce.Name(), l.md.config.Net, l.md.config.MTU) l.logger.Debug(cmd) + args := strings.Split(cmd, " ") if er := exec.Command(args[0], args[1:]...).Run(); er != nil { err = fmt.Errorf("%s: %v", cmd, er) return } - if err = l.addRoutes(ifce.Name(), l.md.routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil { return } - itf, err = net.InterfaceByName(ifce.Name()) - if err != nil { - return - } - - conn = &tunConn{ - ifce: ifce, - addr: &net.IPAddr{IP: ip}, - } return } -func (l *tunListener) addRoutes(ifName string, routes ...ipRoute) error { +func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error { for _, route := range routes { - cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName) l.logger.Debug(cmd) args := strings.Split(cmd, " ") if er := exec.Command(args[0], args[1:]...).Run(); er != nil { diff --git a/pkg/listener/tun/tun_windows.go b/pkg/listener/tun/tun_windows.go index 2451847..1b795fa 100644 --- a/pkg/listener/tun/tun_windows.go +++ b/pkg/listener/tun/tun_windows.go @@ -6,21 +6,22 @@ import ( "os/exec" "strings" + tun_util "github.com/go-gost/gost/pkg/internal/util/tun" "github.com/songgao/water" ) -func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) { - ip, ipNet, err := net.ParseCIDR(l.md.net) +func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) { + ip, ipNet, err := net.ParseCIDR(l.md.config.Net) if err != nil { return } - ifce, err := water.New(water.Config{ + ifce, err = water.New(water.Config{ DeviceType: water.TUN, PlatformSpecificParams: water.PlatformSpecificParams{ ComponentID: "tap0901", - InterfaceName: l.md.name, - Network: l.md.net, + InterfaceName: l.md.config.Name, + Network: l.md.config.Net, }, }) if err != nil { @@ -38,28 +39,19 @@ func (l *tunListener) createTun() (conn net.Conn, itf *net.Interface, err error) return } - if err = l.addRoutes(ifce.Name(), l.md.gateway, l.md.routes...); err != nil { + if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil { return } - itf, err = net.InterfaceByName(ifce.Name()) - if err != nil { - return - } - - conn = &tunConn{ - ifce: ifce, - addr: &net.IPAddr{IP: ip}, - } return } -func (l *tunListener) addRoutes(ifName string, gw string, routes ...ipRoute) error { +func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error { for _, route := range routes { - l.deleteRoute(ifName, route.Dest.String()) + l.deleteRoute(ifName, route.Net.String()) cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", - route.Dest.String(), ifName) + route.Net.String(), ifName) if gw != "" { cmd += " nexthop=" + gw } diff --git a/pkg/listener/udp/metadata.go b/pkg/listener/udp/metadata.go index e1c2c94..3d99759 100644 --- a/pkg/listener/udp/metadata.go +++ b/pkg/listener/udp/metadata.go @@ -3,7 +3,7 @@ package udp import ( "time" - md "github.com/go-gost/gost/pkg/metadata" + mdata "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -21,7 +21,7 @@ type metadata struct { backlog int } -func (l *udpListener) parseMetadata(md md.Metadata) (err error) { +func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) { const ( ttl = "ttl" readBufferSize = "readBufferSize" @@ -29,21 +29,21 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) { backlog = "backlog" ) - l.md.ttl = md.GetDuration(ttl) + l.md.ttl = mdata.GetDuration(md, ttl) if l.md.ttl <= 0 { l.md.ttl = defaultTTL } - l.md.readBufferSize = md.GetInt(readBufferSize) + l.md.readBufferSize = mdata.GetInt(md, readBufferSize) if l.md.readBufferSize <= 0 { l.md.readBufferSize = defaultReadBufferSize } - l.md.readQueueSize = md.GetInt(readQueueSize) + l.md.readQueueSize = mdata.GetInt(md, readQueueSize) if l.md.readQueueSize <= 0 { l.md.readQueueSize = defaultReadQueueSize } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/pkg/listener/ws/metadata.go b/pkg/listener/ws/metadata.go index 51b87a2..f45219a 100644 --- a/pkg/listener/ws/metadata.go +++ b/pkg/listener/ws/metadata.go @@ -47,29 +47,29 @@ func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.path = md.GetString(path) + l.md.path = mdata.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) - l.md.readBufferSize = md.GetInt(readBufferSize) - l.md.writeBufferSize = md.GetInt(writeBufferSize) - l.md.enableCompression = md.GetBool(enableCompression) + l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) + l.md.readBufferSize = mdata.GetInt(md, readBufferSize) + l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) + l.md.enableCompression = mdata.GetBool(md, enableCompression) if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} diff --git a/pkg/listener/ws/mux/metadata.go b/pkg/listener/ws/mux/metadata.go index 3da2b4a..e17f2ab 100644 --- a/pkg/listener/ws/mux/metadata.go +++ b/pkg/listener/ws/mux/metadata.go @@ -59,36 +59,36 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { ) l.md.tlsConfig, err = tls_util.LoadServerConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), + mdata.GetString(md, certFile), + mdata.GetString(md, keyFile), + mdata.GetString(md, caFile), ) if err != nil { return } - l.md.path = md.GetString(path) + l.md.path = mdata.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = md.GetInt(backlog) + l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) - l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) - l.md.readBufferSize = md.GetInt(readBufferSize) - l.md.writeBufferSize = md.GetInt(writeBufferSize) - l.md.enableCompression = md.GetBool(enableCompression) + l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) + l.md.readBufferSize = mdata.GetInt(md, readBufferSize) + l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) + l.md.enableCompression = mdata.GetBool(md, enableCompression) - l.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout) - l.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer) + l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) + l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) + l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) + l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) + l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) + l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index cba9f33..93f9fee 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -10,12 +10,6 @@ type Metadata interface { IsExists(key string) bool Set(key string, value interface{}) Get(key string) interface{} - GetBool(key string) bool - GetInt(key string) int - GetFloat(key string) float64 - GetDuration(key string) time.Duration - GetString(key string) string - GetStrings(key string) []string } type MapMetadata map[string]interface{} @@ -36,11 +30,11 @@ func (m MapMetadata) Get(key string) interface{} { return nil } -func (m MapMetadata) GetBool(key string) (v bool) { - if m == nil || !m.IsExists(key) { +func GetBool(md Metadata, key string) (v bool) { + if md == nil || !md.IsExists(key) { return } - switch vv := m[key].(type) { + switch vv := md.Get(key).(type) { case bool: return vv case int: @@ -52,8 +46,12 @@ func (m MapMetadata) GetBool(key string) (v bool) { return } -func (m MapMetadata) GetInt(key string) (v int) { - switch vv := m[key].(type) { +func GetInt(md Metadata, key string) (v int) { + if md == nil { + return + } + + switch vv := md.Get(key).(type) { case bool: if vv { v = 1 @@ -67,8 +65,12 @@ func (m MapMetadata) GetInt(key string) (v int) { return } -func (m MapMetadata) GetFloat(key string) (v float64) { - switch vv := m[key].(type) { +func GetFloat(md Metadata, key string) (v float64) { + if md == nil { + return + } + + switch vv := md.Get(key).(type) { case int: return float64(vv) case string: @@ -78,27 +80,28 @@ func (m MapMetadata) GetFloat(key string) (v float64) { return } -func (m MapMetadata) GetDuration(key string) (v time.Duration) { - if m != nil { - switch vv := m[key].(type) { - case int: - return time.Duration(vv) * time.Second - case string: - v, _ = time.ParseDuration(vv) - } +func GetDuration(md Metadata, key string) (v time.Duration) { + if md == nil { + return + } + switch vv := md.Get(key).(type) { + case int: + return time.Duration(vv) * time.Second + case string: + v, _ = time.ParseDuration(vv) } return } -func (m MapMetadata) GetString(key string) (v string) { - if m != nil { - v, _ = m[key].(string) +func GetString(md Metadata, key string) (v string) { + if md != nil { + v, _ = md.Get(key).(string) } return } -func (m MapMetadata) GetStrings(key string) (ss []string) { - if v, _ := m.Get(key).([]interface{}); len(v) > 0 { +func GetStrings(md Metadata, key string) (ss []string) { + if v, _ := md.Get(key).([]interface{}); len(v) > 0 { for _, vv := range v { if s, ok := vv.(string); ok { ss = append(ss, s) @@ -108,10 +111,29 @@ func (m MapMetadata) GetStrings(key string) (ss []string) { return } +func GetStringMap(md Metadata, key string) (m map[string]interface{}) { + switch vv := md.Get(key).(type) { + case map[string]interface{}: + return vv + case map[interface{}]interface{}: + m = make(map[string]interface{}) + for k, v := range vv { + m[fmt.Sprintf("%v", k)] = v + } + } + return +} + func GetStringMapString(md Metadata, key string) (m map[string]string) { - if mm, _ := md.Get(key).(map[interface{}]interface{}); len(mm) > 0 { + switch vv := md.Get(key).(type) { + case map[string]interface{}: m = make(map[string]string) - for k, v := range mm { + for k, v := range vv { + m[k] = fmt.Sprintf("%v", v) + } + case map[interface{}]interface{}: + m = make(map[string]string) + for k, v := range vv { m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) } }