diff --git a/admission/admission.go b/admission/admission.go new file mode 100644 index 0000000..a821a55 --- /dev/null +++ b/admission/admission.go @@ -0,0 +1,86 @@ +package admission + +import ( + "net" + "strconv" + + admission_pkg "github.com/go-gost/core/admission" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/util/matcher" +) + +type options struct { + logger logger.Logger +} + +type Option func(opts *options) + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type admission struct { + matchers []matcher.Matcher + reversed bool + options options +} + +// NewAdmission creates and initializes a new Admission using matchers as its match rules. +// The rules will be reversed if the reversed is true. +func NewAdmission(reversed bool, matchers []matcher.Matcher, opts ...Option) admission_pkg.Admission { + options := options{} + for _, opt := range opts { + opt(&options) + } + return &admission{ + matchers: matchers, + reversed: reversed, + options: options, + } +} + +// NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules. +// The rules will be reversed if the reverse is true. +func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission { + var matchers []matcher.Matcher + for _, pattern := range patterns { + if m := matcher.NewMatcher(pattern); m != nil { + matchers = append(matchers, m) + } + } + return NewAdmission(reversed, matchers, opts...) +} + +func (p *admission) Admit(addr string) bool { + if addr == "" || p == nil || len(p.matchers) == 0 { + p.options.logger.Debugf("admission: %v is denied", addr) + return false + } + + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + + var matched bool + for _, matcher := range p.matchers { + if matcher == nil { + continue + } + if matcher.Match(addr) { + matched = true + break + } + } + + b := !p.reversed && matched || + p.reversed && !matched + if !b { + p.options.logger.Debugf("admission: %v is denied", addr) + } + return b +} diff --git a/api/config_admission.go b/api/config_admission.go index d2b1ccf..31fce1e 100644 --- a/api/config_admission.go +++ b/api/config_admission.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createAdmissionRequest diff --git a/api/config_auther.go b/api/config_auther.go index b622089..e17601e 100644 --- a/api/config_auther.go +++ b/api/config_auther.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createAutherRequest diff --git a/api/config_bypass.go b/api/config_bypass.go index 2acb8f6..4defc98 100644 --- a/api/config_bypass.go +++ b/api/config_bypass.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createBypassRequest diff --git a/api/config_chain.go b/api/config_chain.go index b8dbb38..741411a 100644 --- a/api/config_chain.go +++ b/api/config_chain.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createChainRequest diff --git a/api/config_hosts.go b/api/config_hosts.go index a2dc4cd..4d5c78d 100644 --- a/api/config_hosts.go +++ b/api/config_hosts.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createHostsRequest diff --git a/api/config_resolver.go b/api/config_resolver.go index d97bce1..80bf999 100644 --- a/api/config_resolver.go +++ b/api/config_resolver.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createResolverRequest diff --git a/api/config_service.go b/api/config_service.go index 5df6820..9d80e82 100644 --- a/api/config_service.go +++ b/api/config_service.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" ) // swagger:parameters createServiceRequest diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..e95f287 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,27 @@ +package auth + +import ( + "github.com/go-gost/core/auth" +) + +// authenticator is an Authenticator that authenticates client by key-value pairs. +type authenticator struct { + kvs map[string]string +} + +// NewAuthenticator creates an Authenticator that authenticates client by pre-defined user mapping. +func NewAuthenticator(kvs map[string]string) auth.Authenticator { + return &authenticator{ + kvs: kvs, + } +} + +// Authenticate checks the validity of the provided user-password pair. +func (au *authenticator) Authenticate(user, password string) bool { + if au == nil || len(au.kvs) == 0 { + return true + } + + v, ok := au.kvs[user] + return ok && (v == "" || password == v) +} diff --git a/bypass/bypass.go b/bypass/bypass.go new file mode 100644 index 0000000..9824b92 --- /dev/null +++ b/bypass/bypass.go @@ -0,0 +1,85 @@ +package bypass + +import ( + "net" + "strconv" + + bypass_pkg "github.com/go-gost/core/bypass" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/util/matcher" +) + +type options struct { + logger logger.Logger +} + +type Option func(opts *options) + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type bypass struct { + matchers []matcher.Matcher + reversed bool + options options +} + +// NewBypass creates and initializes a new Bypass using matchers as its match rules. +// The rules will be reversed if the reversed is true. +func NewBypass(reversed bool, matchers []matcher.Matcher, opts ...Option) bypass_pkg.Bypass { + options := options{} + for _, opt := range opts { + opt(&options) + } + return &bypass{ + matchers: matchers, + reversed: reversed, + options: options, + } +} + +// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. +// The rules will be reversed if the reverse is true. +func NewBypassPatterns(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass { + var matchers []matcher.Matcher + for _, pattern := range patterns { + if m := matcher.NewMatcher(pattern); m != nil { + matchers = append(matchers, m) + } + } + return NewBypass(reversed, matchers, opts...) +} + +func (bp *bypass) Contains(addr string) bool { + if addr == "" || bp == nil || len(bp.matchers) == 0 { + return false + } + + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + + var matched bool + for _, matcher := range bp.matchers { + if matcher == nil { + continue + } + if matcher.Match(addr) { + matched = true + break + } + } + + b := !bp.reversed && matched || + bp.reversed && !matched + if b { + bp.options.logger.Debugf("bypass: %s", addr) + } + return b +} diff --git a/config/parsing/chain.go b/config/parsing/chain.go index d0fa667..4bbe3cb 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -2,13 +2,13 @@ package parsing import ( "github.com/go-gost/core/chain" - tls_util "github.com/go-gost/core/common/util/tls" "github.com/go-gost/core/connector" "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" - "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/x/config" + tls_util "github.com/go-gost/x/internal/util/tls" + "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" ) func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { @@ -58,7 +58,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if v.Connector.Metadata == nil { v.Connector.Metadata = make(map[string]any) } - if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { + if err := cr.Init(metadata.NewMetadata(v.Connector.Metadata)); err != nil { connectorLogger.Error("init: ", err) return nil, err } @@ -88,7 +88,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if v.Dialer.Metadata == nil { v.Dialer.Metadata = make(map[string]any) } - if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { + if err := d.Init(metadata.NewMetadata(v.Dialer.Metadata)); err != nil { dialerLogger.Error("init: ", err) return nil, err } diff --git a/config/parsing/parse.go b/config/parsing/parse.go index ef4b56c..6fded76 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -8,12 +8,16 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" - hostspkg "github.com/go-gost/core/hosts" + "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" - "github.com/go-gost/core/registry" "github.com/go-gost/core/resolver" - resolver_impl "github.com/go-gost/core/resolver/impl" + admission_impl "github.com/go-gost/x/admission" + auth_impl "github.com/go-gost/x/auth" + bypass_impl "github.com/go-gost/x/bypass" "github.com/go-gost/x/config" + hosts_impl "github.com/go-gost/x/hosts" + "github.com/go-gost/x/registry" + resolver_impl "github.com/go-gost/x/resolver" ) func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { @@ -33,14 +37,14 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { if len(m) == 0 { return nil } - return auth.NewAuthenticator(m) + return auth_impl.NewAuthenticator(m) } func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator { if au == nil || au.Username == "" { return nil } - return auth.NewAuthenticator(map[string]string{ + return auth_impl.NewAuthenticator(map[string]string{ au.Username: au.Password, }) } @@ -84,10 +88,10 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { if cfg == nil { return nil } - return admission.NewAdmissionPatterns( + return admission_impl.NewAdmissionPatterns( cfg.Reverse, cfg.Matchers, - admission.LoggerOption(logger.Default().WithFields(map[string]any{ + admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ "kind": "admission", "admission": cfg.Name, })), @@ -98,10 +102,10 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil } - return bypass.NewBypassPatterns( + return bypass_impl.NewBypassPatterns( cfg.Reverse, cfg.Matchers, - bypass.LoggerOption(logger.Default().WithFields(map[string]any{ + bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ "kind": "bypass", "bypass": cfg.Name, })), @@ -136,11 +140,11 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { ) } -func ParseHosts(cfg *config.HostsConfig) hostspkg.HostMapper { +func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { if cfg == nil || len(cfg.Mappings) == 0 { return nil } - hosts := hostspkg.NewHosts() + hosts := hosts_impl.NewHosts() hosts.Logger = logger.Default().WithFields(map[string]any{ "kind": "hosts", "hosts": cfg.Name, diff --git a/config/parsing/service.go b/config/parsing/service.go index 6a65d58..5f0186d 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -4,14 +4,14 @@ import ( "strings" "github.com/go-gost/core/chain" - tls_util "github.com/go-gost/core/common/util/tls" "github.com/go-gost/core/handler" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" - "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/core/service" "github.com/go-gost/x/config" + tls_util "github.com/go-gost/x/internal/util/tls" + "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" ) func ParseService(cfg *config.ServiceConfig) (service.Service, error) { @@ -46,6 +46,9 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { listenerLogger.Error(err) return nil, err } + if tlsConfig == nil { + tlsConfig = defaultTLSConfig.Clone() + } auther := ParseAutherFromAuth(cfg.Listener.Auth) if cfg.Listener.Auther != "" { @@ -66,7 +69,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { if cfg.Listener.Metadata == nil { cfg.Listener.Metadata = make(map[string]any) } - if err := ln.Init(metadata.MapMetadata(cfg.Listener.Metadata)); err != nil { + if err := ln.Init(metadata.NewMetadata(cfg.Listener.Metadata)); err != nil { listenerLogger.Error("init: ", err) return nil, err } @@ -85,6 +88,9 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { handlerLogger.Error(err) return nil, err } + if tlsConfig == nil { + tlsConfig = defaultTLSConfig.Clone() + } auther = ParseAutherFromAuth(cfg.Handler.Auth) if cfg.Handler.Auther != "" { @@ -124,7 +130,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { if cfg.Handler.Metadata == nil { cfg.Handler.Metadata = make(map[string]any) } - if err := h.Init(metadata.MapMetadata(cfg.Handler.Metadata)); err != nil { + if err := h.Init(metadata.NewMetadata(cfg.Handler.Metadata)); err != nil { handlerLogger.Error("init: ", err) return nil, err } diff --git a/config/parsing/tls.go b/config/parsing/tls.go new file mode 100644 index 0000000..a576b4c --- /dev/null +++ b/config/parsing/tls.go @@ -0,0 +1,113 @@ +package parsing + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" + + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" +) + +var ( + defaultTLSConfig *tls.Config +) + +func BuildDefaultTLSConfig(cfg *config.TLSConfig) { + log := logger.Default() + + if cfg == nil { + cfg = &config.TLSConfig{ + CertFile: "cert.pem", + KeyFile: "key.pem", + } + } + + tlsConfig, err := loadConfig(cfg.CertFile, cfg.KeyFile) + if err != nil { + // generate random self-signed certificate. + cert, err := genCertificate() + if err != nil { + log.Fatal(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + log.Warn("load TLS certificate files failed, use random generated certificate") + } else { + log.Info("load TLS certificate files OK") + } + defaultTLSConfig = tlsConfig +} + +func loadConfig(certFile, keyFile string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + return cfg, nil +} + +func genCertificate() (cert tls.Certificate, err error) { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + return + } + return tls.X509KeyPair(rawCert, rawKey) +} + +func generateKeyPair() (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"gost"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + template.DNSNames = append(template.DNSNames, "gost.run") + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return + } + rawKey = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + + return +} diff --git a/connector/forward/connector.go b/connector/forward/connector.go new file mode 100644 index 0000000..24915b6 --- /dev/null +++ b/connector/forward/connector.go @@ -0,0 +1,45 @@ +package forward + +import ( + "context" + "net" + + "github.com/go-gost/core/connector" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ConnectorRegistry().Register("forward", NewConnector) +} + +type forwardConnector struct { + options connector.Options +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := connector.Options{} + for _, opt := range opts { + opt(&options) + } + + return &forwardConnector{ + options: options, + } +} + +func (c *forwardConnector) Init(md md.Metadata) (err error) { + return nil +} + +func (c *forwardConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + log.Infof("connect %s/%s", address, network) + + return conn, nil +} diff --git a/connector/http/connector.go b/connector/http/connector.go new file mode 100644 index 0000000..d8d8715 --- /dev/null +++ b/connector/http/connector.go @@ -0,0 +1,129 @@ +package http + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "github.com/go-gost/core/connector" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/internal/util/socks" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ConnectorRegistry().Register("http", NewConnector) +} + +type httpConnector struct { + md metadata + options connector.Options +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := connector.Options{} + for _, opt := range opts { + opt(&options) + } + + return &httpConnector{ + options: options, + } +} + +func (c *httpConnector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "local": conn.LocalAddr().String(), + "remote": conn.RemoteAddr().String(), + "network": network, + "address": address, + }) + log.Infof("connect %s/%s", address, network) + + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: address}, + Host: address, + ProtoMajor: 1, + ProtoMinor: 1, + Header: c.md.header, + } + + if req.Header == nil { + req.Header = http.Header{} + } + req.Header.Set("Proxy-Connection", "keep-alive") + + if user := c.options.Auth; user != nil { + u := user.Username() + p, _ := user.Password() + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) + } + + switch network { + case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + log.Error(err) + return nil, err + } + case "udp", "udp4", "udp6": + req.Header.Set("X-Gost-Protocol", "udp") + default: + err := fmt.Errorf("network %s is unsupported", network) + log.Error(err) + return nil, err + } + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Debug(string(dump)) + } + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + req = req.WithContext(ctx) + if err := req.Write(conn); err != nil { + return nil, err + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + // NOTE: the server may return `Transfer-Encoding: chunked` header, + // then the Content-Length of response will be unknown (-1), + // in this case, close body will be blocked, so we leave it untouched. + // defer resp.Body.Close() + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s", resp.Status) + } + + if network == "udp" { + addr, _ := net.ResolveUDPAddr(network, address) + return socks.UDPTunClientConn(conn, addr), nil + } + + return conn, nil +} diff --git a/connector/http/metadata.go b/connector/http/metadata.go new file mode 100644 index 0000000..761765f --- /dev/null +++ b/connector/http/metadata.go @@ -0,0 +1,33 @@ +package http + +import ( + "net/http" + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + connectTimeout time.Duration + header http.Header +} + +func (c *httpConnector) parseMetadata(md mdata.Metadata) (err error) { + const ( + connectTimeout = "timeout" + header = "header" + ) + + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + + if mm := mdx.GetStringMapString(md, header); len(mm) > 0 { + hd := http.Header{} + for k, v := range mm { + hd.Add(k, v) + } + c.md.header = hd + } + + return +} diff --git a/connector/http2/connector.go b/connector/http2/connector.go index 2b4776a..2393486 100644 --- a/connector/http2/connector.go +++ b/connector/http2/connector.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/connector" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/connector/http2/metadata.go b/connector/http2/metadata.go index cdce3d6..dd7f172 100644 --- a/connector/http2/metadata.go +++ b/connector/http2/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -18,8 +19,8 @@ func (c *http2Connector) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + if mm := mdx.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/connector/relay/connector.go b/connector/relay/connector.go index 351f726..69bf481 100644 --- a/connector/relay/connector.go +++ b/connector/relay/connector.go @@ -8,9 +8,9 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/relay" relay_util "github.com/go-gost/x/internal/util/relay" + "github.com/go-gost/x/registry" ) func init() { diff --git a/connector/relay/metadata.go b/connector/relay/metadata.go index 0982a95..46b55f5 100644 --- a/connector/relay/metadata.go +++ b/connector/relay/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -17,8 +18,8 @@ func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { noDelay = "nodelay" ) - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - c.md.noDelay = mdata.GetBool(md, noDelay) + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + c.md.noDelay = mdx.GetBool(md, noDelay) return } diff --git a/connector/sni/connector.go b/connector/sni/connector.go index 636724f..a3a6b79 100644 --- a/connector/sni/connector.go +++ b/connector/sni/connector.go @@ -6,7 +6,7 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/connector/sni/metadata.go b/connector/sni/metadata.go index 5c49c1a..d2b2412 100644 --- a/connector/sni/metadata.go +++ b/connector/sni/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -17,8 +18,8 @@ func (c *sniConnector) parseMetadata(md mdata.Metadata) (err error) { connectTimeout = "timeout" ) - c.md.host = mdata.GetString(md, host) - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.host = mdx.GetString(md, host) + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) return } diff --git a/connector/socks/v4/connector.go b/connector/socks/v4/connector.go new file mode 100644 index 0000000..7d8bd4e --- /dev/null +++ b/connector/socks/v4/connector.go @@ -0,0 +1,123 @@ +package v4 + +import ( + "context" + "errors" + "fmt" + "net" + "strconv" + "time" + + "github.com/go-gost/core/connector" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/gosocks4" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ConnectorRegistry().Register("socks4", NewConnector) + registry.ConnectorRegistry().Register("socks4a", NewConnector) +} + +type socks4Connector struct { + md metadata + options connector.Options +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := connector.Options{} + for _, opt := range opts { + opt(&options) + } + + return &socks4Connector{ + options: options, + } +} + +func (c *socks4Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + log.Infof("connect %s/%s", address, network) + + switch network { + case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + log.Error(err) + return nil, err + } + default: + err := fmt.Errorf("network %s is unsupported", network) + log.Error(err) + return nil, err + } + + var addr *gosocks4.Addr + + if c.md.disable4a { + taddr, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + log.Error("resolve: ", err) + return nil, err + } + if len(taddr.IP) == 0 { + taddr.IP = net.IPv4zero + } + addr = &gosocks4.Addr{ + Type: gosocks4.AddrIPv4, + Host: taddr.IP.String(), + Port: uint16(taddr.Port), + } + } else { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + addr = &gosocks4.Addr{ + Type: gosocks4.AddrDomain, + Host: host, + Port: uint16(p), + } + } + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + var userid []byte + if c.options.Auth != nil { + userid = []byte(c.options.Auth.Username()) + } + req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, userid) + if err := req.Write(conn); err != nil { + log.Error(err) + return nil, err + } + log.Debug(req) + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + log.Error(err) + return nil, err + } + log.Debug(reply) + + if reply.Code != gosocks4.Granted { + err = errors.New("host unreachable") + log.Error(err) + return nil, err + } + + return conn, nil +} diff --git a/connector/socks/v4/metadata.go b/connector/socks/v4/metadata.go new file mode 100644 index 0000000..d6bef12 --- /dev/null +++ b/connector/socks/v4/metadata.go @@ -0,0 +1,25 @@ +package v4 + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + connectTimeout time.Duration + disable4a bool +} + +func (c *socks4Connector) parseMetadata(md mdata.Metadata) (err error) { + const ( + connectTimeout = "timeout" + disable4a = "disable4a" + ) + + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + c.md.disable4a = mdx.GetBool(md, disable4a) + + return +} diff --git a/connector/socks/v5/bind.go b/connector/socks/v5/bind.go new file mode 100644 index 0000000..86de7d6 --- /dev/null +++ b/connector/socks/v5/bind.go @@ -0,0 +1,133 @@ +package v5 + +import ( + "context" + "fmt" + "net" + + "github.com/go-gost/core/common/net/udp" + "github.com/go-gost/core/connector" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/mux" + "github.com/go-gost/x/internal/util/socks" +) + +// Bind implements connector.Binder. +func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + log.Infof("bind on %s/%s", address, network) + + options := connector.BindOptions{} + for _, opt := range opts { + opt(&options) + } + + switch network { + case "tcp", "tcp4", "tcp6": + if options.Mux { + return c.muxBindTCP(ctx, conn, network, address, log) + } + return c.bindTCP(ctx, conn, network, address, log) + case "udp", "udp4", "udp6": + return c.bindUDP(ctx, conn, network, address, &options, log) + default: + err := fmt.Errorf("network %s is unsupported", network) + log.Error(err) + return nil, err + } +} + +func (c *socks5Connector) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (net.Listener, error) { + laddr, err := c.bind(conn, gosocks5.CmdBind, network, address, log) + if err != nil { + return nil, err + } + + return &tcpListener{ + addr: laddr, + conn: conn, + logger: log, + }, nil +} + +func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (net.Listener, error) { + laddr, err := c.bind(conn, socks.CmdMuxBind, network, address, log) + if err != nil { + return nil, err + } + + session, err := mux.ServerSession(conn) + if err != nil { + return nil, err + } + + return &tcpMuxListener{ + addr: laddr, + session: session, + logger: log, + }, nil +} + +func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, address string, opts *connector.BindOptions, log logger.Logger) (net.Listener, error) { + laddr, err := c.bind(conn, socks.CmdUDPTun, network, address, log) + if err != nil { + return nil, err + } + + ln := udp.NewListener(socks.UDPTunClientPacketConn(conn), + &udp.ListenConfig{ + Addr: laddr, + Backlog: opts.Backlog, + ReadQueueSize: opts.UDPDataQueueSize, + ReadBufferSize: opts.UDPDataBufferSize, + TTL: opts.UDPConnTTL, + KeepAlive: true, + Logger: log, + }) + + return ln, nil +} + +func (l *socks5Connector) bind(conn net.Conn, cmd uint8, network, address string, log logger.Logger) (net.Addr, error) { + addr := gosocks5.Addr{} + addr.ParseFrom(address) + req := gosocks5.NewRequest(cmd, &addr) + if err := req.Write(conn); err != nil { + return nil, err + } + log.Debug(req) + + // first reply, bind status + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + log.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("bind on %s/%s failed", address, network) + } + + var baddr net.Addr + switch network { + case "tcp", "tcp4", "tcp6": + baddr, err = net.ResolveTCPAddr(network, reply.Addr.String()) + case "udp", "udp4", "udp6": + baddr, err = net.ResolveUDPAddr(network, reply.Addr.String()) + default: + err = fmt.Errorf("unknown network %s", network) + } + if err != nil { + return nil, err + } + log.Debugf("bind on %s/%s OK", baddr, baddr.Network()) + + return baddr, nil +} diff --git a/connector/socks/v5/conn.go b/connector/socks/v5/conn.go new file mode 100644 index 0000000..d11f3b5 --- /dev/null +++ b/connector/socks/v5/conn.go @@ -0,0 +1,17 @@ +package v5 + +import "net" + +type bindConn struct { + net.Conn + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *bindConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *bindConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/connector/socks/v5/connector.go b/connector/socks/v5/connector.go new file mode 100644 index 0000000..377ad97 --- /dev/null +++ b/connector/socks/v5/connector.go @@ -0,0 +1,173 @@ +package v5 + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "time" + + "github.com/go-gost/core/connector" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ConnectorRegistry().Register("socks5", NewConnector) + registry.ConnectorRegistry().Register("socks", NewConnector) +} + +type socks5Connector struct { + selector gosocks5.Selector + md metadata + options connector.Options +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := connector.Options{} + for _, opt := range opts { + opt(&options) + } + + return &socks5Connector{ + options: options, + } +} + +func (c *socks5Connector) Init(md md.Metadata) (err error) { + if err = c.parseMetadata(md); err != nil { + return + } + + selector := &clientSelector{ + methods: []uint8{ + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + }, + User: c.options.Auth, + TLSConfig: c.options.TLSConfig, + logger: c.options.Logger, + } + if !c.md.noTLS { + selector.methods = append(selector.methods, socks.MethodTLS) + if selector.TLSConfig == nil { + selector.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + } + c.selector = selector + + return +} + +// Handshake implements connector.Handshaker. +func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + cc := gosocks5.ClientConn(conn, c.selector) + if err := cc.Handleshake(); err != nil { + log.Error(err) + return nil, err + } + + return cc, nil +} + +func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + log.Infof("connect %s/%s", address, network) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + switch network { + case "udp", "udp4", "udp6": + return c.connectUDP(ctx, conn, network, address, log) + case "tcp", "tcp4", "tcp6": + if _, ok := conn.(net.PacketConn); ok { + err := fmt.Errorf("tcp over udp is unsupported") + log.Error(err) + return nil, err + } + default: + err := fmt.Errorf("network %s is unsupported", network) + log.Error(err) + return nil, err + } + + addr := gosocks5.Addr{} + if err := addr.ParseFrom(address); err != nil { + log.Error(err) + return nil, err + } + + req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr) + if err := req.Write(conn); err != nil { + log.Error(err) + return nil, err + } + log.Debug(req) + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + log.Error(err) + return nil, err + } + log.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + err = errors.New("host unreachable") + log.Error(err) + return nil, err + } + + return conn, nil +} + +func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (net.Conn, error) { + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + log.Error(err) + return nil, err + } + + req := gosocks5.NewRequest(socks.CmdUDPTun, nil) + if err := req.Write(conn); err != nil { + log.Error(err) + return nil, err + } + log.Debug(req) + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + log.Error(err) + return nil, err + } + log.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("get socks5 UDP tunnel failure") + } + + return socks.UDPTunClientConn(conn, addr), nil +} diff --git a/connector/socks/v5/listener.go b/connector/socks/v5/listener.go new file mode 100644 index 0000000..2659ba7 --- /dev/null +++ b/connector/socks/v5/listener.go @@ -0,0 +1,102 @@ +package v5 + +import ( + "fmt" + "net" + + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/mux" +) + +type tcpListener struct { + addr net.Addr + conn net.Conn + logger logger.Logger +} + +func (p *tcpListener) Accept() (net.Conn, error) { + // second reply, peer connected + rep, err := gosocks5.ReadReply(p.conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("peer connect failed") + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: p.conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpListener) Addr() net.Addr { + return p.addr +} + +func (p *tcpListener) Close() error { + return p.conn.Close() +} + +type tcpMuxListener struct { + addr net.Addr + session *mux.Session + logger logger.Logger +} + +func (p *tcpMuxListener) Accept() (net.Conn, error) { + cc, err := p.session.Accept() + if err != nil { + return nil, err + } + + conn, err := p.getPeerConn(cc) + if err != nil { + cc.Close() + return nil, err + } + + return conn, nil +} + +func (p *tcpMuxListener) getPeerConn(conn net.Conn) (net.Conn, error) { + // second reply, peer connected + rep, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + err = fmt.Errorf("peer connect failed") + return nil, err + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpMuxListener) Addr() net.Addr { + return p.addr +} + +func (p *tcpMuxListener) Close() error { + return p.session.Close() +} diff --git a/connector/socks/v5/metadata.go b/connector/socks/v5/metadata.go new file mode 100644 index 0000000..e02f06b --- /dev/null +++ b/connector/socks/v5/metadata.go @@ -0,0 +1,25 @@ +package v5 + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + connectTimeout time.Duration + noTLS bool +} + +func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { + const ( + connectTimeout = "timeout" + noTLS = "notls" + ) + + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + c.md.noTLS = mdx.GetBool(md, noTLS) + + return +} diff --git a/connector/socks/v5/selector.go b/connector/socks/v5/selector.go new file mode 100644 index 0000000..c6f47ef --- /dev/null +++ b/connector/socks/v5/selector.go @@ -0,0 +1,73 @@ +package v5 + +import ( + "crypto/tls" + "net" + "net/url" + + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" +) + +type clientSelector struct { + methods []uint8 + User *url.Userinfo + TLSConfig *tls.Config + logger logger.Logger +} + +func (s *clientSelector) Methods() []uint8 { + s.logger.Debug("methods: ", s.methods) + return s.methods +} + +func (s *clientSelector) AddMethod(methods ...uint8) { + s.methods = append(s.methods, methods...) +} + +func (s *clientSelector) Select(methods ...uint8) (method uint8) { + return +} + +func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + s.logger.Debug("method selected: ", method) + + switch method { + case socks.MethodTLS: + conn = tls.Client(conn, s.TLSConfig) + + case gosocks5.MethodUserPass, socks.MethodTLSAuth: + if method == socks.MethodTLSAuth { + conn = tls.Client(conn, s.TLSConfig) + } + + var username, password string + if s.User != nil { + username = s.User.Username() + password, _ = s.User.Password() + } + + req := gosocks5.NewUserPassRequest(gosocks5.UserPassVer, username, password) + if err := req.Write(conn); err != nil { + s.logger.Error(err) + return nil, err + } + s.logger.Debug(req) + + resp, err := gosocks5.ReadUserPassResponse(conn) + if err != nil { + s.logger.Error(err) + return nil, err + } + s.logger.Debug(resp) + + if resp.Status != gosocks5.Succeeded { + return nil, gosocks5.ErrAuthFailure + } + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} diff --git a/connector/ss/connector.go b/connector/ss/connector.go index a3cfade..559dc19 100644 --- a/connector/ss/connector.go +++ b/connector/ss/connector.go @@ -9,9 +9,9 @@ import ( "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/gosocks5" "github.com/go-gost/x/internal/util/ss" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" ) diff --git a/connector/ss/metadata.go b/connector/ss/metadata.go index 7b0adb6..caaadd2 100644 --- a/connector/ss/metadata.go +++ b/connector/ss/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -19,9 +20,9 @@ func (c *ssConnector) parseMetadata(md mdata.Metadata) (err error) { noDelay = "nodelay" ) - c.md.key = mdata.GetString(md, key) - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - c.md.noDelay = mdata.GetBool(md, noDelay) + c.md.key = mdx.GetString(md, key) + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) + c.md.noDelay = mdx.GetBool(md, noDelay) return } diff --git a/connector/ss/udp/connector.go b/connector/ss/udp/connector.go index f36e00d..a9bb1bd 100644 --- a/connector/ss/udp/connector.go +++ b/connector/ss/udp/connector.go @@ -8,9 +8,9 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/x/internal/util/relay" "github.com/go-gost/x/internal/util/ss" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" ) diff --git a/connector/ss/udp/metadata.go b/connector/ss/udp/metadata.go index 5202c43..ca255fb 100644 --- a/connector/ss/udp/metadata.go +++ b/connector/ss/udp/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -20,10 +21,10 @@ func (c *ssuConnector) parseMetadata(md mdata.Metadata) (err error) { bufferSize = "bufferSize" // udp buffer size ) - c.md.key = mdata.GetString(md, key) - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) + c.md.key = mdx.GetString(md, key) + c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) - if bs := mdata.GetInt(md, bufferSize); bs > 0 { + if bs := mdx.GetInt(md, bufferSize); bs > 0 { c.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { c.md.bufferSize = 1500 diff --git a/connector/sshd/connector.go b/connector/sshd/connector.go index b8ffe79..66a2b31 100644 --- a/connector/sshd/connector.go +++ b/connector/sshd/connector.go @@ -7,8 +7,8 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" ssh_util "github.com/go-gost/x/internal/util/ssh" + "github.com/go-gost/x/registry" ) func init() { diff --git a/dialer/ftcp/dialer.go b/dialer/ftcp/dialer.go index c4f31df..41b1f6b 100644 --- a/dialer/ftcp/dialer.go +++ b/dialer/ftcp/dialer.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "github.com/xtaci/tcpraw" ) diff --git a/dialer/grpc/dialer.go b/dialer/grpc/dialer.go index acdafd0..53ffcc3 100644 --- a/dialer/grpc/dialer.go +++ b/dialer/grpc/dialer.go @@ -8,8 +8,8 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" pb "github.com/go-gost/x/internal/util/grpc/proto" + "github.com/go-gost/x/registry" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/credentials" diff --git a/dialer/grpc/metadata.go b/dialer/grpc/metadata.go index bf7f4ec..25ad5bc 100644 --- a/dialer/grpc/metadata.go +++ b/dialer/grpc/metadata.go @@ -2,6 +2,7 @@ package grpc import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,8 +16,8 @@ func (d *grpcDialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - d.md.insecure = mdata.GetBool(md, insecure) - d.md.host = mdata.GetString(md, host) + d.md.insecure = mdx.GetBool(md, insecure) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/http2/dialer.go b/dialer/http2/dialer.go index eb73be6..fbee078 100644 --- a/dialer/http2/dialer.go +++ b/dialer/http2/dialer.go @@ -11,7 +11,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" ) func init() { @@ -98,7 +99,7 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D defer d.clientMutex.Unlock() delete(d.clients, address) }, - md: md.MapMetadata{"client": client}, + md: mdx.NewMetadata(map[string]any{"client": client}), } return c, nil diff --git a/dialer/http2/h2/dialer.go b/dialer/http2/h2/dialer.go index da0944c..c9e2cfc 100644 --- a/dialer/http2/h2/dialer.go +++ b/dialer/http2/h2/dialer.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "golang.org/x/net/http2" ) diff --git a/dialer/http2/h2/metadata.go b/dialer/http2/h2/metadata.go index 1cba21c..96617a3 100644 --- a/dialer/http2/h2/metadata.go +++ b/dialer/http2/h2/metadata.go @@ -2,6 +2,7 @@ package h2 import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,8 +16,8 @@ func (d *h2Dialer) parseMetadata(md mdata.Metadata) (err error) { path = "path" ) - d.md.host = mdata.GetString(md, host) - d.md.path = mdata.GetString(md, path) + d.md.host = mdx.GetString(md, host) + d.md.path = mdx.GetString(md, path) return } diff --git a/dialer/http3/dialer.go b/dialer/http3/dialer.go index 3274c6b..ca5be9d 100644 --- a/dialer/http3/dialer.go +++ b/dialer/http3/dialer.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" pht_util "github.com/go-gost/x/internal/util/pht" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" ) diff --git a/dialer/http3/metadata.go b/dialer/http3/metadata.go index 00d7360..b776909 100644 --- a/dialer/http3/metadata.go +++ b/dialer/http3/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -34,19 +35,19 @@ func (d *http3Dialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - d.md.authorizePath = mdata.GetString(md, authorizePath) + d.md.authorizePath = mdx.GetString(md, authorizePath) if !strings.HasPrefix(d.md.authorizePath, "/") { d.md.authorizePath = defaultAuthorizePath } - d.md.pushPath = mdata.GetString(md, pushPath) + d.md.pushPath = mdx.GetString(md, pushPath) if !strings.HasPrefix(d.md.pushPath, "/") { d.md.pushPath = defaultPushPath } - d.md.pullPath = mdata.GetString(md, pullPath) + d.md.pullPath = mdx.GetString(md, pullPath) if !strings.HasPrefix(d.md.pullPath, "/") { d.md.pullPath = defaultPullPath } - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/icmp/dialer.go b/dialer/icmp/dialer.go index 7946457..24f8b39 100644 --- a/dialer/icmp/dialer.go +++ b/dialer/icmp/dialer.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" icmp_pkg "github.com/go-gost/x/internal/util/icmp" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" "golang.org/x/net/icmp" ) diff --git a/dialer/icmp/metadata.go b/dialer/icmp/metadata.go index f728e8d..dab4e59 100644 --- a/dialer/icmp/metadata.go +++ b/dialer/icmp/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -19,11 +20,11 @@ func (d *icmpDialer) parseMetadata(md mdata.Metadata) (err error) { maxIdleTimeout = "maxIdleTimeout" ) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) - d.md.keepAlive = mdata.GetBool(md, keepAlive) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + d.md.keepAlive = mdx.GetBool(md, keepAlive) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + d.md.maxIdleTimeout = mdx.GetDuration(md, maxIdleTimeout) return } diff --git a/dialer/kcp/dialer.go b/dialer/kcp/dialer.go index 2a64da7..3678b2b 100644 --- a/dialer/kcp/dialer.go +++ b/dialer/kcp/dialer.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" kcp_util "github.com/go-gost/x/internal/util/kcp" + "github.com/go-gost/x/registry" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" "github.com/xtaci/tcpraw" diff --git a/dialer/kcp/metadata.go b/dialer/kcp/metadata.go index 1d37190..23108fd 100644 --- a/dialer/kcp/metadata.go +++ b/dialer/kcp/metadata.go @@ -6,6 +6,7 @@ import ( mdata "github.com/go-gost/core/metadata" kcp_util "github.com/go-gost/x/internal/util/kcp" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -19,7 +20,7 @@ func (d *kcpDialer) parseMetadata(md mdata.Metadata) (err error) { handshakeTimeout = "handshakeTimeout" ) - if m := mdata.GetStringMap(md, config); len(m) > 0 { + if m := mdx.GetStringMap(md, config); len(m) > 0 { b, err := json.Marshal(m) if err != nil { return err @@ -34,6 +35,6 @@ func (d *kcpDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.config = kcp_util.DefaultConfig } - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) return } diff --git a/dialer/mtls/dialer.go b/dialer/mtls/dialer.go index 8fe73e3..6835fea 100644 --- a/dialer/mtls/dialer.go +++ b/dialer/mtls/dialer.go @@ -11,7 +11,7 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "github.com/xtaci/smux" ) diff --git a/dialer/mtls/metadata.go b/dialer/mtls/metadata.go index 1eb2be7..17e3e3c 100644 --- a/dialer/mtls/metadata.go +++ b/dialer/mtls/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -29,14 +30,14 @@ func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) - d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) - d.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) + d.md.muxKeepAliveDisabled = mdx.GetBool(md, muxKeepAliveDisabled) + d.md.muxKeepAliveInterval = mdx.GetDuration(md, muxKeepAliveInterval) + d.md.muxKeepAliveTimeout = mdx.GetDuration(md, muxKeepAliveTimeout) + d.md.muxMaxFrameSize = mdx.GetInt(md, muxMaxFrameSize) + d.md.muxMaxReceiveBuffer = mdx.GetInt(md, muxMaxReceiveBuffer) + d.md.muxMaxStreamBuffer = mdx.GetInt(md, muxMaxStreamBuffer) return } diff --git a/dialer/mws/dialer.go b/dialer/mws/dialer.go index dd77b94..4bf1f79 100644 --- a/dialer/mws/dialer.go +++ b/dialer/mws/dialer.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" ws_util "github.com/go-gost/x/internal/util/ws" + "github.com/go-gost/x/registry" "github.com/gorilla/websocket" "github.com/xtaci/smux" ) diff --git a/dialer/mws/metadata.go b/dialer/mws/metadata.go index b6f36fb..48f5da5 100644 --- a/dialer/mws/metadata.go +++ b/dialer/mws/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -54,34 +55,34 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) - d.md.path = mdata.GetString(md, path) + d.md.path = mdx.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) - d.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) + d.md.muxKeepAliveDisabled = mdx.GetBool(md, muxKeepAliveDisabled) + d.md.muxKeepAliveInterval = mdx.GetDuration(md, muxKeepAliveInterval) + d.md.muxKeepAliveTimeout = mdx.GetDuration(md, muxKeepAliveTimeout) + d.md.muxMaxFrameSize = mdx.GetInt(md, muxMaxFrameSize) + d.md.muxMaxReceiveBuffer = mdx.GetInt(md, muxMaxReceiveBuffer) + d.md.muxMaxStreamBuffer = mdx.GetInt(md, muxMaxStreamBuffer) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) - d.md.readBufferSize = mdata.GetInt(md, readBufferSize) - d.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) - d.md.enableCompression = mdata.GetBool(md, enableCompression) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + d.md.readHeaderTimeout = mdx.GetDuration(md, readHeaderTimeout) + d.md.readBufferSize = mdx.GetInt(md, readBufferSize) + d.md.writeBufferSize = mdx.GetInt(md, writeBufferSize) + d.md.enableCompression = mdx.GetBool(md, enableCompression) - if m := mdata.GetStringMapString(md, header); len(m) > 0 { + if m := mdx.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} for k, v := range m { h.Add(k, v) } d.md.header = h } - d.md.keepAlive = mdata.GetDuration(md, keepAlive) + d.md.keepAlive = mdx.GetDuration(md, keepAlive) return } diff --git a/dialer/obfs/http/dialer.go b/dialer/obfs/http/dialer.go index f277807..653102d 100644 --- a/dialer/obfs/http/dialer.go +++ b/dialer/obfs/http/dialer.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/dialer/obfs/http/metadata.go b/dialer/obfs/http/metadata.go index a2eb889..7df1a9c 100644 --- a/dialer/obfs/http/metadata.go +++ b/dialer/obfs/http/metadata.go @@ -4,6 +4,7 @@ import ( "net/http" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -17,13 +18,13 @@ func (d *obfsHTTPDialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - if m := mdata.GetStringMapString(md, header); len(m) > 0 { + if m := mdx.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} for k, v := range m { h.Add(k, v) } d.md.header = h } - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/obfs/tls/dialer.go b/dialer/obfs/tls/dialer.go index 3fa4767..d7b5b1a 100644 --- a/dialer/obfs/tls/dialer.go +++ b/dialer/obfs/tls/dialer.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/dialer/obfs/tls/metadata.go b/dialer/obfs/tls/metadata.go index 720922a..4c18221 100644 --- a/dialer/obfs/tls/metadata.go +++ b/dialer/obfs/tls/metadata.go @@ -2,6 +2,7 @@ package tls import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -13,6 +14,6 @@ func (d *obfsTLSDialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/pht/dialer.go b/dialer/pht/dialer.go index 95a98ea..856e5f7 100644 --- a/dialer/pht/dialer.go +++ b/dialer/pht/dialer.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" pht_util "github.com/go-gost/x/internal/util/pht" + "github.com/go-gost/x/registry" ) func init() { diff --git a/dialer/pht/metadata.go b/dialer/pht/metadata.go index f5ec5ed..7a3acad 100644 --- a/dialer/pht/metadata.go +++ b/dialer/pht/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -34,19 +35,19 @@ func (d *phtDialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - d.md.authorizePath = mdata.GetString(md, authorizePath) + d.md.authorizePath = mdx.GetString(md, authorizePath) if !strings.HasPrefix(d.md.authorizePath, "/") { d.md.authorizePath = defaultAuthorizePath } - d.md.pushPath = mdata.GetString(md, pushPath) + d.md.pushPath = mdx.GetString(md, pushPath) if !strings.HasPrefix(d.md.pushPath, "/") { d.md.pushPath = defaultPushPath } - d.md.pullPath = mdata.GetString(md, pullPath) + d.md.pullPath = mdx.GetString(md, pullPath) if !strings.HasPrefix(d.md.pullPath, "/") { d.md.pullPath = defaultPullPath } - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/quic/dialer.go b/dialer/quic/dialer.go index eaf86e4..707da06 100644 --- a/dialer/quic/dialer.go +++ b/dialer/quic/dialer.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" quic_util "github.com/go-gost/x/internal/util/quic" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" ) diff --git a/dialer/quic/metadata.go b/dialer/quic/metadata.go index b941e8b..c587f7c 100644 --- a/dialer/quic/metadata.go +++ b/dialer/quic/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -25,16 +26,16 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { host = "host" ) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) - if key := mdata.GetString(md, cipherKey); key != "" { + if key := mdx.GetString(md, cipherKey); key != "" { d.md.cipherKey = []byte(key) } - d.md.keepAlive = mdata.GetBool(md, keepAlive) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + d.md.keepAlive = mdx.GetBool(md, keepAlive) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + d.md.maxIdleTimeout = mdx.GetDuration(md, maxIdleTimeout) - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) return } diff --git a/dialer/ssh/dialer.go b/dialer/ssh/dialer.go index 2f1f5fc..95a60d3 100644 --- a/dialer/ssh/dialer.go +++ b/dialer/ssh/dialer.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" ssh_util "github.com/go-gost/x/internal/util/ssh" + "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" ) diff --git a/dialer/ssh/metadata.go b/dialer/ssh/metadata.go index a8fefa5..5ce1018 100644 --- a/dialer/ssh/metadata.go +++ b/dialer/ssh/metadata.go @@ -7,6 +7,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" "golang.org/x/crypto/ssh" ) @@ -24,7 +25,7 @@ func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { passphrase = "passphrase" ) - if v := mdata.GetString(md, user); v != "" { + if v := mdx.GetString(md, user); v != "" { ss := strings.SplitN(v, ":", 2) if len(ss) == 1 { d.md.user = url.User(ss[0]) @@ -33,13 +34,13 @@ func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { } } - if key := mdata.GetString(md, privateKeyFile); key != "" { + if key := mdx.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := mdata.GetString(md, passphrase) + pp := mdx.GetString(md, passphrase) if pp == "" { d.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -50,7 +51,7 @@ func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { } } - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) return } diff --git a/dialer/sshd/dialer.go b/dialer/sshd/dialer.go index e619f55..2282007 100644 --- a/dialer/sshd/dialer.go +++ b/dialer/sshd/dialer.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" ssh_util "github.com/go-gost/x/internal/util/ssh" + "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" ) diff --git a/dialer/sshd/metadata.go b/dialer/sshd/metadata.go index 1961334..d7841ea 100644 --- a/dialer/sshd/metadata.go +++ b/dialer/sshd/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" "golang.org/x/crypto/ssh" ) @@ -20,13 +21,13 @@ func (d *sshdDialer) parseMetadata(md mdata.Metadata) (err error) { passphrase = "passphrase" ) - if key := mdata.GetString(md, privateKeyFile); key != "" { + if key := mdx.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := mdata.GetString(md, passphrase) + pp := mdx.GetString(md, passphrase) if pp == "" { d.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -37,7 +38,7 @@ func (d *sshdDialer) parseMetadata(md mdata.Metadata) (err error) { } } - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) return } diff --git a/dialer/tcp/dialer.go b/dialer/tcp/dialer.go new file mode 100644 index 0000000..16f19f3 --- /dev/null +++ b/dialer/tcp/dialer.go @@ -0,0 +1,48 @@ +package tcp + +import ( + "context" + "net" + + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.DialerRegistry().Register("tcp", NewDialer) +} + +type tcpDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &tcpDialer{ + logger: options.Logger, + } +} + +func (d *tcpDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) + if err != nil { + d.logger.Error(err) + } + return conn, err +} diff --git a/dialer/tcp/metadata.go b/dialer/tcp/metadata.go new file mode 100644 index 0000000..34e3789 --- /dev/null +++ b/dialer/tcp/metadata.go @@ -0,0 +1,23 @@ +package tcp + +import ( + "time" + + md "github.com/go-gost/core/metadata" +) + +const ( + dialTimeout = "dialTimeout" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration +} + +func (d *tcpDialer) parseMetadata(md md.Metadata) (err error) { + return +} diff --git a/dialer/tls/dialer.go b/dialer/tls/dialer.go new file mode 100644 index 0000000..fd9967c --- /dev/null +++ b/dialer/tls/dialer.go @@ -0,0 +1,68 @@ +package tls + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.DialerRegistry().Register("tls", NewDialer) +} + +type tlsDialer struct { + md metadata + logger logger.Logger + options dialer.Options +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := dialer.Options{} + for _, opt := range opts { + opt(&options) + } + + return &tlsDialer{ + logger: options.Logger, + options: options, + } +} + +func (d *tlsDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *tlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) + if err != nil { + d.logger.Error(err) + } + return conn, err +} + +// Handshake implements dialer.Handshaker +func (d *tlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + tlsConn := tls.Client(conn, d.options.TLSConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return nil, err + } + + return tlsConn, nil +} diff --git a/dialer/tls/metadata.go b/dialer/tls/metadata.go new file mode 100644 index 0000000..21c8689 --- /dev/null +++ b/dialer/tls/metadata.go @@ -0,0 +1,22 @@ +package tls + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + handshakeTimeout time.Duration +} + +func (d *tlsDialer) parseMetadata(md mdata.Metadata) (err error) { + const ( + handshakeTimeout = "handshakeTimeout" + ) + + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + + return +} diff --git a/dialer/udp/conn.go b/dialer/udp/conn.go new file mode 100644 index 0000000..33e962d --- /dev/null +++ b/dialer/udp/conn.go @@ -0,0 +1,17 @@ +package udp + +import "net" + +type conn struct { + *net.UDPConn +} + +func (c *conn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.UDPConn.Write(b) +} + +func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.UDPConn.Read(b) + addr = c.RemoteAddr() + return +} diff --git a/dialer/udp/dialer.go b/dialer/udp/dialer.go new file mode 100644 index 0000000..42a8261 --- /dev/null +++ b/dialer/udp/dialer.go @@ -0,0 +1,50 @@ +package udp + +import ( + "context" + "net" + + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.DialerRegistry().Register("udp", NewDialer) +} + +type udpDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &udpDialer{ + logger: options.Logger, + } +} + +func (d *udpDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + c, err := options.NetDialer.Dial(ctx, "udp", addr) + if err != nil { + return nil, err + } + return &conn{ + UDPConn: c.(*net.UDPConn), + }, nil +} diff --git a/dialer/udp/metadata.go b/dialer/udp/metadata.go new file mode 100644 index 0000000..23bd5e2 --- /dev/null +++ b/dialer/udp/metadata.go @@ -0,0 +1,23 @@ +package udp + +import ( + "time" + + md "github.com/go-gost/core/metadata" +) + +const ( + dialTimeout = "dialTimeout" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration +} + +func (d *udpDialer) parseMetadata(md md.Metadata) (err error) { + return +} diff --git a/dialer/ws/dialer.go b/dialer/ws/dialer.go index 60669ec..a29e696 100644 --- a/dialer/ws/dialer.go +++ b/dialer/ws/dialer.go @@ -8,8 +8,8 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" ws_util "github.com/go-gost/x/internal/util/ws" + "github.com/go-gost/x/registry" "github.com/gorilla/websocket" ) diff --git a/dialer/ws/metadata.go b/dialer/ws/metadata.go index b28a35c..0e04a86 100644 --- a/dialer/ws/metadata.go +++ b/dialer/ws/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -40,27 +41,27 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { keepAlive = "keepAlive" ) - d.md.host = mdata.GetString(md, host) + d.md.host = mdx.GetString(md, host) - d.md.path = mdata.GetString(md, path) + d.md.path = mdx.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) - d.md.readBufferSize = mdata.GetInt(md, readBufferSize) - d.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) - d.md.enableCompression = mdata.GetBool(md, enableCompression) + d.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + d.md.readHeaderTimeout = mdx.GetDuration(md, readHeaderTimeout) + d.md.readBufferSize = mdx.GetInt(md, readBufferSize) + d.md.writeBufferSize = mdx.GetInt(md, writeBufferSize) + d.md.enableCompression = mdx.GetBool(md, enableCompression) - if m := mdata.GetStringMapString(md, header); len(m) > 0 { + if m := mdx.GetStringMapString(md, header); len(m) > 0 { h := http.Header{} for k, v := range m { h.Add(k, v) } d.md.header = h } - d.md.keepAlive = mdata.GetDuration(md, keepAlive) + d.md.keepAlive = mdx.GetDuration(md, keepAlive) return } diff --git a/go.mod b/go.mod index 11fabdf..1497557 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/docker/libcontainer v2.2.1+incompatible github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220403142327-6340d5198f83 + github.com/go-gost/core v0.0.0-20220404033031-04f6ed470873 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e diff --git a/go.sum b/go.sum index 9cd5509..c774851 100644 --- a/go.sum +++ b/go.sum @@ -119,6 +119,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gost/core v0.0.0-20220403142327-6340d5198f83 h1:Tt11K5yA/qnxSY8SDH774PSE/VP4Jqes+ab79g/N4Uw= github.com/go-gost/core v0.0.0-20220403142327-6340d5198f83/go.mod h1:oga1T7DJPJM+DpiQaZvTES9P9jvybRSgR/V5j+sEDpg= +github.com/go-gost/core v0.0.0-20220404033031-04f6ed470873 h1:u+g28xvN00bW5ivbhb2GGo0R+JIBy5arxy5R+rKesqk= +github.com/go-gost/core v0.0.0-20220404033031-04f6ed470873/go.mod h1:/LzdiQ+0+3FMhyqw0phjFjXFdOa1fcQR5/bL/7ripCs= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 h1:itaaJhQJ19kUXEB4Igb0EbY8m+1Py2AaNNSBds/9gk4= diff --git a/handler/auto/handler.go b/handler/auto/handler.go new file mode 100644 index 0000000..573a705 --- /dev/null +++ b/handler/auto/handler.go @@ -0,0 +1,115 @@ +package auto + +import ( + "bufio" + "context" + "net" + "time" + + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/gosocks4" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("auto", NewHandler) +} + +type autoHandler struct { + httpHandler handler.Handler + socks4Handler handler.Handler + socks5Handler handler.Handler + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + h := &autoHandler{ + options: options, + } + + if f := registry.HandlerRegistry().Get("http"); f != nil { + v := append(opts, + handler.LoggerOption(options.Logger.WithFields(map[string]any{"type": "http"}))) + h.httpHandler = f(v...) + } + if f := registry.HandlerRegistry().Get("socks4"); f != nil { + v := append(opts, + handler.LoggerOption(options.Logger.WithFields(map[string]any{"type": "socks4"}))) + h.socks4Handler = f(v...) + } + if f := registry.HandlerRegistry().Get("socks5"); f != nil { + v := append(opts, + handler.LoggerOption(options.Logger.WithFields(map[string]any{"type": "socks5"}))) + h.socks5Handler = f(v...) + } + + return h +} + +func (h *autoHandler) Init(md md.Metadata) error { + if h.httpHandler != nil { + if err := h.httpHandler.Init(md); err != nil { + return err + } + } + if h.socks4Handler != nil { + if err := h.socks4Handler.Init(md); err != nil { + return err + } + } + if h.socks5Handler != nil { + if err := h.socks5Handler.Init(md); err != nil { + return err + } + } + + return nil +} + +func (h *autoHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + start := time.Now() + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + br := bufio.NewReader(conn) + b, err := br.Peek(1) + if err != nil { + log.Error(err) + conn.Close() + return err + } + + conn = netpkg.NewBufferReaderConn(conn, br) + switch b[0] { + case gosocks4.Ver4: // socks4 + if h.socks4Handler != nil { + return h.socks4Handler.Handle(ctx, conn) + } + case gosocks5.Ver5: // socks5 + if h.socks5Handler != nil { + return h.socks5Handler.Handle(ctx, conn) + } + default: // http + if h.httpHandler != nil { + return h.httpHandler.Handle(ctx, conn) + } + } + return nil +} diff --git a/handler/dns/handler.go b/handler/dns/handler.go index d0ecd3b..9d75afd 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -12,13 +12,13 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/common/bufpool" - resolver_util "github.com/go-gost/core/common/util/resolver" "github.com/go-gost/core/handler" "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" - "github.com/go-gost/core/resolver/exchanger" + resolver_util "github.com/go-gost/x/internal/util/resolver" + "github.com/go-gost/x/registry" + "github.com/go-gost/x/resolver/exchanger" "github.com/miekg/dns" ) diff --git a/handler/dns/metadata.go b/handler/dns/metadata.go index e69b105..75eed2b 100644 --- a/handler/dns/metadata.go +++ b/handler/dns/metadata.go @@ -5,9 +5,11 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( + defaultTimeout = 5 * time.Second defaultBufferSize = 1024 ) @@ -29,17 +31,17 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { dns = "dns" ) - h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.ttl = mdata.GetDuration(md, ttl) - h.md.timeout = mdata.GetDuration(md, timeout) + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + h.md.ttl = mdx.GetDuration(md, ttl) + h.md.timeout = mdx.GetDuration(md, timeout) if h.md.timeout <= 0 { - h.md.timeout = 5 * time.Second + h.md.timeout = defaultTimeout } - sip := mdata.GetString(md, clientIP) + sip := mdx.GetString(md, clientIP) if sip != "" { h.md.clientIP = net.ParseIP(sip) } - h.md.dns = mdata.GetStrings(md, dns) + h.md.dns = mdx.GetStrings(md, dns) return } diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go new file mode 100644 index 0000000..47d8ba8 --- /dev/null +++ b/handler/forward/local/handler.go @@ -0,0 +1,117 @@ +package local + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/go-gost/core/chain" + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("tcp", NewHandler) + registry.HandlerRegistry().Register("udp", NewHandler) + registry.HandlerRegistry().Register("forward", NewHandler) +} + +type forwardHandler struct { + group *chain.NodeGroup + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &forwardHandler{ + options: options, + } +} + +func (h *forwardHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + if h.group == nil { + // dummy node used by relay connector. + h.group = chain.NewNodeGroup(&chain.Node{Name: "dummy", Addr: ":0"}) + } + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) + } + + return +} + +// Forward implements handler.Forwarder. +func (h *forwardHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + start := time.Now() + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + target := h.group.Next() + if target == nil { + err := errors.New("target not available") + log.Error(err) + return err + } + + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", target.Addr, network), + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) + + cc, err := h.router.Dial(ctx, network, target.Addr) + if err != nil { + log.Error(err) + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + target.Marker.Mark() + return err + } + defer cc.Close() + target.Marker.Reset() + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) + netpkg.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) + + return nil +} diff --git a/handler/forward/local/metadata.go b/handler/forward/local/metadata.go new file mode 100644 index 0000000..6f19074 --- /dev/null +++ b/handler/forward/local/metadata.go @@ -0,0 +1,21 @@ +package local + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + readTimeout time.Duration +} + +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + ) + + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + return +} diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go new file mode 100644 index 0000000..14d2760 --- /dev/null +++ b/handler/forward/remote/handler.go @@ -0,0 +1,111 @@ +package remote + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/go-gost/core/chain" + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("rtcp", NewHandler) + registry.HandlerRegistry().Register("rudp", NewHandler) +} + +type forwardHandler struct { + group *chain.NodeGroup + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &forwardHandler{ + options: options, + } +} + +func (h *forwardHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) + } + + return +} + +// Forward implements handler.Forwarder. +func (h *forwardHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + start := time.Now() + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + target := h.group.Next() + if target == nil { + err := errors.New("target not available") + log.Error(err) + return err + } + + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", target.Addr, network), + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) + + cc, err := h.router.Dial(ctx, network, target.Addr) + if err != nil { + log.Error(err) + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + target.Marker.Mark() + return err + } + defer cc.Close() + target.Marker.Reset() + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) + netpkg.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) + + return nil +} diff --git a/handler/forward/remote/metadata.go b/handler/forward/remote/metadata.go new file mode 100644 index 0000000..97eadd3 --- /dev/null +++ b/handler/forward/remote/metadata.go @@ -0,0 +1,21 @@ +package remote + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + readTimeout time.Duration +} + +func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + ) + + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + return +} diff --git a/handler/http/handler.go b/handler/http/handler.go new file mode 100644 index 0000000..2745e5a --- /dev/null +++ b/handler/http/handler.go @@ -0,0 +1,337 @@ +package http + +import ( + "bufio" + "context" + "encoding/base64" + "encoding/binary" + "errors" + "hash/crc32" + "net" + "net/http" + "net/http/httputil" + "os" + "strconv" + "strings" + "time" + + "github.com/asaskevich/govalidator" + "github.com/go-gost/core/chain" + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/handler" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("http", NewHandler) +} + +type httpHandler struct { + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &httpHandler{ + options: options, + } +} + +func (h *httpHandler) Init(md md.Metadata) error { + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) + } + + return nil +} + +func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + start := time.Now() + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Error(err) + return err + } + defer req.Body.Close() + + return h.handleRequest(ctx, conn, req, log) +} + +func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) error { + if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { + req.URL.Scheme = "http" + } + + network := req.Header.Get("X-Gost-Protocol") + if network != "udp" { + network = "tcp" + } + + // Try to get the actual host. + // Compatible with GOST 2.x. + if v := req.Header.Get("Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("Gost-Target") + + if v := req.Header.Get("X-Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("X-Gost-Target") + + addr := req.Host + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "80") + } + + fields := map[string]any{ + "dst": addr, + } + if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" { + fields["user"] = u + } + log = log.WithFields(fields) + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Debug(string(dump)) + } + log.Infof("%s >> %s", conn.RemoteAddr(), addr) + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: h.md.header, + } + if resp.Header == nil { + resp.Header = http.Header{} + } + + if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { + resp.StatusCode = http.StatusForbidden + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + log.Info("bypass: ", addr) + + return resp.Write(conn) + } + + if !h.authenticate(conn, req, resp, log) { + return nil + } + + if network == "udp" { + return h.handleUDP(ctx, conn, log) + } + + if req.Method == "PRI" || + (req.Method != http.MethodConnect && req.URL.Scheme != "http") { + resp.StatusCode = http.StatusBadRequest + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + + return resp.Write(conn) + } + + req.Header.Del("Proxy-Authorization") + + cc, err := h.router.Dial(ctx, network, addr) + if err != nil { + resp.StatusCode = http.StatusServiceUnavailable + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + resp.Write(conn) + return err + } + defer cc.Close() + + if req.Method == http.MethodConnect { + resp.StatusCode = http.StatusOK + resp.Status = "200 Connection established" + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + if err = resp.Write(conn); err != nil { + log.Error(err) + return err + } + } else { + req.Header.Del("Proxy-Connection") + if err = req.Write(cc); err != nil { + log.Error(err) + return err + } + } + + start := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) + netpkg.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil +} + +func (h *httpHandler) decodeServerName(s string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return "", err + } + if len(b) < 4 { + return "", errors.New("invalid name") + } + v, err := base64.RawURLEncoding.DecodeString(string(b[4:])) + if err != nil { + return "", err + } + if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) { + return "", errors.New("invalid name") + } + return string(v), nil +} + +func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) { + if proxyAuth == "" { + return + } + + if !strings.HasPrefix(proxyAuth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + + return cs[:s], cs[s+1:], true +} + +func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) { + u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) + if h.options.Auther == nil || h.options.Auther.Authenticate(u, p) { + return true + } + + pr := h.md.probeResistance + // probing resistance is enabled, and knocking host is mismatch. + if pr != nil && (pr.Knock == "" || !strings.EqualFold(req.URL.Hostname(), pr.Knock)) { + resp.StatusCode = http.StatusServiceUnavailable // default status code + + switch pr.Type { + case "code": + resp.StatusCode, _ = strconv.Atoi(pr.Value) + case "web": + url := pr.Value + if !strings.HasPrefix(url, "http") { + url = "http://" + url + } + r, err := http.Get(url) + if err != nil { + log.Error(err) + break + } + resp = r + defer resp.Body.Close() + case "host": + cc, err := net.Dial("tcp", pr.Value) + if err != nil { + log.Error(err) + break + } + defer cc.Close() + + req.Write(cc) + netpkg.Transport(conn, cc) + return + case "file": + f, _ := os.Open(pr.Value) + if f != nil { + defer f.Close() + + resp.StatusCode = http.StatusOK + if finfo, _ := f.Stat(); finfo != nil { + resp.ContentLength = finfo.Size() + } + resp.Header.Set("Content-Type", "text/html") + resp.Body = f + } + } + } + + if resp.Header == nil { + resp.Header = http.Header{} + } + if resp.StatusCode == 0 { + resp.StatusCode = http.StatusProxyAuthRequired + resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") + if strings.ToLower(req.Header.Get("Proxy-Connection")) == "keep-alive" { + // XXX libcurl will keep sending auth request in same conn + // which we don't supported yet. + resp.Header.Add("Connection", "close") + resp.Header.Add("Proxy-Connection", "close") + } + + log.Info("proxy authentication required") + } else { + resp.Header.Set("Server", "nginx/1.20.1") + resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + if resp.StatusCode == http.StatusOK { + resp.Header.Set("Connection", "keep-alive") + } + } + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + + resp.Write(conn) + return +} diff --git a/handler/http/metadata.go b/handler/http/metadata.go new file mode 100644 index 0000000..1509796 --- /dev/null +++ b/handler/http/metadata.go @@ -0,0 +1,54 @@ +package http + +import ( + "net/http" + "strings" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + probeResistance *probeResistance + sni bool + enableUDP bool + header http.Header +} + +func (h *httpHandler) parseMetadata(md mdata.Metadata) error { + const ( + header = "header" + probeResistKey = "probeResistance" + knock = "knock" + sni = "sni" + enableUDP = "udp" + ) + + if m := mdx.GetStringMapString(md, header); len(m) > 0 { + hd := http.Header{} + for k, v := range m { + hd.Add(k, v) + } + h.md.header = hd + } + + if v := mdx.GetString(md, probeResistKey); v != "" { + if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { + h.md.probeResistance = &probeResistance{ + Type: ss[0], + Value: ss[1], + Knock: mdx.GetString(md, knock), + } + } + } + h.md.sni = mdx.GetBool(md, sni) + h.md.enableUDP = mdx.GetBool(md, enableUDP) + + return nil +} + +type probeResistance struct { + Type string + Value string + Knock string +} diff --git a/handler/http/udp.go b/handler/http/udp.go new file mode 100644 index 0000000..d02413c --- /dev/null +++ b/handler/http/udp.go @@ -0,0 +1,80 @@ +package http + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httputil" + "time" + + "github.com/go-gost/core/common/net/relay" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/util/socks" +) + +func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "cmd": "udp", + }) + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: h.md.header, + } + if resp.Header == nil { + resp.Header = http.Header{} + } + + if !h.md.enableUDP { + resp.StatusCode = http.StatusForbidden + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + + log.Error("http: UDP relay is disabled") + + return resp.Write(conn) + } + + resp.StatusCode = http.StatusOK + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + if err := resp.Write(conn); err != nil { + log.Error(err) + return err + } + + // obtain a udp connection + c, err := h.router.Dial(ctx, "udp", "") // UDP association + if err != nil { + log.Error(err) + return err + } + defer c.Close() + + pc, ok := c.(net.PacketConn) + if !ok { + err = errors.New("wrong connection type") + log.Error(err) + return err + } + + relay := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + WithBypass(h.options.Bypass). + WithLogger(log) + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + relay.Run() + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + + return nil +} diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 217cb4a..11dc348 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -23,7 +23,7 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/handler/http2/metadata.go b/handler/http2/metadata.go index 3679514..92e1fe1 100644 --- a/handler/http2/metadata.go +++ b/handler/http2/metadata.go @@ -5,6 +5,7 @@ import ( "strings" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -19,7 +20,7 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { knock = "knock" ) - if m := mdata.GetStringMapString(md, header); len(m) > 0 { + if m := mdx.GetStringMapString(md, header); len(m) > 0 { hd := http.Header{} for k, v := range m { hd.Add(k, v) @@ -27,12 +28,12 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { h.md.header = hd } - if v := mdata.GetString(md, probeResistKey); v != "" { + if v := mdx.GetString(md, probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { h.md.probeResistance = &probeResistance{ Type: ss[0], Value: ss[1], - Knock: mdata.GetString(md, knock), + Knock: mdx.GetString(md, knock), } } } diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index e1b959c..2cabbc5 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -18,8 +18,8 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" dissector "github.com/go-gost/tls-dissector" + "github.com/go-gost/x/registry" ) func init() { diff --git a/handler/redirect/tcp/metadata.go b/handler/redirect/tcp/metadata.go index 6c707f3..0e6f8da 100644 --- a/handler/redirect/tcp/metadata.go +++ b/handler/redirect/tcp/metadata.go @@ -2,6 +2,7 @@ package redirect import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -14,7 +15,7 @@ func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) { sniffing = "sniffing" tproxy = "tproxy" ) - h.md.sniffing = mdata.GetBool(md, sniffing) - h.md.tproxy = mdata.GetBool(md, tproxy) + h.md.sniffing = mdx.GetBool(md, sniffing) + h.md.tproxy = mdx.GetBool(md, tproxy) return } diff --git a/handler/redirect/udp/handler.go b/handler/redirect/udp/handler.go index f92b9b6..f9fe7b5 100644 --- a/handler/redirect/udp/handler.go +++ b/handler/redirect/udp/handler.go @@ -10,7 +10,7 @@ import ( netpkg "github.com/go-gost/core/common/net" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/handler/relay/handler.go b/handler/relay/handler.go index 51db3d0..d063f39 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/relay" + "github.com/go-gost/x/registry" ) var ( diff --git a/handler/relay/metadata.go b/handler/relay/metadata.go index c34c9c2..c3b5fb1 100644 --- a/handler/relay/metadata.go +++ b/handler/relay/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -22,11 +23,11 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { noDelay = "nodelay" ) - h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.enableBind = mdata.GetBool(md, enableBind) - h.md.noDelay = mdata.GetBool(md, noDelay) + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + h.md.enableBind = mdx.GetBool(md, enableBind) + h.md.noDelay = mdx.GetBool(md, noDelay) - if bs := mdata.GetInt(md, udpBufferSize); bs > 0 { + if bs := mdx.GetInt(md, udpBufferSize); bs > 0 { h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { h.md.udpBufferSize = 1500 diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 3fc344e..da1c2e9 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -16,8 +16,8 @@ import ( netpkg "github.com/go-gost/core/common/net" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" dissector "github.com/go-gost/tls-dissector" + "github.com/go-gost/x/registry" ) func init() { diff --git a/handler/sni/metadata.go b/handler/sni/metadata.go index 3759d7d..486f529 100644 --- a/handler/sni/metadata.go +++ b/handler/sni/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,6 +16,6 @@ func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) { readTimeout = "readTimeout" ) - h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.readTimeout = mdx.GetDuration(md, readTimeout) return } diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go new file mode 100644 index 0000000..e9be353 --- /dev/null +++ b/handler/socks/v4/handler.go @@ -0,0 +1,152 @@ +package v4 + +import ( + "context" + "errors" + "net" + "time" + + "github.com/go-gost/core/chain" + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/handler" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/gosocks4" + "github.com/go-gost/x/registry" +) + +var ( + ErrUnknownCmd = errors.New("socks4: unknown command") + ErrUnimplemented = errors.New("socks4: unimplemented") +) + +func init() { + registry.HandlerRegistry().Register("socks4", NewHandler) + registry.HandlerRegistry().Register("socks4a", NewHandler) +} + +type socks4Handler struct { + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &socks4Handler{ + options: options, + } +} + +func (h *socks4Handler) Init(md md.Metadata) (err error) { + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) + } + + return nil +} + +func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + start := time.Now() + + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + req, err := gosocks4.ReadRequest(conn) + if err != nil { + log.Error(err) + return err + } + log.Debug(req) + + conn.SetReadDeadline(time.Time{}) + + if h.options.Auther != nil && + !h.options.Auther.Authenticate(string(req.Userid), "") { + resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) + log.Debug(resp) + return resp.Write(conn) + } + + switch req.Cmd { + case gosocks4.CmdConnect: + return h.handleConnect(ctx, conn, req, log) + case gosocks4.CmdBind: + return h.handleBind(ctx, conn, req) + default: + err = ErrUnknownCmd + log.Error(err) + return err + } +} + +func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) error { + addr := req.Addr.String() + + log = log.WithFields(map[string]any{ + "dst": addr, + }) + log.Infof("%s >> %s", conn.RemoteAddr(), addr) + + if h.options.Bypass != nil && h.options.Bypass.Contains(addr) { + resp := gosocks4.NewReply(gosocks4.Rejected, nil) + log.Debug(resp) + log.Info("bypass: ", addr) + return resp.Write(conn) + } + + cc, err := h.router.Dial(ctx, "tcp", addr) + if err != nil { + resp := gosocks4.NewReply(gosocks4.Failed, nil) + resp.Write(conn) + log.Debug(resp) + return err + } + + defer cc.Close() + + resp := gosocks4.NewReply(gosocks4.Granted, nil) + if err := resp.Write(conn); err != nil { + log.Error(err) + return err + } + log.Debug(resp) + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) + netpkg.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) + + return nil +} + +func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) error { + // TODO: bind + return ErrUnimplemented +} diff --git a/handler/socks/v4/metadata.go b/handler/socks/v4/metadata.go new file mode 100644 index 0000000..b36d582 --- /dev/null +++ b/handler/socks/v4/metadata.go @@ -0,0 +1,21 @@ +package v4 + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + readTimeout time.Duration +} + +func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + ) + + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + return +} diff --git a/handler/socks/v5/bind.go b/handler/socks/v5/bind.go new file mode 100644 index 0000000..89299fc --- /dev/null +++ b/handler/socks/v5/bind.go @@ -0,0 +1,149 @@ +package v5 + +import ( + "context" + "fmt" + "net" + "time" + + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" +) + +func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "bind", + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), address) + + if !h.md.enableBind { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(reply) + log.Error("socks5: BIND is disabled") + return reply.Write(conn) + } + + // BIND does not support chain. + return h.bindLocal(ctx, conn, network, address, log) +} + +func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + if err != nil { + log.Error(err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + if err := reply.Write(conn); err != nil { + log.Error(err) + } + log.Debug(reply) + return err + } + + socksAddr := gosocks5.Addr{} + if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil { + log.Warn(err) + } + + // Issue: may not reachable when host has multi-interface + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + socksAddr.Type = 0 + reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) + if err := reply.Write(conn); err != nil { + log.Error(err) + ln.Close() + return err + } + log.Debug(reply) + + log = log.WithFields(map[string]any{ + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), + }) + + log.Debugf("bind on %s OK", ln.Addr()) + + h.serveBind(ctx, conn, ln, log) + return nil +} + +func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) { + var rc net.Conn + accept := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer ln.Close() + + c, err := ln.Accept() + if err != nil { + errc <- err + } + rc = c + }() + + return errc + } + + pc1, pc2 := net.Pipe() + pipe := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer pc1.Close() + + errc <- netpkg.Transport(conn, pc1) + }() + + return errc + } + + defer pc2.Close() + + select { + case err := <-accept(): + if err != nil { + log.Error(err) + + reply := gosocks5.NewReply(gosocks5.Failure, nil) + if err := reply.Write(pc2); err != nil { + log.Error(err) + } + log.Debug(reply) + + return + } + defer rc.Close() + + log.Debugf("peer %s accepted", rc.RemoteAddr()) + + log = log.WithFields(map[string]any{ + "local": rc.LocalAddr().String(), + "remote": rc.RemoteAddr().String(), + }) + + raddr := gosocks5.Addr{} + raddr.ParseFrom(rc.RemoteAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr) + if err := reply.Write(pc2); err != nil { + log.Error(err) + } + log.Debug(reply) + + start := time.Now() + log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr()) + netpkg.Transport(pc2, rc) + log.WithFields(map[string]any{"duration": time.Since(start)}). + Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr()) + + case err := <-pipe(): + if err != nil { + log.Error(err) + } + ln.Close() + return + } +} diff --git a/handler/socks/v5/connect.go b/handler/socks/v5/connect.go new file mode 100644 index 0000000..88a8deb --- /dev/null +++ b/handler/socks/v5/connect.go @@ -0,0 +1,53 @@ +package v5 + +import ( + "context" + "fmt" + "net" + "time" + + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" +) + +func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "connect", + }) + log.Infof("%s >> %s", conn.RemoteAddr(), address) + + if h.options.Bypass != nil && h.options.Bypass.Contains(address) { + resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(resp) + log.Info("bypass: ", address) + return resp.Write(conn) + } + + cc, err := h.router.Dial(ctx, network, address) + if err != nil { + resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) + log.Debug(resp) + resp.Write(conn) + return err + } + + defer cc.Close() + + resp := gosocks5.NewReply(gosocks5.Succeeded, nil) + if err := resp.Write(conn); err != nil { + log.Error(err) + return err + } + log.Debug(resp) + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), address) + netpkg.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), address) + + return nil +} diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go new file mode 100644 index 0000000..12db1d8 --- /dev/null +++ b/handler/socks/v5/handler.go @@ -0,0 +1,115 @@ +package v5 + +import ( + "context" + "errors" + "net" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" + "github.com/go-gost/x/registry" +) + +var ( + ErrUnknownCmd = errors.New("socks5: unknown command") +) + +func init() { + registry.HandlerRegistry().Register("socks5", NewHandler) + registry.HandlerRegistry().Register("socks", NewHandler) +} + +type socks5Handler struct { + selector gosocks5.Selector + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &socks5Handler{ + options: options, + } +} + +func (h *socks5Handler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) + } + + h.selector = &serverSelector{ + Authenticator: h.options.Auther, + TLSConfig: h.options.TLSConfig, + logger: h.options.Logger, + noTLS: h.md.noTLS, + } + + return +} + +func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + start := time.Now() + + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + conn = gosocks5.ServerConn(conn, h.selector) + req, err := gosocks5.ReadRequest(conn) + if err != nil { + log.Error(err) + return err + } + log.Debug(req) + conn.SetReadDeadline(time.Time{}) + + address := req.Addr.String() + + switch req.Cmd { + case gosocks5.CmdConnect: + return h.handleConnect(ctx, conn, "tcp", address, log) + case gosocks5.CmdBind: + return h.handleBind(ctx, conn, "tcp", address, log) + case socks.CmdMuxBind: + return h.handleMuxBind(ctx, conn, "tcp", address, log) + case gosocks5.CmdUdp: + return h.handleUDP(ctx, conn, log) + case socks.CmdUDPTun: + return h.handleUDPTun(ctx, conn, "udp", address, log) + default: + err = ErrUnknownCmd + log.Error(err) + resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) + resp.Write(conn) + log.Debug(resp) + return err + } +} diff --git a/handler/socks/v5/mbind.go b/handler/socks/v5/mbind.go new file mode 100644 index 0000000..5a99259 --- /dev/null +++ b/handler/socks/v5/mbind.go @@ -0,0 +1,133 @@ +package v5 + +import ( + "context" + "fmt" + "net" + "time" + + netpkg "github.com/go-gost/core/common/net" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/mux" +) + +func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "mbind", + }) + + log.Infof("%s >> %s", conn.RemoteAddr(), address) + + if !h.md.enableBind { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(reply) + log.Error("socks5: BIND is disabled") + return reply.Write(conn) + } + + return h.muxBindLocal(ctx, conn, network, address, log) +} + +func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + if err != nil { + log.Error(err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + if err := reply.Write(conn); err != nil { + log.Error(err) + } + log.Debug(reply) + return err + } + + socksAddr := gosocks5.Addr{} + err = socksAddr.ParseFrom(ln.Addr().String()) + if err != nil { + log.Warn(err) + } + + // Issue: may not reachable when host has multi-interface + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + socksAddr.Type = 0 + reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr) + if err := reply.Write(conn); err != nil { + log.Error(err) + ln.Close() + return err + } + log.Debug(reply) + + log = log.WithFields(map[string]any{ + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), + }) + + log.Debugf("bind on %s OK", ln.Addr()) + + return h.serveMuxBind(ctx, conn, ln, log) +} + +func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error { + // Upgrade connection to multiplex stream. + session, err := mux.ClientSession(conn) + if err != nil { + log.Error(err) + return err + } + defer session.Close() + + go func() { + defer ln.Close() + for { + conn, err := session.Accept() + if err != nil { + log.Error(err) + return + } + conn.Close() // we do not handle incoming connections. + } + }() + + for { + rc, err := ln.Accept() + if err != nil { + log.Error(err) + return err + } + log.Debugf("peer %s accepted", rc.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + + log = log.WithFields(map[string]any{ + "local": rc.LocalAddr().String(), + "remote": rc.RemoteAddr().String(), + }) + sc, err := session.GetConn() + if err != nil { + log.Error(err) + return + } + defer sc.Close() + + // incompatible with GOST v2.x + if !h.md.compatibilityMode { + addr := gosocks5.Addr{} + addr.ParseFrom(c.RemoteAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, &addr) + if err := reply.Write(sc); err != nil { + log.Error(err) + return + } + log.Debug(reply) + } + + t := time.Now() + log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) + netpkg.Transport(sc, c) + log.WithFields(map[string]any{"duration": time.Since(t)}). + Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) + }(rc) + } +} diff --git a/handler/socks/v5/metadata.go b/handler/socks/v5/metadata.go new file mode 100644 index 0000000..f5766a3 --- /dev/null +++ b/handler/socks/v5/metadata.go @@ -0,0 +1,44 @@ +package v5 + +import ( + "math" + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +type metadata struct { + readTimeout time.Duration + noTLS bool + enableBind bool + enableUDP bool + udpBufferSize int + compatibilityMode bool +} + +func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + noTLS = "notls" + enableBind = "bind" + enableUDP = "udp" + udpBufferSize = "udpBufferSize" + compatibilityMode = "comp" + ) + + h.md.readTimeout = mdx.GetDuration(md, readTimeout) + h.md.noTLS = mdx.GetBool(md, noTLS) + h.md.enableBind = mdx.GetBool(md, enableBind) + h.md.enableUDP = mdx.GetBool(md, enableUDP) + + if bs := mdx.GetInt(md, udpBufferSize); bs > 0 { + h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) + } else { + h.md.udpBufferSize = 1500 + } + + h.md.compatibilityMode = mdx.GetBool(md, compatibilityMode) + + return nil +} diff --git a/handler/socks/v5/selector.go b/handler/socks/v5/selector.go new file mode 100644 index 0000000..66d2826 --- /dev/null +++ b/handler/socks/v5/selector.go @@ -0,0 +1,90 @@ +package v5 + +import ( + "crypto/tls" + "net" + + "github.com/go-gost/core/auth" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" +) + +type serverSelector struct { + methods []uint8 + Authenticator auth.Authenticator + TLSConfig *tls.Config + logger logger.Logger + noTLS bool +} + +func (selector *serverSelector) Methods() []uint8 { + return selector.methods +} + +func (s *serverSelector) Select(methods ...uint8) (method uint8) { + s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods) + method = gosocks5.MethodNoAuth + for _, m := range methods { + if m == socks.MethodTLS && !s.noTLS { + method = m + break + } + } + + // when Authenticator is set, auth is mandatory + if s.Authenticator != nil { + if method == gosocks5.MethodNoAuth { + method = gosocks5.MethodUserPass + } + if method == socks.MethodTLS && !s.noTLS { + method = socks.MethodTLSAuth + } + } + + return +} + +func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + s.logger.Debugf("%d %d", gosocks5.Ver5, method) + switch method { + case socks.MethodTLS: + conn = tls.Server(conn, s.TLSConfig) + + case gosocks5.MethodUserPass, socks.MethodTLSAuth: + if method == socks.MethodTLSAuth { + conn = tls.Server(conn, s.TLSConfig) + } + + req, err := gosocks5.ReadUserPassRequest(conn) + if err != nil { + s.logger.Error(err) + return nil, err + } + s.logger.Debug(req) + + if s.Authenticator != nil && + !s.Authenticator.Authenticate(req.Username, req.Password) { + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) + if err := resp.Write(conn); err != nil { + s.logger.Error(err) + return nil, err + } + s.logger.Info(resp) + + return nil, gosocks5.ErrAuthFailure + } + + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) + if err := resp.Write(conn); err != nil { + s.logger.Error(err) + return nil, err + } + s.logger.Debug(resp) + + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} diff --git a/handler/socks/v5/udp.go b/handler/socks/v5/udp.go new file mode 100644 index 0000000..d56b4e3 --- /dev/null +++ b/handler/socks/v5/udp.go @@ -0,0 +1,85 @@ +package v5 + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "time" + + "github.com/go-gost/core/common/net/relay" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" +) + +func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "cmd": "udp", + }) + + if !h.md.enableUDP { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(reply) + log.Error("socks5: UDP relay is disabled") + return reply.Write(conn) + } + + cc, err := net.ListenUDP("udp", nil) + if err != nil { + log.Error(err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + log.Debug(reply) + return err + } + defer cc.Close() + + saddr := gosocks5.Addr{} + saddr.ParseFrom(cc.LocalAddr().String()) + saddr.Type = 0 + saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's + reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) + if err := reply.Write(conn); err != nil { + log.Error(err) + return err + } + log.Debug(reply) + + log = log.WithFields(map[string]any{ + "bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()), + }) + log.Debugf("bind on %s OK", cc.LocalAddr()) + + // obtain a udp connection + c, err := h.router.Dial(ctx, "udp", "") // UDP association + if err != nil { + log.Error(err) + return err + } + defer c.Close() + + pc, ok := c.(net.PacketConn) + if !ok { + err := errors.New("socks5: wrong connection type") + log.Error(err) + return err + } + + r := relay.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). + WithBypass(h.options.Bypass). + WithLogger(log) + r.SetBufferSize(h.md.udpBufferSize) + + go r.Run() + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) + io.Copy(ioutil.Discard, conn) + log.WithFields(map[string]any{"duration": time.Since(t)}). + Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) + + return nil +} diff --git a/handler/socks/v5/udp_tun.go b/handler/socks/v5/udp_tun.go new file mode 100644 index 0000000..deb4028 --- /dev/null +++ b/handler/socks/v5/udp_tun.go @@ -0,0 +1,72 @@ +package v5 + +import ( + "context" + "net" + "time" + + "github.com/go-gost/core/common/net/relay" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" + "github.com/go-gost/x/internal/util/socks" +) + +func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "cmd": "udp-tun", + }) + + bindAddr, _ := net.ResolveUDPAddr(network, address) + if bindAddr == nil { + bindAddr = &net.UDPAddr{} + } + + if bindAddr.Port == 0 { + // relay mode + if !h.md.enableUDP { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(reply) + log.Error("socks5: UDP relay is disabled") + return reply.Write(conn) + } + } else { + // BIND mode + if !h.md.enableBind { + reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) + log.Debug(reply) + log.Error("socks5: BIND is disabled") + return reply.Write(conn) + } + } + + pc, err := net.ListenUDP(network, bindAddr) + if err != nil { + log.Error(err) + return err + } + defer pc.Close() + + saddr := gosocks5.Addr{} + saddr.ParseFrom(pc.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) + if err := reply.Write(conn); err != nil { + log.Error(err) + return err + } + log.Debug(reply) + log.Debugf("bind on %s OK", pc.LocalAddr()) + + r := relay.NewUDPRelay(socks.UDPTunServerConn(conn), pc). + WithBypass(h.options.Bypass). + WithLogger(log) + r.SetBufferSize(h.md.udpBufferSize) + + t := time.Now() + log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + r.Run() + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) + + return nil +} diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 0c72477..df7bb52 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -11,9 +11,9 @@ import ( netpkg "github.com/go-gost/core/common/net" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/gosocks5" "github.com/go-gost/x/internal/util/ss" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" ) diff --git a/handler/ss/metadata.go b/handler/ss/metadata.go index 035e85e..b31c455 100644 --- a/handler/ss/metadata.go +++ b/handler/ss/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -17,8 +18,8 @@ func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { readTimeout = "readTimeout" ) - h.md.key = mdata.GetString(md, key) - h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.key = mdx.GetString(md, key) + h.md.readTimeout = mdx.GetDuration(md, readTimeout) return } diff --git a/handler/ss/udp/handler.go b/handler/ss/udp/handler.go index 1a755ee..a9f2da2 100644 --- a/handler/ss/udp/handler.go +++ b/handler/ss/udp/handler.go @@ -11,9 +11,9 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/x/internal/util/relay" "github.com/go-gost/x/internal/util/ss" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" ) diff --git a/handler/ss/udp/metadata.go b/handler/ss/udp/metadata.go index 258264a..a687d31 100644 --- a/handler/ss/udp/metadata.go +++ b/handler/ss/udp/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -20,10 +21,10 @@ func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) { bufferSize = "bufferSize" ) - h.md.key = mdata.GetString(md, key) - h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.key = mdx.GetString(md, key) + h.md.readTimeout = mdx.GetDuration(md, readTimeout) - if bs := mdata.GetInt(md, bufferSize); bs > 0 { + if bs := mdx.GetInt(md, bufferSize); bs > 0 { h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) } else { h.md.bufferSize = 1500 diff --git a/handler/sshd/handler.go b/handler/sshd/handler.go index 6dfa0c0..c55ca2f 100644 --- a/handler/sshd/handler.go +++ b/handler/sshd/handler.go @@ -14,8 +14,8 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" sshd_util "github.com/go-gost/x/internal/util/sshd" + "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" ) diff --git a/handler/tap/handler.go b/handler/tap/handler.go index a54b8af..ac8cd26 100644 --- a/handler/tap/handler.go +++ b/handler/tap/handler.go @@ -15,9 +15,9 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/x/internal/util/ss" tap_util "github.com/go-gost/x/internal/util/tap" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/shadowaead" "github.com/songgao/water/waterutil" diff --git a/handler/tap/metadata.go b/handler/tap/metadata.go index 9410964..edf98f5 100644 --- a/handler/tap/metadata.go +++ b/handler/tap/metadata.go @@ -2,6 +2,7 @@ package tap import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,8 +16,8 @@ func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { bufferSize = "bufferSize" ) - h.md.key = mdata.GetString(md, key) - h.md.bufferSize = mdata.GetInt(md, bufferSize) + h.md.key = mdx.GetString(md, key) + h.md.bufferSize = mdx.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { h.md.bufferSize = 1500 } diff --git a/handler/tun/handler.go b/handler/tun/handler.go index 0266106..e605af1 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -15,9 +15,9 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/registry" "github.com/go-gost/x/internal/util/ss" tun_util "github.com/go-gost/x/internal/util/tun" + "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/shadowaead" "github.com/songgao/water/waterutil" diff --git a/handler/tun/metadata.go b/handler/tun/metadata.go index 8afb4dd..386603c 100644 --- a/handler/tun/metadata.go +++ b/handler/tun/metadata.go @@ -2,6 +2,7 @@ package tun import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,8 +16,8 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { bufferSize = "bufferSize" ) - h.md.key = mdata.GetString(md, key) - h.md.bufferSize = mdata.GetInt(md, bufferSize) + h.md.key = mdx.GetString(md, key) + h.md.bufferSize = mdx.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { h.md.bufferSize = 1500 } diff --git a/hosts/hosts.go b/hosts/hosts.go new file mode 100644 index 0000000..a314528 --- /dev/null +++ b/hosts/hosts.go @@ -0,0 +1,127 @@ +package hosts + +import ( + "net" + "strings" + "sync" + + "github.com/go-gost/core/logger" +) + +type hostMapping struct { + IPs []net.IP + Hostname string +} + +// Hosts is a static table lookup for hostnames. +// For each host a single line should be present with the following information: +// IP_address canonical_hostname [aliases...] +// Fields of the entry are separated by any number of blanks and/or tab characters. +// Text from a "#" character until the end of the line is a comment, and is ignored. +type Hosts struct { + mappings sync.Map + Logger logger.Logger +} + +func NewHosts() *Hosts { + return &Hosts{} +} + +// Map maps ip to hostname or aliases. +func (h *Hosts) Map(ip net.IP, hostname string, aliases ...string) { + if hostname == "" { + return + } + + v, _ := h.mappings.Load(hostname) + m, _ := v.(*hostMapping) + if m == nil { + m = &hostMapping{ + IPs: []net.IP{ip}, + Hostname: hostname, + } + } else { + m.IPs = append(m.IPs, ip) + } + h.mappings.Store(hostname, m) + + for _, alias := range aliases { + // indirect mapping from alias to hostname + if alias != "" { + h.mappings.Store(alias, &hostMapping{ + Hostname: hostname, + }) + } + } +} + +// Lookup searches the IP address corresponds to the given network and host from the host table. +// The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'. +// the host should be a hostname (example.org) or a hostname with dot prefix (.example.org). +func (h *Hosts) Lookup(network, host string) (ips []net.IP, ok bool) { + m := h.lookup(host) + if m == nil { + m = h.lookup("." + host) + } + if m == nil { + s := host + for { + if index := strings.IndexByte(s, '.'); index > 0 { + m = h.lookup(s[index:]) + s = s[index+1:] + if m == nil { + continue + } + } + break + } + } + + if m == nil { + return + } + + // hostname alias + if !strings.HasPrefix(m.Hostname, ".") && host != m.Hostname { + m = h.lookup(m.Hostname) + if m == nil { + return + } + } + + switch network { + case "ip4": + for _, ip := range m.IPs { + if ip = ip.To4(); ip != nil { + ips = append(ips, ip) + } + } + case "ip6": + for _, ip := range m.IPs { + if ip.To4() == nil { + ips = append(ips, ip) + } + } + default: + ips = m.IPs + } + + if len(ips) > 0 { + h.Logger.Debugf("host mapper: %s -> %s", host, ips) + } + + return +} + +func (h *Hosts) lookup(host string) *hostMapping { + if h == nil || host == "" { + return nil + } + + v, ok := h.mappings.Load(host) + if !ok { + return nil + } + m, _ := v.(*hostMapping) + return m +} diff --git a/internal/util/matcher/matcher.go b/internal/util/matcher/matcher.go new file mode 100644 index 0000000..81442d3 --- /dev/null +++ b/internal/util/matcher/matcher.go @@ -0,0 +1,99 @@ +package matcher + +import ( + "net" + "strings" + + "github.com/gobwas/glob" +) + +// Matcher is a generic pattern matcher, +// it gives the match result of the given pattern for specific v. +type Matcher interface { + Match(v string) bool +} + +// NewMatcher creates a Matcher for the given pattern. +// The acutal Matcher depends on the pattern: +// IP Matcher if pattern is a valid IP address. +// CIDR Matcher if pattern is a valid CIDR address. +// Domain Matcher if both of the above are not. +func NewMatcher(pattern string) Matcher { + if pattern == "" { + return nil + } + if ip := net.ParseIP(pattern); ip != nil { + return IPMatcher(ip) + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + return CIDRMatcher(inet) + } + return DomainMatcher(pattern) +} + +type ipMatcher struct { + ip net.IP +} + +// IPMatcher creates a Matcher for a specific IP address. +func IPMatcher(ip net.IP) Matcher { + return &ipMatcher{ + ip: ip, + } +} + +func (m *ipMatcher) Match(ip string) bool { + if m == nil { + return false + } + return m.ip.Equal(net.ParseIP(ip)) +} + +type cidrMatcher struct { + ipNet *net.IPNet +} + +// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. +func CIDRMatcher(inet *net.IPNet) Matcher { + return &cidrMatcher{ + ipNet: inet, + } +} + +func (m *cidrMatcher) Match(ip string) bool { + if m == nil || m.ipNet == nil { + return false + } + return m.ipNet.Contains(net.ParseIP(ip)) +} + +type domainMatcher struct { + pattern string + glob glob.Glob +} + +// DomainMatcher creates a Matcher for a specific domain pattern, +// the pattern can be a plain domain such as 'example.com', +// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. +func DomainMatcher(pattern string) Matcher { + p := pattern + if strings.HasPrefix(pattern, ".") { + p = pattern[1:] // trim the prefix '.' + pattern = "*" + p + } + return &domainMatcher{ + pattern: p, + glob: glob.MustCompile(pattern), + } +} + +func (m *domainMatcher) Match(domain string) bool { + if m == nil || m.glob == nil { + return false + } + + if domain == m.pattern { + return true + } + return m.glob.Match(domain) +} diff --git a/internal/util/resolver/cache.go b/internal/util/resolver/cache.go new file mode 100644 index 0000000..177fe2f --- /dev/null +++ b/internal/util/resolver/cache.go @@ -0,0 +1,88 @@ +package resolver + +import ( + "fmt" + "sync" + "time" + + "github.com/go-gost/core/logger" + "github.com/miekg/dns" +) + +type CacheKey string + +// NewCacheKey generates resolver cache key from question of dns query. +func NewCacheKey(q *dns.Question) CacheKey { + if q == nil { + return "" + } + key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String()) + return CacheKey(key) +} + +type cacheItem struct { + msg *dns.Msg + ts time.Time + ttl time.Duration +} + +type Cache struct { + m sync.Map + logger logger.Logger +} + +func NewCache() *Cache { + return &Cache{} +} + +func (c *Cache) WithLogger(logger logger.Logger) *Cache { + c.logger = logger + return c +} + +func (c *Cache) Load(key CacheKey) *dns.Msg { + v, ok := c.m.Load(key) + if !ok { + return nil + } + + item, ok := v.(*cacheItem) + if !ok { + return nil + } + + if time.Since(item.ts) > item.ttl { + c.m.Delete(key) + return nil + } + + c.logger.Debugf("hit resolver cache: %s", key) + + return item.msg.Copy() +} + +func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) { + if key == "" || mr == nil || ttl < 0 { + return + } + + if ttl == 0 { + for _, answer := range mr.Answer { + v := time.Duration(answer.Header().Ttl) * time.Second + if ttl == 0 || ttl > v { + ttl = v + } + } + } + if ttl == 0 { + ttl = 30 * time.Second + } + + c.m.Store(key, &cacheItem{ + msg: mr.Copy(), + ts: time.Now(), + ttl: ttl, + }) + + c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl) +} diff --git a/internal/util/resolver/resolver.go b/internal/util/resolver/resolver.go new file mode 100644 index 0000000..74ec536 --- /dev/null +++ b/internal/util/resolver/resolver.go @@ -0,0 +1,30 @@ +package resolver + +import ( + "net" + + "github.com/miekg/dns" +) + +func AddSubnetOpt(m *dns.Msg, ip net.IP) { + if m == nil || ip == nil { + return + } + + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + if ip := ip.To4(); ip != nil { + e.Family = 1 + e.SourceNetmask = 24 + e.Address = ip + } else { + e.Family = 2 + e.SourceNetmask = 128 + e.Address = ip.To16() + } + opt.Option = append(opt.Option, e) + m.Extra = append(m.Extra, opt) +} diff --git a/internal/util/socks/conn.go b/internal/util/socks/conn.go new file mode 100644 index 0000000..69316e8 --- /dev/null +++ b/internal/util/socks/conn.go @@ -0,0 +1,172 @@ +package socks + +import ( + "bytes" + "net" + + "github.com/go-gost/core/common/bufpool" + "github.com/go-gost/gosocks5" +) + +type udpTunConn struct { + net.Conn + taddr net.Addr +} + +func UDPTunClientConn(c net.Conn, targetAddr net.Addr) net.Conn { + return &udpTunConn{ + Conn: c, + taddr: targetAddr, + } +} + +func UDPTunClientPacketConn(c net.Conn) net.PacketConn { + return &udpTunConn{ + Conn: c, + } +} + +func UDPTunServerConn(c net.Conn) net.PacketConn { + return &udpTunConn{ + Conn: c, + } +} + +func (c *udpTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + socksAddr := gosocks5.Addr{} + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + _, err = dgram.ReadFrom(c.Conn) + if err != nil { + return + } + + n = len(dgram.Data) + if n > len(b) { + n = copy(b, dgram.Data) + } + addr, err = net.ResolveUDPAddr("udp", socksAddr.String()) + + return +} + +func (c *udpTunConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *udpTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + dgram.Header.Rsv = uint16(len(dgram.Data)) + dgram.Header.Frag = 0xff // UDP tun relay flag, used by shadowsocks + _, err = dgram.WriteTo(c.Conn) + n = len(b) + + return +} + +func (c *udpTunConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +var ( + DefaultBufferSize = 4096 +) + +type udpConn struct { + net.PacketConn + raddr net.Addr + taddr net.Addr + bufferSize int +} + +func UDPConn(c net.PacketConn, bufferSize int) net.PacketConn { + return &udpConn{ + PacketConn: c, + bufferSize: bufferSize, + } +} + +// ReadFrom reads an UDP datagram. +// NOTE: for server side, +// the returned addr is the target address the client want to relay to. +func (c *udpConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + rbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(rbuf) + + n, c.raddr, err = c.PacketConn.ReadFrom(*rbuf) + if err != nil { + return + } + + socksAddr := gosocks5.Addr{} + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + hlen, err := header.ReadFrom(bytes.NewReader((*rbuf)[:n])) + if err != nil { + return + } + n = copy(b, (*rbuf)[hlen:n]) + + addr, err = net.ResolveUDPAddr("udp", socksAddr.String()) + return +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *udpConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + wbuf := bufpool.Get(c.bufferSize) + defer bufpool.Put(wbuf) + + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + + buf := bytes.NewBuffer((*wbuf)[:0]) + _, err = dgram.WriteTo(buf) + if err != nil { + return + } + + _, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr) + n = len(b) + + return +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *udpConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/internal/util/socks/socks.go b/internal/util/socks/socks.go new file mode 100644 index 0000000..91d241d --- /dev/null +++ b/internal/util/socks/socks.go @@ -0,0 +1,18 @@ +package socks + +const ( + // MethodTLS is an extended SOCKS5 method with tls encryption support. + MethodTLS uint8 = 0x80 + // MethodTLSAuth is an extended SOCKS5 method with tls encryption and authentication support. + MethodTLSAuth uint8 = 0x82 + // MethodMux is an extended SOCKS5 method for stream multiplexing. + MethodMux = 0x88 +) + +const ( + // CmdMuxBind is an extended SOCKS5 request CMD for + // multiplexing transport with the binding server. + CmdMuxBind uint8 = 0xF2 + // CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP. + CmdUDPTun uint8 = 0xF3 +) diff --git a/internal/util/tls/tls.go b/internal/util/tls/tls.go new file mode 100644 index 0000000..50a7bbb --- /dev/null +++ b/internal/util/tls/tls.go @@ -0,0 +1,173 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" + "net" + "time" +) + +// LoadServerConfig loads the certificate from cert & key files and optional client CA file. +func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + if certFile == "" && keyFile == "" { + return nil, nil + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + + pool, err := loadCA(caFile) + if err != nil { + return nil, err + } + if pool != nil { + cfg.ClientCAs = pool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, nil +} + +// LoadClientConfig loads the certificate from cert & key files and optional CA file. +func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) { + var cfg *tls.Config + + if certFile == "" && keyFile == "" { + cfg = &tls.Config{} + } else { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + + rootCAs, err := loadCA(caFile) + if err != nil { + return nil, err + } + + cfg.RootCAs = rootCAs + cfg.ServerName = serverName + cfg.InsecureSkipVerify = !verify + + // If the root ca is given, but skip verify, we verify the certificate manually. + if cfg.RootCAs != nil && !verify { + cfg.VerifyConnection = func(state tls.ConnectionState) error { + opts := x509.VerifyOptions{ + Roots: cfg.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := state.PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err := certs[0].Verify(opts) + return err + } + } + + return cfg, nil +} + +func loadCA(caFile string) (cp *x509.CertPool, err error) { + if caFile == "" { + return + } + cp = x509.NewCertPool() + data, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + if !cp.AppendCertsFromPEM(data) { + return nil, errors.New("AppendCertsFromPEM failed") + } + return +} + +// Wrap a net.Conn into a client tls connection, performing any +// additional verification as needed. +// +// As of go 1.3, crypto/tls only supports either doing no certificate +// verification, or doing full verification including of the peer's +// DNS name. For consul, we want to validate that the certificate is +// signed by a known CA, but because consul doesn't use DNS names for +// node names, we don't verify the certificate DNS names. Since go 1.3 +// no longer supports this mode of operation, we have to do it +// manually. +// +// This code is taken from consul: +// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go +func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + var err error + var tlsConn *tls.Conn + + if timeout > 0 { + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + } + + tlsConn = tls.Client(conn, tlsConfig) + + // Otherwise perform handshake, but don't verify the domain + // + // The following is lightly-modified from the doFullHandshake + // method in https://golang.org/src/crypto/tls/handshake_client.go + if err = tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + + // We can do this in `tls.Config.VerifyConnection`, which effective for + // other TLS protocols such as WebSocket. See `route.go:parseChainNode` + /* + // If crypto/tls is doing verification, there's no need to do our own. + if tlsConfig.InsecureSkipVerify == false { + return tlsConn, nil + } + + // Similarly if we use host's CA, we can do full handshake + if tlsConfig.RootCAs == nil { + return tlsConn, nil + } + + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := tlsConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + if err != nil { + tlsConn.Close() + return nil, err + } + */ + + return tlsConn, err +} diff --git a/listener/dns/listener.go b/listener/dns/listener.go index 4894b3f..35a53b7 100644 --- a/listener/dns/listener.go +++ b/listener/dns/listener.go @@ -13,7 +13,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "github.com/miekg/dns" ) diff --git a/listener/dns/metadata.go b/listener/dns/metadata.go index c61c064..68fd7d4 100644 --- a/listener/dns/metadata.go +++ b/listener/dns/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -27,12 +28,12 @@ func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { writeTimeout = "writeTimeout" ) - l.md.mode = mdata.GetString(md, mode) - l.md.readBufferSize = mdata.GetInt(md, readBufferSize) - l.md.readTimeout = mdata.GetDuration(md, readTimeout) - l.md.writeTimeout = mdata.GetDuration(md, writeTimeout) + l.md.mode = mdx.GetString(md, mode) + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) + l.md.readTimeout = mdx.GetDuration(md, readTimeout) + l.md.writeTimeout = mdx.GetDuration(md, writeTimeout) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/ftcp/listener.go b/listener/ftcp/listener.go index eab7b4f..9037424 100644 --- a/listener/ftcp/listener.go +++ b/listener/ftcp/listener.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "github.com/xtaci/tcpraw" ) diff --git a/listener/ftcp/metadata.go b/listener/ftcp/metadata.go index b247386..7cec80f 100644 --- a/listener/ftcp/metadata.go +++ b/listener/ftcp/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -29,21 +30,21 @@ func (l *ftcpListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.ttl = mdata.GetDuration(md, ttl) + l.md.ttl = mdx.GetDuration(md, ttl) if l.md.ttl <= 0 { l.md.ttl = defaultTTL } - l.md.readBufferSize = mdata.GetInt(md, readBufferSize) + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) if l.md.readBufferSize <= 0 { l.md.readBufferSize = defaultReadBufferSize } - l.md.readQueueSize = mdata.GetInt(md, readQueueSize) + l.md.readQueueSize = mdx.GetInt(md, readQueueSize) if l.md.readQueueSize <= 0 { l.md.readQueueSize = defaultReadQueueSize } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/grpc/listener.go b/listener/grpc/listener.go index 35fa371..4b65c5b 100644 --- a/listener/grpc/listener.go +++ b/listener/grpc/listener.go @@ -8,8 +8,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" pb "github.com/go-gost/x/internal/util/grpc/proto" + "github.com/go-gost/x/registry" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) diff --git a/listener/grpc/metadata.go b/listener/grpc/metadata.go index 4164c5e..6b9e796 100644 --- a/listener/grpc/metadata.go +++ b/listener/grpc/metadata.go @@ -2,6 +2,7 @@ package grpc import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -19,11 +20,11 @@ func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) { insecure = "grpcInsecure" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.insecure = mdata.GetBool(md, insecure) + l.md.insecure = mdx.GetBool(md, insecure) return } diff --git a/listener/http2/h2/listener.go b/listener/http2/h2/listener.go index b9133d7..f4dcb2d 100644 --- a/listener/http2/h2/listener.go +++ b/listener/http2/h2/listener.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) diff --git a/listener/http2/h2/metadata.go b/listener/http2/h2/metadata.go index 33f32df..994fdf1 100644 --- a/listener/http2/h2/metadata.go +++ b/listener/http2/h2/metadata.go @@ -2,6 +2,7 @@ package h2 import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -19,11 +20,11 @@ func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.path = mdata.GetString(md, path) + l.md.path = mdx.GetString(md, path) return } diff --git a/listener/http2/listener.go b/listener/http2/listener.go index 0568237..40a4a7a 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -10,7 +10,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" "golang.org/x/net/http2" ) @@ -111,10 +112,10 @@ func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { laddr: l.addr, raddr: raddr, closed: make(chan struct{}), - md: md.MapMetadata{ + md: mdx.NewMetadata(map[string]any{ "r": r, "w": w, - }, + }), } select { case l.cqueue <- conn: diff --git a/listener/http2/metadata.go b/listener/http2/metadata.go index 095d7b7..dfa7793 100644 --- a/listener/http2/metadata.go +++ b/listener/http2/metadata.go @@ -2,6 +2,7 @@ package http2 import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -17,7 +18,7 @@ func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/http3/listener.go b/listener/http3/listener.go index f410d46..d54abdf 100644 --- a/listener/http3/listener.go +++ b/listener/http3/listener.go @@ -7,8 +7,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" pht_util "github.com/go-gost/x/internal/util/pht" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" ) diff --git a/listener/http3/metadata.go b/listener/http3/metadata.go index ed04874..9f23c81 100644 --- a/listener/http3/metadata.go +++ b/listener/http3/metadata.go @@ -4,6 +4,7 @@ import ( "strings" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -29,20 +30,20 @@ func (l *http3Listener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.authorizePath = mdata.GetString(md, authorizePath) + l.md.authorizePath = mdx.GetString(md, authorizePath) if !strings.HasPrefix(l.md.authorizePath, "/") { l.md.authorizePath = defaultAuthorizePath } - l.md.pushPath = mdata.GetString(md, pushPath) + l.md.pushPath = mdx.GetString(md, pushPath) if !strings.HasPrefix(l.md.pushPath, "/") { l.md.pushPath = defaultPushPath } - l.md.pullPath = mdata.GetString(md, pullPath) + l.md.pullPath = mdx.GetString(md, pullPath) if !strings.HasPrefix(l.md.pullPath, "/") { l.md.pullPath = defaultPullPath } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/icmp/listener.go b/listener/icmp/listener.go index 9b7904d..a38aa6b 100644 --- a/listener/icmp/listener.go +++ b/listener/icmp/listener.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" icmp_pkg "github.com/go-gost/x/internal/util/icmp" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" "golang.org/x/net/icmp" ) diff --git a/listener/icmp/metadata.go b/listener/icmp/metadata.go index 7e53cf5..c731b9c 100644 --- a/listener/icmp/metadata.go +++ b/listener/icmp/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -28,14 +29,14 @@ func (l *icmpListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.keepAlive = mdata.GetBool(md, keepAlive) - l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + l.md.keepAlive = mdx.GetBool(md, keepAlive) + l.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + l.md.maxIdleTimeout = mdx.GetDuration(md, maxIdleTimeout) return } diff --git a/listener/kcp/listener.go b/listener/kcp/listener.go index e018d83..901b101 100644 --- a/listener/kcp/listener.go +++ b/listener/kcp/listener.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" kcp_util "github.com/go-gost/x/internal/util/kcp" + "github.com/go-gost/x/registry" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" "github.com/xtaci/tcpraw" diff --git a/listener/kcp/metadata.go b/listener/kcp/metadata.go index 454a696..cdc7e46 100644 --- a/listener/kcp/metadata.go +++ b/listener/kcp/metadata.go @@ -5,6 +5,7 @@ import ( mdata "github.com/go-gost/core/metadata" kcp_util "github.com/go-gost/x/internal/util/kcp" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -22,7 +23,7 @@ func (l *kcpListener) parseMetadata(md mdata.Metadata) (err error) { config = "config" ) - if m := mdata.GetStringMap(md, config); len(m) > 0 { + if m := mdx.GetStringMap(md, config); len(m) > 0 { b, err := json.Marshal(m) if err != nil { return err @@ -38,7 +39,7 @@ func (l *kcpListener) parseMetadata(md mdata.Metadata) (err error) { l.md.config = kcp_util.DefaultConfig } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/mtls/listener.go b/listener/mtls/listener.go index ffc1bda..23ed687 100644 --- a/listener/mtls/listener.go +++ b/listener/mtls/listener.go @@ -9,7 +9,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" "github.com/xtaci/smux" ) diff --git a/listener/mtls/metadata.go b/listener/mtls/metadata.go index 72bd80d..f7d9366 100644 --- a/listener/mtls/metadata.go +++ b/listener/mtls/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -33,17 +34,17 @@ func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) - l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) + l.md.muxKeepAliveDisabled = mdx.GetBool(md, muxKeepAliveDisabled) + l.md.muxKeepAliveInterval = mdx.GetDuration(md, muxKeepAliveInterval) + l.md.muxKeepAliveTimeout = mdx.GetDuration(md, muxKeepAliveTimeout) + l.md.muxMaxFrameSize = mdx.GetInt(md, muxMaxFrameSize) + l.md.muxMaxReceiveBuffer = mdx.GetInt(md, muxMaxReceiveBuffer) + l.md.muxMaxStreamBuffer = mdx.GetInt(md, muxMaxStreamBuffer) return } diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 74f43bf..004c68a 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" ws_util "github.com/go-gost/x/internal/util/ws" + "github.com/go-gost/x/registry" "github.com/gorilla/websocket" "github.com/xtaci/smux" ) diff --git a/listener/mws/metadata.go b/listener/mws/metadata.go index 084d7bd..871393b 100644 --- a/listener/mws/metadata.go +++ b/listener/mws/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -51,30 +52,30 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - l.md.path = mdata.GetString(md, path) + l.md.path = mdx.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) - l.md.readBufferSize = mdata.GetInt(md, readBufferSize) - l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) - l.md.enableCompression = mdata.GetBool(md, enableCompression) + l.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + l.md.readHeaderTimeout = mdx.GetDuration(md, readHeaderTimeout) + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) + l.md.writeBufferSize = mdx.GetInt(md, writeBufferSize) + l.md.enableCompression = mdx.GetBool(md, enableCompression) - l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) - l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer) + l.md.muxKeepAliveDisabled = mdx.GetBool(md, muxKeepAliveDisabled) + l.md.muxKeepAliveInterval = mdx.GetDuration(md, muxKeepAliveInterval) + l.md.muxKeepAliveTimeout = mdx.GetDuration(md, muxKeepAliveTimeout) + l.md.muxMaxFrameSize = mdx.GetInt(md, muxMaxFrameSize) + l.md.muxMaxReceiveBuffer = mdx.GetInt(md, muxMaxReceiveBuffer) + l.md.muxMaxStreamBuffer = mdx.GetInt(md, muxMaxStreamBuffer) - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { + if mm := mdx.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/listener/obfs/http/listener.go b/listener/obfs/http/listener.go index 0c5b2b1..544f3bd 100644 --- a/listener/obfs/http/listener.go +++ b/listener/obfs/http/listener.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/listener/obfs/http/metadata.go b/listener/obfs/http/metadata.go index 67900eb..8795d35 100644 --- a/listener/obfs/http/metadata.go +++ b/listener/obfs/http/metadata.go @@ -4,6 +4,7 @@ import ( "net/http" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -15,7 +16,7 @@ func (l *obfsListener) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { + if mm := mdx.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/listener/obfs/tls/listener.go b/listener/obfs/tls/listener.go index d240768..f12488c 100644 --- a/listener/obfs/tls/listener.go +++ b/listener/obfs/tls/listener.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/listener/pht/listener.go b/listener/pht/listener.go index d7273d3..52ca012 100644 --- a/listener/pht/listener.go +++ b/listener/pht/listener.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" pht_util "github.com/go-gost/x/internal/util/pht" + "github.com/go-gost/x/registry" ) func init() { diff --git a/listener/pht/metadata.go b/listener/pht/metadata.go index 98c82dd..15906b0 100644 --- a/listener/pht/metadata.go +++ b/listener/pht/metadata.go @@ -4,6 +4,7 @@ import ( "strings" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -29,20 +30,20 @@ func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.authorizePath = mdata.GetString(md, authorizePath) + l.md.authorizePath = mdx.GetString(md, authorizePath) if !strings.HasPrefix(l.md.authorizePath, "/") { l.md.authorizePath = defaultAuthorizePath } - l.md.pushPath = mdata.GetString(md, pushPath) + l.md.pushPath = mdx.GetString(md, pushPath) if !strings.HasPrefix(l.md.pushPath, "/") { l.md.pushPath = defaultPushPath } - l.md.pullPath = mdata.GetString(md, pullPath) + l.md.pullPath = mdx.GetString(md, pullPath) if !strings.HasPrefix(l.md.pullPath, "/") { l.md.pullPath = defaultPullPath } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/quic/listener.go b/listener/quic/listener.go index 811020a..621b45b 100644 --- a/listener/quic/listener.go +++ b/listener/quic/listener.go @@ -8,8 +8,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" quic_util "github.com/go-gost/x/internal/util/quic" + "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" ) diff --git a/listener/quic/metadata.go b/listener/quic/metadata.go index 8342e17..7306b38 100644 --- a/listener/quic/metadata.go +++ b/listener/quic/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -29,18 +30,18 @@ func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) { cipherKey = "cipherKey" ) - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - if key := mdata.GetString(md, cipherKey); key != "" { + if key := mdx.GetString(md, cipherKey); key != "" { l.md.cipherKey = []byte(key) } - l.md.keepAlive = mdata.GetBool(md, keepAlive) - l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + l.md.keepAlive = mdx.GetBool(md, keepAlive) + l.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + l.md.maxIdleTimeout = mdx.GetDuration(md, maxIdleTimeout) return } diff --git a/listener/redirect/tcp/listener.go b/listener/redirect/tcp/listener.go index c59a196..1678a9e 100644 --- a/listener/redirect/tcp/listener.go +++ b/listener/redirect/tcp/listener.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/listener/redirect/tcp/metadata.go b/listener/redirect/tcp/metadata.go index 032768c..ddc5294 100644 --- a/listener/redirect/tcp/metadata.go +++ b/listener/redirect/tcp/metadata.go @@ -2,6 +2,7 @@ package tcp import ( mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) type metadata struct { @@ -12,6 +13,6 @@ func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) { const ( tproxy = "tproxy" ) - l.md.tproxy = mdata.GetBool(md, tproxy) + l.md.tproxy = mdx.GetBool(md, tproxy) return } diff --git a/listener/redirect/udp/listener.go b/listener/redirect/udp/listener.go index 3a65fd7..255921b 100644 --- a/listener/redirect/udp/listener.go +++ b/listener/redirect/udp/listener.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + "github.com/go-gost/x/registry" ) func init() { diff --git a/listener/redirect/udp/metadata.go b/listener/redirect/udp/metadata.go index 4173cf5..8c68c11 100644 --- a/listener/redirect/udp/metadata.go +++ b/listener/redirect/udp/metadata.go @@ -4,6 +4,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -22,12 +23,12 @@ func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) { readBufferSize = "readBufferSize" ) - l.md.ttl = mdata.GetDuration(md, ttl) + l.md.ttl = mdx.GetDuration(md, ttl) if l.md.ttl <= 0 { l.md.ttl = defaultTTL } - l.md.readBufferSize = mdata.GetInt(md, readBufferSize) + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) if l.md.readBufferSize <= 0 { l.md.readBufferSize = defaultReadBufferSize } diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go new file mode 100644 index 0000000..f79c181 --- /dev/null +++ b/listener/rtcp/listener.go @@ -0,0 +1,102 @@ +package rtcp + +import ( + "context" + "net" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/connector" + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("rtcp", NewListener) +} + +type rtcpListener struct { + laddr net.Addr + ln net.Listener + md metadata + router *chain.Router + logger logger.Logger + closed chan struct{} + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &rtcpListener{ + closed: make(chan struct{}), + logger: options.Logger, + options: options, + } +} + +func (l *rtcpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + laddr, err := net.ResolveTCPAddr("tcp", l.options.Addr) + if err != nil { + return + } + + l.laddr = laddr + l.router = (&chain.Router{}). + WithChain(l.options.Chain). + WithLogger(l.logger) + + return +} + +func (l *rtcpListener) Accept() (conn net.Conn, err error) { + select { + case <-l.closed: + return nil, net.ErrClosed + default: + } + + if l.ln == nil { + l.ln, err = l.router.Bind( + context.Background(), "tcp", l.laddr.String(), + connector.MuxBindOption(true), + ) + if err != nil { + return nil, listener.NewAcceptError(err) + } + l.ln = metrics.WrapListener(l.options.Service, l.ln) + } + conn, err = l.ln.Accept() + if err != nil { + l.ln.Close() + l.ln = nil + return nil, listener.NewAcceptError(err) + } + return +} + +func (l *rtcpListener) Addr() net.Addr { + return l.laddr +} + +func (l *rtcpListener) Close() error { + select { + case <-l.closed: + default: + close(l.closed) + if l.ln != nil { + l.ln.Close() + l.ln = nil + } + } + + return nil +} diff --git a/listener/rtcp/metadata.go b/listener/rtcp/metadata.go new file mode 100644 index 0000000..42d52d5 --- /dev/null +++ b/listener/rtcp/metadata.go @@ -0,0 +1,11 @@ +package rtcp + +import ( + mdata "github.com/go-gost/core/metadata" +) + +type metadata struct{} + +func (l *rtcpListener) parseMetadata(md mdata.Metadata) (err error) { + return +} diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go new file mode 100644 index 0000000..477e0a0 --- /dev/null +++ b/listener/rudp/listener.go @@ -0,0 +1,109 @@ +package rudp + +import ( + "context" + "net" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/connector" + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("rudp", NewListener) +} + +type rudpListener struct { + laddr net.Addr + ln net.Listener + router *chain.Router + closed chan struct{} + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &rudpListener{ + closed: make(chan struct{}), + logger: options.Logger, + options: options, + } +} + +func (l *rudpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) + if err != nil { + return + } + + l.laddr = laddr + l.router = (&chain.Router{}). + WithChain(l.options.Chain). + WithLogger(l.logger) + + return +} + +func (l *rudpListener) Accept() (conn net.Conn, err error) { + select { + case <-l.closed: + return nil, net.ErrClosed + default: + } + + if l.ln == nil { + l.ln, err = l.router.Bind( + context.Background(), "udp", l.laddr.String(), + connector.BacklogBindOption(l.md.backlog), + connector.UDPConnTTLBindOption(l.md.ttl), + connector.UDPDataBufferSizeBindOption(l.md.readBufferSize), + connector.UDPDataQueueSizeBindOption(l.md.readQueueSize), + ) + if err != nil { + return nil, listener.NewAcceptError(err) + } + } + conn, err = l.ln.Accept() + if err != nil { + l.ln.Close() + l.ln = nil + return nil, listener.NewAcceptError(err) + } + + if pc, ok := conn.(net.PacketConn); ok { + conn = metrics.WrapUDPConn(l.options.Service, pc) + } + + return +} + +func (l *rudpListener) Addr() net.Addr { + return l.laddr +} + +func (l *rudpListener) Close() error { + select { + case <-l.closed: + default: + close(l.closed) + if l.ln != nil { + l.ln.Close() + l.ln = nil + } + } + + return nil +} diff --git a/listener/rudp/metadata.go b/listener/rudp/metadata.go new file mode 100644 index 0000000..0bf08ca --- /dev/null +++ b/listener/rudp/metadata.go @@ -0,0 +1,52 @@ +package rudp + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +const ( + defaultTTL = 5 * time.Second + defaultReadBufferSize = 1500 + defaultReadQueueSize = 128 + defaultBacklog = 128 +) + +type metadata struct { + ttl time.Duration + readBufferSize int + readQueueSize int + backlog int +} + +func (l *rudpListener) parseMetadata(md mdata.Metadata) (err error) { + const ( + ttl = "ttl" + readBufferSize = "readBufferSize" + readQueueSize = "readQueueSize" + backlog = "backlog" + ) + + l.md.ttl = mdx.GetDuration(md, ttl) + if l.md.ttl <= 0 { + l.md.ttl = defaultTTL + } + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) + if l.md.readBufferSize <= 0 { + l.md.readBufferSize = defaultReadBufferSize + } + + l.md.readQueueSize = mdx.GetInt(md, readQueueSize) + if l.md.readQueueSize <= 0 { + l.md.readQueueSize = defaultReadQueueSize + } + + l.md.backlog = mdx.GetInt(md, backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + return +} diff --git a/listener/ssh/listener.go b/listener/ssh/listener.go index 42a366c..1d5405a 100644 --- a/listener/ssh/listener.go +++ b/listener/ssh/listener.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" ssh_util "github.com/go-gost/x/internal/util/ssh" + "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" ) diff --git a/listener/ssh/metadata.go b/listener/ssh/metadata.go index 50a5b5a..ef92cc0 100644 --- a/listener/ssh/metadata.go +++ b/listener/ssh/metadata.go @@ -3,9 +3,9 @@ package ssh import ( "io/ioutil" - tls_util "github.com/go-gost/core/common/util/tls" mdata "github.com/go-gost/core/metadata" ssh_util "github.com/go-gost/x/internal/util/ssh" + mdx "github.com/go-gost/x/metadata" "golang.org/x/crypto/ssh" ) @@ -27,13 +27,13 @@ func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - if key := mdata.GetString(md, privateKeyFile); key != "" { + if key := mdx.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := mdata.GetString(md, passphrase) + pp := mdx.GetString(md, passphrase) if pp == "" { l.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -44,14 +44,14 @@ func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { } } if l.md.signer == nil { - signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey) + signer, err := ssh.NewSignerFromKey(l.options.TLSConfig.Certificates[0].PrivateKey) if err != nil { return err } l.md.signer = signer } - if name := mdata.GetString(md, authorizedKeys); name != "" { + if name := mdx.GetString(md, authorizedKeys); name != "" { m, err := ssh_util.ParseAuthorizedKeysFile(name) if err != nil { return err @@ -59,7 +59,7 @@ func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { l.md.authorizedKeys = m } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/sshd/listener.go b/listener/sshd/listener.go index 5a344ee..64d9365 100644 --- a/listener/sshd/listener.go +++ b/listener/sshd/listener.go @@ -12,9 +12,9 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" ssh_util "github.com/go-gost/x/internal/util/ssh" sshd_util "github.com/go-gost/x/internal/util/sshd" + "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" ) diff --git a/listener/sshd/metadata.go b/listener/sshd/metadata.go index bb7ef38..d477607 100644 --- a/listener/sshd/metadata.go +++ b/listener/sshd/metadata.go @@ -3,9 +3,9 @@ package ssh import ( "io/ioutil" - tls_util "github.com/go-gost/core/common/util/tls" mdata "github.com/go-gost/core/metadata" ssh_util "github.com/go-gost/x/internal/util/ssh" + mdx "github.com/go-gost/x/metadata" "golang.org/x/crypto/ssh" ) @@ -27,13 +27,13 @@ func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - if key := mdata.GetString(md, privateKeyFile); key != "" { + if key := mdx.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { return err } - pp := mdata.GetString(md, passphrase) + pp := mdx.GetString(md, passphrase) if pp == "" { l.md.signer, err = ssh.ParsePrivateKey(data) } else { @@ -44,14 +44,14 @@ func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { } } if l.md.signer == nil { - signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey) + signer, err := ssh.NewSignerFromKey(l.options.TLSConfig.Certificates[0].PrivateKey) if err != nil { return err } l.md.signer = signer } - if name := mdata.GetString(md, authorizedKeys); name != "" { + if name := mdx.GetString(md, authorizedKeys); name != "" { m, err := ssh_util.ParseAuthorizedKeysFile(name) if err != nil { return err @@ -59,7 +59,7 @@ func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { l.md.authorizedKeys = m } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } diff --git a/listener/tap/listener.go b/listener/tap/listener.go index c60695f..2c2a613 100644 --- a/listener/tap/listener.go +++ b/listener/tap/listener.go @@ -7,7 +7,8 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" ) func init() { @@ -73,9 +74,9 @@ func (l *tapListener) Init(md mdata.Metadata) (err error) { raddr: &net.IPAddr{IP: ip}, } c = metrics.WrapConn(l.options.Service, c) - c = withMetadata(mdata.MapMetadata{ + c = withMetadata(mdx.NewMetadata(map[string]any{ "config": l.md.config, - }, c) + }), c) l.cqueue <- c diff --git a/listener/tap/metadata.go b/listener/tap/metadata.go index 6dd04c5..c26b4f6 100644 --- a/listener/tap/metadata.go +++ b/listener/tap/metadata.go @@ -3,6 +3,7 @@ package tap import ( mdata "github.com/go-gost/core/metadata" tap_util "github.com/go-gost/x/internal/util/tap" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -23,16 +24,16 @@ func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) { ) config := &tap_util.Config{ - Name: mdata.GetString(md, name), - Net: mdata.GetString(md, netKey), - MTU: mdata.GetInt(md, mtu), - Gateway: mdata.GetString(md, gateway), + Name: mdx.GetString(md, name), + Net: mdx.GetString(md, netKey), + MTU: mdx.GetInt(md, mtu), + Gateway: mdx.GetString(md, gateway), } if config.MTU <= 0 { config.MTU = DefaultMTU } - for _, s := range mdata.GetStrings(md, routes) { + for _, s := range mdx.GetStrings(md, routes) { if s != "" { config.Routes = append(config.Routes, s) } diff --git a/listener/tcp/listener.go b/listener/tcp/listener.go new file mode 100644 index 0000000..9d9f25a --- /dev/null +++ b/listener/tcp/listener.go @@ -0,0 +1,60 @@ +package tcp + +import ( + "net" + + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("tcp", NewListener) +} + +type tcpListener struct { + ln net.Listener + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &tcpListener{ + logger: options.Logger, + options: options, + } +} + +func (l *tcpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + ln, err := net.Listen("tcp", l.options.Addr) + if err != nil { + return + } + + l.ln = metrics.WrapListener(l.options.Service, ln) + + return +} + +func (l *tcpListener) Accept() (conn net.Conn, err error) { + return l.ln.Accept() +} + +func (l *tcpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *tcpListener) Close() error { + return l.ln.Close() +} diff --git a/listener/tcp/metadata.go b/listener/tcp/metadata.go new file mode 100644 index 0000000..e1d286a --- /dev/null +++ b/listener/tcp/metadata.go @@ -0,0 +1,12 @@ +package tcp + +import ( + md "github.com/go-gost/core/metadata" +) + +type metadata struct { +} + +func (l *tcpListener) parseMetadata(md md.Metadata) (err error) { + return +} diff --git a/listener/tls/listener.go b/listener/tls/listener.go new file mode 100644 index 0000000..b76422d --- /dev/null +++ b/listener/tls/listener.go @@ -0,0 +1,64 @@ +package tls + +import ( + "crypto/tls" + "net" + + admission "github.com/go-gost/core/admission/wrapper" + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("tls", NewListener) +} + +type tlsListener struct { + ln net.Listener + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &tlsListener{ + logger: options.Logger, + options: options, + } +} + +func (l *tlsListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + ln, err := net.Listen("tcp", l.options.Addr) + if err != nil { + return + } + ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) + + l.ln = tls.NewListener(ln, l.options.TLSConfig) + + return +} + +func (l *tlsListener) Accept() (conn net.Conn, err error) { + return l.ln.Accept() +} + +func (l *tlsListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *tlsListener) Close() error { + return l.ln.Close() +} diff --git a/listener/tls/metadata.go b/listener/tls/metadata.go new file mode 100644 index 0000000..d515844 --- /dev/null +++ b/listener/tls/metadata.go @@ -0,0 +1,12 @@ +package tls + +import ( + mdata "github.com/go-gost/core/metadata" +) + +type metadata struct { +} + +func (l *tlsListener) parseMetadata(md mdata.Metadata) (err error) { + return +} diff --git a/listener/tun/listener.go b/listener/tun/listener.go index bb5faf5..4d5aef1 100644 --- a/listener/tun/listener.go +++ b/listener/tun/listener.go @@ -7,7 +7,8 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" ) func init() { @@ -71,9 +72,9 @@ func (l *tunListener) Init(md mdata.Metadata) (err error) { raddr: &net.IPAddr{IP: ip}, } c = metrics.WrapConn(l.options.Service, c) - c = withMetadata(mdata.MapMetadata{ + c = withMetadata(mdx.NewMetadata(map[string]any{ "config": l.md.config, - }, c) + }), c) l.cqueue <- c diff --git a/listener/tun/metadata.go b/listener/tun/metadata.go index 6a4c64a..055128a 100644 --- a/listener/tun/metadata.go +++ b/listener/tun/metadata.go @@ -6,6 +6,7 @@ import ( mdata "github.com/go-gost/core/metadata" tun_util "github.com/go-gost/x/internal/util/tun" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -27,11 +28,11 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { ) config := &tun_util.Config{ - Name: mdata.GetString(md, name), - Net: mdata.GetString(md, netKey), - Peer: mdata.GetString(md, peer), - MTU: mdata.GetInt(md, mtu), - Gateway: mdata.GetString(md, gateway), + Name: mdx.GetString(md, name), + Net: mdx.GetString(md, netKey), + Peer: mdx.GetString(md, peer), + MTU: mdx.GetInt(md, mtu), + Gateway: mdx.GetString(md, gateway), } if config.MTU <= 0 { config.MTU = DefaultMTU @@ -39,7 +40,7 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { gw := net.ParseIP(config.Gateway) - for _, s := range mdata.GetStrings(md, routes) { + for _, s := range mdx.GetStrings(md, routes) { ss := strings.SplitN(s, " ", 2) if len(ss) == 2 { var route tun_util.Route diff --git a/listener/udp/listener.go b/listener/udp/listener.go new file mode 100644 index 0000000..c167d90 --- /dev/null +++ b/listener/udp/listener.go @@ -0,0 +1,74 @@ +package udp + +import ( + "net" + + "github.com/go-gost/core/common/net/udp" + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("udp", NewListener) +} + +type udpListener struct { + ln net.Listener + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &udpListener{ + logger: options.Logger, + options: options, + } +} + +func (l *udpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) + if err != nil { + return + } + + var conn net.PacketConn + conn, err = net.ListenUDP("udp", laddr) + if err != nil { + return + } + conn = metrics.WrapPacketConn(l.options.Service, conn) + + l.ln = udp.NewListener(conn, &udp.ListenConfig{ + Backlog: l.md.backlog, + ReadQueueSize: l.md.readQueueSize, + ReadBufferSize: l.md.readBufferSize, + KeepAlive: l.md.keepalive, + TTL: l.md.ttl, + Logger: l.logger, + }) + return +} + +func (l *udpListener) Accept() (conn net.Conn, err error) { + return l.ln.Accept() +} + +func (l *udpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *udpListener) Close() error { + return l.ln.Close() +} diff --git a/listener/udp/metadata.go b/listener/udp/metadata.go new file mode 100644 index 0000000..750555e --- /dev/null +++ b/listener/udp/metadata.go @@ -0,0 +1,55 @@ +package udp + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" +) + +const ( + defaultTTL = 5 * time.Second + defaultReadBufferSize = 1500 + defaultReadQueueSize = 128 + defaultBacklog = 128 +) + +type metadata struct { + readBufferSize int + readQueueSize int + backlog int + keepalive bool + ttl time.Duration +} + +func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) { + const ( + readBufferSize = "readBufferSize" + readQueueSize = "readQueueSize" + backlog = "backlog" + keepAlive = "keepAlive" + ttl = "ttl" + ) + + l.md.ttl = mdx.GetDuration(md, ttl) + if l.md.ttl <= 0 { + l.md.ttl = defaultTTL + } + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) + if l.md.readBufferSize <= 0 { + l.md.readBufferSize = defaultReadBufferSize + } + + l.md.readQueueSize = mdx.GetInt(md, readQueueSize) + if l.md.readQueueSize <= 0 { + l.md.readQueueSize = defaultReadQueueSize + } + + l.md.backlog = mdx.GetInt(md, backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + l.md.keepalive = mdx.GetBool(md, keepAlive) + + return +} diff --git a/listener/ws/listener.go b/listener/ws/listener.go index 13f175f..1f0e52c 100644 --- a/listener/ws/listener.go +++ b/listener/ws/listener.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" metrics "github.com/go-gost/core/metrics/wrapper" - "github.com/go-gost/core/registry" ws_util "github.com/go-gost/x/internal/util/ws" + "github.com/go-gost/x/registry" "github.com/gorilla/websocket" ) diff --git a/listener/ws/metadata.go b/listener/ws/metadata.go index f4deb29..2d6bb72 100644 --- a/listener/ws/metadata.go +++ b/listener/ws/metadata.go @@ -5,6 +5,7 @@ import ( "time" mdata "github.com/go-gost/core/metadata" + mdx "github.com/go-gost/x/metadata" ) const ( @@ -39,23 +40,23 @@ func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) - l.md.path = mdata.GetString(md, path) + l.md.path = mdx.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = mdata.GetInt(md, backlog) + l.md.backlog = mdx.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) - l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) - l.md.readBufferSize = mdata.GetInt(md, readBufferSize) - l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize) - l.md.enableCompression = mdata.GetBool(md, enableCompression) + l.md.handshakeTimeout = mdx.GetDuration(md, handshakeTimeout) + l.md.readHeaderTimeout = mdx.GetDuration(md, readHeaderTimeout) + l.md.readBufferSize = mdx.GetInt(md, readBufferSize) + l.md.writeBufferSize = mdx.GetInt(md, writeBufferSize) + l.md.enableCompression = mdx.GetBool(md, enableCompression) - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { + if mm := mdx.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..3438ff5 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,176 @@ +package logger + +import ( + "fmt" + "io" + "path/filepath" + "runtime" + + "github.com/go-gost/core/logger" + + "github.com/sirupsen/logrus" +) + +type LoggerOptions struct { + Output io.Writer + Format logger.LogFormat + Level logger.LogLevel +} + +type LoggerOption func(opts *LoggerOptions) + +func OutputLoggerOption(out io.Writer) LoggerOption { + return func(opts *LoggerOptions) { + opts.Output = out + } +} + +func FormatLoggerOption(format logger.LogFormat) LoggerOption { + return func(opts *LoggerOptions) { + opts.Format = format + } +} + +func LevelLoggerOption(level logger.LogLevel) LoggerOption { + return func(opts *LoggerOptions) { + opts.Level = level + } +} + +type logrusLogger struct { + logger *logrus.Entry +} + +func NewLogger(opts ...LoggerOption) logger.Logger { + var options LoggerOptions + for _, opt := range opts { + opt(&options) + } + + log := logrus.New() + if options.Output != nil { + log.SetOutput(options.Output) + } + + switch options.Format { + case logger.TextFormat: + log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + default: + log.SetFormatter(&logrus.JSONFormatter{ + DisableHTMLEscape: true, + // PrettyPrint: true, + }) + } + + switch options.Level { + case logger.DebugLevel, + logger.InfoLevel, + logger.WarnLevel, + logger.ErrorLevel, + logger.FatalLevel: + lvl, _ := logrus.ParseLevel(string(options.Level)) + log.SetLevel(lvl) + default: + log.SetLevel(logrus.InfoLevel) + } + + return &logrusLogger{ + logger: logrus.NewEntry(log), + } +} + +// WithFields adds new fields to log. +func (l *logrusLogger) WithFields(fields map[string]any) logger.Logger { + return &logrusLogger{ + logger: l.logger.WithFields(logrus.Fields(fields)), + } +} + +// Debug logs a message at level Debug. +func (l *logrusLogger) Debug(args ...any) { + l.log(logrus.DebugLevel, args...) +} + +// Debugf logs a message at level Debug. +func (l *logrusLogger) Debugf(format string, args ...any) { + l.logf(logrus.DebugLevel, format, args...) +} + +// Info logs a message at level Info. +func (l *logrusLogger) Info(args ...any) { + l.log(logrus.InfoLevel, args...) +} + +// Infof logs a message at level Info. +func (l *logrusLogger) Infof(format string, args ...any) { + l.logf(logrus.InfoLevel, format, args...) +} + +// Warn logs a message at level Warn. +func (l *logrusLogger) Warn(args ...any) { + l.log(logrus.WarnLevel, args...) +} + +// Warnf logs a message at level Warn. +func (l *logrusLogger) Warnf(format string, args ...any) { + l.logf(logrus.WarnLevel, format, args...) +} + +// Error logs a message at level Error. +func (l *logrusLogger) Error(args ...any) { + l.log(logrus.ErrorLevel, args...) +} + +// Errorf logs a message at level Error. +func (l *logrusLogger) Errorf(format string, args ...any) { + l.logf(logrus.ErrorLevel, format, args...) +} + +// Fatal logs a message at level Fatal then the process will exit with status set to 1. +func (l *logrusLogger) Fatal(args ...any) { + l.log(logrus.FatalLevel, args...) + l.logger.Logger.Exit(1) +} + +// Fatalf logs a message at level Fatal then the process will exit with status set to 1. +func (l *logrusLogger) Fatalf(format string, args ...any) { + l.logf(logrus.FatalLevel, format, args...) + l.logger.Logger.Exit(1) +} + +func (l *logrusLogger) GetLevel() logger.LogLevel { + return logger.LogLevel(l.logger.Logger.GetLevel().String()) +} + +func (l *logrusLogger) IsLevelEnabled(level logger.LogLevel) bool { + lvl, _ := logrus.ParseLevel(string(level)) + return l.logger.Logger.IsLevelEnabled(lvl) +} + +func (l *logrusLogger) log(level logrus.Level, args ...any) { + lg := l.logger + if l.logger.Logger.IsLevelEnabled(logrus.DebugLevel) { + lg = lg.WithField("caller", l.caller(3)) + } + lg.Log(level, args...) +} + +func (l *logrusLogger) logf(level logrus.Level, format string, args ...any) { + lg := l.logger + if l.logger.Logger.IsLevelEnabled(logrus.DebugLevel) { + lg = lg.WithField("caller", l.caller(3)) + } + lg.Logf(level, format, args...) +} + +func (l *logrusLogger) caller(skip int) string { + _, file, line, ok := runtime.Caller(skip) + if !ok { + file = "" + } else { + file = filepath.Join(filepath.Base(filepath.Dir(file)), filepath.Base(file)) + } + return fmt.Sprintf("%s:%d", file, line) +} diff --git a/logger/nop_logger.go b/logger/nop_logger.go new file mode 100644 index 0000000..2960eef --- /dev/null +++ b/logger/nop_logger.go @@ -0,0 +1,57 @@ +package logger + +import ( + "github.com/go-gost/core/logger" +) + +var ( + nop = &nopLogger{} +) + +func Nop() logger.Logger { + return nop +} + +type nopLogger struct{} + +func (l *nopLogger) WithFields(fields map[string]any) logger.Logger { + return l +} + +func (l *nopLogger) Debug(args ...any) { +} + +func (l *nopLogger) Debugf(format string, args ...any) { +} + +func (l *nopLogger) Info(args ...any) { +} + +func (l *nopLogger) Infof(format string, args ...any) { +} + +func (l *nopLogger) Warn(args ...any) { +} + +func (l *nopLogger) Warnf(format string, args ...any) { +} + +func (l *nopLogger) Error(args ...any) { +} + +func (l *nopLogger) Errorf(format string, args ...any) { +} + +func (l *nopLogger) Fatal(args ...any) { +} + +func (l *nopLogger) Fatalf(format string, args ...any) { +} + +func (l *nopLogger) GetLevel() logger.LogLevel { + return "" +} + +func (l *nopLogger) IsLevelEnabled(level logger.LogLevel) bool { + return false +} diff --git a/metadata/metadata.go b/metadata/metadata.go new file mode 100644 index 0000000..77fb5f3 --- /dev/null +++ b/metadata/metadata.go @@ -0,0 +1,148 @@ +package metadata + +import ( + "fmt" + "strconv" + "time" + + "github.com/go-gost/core/metadata" +) + +type mapMetadata map[string]any + +func NewMetadata(m map[string]any) metadata.Metadata { + return mapMetadata(m) +} + +func (m mapMetadata) IsExists(key string) bool { + _, ok := m[key] + return ok +} + +func (m mapMetadata) Set(key string, value any) { + m[key] = value +} + +func (m mapMetadata) Get(key string) any { + if m != nil { + return m[key] + } + return nil +} + +func GetBool(md metadata.Metadata, key string) (v bool) { + if md == nil || !md.IsExists(key) { + return + } + switch vv := md.Get(key).(type) { + case bool: + return vv + case int: + return vv != 0 + case string: + v, _ = strconv.ParseBool(vv) + return + } + return +} + +func GetInt(md metadata.Metadata, key string) (v int) { + if md == nil { + return + } + + switch vv := md.Get(key).(type) { + case bool: + if vv { + v = 1 + } + case int: + return vv + case string: + v, _ = strconv.Atoi(vv) + return + } + return +} + +func GetFloat(md metadata.Metadata, key string) (v float64) { + if md == nil { + return + } + + switch vv := md.Get(key).(type) { + case int: + return float64(vv) + case string: + v, _ = strconv.ParseFloat(vv, 64) + return + } + return +} + +func GetDuration(md metadata.Metadata, key string) (v time.Duration) { + if md == nil { + return + } + switch vv := md.Get(key).(type) { + case int: + return time.Duration(vv) * time.Second + case string: + v, _ = time.ParseDuration(vv) + if v == 0 { + n, _ := strconv.Atoi(vv) + v = time.Duration(n) * time.Second + } + } + return +} + +func GetString(md metadata.Metadata, key string) (v string) { + if md != nil { + v, _ = md.Get(key).(string) + } + return +} + +func GetStrings(md metadata.Metadata, key string) (ss []string) { + switch v := md.Get(key).(type) { + case []string: + ss = v + case []any: + for _, vv := range v { + if s, ok := vv.(string); ok { + ss = append(ss, s) + } + } + } + return +} + +func GetStringMap(md metadata.Metadata, key string) (m map[string]any) { + switch vv := md.Get(key).(type) { + case map[string]any: + return vv + case map[any]any: + m = make(map[string]any) + for k, v := range vv { + m[fmt.Sprintf("%v", k)] = v + } + } + return +} + +func GetStringMapString(md metadata.Metadata, key string) (m map[string]string) { + switch vv := md.Get(key).(type) { + case map[string]any: + m = make(map[string]string) + for k, v := range vv { + m[k] = fmt.Sprintf("%v", v) + } + case map[any]any: + m = make(map[string]string) + for k, v := range vv { + m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) + } + } + return +} diff --git a/registry/admission.go b/registry/admission.go new file mode 100644 index 0000000..3fffb28 --- /dev/null +++ b/registry/admission.go @@ -0,0 +1,40 @@ +package registry + +import ( + "github.com/go-gost/core/admission" +) + +type admissionRegistry struct { + registry +} + +func (r *admissionRegistry) Register(name string, v admission.Admission) error { + return r.registry.Register(name, v) +} + +func (r *admissionRegistry) Get(name string) admission.Admission { + if name != "" { + return &admissionWrapper{name: name, r: r} + } + return nil +} + +func (r *admissionRegistry) get(name string) admission.Admission { + if v := r.registry.Get(name); v != nil { + return v.(admission.Admission) + } + return nil +} + +type admissionWrapper struct { + name string + r *admissionRegistry +} + +func (w *admissionWrapper) Admit(addr string) bool { + p := w.r.get(w.name) + if p == nil { + return false + } + return p.Admit(addr) +} diff --git a/registry/auther.go b/registry/auther.go new file mode 100644 index 0000000..ecc2d36 --- /dev/null +++ b/registry/auther.go @@ -0,0 +1,40 @@ +package registry + +import ( + "github.com/go-gost/core/auth" +) + +type autherRegistry struct { + registry +} + +func (r *autherRegistry) Register(name string, v auth.Authenticator) error { + return r.registry.Register(name, v) +} + +func (r *autherRegistry) Get(name string) auth.Authenticator { + if name != "" { + return &autherWrapper{name: name, r: r} + } + return nil +} + +func (r *autherRegistry) get(name string) auth.Authenticator { + if v := r.registry.Get(name); v != nil { + return v.(auth.Authenticator) + } + return nil +} + +type autherWrapper struct { + name string + r *autherRegistry +} + +func (w *autherWrapper) Authenticate(user, password string) bool { + v := w.r.get(w.name) + if v == nil { + return true + } + return v.Authenticate(user, password) +} diff --git a/registry/bypass.go b/registry/bypass.go new file mode 100644 index 0000000..8a77c67 --- /dev/null +++ b/registry/bypass.go @@ -0,0 +1,40 @@ +package registry + +import ( + "github.com/go-gost/core/bypass" +) + +type bypassRegistry struct { + registry +} + +func (r *bypassRegistry) Register(name string, v bypass.Bypass) error { + return r.registry.Register(name, v) +} + +func (r *bypassRegistry) Get(name string) bypass.Bypass { + if name != "" { + return &bypassWrapper{name: name, r: r} + } + return nil +} + +func (r *bypassRegistry) get(name string) bypass.Bypass { + if v := r.registry.Get(name); v != nil { + return v.(bypass.Bypass) + } + return nil +} + +type bypassWrapper struct { + name string + r *bypassRegistry +} + +func (w *bypassWrapper) Contains(addr string) bool { + bp := w.r.get(w.name) + if bp == nil { + return false + } + return bp.Contains(addr) +} diff --git a/registry/chain.go b/registry/chain.go new file mode 100644 index 0000000..36d2465 --- /dev/null +++ b/registry/chain.go @@ -0,0 +1,40 @@ +package registry + +import ( + "github.com/go-gost/core/chain" +) + +type chainRegistry struct { + registry +} + +func (r *chainRegistry) Register(name string, v chain.Chainer) error { + return r.registry.Register(name, v) +} + +func (r *chainRegistry) Get(name string) chain.Chainer { + if name != "" { + return &chainWrapper{name: name, r: r} + } + return nil +} + +func (r *chainRegistry) get(name string) chain.Chainer { + if v := r.registry.Get(name); v != nil { + return v.(chain.Chainer) + } + return nil +} + +type chainWrapper struct { + name string + r *chainRegistry +} + +func (w *chainWrapper) Route(network, address string) *chain.Route { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Route(network, address) +} diff --git a/registry/connector.go b/registry/connector.go new file mode 100644 index 0000000..955b49b --- /dev/null +++ b/registry/connector.go @@ -0,0 +1,26 @@ +package registry + +import ( + "github.com/go-gost/core/connector" + "github.com/go-gost/core/logger" +) + +type NewConnector func(opts ...connector.Option) connector.Connector + +type connectorRegistry struct { + registry +} + +func (r *connectorRegistry) Register(name string, v NewConnector) error { + if err := r.registry.Register(name, v); err != nil { + logger.Default().Fatal(err) + } + return nil +} + +func (r *connectorRegistry) Get(name string) NewConnector { + if v := r.registry.Get(name); v != nil { + return v.(NewConnector) + } + return nil +} diff --git a/registry/dialer.go b/registry/dialer.go new file mode 100644 index 0000000..2db3b5a --- /dev/null +++ b/registry/dialer.go @@ -0,0 +1,26 @@ +package registry + +import ( + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" +) + +type NewDialer func(opts ...dialer.Option) dialer.Dialer + +type dialerRegistry struct { + registry +} + +func (r *dialerRegistry) Register(name string, v NewDialer) error { + if err := r.registry.Register(name, v); err != nil { + logger.Default().Fatal(err) + } + return nil +} + +func (r *dialerRegistry) Get(name string) NewDialer { + if v := r.registry.Get(name); v != nil { + return v.(NewDialer) + } + return nil +} diff --git a/registry/handler.go b/registry/handler.go new file mode 100644 index 0000000..c063206 --- /dev/null +++ b/registry/handler.go @@ -0,0 +1,26 @@ +package registry + +import ( + "github.com/go-gost/core/handler" + "github.com/go-gost/core/logger" +) + +type NewHandler func(opts ...handler.Option) handler.Handler + +type handlerRegistry struct { + registry +} + +func (r *handlerRegistry) Register(name string, v NewHandler) error { + if err := r.registry.Register(name, v); err != nil { + logger.Default().Fatal(err) + } + return nil +} + +func (r *handlerRegistry) Get(name string) NewHandler { + if v := r.registry.Get(name); v != nil { + return v.(NewHandler) + } + return nil +} diff --git a/registry/hosts.go b/registry/hosts.go new file mode 100644 index 0000000..16dacad --- /dev/null +++ b/registry/hosts.go @@ -0,0 +1,42 @@ +package registry + +import ( + "net" + + "github.com/go-gost/core/hosts" +) + +type hostsRegistry struct { + registry +} + +func (r *hostsRegistry) Register(name string, v hosts.HostMapper) error { + return r.registry.Register(name, v) +} + +func (r *hostsRegistry) Get(name string) hosts.HostMapper { + if name != "" { + return &hostsWrapper{name: name, r: r} + } + return nil +} + +func (r *hostsRegistry) get(name string) hosts.HostMapper { + if v := r.registry.Get(name); v != nil { + return v.(hosts.HostMapper) + } + return nil +} + +type hostsWrapper struct { + name string + r *hostsRegistry +} + +func (w *hostsWrapper) Lookup(network, host string) ([]net.IP, bool) { + v := w.r.get(w.name) + if v == nil { + return nil, false + } + return v.Lookup(network, host) +} diff --git a/registry/listener.go b/registry/listener.go new file mode 100644 index 0000000..8d55ce3 --- /dev/null +++ b/registry/listener.go @@ -0,0 +1,26 @@ +package registry + +import ( + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" +) + +type NewListener func(opts ...listener.Option) listener.Listener + +type listenerRegistry struct { + registry +} + +func (r *listenerRegistry) Register(name string, v NewListener) error { + if err := r.registry.Register(name, v); err != nil { + logger.Default().Fatal(err) + } + return nil +} + +func (r *listenerRegistry) Get(name string) NewListener { + if v := r.registry.Get(name); v != nil { + return v.(NewListener) + } + return nil +} diff --git a/registry/registry.go b/registry/registry.go new file mode 100644 index 0000000..affd3d2 --- /dev/null +++ b/registry/registry.go @@ -0,0 +1,116 @@ +package registry + +import ( + "errors" + "sync" + + "github.com/go-gost/core/admission" + "github.com/go-gost/core/auth" + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/chain" + "github.com/go-gost/core/hosts" + "github.com/go-gost/core/resolver" + "github.com/go-gost/core/service" +) + +var ( + ErrDup = errors.New("registry: duplicate object") +) + +var ( + listenerReg Registry[NewListener] = &listenerRegistry{} + handlerReg Registry[NewHandler] = &handlerRegistry{} + dialerReg Registry[NewDialer] = &dialerRegistry{} + connectorReg Registry[NewConnector] = &connectorRegistry{} + + serviceReg Registry[service.Service] = &serviceRegistry{} + chainReg Registry[chain.Chainer] = &chainRegistry{} + autherReg Registry[auth.Authenticator] = &autherRegistry{} + admissionReg Registry[admission.Admission] = &admissionRegistry{} + bypassReg Registry[bypass.Bypass] = &bypassRegistry{} + resolverReg Registry[resolver.Resolver] = &resolverRegistry{} + hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} +) + +type Registry[T any] interface { + Register(name string, v T) error + Unregister(name string) + IsRegistered(name string) bool + Get(name string) T +} + +type registry struct { + m sync.Map +} + +func (r *registry) Register(name string, v any) error { + if name == "" || v == nil { + return nil + } + if _, loaded := r.m.LoadOrStore(name, v); loaded { + return ErrDup + } + + return nil +} + +func (r *registry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *registry) IsRegistered(name string) bool { + _, ok := r.m.Load(name) + return ok +} + +func (r *registry) Get(name string) any { + if name == "" { + return nil + } + v, _ := r.m.Load(name) + return v +} + +func ListenerRegistry() Registry[NewListener] { + return listenerReg +} + +func HandlerRegistry() Registry[NewHandler] { + return handlerReg +} + +func DialerRegistry() Registry[NewDialer] { + return dialerReg +} + +func ConnectorRegistry() Registry[NewConnector] { + return connectorReg +} + +func ServiceRegistry() Registry[service.Service] { + return serviceReg +} + +func ChainRegistry() Registry[chain.Chainer] { + return chainReg +} + +func AutherRegistry() Registry[auth.Authenticator] { + return autherReg +} + +func AdmissionRegistry() Registry[admission.Admission] { + return admissionReg +} + +func BypassRegistry() Registry[bypass.Bypass] { + return bypassReg +} + +func ResolverRegistry() Registry[resolver.Resolver] { + return resolverReg +} + +func HostsRegistry() Registry[hosts.HostMapper] { + return hostsReg +} diff --git a/registry/resolver.go b/registry/resolver.go new file mode 100644 index 0000000..bd950fa --- /dev/null +++ b/registry/resolver.go @@ -0,0 +1,43 @@ +package registry + +import ( + "context" + "net" + + "github.com/go-gost/core/resolver" +) + +type resolverRegistry struct { + registry +} + +func (r *resolverRegistry) Register(name string, v resolver.Resolver) error { + return r.registry.Register(name, v) +} + +func (r *resolverRegistry) Get(name string) resolver.Resolver { + if name != "" { + return &resolverWrapper{name: name, r: r} + } + return nil +} + +func (r *resolverRegistry) get(name string) resolver.Resolver { + if v := r.registry.Get(name); v != nil { + return v.(resolver.Resolver) + } + return nil +} + +type resolverWrapper struct { + name string + r *resolverRegistry +} + +func (w *resolverWrapper) Resolve(ctx context.Context, network, host string) ([]net.IP, error) { + r := w.r.get(w.name) + if r == nil { + return nil, resolver.ErrInvalid + } + return r.Resolve(ctx, network, host) +} diff --git a/registry/service.go b/registry/service.go new file mode 100644 index 0000000..97bc25a --- /dev/null +++ b/registry/service.go @@ -0,0 +1,20 @@ +package registry + +import ( + "github.com/go-gost/core/service" +) + +type serviceRegistry struct { + registry +} + +func (r *serviceRegistry) Register(name string, v service.Service) error { + return r.registry.Register(name, v) +} + +func (r *serviceRegistry) Get(name string) service.Service { + if v := r.registry.Get(name); v != nil { + return v.(service.Service) + } + return nil +} diff --git a/resolver/exchanger/exchanger.go b/resolver/exchanger/exchanger.go new file mode 100644 index 0000000..b109f95 --- /dev/null +++ b/resolver/exchanger/exchanger.go @@ -0,0 +1,226 @@ +package exchanger + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/logger" + "github.com/miekg/dns" +) + +type Options struct { + router *chain.Router + tlsConfig *tls.Config + timeout time.Duration + logger logger.Logger +} + +// Option allows a common way to set Exchanger options. +type Option func(opts *Options) + +// RouterOption sets the router for Exchanger. +func RouterOption(router *chain.Router) Option { + return func(opts *Options) { + opts.router = router + } +} + +// TLSConfigOption sets the TLS config for Exchanger. +func TLSConfigOption(cfg *tls.Config) Option { + return func(opts *Options) { + opts.tlsConfig = cfg + } +} + +// LoggerOption sets the logger for Exchanger. +func LoggerOption(logger logger.Logger) Option { + return func(opts *Options) { + opts.logger = logger + } +} + +// TimeoutOption sets the timeout for Exchanger. +func TimeoutOption(timeout time.Duration) Option { + return func(opts *Options) { + opts.timeout = timeout + } +} + +// Exchanger is an interface for DNS synchronous query. +type Exchanger interface { + Exchange(ctx context.Context, msg []byte) ([]byte, error) + String() string +} + +type exchanger struct { + network string + addr string + rawAddr string + router *chain.Router + client *http.Client + options Options +} + +// NewExchanger create an Exchanger. +// The addr should be URL-like format, +// e.g. udp://1.1.1.1:53, tls://1.1.1.1:853, https://1.0.0.1/dns-query +func NewExchanger(addr string, opts ...Option) (Exchanger, error) { + var options Options + for _, opt := range opts { + opt(&options) + } + + if !strings.Contains(addr, "://") { + addr = "udp://" + addr + } + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + + if options.timeout <= 0 { + options.timeout = 5 * time.Second + } + + ex := &exchanger{ + network: u.Scheme, + addr: u.Host, + rawAddr: addr, + router: options.router, + options: options, + } + if _, port, _ := net.SplitHostPort(ex.addr); port == "" { + ex.addr = net.JoinHostPort(ex.addr, "53") + } + if ex.router == nil { + ex.router = (&chain.Router{}).WithLogger(options.logger) + } + + switch ex.network { + case "tcp": + case "dot", "tls": + if ex.options.tlsConfig == nil { + ex.options.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + ex.network = "tcp" + case "https": + ex.addr = addr + if ex.options.tlsConfig == nil { + ex.options.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + ex.client = &http.Client{ + Timeout: options.timeout, + Transport: &http.Transport{ + TLSClientConfig: options.tlsConfig, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: options.timeout, + ExpectContinueTimeout: 1 * time.Second, + DialContext: ex.dial, + }, + } + default: + ex.network = "udp" + } + + return ex, nil +} + +func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) { + if ex.network == "https" { + return ex.dohExchange(ctx, msg) + } + return ex.exchange(ctx, msg) +} + +func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg)) + if err != nil { + return nil, fmt.Errorf("failed to create an HTTPS request: %w", err) + } + + // req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Header.Add("Content-Type", "application/dns-message") + + client := ex.client + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err) + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + // Read wireformat response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read the response body: %w", err) + } + + return buf, nil +} + +func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) { + if ex.options.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, ex.options.timeout) + defer cancel() + } + + c, err := ex.dial(ctx, ex.network, ex.addr) + if err != nil { + return nil, err + } + defer c.Close() + + if ex.options.tlsConfig != nil { + c = tls.Client(c, ex.options.tlsConfig) + } + if ex.options.timeout > 0 { + c.SetDeadline(time.Now().Add(ex.options.timeout)) + } + + conn := &dns.Conn{ + UDPSize: 1024, + Conn: c, + } + + if _, err = conn.Write(msg); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) { + return ex.router.Dial(ctx, network, address) +} + +func (ex *exchanger) String() string { + return ex.rawAddr +} diff --git a/resolver/resolver.go b/resolver/resolver.go new file mode 100644 index 0000000..b49ea69 --- /dev/null +++ b/resolver/resolver.go @@ -0,0 +1,178 @@ +package resolver + +import ( + "context" + "net" + "strings" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/logger" + resolverpkg "github.com/go-gost/core/resolver" + resolver_util "github.com/go-gost/x/internal/util/resolver" + "github.com/go-gost/x/resolver/exchanger" + "github.com/miekg/dns" +) + +type NameServer struct { + Addr string + Chain chain.Chainer + TTL time.Duration + Timeout time.Duration + ClientIP net.IP + Prefer string + Hostname string // for TLS handshake verification + exchanger exchanger.Exchanger +} + +type resolverOptions struct { + domain string + logger logger.Logger +} + +type ResolverOption func(opts *resolverOptions) + +func DomainResolverOption(domain string) ResolverOption { + return func(opts *resolverOptions) { + opts.domain = domain + } +} + +func LoggerResolverOption(logger logger.Logger) ResolverOption { + return func(opts *resolverOptions) { + opts.logger = logger + } +} + +type resolver struct { + servers []NameServer + cache *resolver_util.Cache + options resolverOptions +} + +func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.Resolver, error) { + options := resolverOptions{} + for _, opt := range opts { + opt(&options) + } + + var servers []NameServer + for _, server := range nameservers { + addr := strings.TrimSpace(server.Addr) + if addr == "" { + continue + } + ex, err := exchanger.NewExchanger( + addr, + exchanger.RouterOption( + (&chain.Router{}). + WithChain(server.Chain). + WithLogger(options.logger), + ), + exchanger.TimeoutOption(server.Timeout), + exchanger.LoggerOption(options.logger), + ) + if err != nil { + options.logger.Warnf("parse %s: %v", server, err) + continue + } + + server.exchanger = ex + servers = append(servers, server) + } + cache := resolver_util.NewCache(). + WithLogger(options.logger) + + return &resolver{ + servers: servers, + cache: cache, + options: options, + }, nil +} + +func (r *resolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + if r.options.domain != "" && + !strings.Contains(host, ".") { + host = host + "." + r.options.domain + } + + for _, server := range r.servers { + ips, err = r.resolve(ctx, &server, host) + if err != nil { + r.options.logger.Error(err) + continue + } + + r.options.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) + + if len(ips) > 0 { + break + } + } + + return +} + +func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { + if server == nil { + return + } + + if server.Prefer == "ipv6" { // prefer ipv6 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + ips, err = r.resolveIPs(ctx, server, &mq) + if err != nil || len(ips) > 0 { + return + } + } + + // fallback to ipv4 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeA) + return r.resolveIPs(ctx, server, &mq) +} + +func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { + key := resolver_util.NewCacheKey(&mq.Question[0]) + mr := r.cache.Load(key) + if mr == nil { + resolver_util.AddSubnetOpt(mq, server.ClientIP) + mr, err = r.exchange(ctx, server.exchanger, mq) + if err != nil { + return + } + r.cache.Store(key, mr, server.TTL) + } + + for _, ans := range mr.Answer { + if ar, _ := ans.(*dns.AAAA); ar != nil { + ips = append(ips, ar.AAAA) + } + if ar, _ := ans.(*dns.A); ar != nil { + ips = append(ips, ar.A) + } + } + + return +} + +func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { + query, err := mq.Pack() + if err != nil { + return + } + reply, err := ex.Exchange(ctx, query) + if err != nil { + return + } + + mr = &dns.Msg{} + err = mr.Unpack(reply) + + return +}