fix pht connection
This commit is contained in:
parent
669e80b780
commit
dc2fe32a2a
@ -77,6 +77,8 @@ func (c *Client) authorize(ctx context.Context, addr string) (token string, err
|
|||||||
if c.Logger.IsLevelEnabled(logger.TraceLevel) {
|
if c.Logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
c.Logger.Trace(string(dump))
|
c.Logger.Trace(string(dump))
|
||||||
|
} else if c.Logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
c.Logger.Debugf("%s %s", r.Method, r.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.Client.Do(r)
|
resp, err := c.Client.Do(r)
|
||||||
|
@ -5,9 +5,11 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
@ -20,6 +22,7 @@ type clientConn struct {
|
|||||||
buf []byte
|
buf []byte
|
||||||
rxc chan []byte
|
rxc chan []byte
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
mu sync.Mutex
|
||||||
localAddr net.Addr
|
localAddr net.Addr
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
@ -30,7 +33,7 @@ func (c *clientConn) Read(b []byte) (n int, err error) {
|
|||||||
select {
|
select {
|
||||||
case c.buf = <-c.rxc:
|
case c.buf = <-c.rxc:
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
err = net.ErrClosed
|
err = io.ErrClosedPipe
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -45,21 +48,35 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
|
|||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
return c.write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) write(b []byte) (n int, err error) {
|
||||||
|
if c.isClosed() {
|
||||||
|
err = io.ErrClosedPipe
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var r io.Reader
|
||||||
|
if len(b) > 0 {
|
||||||
buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b))
|
buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b))
|
||||||
buf.WriteByte('\n')
|
buf.WriteByte('\n')
|
||||||
|
r = buf
|
||||||
|
}
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodPost, c.pushURL, buf)
|
req, err := http.NewRequest(http.MethodPost, c.pushURL, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.logger.IsLevelEnabled(logger.TraceLevel) {
|
if c.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(req, false)
|
||||||
c.logger.Trace(string(dump))
|
c.logger.Trace(string(dump))
|
||||||
|
} else if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
c.logger.Debugf("%s %s", req.Method, req.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.Do(r)
|
resp, err := c.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -80,9 +97,12 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) readLoop() {
|
func (c *clientConn) readLoop() {
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
if c.isClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := true
|
||||||
err := func() error {
|
err := func() error {
|
||||||
r, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
|
r, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -92,6 +112,8 @@ func (c *clientConn) readLoop() {
|
|||||||
if c.logger.IsLevelEnabled(logger.TraceLevel) {
|
if c.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
c.logger.Trace(string(dump))
|
c.logger.Trace(string(dump))
|
||||||
|
} else if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
c.logger.Debugf("%s %s", r.Method, r.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.Do(r)
|
resp, err := c.client.Do(r)
|
||||||
@ -111,6 +133,11 @@ func (c *clientConn) readLoop() {
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
done = false
|
||||||
|
if scanner.Text() == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
b, err := base64.StdEncoding.DecodeString(scanner.Text())
|
b, err := base64.StdEncoding.DecodeString(scanner.Text())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -121,12 +148,15 @@ func (c *clientConn) readLoop() {
|
|||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanner.Err()
|
return scanner.Err()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error(err)
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if done { // server connection closed
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,12 +171,32 @@ func (c *clientConn) RemoteAddr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) Close() error {
|
func (c *clientConn) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
close(c.closed)
|
close(c.closed)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
_, err := c.write(nil)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientConn) isClosed() bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) SetReadDeadline(t time.Time) error {
|
func (c *clientConn) SetReadDeadline(t time.Time) error {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
@ -24,6 +25,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultBacklog = 128
|
defaultBacklog = 128
|
||||||
|
defaultReadBufferSize = 32 * 1024
|
||||||
|
defaultReadTimeout = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type serverOptions struct {
|
type serverOptions struct {
|
||||||
@ -33,6 +36,8 @@ type serverOptions struct {
|
|||||||
backlog int
|
backlog int
|
||||||
tlsEnabled bool
|
tlsEnabled bool
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
readBufferSize int
|
||||||
|
readTimeout time.Duration
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +69,18 @@ func EnableTLSServerOption(enable bool) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ReadBufferSizeServerOption(n int) ServerOption {
|
||||||
|
return func(opts *serverOptions) {
|
||||||
|
opts.readBufferSize = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadTimeoutServerOption(timeout time.Duration) ServerOption {
|
||||||
|
return func(opts *serverOptions) {
|
||||||
|
opts.readTimeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func LoggerServerOption(logger logger.Logger) ServerOption {
|
func LoggerServerOption(logger logger.Logger) ServerOption {
|
||||||
return func(opts *serverOptions) {
|
return func(opts *serverOptions) {
|
||||||
opts.logger = logger
|
opts.logger = logger
|
||||||
@ -90,6 +107,12 @@ func NewServer(addr string, opts ...ServerOption) *Server {
|
|||||||
if options.backlog <= 0 {
|
if options.backlog <= 0 {
|
||||||
options.backlog = defaultBacklog
|
options.backlog = defaultBacklog
|
||||||
}
|
}
|
||||||
|
if options.readBufferSize <= 0 {
|
||||||
|
options.readBufferSize = defaultReadBufferSize
|
||||||
|
}
|
||||||
|
if options.readTimeout <= 0 {
|
||||||
|
options.readTimeout = defaultReadTimeout
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
httpServer: &http.Server{
|
httpServer: &http.Server{
|
||||||
@ -118,6 +141,12 @@ func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption)
|
|||||||
if options.backlog <= 0 {
|
if options.backlog <= 0 {
|
||||||
options.backlog = defaultBacklog
|
options.backlog = defaultBacklog
|
||||||
}
|
}
|
||||||
|
if options.readBufferSize <= 0 {
|
||||||
|
options.readBufferSize = defaultReadBufferSize
|
||||||
|
}
|
||||||
|
if options.readTimeout <= 0 {
|
||||||
|
options.readTimeout = defaultReadTimeout
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
http3Server: &http3.Server{
|
http3Server: &http3.Server{
|
||||||
@ -200,6 +229,8 @@ func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
|||||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
s.options.logger.Trace(string(dump))
|
s.options.logger.Trace(string(dump))
|
||||||
|
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||||
@ -234,6 +265,8 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
|
|||||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
s.options.logger.Trace(string(dump))
|
s.options.logger.Trace(string(dump))
|
||||||
|
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -257,10 +290,12 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
|
|||||||
br := bufio.NewReader(r.Body)
|
br := bufio.NewReader(r.Body)
|
||||||
data, err := br.ReadString('\n')
|
data, err := br.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
s.options.logger.Error(err)
|
s.options.logger.Error(err)
|
||||||
|
w.WriteHeader(http.StatusPartialContent)
|
||||||
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
s.conns.Delete(cid)
|
s.conns.Delete(cid)
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,6 +328,8 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
|||||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
s.options.logger.Trace(string(dump))
|
s.options.logger.Trace(string(dump))
|
||||||
|
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
@ -319,24 +356,13 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
|||||||
fw.Flush()
|
fw.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
b := bufpool.Get(4096)
|
b := bufpool.Get(s.options.readBufferSize)
|
||||||
defer bufpool.Put(b)
|
defer bufpool.Put(b)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
conn.SetReadDeadline(time.Now().Add(s.options.readTimeout))
|
||||||
n, err := conn.Read(*b)
|
n, err := conn.Read(*b)
|
||||||
if err != nil {
|
if n > 0 {
|
||||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
|
||||||
s.options.logger.Error(err)
|
|
||||||
s.conns.Delete(cid)
|
|
||||||
conn.Close()
|
|
||||||
} else {
|
|
||||||
(*b)[0] = '\n'
|
|
||||||
w.Write((*b)[:1])
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
bw := bufio.NewWriter(w)
|
bw := bufio.NewWriter(w)
|
||||||
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
|
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
|
||||||
bw.WriteString("\n")
|
bw.WriteString("\n")
|
||||||
@ -347,4 +373,20 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
|||||||
fw.Flush()
|
fw.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
(*b)[0] = '\n' // no data
|
||||||
|
w.Write((*b)[:1])
|
||||||
|
} else if errors.Is(err, io.EOF) {
|
||||||
|
// server connection closed
|
||||||
|
} else {
|
||||||
|
if !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
s.options.logger.Error(err)
|
||||||
|
}
|
||||||
|
s.conns.Delete(cid)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user