add tls and ws listener

This commit is contained in:
ginuerzh 2021-03-31 23:24:51 +08:00
parent b74e4cc8a4
commit 4e04e7ed86
18 changed files with 518 additions and 42 deletions

View File

@ -2,10 +2,10 @@ package client
import ( import (
"github.com/go-gost/gost/client/connector" "github.com/go-gost/gost/client/connector"
"github.com/go-gost/gost/client/transporter" "github.com/go-gost/gost/client/dialer"
) )
type Client struct { type Client struct {
Connector connector.Connector Connector connector.Connector
Transporter transporter.Transporter Dialer dialer.Dialer
} }

20
client/dialer/dialer.go Normal file
View File

@ -0,0 +1,20 @@
package dialer
import (
"context"
"net"
)
// Dialer dials to target server.
type Dialer interface {
Init(md Metadata) error
Dial(ctx context.Context, addr string) (net.Conn, error)
}
type Handshaker interface {
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
}
type Multiplexer interface {
Multiplexed() bool
}

View File

@ -0,0 +1,3 @@
package dialer
type Metadata map[string]string

17
client/dialer/option.go Normal file
View File

@ -0,0 +1,17 @@
package dialer
import (
"github.com/go-gost/gost/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}

View File

@ -0,0 +1,41 @@
package tcp
import (
"context"
"net"
"github.com/go-gost/gost/client/dialer"
"github.com/go-gost/gost/logger"
)
type Dialer struct {
md metadata
logger logger.Logger
}
func NewDialer(opts ...dialer.Option) *Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &Dialer{
logger: options.Logger,
}
}
func (d *Dialer) Init(md dialer.Metadata) (err error) {
d.md, err = d.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (d *Dialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return nil, nil
}
func (d *Dialer) parseMetadata(md dialer.Metadata) (m metadata, err error) {
return
}

View File

@ -0,0 +1,15 @@
package tcp
import "time"
const (
dialTimeout = "dialTimeout"
)
const (
defaultDialTimeout = 5 * time.Second
)
type metadata struct {
dialTimeout time.Duration
}

View File

@ -1,14 +0,0 @@
package transporter
import (
"context"
"net"
)
// Transporter is responsible for handshaking with server.
type Transporter interface {
Dial(ctx context.Context, addr string) (net.Conn, error)
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
// Indicate that the Transporter supports multiplex
Multiplex() bool
}

1
go.mod
View File

@ -5,6 +5,7 @@ go 1.16
require ( require (
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
github.com/go-gost/gosocks5 v0.3.0 github.com/go-gost/gosocks5 v0.3.0
github.com/gorilla/websocket v1.4.2
github.com/shadowsocks/go-shadowsocks2 v0.1.4 github.com/shadowsocks/go-shadowsocks2 v0.1.4
github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1

2
go.sum
View File

@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA= github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA=
github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=

17
server/listener/option.go Normal file
View File

@ -0,0 +1,17 @@
package listener
import (
"github.com/go-gost/gost/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}

View File

@ -6,16 +6,29 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener" "github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
)
var (
_ listener.Listener = (*Listener)(nil)
) )
type Listener struct { type Listener struct {
md metadata md metadata
net.Listener net.Listener
logger logger.Logger
} }
func NewTCPListener() *Listener { func NewListener(opts ...listener.Option) *Listener {
return &Listener{} options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
} }
func (l *Listener) Init(md listener.Metadata) (err error) { func (l *Listener) Init(md listener.Metadata) (err error) {
@ -34,9 +47,9 @@ func (l *Listener) Init(md listener.Metadata) (err error) {
} }
if l.md.keepAlive { if l.md.keepAlive {
l.Listener = &keepAliveListener{ l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln, TCPListener: ln,
keepAlivePeriod: l.md.keepAlivePeriod, KeepAlivePeriod: l.md.keepAlivePeriod,
} }
return return
} }
@ -49,7 +62,7 @@ func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok { if val, ok := md[addr]; ok {
m.addr = val m.addr = val
} else { } else {
err = errors.New("tcp listener: missing address") err = errors.New("missing address")
return return
} }
@ -61,26 +74,6 @@ func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[keepAlivePeriod]; ok { if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val) m.keepAlivePeriod, _ = time.ParseDuration(val)
} }
if m.keepAlivePeriod <= 0 {
m.keepAlivePeriod = defaultKeepAlivePeriod
}
return return
} }
type keepAliveListener struct {
keepAlivePeriod time.Duration
*net.TCPListener
}
func (l *keepAliveListener) Accept() (c net.Conn, err error) {
tc, err := l.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(l.keepAlivePeriod)
return tc, nil
}

View File

@ -0,0 +1,75 @@
package tls
import (
"crypto/tls"
"errors"
"net"
"time"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
l.Listener = ln
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,20 @@
package tls
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
keepAlivePeriod = "keepAlivePeriod"
)
type metadata struct {
addr string
tlsConfig *tls.Config
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,173 @@
package tcp
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
"github.com/gorilla/websocket"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
select {
case err = <-l.errChan:
return
case <-time.After(100 * time.Millisecond):
}
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-l.connChan:
case err = <-l.errChan:
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.connChan <- &websocketConn{Conn: conn}:
default:
conn.Close()
l.logger.Warn("connection queue is full")
}
}
type websocketConn struct {
*websocket.Conn
rb []byte
}
func (c *websocketConn) Read(b []byte) (n int, err error) {
if len(c.rb) == 0 {
_, c.rb, err = c.ReadMessage()
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *websocketConn) Write(b []byte) (n int, err error) {
err = c.WriteMessage(websocket.BinaryMessage, b)
n = len(b)
return
}
func (c *websocketConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}

View File

@ -0,0 +1,40 @@
package tcp
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
}

32
utils/tcp.go Normal file
View File

@ -0,0 +1,32 @@
package utils
import (
"net"
"time"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
// TCPKeepAliveListener is a TCP listener with keep alive enabled.
type TCPKeepAliveListener struct {
KeepAlivePeriod time.Duration
*net.TCPListener
}
func (l *TCPKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := l.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
period := l.KeepAlivePeriod
if period <= 0 {
period = defaultKeepAlivePeriod
}
tc.SetKeepAlivePeriod(period)
return tc, nil
}

40
utils/tls.go Normal file
View File

@ -0,0 +1,40 @@
package utils
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
)
// LoadTLSConfig loads the certificate from cert & key files and optional client CA file.
func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
if pool, _ := loadCA(caFile); pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}

1
utils/ws.go Normal file
View File

@ -0,0 +1 @@
package utils