From c7f5da6ac7be5f50a49e2e73d9b1306109edc8f6 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 16 Dec 2021 22:26:13 +0800 Subject: [PATCH] add obfs dialer --- cmd/gost/register.go | 2 + go.mod | 1 - go.sum | 2 - pkg/connector/http/connector.go | 7 +- pkg/connector/http/metadata.go | 21 +-- pkg/dialer/obfs/http/conn.go | 144 ++++++++++++++++ pkg/dialer/obfs/http/dialer.go | 64 +++++++ pkg/dialer/obfs/http/metadata.go | 29 ++++ pkg/dialer/obfs/tls/conn.go | 266 +++++++++++++++++++++++++++++ pkg/dialer/obfs/tls/dialer.go | 63 +++++++ pkg/dialer/obfs/tls/metadata.go | 18 ++ pkg/handler/http/handler.go | 4 +- pkg/handler/http/metadata.go | 17 +- pkg/handler/http/udp.go | 4 +- pkg/listener/obfs/http/conn.go | 80 ++++----- pkg/listener/obfs/http/listener.go | 21 +-- pkg/listener/obfs/http/metadata.go | 14 +- pkg/listener/obfs/tls/conn.go | 171 ++----------------- pkg/listener/obfs/tls/listener.go | 20 +-- pkg/listener/obfs/tls/metadata.go | 17 +- 20 files changed, 691 insertions(+), 274 deletions(-) create mode 100644 pkg/dialer/obfs/http/conn.go create mode 100644 pkg/dialer/obfs/http/dialer.go create mode 100644 pkg/dialer/obfs/http/metadata.go create mode 100644 pkg/dialer/obfs/tls/conn.go create mode 100644 pkg/dialer/obfs/tls/dialer.go create mode 100644 pkg/dialer/obfs/tls/metadata.go diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 5068cfd..08d5b76 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -17,6 +17,8 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/http2/h2" _ "github.com/go-gost/gost/pkg/dialer/kcp" + _ "github.com/go-gost/gost/pkg/dialer/obfs/http" + _ "github.com/go-gost/gost/pkg/dialer/obfs/tls" _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/udp" diff --git a/go.mod b/go.mod index 1e73590..561bcaa 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/coreos/go-iptables v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect - github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index 83b69fb..7f1a624 100644 --- a/go.sum +++ b/go.sum @@ -104,8 +104,6 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da h1:CPNdzkS5TMPghHVTYJp08SUdSneNVSwJSePAPGDuYgY= -github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da/go.mod h1:YyzP8PQrGwDH/XsfHJXwqdHLwWvBYxu77YVKm0+68f0= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 7b6492e..fdb1b3a 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -59,9 +59,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add ProtoMinor: 1, Header: make(http.Header), } - if c.md.UserAgent != "" { - req.Header.Set("User-Agent", c.md.UserAgent) - } req.Header.Set("Proxy-Connection", "keep-alive") if user := c.md.User; user != nil { @@ -71,6 +68,10 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) } + for k, v := range c.md.headers { + req.Header.Set(k, v) + } + switch network { case "tcp", "tcp4", "tcp6": if _, ok := conn.(net.PacketConn); ok { diff --git a/pkg/connector/http/metadata.go b/pkg/connector/http/metadata.go index 3b86cb2..e284275 100644 --- a/pkg/connector/http/metadata.go +++ b/pkg/connector/http/metadata.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "net/url" "strings" "time" @@ -8,28 +9,20 @@ import ( md "github.com/go-gost/gost/pkg/metadata" ) -const ( - defaultUserAgent = "Chrome/78.0.3904.106" -) - type metadata struct { connectTimeout time.Duration - UserAgent string User *url.Userinfo + headers map[string]string } func (c *httpConnector) parseMetadata(md md.Metadata) (err error) { const ( connectTimeout = "timeout" - userAgent = "userAgent" user = "user" + headers = "headers" ) c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.UserAgent, _ = md.Get(userAgent).(string) - if c.md.UserAgent == "" { - c.md.UserAgent = defaultUserAgent - } if v := md.GetString(user); v != "" { ss := strings.SplitN(v, ":", 2) @@ -40,5 +33,13 @@ func (c *httpConnector) parseMetadata(md md.Metadata) (err error) { } } + if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { + m := make(map[string]string) + for k, v := range mm { + m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) + } + c.md.headers = m + } + return } diff --git a/pkg/dialer/obfs/http/conn.go b/pkg/dialer/obfs/http/conn.go new file mode 100644 index 0000000..0690f5b --- /dev/null +++ b/pkg/dialer/obfs/http/conn.go @@ -0,0 +1,144 @@ +package http + +import ( + "bufio" + "bytes" + "crypto/rand" + "encoding/base64" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + + "github.com/go-gost/gost/pkg/logger" +) + +type obfsHTTPConn struct { + net.Conn + host string + rbuf bytes.Buffer + wbuf bytes.Buffer + headerDrained bool + handshaked bool + handshakeMutex sync.Mutex + headers map[string]string + logger logger.Logger +} + +func (c *obfsHTTPConn) Handshake() (err error) { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.handshaked { + return nil + } + + err = c.handshake() + if err != nil { + return + } + + c.handshaked = true + return nil +} + +func (c *obfsHTTPConn) handshake() (err error) { + r := &http.Request{ + Method: http.MethodGet, + ProtoMajor: 1, + ProtoMinor: 1, + URL: &url.URL{Scheme: "http", Host: c.host}, + Header: make(http.Header), + } + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + key, _ := c.generateChallengeKey() + r.Header.Set("Sec-WebSocket-Key", key) + for k, v := range c.headers { + r.Header.Set(k, v) + } + + // cache the request header + if err = r.Write(&c.wbuf); err != nil { + return + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + c.logger.Debug(string(dump)) + } + + return nil +} + +func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + + if err = c.drainHeader(); err != nil { + return + } + + if c.rbuf.Len() > 0 { + return c.rbuf.Read(b) + } + return c.Conn.Read(b) +} + +func (c *obfsHTTPConn) drainHeader() (err error) { + if c.headerDrained { + return + } + c.headerDrained = true + + br := bufio.NewReader(c.Conn) + // drain and discard the response header + var line string + var buf bytes.Buffer + for { + line, err = br.ReadString('\n') + if err != nil { + return + } + buf.WriteString(line) + if line == "\r\n" { + break + } + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debug(buf.String()) + } + + // cache the extra data for next read. + var b []byte + b, err = br.Peek(br.Buffered()) + if len(b) > 0 { + _, err = c.rbuf.Write(b) + } + return +} + +func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + n = len(b) // exclude the header length + return + } + return c.Conn.Write(b) +} + +func (c *obfsHTTPConn) generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} diff --git a/pkg/dialer/obfs/http/dialer.go b/pkg/dialer/obfs/http/dialer.go new file mode 100644 index 0000000..ed0adf3 --- /dev/null +++ b/pkg/dialer/obfs/http/dialer.go @@ -0,0 +1,64 @@ +package http + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterDialer("ohttp", NewDialer) +} + +type obfsHTTPDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &obfsHTTPDialer{ + logger: options.Logger, + } +} + +func (d *obfsHTTPDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *obfsHTTPDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + var netd net.Dialer + conn, err := netd.DialContext(ctx, "tcp", addr) + if err != nil { + d.logger.Error(err) + } + return conn, err +} + +// Handshake implements dialer.Handshaker +func (d *obfsHTTPDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + host := d.md.host + if host == "" { + host = opts.Addr + } + + return &obfsHTTPConn{ + Conn: conn, + host: host, + headers: d.md.headers, + logger: d.logger, + }, nil +} diff --git a/pkg/dialer/obfs/http/metadata.go b/pkg/dialer/obfs/http/metadata.go new file mode 100644 index 0000000..1237e65 --- /dev/null +++ b/pkg/dialer/obfs/http/metadata.go @@ -0,0 +1,29 @@ +package http + +import ( + "fmt" + + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + host string + headers map[string]string +} + +func (d *obfsHTTPDialer) parseMetadata(md md.Metadata) (err error) { + const ( + headers = "headers" + host = "host" + ) + + if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { + m := make(map[string]string) + for k, v := range mm { + m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) + } + d.md.headers = m + } + d.md.host = md.GetString(host) + return +} diff --git a/pkg/dialer/obfs/tls/conn.go b/pkg/dialer/obfs/tls/conn.go new file mode 100644 index 0000000..8425d35 --- /dev/null +++ b/pkg/dialer/obfs/tls/conn.go @@ -0,0 +1,266 @@ +package tls + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "net" + "sync" + "time" + + dissector "github.com/go-gost/tls-dissector" +) + +const ( + maxTLSDataLen = 16384 +) + +var ( + cipherSuites = []uint16{ + 0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f, + 0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a, + 0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d, + 0x003c, 0x0035, 0x002f, 0x00ff, + } + + compressionMethods = []uint8{0x00} + + algorithms = []uint16{ + 0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402, + 0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203, + } + + tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17} + tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03} + + ErrBadType = errors.New("bad type") + ErrBadMajorVersion = errors.New("bad major version") + ErrBadMinorVersion = errors.New("bad minor version") + ErrMaxDataLen = errors.New("bad tls data len") +) + +const ( + tlsRecordStateType = iota + tlsRecordStateVersion0 + tlsRecordStateVersion1 + tlsRecordStateLength0 + tlsRecordStateLength1 + tlsRecordStateData +) + +type obfsTLSParser struct { + step uint8 + state uint8 + length uint16 +} + +func (r *obfsTLSParser) Parse(b []byte) (int, error) { + i := 0 + last := 0 + length := len(b) + + for i < length { + ch := b[i] + switch r.state { + case tlsRecordStateType: + if tlsRecordTypes[r.step] != ch { + return 0, ErrBadType + } + r.state = tlsRecordStateVersion0 + i++ + case tlsRecordStateVersion0: + if ch != 0x03 { + return 0, ErrBadMajorVersion + } + r.state = tlsRecordStateVersion1 + i++ + case tlsRecordStateVersion1: + if ch != tlsVersionMinors[r.step] { + return 0, ErrBadMinorVersion + } + r.state = tlsRecordStateLength0 + i++ + case tlsRecordStateLength0: + r.length = uint16(ch) << 8 + r.state = tlsRecordStateLength1 + i++ + case tlsRecordStateLength1: + r.length |= uint16(ch) + if r.step == 0 { + r.length = 91 + } else if r.step == 1 { + r.length = 1 + } else if r.length > maxTLSDataLen { + return 0, ErrMaxDataLen + } + if r.length > 0 { + r.state = tlsRecordStateData + } else { + r.state = tlsRecordStateType + r.step++ + } + i++ + case tlsRecordStateData: + left := uint16(length - i) + if left > r.length { + left = r.length + } + if r.step >= 2 { + skip := i - last + copy(b[last:], b[i:length]) + length -= int(skip) + last += int(left) + i = last + } else { + i += int(left) + } + r.length -= left + if r.length == 0 { + if r.step < 3 { + r.step++ + } + r.state = tlsRecordStateType + } + } + } + + if last == 0 { + return 0, nil + } else if last < length { + length -= last + } + + return length, nil +} + +type obfsTLSConn struct { + net.Conn + wbuf bytes.Buffer + host string + handshaked chan struct{} + parser obfsTLSParser + handshakeMutex sync.Mutex +} + +func (c *obfsTLSConn) Handshaked() bool { + select { + case <-c.handshaked: + return true + default: + return false + } +} + +func (c *obfsTLSConn) Handshake(payload []byte) (err error) { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.Handshaked() { + return + } + + if err = c.handshake(payload); err != nil { + return + } + + close(c.handshaked) + return nil +} + +func (c *obfsTLSConn) handshake(payload []byte) error { + clientMsg := &dissector.ClientHelloMsg{ + Version: tls.VersionTLS12, + SessionID: make([]byte, 32), + CipherSuites: cipherSuites, + CompressionMethods: compressionMethods, + Extensions: []dissector.Extension{ + &dissector.SessionTicketExtension{ + Data: payload, + }, + &dissector.ServerNameExtension{ + Name: c.host, + }, + &dissector.ECPointFormatsExtension{ + Formats: []uint8{0x01, 0x00, 0x02}, + }, + &dissector.SupportedGroupsExtension{ + Groups: []uint16{0x001d, 0x0017, 0x0019, 0x0018}, + }, + &dissector.SignatureAlgorithmsExtension{ + Algorithms: algorithms, + }, + &dissector.EncryptThenMacExtension{}, + &dissector.ExtendedMasterSecretExtension{}, + }, + } + clientMsg.Random.Time = uint32(time.Now().Unix()) + rand.Read(clientMsg.Random.Opaque[:]) + rand.Read(clientMsg.SessionID) + b, err := clientMsg.Encode() + if err != nil { + return err + } + + record := &dissector.Record{ + Type: dissector.Handshake, + Version: tls.VersionTLS10, + Opaque: b, + } + if _, err := record.WriteTo(c.Conn); err != nil { + return err + } + return err +} + +func (c *obfsTLSConn) Read(b []byte) (n int, err error) { + <-c.handshaked + + n, err = c.Conn.Read(b) + if err != nil { + return + } + if n > 0 { + n, err = c.parser.Parse(b[:n]) + } + + return +} + +func (c *obfsTLSConn) Write(b []byte) (n int, err error) { + n = len(b) + + if !c.Handshaked() { + if err = c.Handshake(b); err != nil { + return + } + return + } + + for len(b) > 0 { + data := b + if len(b) > maxTLSDataLen { + data = b[:maxTLSDataLen] + b = b[maxTLSDataLen:] + } else { + b = b[:0] + } + record := &dissector.Record{ + Type: dissector.AppData, + Version: tls.VersionTLS12, + Opaque: data, + } + + if c.wbuf.Len() > 0 { + record.Type = dissector.Handshake + record.WriteTo(&c.wbuf) + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + if _, err = record.WriteTo(c.Conn); err != nil { + return + } + } + return +} diff --git a/pkg/dialer/obfs/tls/dialer.go b/pkg/dialer/obfs/tls/dialer.go new file mode 100644 index 0000000..44665a5 --- /dev/null +++ b/pkg/dialer/obfs/tls/dialer.go @@ -0,0 +1,63 @@ +package tls + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterDialer("otls", NewDialer) +} + +type obfsTLSDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &obfsTLSDialer{ + logger: options.Logger, + } +} + +func (d *obfsTLSDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *obfsTLSDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { + var netd net.Dialer + conn, err := netd.DialContext(ctx, "tcp", addr) + if err != nil { + d.logger.Error(err) + } + return conn, err +} + +// Handshake implements dialer.Handshaker +func (d *obfsTLSDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + host := d.md.host + if host == "" { + host = opts.Addr + } + + return &obfsTLSConn{ + Conn: conn, + host: host, + handshaked: make(chan struct{}), + }, nil +} diff --git a/pkg/dialer/obfs/tls/metadata.go b/pkg/dialer/obfs/tls/metadata.go new file mode 100644 index 0000000..f387a20 --- /dev/null +++ b/pkg/dialer/obfs/tls/metadata.go @@ -0,0 +1,18 @@ +package tls + +import ( + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + host string +} + +func (d *obfsTLSDialer) parseMetadata(md md.Metadata) (err error) { + const ( + host = "host" + ) + + d.md.host = md.GetString(host) + return +} diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index 2565b61..81c33dc 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -136,8 +136,8 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt Header: http.Header{}, } - if h.md.proxyAgent != "" { - resp.Header.Add("Proxy-Agent", h.md.proxyAgent) + for k, v := range h.md.headers { + resp.Header.Set(k, v) } /* diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index 7fa7944..5f7a1fb 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "strings" "github.com/go-gost/gost/pkg/auth" @@ -8,17 +9,17 @@ import ( ) type metadata struct { - authenticator auth.Authenticator - proxyAgent string retryCount int + authenticator auth.Authenticator probeResist *probeResist sni bool enableUDP bool + headers map[string]string } func (h *httpHandler) parseMetadata(md md.Metadata) error { const ( - proxyAgent = "proxyAgent" + headers = "headers" users = "users" probeResistKey = "probeResist" knock = "knock" @@ -27,8 +28,6 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error { enableUDP = "udp" ) - h.md.proxyAgent = md.GetString(proxyAgent) - if v, _ := md.Get(users).([]interface{}); len(v) > 0 { authenticator := auth.NewLocalAuthenticator(nil) for _, auth := range v { @@ -44,6 +43,14 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error { h.md.authenticator = authenticator } + if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { + m := make(map[string]string) + for k, v := range mm { + m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) + } + h.md.headers = m + } + if v := md.GetString(probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { h.md.probeResist = &probeResist{ diff --git a/pkg/handler/http/udp.go b/pkg/handler/http/udp.go index 3dbeb3c..0598cdd 100644 --- a/pkg/handler/http/udp.go +++ b/pkg/handler/http/udp.go @@ -23,8 +23,8 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add ProtoMinor: 1, Header: http.Header{}, } - if h.md.proxyAgent != "" { - resp.Header.Add("Proxy-Agent", h.md.proxyAgent) + for k, v := range h.md.headers { + resp.Header.Set(k, v) } if !h.md.enableUDP { diff --git a/pkg/listener/obfs/http/conn.go b/pkg/listener/obfs/http/conn.go index f006a97..c89851e 100644 --- a/pkg/listener/obfs/http/conn.go +++ b/pkg/listener/obfs/http/conn.go @@ -6,23 +6,26 @@ import ( "crypto/sha1" "encoding/base64" "errors" - "fmt" "io" "net" "net/http" + "net/http/httputil" "sync" "time" + + "github.com/go-gost/gost/pkg/logger" ) -type conn struct { +type obfsHTTPConn struct { net.Conn rbuf bytes.Buffer wbuf bytes.Buffer handshaked bool handshakeMutex sync.Mutex + logger logger.Logger } -func (c *conn) Handshake() (err error) { +func (c *obfsHTTPConn) Handshake() (err error) { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -38,18 +41,17 @@ func (c *conn) Handshake() (err error) { return nil } -func (c *conn) handshake() (err error) { +func (c *obfsHTTPConn) handshake() (err error) { br := bufio.NewReader(c.Conn) r, err := http.ReadRequest(br) if err != nil { return } - /* - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump)) - } - */ + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + c.logger.Debug(string(dump)) + } if r.ContentLength > 0 { _, err = io.Copy(&c.rbuf, r.Body) @@ -61,52 +63,52 @@ func (c *conn) handshake() (err error) { } } if err != nil { - // log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err) + c.logger.Error(err) return } - b := bytes.Buffer{} + resp := http.Response{ + StatusCode: http.StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + resp.Header.Set("Server", "nginx/1.18.0") + resp.Header.Set("Date", time.Now().Format(time.RFC1123)) if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" { - b.WriteString("HTTP/1.1 503 Service Unavailable\r\n") - b.WriteString("Content-Length: 0\r\n") - b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("\r\n") + resp.StatusCode = http.StatusBadRequest - /* - if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) - } - */ + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(&resp, false) + c.logger.Debug(string(dump)) + } - b.WriteTo(c.Conn) + resp.Write(c.Conn) return errors.New("bad request") } - b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") - b.WriteString("Server: nginx/1.10.0\r\n") - b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("Connection: Upgrade\r\n") - b.WriteString("Upgrade: websocket\r\n") - b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))) - b.WriteString("\r\n") + resp.StatusCode = http.StatusSwitchingProtocols + resp.Header.Set("Connection", "Upgrade") + resp.Header.Set("Upgrade", "websocket") + resp.Header.Set("Sec-WebSocket-Accept", c.computeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))) - /* - if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) - } - */ + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(&resp, false) + c.logger.Debug(string(dump)) + } if c.rbuf.Len() > 0 { - c.wbuf = b // cache the response header if there are extra data in the request body. + // cache the response header if there are extra data in the request body. + resp.Write(&c.wbuf) return } - _, err = b.WriteTo(c.Conn) + err = resp.Write(c.Conn) return } -func (c *conn) Read(b []byte) (n int, err error) { +func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { if err = c.Handshake(); err != nil { return } @@ -117,7 +119,7 @@ func (c *conn) Read(b []byte) (n int, err error) { return c.Conn.Read(b) } -func (c *conn) Write(b []byte) (n int, err error) { +func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { if err = c.Handshake(); err != nil { return } @@ -132,7 +134,7 @@ func (c *conn) Write(b []byte) (n int, err error) { var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -func computeAcceptKey(challengeKey string) string { +func (c *obfsHTTPConn) computeAcceptKey(challengeKey string) string { h := sha1.New() h.Write([]byte(challengeKey)) h.Write(keyGUID) diff --git a/pkg/listener/obfs/http/listener.go b/pkg/listener/obfs/http/listener.go index f80054f..e45778f 100644 --- a/pkg/listener/obfs/http/listener.go +++ b/pkg/listener/obfs/http/listener.go @@ -3,7 +3,6 @@ package http import ( "net" - "github.com/go-gost/gost/pkg/common/util" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -46,14 +45,6 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } - if l.md.keepAlive { - l.Listener = &util.TCPKeepAliveListener{ - TCPListener: ln, - KeepAlivePeriod: l.md.keepAlivePeriod, - } - return - } - l.Listener = ln return } @@ -64,12 +55,8 @@ func (l *obfsListener) Accept() (net.Conn, error) { return nil, err } - return &conn{Conn: c}, nil -} - -func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { - l.md.keepAlive = md.GetBool(keepAlive) - l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) - - return + return &obfsHTTPConn{ + Conn: c, + logger: l.logger, + }, nil } diff --git a/pkg/listener/obfs/http/metadata.go b/pkg/listener/obfs/http/metadata.go index 76079f6..70ec3c7 100644 --- a/pkg/listener/obfs/http/metadata.go +++ b/pkg/listener/obfs/http/metadata.go @@ -1,17 +1,17 @@ package http -import "time" +import ( + md "github.com/go-gost/gost/pkg/metadata" +) const ( keepAlive = "keepAlive" keepAlivePeriod = "keepAlivePeriod" ) -const ( - defaultKeepAlivePeriod = 180 * time.Second -) - type metadata struct { - keepAlive bool - keepAlivePeriod time.Duration +} + +func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { + return } diff --git a/pkg/listener/obfs/tls/conn.go b/pkg/listener/obfs/tls/conn.go index ae6c529..70786bc 100644 --- a/pkg/listener/obfs/tls/conn.go +++ b/pkg/listener/obfs/tls/conn.go @@ -4,169 +4,30 @@ import ( "bytes" "crypto/rand" "crypto/tls" - "errors" "net" "sync" "time" - dissector "github.com/ginuerzh/tls-dissector" + dissector "github.com/go-gost/tls-dissector" ) const ( maxTLSDataLen = 16384 ) -var ( - cipherSuites = []uint16{ - 0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f, - 0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a, - 0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d, - 0x003c, 0x0035, 0x002f, 0x00ff, - } - - compressionMethods = []uint8{0x00} - - algorithms = []uint16{ - 0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402, - 0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203, - } - - tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17} - tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03} - - ErrBadType = errors.New("bad type") - ErrBadMajorVersion = errors.New("bad major version") - ErrBadMinorVersion = errors.New("bad minor version") - ErrMaxDataLen = errors.New("bad tls data len") -) - -const ( - tlsRecordStateType = iota - tlsRecordStateVersion0 - tlsRecordStateVersion1 - tlsRecordStateLength0 - tlsRecordStateLength1 - tlsRecordStateData -) - -type obfsTLSParser struct { - step uint8 - state uint8 - length uint16 -} - -func (r *obfsTLSParser) Parse(b []byte) (int, error) { - i := 0 - last := 0 - length := len(b) - - for i < length { - ch := b[i] - switch r.state { - case tlsRecordStateType: - if tlsRecordTypes[r.step] != ch { - return 0, ErrBadType - } - r.state = tlsRecordStateVersion0 - i++ - case tlsRecordStateVersion0: - if ch != 0x03 { - return 0, ErrBadMajorVersion - } - r.state = tlsRecordStateVersion1 - i++ - case tlsRecordStateVersion1: - if ch != tlsVersionMinors[r.step] { - return 0, ErrBadMinorVersion - } - r.state = tlsRecordStateLength0 - i++ - case tlsRecordStateLength0: - r.length = uint16(ch) << 8 - r.state = tlsRecordStateLength1 - i++ - case tlsRecordStateLength1: - r.length |= uint16(ch) - if r.step == 0 { - r.length = 91 - } else if r.step == 1 { - r.length = 1 - } else if r.length > maxTLSDataLen { - return 0, ErrMaxDataLen - } - if r.length > 0 { - r.state = tlsRecordStateData - } else { - r.state = tlsRecordStateType - r.step++ - } - i++ - case tlsRecordStateData: - left := uint16(length - i) - if left > r.length { - left = r.length - } - if r.step >= 2 { - skip := i - last - copy(b[last:], b[i:length]) - length -= int(skip) - last += int(left) - i = last - } else { - i += int(left) - } - r.length -= left - if r.length == 0 { - if r.step < 3 { - r.step++ - } - r.state = tlsRecordStateType - } - } - } - - if last == 0 { - return 0, nil - } else if last < length { - length -= last - } - - return length, nil -} - -type conn struct { +type obfsTLSConn struct { net.Conn rbuf bytes.Buffer wbuf bytes.Buffer - host string - handshaked chan struct{} - parser *obfsTLSParser + handshaked bool handshakeMutex sync.Mutex } -// newConn creates a connection for obfs-tls server. -func newConn(c net.Conn, host string) net.Conn { - return &conn{ - Conn: c, - host: host, - handshaked: make(chan struct{}), - } -} - -func (c *conn) Handshaked() bool { - select { - case <-c.handshaked: - return true - default: - return false - } -} - -func (c *conn) Handshake(payload []byte) (err error) { +func (c *obfsTLSConn) Handshake() (err error) { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if c.Handshaked() { + if c.handshaked { return } @@ -174,11 +35,11 @@ func (c *conn) Handshake(payload []byte) (err error) { return } - close(c.handshaked) + c.handshaked = true return nil } -func (c *conn) handshake() error { +func (c *obfsTLSConn) handshake() error { record := &dissector.Record{} if _, err := record.ReadFrom(c.Conn); err != nil { // log.Log(err) @@ -248,15 +109,11 @@ func (c *conn) handshake() error { return nil } -func (c *conn) Read(b []byte) (n int, err error) { - if err = c.Handshake(nil); err != nil { +func (c *obfsTLSConn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { return } - select { - case <-c.handshaked: - } - if c.rbuf.Len() > 0 { return c.rbuf.Read(b) } @@ -269,13 +126,11 @@ func (c *conn) Read(b []byte) (n int, err error) { return } -func (c *conn) Write(b []byte) (n int, err error) { - n = len(b) - if !c.Handshaked() { - if err = c.Handshake(b); err != nil { - return - } +func (c *obfsTLSConn) Write(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return } + n = len(b) for len(b) > 0 { data := b diff --git a/pkg/listener/obfs/tls/listener.go b/pkg/listener/obfs/tls/listener.go index 29d7308..cb4056a 100644 --- a/pkg/listener/obfs/tls/listener.go +++ b/pkg/listener/obfs/tls/listener.go @@ -3,7 +3,6 @@ package tls import ( "net" - "github.com/go-gost/gost/pkg/common/util" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -46,14 +45,6 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } - if l.md.keepAlive { - l.Listener = &util.TCPKeepAliveListener{ - TCPListener: ln, - KeepAlivePeriod: l.md.keepAlivePeriod, - } - return - } - l.Listener = ln return } @@ -64,12 +55,7 @@ func (l *obfsListener) Accept() (net.Conn, error) { return nil, err } - return &conn{Conn: c}, nil -} - -func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { - l.md.keepAlive = md.GetBool(keepAlive) - l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) - - return + return &obfsTLSConn{ + Conn: c, + }, nil } diff --git a/pkg/listener/obfs/tls/metadata.go b/pkg/listener/obfs/tls/metadata.go index 74d5b44..046933c 100644 --- a/pkg/listener/obfs/tls/metadata.go +++ b/pkg/listener/obfs/tls/metadata.go @@ -1,17 +1,12 @@ package tls -import "time" - -const ( - keepAlive = "keepAlive" - keepAlivePeriod = "keepAlivePeriod" -) - -const ( - defaultKeepAlivePeriod = 180 * time.Second +import ( + md "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - keepAlive bool - keepAlivePeriod time.Duration +} + +func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { + return }