update forward handler

This commit is contained in:
ginuerzh
2023-10-16 23:16:47 +08:00
parent 5ab729b166
commit 5dfbb59f8a
17 changed files with 253 additions and 174 deletions

View File

@ -17,13 +17,15 @@ import (
type tcpConn struct {
net.Conn
wbuf bytes.Buffer
wbuf *bytes.Buffer
once sync.Once
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
err = readResponse(c.Conn)
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
@ -34,7 +36,7 @@ func (c *tcpConn) Read(b []byte) (n int, err error) {
func (c *tcpConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf.Len() > 0 {
if c.wbuf != nil && c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.Conn.Write(c.wbuf.Bytes())
c.wbuf.Reset()
@ -46,13 +48,15 @@ func (c *tcpConn) Write(b []byte) (n int, err error) {
type udpConn struct {
net.Conn
wbuf bytes.Buffer
wbuf *bytes.Buffer
once sync.Once
}
func (c *udpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
err = readResponse(c.Conn)
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
return
@ -84,7 +88,7 @@ func (c *udpConn) Write(b []byte) (n int, err error) {
}
n = len(b)
if c.wbuf.Len() > 0 {
if c.wbuf != nil && c.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
c.wbuf.Write(bb[:])

View File

@ -1,6 +1,7 @@
package relay
import (
"bytes"
"context"
"fmt"
"net"
@ -121,8 +122,9 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad
if !c.md.noDelay {
cc := &tcpConn{
Conn: conn,
wbuf: &bytes.Buffer{},
}
if _, err := req.WriteTo(&cc.wbuf); err != nil {
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
conn = cc
@ -132,7 +134,8 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad
Conn: conn,
}
if !c.md.noDelay {
if _, err := req.WriteTo(&cc.wbuf); err != nil {
cc.wbuf = &bytes.Buffer{}
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
}

View File

@ -53,13 +53,16 @@ func (c *tunnelConnector) initTunnel(conn net.Conn, network, address string) (ad
}
af := &relay.AddrFeature{}
af.ParseFrom(address)
af.ParseFrom(conn.LocalAddr().String()) // src address
req.Features = append(req.Features, af)
req.Features = append(req.Features, af,
&relay.TunnelFeature{
ID: c.md.tunnelID.ID(),
},
)
af = &relay.AddrFeature{}
af.ParseFrom(address)
req.Features = append(req.Features, af) // dst address
req.Features = append(req.Features, &relay.TunnelFeature{
ID: c.md.tunnelID.ID(),
})
if _, err = req.WriteTo(conn); err != nil {
return
}

View File

@ -17,13 +17,15 @@ import (
type tcpConn struct {
net.Conn
wbuf bytes.Buffer
wbuf *bytes.Buffer
once sync.Once
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
err = readResponse(c.Conn)
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
@ -34,7 +36,7 @@ func (c *tcpConn) Read(b []byte) (n int, err error) {
func (c *tcpConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf.Len() > 0 {
if c.wbuf != nil && c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.Conn.Write(c.wbuf.Bytes())
c.wbuf.Reset()
@ -46,13 +48,15 @@ func (c *tcpConn) Write(b []byte) (n int, err error) {
type udpConn struct {
net.Conn
wbuf bytes.Buffer
wbuf *bytes.Buffer
once sync.Once
}
func (c *udpConn) Read(b []byte) (n int, err error) {
c.once.Do(func() {
err = readResponse(c.Conn)
if c.wbuf != nil {
err = readResponse(c.Conn)
}
})
if err != nil {
return
@ -84,7 +88,7 @@ func (c *udpConn) Write(b []byte) (n int, err error) {
}
n = len(b)
if c.wbuf.Len() > 0 {
if c.wbuf != nil && c.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
c.wbuf.Write(bb[:])

View File

@ -1,6 +1,7 @@
package tunnel
import (
"bytes"
"context"
"fmt"
"net"
@ -9,6 +10,7 @@ import (
"github.com/go-gost/core/connector"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/relay"
auth_util "github.com/go-gost/x/internal/util/auth"
"github.com/go-gost/x/registry"
)
@ -55,6 +57,14 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a
Cmd: relay.CmdConnect,
}
switch network {
case "udp", "udp4", "udp6":
req.Cmd |= relay.FUDP
req.Features = append(req.Features, &relay.NetworkFeature{
Network: relay.NetworkUDP,
})
}
if c.options.Auth != nil {
pwd, _ := c.options.Auth.Password()
req.Features = append(req.Features, &relay.UserAuthFeature{
@ -63,33 +73,54 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a
})
}
if address != "" {
af := &relay.AddrFeature{}
if err := af.ParseFrom(address); err != nil {
return nil, err
}
req.Features = append(req.Features, af)
srcAddr := conn.LocalAddr().String()
if v := auth_util.ClientAddrFromContext(ctx); v != "" {
srcAddr = string(v)
}
af := &relay.AddrFeature{}
af.ParseFrom(srcAddr)
req.Features = append(req.Features, af) // src address
af = &relay.AddrFeature{}
af.ParseFrom(address)
req.Features = append(req.Features, af) // dst address
req.Features = append(req.Features, &relay.TunnelFeature{
ID: c.md.tunnelID.ID(),
})
switch network {
case "tcp", "tcp4", "tcp6", "unix", "serial":
cc := &tcpConn{
Conn: conn,
}
if _, err := req.WriteTo(&cc.wbuf); err != nil {
if c.md.noDelay {
if _, err := req.WriteTo(conn); err != nil {
return nil, err
}
conn = cc
// drain the response
if err := readResponse(conn); err != nil {
return nil, err
}
}
switch network {
case "tcp", "tcp4", "tcp6":
if !c.md.noDelay {
cc := &tcpConn{
Conn: conn,
wbuf: &bytes.Buffer{},
}
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
conn = cc
}
case "udp", "udp4", "udp6":
cc := &udpConn{
Conn: conn,
}
if _, err := req.WriteTo(&cc.wbuf); err != nil {
return nil, err
if !c.md.noDelay {
cc.wbuf = &bytes.Buffer{}
if _, err := req.WriteTo(cc.wbuf); err != nil {
return nil, err
}
}
conn = cc
default:

View File

@ -17,15 +17,16 @@ var (
type metadata struct {
connectTimeout time.Duration
tunnelID relay.TunnelID
noDelay bool
}
func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) {
const (
connectTimeout = "connectTimeout"
noDelay = "nodelay"
)
c.md.connectTimeout = mdutil.GetDuration(md, connectTimeout)
c.md.noDelay = mdutil.GetBool(md, "nodelay")
if s := mdutil.GetString(md, "tunnelID", "tunnel.id"); s != "" {
uuid, err := uuid.Parse(s)