add quic dialer

This commit is contained in:
ginuerzh 2021-12-17 00:23:08 +08:00
parent c7f5da6ac7
commit 965c6846dd
15 changed files with 385 additions and 71 deletions

View File

@ -19,6 +19,7 @@ import (
_ "github.com/go-gost/gost/pkg/dialer/kcp" _ "github.com/go-gost/gost/pkg/dialer/kcp"
_ "github.com/go-gost/gost/pkg/dialer/obfs/http" _ "github.com/go-gost/gost/pkg/dialer/obfs/http"
_ "github.com/go-gost/gost/pkg/dialer/obfs/tls" _ "github.com/go-gost/gost/pkg/dialer/obfs/tls"
_ "github.com/go-gost/gost/pkg/dialer/quic"
_ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/tcp"
_ "github.com/go-gost/gost/pkg/dialer/udp" _ "github.com/go-gost/gost/pkg/dialer/udp"

View File

@ -17,7 +17,7 @@ var (
// LoadServerConfig loads the certificate from cert & key files and optional client CA file. // LoadServerConfig loads the certificate from cert & key files and optional client CA file.
func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) { func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" && keyFile == "" { if certFile == "" && keyFile == "" {
return DefaultConfig, nil return DefaultConfig.Clone(), nil
} }
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)

View File

@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/connector"
http2_util "github.com/go-gost/gost/pkg/internal/http2" http2_util "github.com/go-gost/gost/pkg/internal/util/http2"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"

View File

@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/dialer"
http2_util "github.com/go-gost/gost/pkg/internal/http2" http2_util "github.com/go-gost/gost/pkg/internal/util/http2"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"

View File

@ -114,6 +114,11 @@ func (d *kcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
} }
session, ok := d.sessions[opts.Addr] session, ok := d.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("kcp: unrecognized connection")
}
if !ok || session.session == nil { if !ok || session.session == nil {
s, err := d.initSession(opts.Addr, conn, config) s, err := d.initSession(opts.Addr, conn, config)
if err != nil { if err != nil {
@ -138,7 +143,7 @@ func (d *kcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
func (d *kcpDialer) initSession(addr string, conn net.Conn, config *kcp_util.Config) (*muxSession, error) { func (d *kcpDialer) initSession(addr string, conn net.Conn, config *kcp_util.Config) (*muxSession, error) {
pc, ok := conn.(net.PacketConn) pc, ok := conn.(net.PacketConn)
if !ok { if !ok {
return nil, errors.New("wrong connection type") return nil, errors.New("kcp: wrong connection type")
} }
kcpconn, err := kcp.NewConn(addr, kcpconn, err := kcp.NewConn(addr,

43
pkg/dialer/quic/conn.go Normal file
View File

@ -0,0 +1,43 @@
package quic
import (
"context"
"net"
"github.com/lucas-clemente/quic-go"
)
type quicSession struct {
conn net.Conn
session quic.Session
}
func (session *quicSession) GetConn() (*quicConn, error) {
stream, err := session.session.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
return &quicConn{
Stream: stream,
laddr: session.session.LocalAddr(),
raddr: session.session.RemoteAddr(),
}, nil
}
func (session *quicSession) Close() error {
return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed")
}
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}

153
pkg/dialer/quic/dialer.go Normal file
View File

@ -0,0 +1,153 @@
package quic
import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/go-gost/gost/pkg/dialer"
quic_util "github.com/go-gost/gost/pkg/internal/util/quic"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/lucas-clemente/quic-go"
)
func init() {
registry.RegisterDialer("quic", NewDialer)
}
type quicDialer struct {
sessions map[string]*quicSession
sessionMutex sync.Mutex
logger logger.Logger
md metadata
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &quicDialer{
sessions: make(map[string]*quicSession),
logger: options.Logger,
}
}
func (d *quicDialer) Init(md md.Metadata) (err error) {
if err = d.parseMetadata(md); err != nil {
return
}
return nil
}
// IsMultiplex implements dialer.Multiplexer interface.
func (d *quicDialer) IsMultiplex() bool {
return true
}
func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
session, ok := d.sessions[addr]
if !ok {
var cc *net.UDPConn
cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return
}
conn = cc
if d.md.cipherKey != nil {
conn = quic_util.CipherConn(cc, d.md.cipherKey)
}
session = &quicSession{conn: conn}
d.sessions[addr] = session
}
return session.conn, err
}
// Handshake implements dialer.Handshaker
func (d *quicDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) {
opts := &dialer.HandshakeOptions{}
for _, option := range options {
option(opts)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
if d.md.handshakeTimeout > 0 {
conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout))
defer conn.SetDeadline(time.Time{})
}
session, ok := d.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("quic: unrecognized connection")
}
if !ok || session.session == nil {
s, err := d.initSession(ctx, opts.Addr, conn)
if err != nil {
d.logger.Error(err)
conn.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
session = s
d.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*quicSession, error) {
pc, ok := conn.(net.PacketConn)
if !ok {
return nil, errors.New("quic: wrong connection type")
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
KeepAlive: d.md.keepAlive,
HandshakeIdleTimeout: d.md.handshakeTimeout,
MaxIdleTimeout: d.md.maxIdleTimeout,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
}
tlsCfg := d.md.tlsConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig)
if err != nil {
d.logger.Error(err)
return nil, err
}
return &quicSession{conn: conn, session: session}, nil
}

View File

@ -0,0 +1,58 @@
package quic
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
keepAlive bool
maxIdleTimeout time.Duration
handshakeTimeout time.Duration
cipherKey []byte
tlsConfig *tls.Config
}
func (d *quicDialer) parseMetadata(md md.Metadata) (err error) {
const (
keepAlive = "keepAlive"
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
cipherKey = "cipherKey"
)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
if key := md.GetString(cipherKey); key != "" {
d.md.cipherKey = []byte(key)
}
sn, _, _ := net.SplitHostPort(md.GetString(serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.keepAlive = md.GetBool(keepAlive)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
d.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout)
return
}

View File

@ -18,7 +18,7 @@ import (
"github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
http2_util "github.com/go-gost/gost/pkg/internal/http2" http2_util "github.com/go-gost/gost/pkg/internal/util/http2"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"

View File

@ -7,36 +7,29 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"github.com/lucas-clemente/quic-go"
) )
type quicConn struct { type cipherConn struct {
quic.Session *net.UDPConn
quic.Stream
}
func QUICConn(session quic.Session, stream quic.Stream) net.Conn {
return &quicConn{
Session: session,
Stream: stream,
}
}
type quicCipherConn struct {
net.PacketConn
key []byte key []byte
} }
func QUICCipherConn(conn net.PacketConn, key []byte) net.PacketConn { func CipherConn(conn *net.UDPConn, key []byte) net.Conn {
return &quicCipherConn{ return &cipherConn{
PacketConn: conn, UDPConn: conn,
key: key, key: key,
} }
} }
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { func CipherPacketConn(conn *net.UDPConn, key []byte) net.PacketConn {
n, addr, err = conn.PacketConn.ReadFrom(data) return &cipherConn{
UDPConn: conn,
key: key,
}
}
func (conn *cipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
n, addr, err = conn.UDPConn.ReadFrom(data)
if err != nil { if err != nil {
return return
} }
@ -50,13 +43,13 @@ func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err err
return len(b), addr, nil return len(b), addr, nil
} }
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { func (conn *cipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
b, err := conn.encrypt(data) b, err := conn.encrypt(data)
if err != nil { if err != nil {
return return
} }
_, err = conn.PacketConn.WriteTo(b, addr) _, err = conn.UDPConn.WriteTo(b, addr)
if err != nil { if err != nil {
return return
} }
@ -64,7 +57,7 @@ func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err erro
return len(b), nil return len(b), nil
} }
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { func (conn *cipherConn) encrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key) c, err := aes.NewCipher(conn.key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -83,7 +76,7 @@ func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
return gcm.Seal(nonce, nonce, data, nil), nil return gcm.Seal(nonce, nonce, data, nil), nil
} }
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { func (conn *cipherConn) decrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key) c, err := aes.NewCipher(conn.key)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"github.com/go-gost/gost/pkg/common/util" "github.com/go-gost/gost/pkg/common/util"
http2_util "github.com/go-gost/gost/pkg/internal/http2" http2_util "github.com/go-gost/gost/pkg/internal/util/http2"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"

21
pkg/listener/quic/conn.go Normal file
View File

@ -0,0 +1,21 @@
package quic
import (
"net"
"github.com/lucas-clemente/quic-go"
)
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"net" "net"
quic_util "github.com/go-gost/gost/pkg/common/util/quic" quic_util "github.com/go-gost/gost/pkg/internal/util/quic"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -18,11 +18,11 @@ func init() {
type quicListener struct { type quicListener struct {
addr string addr string
md metadata
ln quic.Listener ln quic.Listener
connChan chan net.Conn cqueue chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
md metadata
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -46,29 +46,37 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
return return
} }
var conn net.PacketConn uc, err := net.ListenUDP("udp", laddr)
conn, err = net.ListenUDP("udp", laddr)
if err != nil { if err != nil {
return return
} }
var conn net.PacketConn = uc
if l.md.cipherKey != nil { if l.md.cipherKey != nil {
conn = quic_util.QUICCipherConn(conn, l.md.cipherKey) conn = quic_util.CipherPacketConn(uc, l.md.cipherKey)
} }
config := &quic.Config{ config := &quic.Config{
KeepAlive: l.md.keepAlive, KeepAlive: l.md.keepAlive,
HandshakeIdleTimeout: l.md.HandshakeTimeout, HandshakeIdleTimeout: l.md.handshakeTimeout,
MaxIdleTimeout: l.md.MaxIdleTimeout, MaxIdleTimeout: l.md.maxIdleTimeout,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
} }
ln, err := quic.Listen(conn, l.md.tlsConfig, config) tlsCfg := l.md.tlsConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
ln, err := quic.Listen(conn, tlsCfg, config)
if err != nil { if err != nil {
return return
} }
l.ln = ln l.ln = ln
l.connChan = make(chan net.Conn, l.md.connQueueSize) l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
go l.listenLoop() go l.listenLoop()
@ -79,7 +87,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
func (l *quicListener) Accept() (conn net.Conn, err error) { func (l *quicListener) Accept() (conn net.Conn, err error) {
var ok bool var ok bool
select { select {
case conn = <-l.connChan: case conn = <-l.cqueue:
case err, ok = <-l.errChan: case err, ok = <-l.errChan:
if !ok { if !ok {
err = listener.ErrClosed err = listener.ErrClosed
@ -111,7 +119,7 @@ func (l *quicListener) listenLoop() {
} }
func (l *quicListener) mux(ctx context.Context, session quic.Session) { func (l *quicListener) mux(ctx context.Context, session quic.Session) {
defer session.CloseWithError(0, "") defer session.CloseWithError(0, "closed")
for { for {
stream, err := session.AcceptStream(ctx) stream, err := session.AcceptStream(ctx)
@ -120,19 +128,18 @@ func (l *quicListener) mux(ctx context.Context, session quic.Session) {
return return
} }
conn := quic_util.QUICConn(session, stream) conn := &quicConn{
Stream: stream,
laddr: session.LocalAddr(),
raddr: session.RemoteAddr(),
}
select { select {
case l.connChan <- conn: case l.cqueue <- conn:
case <-stream.Context().Done(): case <-stream.Context().Done():
stream.Close() stream.Close()
default: default:
stream.Close() stream.Close()
l.logger.Error("connection queue is full") l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr())
} }
} }
} }
func (l *quicListener) parseMetadata(md md.Metadata) (err error) {
return
}

View File

@ -3,27 +3,60 @@ package quic
import ( import (
"crypto/tls" "crypto/tls"
"time" "time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
) )
const ( const (
defaultBacklog = 128
)
type metadata struct {
keepAlive bool
handshakeTimeout time.Duration
maxIdleTimeout time.Duration
tlsConfig *tls.Config
cipherKey []byte
backlog int
}
func (l *quicListener) parseMetadata(md md.Metadata) (err error) {
const (
keepAlive = "keepAlive"
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
certFile = "certFile" certFile = "certFile"
keyFile = "keyFile" keyFile = "keyFile"
caFile = "caFile" caFile = "caFile"
keepAlive = "keepAlive" backlog = "backlog"
keepAlivePeriod = "keepAlivePeriod" cipherKey = "cipherKey"
) )
const ( l.md.tlsConfig, err = tls_util.LoadServerConfig(
defaultKeepAlivePeriod = 180 * time.Second md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
) )
type metadata struct { if err != nil {
tlsConfig *tls.Config return
keepAlive bool }
HandshakeTimeout time.Duration l.md.backlog = md.GetInt(backlog)
MaxIdleTimeout time.Duration if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
cipherKey []byte }
connQueueSize int
if key := md.GetString(cipherKey); key != "" {
l.md.cipherKey = []byte(key)
}
l.md.keepAlive = md.GetBool(keepAlive)
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.maxIdleTimeout = md.GetDuration(maxIdleTimeout)
return
} }