add network feature for relay

This commit is contained in:
ginuerzh
2023-09-16 21:47:40 +08:00
parent ee4f80b68d
commit 92db078642
13 changed files with 193 additions and 174 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"time"
@ -12,9 +13,17 @@ import (
"github.com/go-gost/relay"
xnet "github.com/go-gost/x/internal/net"
sx "github.com/go-gost/x/internal/util/selector"
serial_util "github.com/go-gost/x/internal/util/serial"
goserial "github.com/tarm/serial"
)
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (err error) {
if network == "unix" || network == "serial" {
if host, _, _ := net.SplitHostPort(address); host != "" {
address = host
}
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect",
@ -30,23 +39,31 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if address == "" {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
err := errors.New("target not specified")
err = errors.New("target not specified")
log.Error(err)
return err
return
}
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, address) {
log.Debug("bypass: ", address)
resp.Status = relay.StatusForbidden
_, err := resp.WriteTo(conn)
return err
_, err = resp.WriteTo(conn)
return
}
switch h.md.hash {
case "host":
ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address})
}
cc, err := h.router.Dial(ctx, network, address)
var cc io.ReadWriteCloser
switch network {
case "serial":
cc, err = goserial.OpenPort(serial_util.ParseConfigFromAddr(address))
default:
cc, err = h.router.Dial(ctx, network, address)
}
if err != nil {
resp.Status = relay.StatusNetworkUnreachable
resp.WriteTo(conn)

View File

@ -173,6 +173,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
var user, pass string
var address string
var networkID relay.NetworkID
var tunnelID relay.TunnelID
for _, f := range req.Features {
switch f.Type() {
@ -188,6 +189,10 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
if feature, _ := f.(*relay.TunnelFeature); feature != nil {
tunnelID = relay.NewTunnelID(feature.ID[:])
}
case relay.FeatureNetwork:
if feature, _ := f.(*relay.NetworkFeature); feature != nil {
networkID = feature.Network
}
}
}
@ -202,7 +207,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
return ErrUnauthorized
}
network := "tcp"
network := networkID.String()
if (req.Cmd & relay.FUDP) == relay.FUDP {
network = "udp"
}