From bfc1f8472cc28a801d9022bdfb04b1d0d0c049f0 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 13 Mar 2022 00:08:16 +0800 Subject: [PATCH] add admission option for listener --- pkg/admission/admission.go | 4 + pkg/common/admission/conn.go | 223 +++++++++++++++++++++++++++++ pkg/common/admission/listener.go | 37 +++++ pkg/common/net/udp.go | 41 ++++++ pkg/config/parsing/service.go | 3 +- pkg/listener/grpc/listener.go | 2 + pkg/listener/http2/h2/listener.go | 2 + pkg/listener/http2/listener.go | 2 + pkg/listener/icmp/listener.go | 2 + pkg/listener/kcp/listener.go | 2 + pkg/listener/obfs/http/listener.go | 2 + pkg/listener/obfs/tls/listener.go | 2 + pkg/listener/option.go | 8 ++ pkg/listener/ssh/listener.go | 2 + pkg/listener/sshd/listener.go | 2 + pkg/listener/tcp/listener.go | 4 +- pkg/listener/tls/listener.go | 2 + pkg/listener/tls/mux/listener.go | 2 + pkg/listener/udp/listener.go | 6 +- pkg/listener/ws/listener.go | 2 + pkg/listener/ws/mux/listener.go | 2 + pkg/service/service.go | 1 - 22 files changed, 348 insertions(+), 5 deletions(-) create mode 100644 pkg/common/admission/conn.go create mode 100644 pkg/common/admission/listener.go create mode 100644 pkg/common/net/udp.go diff --git a/pkg/admission/admission.go b/pkg/admission/admission.go index b5ee368..1b16de9 100644 --- a/pkg/admission/admission.go +++ b/pkg/admission/admission.go @@ -58,6 +58,7 @@ func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) Admi func (p *admission) Admit(addr string) bool { if addr == "" || p == nil || len(p.matchers) == 0 { + p.options.logger.Debugf("admission: %v is denied", addr) return false } @@ -81,5 +82,8 @@ func (p *admission) Admit(addr string) bool { b := !p.reversed && matched || p.reversed && !matched + if !b { + p.options.logger.Debugf("admission: %v is denied", addr) + } return b } diff --git a/pkg/common/admission/conn.go b/pkg/common/admission/conn.go new file mode 100644 index 0000000..261cd8a --- /dev/null +++ b/pkg/common/admission/conn.go @@ -0,0 +1,223 @@ +package admission + +import ( + "errors" + "io" + "net" + "syscall" + + "github.com/go-gost/gost/pkg/admission" +) + +var ( + errUnsupport = errors.New("unsupported operation") +) + +type packetConn struct { + net.PacketConn + admission admission.Admission +} + +func WrapPacketConn(admission admission.Admission, pc net.PacketConn) net.PacketConn { + if admission == nil { + return pc + } + return &packetConn{ + PacketConn: pc, + admission: admission, + } +} + +func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + + return + } +} + +type udpConn struct { + net.PacketConn + admission admission.Admission +} + +func WrapUDPConn(admission admission.Admission, pc net.PacketConn) UDPConn { + return &udpConn{ + PacketConn: pc, + admission: admission, + } +} + +func (c *udpConn) RemoteAddr() net.Addr { + if nc, ok := c.PacketConn.(remoteAddr); ok { + return nc.RemoteAddr() + } + return nil +} + +func (c *udpConn) SetReadBuffer(n int) error { + if nc, ok := c.PacketConn.(setBuffer); ok { + return nc.SetReadBuffer(n) + } + return errUnsupport +} + +func (c *udpConn) SetWriteBuffer(n int) error { + if nc, ok := c.PacketConn.(setBuffer); ok { + return nc.SetWriteBuffer(n) + } + return errUnsupport +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + if nc, ok := c.PacketConn.(io.Reader); ok { + n, err = nc.Read(b) + return + } + err = errUnsupport + return +} + +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } +} + +func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + if nc, ok := c.PacketConn.(readUDP); ok { + for { + n, addr, err = nc.ReadFromUDP(b) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } + } + err = errUnsupport + return +} + +func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { + if nc, ok := c.PacketConn.(readUDP); ok { + for { + n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } + } + err = errUnsupport + return +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + if nc, ok := c.PacketConn.(io.Writer); ok { + n, err = nc.Write(b) + return + } + err = errUnsupport + return +} + +func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + n, err = c.PacketConn.WriteTo(p, addr) + return +} + +func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { + if nc, ok := c.PacketConn.(writeUDP); ok { + n, err = nc.WriteToUDP(b, addr) + return + } + err = errUnsupport + return +} + +func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { + if nc, ok := c.PacketConn.(writeUDP); ok { + n, oobn, err = nc.WriteMsgUDP(b, oob, addr) + return + } + err = errUnsupport + return +} + +func (c *udpConn) SyscallConn() (rc syscall.RawConn, err error) { + if nc, ok := c.PacketConn.(syscallConn); ok { + return nc.SyscallConn() + } + err = errUnsupport + return +} + +func (c *udpConn) SetDSCP(n int) error { + if nc, ok := c.PacketConn.(setDSCP); ok { + return nc.SetDSCP(n) + } + return nil +} + +type UDPConn interface { + net.PacketConn + io.Reader + io.Writer + readUDP + writeUDP + setBuffer + syscallConn + remoteAddr +} + +type setBuffer interface { + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + +type readUDP interface { + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) +} + +type writeUDP interface { + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +type syscallConn interface { + SyscallConn() (syscall.RawConn, error) +} + +type remoteAddr interface { + RemoteAddr() net.Addr +} + +// tcpraw.TCPConn +type setDSCP interface { + SetDSCP(int) error +} diff --git a/pkg/common/admission/listener.go b/pkg/common/admission/listener.go new file mode 100644 index 0000000..e55f2d3 --- /dev/null +++ b/pkg/common/admission/listener.go @@ -0,0 +1,37 @@ +package admission + +import ( + "net" + + "github.com/go-gost/gost/pkg/admission" +) + +type listener struct { + net.Listener + admission admission.Admission +} + +func WrapListener(admission admission.Admission, ln net.Listener) net.Listener { + if admission == nil { + return ln + } + return &listener{ + Listener: ln, + admission: admission, + } +} + +func (ln *listener) Accept() (net.Conn, error) { + for { + c, err := ln.Listener.Accept() + if err != nil { + return nil, err + } + if ln.admission != nil && + !ln.admission.Admit(c.RemoteAddr().String()) { + c.Close() + continue + } + return c, err + } +} diff --git a/pkg/common/net/udp.go b/pkg/common/net/udp.go new file mode 100644 index 0000000..515060f --- /dev/null +++ b/pkg/common/net/udp.go @@ -0,0 +1,41 @@ +package net + +import ( + "io" + "net" + "syscall" +) + +type UDPConn interface { + net.PacketConn + io.Reader + io.Writer + readUDP + writeUDP + setBuffer + syscallConn + remoteAddr +} + +type setBuffer interface { + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + +type readUDP interface { + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) +} + +type writeUDP interface { + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +type syscallConn interface { + SyscallConn() (syscall.RawConn, error) +} + +type remoteAddr interface { + RemoteAddr() net.Addr +} diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go index 8977ad7..ab80df7 100644 --- a/pkg/config/parsing/service.go +++ b/pkg/config/parsing/service.go @@ -54,10 +54,11 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { ln := registry.ListenerRegistry().Get(cfg.Listener.Type)( listener.AddrOption(cfg.Addr), - listener.ChainOption(registry.ChainRegistry().Get(cfg.Listener.Chain)), listener.AutherOption(auther), listener.AuthOption(parseAuth(cfg.Listener.Auth)), listener.TLSConfigOption(tlsConfig), + listener.AdmissionOption(registry.AdmissionRegistry().Get(cfg.Admission)), + listener.ChainOption(registry.ChainRegistry().Get(cfg.Listener.Chain)), listener.LoggerOption(listenerLogger), listener.ServiceOption(cfg.Name), ) diff --git a/pkg/listener/grpc/listener.go b/pkg/listener/grpc/listener.go index 10847ae..8ba93fa 100644 --- a/pkg/listener/grpc/listener.go +++ b/pkg/listener/grpc/listener.go @@ -3,6 +3,7 @@ package grpc import ( "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" pb "github.com/go-gost/gost/pkg/common/util/grpc/proto" "github.com/go-gost/gost/pkg/listener" @@ -48,6 +49,7 @@ func (l *grpcListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) var opts []grpc.ServerOption if !l.md.insecure { diff --git a/pkg/listener/http2/h2/listener.go b/pkg/listener/http2/h2/listener.go index 50210b2..072dfbb 100644 --- a/pkg/listener/http2/h2/listener.go +++ b/pkg/listener/http2/h2/listener.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httputil" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -70,6 +71,7 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { } l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) if l.h2c { l.server.Handler = h2c.NewHandler( diff --git a/pkg/listener/http2/listener.go b/pkg/listener/http2/listener.go index 9c25fa5..1bf0525 100644 --- a/pkg/listener/http2/listener.go +++ b/pkg/listener/http2/listener.go @@ -5,6 +5,7 @@ import ( "net" "net/http" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/listener" @@ -59,6 +60,7 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { } l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) ln = tls.NewListener( ln, diff --git a/pkg/listener/icmp/listener.go b/pkg/listener/icmp/listener.go index a978076..e737d70 100644 --- a/pkg/listener/icmp/listener.go +++ b/pkg/listener/icmp/listener.go @@ -4,6 +4,7 @@ import ( "context" "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" icmp_pkg "github.com/go-gost/gost/pkg/internal/util/icmp" "github.com/go-gost/gost/pkg/listener" @@ -55,6 +56,7 @@ func (l *icmpListener) Init(md md.Metadata) (err error) { } conn = icmp_pkg.ServerConn(conn) conn = metrics.WrapPacketConn(l.options.Service, conn) + conn = admission.WrapPacketConn(l.options.Admission, conn) config := &quic.Config{ KeepAlive: l.md.keepAlive, diff --git a/pkg/listener/kcp/listener.go b/pkg/listener/kcp/listener.go index 1ea3bd0..07c2e04 100644 --- a/pkg/listener/kcp/listener.go +++ b/pkg/listener/kcp/listener.go @@ -4,6 +4,7 @@ import ( "net" "time" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" kcp_util "github.com/go-gost/gost/pkg/common/util/kcp" "github.com/go-gost/gost/pkg/listener" @@ -64,6 +65,7 @@ func (l *kcpListener) Init(md md.Metadata) (err error) { } conn = metrics.WrapUDPConn(l.options.Service, conn) + conn = admission.WrapUDPConn(l.options.Admission, conn) ln, err := kcp.ServeConn( kcp_util.BlockCrypt(config.Key, config.Crypt, kcp_util.DefaultSalt), diff --git a/pkg/listener/obfs/http/listener.go b/pkg/listener/obfs/http/listener.go index 1867099..cc8900d 100644 --- a/pkg/listener/obfs/http/listener.go +++ b/pkg/listener/obfs/http/listener.go @@ -3,6 +3,7 @@ package http import ( "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -42,6 +43,7 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = ln return diff --git a/pkg/listener/obfs/tls/listener.go b/pkg/listener/obfs/tls/listener.go index b8860a7..e696de7 100644 --- a/pkg/listener/obfs/tls/listener.go +++ b/pkg/listener/obfs/tls/listener.go @@ -3,6 +3,7 @@ package tls import ( "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -42,6 +43,7 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = ln return diff --git a/pkg/listener/option.go b/pkg/listener/option.go index 8cea96a..5a99518 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/url" + "github.com/go-gost/gost/pkg/admission" "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" @@ -14,6 +15,7 @@ type Options struct { Auther auth.Authenticator Auth *url.Userinfo TLSConfig *tls.Config + Admission admission.Admission Chain chain.Chainer Logger logger.Logger Service string @@ -45,6 +47,12 @@ func TLSConfigOption(tlsConfig *tls.Config) Option { } } +func AdmissionOption(admission admission.Admission) Option { + return func(opts *Options) { + opts.Admission = admission + } +} + func ChainOption(chain chain.Chainer) Option { return func(opts *Options) { opts.Chain = chain diff --git a/pkg/listener/ssh/listener.go b/pkg/listener/ssh/listener.go index afbea97..1cea3b7 100644 --- a/pkg/listener/ssh/listener.go +++ b/pkg/listener/ssh/listener.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" "github.com/go-gost/gost/pkg/listener" @@ -50,6 +51,7 @@ func (l *sshListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = ln config := &ssh.ServerConfig{ diff --git a/pkg/listener/sshd/listener.go b/pkg/listener/sshd/listener.go index 6163f6e..10aac39 100644 --- a/pkg/listener/sshd/listener.go +++ b/pkg/listener/sshd/listener.go @@ -7,6 +7,7 @@ import ( "strconv" "time" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" @@ -59,6 +60,7 @@ func (l *sshdListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = ln config := &ssh.ServerConfig{ diff --git a/pkg/listener/tcp/listener.go b/pkg/listener/tcp/listener.go index 84cc164..99806b9 100644 --- a/pkg/listener/tcp/listener.go +++ b/pkg/listener/tcp/listener.go @@ -41,7 +41,9 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { if err != nil { return } - l.Listener = metrics.WrapListener(l.options.Service, ln) + + ln = metrics.WrapListener(l.options.Service, ln) + l.Listener = ln return } diff --git a/pkg/listener/tls/listener.go b/pkg/listener/tls/listener.go index d62c144..06bf319 100644 --- a/pkg/listener/tls/listener.go +++ b/pkg/listener/tls/listener.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -43,6 +44,7 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = tls.NewListener(ln, l.options.TLSConfig) diff --git a/pkg/listener/tls/mux/listener.go b/pkg/listener/tls/mux/listener.go index 8346cba..e9c99b9 100644 --- a/pkg/listener/tls/mux/listener.go +++ b/pkg/listener/tls/mux/listener.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -47,6 +48,7 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) l.Listener = tls.NewListener(ln, l.options.TLSConfig) l.cqueue = make(chan net.Conn, l.md.backlog) diff --git a/pkg/listener/udp/listener.go b/pkg/listener/udp/listener.go index 2a0ab86..c22c9d8 100644 --- a/pkg/listener/udp/listener.go +++ b/pkg/listener/udp/listener.go @@ -43,13 +43,15 @@ func (l *udpListener) Init(md md.Metadata) (err error) { return } - conn, err := net.ListenUDP("udp", laddr) + var conn net.PacketConn + conn, err = net.ListenUDP("udp", laddr) if err != nil { return } + conn = metrics.WrapPacketConn(l.options.Service, conn) l.Listener = udp.NewListener( - metrics.WrapPacketConn(l.options.Service, conn), + conn, laddr, l.md.backlog, l.md.readQueueSize, l.md.readBufferSize, diff --git a/pkg/listener/ws/listener.go b/pkg/listener/ws/listener.go index 532c241..90649c1 100644 --- a/pkg/listener/ws/listener.go +++ b/pkg/listener/ws/listener.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httputil" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" ws_util "github.com/go-gost/gost/pkg/internal/util/ws" "github.com/go-gost/gost/pkg/listener" @@ -84,6 +85,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) if l.tlsEnabled { ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/pkg/listener/ws/mux/listener.go b/pkg/listener/ws/mux/listener.go index aeebbea..62e753e 100644 --- a/pkg/listener/ws/mux/listener.go +++ b/pkg/listener/ws/mux/listener.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httputil" + "github.com/go-gost/gost/pkg/common/admission" "github.com/go-gost/gost/pkg/common/metrics" ws_util "github.com/go-gost/gost/pkg/internal/util/ws" "github.com/go-gost/gost/pkg/listener" @@ -89,6 +90,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) if l.tlsEnabled { ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/pkg/service/service.go b/pkg/service/service.go index be3b8f4..b4f5cd1 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -93,7 +93,6 @@ func (s *service) Serve() error { if s.options.admission != nil && !s.options.admission.Admit(conn.RemoteAddr().String()) { - s.options.logger.Infof("admission: %s is denied", conn.RemoteAddr()) conn.Close() continue }