diff --git a/cmd/gost/register.go b/cmd/gost/register.go index d92366c..0e8c6ed 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -55,6 +55,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/ftcp" _ "github.com/go-gost/gost/pkg/listener/http2" _ "github.com/go-gost/gost/pkg/listener/http2/h2" + _ "github.com/go-gost/gost/pkg/listener/http3" _ "github.com/go-gost/gost/pkg/listener/kcp" _ "github.com/go-gost/gost/pkg/listener/obfs/http" _ "github.com/go-gost/gost/pkg/listener/obfs/tls" diff --git a/pkg/internal/util/pht/conn.go b/pkg/internal/util/pht/conn.go index 9e8d7da..4247e05 100644 --- a/pkg/internal/util/pht/conn.go +++ b/pkg/internal/util/pht/conn.go @@ -158,3 +158,17 @@ func (c *clientConn) SetWriteDeadline(t time.Time) error { func (c *clientConn) SetDeadline(t time.Time) error { return nil } + +type serverConn struct { + net.Conn + remoteAddr net.Addr + localAddr net.Addr +} + +func (c *serverConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *serverConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/pkg/internal/util/pht/server.go b/pkg/internal/util/pht/server.go new file mode 100644 index 0000000..0efa227 --- /dev/null +++ b/pkg/internal/util/pht/server.go @@ -0,0 +1,343 @@ +package pht + +import ( + "bufio" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "net/http/httputil" + "os" + "strings" + "sync" + "time" + + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/logger" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" + "github.com/rs/xid" +) + +const ( + defaultBacklog = 128 +) + +type serverOptions struct { + authorizePath string + pushPath string + pullPath string + backlog int + tlsEnabled bool + tlsConfig *tls.Config + logger logger.Logger +} + +type ServerOption func(opts *serverOptions) + +func PathServerOption(authorizePath, pushPath, pullPath string) ServerOption { + return func(opts *serverOptions) { + opts.authorizePath = authorizePath + opts.pullPath = pullPath + opts.pushPath = pushPath + } +} + +func BacklogServerOption(backlog int) ServerOption { + return func(opts *serverOptions) { + opts.backlog = backlog + } +} + +func TLSConfigServerOption(tlsConfig *tls.Config) ServerOption { + return func(opts *serverOptions) { + opts.tlsConfig = tlsConfig + } +} + +func EnableTLSServerOption(enable bool) ServerOption { + return func(opts *serverOptions) { + opts.tlsEnabled = enable + } +} + +func LoggerServerOption(logger logger.Logger) ServerOption { + return func(opts *serverOptions) { + opts.logger = logger + } +} + +type Server struct { + addr net.Addr + httpServer *http.Server + http3Server *http3.Server + cqueue chan net.Conn + conns sync.Map + closed chan struct{} + + options serverOptions +} + +func NewServer(addr string, opts ...ServerOption) *Server { + var options serverOptions + for _, opt := range opts { + opt(&options) + } + if options.backlog <= 0 { + options.backlog = defaultBacklog + } + + s := &Server{ + httpServer: &http.Server{ + Addr: addr, + ReadHeaderTimeout: 30 * time.Second, + }, + cqueue: make(chan net.Conn, options.backlog), + closed: make(chan struct{}), + options: options, + } + + mux := http.NewServeMux() + mux.HandleFunc(options.authorizePath, s.handleAuthorize) + mux.HandleFunc(options.pushPath, s.handlePush) + mux.HandleFunc(options.pullPath, s.handlePull) + s.httpServer.Handler = mux + + return s +} + +func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption) *Server { + var options serverOptions + for _, opt := range opts { + opt(&options) + } + if options.backlog <= 0 { + options.backlog = defaultBacklog + } + + s := &Server{ + http3Server: &http3.Server{ + Server: &http.Server{ + Addr: addr, + TLSConfig: options.tlsConfig, + ReadHeaderTimeout: 30 * time.Second, + }, + QuicConfig: quicConfig, + }, + cqueue: make(chan net.Conn, options.backlog), + closed: make(chan struct{}), + options: options, + } + + mux := http.NewServeMux() + mux.HandleFunc(options.authorizePath, s.handleAuthorize) + mux.HandleFunc(options.pushPath, s.handlePush) + mux.HandleFunc(options.pullPath, s.handlePull) + s.http3Server.Handler = mux + + return s +} + +func (s *Server) ListenAndServe() error { + if s.http3Server != nil { + addr, err := net.ResolveUDPAddr("udp", s.http3Server.Addr) + if err != nil { + return err + } + + s.addr = addr + return s.http3Server.ListenAndServe() + } + + ln, err := net.Listen("tcp", s.httpServer.Addr) + if err != nil { + s.options.logger.Error(err) + return err + } + + s.addr = ln.Addr() + if s.options.tlsEnabled { + s.httpServer.TLSConfig = s.options.tlsConfig + ln = tls.NewListener(ln, s.options.tlsConfig) + } + + return s.httpServer.Serve(ln) +} + +func (s *Server) Accept() (conn net.Conn, err error) { + select { + case conn = <-s.cqueue: + case <-s.closed: + err = http.ErrServerClosed + } + return +} + +func (s *Server) Close() error { + select { + case <-s.closed: + return http.ErrServerClosed + default: + close(s.closed) + + if s.http3Server != nil { + return s.http3Server.Close() + } + return s.httpServer.Close() + } +} + +func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + s.options.logger.Debug(string(dump)) + } + + raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if raddr == nil { + raddr = &net.TCPAddr{} + } + + // connection id + cid := xid.New().String() + + c1, c2 := net.Pipe() + c := &serverConn{ + Conn: c1, + localAddr: s.addr, + remoteAddr: raddr, + } + + select { + case s.cqueue <- c: + default: + c.Close() + s.options.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr) + w.WriteHeader(http.StatusTooManyRequests) + return + } + + w.Write([]byte(fmt.Sprintf("token=%s", cid))) + s.conns.Store(cid, c2) +} + +func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) { + if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + s.options.logger.Debug(string(dump)) + } + + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusBadRequest) + return + } + + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + cid := r.Form.Get("token") + v, ok := s.conns.Load(cid) + if !ok { + w.WriteHeader(http.StatusForbidden) + return + } + conn := v.(net.Conn) + + br := bufio.NewReader(r.Body) + data, err := br.ReadString('\n') + if err != nil { + s.options.logger.Error(err) + conn.Close() + s.conns.Delete(cid) + w.WriteHeader(http.StatusBadRequest) + return + } + + data = strings.TrimSuffix(data, "\n") + if len(data) == 0 { + return + } + + b, err := base64.StdEncoding.DecodeString(data) + if err != nil { + s.options.logger.Error(err) + s.conns.Delete(cid) + conn.Close() + w.WriteHeader(http.StatusBadRequest) + return + } + + conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + defer conn.SetWriteDeadline(time.Time{}) + + if _, err := conn.Write(b); err != nil { + s.options.logger.Error(err) + s.conns.Delete(cid) + conn.Close() + w.WriteHeader(http.StatusGone) + } +} + +func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) { + if s.options.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + s.options.logger.Debug(string(dump)) + } + + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusBadRequest) + return + } + + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + cid := r.Form.Get("token") + v, ok := s.conns.Load(cid) + if !ok { + w.WriteHeader(http.StatusForbidden) + return + } + + conn := v.(net.Conn) + + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + b := bufpool.Get(4096) + defer bufpool.Put(b) + + for { + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + n, err := conn.Read(*b) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + s.options.logger.Error(err) + s.conns.Delete(cid) + conn.Close() + } else { + (*b)[0] = '\n' + w.Write((*b)[:1]) + } + return + } + + bw := bufio.NewWriter(w) + bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n])) + bw.WriteString("\n") + if err := bw.Flush(); err != nil { + return + } + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + } +} diff --git a/pkg/listener/http3/listener.go b/pkg/listener/http3/listener.go new file mode 100644 index 0000000..b8636ba --- /dev/null +++ b/pkg/listener/http3/listener.go @@ -0,0 +1,77 @@ +// plain http tunnel + +package pht + +import ( + "net" + + pht_util "github.com/go-gost/gost/pkg/internal/util/pht" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go" +) + +func init() { + registry.RegisterListener("http3", NewListener) +} + +type phtListener struct { + addr net.Addr + server *pht_util.Server + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &phtListener{ + logger: options.Logger, + options: options, + } +} + +func (l *phtListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr) + if err != nil { + return + } + + l.server = pht_util.NewHTTP3Server( + l.options.Addr, + &quic.Config{}, + pht_util.TLSConfigServerOption(l.options.TLSConfig), + pht_util.BacklogServerOption(l.md.backlog), + pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath), + pht_util.LoggerServerOption(l.options.Logger), + ) + + go func() { + if err := l.server.ListenAndServe(); err != nil { + l.logger.Error(err) + } + }() + + return +} + +func (l *phtListener) Accept() (conn net.Conn, err error) { + return l.server.Accept() +} + +func (l *phtListener) Addr() net.Addr { + return l.addr +} + +func (l *phtListener) Close() (err error) { + return l.server.Close() +} diff --git a/pkg/listener/http3/metadata.go b/pkg/listener/http3/metadata.go new file mode 100644 index 0000000..239077d --- /dev/null +++ b/pkg/listener/http3/metadata.go @@ -0,0 +1,51 @@ +package pht + +import ( + "strings" + + mdata "github.com/go-gost/gost/pkg/metadata" +) + +const ( + defaultAuthorizePath = "/authorize" + defaultPushPath = "/push" + defaultPullPath = "/pull" + defaultBacklog = 128 +) + +type metadata struct { + authorizePath string + pushPath string + pullPath string + backlog int +} + +func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) { + const ( + authorizePath = "authorizePath" + pushPath = "pushPath" + pullPath = "pullPath" + + backlog = "backlog" + ) + + l.md.authorizePath = mdata.GetString(md, authorizePath) + if !strings.HasPrefix(l.md.authorizePath, "/") { + l.md.authorizePath = defaultAuthorizePath + } + l.md.pushPath = mdata.GetString(md, pushPath) + if !strings.HasPrefix(l.md.pushPath, "/") { + l.md.pushPath = defaultPushPath + } + l.md.pullPath = mdata.GetString(md, pullPath) + if !strings.HasPrefix(l.md.pullPath, "/") { + l.md.pullPath = defaultPullPath + } + + l.md.backlog = mdata.GetInt(md, backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + return +} diff --git a/pkg/listener/pht/conn.go b/pkg/listener/pht/conn.go deleted file mode 100644 index 8c0e1e0..0000000 --- a/pkg/listener/pht/conn.go +++ /dev/null @@ -1,20 +0,0 @@ -package pht - -import ( - "net" -) - -// pht connection, wrapped up just like a net.Conn -type conn struct { - net.Conn - remoteAddr net.Addr - localAddr net.Addr -} - -func (c *conn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *conn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/pkg/listener/pht/listener.go b/pkg/listener/pht/listener.go index 163ad0e..ce61114 100644 --- a/pkg/listener/pht/listener.go +++ b/pkg/listener/pht/listener.go @@ -3,25 +3,13 @@ package pht import ( - "bufio" - "crypto/tls" - "encoding/base64" - "errors" - "fmt" "net" - "net/http" - "net/http/httputil" - "os" - "strings" - "sync" - "time" - "github.com/go-gost/gost/pkg/common/bufpool" + pht_util "github.com/go-gost/gost/pkg/internal/util/pht" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" - "github.com/rs/xid" ) func init() { @@ -30,12 +18,9 @@ func init() { } type phtListener struct { - tlsEnabled bool - server *http.Server addr net.Addr - conns sync.Map - cqueue chan net.Conn - errChan chan error + tlsEnabled bool + server *pht_util.Server logger logger.Logger md metadata options listener.Options @@ -69,31 +54,22 @@ func (l *phtListener) Init(md md.Metadata) (err error) { return } - ln, err := net.Listen("tcp", l.options.Addr) + l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) if err != nil { - return err - } - l.addr = ln.Addr() - - mux := http.NewServeMux() - mux.HandleFunc(l.md.authorizePath, l.handleAuthorize) - mux.HandleFunc(l.md.pushPath, l.handlePush) - mux.HandleFunc(l.md.pullPath, l.handlePull) - - l.server = &http.Server{ - Addr: l.options.Addr, - Handler: mux, - } - if l.tlsEnabled { - l.server.TLSConfig = l.options.TLSConfig - ln = tls.NewListener(ln, l.options.TLSConfig) + return } - l.cqueue = make(chan net.Conn, l.md.backlog) - l.errChan = make(chan error, 1) + l.server = pht_util.NewServer( + l.options.Addr, + pht_util.TLSConfigServerOption(l.options.TLSConfig), + pht_util.EnableTLSServerOption(l.tlsEnabled), + pht_util.BacklogServerOption(l.md.backlog), + pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath), + pht_util.LoggerServerOption(l.options.Logger), + ) go func() { - if err := l.server.Serve(ln); err != nil { + if err := l.server.ListenAndServe(); err != nil { l.logger.Error(err) } }() @@ -102,15 +78,7 @@ func (l *phtListener) Init(md md.Metadata) (err error) { } func (l *phtListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.cqueue: - case err, ok = <-l.errChan: - if !ok { - err = listener.ErrClosed - } - } - return + return l.server.Accept() } func (l *phtListener) Addr() net.Addr { @@ -118,165 +86,5 @@ func (l *phtListener) Addr() net.Addr { } func (l *phtListener) Close() (err error) { - select { - case <-l.errChan: - default: - err = l.server.Close() - l.errChan <- err - close(l.errChan) - } - return nil -} - -func (l *phtListener) handleAuthorize(w http.ResponseWriter, r *http.Request) { - if l.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpRequest(r, false) - l.logger.Debug(string(dump)) - } - - raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - if raddr == nil { - raddr = &net.TCPAddr{} - } - - // connection id - cid := xid.New().String() - - c1, c2 := net.Pipe() - c := &conn{ - Conn: c1, - localAddr: l.addr, - remoteAddr: raddr, - } - - select { - case l.cqueue <- c: - default: - c.Close() - l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr) - w.WriteHeader(http.StatusTooManyRequests) - return - } - - w.Write([]byte(fmt.Sprintf("token=%s", cid))) - l.conns.Store(cid, c2) -} - -func (l *phtListener) handlePush(w http.ResponseWriter, r *http.Request) { - if l.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpRequest(r, false) - l.logger.Debug(string(dump)) - } - - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := r.ParseForm(); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - cid := r.Form.Get("token") - v, ok := l.conns.Load(cid) - if !ok { - w.WriteHeader(http.StatusForbidden) - return - } - conn := v.(net.Conn) - - br := bufio.NewReader(r.Body) - data, err := br.ReadString('\n') - if err != nil { - l.logger.Error(err) - conn.Close() - l.conns.Delete(cid) - w.WriteHeader(http.StatusBadRequest) - return - } - - data = strings.TrimSuffix(data, "\n") - if len(data) == 0 { - return - } - - b, err := base64.StdEncoding.DecodeString(data) - if err != nil { - l.logger.Error(err) - l.conns.Delete(cid) - conn.Close() - w.WriteHeader(http.StatusBadRequest) - return - } - - conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) - defer conn.SetWriteDeadline(time.Time{}) - - if _, err := conn.Write(b); err != nil { - l.logger.Error(err) - l.conns.Delete(cid) - conn.Close() - w.WriteHeader(http.StatusGone) - } -} - -func (l *phtListener) handlePull(w http.ResponseWriter, r *http.Request) { - if l.logger.IsLevelEnabled(logger.DebugLevel) { - dump, _ := httputil.DumpRequest(r, false) - l.logger.Debug(string(dump)) - } - - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := r.ParseForm(); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - cid := r.Form.Get("token") - v, ok := l.conns.Load(cid) - if !ok { - w.WriteHeader(http.StatusForbidden) - return - } - - conn := v.(net.Conn) - - w.WriteHeader(http.StatusOK) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - - b := bufpool.Get(4096) - defer bufpool.Put(b) - - for { - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) - n, err := conn.Read(*b) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - l.logger.Error(err) - l.conns.Delete(cid) - conn.Close() - } else { - (*b)[0] = '\n' - w.Write((*b)[:1]) - } - return - } - - bw := bufio.NewWriter(w) - bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n])) - bw.WriteString("\n") - if err := bw.Flush(); err != nil { - return - } - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - } + return l.server.Close() }