diff --git a/client/client.go b/client/client.go index 445548c..868296d 100644 --- a/client/client.go +++ b/client/client.go @@ -2,10 +2,10 @@ package client import ( "github.com/go-gost/gost/client/connector" - "github.com/go-gost/gost/client/transporter" + "github.com/go-gost/gost/client/dialer" ) type Client struct { - Connector connector.Connector - Transporter transporter.Transporter + Connector connector.Connector + Dialer dialer.Dialer } diff --git a/client/dialer/dialer.go b/client/dialer/dialer.go new file mode 100644 index 0000000..86f373f --- /dev/null +++ b/client/dialer/dialer.go @@ -0,0 +1,20 @@ +package dialer + +import ( + "context" + "net" +) + +// Dialer dials to target server. +type Dialer interface { + Init(md Metadata) error + Dial(ctx context.Context, addr string) (net.Conn, error) +} + +type Handshaker interface { + Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) +} + +type Multiplexer interface { + Multiplexed() bool +} diff --git a/client/dialer/metadata.go b/client/dialer/metadata.go new file mode 100644 index 0000000..755bb8a --- /dev/null +++ b/client/dialer/metadata.go @@ -0,0 +1,3 @@ +package dialer + +type Metadata map[string]string diff --git a/client/dialer/option.go b/client/dialer/option.go new file mode 100644 index 0000000..f059dc9 --- /dev/null +++ b/client/dialer/option.go @@ -0,0 +1,17 @@ +package dialer + +import ( + "github.com/go-gost/gost/logger" +) + +type Options struct { + Logger logger.Logger +} + +type Option func(opts *Options) + +func LoggerOption(logger logger.Logger) Option { + return func(opts *Options) { + opts.Logger = logger + } +} diff --git a/client/dialer/tcp/dialer.go b/client/dialer/tcp/dialer.go new file mode 100644 index 0000000..fd4bc12 --- /dev/null +++ b/client/dialer/tcp/dialer.go @@ -0,0 +1,41 @@ +package tcp + +import ( + "context" + "net" + + "github.com/go-gost/gost/client/dialer" + "github.com/go-gost/gost/logger" +) + +type Dialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) *Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &Dialer{ + logger: options.Logger, + } +} + +func (d *Dialer) Init(md dialer.Metadata) (err error) { + d.md, err = d.parseMetadata(md) + if err != nil { + return + } + return nil +} + +func (d *Dialer) Dial(ctx context.Context, addr string) (net.Conn, error) { + return nil, nil +} + +func (d *Dialer) parseMetadata(md dialer.Metadata) (m metadata, err error) { + return +} diff --git a/client/dialer/tcp/metadata.go b/client/dialer/tcp/metadata.go new file mode 100644 index 0000000..9afea73 --- /dev/null +++ b/client/dialer/tcp/metadata.go @@ -0,0 +1,15 @@ +package tcp + +import "time" + +const ( + dialTimeout = "dialTimeout" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration +} diff --git a/client/transporter/transporter.go b/client/transporter/transporter.go deleted file mode 100644 index 1d3f8e0..0000000 --- a/client/transporter/transporter.go +++ /dev/null @@ -1,14 +0,0 @@ -package transporter - -import ( - "context" - "net" -) - -// Transporter is responsible for handshaking with server. -type Transporter interface { - Dial(ctx context.Context, addr string) (net.Conn, error) - Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) - // Indicate that the Transporter supports multiplex - Multiplex() bool -} diff --git a/go.mod b/go.mod index 2a94684..ca2524e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect github.com/go-gost/gosocks5 v0.3.0 + github.com/gorilla/websocket v1.4.2 github.com/shadowsocks/go-shadowsocks2 v0.1.4 github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 github.com/sirupsen/logrus v1.8.1 diff --git a/go.sum b/go.sum index 9638f00..dfe6401 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA= github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= diff --git a/server/listener/option.go b/server/listener/option.go new file mode 100644 index 0000000..362b964 --- /dev/null +++ b/server/listener/option.go @@ -0,0 +1,17 @@ +package listener + +import ( + "github.com/go-gost/gost/logger" +) + +type Options struct { + Logger logger.Logger +} + +type Option func(opts *Options) + +func LoggerOption(logger logger.Logger) Option { + return func(opts *Options) { + opts.Logger = logger + } +} diff --git a/server/listener/tcp/tcp.go b/server/listener/tcp/listener.go similarity index 62% rename from server/listener/tcp/tcp.go rename to server/listener/tcp/listener.go index 61c8feb..be2344d 100644 --- a/server/listener/tcp/tcp.go +++ b/server/listener/tcp/listener.go @@ -6,16 +6,29 @@ import ( "strconv" "time" + "github.com/go-gost/gost/logger" "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" +) + +var ( + _ listener.Listener = (*Listener)(nil) ) type Listener struct { md metadata net.Listener + logger logger.Logger } -func NewTCPListener() *Listener { - return &Listener{} +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } } func (l *Listener) Init(md listener.Metadata) (err error) { @@ -34,9 +47,9 @@ func (l *Listener) Init(md listener.Metadata) (err error) { } if l.md.keepAlive { - l.Listener = &keepAliveListener{ + l.Listener = &utils.TCPKeepAliveListener{ TCPListener: ln, - keepAlivePeriod: l.md.keepAlivePeriod, + KeepAlivePeriod: l.md.keepAlivePeriod, } return } @@ -49,7 +62,7 @@ func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { if val, ok := md[addr]; ok { m.addr = val } else { - err = errors.New("tcp listener: missing address") + err = errors.New("missing address") return } @@ -61,26 +74,6 @@ func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { if val, ok := md[keepAlivePeriod]; ok { m.keepAlivePeriod, _ = time.ParseDuration(val) } - if m.keepAlivePeriod <= 0 { - m.keepAlivePeriod = defaultKeepAlivePeriod - } return } - -type keepAliveListener struct { - keepAlivePeriod time.Duration - *net.TCPListener -} - -func (l *keepAliveListener) Accept() (c net.Conn, err error) { - tc, err := l.AcceptTCP() - if err != nil { - return - } - - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(l.keepAlivePeriod) - - return tc, nil -} diff --git a/server/listener/tls/listener.go b/server/listener/tls/listener.go new file mode 100644 index 0000000..510209c --- /dev/null +++ b/server/listener/tls/listener.go @@ -0,0 +1,75 @@ +package tls + +import ( + "crypto/tls" + "errors" + "net" + "time" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + net.Listener + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + ln, err := net.Listen("tcp", l.md.addr) + if err != nil { + return + } + + ln = tls.NewListener( + &utils.TCPKeepAliveListener{ + TCPListener: ln.(*net.TCPListener), + KeepAlivePeriod: l.md.keepAlivePeriod, + }, + l.md.tlsConfig, + ) + + l.Listener = ln + return +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + if val, ok := md[keepAlivePeriod]; ok { + m.keepAlivePeriod, _ = time.ParseDuration(val) + } + + return +} diff --git a/server/listener/tls/metadata.go b/server/listener/tls/metadata.go new file mode 100644 index 0000000..6b9d92d --- /dev/null +++ b/server/listener/tls/metadata.go @@ -0,0 +1,20 @@ +package tls + +import ( + "crypto/tls" + "time" +) + +const ( + addr = "addr" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + keepAlivePeriod = "keepAlivePeriod" +) + +type metadata struct { + addr string + tlsConfig *tls.Config + keepAlivePeriod time.Duration +} diff --git a/server/listener/ws/listener.go b/server/listener/ws/listener.go new file mode 100644 index 0000000..e7d5ce8 --- /dev/null +++ b/server/listener/ws/listener.go @@ -0,0 +1,173 @@ +package tcp + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + "time" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" + "github.com/gorilla/websocket" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + l.upgrader = &websocket.Upgrader{ + HandshakeTimeout: l.md.handshakeTimeout, + ReadBufferSize: l.md.readBufferSize, + WriteBufferSize: l.md.writeBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: l.md.enableCompression, + } + + path := l.md.path + if path == "" { + path = defaultPath + } + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: l.md.addr, + TLSConfig: l.md.tlsConfig, + Handler: mux, + ReadHeaderTimeout: l.md.readHeaderTimeout, + } + + queueSize := l.md.connQueueSize + if queueSize <= 0 { + queueSize = defaultQueueSize + } + l.connChan = make(chan net.Conn, queueSize) + l.errChan = make(chan error, 1) + + ln, err := net.Listen("tcp", l.md.addr) + if err != nil { + return + } + if l.md.tlsConfig != nil { + ln = tls.NewListener(ln, l.md.tlsConfig) + } + + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(ln) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + + select { + case err = <-l.errChan: + return + case <-time.After(100 * time.Millisecond): + } + + return +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *Listener) Close() error { + return l.srv.Close() +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + return +} + +func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) { + conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) + if err != nil { + l.logger.Error(err) + return + } + + select { + case l.connChan <- &websocketConn{Conn: conn}: + default: + conn.Close() + l.logger.Warn("connection queue is full") + } +} + +type websocketConn struct { + *websocket.Conn + rb []byte +} + +func (c *websocketConn) Read(b []byte) (n int, err error) { + if len(c.rb) == 0 { + _, c.rb, err = c.ReadMessage() + } + n = copy(b, c.rb) + c.rb = c.rb[n:] + return +} + +func (c *websocketConn) Write(b []byte) (n int, err error) { + err = c.WriteMessage(websocket.BinaryMessage, b) + n = len(b) + return +} + +func (c *websocketConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +} diff --git a/server/listener/ws/metadata.go b/server/listener/ws/metadata.go new file mode 100644 index 0000000..257364a --- /dev/null +++ b/server/listener/ws/metadata.go @@ -0,0 +1,40 @@ +package tcp + +import ( + "crypto/tls" + "net/http" + "time" +) + +const ( + addr = "addr" + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + enableCompression = "enableCompression" + responseHeader = "responseHeader" + connQueueSize = "connQueueSize" +) + +const ( + defaultPath = "/ws" + defaultQueueSize = 128 +) + +type metadata struct { + addr string + path string + tlsConfig *tls.Config + handshakeTimeout time.Duration + readHeaderTimeout time.Duration + readBufferSize int + writeBufferSize int + enableCompression bool + responseHeader http.Header + connQueueSize int +} diff --git a/utils/tcp.go b/utils/tcp.go new file mode 100644 index 0000000..140d228 --- /dev/null +++ b/utils/tcp.go @@ -0,0 +1,32 @@ +package utils + +import ( + "net" + "time" +) + +const ( + defaultKeepAlivePeriod = 180 * time.Second +) + +// TCPKeepAliveListener is a TCP listener with keep alive enabled. +type TCPKeepAliveListener struct { + KeepAlivePeriod time.Duration + *net.TCPListener +} + +func (l *TCPKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := l.AcceptTCP() + if err != nil { + return + } + + tc.SetKeepAlive(true) + period := l.KeepAlivePeriod + if period <= 0 { + period = defaultKeepAlivePeriod + } + tc.SetKeepAlivePeriod(period) + + return tc, nil +} diff --git a/utils/tls.go b/utils/tls.go new file mode 100644 index 0000000..eeea1f5 --- /dev/null +++ b/utils/tls.go @@ -0,0 +1,40 @@ +package utils + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" +) + +// LoadTLSConfig loads the certificate from cert & key files and optional client CA file. +func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + + if pool, _ := loadCA(caFile); pool != nil { + cfg.ClientCAs = pool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, nil +} + +func loadCA(caFile string) (cp *x509.CertPool, err error) { + if caFile == "" { + return + } + cp = x509.NewCertPool() + data, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + if !cp.AppendCertsFromPEM(data) { + return nil, errors.New("AppendCertsFromPEM failed") + } + return +} diff --git a/utils/ws.go b/utils/ws.go new file mode 100644 index 0000000..d4b585b --- /dev/null +++ b/utils/ws.go @@ -0,0 +1 @@ +package utils