From f2ff1aa45a4ead6645eea4983f8b09961d438650 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 16 Sep 2023 23:14:12 +0800 Subject: [PATCH] add unix domain socket --- handler/relay/connect.go | 2 + handler/serial/handler.go | 7 +- handler/unix/handler.go | 135 +++++++++++++++++++++++++++++++++ handler/unix/metadata.go | 12 +++ internal/util/serial/serial.go | 20 ++++- listener/unix/listener.go | 67 ++++++++++++++++ listener/unix/metadata.go | 12 +++ 7 files changed, 251 insertions(+), 4 deletions(-) create mode 100644 handler/unix/handler.go create mode 100644 handler/unix/metadata.go create mode 100644 listener/unix/listener.go create mode 100644 listener/unix/metadata.go diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 9eab3d4..75f7da9 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -59,6 +59,8 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network var cc io.ReadWriteCloser switch network { + case "unix": + cc, err = (&net.Dialer{}).DialContext(ctx, "unix", address) case "serial": cc, err = goserial.OpenPort(serial_util.ParseConfigFromAddr(address)) default: diff --git a/handler/serial/handler.go b/handler/serial/handler.go index d18b21a..3d5dc0e 100644 --- a/handler/serial/handler.go +++ b/handler/serial/handler.go @@ -19,7 +19,6 @@ import ( func init() { registry.HandlerRegistry().Register("serial", NewHandler) - registry.HandlerRegistry().Register("com", NewHandler) } type serialHandler struct { @@ -117,10 +116,12 @@ func (h *serialHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl func (h *serialHandler) forwardSerial(ctx context.Context, conn net.Conn, target *chain.Node, log logger.Logger) (err error) { var port io.ReadWriteCloser + cfg := serial_util.ParseConfigFromAddr(conn.LocalAddr().String()) + cfg.Name = target.Addr + if opts := h.router.Options(); opts != nil && opts.Chain != nil { - port, err = h.router.Dial(ctx, "serial", target.Addr) + port, err = h.router.Dial(ctx, "serial", serial_util.AddrFromConfig(cfg)) } else { - cfg := serial_util.ParseConfigFromAddr(target.Addr) cfg.ReadTimeout = h.md.timeout port, err = goserial.OpenPort(cfg) } diff --git a/handler/unix/handler.go b/handler/unix/handler.go new file mode 100644 index 0000000..b9a3cc1 --- /dev/null +++ b/handler/unix/handler.go @@ -0,0 +1,135 @@ +package unix + +import ( + "context" + "errors" + "io" + "net" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/handler" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + xnet "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("unix", NewHandler) +} + +type unixHandler struct { + hop chain.Hop + router *chain.Router + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &unixHandler{ + options: options, + } +} + +func (h *unixHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + h.router = h.options.Router + if h.router == nil { + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) + } + + return +} + +// Forward implements handler.Forwarder. +func (h *unixHandler) Forward(hop chain.Hop) { + h.hop = hop +} + +func (h *unixHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + defer conn.Close() + + log := h.options.Logger + + log = log.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx) + } + + if target == nil { + err := errors.New("target not available") + log.Error(err) + return err + } + + log = log.WithFields(map[string]any{ + "node": target.Name, + "dst": target.Addr, + }) + + log.Debugf("%s >> %s", conn.LocalAddr(), target.Addr) + + if _, _, err := net.SplitHostPort(target.Addr); err != nil { + return h.forwardUnix(ctx, conn, target, log) + } + + cc, err := h.router.Dial(ctx, "tcp", target.Addr) + if err != nil { + log.Error(err) + if marker := target.Marker(); marker != nil { + marker.Mark() + } + return err + } + defer cc.Close() + if marker := target.Marker(); marker != nil { + marker.Reset() + } + + t := time.Now() + log.Infof("%s <-> %s", conn.LocalAddr(), target.Addr) + xnet.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.LocalAddr(), target.Addr) + + return nil +} + +func (h *unixHandler) forwardUnix(ctx context.Context, conn net.Conn, target *chain.Node, log logger.Logger) (err error) { + var cc io.ReadWriteCloser + + if opts := h.router.Options(); opts != nil && opts.Chain != nil { + cc, err = h.router.Dial(ctx, "unix", target.Addr) + } else { + cc, err = (&net.Dialer{}).DialContext(ctx, "unix", target.Addr) + } + if err != nil { + log.Error(err) + return err + } + defer cc.Close() + + t := time.Now() + log.Infof("%s <-> %s", conn.LocalAddr(), target.Addr) + xnet.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.LocalAddr(), target.Addr) + + return nil +} diff --git a/handler/unix/metadata.go b/handler/unix/metadata.go new file mode 100644 index 0000000..776c802 --- /dev/null +++ b/handler/unix/metadata.go @@ -0,0 +1,12 @@ +package unix + +import ( + mdata "github.com/go-gost/core/metadata" +) + +type metadata struct { +} + +func (h *unixHandler) parseMetadata(md mdata.Metadata) (err error) { + return +} diff --git a/internal/util/serial/serial.go b/internal/util/serial/serial.go index 17dc5b3..975295a 100644 --- a/internal/util/serial/serial.go +++ b/internal/util/serial/serial.go @@ -10,7 +10,6 @@ import ( const ( DefaultPort = "COM1" DefaultBaudRate = 9600 - DefaultParity = "none" ) // COM1,9600,odd @@ -34,6 +33,25 @@ func ParseConfigFromAddr(addr string) *goserial.Config { return cfg } +func AddrFromConfig(cfg *goserial.Config) string { + ss := []string{ + cfg.Name, + strconv.Itoa(cfg.Baud), + } + + switch cfg.Parity { + case goserial.ParityEven: + ss = append(ss, "even") + case goserial.ParityOdd: + ss = append(ss, "odd") + case goserial.ParityMark: + ss = append(ss, "mark") + case goserial.ParitySpace: + ss = append(ss, "space") + } + return strings.Join(ss, ",") +} + func parseParity(s string) goserial.Parity { switch strings.ToLower(s) { case "o", "odd": diff --git a/listener/unix/listener.go b/listener/unix/listener.go new file mode 100644 index 0000000..98fe4c8 --- /dev/null +++ b/listener/unix/listener.go @@ -0,0 +1,67 @@ +package unix + +import ( + "net" + + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + admission "github.com/go-gost/x/admission/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" + metrics "github.com/go-gost/x/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("unix", NewListener) +} + +type unixListener struct { + ln net.Listener + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &unixListener{ + logger: options.Logger, + options: options, + } +} + +func (l *unixListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + ln, err := net.Listen("unix", l.options.Addr) + if err != nil { + return + } + + ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) + l.ln = ln + + return +} + +func (l *unixListener) Accept() (conn net.Conn, err error) { + return l.ln.Accept() +} + +func (l *unixListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *unixListener) Close() error { + return l.ln.Close() +} diff --git a/listener/unix/metadata.go b/listener/unix/metadata.go new file mode 100644 index 0000000..2f9cd99 --- /dev/null +++ b/listener/unix/metadata.go @@ -0,0 +1,12 @@ +package unix + +import ( + md "github.com/go-gost/core/metadata" +) + +type metadata struct { +} + +func (l *unixListener) parseMetadata(md md.Metadata) (err error) { + return +}