Compare commits

..

19 Commits

Author SHA1 Message Date
2697697552 add rewrite body http node setting 2024-07-19 20:42:33 +08:00
48d070d345 add router interface 2024-07-08 22:28:21 +08:00
30cc928705 add observer/stats 2024-07-04 23:03:22 +08:00
4e831b95e8 fix timeout for router 2024-06-25 20:37:08 +08:00
ca340b1bf1 add netns option for handler and listener 2024-06-24 21:13:23 +08:00
5aede9a2b3 add support for linux network namespace 2024-06-21 23:34:12 +08:00
8d554ddcf7 add filter node option 2024-05-08 21:20:29 +08:00
5d6c2115fa update go.mod 2024-04-24 23:31:55 +08:00
f06f3bb46a Merge pull request #6 from cgroschupp/feature/whitelist
feat: change whitelist bypass behaviour
2024-04-24 22:49:14 +08:00
ed7f2dcdfc feat: change whitelist bypass behaviour 2024-04-16 20:58:50 +02:00
fea73cf682 Update recorder.go
Changed functions name to MetadataRecordOption
2024-03-09 20:45:15 +08:00
a06608ccaf added http url rewrite setting for forwarder node 2024-01-31 23:17:24 +08:00
04314fa084 added auther option for node http settings 2024-01-27 21:06:04 +08:00
5a427b4eaf add observer 2024-01-03 20:53:00 +08:00
6b5c04b5e4 added logger group 2023-12-19 21:23:06 +08:00
abc73f2ca2 update router interface 2023-11-19 16:14:03 +08:00
6b01698ea9 add router interface 2023-11-19 14:20:35 +08:00
486f2cee61 add traffic limiter option for handler 2023-11-18 18:25:40 +08:00
a916f04016 update ingress interface 2023-11-13 20:38:50 +08:00
28 changed files with 452 additions and 1133 deletions

View File

@ -1,6 +1,9 @@
package bypass
import "context"
import (
"context"
"slices"
)
type Options struct {
Host string
@ -24,6 +27,7 @@ func WithPathOption(path string) Option {
// Bypass is a filter of address (IP or domain).
type Bypass interface {
// Contains reports whether the bypass includes addr.
IsWhitelist() bool
Contains(ctx context.Context, network, addr string, opts ...Option) bool
}
@ -38,10 +42,33 @@ func BypassGroup(bypasses ...Bypass) Bypass {
}
func (p *bypassGroup) Contains(ctx context.Context, network, addr string, opts ...Option) bool {
var whitelist, blacklist []bool
for _, bypass := range p.bypasses {
if bypass != nil && bypass.Contains(ctx, network, addr, opts...) {
return true
result := bypass.Contains(ctx, network, addr, opts...)
if bypass.IsWhitelist() {
whitelist = append(whitelist, result)
} else {
blacklist = append(blacklist, result)
}
}
status := false
if len(whitelist) > 0 {
if slices.Contains(whitelist, false) {
status = false
} else {
status = true
}
}
if !status && len(blacklist) > 0 {
if slices.Contains(blacklist, true) {
status = true
} else {
status = false
}
}
return status
}
func (p *bypassGroup) IsWhitelist() bool {
return false
}

View File

@ -1,6 +1,8 @@
package chain
import (
"regexp"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/hosts"
@ -9,9 +11,29 @@ import (
"github.com/go-gost/core/selector"
)
type NodeFilterSettings struct {
Protocol string
Host string
Path string
}
type HTTPURLRewriteSetting struct {
Pattern *regexp.Regexp
Replacement string
}
type HTTPBodyRewriteSettings struct {
Type string
Pattern *regexp.Regexp
Replacement []byte
}
type HTTPNodeSettings struct {
Host string
Header map[string]string
Host string
Header map[string]string
Auther auth.Authenticator
RewriteURL []HTTPURLRewriteSetting
RewriteBody []HTTPBodyRewriteSettings
}
type TLSNodeSettings struct {
@ -25,23 +47,20 @@ type TLSNodeSettings struct {
}
type NodeOptions struct {
Transport *Transport
Network string
Transport Transporter
Bypass bypass.Bypass
Resolver resolver.Resolver
HostMapper hosts.HostMapper
Metadata metadata.Metadata
Host string
Network string
Protocol string
Path string
Filter *NodeFilterSettings
HTTP *HTTPNodeSettings
TLS *TLSNodeSettings
Auther auth.Authenticator
Metadata metadata.Metadata
}
type NodeOption func(*NodeOptions)
func TransportNodeOption(tr *Transport) NodeOption {
func TransportNodeOption(tr Transporter) NodeOption {
return func(o *NodeOptions) {
o.Transport = tr
}
@ -65,33 +84,15 @@ func HostMapperNodeOption(m hosts.HostMapper) NodeOption {
}
}
func HostNodeOption(host string) NodeOption {
return func(o *NodeOptions) {
o.Host = host
}
}
func NetworkNodeOption(network string) NodeOption {
return func(o *NodeOptions) {
o.Network = network
}
}
func ProtocolNodeOption(protocol string) NodeOption {
func NodeFilterOption(filter *NodeFilterSettings) NodeOption {
return func(o *NodeOptions) {
o.Protocol = protocol
}
}
func PathNodeOption(path string) NodeOption {
return func(o *NodeOptions) {
o.Path = path
}
}
func MetadataNodeOption(md metadata.Metadata) NodeOption {
return func(o *NodeOptions) {
o.Metadata = md
o.Filter = filter
}
}
@ -107,9 +108,9 @@ func TLSNodeOption(tlsSettings *TLSNodeSettings) NodeOption {
}
}
func AutherNodeOption(auther auth.Authenticator) NodeOption {
func MetadataNodeOption(md metadata.Metadata) NodeOption {
return func(o *NodeOptions) {
o.Auther = auther
o.Metadata = md
}
}

View File

@ -1,44 +0,0 @@
package chain
import (
"context"
"fmt"
"net"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/resolver"
)
func Resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
if addr == "" {
return addr, nil
}
host, port, _ := net.SplitHostPort(addr)
if host == "" {
return addr, nil
}
if hosts != nil {
if ips, _ := hosts.Lookup(ctx, network, host); len(ips) > 0 {
log.Debugf("hit host mapper: %s -> %s", host, ips)
return net.JoinHostPort(ips[0].String(), port), nil
}
}
if r != nil {
ips, err := r.Resolve(ctx, network, host)
if err != nil {
if err == resolver.ErrInvalid {
return addr, nil
}
log.Error(err)
}
if len(ips) == 0 {
return "", fmt.Errorf("resolver: domain %s does not exist", host)
}
return net.JoinHostPort(ips[0].String(), port), nil
}
return addr, nil
}

View File

@ -2,117 +2,39 @@ package chain
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/logger"
)
var (
ErrEmptyRoute = errors.New("empty route")
)
var (
DefaultRoute Route = &route{}
)
type Route interface {
Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
Nodes() []*Node
}
// route is a Route without nodes.
type route struct{}
func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) {
var options DialOptions
for _, opt := range opts {
opt(&options)
}
netd := dialer.NetDialer{
Timeout: options.Timeout,
Interface: options.Interface,
Logger: options.Logger,
}
if options.SockOpts != nil {
netd.Mark = options.SockOpts.Mark
}
return netd.Dial(ctx, network, address)
}
func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) {
var options BindOptions
for _, opt := range opts {
opt(&options)
}
switch network {
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
return net.ListenTCP(network, addr)
case "udp", "udp4", "udp6":
addr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
logger := logger.Default().WithFields(map[string]any{
"network": network,
"address": address,
})
ln := udp.NewListener(conn, &udp.ListenConfig{
Backlog: options.Backlog,
ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: options.UDPDataBufferSize,
TTL: options.UDPConnTTL,
KeepAlive: true,
Logger: logger,
})
return ln, err
default:
err := fmt.Errorf("network %s unsupported", network)
return nil, err
}
}
func (r *route) Nodes() []*Node {
return nil
}
type DialOptions struct {
Timeout time.Duration
Interface string
Netns string
SockOpts *SockOpts
Logger logger.Logger
}
type DialOption func(opts *DialOptions)
func TimeoutDialOption(d time.Duration) DialOption {
return func(opts *DialOptions) {
opts.Timeout = d
}
}
func InterfaceDialOption(ifName string) DialOption {
return func(opts *DialOptions) {
opts.Interface = ifName
}
}
func NetnsDialOption(netns string) DialOption {
return func(opts *DialOptions) {
opts.Netns = netns
}
}
func SockOptsDialOption(so *SockOpts) DialOption {
return func(opts *DialOptions) {
opts.SockOpts = so

View File

@ -1,9 +1,7 @@
package chain
import (
"bytes"
"context"
"fmt"
"net"
"time"
@ -21,6 +19,7 @@ type RouterOptions struct {
Retries int
Timeout time.Duration
IfceName string
Netns string
SockOpts *SockOpts
Chain Chainer
Resolver resolver.Resolver
@ -37,6 +36,12 @@ func InterfaceRouterOption(ifceName string) RouterOption {
}
}
func NetnsRouterOption(netns string) RouterOption {
return func(o *RouterOptions) {
o.Netns = netns
}
}
func SockOptsRouterOption(so *SockOpts) RouterOption {
return func(o *RouterOptions) {
o.SockOpts = so
@ -85,177 +90,8 @@ func LoggerRouterOption(logger logger.Logger) RouterOption {
}
}
type Router struct {
options RouterOptions
}
func NewRouter(opts ...RouterOption) *Router {
r := &Router{}
for _, opt := range opts {
if opt != nil {
opt(&r.options)
}
}
if r.options.Logger == nil {
r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"})
}
return r
}
func (r *Router) Options() *RouterOptions {
if r == nil {
return nil
}
return &r.options
}
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
host := address
if h, _, _ := net.SplitHostPort(address); h != "" {
host = h
}
r.record(ctx, recorder.RecorderServiceRouterDialAddress, []byte(host))
conn, err = r.dial(ctx, network, address)
if err != nil {
r.record(ctx, recorder.RecorderServiceRouterDialAddressError, []byte(host))
return
}
if network == "udp" || network == "udp4" || network == "udp6" {
if _, ok := conn.(net.PacketConn); !ok {
return &packetConn{conn}, nil
}
}
return
}
func (r *Router) record(ctx context.Context, name string, data []byte) error {
if len(data) == 0 {
return nil
}
for _, rec := range r.options.Recorders {
if rec.Record == name {
err := rec.Recorder.Record(ctx, data)
if err != nil {
r.options.Logger.Errorf("record %s: %v", name, err)
}
return err
}
}
return nil
}
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ {
var ipAddr string
ipAddr, err = Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
if err != nil {
r.options.Logger.Error(err)
break
}
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, ipAddr, WithHostRouteOption(address))
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", ipAddr)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
conn, err = route.Dial(ctx, network, ipAddr,
InterfaceDialOption(r.options.IfceName),
SockOptsDialOption(r.options.SockOpts),
LoggerDialOption(r.options.Logger),
TimeoutDialOption(r.options.Timeout),
)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
count := r.options.Retries + 1
if count <= 0 {
count = 1
}
r.options.Logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ {
var route Route
if r.options.Chain != nil {
route = r.options.Chain.Route(ctx, network, address)
if route == nil || len(route.Nodes()) == 0 {
err = ErrEmptyRoute
return
}
}
if r.options.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{}
for _, node := range routePath(route) {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
}
fmt.Fprintf(&buf, "%s", address)
r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String())
}
if route == nil {
route = DefaultRoute
}
ln, err = route.Bind(ctx, network, address, opts...)
if err == nil {
break
}
r.options.Logger.Errorf("route(retry=%d) %s", i, err)
}
return
}
func routePath(route Route) (path []*Node) {
if route == nil {
return
}
for _, node := range route.Nodes() {
if tr := node.Options().Transport; tr != nil {
path = append(path, routePath(tr.Options().Route)...)
}
path = append(path, node)
}
return
}
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
type Router interface {
Options() *RouterOptions
Dial(ctx context.Context, network, address string) (net.Conn, error)
Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error)
}

View File

@ -3,19 +3,16 @@ package chain
import (
"context"
"net"
"time"
net_dialer "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/dialer"
)
type TransportOptions struct {
Addr string
IfceName string
Netns string
SockOpts *SockOpts
Route Route
Timeout time.Duration
}
type TransportOption func(*TransportOptions)
@ -32,6 +29,12 @@ func InterfaceTransportOption(ifceName string) TransportOption {
}
}
func NetnsTransportOption(netns string) TransportOption {
return func(o *TransportOptions) {
o.Netns = netns
}
}
func SockOptsTransportOption(so *SockOpts) TransportOption {
return func(o *TransportOptions) {
o.SockOpts = so
@ -44,103 +47,12 @@ func RouteTransportOption(route Route) TransportOption {
}
}
func TimeoutTransportOption(timeout time.Duration) TransportOption {
return func(o *TransportOptions) {
o.Timeout = timeout
}
}
type Transport struct {
dialer dialer.Dialer
connector connector.Connector
options TransportOptions
}
func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOption) *Transport {
tr := &Transport{
dialer: d,
connector: c,
}
for _, opt := range opts {
if opt != nil {
opt(&tr.options)
}
}
return tr
}
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Timeout: tr.options.Timeout,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
if tr.options.Route != nil && len(tr.options.Route.Nodes()) > 0 {
netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tr.options.Route.Dial(ctx, network, addr)
}
}
opts := []dialer.DialOption{
dialer.HostDialOption(tr.options.Addr),
dialer.NetDialerDialOption(netd),
}
return tr.dialer.Dial(ctx, addr, opts...)
}
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
var err error
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
conn, err = hs.Handshake(ctx, conn,
dialer.AddrHandshakeOption(tr.options.Addr))
if err != nil {
return nil, err
}
}
if hs, ok := tr.connector.(connector.Handshaker); ok {
return hs.Handshake(ctx, conn)
}
return conn, nil
}
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName,
Timeout: tr.options.Timeout,
}
if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark
}
return tr.connector.Connect(ctx, conn, network, address,
connector.NetDialerConnectOption(netd),
)
}
func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if binder, ok := tr.connector.(connector.Binder); ok {
return binder.Bind(ctx, conn, network, address, opts...)
}
return nil, connector.ErrBindUnsupported
}
func (tr *Transport) Multiplex() bool {
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
return mux.Multiplex()
}
return false
}
func (tr *Transport) Options() *TransportOptions {
if tr != nil {
return &tr.options
}
return nil
}
func (tr *Transport) Copy() *Transport {
tr2 := &Transport{}
*tr2 = *tr
return tr
type Transporter interface {
Dial(ctx context.Context, addr string) (net.Conn, error)
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error)
Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error)
Multiplex() bool
Options() *TransportOptions
Copy() Transporter
}

View File

@ -1,84 +0,0 @@
package net
import (
"fmt"
"net"
)
func ParseInterfaceAddr(ifceName, network string) (ifce string, addr []net.Addr, err error) {
if ifceName == "" {
addr = append(addr, nil)
return
}
ip := net.ParseIP(ifceName)
if ip == nil {
var ife *net.Interface
ife, err = net.InterfaceByName(ifceName)
if err != nil {
return
}
var addrs []net.Addr
addrs, err = ife.Addrs()
if err != nil {
return
}
if len(addrs) == 0 {
err = fmt.Errorf("addr not found for interface %s", ifceName)
return
}
ifce = ifceName
for _, addr_ := range addrs {
if ipNet, ok := addr_.(*net.IPNet); ok {
addr = append(addr, ipToAddr(ipNet.IP, network))
}
}
} else {
ifce, err = findInterfaceByIP(ip)
if err != nil {
return
}
addr = []net.Addr{ipToAddr(ip, network)}
}
return
}
func ipToAddr(ip net.IP, network string) (addr net.Addr) {
port := 0
switch network {
case "tcp", "tcp4", "tcp6":
addr = &net.TCPAddr{IP: ip, Port: port}
return
case "udp", "udp4", "udp6":
addr = &net.UDPAddr{IP: ip, Port: port}
return
default:
addr = &net.IPAddr{IP: ip}
return
}
}
func findInterfaceByIP(ip net.IP) (string, error) {
ifces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, ifce := range ifces {
addrs, _ := ifce.Addrs()
if len(addrs) == 0 {
continue
}
for _, addr := range addrs {
ipAddr, _ := addr.(*net.IPNet)
if ipAddr == nil {
continue
}
// logger.Default().Infof("%s-%s", ipAddr, ip)
if ipAddr.IP.Equal(ip) {
return ifce.Name, nil
}
}
}
return "", nil
}

10
common/net/dialer.go Normal file
View File

@ -0,0 +1,10 @@
package net
import (
"context"
"net"
)
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}

View File

@ -1,154 +0,0 @@
package dialer
import (
"context"
"fmt"
"net"
"strings"
"syscall"
"time"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
const (
DefaultTimeout = 10 * time.Second
)
var (
DefaultNetDialer = &NetDialer{}
)
type NetDialer struct {
Interface string
Mark int
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
Logger logger.Logger
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if d == nil {
d = DefaultNetDialer
}
timeout := d.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
log := d.Logger
if log == nil {
log = logger.Default()
}
switch network {
case "unix":
netd := net.Dialer{}
return netd.DialContext(ctx, network, addr)
default:
}
deadline := time.Now().Add(timeout)
ifces := strings.Split(d.Interface, ",")
for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!")
ifce = strings.TrimSuffix(ifce, "!")
var ifceName string
var ifAddrs []net.Addr
ifceName, ifAddrs, err = xnet.ParseInterfaceAddr(ifce, network)
if err != nil && strict {
return
}
for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, deadline, log)
if err == nil {
return
}
log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err)
if strict &&
!strings.Contains(err.Error(), "no suitable address found") &&
!strings.Contains(err.Error(), "mismatched local address type") {
return
}
if time.Until(deadline) < 0 {
return
}
}
}
return
}
func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, deadline time.Time, log logger.Logger) (net.Conn, error) {
if ifceName != "" {
log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network)
}
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
c, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
sc, err := c.SyscallConn()
if err != nil {
log.Error(err)
return nil, err
}
err = sc.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
if err != nil {
log.Error(err)
}
return c, nil
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
Deadline: deadline,
LocalAddr: ifAddr,
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
},
}
return netd.DialContext(ctx, network, addr)
}

View File

@ -1,19 +0,0 @@
package dialer
import (
"golang.org/x/sys/unix"
)
func bindDevice(fd uintptr, ifceName string) error {
if ifceName == "" {
return nil
}
return unix.BindToDevice(int(fd), ifceName)
}
func setMark(fd uintptr, mark int) error {
if mark == 0 {
return nil
}
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
}

View File

@ -1,11 +0,0 @@
//go:build !linux
package dialer
func bindDevice(fd uintptr, ifceName string) error {
return nil
}
func setMark(fd uintptr, mark int) error {
return nil
}

View File

@ -1,108 +0,0 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
)
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type conn struct {
net.PacketConn
localAddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
idle int32 // indicate the connection is idle
closed chan struct{}
closeMutex sync.Mutex
keepAlive bool
}
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
return &conn{
PacketConn: c,
localAddr: laddr,
remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize),
closed: make(chan struct{}),
keepAlive: keepAlive,
}
}
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *conn) Write(b []byte) (n int, err error) {
n, err = c.WriteTo(b, c.remoteAddr)
if !c.keepAlive {
c.Close()
}
return
}
func (c *conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) WriteQueue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}

View File

@ -1,131 +0,0 @@
package udp
import (
"net"
"time"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type ListenConfig struct {
Addr net.Addr
Backlog int
ReadQueueSize int
ReadBufferSize int
TTL time.Duration
KeepAlive bool
Logger logger.Logger
}
type listener struct {
conn net.PacketConn
cqueue chan net.Conn
connPool *connPool
// mux sync.Mutex
closed chan struct{}
errChan chan error
config *ListenConfig
}
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
if cfg == nil {
cfg = &ListenConfig{}
}
ln := &listener{
conn: conn,
cqueue: make(chan net.Conn, cfg.Backlog),
closed: make(chan struct{}),
errChan: make(chan error, 1),
config: cfg,
}
if cfg.KeepAlive {
ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger)
}
go ln.listenLoop()
return ln
}
func (ln *listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-ln.cqueue:
return
case <-ln.closed:
return nil, net.ErrClosed
case err = <-ln.errChan:
if err == nil {
err = net.ErrClosed
}
return
}
}
func (ln *listener) listenLoop() {
for {
select {
case <-ln.closed:
return
default:
}
b := bufpool.Get(ln.config.ReadBufferSize)
n, raddr, err := ln.conn.ReadFrom(b)
if err != nil {
ln.errChan <- err
close(ln.errChan)
return
}
c := ln.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := c.WriteQueue(b[:n]); err != nil {
ln.config.Logger.Warn("data discarded: ", err)
}
}
}
func (ln *listener) Addr() net.Addr {
if ln.config.Addr != nil {
return ln.config.Addr
}
return ln.conn.LocalAddr()
}
func (ln *listener) Close() error {
select {
case <-ln.closed:
default:
close(ln.closed)
ln.conn.Close()
ln.connPool.Close()
}
return nil
}
func (ln *listener) getConn(raddr net.Addr) *conn {
// ln.mux.Lock()
// defer ln.mux.Unlock()
c, ok := ln.connPool.Get(raddr.String())
if ok {
return c
}
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
select {
case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c)
return c
default:
c.Close()
ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil
}
}

View File

@ -1,115 +0,0 @@
package udp
import (
"sync"
"time"
"github.com/go-gost/core/logger"
)
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key any) (c *conn, ok bool) {
if p == nil {
return
}
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key any, c *conn) {
if p == nil {
return
}
p.m.Store(key, c)
}
func (p *connPool) Delete(key any) {
if p == nil {
return
}
p.m.Delete(key)
}
func (p *connPool) Close() {
if p == nil {
return
}
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v any) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size := 0
idles := 0
p.m.Range(func(key, value any) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-p.closed:
return
}
}
}

View File

@ -1,41 +0,0 @@
package udp
import (
"io"
"net"
"syscall"
)
type Conn interface {
net.PacketConn
io.Reader
io.Writer
readUDP
writeUDP
setBuffer
syscallConn
remoteAddr
}
type setBuffer interface {
SetReadBuffer(bytes int) error
SetWriteBuffer(bytes int) error
}
type readUDP interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
type writeUDP interface {
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
}
type syscallConn interface {
SyscallConn() (syscall.RawConn, error)
}
type remoteAddr interface {
RemoteAddr() net.Addr
}

View File

@ -5,7 +5,7 @@ import (
"net/url"
"time"
"github.com/go-gost/core/common/net/dialer"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
@ -36,14 +36,14 @@ func LoggerOption(logger logger.Logger) Option {
}
type ConnectOptions struct {
NetDialer *dialer.NetDialer
Dialer xnet.Dialer
}
type ConnectOption func(opts *ConnectOptions)
func NetDialerConnectOption(netd *dialer.NetDialer) ConnectOption {
func DialerConnectOption(dialer xnet.Dialer) ConnectOption {
return func(opts *ConnectOptions) {
opts.NetDialer = netd
opts.Dialer = dialer
}
}

View File

@ -4,7 +4,7 @@ import (
"crypto/tls"
"net/url"
"github.com/go-gost/core/common/net/dialer"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
@ -42,8 +42,8 @@ func ProxyProtocolOption(ppv int) Option {
}
type DialOptions struct {
Host string
NetDialer *dialer.NetDialer
Host string
Dialer xnet.Dialer
}
type DialOption func(opts *DialOptions)
@ -54,9 +54,9 @@ func HostDialOption(host string) DialOption {
}
}
func NetDialerDialOption(netd *dialer.NetDialer) DialOption {
func NetDialerDialOption(dialer xnet.Dialer) DialOption {
return func(opts *DialOptions) {
opts.NetDialer = netd
opts.Dialer = dialer
}
}

4
go.mod
View File

@ -1,5 +1,5 @@
module github.com/go-gost/core
go 1.18
go 1.22
require golang.org/x/sys v0.12.0
toolchain go1.22.2

2
go.sum
View File

@ -1,2 +0,0 @@
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -8,19 +8,24 @@ import (
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/limiter/traffic"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/metadata"
"github.com/go-gost/core/observer"
)
type Options struct {
Bypass bypass.Bypass
Router *chain.Router
Router chain.Router
Auth *url.Userinfo
Auther auth.Authenticator
RateLimiter rate.RateLimiter
Limiter traffic.TrafficLimiter
TLSConfig *tls.Config
Logger logger.Logger
Observer observer.Observer
Service string
Netns string
}
type Option func(opts *Options)
@ -31,7 +36,7 @@ func BypassOption(bypass bypass.Bypass) Option {
}
}
func RouterOption(router *chain.Router) Option {
func RouterOption(router chain.Router) Option {
return func(opts *Options) {
opts.Router = router
}
@ -55,6 +60,12 @@ func RateLimiterOption(limiter rate.RateLimiter) Option {
}
}
func TrafficLimiterOption(limiter traffic.TrafficLimiter) Option {
return func(opts *Options) {
opts.Limiter = limiter
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
@ -67,12 +78,24 @@ func LoggerOption(logger logger.Logger) Option {
}
}
func ObserverOption(observer observer.Observer) Option {
return func(opts *Options) {
opts.Observer = observer
}
}
func ServiceOption(service string) Option {
return func(opts *Options) {
opts.Service = service
}
}
func NetnsOption(netns string) Option {
return func(opts *Options) {
opts.Netns = netns
}
}
type HandleOptions struct {
Metadata metadata.Metadata
}

View File

@ -2,15 +2,20 @@ package ingress
import "context"
type GetOptions struct{}
type Options struct{}
type GetOption func(opts *GetOptions)
type Option func(opts *Options)
type SetOptions struct{}
type SetOption func(opts *SetOptions)
type Rule struct {
// Hostname is the hostname match pattern, e.g. example.com, *.example.org or .example.com.
Hostname string
// Endpoint is the tunnel ID for the hostname.
Endpoint string
}
type Ingress interface {
Get(ctx context.Context, host string, opts ...GetOption) string
Set(ctx context.Context, host, endpoint string, opts ...SetOption)
// SetRule adds or updates a rule for the ingress.
SetRule(ctx context.Context, rule *Rule, opts ...Option) bool
// GetRule queries a rule by host.
GetRule(ctx context.Context, host string, opts ...Option) *Rule
}

View File

@ -10,7 +10,40 @@ type Limiter interface {
Set(n int)
}
type TrafficLimiter interface {
In(key string) Limiter
Out(key string) Limiter
type Options struct {
Network string
Addr string
Client string
Src string
}
type Option func(opts *Options)
func NetworkOption(network string) Option {
return func(opts *Options) {
opts.Network = network
}
}
func AddrOption(addr string) Option {
return func(opts *Options) {
opts.Addr = addr
}
}
func ClientOption(client string) Option {
return func(opts *Options) {
opts.Client = client
}
}
func SrcOption(src string) Option {
return func(opts *Options) {
opts.Src = src
}
}
type TrafficLimiter interface {
In(ctx context.Context, key string, opts ...Option) Limiter
Out(ctx context.Context, key string, opts ...Option) Limiter
}

View File

@ -10,6 +10,7 @@ import (
"github.com/go-gost/core/limiter/conn"
"github.com/go-gost/core/limiter/traffic"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/observer/stats"
)
type Options struct {
@ -21,9 +22,12 @@ type Options struct {
TrafficLimiter traffic.TrafficLimiter
ConnLimiter conn.ConnLimiter
Chain chain.Chainer
Stats *stats.Stats
Logger logger.Logger
Service string
ProxyProtocol int
Netns string
Router chain.Router
}
type Option func(opts *Options)
@ -70,9 +74,9 @@ func ConnLimiterOption(limiter conn.ConnLimiter) Option {
}
}
func ChainOption(chain chain.Chainer) Option {
func StatsOption(stats *stats.Stats) Option {
return func(opts *Options) {
opts.Chain = chain
opts.Stats = stats
}
}
@ -93,3 +97,15 @@ func ProxyProtocolOption(ppv int) Option {
opts.ProxyProtocol = ppv
}
}
func NetnsOption(netns string) Option {
return func(opts *Options) {
opts.Netns = netns
}
}
func RouterOption(router chain.Router) Option {
return func(opts *Options) {
opts.Router = router
}
}

View File

@ -55,3 +55,109 @@ func Default() Logger {
func SetDefault(logger Logger) {
defaultLogger = logger
}
type loggerGroup struct {
loggers []Logger
}
func LoggerGroup(loggers ...Logger) Logger {
return &loggerGroup{
loggers: loggers,
}
}
func (l *loggerGroup) WithFields(m map[string]any) Logger {
lg := &loggerGroup{}
for i := range l.loggers {
lg.loggers = append(lg.loggers, l.loggers[i].WithFields(m))
}
return lg
}
func (l *loggerGroup) Trace(args ...any) {
for _, lg := range l.loggers {
lg.Trace(args...)
}
}
func (l *loggerGroup) Tracef(format string, args ...any) {
for _, lg := range l.loggers {
lg.Tracef(format, args...)
}
}
func (l *loggerGroup) Debug(args ...any) {
for _, lg := range l.loggers {
lg.Debug(args...)
}
}
func (l *loggerGroup) Debugf(format string, args ...any) {
for _, lg := range l.loggers {
lg.Debugf(format, args...)
}
}
func (l *loggerGroup) Info(args ...any) {
for _, lg := range l.loggers {
lg.Info(args...)
}
}
func (l *loggerGroup) Infof(format string, args ...any) {
for _, lg := range l.loggers {
lg.Infof(format, args...)
}
}
func (l *loggerGroup) Warn(args ...any) {
for _, lg := range l.loggers {
lg.Warn(args...)
}
}
func (l *loggerGroup) Warnf(format string, args ...any) {
for _, lg := range l.loggers {
lg.Warnf(format, args...)
}
}
func (l *loggerGroup) Error(args ...any) {
for _, lg := range l.loggers {
lg.Error(args...)
}
}
func (l *loggerGroup) Errorf(format string, args ...any) {
for _, lg := range l.loggers {
lg.Errorf(format, args...)
}
}
func (l *loggerGroup) Fatal(args ...any) {
for _, lg := range l.loggers {
lg.Fatal(args...)
}
}
func (l *loggerGroup) Fatalf(format string, args ...any) {
for _, lg := range l.loggers {
lg.Fatalf(format, args...)
}
}
func (l *loggerGroup) GetLevel() LogLevel {
for _, lg := range l.loggers {
return lg.GetLevel()
}
return InfoLevel
}
func (l *loggerGroup) IsLevelEnabled(level LogLevel) bool {
for _, lg := range l.loggers {
if lg.IsLevelEnabled(level) {
return true
}
}
return false
}

22
observer/observer.go Normal file
View File

@ -0,0 +1,22 @@
package observer
import "context"
type Options struct{}
type Option func(opts *Options)
type Observer interface {
Observe(ctx context.Context, events []Event, opts ...Option) error
}
type EventType string
const (
EventStatus EventType = "status"
EventStats EventType = "stats"
)
type Event interface {
Type() EventType
}

93
observer/stats/stats.go Normal file
View File

@ -0,0 +1,93 @@
package stats
import (
"sync/atomic"
"github.com/go-gost/core/observer"
)
type Kind int
const (
KindTotalConns Kind = 1
KindCurrentConns Kind = 2
KindInputBytes Kind = 3
KindOutputBytes Kind = 4
KindTotalErrs Kind = 5
)
type Stats struct {
updated atomic.Bool
totalConns atomic.Uint64
currentConns atomic.Int64
inputBytes atomic.Uint64
outputBytes atomic.Uint64
totalErrs atomic.Uint64
}
func (s *Stats) Add(kind Kind, n int64) {
if s == nil {
return
}
switch kind {
case KindTotalConns:
if n > 0 {
s.totalConns.Add(uint64(n))
}
case KindCurrentConns:
s.currentConns.Add(n)
case KindInputBytes:
if n > 0 {
s.inputBytes.Add(uint64(n))
}
case KindOutputBytes:
if n > 0 {
s.outputBytes.Add(uint64(n))
}
case KindTotalErrs:
if n > 0 {
s.totalErrs.Add(uint64(n))
}
}
s.updated.Store(true)
}
func (s *Stats) Get(kind Kind) uint64 {
if s == nil {
return 0
}
switch kind {
case KindTotalConns:
return s.totalConns.Load()
case KindCurrentConns:
return uint64(s.currentConns.Load())
case KindInputBytes:
return s.inputBytes.Load()
case KindOutputBytes:
return s.outputBytes.Load()
case KindTotalErrs:
return s.totalErrs.Load()
}
return 0
}
func (s *Stats) IsUpdated() bool {
return s.updated.Swap(false)
}
type StatsEvent struct {
Kind string
Service string
Client string
TotalConns uint64
CurrentConns uint64
InputBytes uint64
OutputBytes uint64
TotalErrs uint64
}
func (StatsEvent) Type() observer.EventType {
return observer.EventStats
}

View File

@ -10,7 +10,7 @@ type RecordOptions struct {
type RecordOption func(opts *RecordOptions)
func MetadataReocrdOption(md any) RecordOption {
func MetadataRecordOption(md any) RecordOption {
return func(opts *RecordOptions) {
opts.Metadata = md
}

22
router/router.go Normal file
View File

@ -0,0 +1,22 @@
package router
import (
"context"
"net"
)
type Options struct{}
type Option func(opts *Options)
type Route struct {
// Net is the destination network, e.g. 192.168.0.0/16, 172.10.10.0/24.
Net *net.IPNet
// Gateway is the gateway for the destination network.
Gateway net.IP
}
type Router interface {
// GetRoute queries a route by destination IP address.
GetRoute(ctx context.Context, dst net.IP, opts ...Option) *Route
}