fix connection state in tunnel entrypoint
This commit is contained in:
parent
12ef82e41f
commit
1a776dc759
@ -106,7 +106,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
if protocol == forward.ProtoHTTP {
|
if protocol == forward.ProtoHTTP {
|
||||||
h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
|
h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -334,12 +334,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("rewrite body: %v", err)
|
log.Errorf("rewrite body: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -108,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol == forward.ProtoHTTP {
|
if protocol == forward.ProtoHTTP {
|
||||||
h.handleHTTP(ctx, rw, conn.RemoteAddr(), localAddr, log)
|
h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), localAddr, log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -335,12 +335,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("rewrite body: %v", err)
|
log.Errorf("rewrite body: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -279,6 +279,9 @@ func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log l
|
|||||||
err := func() error {
|
err := func() error {
|
||||||
req, err := http.ReadRequest(bufio.NewReader(rw))
|
req, err := http.ReadRequest(bufio.NewReader(rw))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,6 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl
|
|||||||
|
|
||||||
if clientID := sc.ID(); clientID != "" {
|
if clientID := sc.ID(); clientID != "" {
|
||||||
ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID))
|
ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID))
|
||||||
|
log = log.WithFields(map[string]any{"user": clientID})
|
||||||
}
|
}
|
||||||
|
|
||||||
conn = sc
|
conn = sc
|
||||||
|
@ -123,13 +123,15 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
timeout: 15 * time.Second,
|
timeout: 15 * time.Second,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
cc, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String())
|
c, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return resp.Write(conn)
|
return resp.Write(conn)
|
||||||
}
|
}
|
||||||
log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid)
|
log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid)
|
||||||
|
|
||||||
|
cc = c
|
||||||
|
|
||||||
host := req.Host
|
host := req.Host
|
||||||
if h, _, _ := net.SplitHostPort(host); h == "" {
|
if h, _, _ := net.SplitHostPort(host); h == "" {
|
||||||
host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
|
host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
|
||||||
@ -149,17 +151,26 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
Version: relay.Version1,
|
Version: relay.Version1,
|
||||||
Status: relay.StatusOK,
|
Status: relay.StatusOK,
|
||||||
Features: features,
|
Features: features,
|
||||||
}).WriteTo(cc)
|
}).WriteTo(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := req.Write(cc); err != nil {
|
// HTTP/1.0
|
||||||
cc.Close()
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
|
if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" {
|
||||||
|
req.Header.Del("Connection")
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Connection", "close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.Write(c); err != nil {
|
||||||
|
c.Close()
|
||||||
log.Errorf("send request: %v", err)
|
log.Errorf("send request: %v", err)
|
||||||
return resp.Write(conn)
|
return resp.Write(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Header.Get("Upgrade") == "websocket" {
|
if req.Header.Get("Upgrade") == "websocket" {
|
||||||
err := xnet.Transport(cc, xio.NewReadWriter(br, conn))
|
err := xnet.Transport(c, xio.NewReadWriter(br, conn))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
@ -167,7 +178,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer cc.Close()
|
defer c.Close()
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Debugf("%s <-> %s", remoteAddr, host)
|
log.Debugf("%s <-> %s", remoteAddr, host)
|
||||||
@ -178,7 +189,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
}).Debugf("%s >-< %s", remoteAddr, host)
|
}).Debugf("%s >-< %s", remoteAddr, host)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
res, err := http.ReadResponse(bufio.NewReader(c), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("read response: %v", err)
|
log.Errorf("read response: %v", err)
|
||||||
resp.Write(conn)
|
resp.Write(conn)
|
||||||
@ -190,7 +201,21 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP/1.0
|
||||||
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
|
if !res.Close {
|
||||||
|
res.Header.Set("Connection", "keep-alive")
|
||||||
|
}
|
||||||
|
res.ProtoMajor = req.ProtoMajor
|
||||||
|
res.ProtoMinor = req.ProtoMinor
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(conn); err != nil {
|
if err = res.Write(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
log.Errorf("write response: %v", err)
|
log.Errorf("write response: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -264,6 +264,10 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl
|
|||||||
|
|
||||||
// Close implements io.Closer interface.
|
// Close implements io.Closer interface.
|
||||||
func (h *tunnelHandler) Close() error {
|
func (h *tunnelHandler) Close() error {
|
||||||
|
if h.epSvc != nil {
|
||||||
|
h.epSvc.Close()
|
||||||
|
}
|
||||||
|
h.pool.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,14 @@ func (c *Connector) Session() *mux.Session {
|
|||||||
return c.s
|
return c.s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connector) Close() error {
|
||||||
|
if c == nil || c.s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.s.Close()
|
||||||
|
}
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
node string
|
node string
|
||||||
id relay.TunnelID
|
id relay.TunnelID
|
||||||
@ -75,7 +83,7 @@ type Tunnel struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sd sd.SD
|
sd sd.SD
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
rw *selector.RandomWeighted[*Connector]
|
// rw *selector.RandomWeighted[*Connector]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
||||||
@ -85,7 +93,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
|||||||
t: time.Now(),
|
t: time.Now(),
|
||||||
close: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
rw: selector.NewRandomWeighted[*Connector](),
|
// rw: selector.NewRandomWeighted[*Connector](),
|
||||||
}
|
}
|
||||||
if t.ttl <= 0 {
|
if t.ttl <= 0 {
|
||||||
t.ttl = defaultTTL
|
t.ttl = defaultTTL
|
||||||
@ -117,8 +125,14 @@ func (t *Tunnel) GetConnector(network string) *Connector {
|
|||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
rw := t.rw
|
// rw := t.rw
|
||||||
rw.Reset()
|
// rw.Reset()
|
||||||
|
|
||||||
|
if len(t.connectors) == 1 {
|
||||||
|
return t.connectors[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := selector.NewRandomWeighted[*Connector]()
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, c := range t.connectors {
|
for _, c := range t.connectors {
|
||||||
@ -147,6 +161,22 @@ func (t *Tunnel) GetConnector(network string) *Connector {
|
|||||||
return rw.Next()
|
return rw.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) Close() error {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.close:
|
||||||
|
default:
|
||||||
|
for _, c := range t.connectors {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
close(t.close)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tunnel) CloseOnIdle() bool {
|
func (t *Tunnel) CloseOnIdle() bool {
|
||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
@ -256,6 +286,22 @@ func (p *ConnectorPool) Get(network string, tid string) *Connector {
|
|||||||
return t.GetConnector(network)
|
return t.GetConnector(network)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ConnectorPool) Close() error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
for k, v := range p.tunnels {
|
||||||
|
v.Close()
|
||||||
|
delete(p.tunnels, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ConnectorPool) closeIdles() {
|
func (p *ConnectorPool) closeIdles() {
|
||||||
ticker := time.NewTicker(1 * time.Hour)
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
@ -245,7 +245,7 @@ func (ing *localIngress) GetRule(ctx context.Context, host string, opts ...ingre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if ep != nil {
|
||||||
ing.options.logger.Debugf("ingress: %s -> %s", host, ep)
|
ing.options.logger.Debugf("ingress: %s -> %s:%s", host, ep.Hostname, ep.Endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ep
|
return ep
|
||||||
|
@ -13,3 +13,17 @@ func NewReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
|
|||||||
Writer: w,
|
Writer: w,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type readWriteCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadWriteCloser(r io.Reader, w io.Writer, c io.Closer) io.ReadWriteCloser {
|
||||||
|
return &readWriteCloser{
|
||||||
|
Reader: r,
|
||||||
|
Writer: w,
|
||||||
|
Closer: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user