add bind for relay

This commit is contained in:
ginuerzh
2021-11-25 17:29:54 +08:00
parent 98ef6c7492
commit 6daf0a4d0f
29 changed files with 600 additions and 352 deletions

81
pkg/handler/relay/conn.go Normal file
View File

@ -0,0 +1,81 @@
package relay
import (
"bytes"
"encoding/binary"
"errors"
"io"
"math"
"net"
)
type tcpConn struct {
net.Conn
wbuf bytes.Buffer
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
if err != nil {
return
}
return c.Conn.Read(b)
}
func (c *tcpConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
return
}
_, err = c.Conn.Write(b)
return
}
type udpConn struct {
net.Conn
wbuf bytes.Buffer
}
func (c *udpConn) Read(b []byte) (n int, err error) {
var bb [2]byte
_, err = io.ReadFull(c.Conn, bb[:])
if err != nil {
return
}
dlen := int(binary.BigEndian.Uint16(bb[:]))
if len(b) >= dlen {
return io.ReadFull(c.Conn, b[:dlen])
}
buf := make([]byte, dlen)
_, err = io.ReadFull(c.Conn, buf)
n = copy(b, buf)
return
}
func (c *udpConn) Write(b []byte) (n int, err error) {
if len(b) > math.MaxUint16 {
err = errors.New("write: data maximum exceeded")
return
}
n = len(b)
if c.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
c.wbuf.Write(bb[:])
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
return
}
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
_, err = c.Conn.Write(bb[:])
if err != nil {
return
}
return c.Conn.Write(b)
}

View File

@ -7,7 +7,6 @@ import (
"time"
"github.com/go-gost/gost/pkg/chain"
util_relay "github.com/go-gost/gost/pkg/common/util/relay"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/relay"
)
@ -51,12 +50,36 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
}
defer cc.Close()
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
return
}
}
if network == "udp" {
conn = util_relay.UDPTunConn(conn)
switch network {
case "udp", "udp4", "udp6":
rc := &udpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return
}
}
conn = rc
default:
rc := &tcpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return
}
}
conn = rc
}
t := time.Now()

View File

@ -8,11 +8,18 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
target := h.group.Next()
if target == nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
h.logger.Error("no target available")
return
}
@ -30,15 +37,51 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
cc, err := r.Dial(ctx, network, target.Addr())
if err != nil {
h.logger.Error(err)
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
target.Marker().Mark()
resp.Status = relay.StatusHostUnreachable
resp.WriteTo(conn)
h.logger.Error(err)
return
}
defer cc.Close()
target.Marker().Reset()
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
return
}
}
switch network {
case "udp", "udp4", "udp6":
rc := &udpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return
}
}
conn = rc
default:
rc := &tcpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return
}
}
conn = rc
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
handler.Transport(conn, cc)

View File

@ -123,7 +123,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
if address != "" {
resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
h.logger.Error("forbidden")
h.logger.Error("forward mode, connect is forbidden")
return
}
// forward mode
@ -132,7 +132,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
}
switch req.Flags & relay.CmdMask {
case relay.CONNECT:
case 0, relay.CONNECT:
h.handleConnect(ctx, conn, network, address)
case relay.BIND:
h.handleBind(ctx, conn, network, address)

View File

@ -14,6 +14,7 @@ type metadata struct {
retryCount int
enableBind bool
udpBufferSize int
noDelay bool
}
func (h *relayHandler) parseMetadata(md md.Metadata) (err error) {
@ -23,6 +24,7 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) {
retryCount = "retry"
enableBind = "bind"
udpBufferSize = "udpBufferSize"
noDelay = "nodelay"
)
if v, _ := md.Get(users).([]interface{}); len(v) > 0 {
@ -42,6 +44,7 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) {
h.md.readTimeout = md.GetDuration(readTimeout)
h.md.retryCount = md.GetInt(retryCount)
h.md.enableBind = md.GetBool(enableBind)
h.md.noDelay = md.GetBool(noDelay)
h.md.udpBufferSize = md.GetInt(udpBufferSize)
if h.md.udpBufferSize > 0 {
if h.md.udpBufferSize < 512 {