package quic import ( "context" "net" "strings" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" quic_util "github.com/go-gost/x/internal/util/quic" limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" stats "github.com/go-gost/x/observer/stats/wrapper" "github.com/go-gost/x/registry" "github.com/quic-go/quic-go" ) func init() { registry.ListenerRegistry().Register("quic", NewListener) } type quicListener struct { ln quic.EarlyListener cqueue chan net.Conn errChan chan error logger logger.Logger md metadata options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { options := listener.Options{} for _, opt := range opts { opt(&options) } return &quicListener{ logger: options.Logger, options: options, } } func (l *quicListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return } addr := l.options.Addr if _, _, err := net.SplitHostPort(addr); err != nil { addr = net.JoinHostPort(strings.Trim(addr, "[]"), "0") } network := "udp" if xnet.IsIPv4(addr) { network = "udp4" } laddr, err := net.ResolveUDPAddr(network, addr) if err != nil { return } var conn net.PacketConn conn, err = net.ListenUDP(network, laddr) if err != nil { return } if l.md.cipherKey != nil { conn = quic_util.CipherPacketConn(conn, l.md.cipherKey) } config := &quic.Config{ KeepAlivePeriod: l.md.keepAlivePeriod, HandshakeIdleTimeout: l.md.handshakeTimeout, MaxIdleTimeout: l.md.maxIdleTimeout, Versions: []quic.VersionNumber{ quic.Version1, quic.Version2, }, MaxIncomingStreams: int64(l.md.maxStreams), } tlsCfg := l.options.TLSConfig tlsCfg.NextProtos = []string{"h3", "quic/v1"} ln, err := quic.ListenEarly(conn, tlsCfg, config) if err != nil { return } l.ln = *ln l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) go l.listenLoop() return } func (l *quicListener) Accept() (conn net.Conn, err error) { var ok bool select { case conn = <-l.cqueue: conn = metrics.WrapConn(l.options.Service, conn) conn = stats.WrapConn(conn, l.options.Stats) conn = admission.WrapConn(l.options.Admission, conn) conn = limiter.WrapConn(l.options.TrafficLimiter, conn) case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed } } return } func (l *quicListener) Close() error { return l.ln.Close() } func (l *quicListener) Addr() net.Addr { return l.ln.Addr() } func (l *quicListener) listenLoop() { for { ctx := context.Background() session, err := l.ln.Accept(ctx) if err != nil { l.logger.Error("accept:", err) l.errChan <- err close(l.errChan) return } go l.mux(ctx, session) } } func (l *quicListener) mux(ctx context.Context, session quic.Connection) { defer session.CloseWithError(0, "closed") for { stream, err := session.AcceptStream(ctx) if err != nil { l.logger.Error("accept stream:", err) return } conn := &quicConn{ Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr(), } select { case l.cqueue <- conn: case <-stream.Context().Done(): stream.Close() default: stream.Close() l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr()) } } }