add network feature for relay

This commit is contained in:
ginuerzh
2023-09-16 21:47:40 +08:00
parent ee4f80b68d
commit 92db078642
13 changed files with 193 additions and 174 deletions

View File

@ -1,32 +0,0 @@
package com
import (
"time"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
)
const (
defaultBaudRate = 9600
defaultParity = "odd"
)
type metadata struct {
baudRate int
parity string
timeout time.Duration
}
func (h *comHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.baudRate = mdutil.GetInt(md, "baud", "com.baud", "handler.com.baud")
if h.md.baudRate <= 0 {
h.md.baudRate = defaultBaudRate
}
h.md.parity = mdutil.GetString(md, "parity", "com.parity", "handler.com.parity")
if h.md.parity == "" {
h.md.parity = defaultParity
}
h.md.timeout = mdutil.GetDuration(md, "timeout", "com.timeout", "handler.com.timeout")
return
}

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"time"
@ -12,9 +13,17 @@ import (
"github.com/go-gost/relay"
xnet "github.com/go-gost/x/internal/net"
sx "github.com/go-gost/x/internal/util/selector"
serial_util "github.com/go-gost/x/internal/util/serial"
goserial "github.com/tarm/serial"
)
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (err error) {
if network == "unix" || network == "serial" {
if host, _, _ := net.SplitHostPort(address); host != "" {
address = host
}
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect",
@ -30,23 +39,31 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if address == "" {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
err := errors.New("target not specified")
err = errors.New("target not specified")
log.Error(err)
return err
return
}
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, address) {
log.Debug("bypass: ", address)
resp.Status = relay.StatusForbidden
_, err := resp.WriteTo(conn)
return err
_, err = resp.WriteTo(conn)
return
}
switch h.md.hash {
case "host":
ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address})
}
cc, err := h.router.Dial(ctx, network, address)
var cc io.ReadWriteCloser
switch network {
case "serial":
cc, err = goserial.OpenPort(serial_util.ParseConfigFromAddr(address))
default:
cc, err = h.router.Dial(ctx, network, address)
}
if err != nil {
resp.Status = relay.StatusNetworkUnreachable
resp.WriteTo(conn)

View File

@ -173,6 +173,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
var user, pass string
var address string
var networkID relay.NetworkID
var tunnelID relay.TunnelID
for _, f := range req.Features {
switch f.Type() {
@ -188,6 +189,10 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
if feature, _ := f.(*relay.TunnelFeature); feature != nil {
tunnelID = relay.NewTunnelID(feature.ID[:])
}
case relay.FeatureNetwork:
if feature, _ := f.(*relay.NetworkFeature); feature != nil {
networkID = feature.Network
}
}
}
@ -202,7 +207,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
return ErrUnauthorized
}
network := "tcp"
network := networkID.String()
if (req.Cmd & relay.FUDP) == relay.FUDP {
network = "udp"
}

View File

@ -1,10 +1,10 @@
package com
package serial
import (
"context"
"errors"
"io"
"net"
"strings"
"time"
"github.com/go-gost/core/chain"
@ -12,15 +12,17 @@ import (
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
xnet "github.com/go-gost/x/internal/net"
serial_util "github.com/go-gost/x/internal/util/serial"
"github.com/go-gost/x/registry"
goserial "github.com/tarm/serial"
)
func init() {
registry.HandlerRegistry().Register("serial", NewHandler)
registry.HandlerRegistry().Register("com", NewHandler)
}
type comHandler struct {
type serialHandler struct {
hop chain.Hop
router *chain.Router
md metadata
@ -33,12 +35,12 @@ func NewHandler(opts ...handler.Option) handler.Handler {
opt(&options)
}
return &comHandler{
return &serialHandler{
options: options,
}
}
func (h *comHandler) Init(md md.Metadata) (err error) {
func (h *serialHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
@ -52,28 +54,20 @@ func (h *comHandler) Init(md md.Metadata) (err error) {
}
// Forward implements handler.Forwarder.
func (h *comHandler) Forward(hop chain.Hop) {
func (h *serialHandler) Forward(hop chain.Hop) {
h.hop = hop
}
func (h *comHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
func (h *serialHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
log := h.options.Logger
start := time.Now()
log = log.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
var target *chain.Node
if h.hop != nil {
target = h.hop.Select(ctx)
@ -94,7 +88,7 @@ func (h *comHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
// serial port
if _, _, err := net.SplitHostPort(target.Addr); err != nil {
return h.forwardCom(ctx, conn, target, log)
return h.forwardSerial(ctx, conn, target, log)
}
cc, err := h.router.Dial(ctx, "tcp", target.Addr)
@ -120,17 +114,20 @@ func (h *comHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
return nil
}
func (h *comHandler) forwardCom(ctx context.Context, conn net.Conn, target *chain.Node, log logger.Logger) error {
port, err := goserial.OpenPort(&goserial.Config{
Name: target.Addr,
Baud: h.md.baudRate,
Parity: parseParity(h.md.parity),
ReadTimeout: h.md.timeout,
})
func (h *serialHandler) forwardSerial(ctx context.Context, conn net.Conn, target *chain.Node, log logger.Logger) (err error) {
var port io.ReadWriteCloser
if opts := h.router.Options(); opts != nil && opts.Chain != nil {
port, err = h.router.Dial(ctx, "serial", target.Addr)
} else {
cfg := serial_util.ParseConfigFromAddr(target.Addr)
cfg.ReadTimeout = h.md.timeout
port, err = goserial.OpenPort(cfg)
}
if err != nil {
log.Error(err)
return err
}
defer port.Close()
t := time.Now()
@ -142,18 +139,3 @@ func (h *comHandler) forwardCom(ctx context.Context, conn net.Conn, target *chai
return nil
}
func parseParity(s string) goserial.Parity {
switch strings.ToLower(s) {
case "o", "odd":
return goserial.ParityOdd
case "e", "even":
return goserial.ParityEven
case "m", "mark":
return goserial.ParityMark
case "s", "space":
return goserial.ParitySpace
default:
return goserial.ParityNone
}
}

View File

@ -0,0 +1,22 @@
package serial
import (
"time"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
)
const (
defaultPort = "COM1"
defaultBaudRate = 9600
)
type metadata struct {
timeout time.Duration
}
func (h *serialHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.timeout = mdutil.GetDuration(md, "timeout", "serial.timeout", "handler.serial.timeout")
return
}