diff --git a/pkg/admission/wrapper/conn.go b/pkg/admission/wrapper/conn.go new file mode 100644 index 0000000..673ee3e --- /dev/null +++ b/pkg/admission/wrapper/conn.go @@ -0,0 +1,223 @@ +package wrapper + +import ( + "errors" + "io" + "net" + "syscall" + + "github.com/go-gost/gost/v3/pkg/admission" +) + +var ( + errUnsupport = errors.New("unsupported operation") +) + +type packetConn struct { + net.PacketConn + admission admission.Admission +} + +func WrapPacketConn(admission admission.Admission, pc net.PacketConn) net.PacketConn { + if admission == nil { + return pc + } + return &packetConn{ + PacketConn: pc, + admission: admission, + } +} + +func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + + return + } +} + +type udpConn struct { + net.PacketConn + admission admission.Admission +} + +func WrapUDPConn(admission admission.Admission, pc net.PacketConn) UDPConn { + return &udpConn{ + PacketConn: pc, + admission: admission, + } +} + +func (c *udpConn) RemoteAddr() net.Addr { + if nc, ok := c.PacketConn.(remoteAddr); ok { + return nc.RemoteAddr() + } + return nil +} + +func (c *udpConn) SetReadBuffer(n int) error { + if nc, ok := c.PacketConn.(setBuffer); ok { + return nc.SetReadBuffer(n) + } + return errUnsupport +} + +func (c *udpConn) SetWriteBuffer(n int) error { + if nc, ok := c.PacketConn.(setBuffer); ok { + return nc.SetWriteBuffer(n) + } + return errUnsupport +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + if nc, ok := c.PacketConn.(io.Reader); ok { + n, err = nc.Read(b) + return + } + err = errUnsupport + return +} + +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } +} + +func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { + if nc, ok := c.PacketConn.(readUDP); ok { + for { + n, addr, err = nc.ReadFromUDP(b) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } + } + err = errUnsupport + return +} + +func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { + if nc, ok := c.PacketConn.(readUDP); ok { + for { + n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob) + if err != nil { + return + } + if c.admission != nil && + !c.admission.Admit(addr.String()) { + continue + } + return + } + } + err = errUnsupport + return +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + if nc, ok := c.PacketConn.(io.Writer); ok { + n, err = nc.Write(b) + return + } + err = errUnsupport + return +} + +func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + n, err = c.PacketConn.WriteTo(p, addr) + return +} + +func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { + if nc, ok := c.PacketConn.(writeUDP); ok { + n, err = nc.WriteToUDP(b, addr) + return + } + err = errUnsupport + return +} + +func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { + if nc, ok := c.PacketConn.(writeUDP); ok { + n, oobn, err = nc.WriteMsgUDP(b, oob, addr) + return + } + err = errUnsupport + return +} + +func (c *udpConn) SyscallConn() (rc syscall.RawConn, err error) { + if nc, ok := c.PacketConn.(syscallConn); ok { + return nc.SyscallConn() + } + err = errUnsupport + return +} + +func (c *udpConn) SetDSCP(n int) error { + if nc, ok := c.PacketConn.(setDSCP); ok { + return nc.SetDSCP(n) + } + return nil +} + +type UDPConn interface { + net.PacketConn + io.Reader + io.Writer + readUDP + writeUDP + setBuffer + syscallConn + remoteAddr +} + +type setBuffer interface { + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + +type readUDP interface { + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) +} + +type writeUDP interface { + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +type syscallConn interface { + SyscallConn() (syscall.RawConn, error) +} + +type remoteAddr interface { + RemoteAddr() net.Addr +} + +// tcpraw.TCPConn +type setDSCP interface { + SetDSCP(int) error +} diff --git a/pkg/admission/wrapper/listener.go b/pkg/admission/wrapper/listener.go new file mode 100644 index 0000000..eb411bb --- /dev/null +++ b/pkg/admission/wrapper/listener.go @@ -0,0 +1,37 @@ +package wrapper + +import ( + "net" + + "github.com/go-gost/gost/v3/pkg/admission" +) + +type listener struct { + net.Listener + admission admission.Admission +} + +func WrapListener(admission admission.Admission, ln net.Listener) net.Listener { + if admission == nil { + return ln + } + return &listener{ + Listener: ln, + admission: admission, + } +} + +func (ln *listener) Accept() (net.Conn, error) { + for { + c, err := ln.Listener.Accept() + if err != nil { + return nil, err + } + if ln.admission != nil && + !ln.admission.Admit(c.RemoteAddr().String()) { + c.Close() + continue + } + return c, err + } +} diff --git a/pkg/listener/tls/listener.go b/pkg/listener/tls/listener.go index bd01e0d..be3c468 100644 --- a/pkg/listener/tls/listener.go +++ b/pkg/listener/tls/listener.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "net" - "github.com/go-gost/gost/v3/pkg/common/admission" + admission "github.com/go-gost/gost/v3/pkg/admission/wrapper" "github.com/go-gost/gost/v3/pkg/listener" "github.com/go-gost/gost/v3/pkg/logger" md "github.com/go-gost/gost/v3/pkg/metadata"