fix pht connection
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@ -23,17 +24,21 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
defaultBacklog = 128
|
||||
defaultReadBufferSize = 32 * 1024
|
||||
defaultReadTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type serverOptions struct {
|
||||
authorizePath string
|
||||
pushPath string
|
||||
pullPath string
|
||||
backlog int
|
||||
tlsEnabled bool
|
||||
tlsConfig *tls.Config
|
||||
logger logger.Logger
|
||||
authorizePath string
|
||||
pushPath string
|
||||
pullPath string
|
||||
backlog int
|
||||
tlsEnabled bool
|
||||
tlsConfig *tls.Config
|
||||
readBufferSize int
|
||||
readTimeout time.Duration
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type ServerOption func(opts *serverOptions)
|
||||
@ -64,6 +69,18 @@ func EnableTLSServerOption(enable bool) ServerOption {
|
||||
}
|
||||
}
|
||||
|
||||
func ReadBufferSizeServerOption(n int) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.readBufferSize = n
|
||||
}
|
||||
}
|
||||
|
||||
func ReadTimeoutServerOption(timeout time.Duration) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.readTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func LoggerServerOption(logger logger.Logger) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.logger = logger
|
||||
@ -90,6 +107,12 @@ func NewServer(addr string, opts ...ServerOption) *Server {
|
||||
if options.backlog <= 0 {
|
||||
options.backlog = defaultBacklog
|
||||
}
|
||||
if options.readBufferSize <= 0 {
|
||||
options.readBufferSize = defaultReadBufferSize
|
||||
}
|
||||
if options.readTimeout <= 0 {
|
||||
options.readTimeout = defaultReadTimeout
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
httpServer: &http.Server{
|
||||
@ -118,6 +141,12 @@ func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption)
|
||||
if options.backlog <= 0 {
|
||||
options.backlog = defaultBacklog
|
||||
}
|
||||
if options.readBufferSize <= 0 {
|
||||
options.readBufferSize = defaultReadBufferSize
|
||||
}
|
||||
if options.readTimeout <= 0 {
|
||||
options.readTimeout = defaultReadTimeout
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
http3Server: &http3.Server{
|
||||
@ -200,6 +229,8 @@ func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Trace(string(dump))
|
||||
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||
}
|
||||
|
||||
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
@ -234,6 +265,8 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Trace(string(dump))
|
||||
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
@ -257,10 +290,12 @@ func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
|
||||
br := bufio.NewReader(r.Body)
|
||||
data, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
s.options.logger.Error(err)
|
||||
if err != io.EOF {
|
||||
s.options.logger.Error(err)
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
}
|
||||
conn.Close()
|
||||
s.conns.Delete(cid)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@ -293,6 +328,8 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Trace(string(dump))
|
||||
} else if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
s.options.logger.Debugf("%s %s", r.Method, r.RequestURI)
|
||||
}
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
@ -319,32 +356,37 @@ func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
||||
fw.Flush()
|
||||
}
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
b := bufpool.Get(s.options.readBufferSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
conn.SetReadDeadline(time.Now().Add(s.options.readTimeout))
|
||||
n, err := conn.Read(*b)
|
||||
if n > 0 {
|
||||
bw := bufio.NewWriter(w)
|
||||
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
|
||||
bw.WriteString("\n")
|
||||
if err := bw.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.options.logger.Error(err)
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
(*b)[0] = '\n' // no data
|
||||
w.Write((*b)[:1])
|
||||
} else if errors.Is(err, io.EOF) {
|
||||
// server connection closed
|
||||
} else {
|
||||
if !errors.Is(err, io.ErrClosedPipe) {
|
||||
s.options.logger.Error(err)
|
||||
}
|
||||
s.conns.Delete(cid)
|
||||
conn.Close()
|
||||
} else {
|
||||
(*b)[0] = '\n'
|
||||
w.Write((*b)[:1])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
bw := bufio.NewWriter(w)
|
||||
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
|
||||
bw.WriteString("\n")
|
||||
if err := bw.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user