diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 08d5b76..07cc7b8 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -19,6 +19,7 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/kcp" _ "github.com/go-gost/gost/pkg/dialer/obfs/http" _ "github.com/go-gost/gost/pkg/dialer/obfs/tls" + _ "github.com/go-gost/gost/pkg/dialer/quic" _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/udp" diff --git a/pkg/common/util/tls/tls.go b/pkg/common/util/tls/tls.go index c1e6c14..8fa5e3e 100644 --- a/pkg/common/util/tls/tls.go +++ b/pkg/common/util/tls/tls.go @@ -17,7 +17,7 @@ var ( // LoadServerConfig loads the certificate from cert & key files and optional client CA file. func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) { if certFile == "" && keyFile == "" { - return DefaultConfig, nil + return DefaultConfig.Clone(), nil } cert, err := tls.LoadX509KeyPair(certFile, keyFile) diff --git a/pkg/connector/http2/connector.go b/pkg/connector/http2/connector.go index a72461e..162c6d4 100644 --- a/pkg/connector/http2/connector.go +++ b/pkg/connector/http2/connector.go @@ -13,7 +13,7 @@ import ( "time" "github.com/go-gost/gost/pkg/connector" - http2_util "github.com/go-gost/gost/pkg/internal/http2" + http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go index b7398e0..52a2021 100644 --- a/pkg/dialer/http2/dialer.go +++ b/pkg/dialer/http2/dialer.go @@ -8,7 +8,7 @@ import ( "time" "github.com/go-gost/gost/pkg/dialer" - http2_util "github.com/go-gost/gost/pkg/internal/http2" + http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" diff --git a/pkg/dialer/kcp/dialer.go b/pkg/dialer/kcp/dialer.go index b0b2270..1fcd160 100644 --- a/pkg/dialer/kcp/dialer.go +++ b/pkg/dialer/kcp/dialer.go @@ -114,6 +114,11 @@ func (d *kcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia } session, ok := d.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("kcp: unrecognized connection") + } + if !ok || session.session == nil { s, err := d.initSession(opts.Addr, conn, config) if err != nil { @@ -138,7 +143,7 @@ func (d *kcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia func (d *kcpDialer) initSession(addr string, conn net.Conn, config *kcp_util.Config) (*muxSession, error) { pc, ok := conn.(net.PacketConn) if !ok { - return nil, errors.New("wrong connection type") + return nil, errors.New("kcp: wrong connection type") } kcpconn, err := kcp.NewConn(addr, diff --git a/pkg/dialer/quic/conn.go b/pkg/dialer/quic/conn.go new file mode 100644 index 0000000..e1af828 --- /dev/null +++ b/pkg/dialer/quic/conn.go @@ -0,0 +1,43 @@ +package quic + +import ( + "context" + "net" + + "github.com/lucas-clemente/quic-go" +) + +type quicSession struct { + conn net.Conn + session quic.Session +} + +func (session *quicSession) GetConn() (*quicConn, error) { + stream, err := session.session.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + return &quicConn{ + Stream: stream, + laddr: session.session.LocalAddr(), + raddr: session.session.RemoteAddr(), + }, nil +} + +func (session *quicSession) Close() error { + return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed") +} + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/dialer/quic/dialer.go b/pkg/dialer/quic/dialer.go new file mode 100644 index 0000000..0d4d761 --- /dev/null +++ b/pkg/dialer/quic/dialer.go @@ -0,0 +1,153 @@ +package quic + +import ( + "context" + "errors" + "net" + "sync" + "time" + + "github.com/go-gost/gost/pkg/dialer" + quic_util "github.com/go-gost/gost/pkg/internal/util/quic" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go" +) + +func init() { + registry.RegisterDialer("quic", NewDialer) +} + +type quicDialer struct { + sessions map[string]*quicSession + sessionMutex sync.Mutex + logger logger.Logger + md metadata +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &quicDialer{ + sessions: make(map[string]*quicSession), + logger: options.Logger, + } +} + +func (d *quicDialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// IsMultiplex implements dialer.Multiplexer interface. +func (d *quicDialer) IsMultiplex() bool { + return true +} + +func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + session, ok := d.sessions[addr] + if !ok { + var cc *net.UDPConn + cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return + } + conn = cc + + if d.md.cipherKey != nil { + conn = quic_util.CipherConn(cc, d.md.cipherKey) + } + + session = &quicSession{conn: conn} + d.sessions[addr] = session + } + + return session.conn, err +} + +// Handshake implements dialer.Handshaker +func (d *quicDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + session, ok := d.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("quic: unrecognized connection") + } + if !ok || session.session == nil { + s, err := d.initSession(ctx, opts.Addr, conn) + if err != nil { + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + session = s + d.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*quicSession, error) { + pc, ok := conn.(net.PacketConn) + if !ok { + return nil, errors.New("quic: wrong connection type") + } + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + quicConfig := &quic.Config{ + KeepAlive: d.md.keepAlive, + HandshakeIdleTimeout: d.md.handshakeTimeout, + MaxIdleTimeout: d.md.maxIdleTimeout, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, + } + + tlsCfg := d.md.tlsConfig + tlsCfg.NextProtos = []string{"http/3", "quic/v1"} + + session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig) + if err != nil { + d.logger.Error(err) + return nil, err + } + return &quicSession{conn: conn, session: session}, nil +} diff --git a/pkg/dialer/quic/metadata.go b/pkg/dialer/quic/metadata.go new file mode 100644 index 0000000..18f1bfa --- /dev/null +++ b/pkg/dialer/quic/metadata.go @@ -0,0 +1,58 @@ +package quic + +import ( + "crypto/tls" + "net" + "time" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + keepAlive bool + maxIdleTimeout time.Duration + handshakeTimeout time.Duration + + cipherKey []byte + tlsConfig *tls.Config +} + +func (d *quicDialer) parseMetadata(md md.Metadata) (err error) { + const ( + keepAlive = "keepAlive" + handshakeTimeout = "handshakeTimeout" + maxIdleTimeout = "maxIdleTimeout" + + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + secure = "secure" + serverName = "serverName" + + cipherKey = "cipherKey" + ) + + d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + + if key := md.GetString(cipherKey); key != "" { + d.md.cipherKey = []byte(key) + } + + sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + if sn == "" { + sn = "localhost" + } + d.md.tlsConfig, err = tls_util.LoadClientConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + md.GetBool(secure), + sn, + ) + + d.md.keepAlive = md.GetBool(keepAlive) + d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + d.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout) + return +} diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 27bdfc5..1a4deac 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -18,7 +18,7 @@ import ( "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" - http2_util "github.com/go-gost/gost/pkg/internal/http2" + http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" diff --git a/pkg/internal/http2/conn.go b/pkg/internal/util/http2/conn.go similarity index 100% rename from pkg/internal/http2/conn.go rename to pkg/internal/util/http2/conn.go diff --git a/pkg/common/util/quic/quic.go b/pkg/internal/util/quic/conn.go similarity index 57% rename from pkg/common/util/quic/quic.go rename to pkg/internal/util/quic/conn.go index 12d8a45..7d49a5c 100644 --- a/pkg/common/util/quic/quic.go +++ b/pkg/internal/util/quic/conn.go @@ -7,36 +7,29 @@ import ( "errors" "io" "net" - - "github.com/lucas-clemente/quic-go" ) -type quicConn struct { - quic.Session - quic.Stream -} - -func QUICConn(session quic.Session, stream quic.Stream) net.Conn { - return &quicConn{ - Session: session, - Stream: stream, - } -} - -type quicCipherConn struct { - net.PacketConn +type cipherConn struct { + *net.UDPConn key []byte } -func QUICCipherConn(conn net.PacketConn, key []byte) net.PacketConn { - return &quicCipherConn{ - PacketConn: conn, - key: key, +func CipherConn(conn *net.UDPConn, key []byte) net.Conn { + return &cipherConn{ + UDPConn: conn, + key: key, } } -func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { - n, addr, err = conn.PacketConn.ReadFrom(data) +func CipherPacketConn(conn *net.UDPConn, key []byte) net.PacketConn { + return &cipherConn{ + UDPConn: conn, + key: key, + } +} + +func (conn *cipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { + n, addr, err = conn.UDPConn.ReadFrom(data) if err != nil { return } @@ -50,13 +43,13 @@ func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err err return len(b), addr, nil } -func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { +func (conn *cipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { b, err := conn.encrypt(data) if err != nil { return } - _, err = conn.PacketConn.WriteTo(b, addr) + _, err = conn.UDPConn.WriteTo(b, addr) if err != nil { return } @@ -64,7 +57,7 @@ func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err erro return len(b), nil } -func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { +func (conn *cipherConn) encrypt(data []byte) ([]byte, error) { c, err := aes.NewCipher(conn.key) if err != nil { return nil, err @@ -83,7 +76,7 @@ func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { return gcm.Seal(nonce, nonce, data, nil), nil } -func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { +func (conn *cipherConn) decrypt(data []byte) ([]byte, error) { c, err := aes.NewCipher(conn.key) if err != nil { return nil, err diff --git a/pkg/listener/http2/listener.go b/pkg/listener/http2/listener.go index 0804f76..f5f81ac 100644 --- a/pkg/listener/http2/listener.go +++ b/pkg/listener/http2/listener.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/go-gost/gost/pkg/common/util" - http2_util "github.com/go-gost/gost/pkg/internal/http2" + http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" diff --git a/pkg/listener/quic/conn.go b/pkg/listener/quic/conn.go new file mode 100644 index 0000000..ee1a26c --- /dev/null +++ b/pkg/listener/quic/conn.go @@ -0,0 +1,21 @@ +package quic + +import ( + "net" + + "github.com/lucas-clemente/quic-go" +) + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/listener/quic/listener.go b/pkg/listener/quic/listener.go index 84cbe81..741c9a5 100644 --- a/pkg/listener/quic/listener.go +++ b/pkg/listener/quic/listener.go @@ -4,7 +4,7 @@ import ( "context" "net" - quic_util "github.com/go-gost/gost/pkg/common/util/quic" + quic_util "github.com/go-gost/gost/pkg/internal/util/quic" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -17,12 +17,12 @@ func init() { } type quicListener struct { - addr string - md metadata - ln quic.Listener - connChan chan net.Conn - errChan chan error - logger logger.Logger + addr string + ln quic.Listener + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata } func NewListener(opts ...listener.Option) listener.Listener { @@ -46,29 +46,37 @@ func (l *quicListener) Init(md md.Metadata) (err error) { return } - var conn net.PacketConn - conn, err = net.ListenUDP("udp", laddr) + uc, err := net.ListenUDP("udp", laddr) if err != nil { return } + var conn net.PacketConn = uc + if l.md.cipherKey != nil { - conn = quic_util.QUICCipherConn(conn, l.md.cipherKey) + conn = quic_util.CipherPacketConn(uc, l.md.cipherKey) } config := &quic.Config{ KeepAlive: l.md.keepAlive, - HandshakeIdleTimeout: l.md.HandshakeTimeout, - MaxIdleTimeout: l.md.MaxIdleTimeout, + HandshakeIdleTimeout: l.md.handshakeTimeout, + MaxIdleTimeout: l.md.maxIdleTimeout, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, } - ln, err := quic.Listen(conn, l.md.tlsConfig, config) + tlsCfg := l.md.tlsConfig + tlsCfg.NextProtos = []string{"http/3", "quic/v1"} + + ln, err := quic.Listen(conn, tlsCfg, config) if err != nil { return } l.ln = ln - l.connChan = make(chan net.Conn, l.md.connQueueSize) + l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) go l.listenLoop() @@ -79,7 +87,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) { func (l *quicListener) Accept() (conn net.Conn, err error) { var ok bool select { - case conn = <-l.connChan: + case conn = <-l.cqueue: case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed @@ -111,7 +119,7 @@ func (l *quicListener) listenLoop() { } func (l *quicListener) mux(ctx context.Context, session quic.Session) { - defer session.CloseWithError(0, "") + defer session.CloseWithError(0, "closed") for { stream, err := session.AcceptStream(ctx) @@ -120,19 +128,18 @@ func (l *quicListener) mux(ctx context.Context, session quic.Session) { return } - conn := quic_util.QUICConn(session, stream) + conn := &quicConn{ + Stream: stream, + laddr: session.LocalAddr(), + raddr: session.RemoteAddr(), + } select { - case l.connChan <- conn: + case l.cqueue <- conn: case <-stream.Context().Done(): stream.Close() default: stream.Close() - l.logger.Error("connection queue is full") + l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr()) } } } - -func (l *quicListener) parseMetadata(md md.Metadata) (err error) { - - return -} diff --git a/pkg/listener/quic/metadata.go b/pkg/listener/quic/metadata.go index 3b84e80..79fb4ff 100644 --- a/pkg/listener/quic/metadata.go +++ b/pkg/listener/quic/metadata.go @@ -3,27 +3,60 @@ package quic import ( "crypto/tls" "time" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - - keepAlive = "keepAlive" - keepAlivePeriod = "keepAlivePeriod" -) - -const ( - defaultKeepAlivePeriod = 180 * time.Second + defaultBacklog = 128 ) type metadata struct { - tlsConfig *tls.Config keepAlive bool - HandshakeTimeout time.Duration - MaxIdleTimeout time.Duration + handshakeTimeout time.Duration + maxIdleTimeout time.Duration - cipherKey []byte - connQueueSize int + tlsConfig *tls.Config + cipherKey []byte + backlog int +} + +func (l *quicListener) parseMetadata(md md.Metadata) (err error) { + const ( + keepAlive = "keepAlive" + handshakeTimeout = "handshakeTimeout" + maxIdleTimeout = "maxIdleTimeout" + + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + + backlog = "backlog" + cipherKey = "cipherKey" + ) + + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + + if err != nil { + return + } + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + if key := md.GetString(cipherKey); key != "" { + l.md.cipherKey = []byte(key) + } + + l.md.keepAlive = md.GetBool(keepAlive) + l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + l.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout) + + return }