remove nodelay option for tunnel

This commit is contained in:
ginuerzh 2023-11-16 20:36:17 +08:00
parent 9584bdbf4c
commit f5a20fd0fc
5 changed files with 35 additions and 96 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
"github.com/go-gost/relay" "github.com/go-gost/relay"
xrelay "github.com/go-gost/x/internal/util/relay"
) )
type tcpConn struct { type tcpConn struct {
@ -129,7 +130,7 @@ func readResponse(r io.Reader) (err error) {
} }
if resp.Status != relay.StatusOK { if resp.Status != relay.StatusOK {
err = fmt.Errorf("status %d", resp.Status) err = fmt.Errorf("%d %s", resp.Status, xrelay.StatusText(resp.Status))
return return
} }
return nil return nil
@ -223,16 +224,3 @@ func (c *bindUDPConn) RemoteAddr() net.Addr {
func (c *bindUDPConn) Metadata() mdata.Metadata { func (c *bindUDPConn) Metadata() mdata.Metadata {
return c.md return c.md
} }
type bindAddr struct {
network string
addr string
}
func (p *bindAddr) Network() string {
return p.network
}
func (p *bindAddr) String() string {
return p.addr
}

View File

@ -1,67 +1,24 @@
package tunnel package tunnel
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"net" "net"
"sync"
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
"github.com/go-gost/relay" "github.com/go-gost/relay"
xrelay "github.com/go-gost/x/internal/util/relay"
) )
type tcpConn struct {
net.Conn
wbuf *bytes.Buffer
once sync.Once
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
return
}
return c.Conn.Read(b)
}
func (c *tcpConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf != nil && c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.Conn.Write(c.wbuf.Bytes())
c.wbuf.Reset()
return
}
_, err = c.Conn.Write(b)
return
}
type udpConn struct { type udpConn struct {
net.Conn net.Conn
wbuf *bytes.Buffer
once sync.Once
} }
func (c *udpConn) Read(b []byte) (n int, err error) { func (c *udpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
return
}
var bb [2]byte var bb [2]byte
_, err = io.ReadFull(c.Conn, bb[:]) _, err = io.ReadFull(c.Conn, bb[:])
if err != nil { if err != nil {
@ -88,14 +45,6 @@ func (c *udpConn) Write(b []byte) (n int, err error) {
} }
n = len(b) n = len(b)
if c.wbuf != nil && c.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
c.wbuf.Write(bb[:])
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
return
}
var bb [2]byte var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b))) binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
@ -119,7 +68,7 @@ func readResponse(r io.Reader) (err error) {
} }
if resp.Status != relay.StatusOK { if resp.Status != relay.StatusOK {
err = fmt.Errorf("status %d", resp.Status) err = fmt.Errorf("%d %s", resp.Status, xrelay.StatusText(resp.Status))
return return
} }
return nil return nil

View File

@ -1,7 +1,6 @@
package tunnel package tunnel
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net" "net"
@ -90,39 +89,20 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a
ID: c.md.tunnelID.ID(), ID: c.md.tunnelID.ID(),
}) })
if c.md.noDelay { if _, err := req.WriteTo(conn); err != nil {
if _, err := req.WriteTo(conn); err != nil { return nil, err
return nil, err }
} // drain the response
// drain the response if err := readResponse(conn); err != nil {
if err := readResponse(conn); err != nil { return nil, err
return nil, err
}
} }
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if !c.md.noDelay {
cc := &tcpConn{
Conn: conn,
wbuf: &bytes.Buffer{},
}
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
conn = cc
}
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
cc := &udpConn{ conn = &udpConn{
Conn: conn, Conn: conn,
} }
if !c.md.noDelay {
cc.wbuf = &bytes.Buffer{}
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
}
conn = cc
default: default:
err := fmt.Errorf("network %s is unsupported", network) err := fmt.Errorf("network %s is unsupported", network)
log.Error(err) log.Error(err)

View File

@ -18,13 +18,11 @@ var (
type metadata struct { type metadata struct {
connectTimeout time.Duration connectTimeout time.Duration
tunnelID relay.TunnelID tunnelID relay.TunnelID
noDelay bool
muxCfg *mux.Config muxCfg *mux.Config
} }
func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) { func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) {
c.md.connectTimeout = mdutil.GetDuration(md, "connectTimeout") c.md.connectTimeout = mdutil.GetDuration(md, "connectTimeout")
c.md.noDelay = mdutil.GetBool(md, "nodelay")
if s := mdutil.GetString(md, "tunnelID", "tunnel.id"); s != "" { if s := mdutil.GetString(md, "tunnelID", "tunnel.id"); s != "" {
uuid, err := uuid.Parse(s) uuid, err := uuid.Parse(s)

View File

@ -6,8 +6,32 @@ import (
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/relay"
) )
func StatusText(code uint8) string {
switch code {
case relay.StatusBadRequest:
return "Bad Request"
case relay.StatusForbidden:
return "Forbidden"
case relay.StatusHostUnreachable:
return "Host Unreachable"
case relay.StatusInternalServerError:
return "Internal Server Error"
case relay.StatusNetworkUnreachable:
return "Network Unreachable"
case relay.StatusServiceUnavailable:
return "Service Unavailable"
case relay.StatusTimeout:
return "Timeout"
case relay.StatusUnauthorized:
return "Unauthorized"
default:
return ""
}
}
type udpTunConn struct { type udpTunConn struct {
net.Conn net.Conn
taddr net.Addr taddr net.Addr