add http3 listener

This commit is contained in:
ginuerzh 2022-01-21 23:54:05 +08:00
parent 412ed7f218
commit a134026e76
7 changed files with 502 additions and 228 deletions

View File

@ -55,6 +55,7 @@ import (
_ "github.com/go-gost/gost/pkg/listener/ftcp" _ "github.com/go-gost/gost/pkg/listener/ftcp"
_ "github.com/go-gost/gost/pkg/listener/http2" _ "github.com/go-gost/gost/pkg/listener/http2"
_ "github.com/go-gost/gost/pkg/listener/http2/h2" _ "github.com/go-gost/gost/pkg/listener/http2/h2"
_ "github.com/go-gost/gost/pkg/listener/http3"
_ "github.com/go-gost/gost/pkg/listener/kcp" _ "github.com/go-gost/gost/pkg/listener/kcp"
_ "github.com/go-gost/gost/pkg/listener/obfs/http" _ "github.com/go-gost/gost/pkg/listener/obfs/http"
_ "github.com/go-gost/gost/pkg/listener/obfs/tls" _ "github.com/go-gost/gost/pkg/listener/obfs/tls"

View File

@ -158,3 +158,17 @@ func (c *clientConn) SetWriteDeadline(t time.Time) error {
func (c *clientConn) SetDeadline(t time.Time) error { func (c *clientConn) SetDeadline(t time.Time) error {
return nil return nil
} }
type serverConn struct {
net.Conn
remoteAddr net.Addr
localAddr net.Addr
}
func (c *serverConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View File

@ -0,0 +1,343 @@
package pht
import (
"bufio"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"os"
"strings"
"sync"
"time"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/logger"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/rs/xid"
)
const (
defaultBacklog = 128
)
type serverOptions struct {
authorizePath string
pushPath string
pullPath string
backlog int
tlsEnabled bool
tlsConfig *tls.Config
logger logger.Logger
}
type ServerOption func(opts *serverOptions)
func PathServerOption(authorizePath, pushPath, pullPath string) ServerOption {
return func(opts *serverOptions) {
opts.authorizePath = authorizePath
opts.pullPath = pullPath
opts.pushPath = pushPath
}
}
func BacklogServerOption(backlog int) ServerOption {
return func(opts *serverOptions) {
opts.backlog = backlog
}
}
func TLSConfigServerOption(tlsConfig *tls.Config) ServerOption {
return func(opts *serverOptions) {
opts.tlsConfig = tlsConfig
}
}
func EnableTLSServerOption(enable bool) ServerOption {
return func(opts *serverOptions) {
opts.tlsEnabled = enable
}
}
func LoggerServerOption(logger logger.Logger) ServerOption {
return func(opts *serverOptions) {
opts.logger = logger
}
}
type Server struct {
addr net.Addr
httpServer *http.Server
http3Server *http3.Server
cqueue chan net.Conn
conns sync.Map
closed chan struct{}
options serverOptions
}
func NewServer(addr string, opts ...ServerOption) *Server {
var options serverOptions
for _, opt := range opts {
opt(&options)
}
if options.backlog <= 0 {
options.backlog = defaultBacklog
}
s := &Server{
httpServer: &http.Server{
Addr: addr,
ReadHeaderTimeout: 30 * time.Second,
},
cqueue: make(chan net.Conn, options.backlog),
closed: make(chan struct{}),
options: options,
}
mux := http.NewServeMux()
mux.HandleFunc(options.authorizePath, s.handleAuthorize)
mux.HandleFunc(options.pushPath, s.handlePush)
mux.HandleFunc(options.pullPath, s.handlePull)
s.httpServer.Handler = mux
return s
}
func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption) *Server {
var options serverOptions
for _, opt := range opts {
opt(&options)
}
if options.backlog <= 0 {
options.backlog = defaultBacklog
}
s := &Server{
http3Server: &http3.Server{
Server: &http.Server{
Addr: addr,
TLSConfig: options.tlsConfig,
ReadHeaderTimeout: 30 * time.Second,
},
QuicConfig: quicConfig,
},
cqueue: make(chan net.Conn, options.backlog),
closed: make(chan struct{}),
options: options,
}
mux := http.NewServeMux()
mux.HandleFunc(options.authorizePath, s.handleAuthorize)
mux.HandleFunc(options.pushPath, s.handlePush)
mux.HandleFunc(options.pullPath, s.handlePull)
s.http3Server.Handler = mux
return s
}
func (s *Server) ListenAndServe() error {
if s.http3Server != nil {
addr, err := net.ResolveUDPAddr("udp", s.http3Server.Addr)
if err != nil {
return err
}
s.addr = addr
return s.http3Server.ListenAndServe()
}
ln, err := net.Listen("tcp", s.httpServer.Addr)
if err != nil {
s.options.logger.Error(err)
return err
}
s.addr = ln.Addr()
if s.options.tlsEnabled {
s.httpServer.TLSConfig = s.options.tlsConfig
ln = tls.NewListener(ln, s.options.tlsConfig)
}
return s.httpServer.Serve(ln)
}
func (s *Server) Accept() (conn net.Conn, err error) {
select {
case conn = <-s.cqueue:
case <-s.closed:
err = http.ErrServerClosed
}
return
}
func (s *Server) Close() error {
select {
case <-s.closed:
return http.ErrServerClosed
default:
close(s.closed)
if s.http3Server != nil {
return s.http3Server.Close()
}
return s.httpServer.Close()
}
}
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Debug(string(dump))
}
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
// connection id
cid := xid.New().String()
c1, c2 := net.Pipe()
c := &serverConn{
Conn: c1,
localAddr: s.addr,
remoteAddr: raddr,
}
select {
case s.cqueue <- c:
default:
c.Close()
s.options.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.Write([]byte(fmt.Sprintf("token=%s", cid)))
s.conns.Store(cid, c2)
}
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Debug(string(dump))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := s.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
br := bufio.NewReader(r.Body)
data, err := br.ReadString('\n')
if err != nil {
s.options.logger.Error(err)
conn.Close()
s.conns.Delete(cid)
w.WriteHeader(http.StatusBadRequest)
return
}
data = strings.TrimSuffix(data, "\n")
if len(data) == 0 {
return
}
b, err := base64.StdEncoding.DecodeString(data)
if err != nil {
s.options.logger.Error(err)
s.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusBadRequest)
return
}
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
defer conn.SetWriteDeadline(time.Time{})
if _, err := conn.Write(b); err != nil {
s.options.logger.Error(err)
s.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusGone)
}
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
s.options.logger.Debug(string(dump))
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := s.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
b := bufpool.Get(4096)
defer bufpool.Put(b)
for {
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
n, err := conn.Read(*b)
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
s.options.logger.Error(err)
s.conns.Delete(cid)
conn.Close()
} else {
(*b)[0] = '\n'
w.Write((*b)[:1])
}
return
}
bw := bufio.NewWriter(w)
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
bw.WriteString("\n")
if err := bw.Flush(); err != nil {
return
}
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
}
}

View File

@ -0,0 +1,77 @@
// plain http tunnel
package pht
import (
"net"
pht_util "github.com/go-gost/gost/pkg/internal/util/pht"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/lucas-clemente/quic-go"
)
func init() {
registry.RegisterListener("http3", NewListener)
}
type phtListener struct {
addr net.Addr
server *pht_util.Server
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtListener{
logger: options.Logger,
options: options,
}
}
func (l *phtListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
l.server = pht_util.NewHTTP3Server(
l.options.Addr,
&quic.Config{},
pht_util.TLSConfigServerOption(l.options.TLSConfig),
pht_util.BacklogServerOption(l.md.backlog),
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
pht_util.LoggerServerOption(l.options.Logger),
)
go func() {
if err := l.server.ListenAndServe(); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *phtListener) Accept() (conn net.Conn, err error) {
return l.server.Accept()
}
func (l *phtListener) Addr() net.Addr {
return l.addr
}
func (l *phtListener) Close() (err error) {
return l.server.Close()
}

View File

@ -0,0 +1,51 @@
package pht
import (
"strings"
mdata "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultAuthorizePath = "/authorize"
defaultPushPath = "/push"
defaultPullPath = "/pull"
defaultBacklog = 128
)
type metadata struct {
authorizePath string
pushPath string
pullPath string
backlog int
}
func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizePath = "authorizePath"
pushPath = "pushPath"
pullPath = "pullPath"
backlog = "backlog"
)
l.md.authorizePath = mdata.GetString(md, authorizePath)
if !strings.HasPrefix(l.md.authorizePath, "/") {
l.md.authorizePath = defaultAuthorizePath
}
l.md.pushPath = mdata.GetString(md, pushPath)
if !strings.HasPrefix(l.md.pushPath, "/") {
l.md.pushPath = defaultPushPath
}
l.md.pullPath = mdata.GetString(md, pullPath)
if !strings.HasPrefix(l.md.pullPath, "/") {
l.md.pullPath = defaultPullPath
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

View File

@ -1,20 +0,0 @@
package pht
import (
"net"
)
// pht connection, wrapped up just like a net.Conn
type conn struct {
net.Conn
remoteAddr net.Addr
localAddr net.Addr
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View File

@ -3,25 +3,13 @@
package pht package pht
import ( import (
"bufio"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net" "net"
"net/http"
"net/http/httputil"
"os"
"strings"
"sync"
"time"
"github.com/go-gost/gost/pkg/common/bufpool" pht_util "github.com/go-gost/gost/pkg/internal/util/pht"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
"github.com/rs/xid"
) )
func init() { func init() {
@ -30,12 +18,9 @@ func init() {
} }
type phtListener struct { type phtListener struct {
tlsEnabled bool
server *http.Server
addr net.Addr addr net.Addr
conns sync.Map tlsEnabled bool
cqueue chan net.Conn server *pht_util.Server
errChan chan error
logger logger.Logger logger logger.Logger
md metadata md metadata
options listener.Options options listener.Options
@ -69,31 +54,22 @@ func (l *phtListener) Init(md md.Metadata) (err error) {
return return
} }
ln, err := net.Listen("tcp", l.options.Addr) l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil { if err != nil {
return err return
}
l.addr = ln.Addr()
mux := http.NewServeMux()
mux.HandleFunc(l.md.authorizePath, l.handleAuthorize)
mux.HandleFunc(l.md.pushPath, l.handlePush)
mux.HandleFunc(l.md.pullPath, l.handlePull)
l.server = &http.Server{
Addr: l.options.Addr,
Handler: mux,
}
if l.tlsEnabled {
l.server.TLSConfig = l.options.TLSConfig
ln = tls.NewListener(ln, l.options.TLSConfig)
} }
l.cqueue = make(chan net.Conn, l.md.backlog) l.server = pht_util.NewServer(
l.errChan = make(chan error, 1) l.options.Addr,
pht_util.TLSConfigServerOption(l.options.TLSConfig),
pht_util.EnableTLSServerOption(l.tlsEnabled),
pht_util.BacklogServerOption(l.md.backlog),
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
pht_util.LoggerServerOption(l.options.Logger),
)
go func() { go func() {
if err := l.server.Serve(ln); err != nil { if err := l.server.ListenAndServe(); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
}() }()
@ -102,15 +78,7 @@ func (l *phtListener) Init(md md.Metadata) (err error) {
} }
func (l *phtListener) Accept() (conn net.Conn, err error) { func (l *phtListener) Accept() (conn net.Conn, err error) {
var ok bool return l.server.Accept()
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
} }
func (l *phtListener) Addr() net.Addr { func (l *phtListener) Addr() net.Addr {
@ -118,165 +86,5 @@ func (l *phtListener) Addr() net.Addr {
} }
func (l *phtListener) Close() (err error) { func (l *phtListener) Close() (err error) {
select { return l.server.Close()
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *phtListener) handleAuthorize(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
// connection id
cid := xid.New().String()
c1, c2 := net.Pipe()
c := &conn{
Conn: c1,
localAddr: l.addr,
remoteAddr: raddr,
}
select {
case l.cqueue <- c:
default:
c.Close()
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.Write([]byte(fmt.Sprintf("token=%s", cid)))
l.conns.Store(cid, c2)
}
func (l *phtListener) handlePush(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := l.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
br := bufio.NewReader(r.Body)
data, err := br.ReadString('\n')
if err != nil {
l.logger.Error(err)
conn.Close()
l.conns.Delete(cid)
w.WriteHeader(http.StatusBadRequest)
return
}
data = strings.TrimSuffix(data, "\n")
if len(data) == 0 {
return
}
b, err := base64.StdEncoding.DecodeString(data)
if err != nil {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusBadRequest)
return
}
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
defer conn.SetWriteDeadline(time.Time{})
if _, err := conn.Write(b); err != nil {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusGone)
}
}
func (l *phtListener) handlePull(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := l.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
b := bufpool.Get(4096)
defer bufpool.Put(b)
for {
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
n, err := conn.Read(*b)
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
} else {
(*b)[0] = '\n'
w.Write((*b)[:1])
}
return
}
bw := bufio.NewWriter(w)
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
bw.WriteString("\n")
if err := bw.Flush(); err != nil {
return
}
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
}
} }