initial commit

This commit is contained in:
ginuerzh
2022-03-16 19:40:29 +08:00
commit 7db81fcfeb
109 changed files with 8782 additions and 0 deletions

113
common/bufpool/pool.go Normal file
View File

@ -0,0 +1,113 @@
package bufpool
import "sync"
var (
pools = []struct {
size int
pool sync.Pool
}{
{
size: 128,
pool: sync.Pool{
New: func() any {
b := make([]byte, 128)
return &b
},
},
},
{
size: 512,
pool: sync.Pool{
New: func() any {
b := make([]byte, 512)
return &b
},
},
},
{
size: 1024,
pool: sync.Pool{
New: func() any {
b := make([]byte, 1024)
return &b
},
},
},
{
size: 4096,
pool: sync.Pool{
New: func() any {
b := make([]byte, 4096)
return &b
},
},
},
{
size: 8192,
pool: sync.Pool{
New: func() any {
b := make([]byte, 8192)
return &b
},
},
},
{
size: 16 * 1024,
pool: sync.Pool{
New: func() any {
b := make([]byte, 16*1024)
return &b
},
},
},
{
size: 32 * 1024,
pool: sync.Pool{
New: func() any {
b := make([]byte, 32*1024)
return &b
},
},
},
{
size: 64 * 1024,
pool: sync.Pool{
New: func() any {
b := make([]byte, 64*1024)
return &b
},
},
},
{
size: 65 * 1024,
pool: sync.Pool{
New: func() any {
b := make([]byte, 65*1024)
return &b
},
},
},
}
)
// Get returns a buffer of specified size.
func Get(size int) *[]byte {
for i := range pools {
if size <= pools[i].size {
b := pools[i].pool.Get().(*[]byte)
*b = (*b)[:size]
return b
}
}
b := make([]byte, size)
return &b
}
func Put(b *[]byte) {
for i := range pools {
if cap(*b) == pools[i].size {
pools[i].pool.Put(b)
}
}
}

99
common/matcher/matcher.go Normal file
View File

@ -0,0 +1,99 @@
package matcher
import (
"net"
"strings"
"github.com/gobwas/glob"
)
// Matcher is a generic pattern matcher,
// it gives the match result of the given pattern for specific v.
type Matcher interface {
Match(v string) bool
}
// NewMatcher creates a Matcher for the given pattern.
// The acutal Matcher depends on the pattern:
// IP Matcher if pattern is a valid IP address.
// CIDR Matcher if pattern is a valid CIDR address.
// Domain Matcher if both of the above are not.
func NewMatcher(pattern string) Matcher {
if pattern == "" {
return nil
}
if ip := net.ParseIP(pattern); ip != nil {
return IPMatcher(ip)
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
return CIDRMatcher(inet)
}
return DomainMatcher(pattern)
}
type ipMatcher struct {
ip net.IP
}
// IPMatcher creates a Matcher for a specific IP address.
func IPMatcher(ip net.IP) Matcher {
return &ipMatcher{
ip: ip,
}
}
func (m *ipMatcher) Match(ip string) bool {
if m == nil {
return false
}
return m.ip.Equal(net.ParseIP(ip))
}
type cidrMatcher struct {
ipNet *net.IPNet
}
// CIDRMatcher creates a Matcher for a specific CIDR notation IP address.
func CIDRMatcher(inet *net.IPNet) Matcher {
return &cidrMatcher{
ipNet: inet,
}
}
func (m *cidrMatcher) Match(ip string) bool {
if m == nil || m.ipNet == nil {
return false
}
return m.ipNet.Contains(net.ParseIP(ip))
}
type domainMatcher struct {
pattern string
glob glob.Glob
}
// DomainMatcher creates a Matcher for a specific domain pattern,
// the pattern can be a plain domain such as 'example.com',
// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'.
func DomainMatcher(pattern string) Matcher {
p := pattern
if strings.HasPrefix(pattern, ".") {
p = pattern[1:] // trim the prefix '.'
pattern = "*" + p
}
return &domainMatcher{
pattern: p,
glob: glob.MustCompile(pattern),
}
}
func (m *domainMatcher) Match(domain string) bool {
if m == nil || m.glob == nil {
return false
}
if domain == m.pattern {
return true
}
return m.glob.Match(domain)
}

144
common/net/dialer/dialer.go Normal file
View File

@ -0,0 +1,144 @@
package dialer
import (
"context"
"fmt"
"net"
"syscall"
"time"
"github.com/go-gost/core/logger"
)
var (
DefaultNetDialer = &NetDialer{
Timeout: 30 * time.Second,
}
)
type NetDialer struct {
Interface string
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
Logger logger.Logger
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
if d == nil {
d = DefaultNetDialer
}
log := d.Logger
if log == nil {
log = logger.Default()
}
ifceName, ifAddr, err := parseInterfaceAddr(d.Interface, network)
if err != nil {
return nil, err
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
logger.Default().Infof("interface: %s %v/%s", ifceName, ifAddr, network)
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
return net.ListenUDP(network, laddr)
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
Timeout: d.Timeout,
LocalAddr: ifAddr,
Control: func(network, address string, c syscall.RawConn) error {
var cerr error
err := c.Control(func(fd uintptr) {
cerr = bindDevice(fd, ifceName)
})
if err != nil {
return err
}
if cerr != nil {
return cerr
}
return nil
},
}
return netd.DialContext(ctx, network, addr)
}
func parseInterfaceAddr(ifceName, network string) (ifce string, addr net.Addr, err error) {
if ifceName == "" {
return
}
ip := net.ParseIP(ifceName)
if ip == nil {
var ife *net.Interface
ife, err = net.InterfaceByName(ifceName)
if err != nil {
return
}
var addrs []net.Addr
addrs, err = ife.Addrs()
if err != nil {
return
}
if len(addrs) == 0 {
err = fmt.Errorf("addr not found for interface %s", ifceName)
return
}
ip = addrs[0].(*net.IPNet).IP
ifce = ifceName
} else {
ifce, err = findInterfaceByIP(ip)
if err != nil {
return
}
}
port := 0
switch network {
case "tcp", "tcp4", "tcp6":
addr = &net.TCPAddr{IP: ip, Port: port}
return
case "udp", "udp4", "udp6":
addr = &net.UDPAddr{IP: ip, Port: port}
return
default:
addr = &net.IPAddr{IP: ip}
return
}
}
func findInterfaceByIP(ip net.IP) (string, error) {
ifces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, ifce := range ifces {
addrs, _ := ifce.Addrs()
if len(addrs) == 0 {
continue
}
for _, addr := range addrs {
ipAddr, _ := addr.(*net.IPNet)
if ipAddr == nil {
continue
}
// logger.Default().Infof("%s-%s", ipAddr, ip)
if ipAddr.IP.Equal(ip) {
return ifce.Name, nil
}
}
}
return "", nil
}

View File

@ -0,0 +1,14 @@
package dialer
import (
"golang.org/x/sys/unix"
)
func bindDevice(fd uintptr, ifceName string) error {
// unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
// unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
if ifceName == "" {
return nil
}
return unix.BindToDevice(int(fd), ifceName)
}

View File

@ -0,0 +1,7 @@
//go:build !linux
package dialer
func bindDevice(fd uintptr, ifceName string) error {
return nil
}

126
common/net/relay/relay.go Normal file
View File

@ -0,0 +1,126 @@
package relay
import (
"net"
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type UDPRelay struct {
pc1 net.PacketConn
pc2 net.PacketConn
bypass bypass.Bypass
bufferSize int
logger logger.Logger
}
func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay {
return &UDPRelay{
pc1: pc1,
pc2: pc2,
}
}
func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay {
r.bypass = bp
return r
}
func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay {
r.logger = logger
return r
}
func (r *UDPRelay) SetBufferSize(n int) {
r.bufferSize = n
}
func (r *UDPRelay) Run() (err error) {
bufSize := r.bufferSize
if bufSize <= 0 {
bufSize = 1024
}
errc := make(chan error, 2)
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := r.pc1.ReadFrom(*b)
if err != nil {
return err
}
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
if r.logger != nil {
r.logger.Warn("bypass: ", raddr)
}
return nil
}
if _, err := r.pc2.WriteTo((*b)[:n], raddr); err != nil {
return err
}
if r.logger != nil {
r.logger.Debugf("%s >>> %s data: %d",
r.pc2.LocalAddr(), raddr, n)
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := r.pc2.ReadFrom(*b)
if err != nil {
return err
}
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
if r.logger != nil {
r.logger.Warn("bypass: ", raddr)
}
return nil
}
if _, err := r.pc1.WriteTo((*b)[:n], raddr); err != nil {
return err
}
if r.logger != nil {
r.logger.Debugf("%s <<< %s data: %d",
r.pc2.LocalAddr(), raddr, n)
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
return <-errc
}

50
common/net/transport.go Normal file
View File

@ -0,0 +1,50 @@
package net
import (
"bufio"
"io"
"net"
"github.com/go-gost/core/common/bufpool"
)
func Transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1)
go func() {
errc <- copyBuffer(rw1, rw2)
}()
go func() {
errc <- copyBuffer(rw2, rw1)
}()
err := <-errc
if err != nil && err == io.EOF {
err = nil
}
return err
}
func copyBuffer(dst io.Writer, src io.Reader) error {
buf := bufpool.Get(16 * 1024)
defer bufpool.Put(buf)
_, err := io.CopyBuffer(dst, src, *buf)
return err
}
type bufferReaderConn struct {
net.Conn
br *bufio.Reader
}
func NewBufferReaderConn(conn net.Conn, br *bufio.Reader) net.Conn {
return &bufferReaderConn{
Conn: conn,
br: br,
}
}
func (c *bufferReaderConn) Read(b []byte) (int, error) {
return c.br.Read(b)
}

41
common/net/udp.go Normal file
View File

@ -0,0 +1,41 @@
package net
import (
"io"
"net"
"syscall"
)
type UDPConn interface {
net.PacketConn
io.Reader
io.Writer
readUDP
writeUDP
setBuffer
syscallConn
remoteAddr
}
type setBuffer interface {
SetReadBuffer(bytes int) error
SetWriteBuffer(bytes int) error
}
type readUDP interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
}
type writeUDP interface {
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
}
type syscallConn interface {
SyscallConn() (syscall.RawConn, error)
}
type remoteAddr interface {
RemoteAddr() net.Addr
}

View File

@ -0,0 +1,88 @@
package resolver
import (
"fmt"
"sync"
"time"
"github.com/go-gost/core/logger"
"github.com/miekg/dns"
)
type CacheKey string
// NewCacheKey generates resolver cache key from question of dns query.
func NewCacheKey(q *dns.Question) CacheKey {
if q == nil {
return ""
}
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
return CacheKey(key)
}
type cacheItem struct {
msg *dns.Msg
ts time.Time
ttl time.Duration
}
type Cache struct {
m sync.Map
logger logger.Logger
}
func NewCache() *Cache {
return &Cache{}
}
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
c.logger = logger
return c
}
func (c *Cache) Load(key CacheKey) *dns.Msg {
v, ok := c.m.Load(key)
if !ok {
return nil
}
item, ok := v.(*cacheItem)
if !ok {
return nil
}
if time.Since(item.ts) > item.ttl {
c.m.Delete(key)
return nil
}
c.logger.Debugf("hit resolver cache: %s", key)
return item.msg.Copy()
}
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
if key == "" || mr == nil || ttl < 0 {
return
}
if ttl == 0 {
for _, answer := range mr.Answer {
v := time.Duration(answer.Header().Ttl) * time.Second
if ttl == 0 || ttl > v {
ttl = v
}
}
}
if ttl == 0 {
ttl = 30 * time.Second
}
c.m.Store(key, &cacheItem{
msg: mr.Copy(),
ts: time.Now(),
ttl: ttl,
})
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
}

View File

@ -0,0 +1,30 @@
package resolver
import (
"net"
"github.com/miekg/dns"
)
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
if m == nil || ip == nil {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip := ip.To4(); ip != nil {
e.Family = 1
e.SourceNetmask = 24
e.Address = ip
} else {
e.Family = 2
e.SourceNetmask = 128
e.Address = ip.To16()
}
opt.Option = append(opt.Option, e)
m.Extra = append(m.Extra, opt)
}

178
common/util/tls/tls.go Normal file
View File

@ -0,0 +1,178 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"net"
"time"
)
var (
// DefaultConfig is a default TLS config for global use.
DefaultConfig *tls.Config
)
// LoadServerConfig loads the certificate from cert & key files and optional client CA file.
func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" && keyFile == "" {
return DefaultConfig.Clone(), nil
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
pool, err := loadCA(caFile)
if err != nil {
return nil, err
}
if pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
}
// LoadClientConfig loads the certificate from cert & key files and optional CA file.
func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) {
var cfg *tls.Config
if certFile == "" && keyFile == "" {
cfg = &tls.Config{}
} else {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
rootCAs, err := loadCA(caFile)
if err != nil {
return nil, err
}
cfg.RootCAs = rootCAs
cfg.ServerName = serverName
cfg.InsecureSkipVerify = !verify
// If the root ca is given, but skip verify, we verify the certificate manually.
if cfg.RootCAs != nil && !verify {
cfg.VerifyConnection = func(state tls.ConnectionState) error {
opts := x509.VerifyOptions{
Roots: cfg.RootCAs,
CurrentTime: time.Now(),
DNSName: "",
Intermediates: x509.NewCertPool(),
}
certs := state.PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
}
return cfg, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}
// Wrap a net.Conn into a client tls connection, performing any
// additional verification as needed.
//
// As of go 1.3, crypto/tls only supports either doing no certificate
// verification, or doing full verification including of the peer's
// DNS name. For consul, we want to validate that the certificate is
// signed by a known CA, but because consul doesn't use DNS names for
// node names, we don't verify the certificate DNS names. Since go 1.3
// no longer supports this mode of operation, we have to do it
// manually.
//
// This code is taken from consul:
// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go
func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
var err error
var tlsConn *tls.Conn
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
}
tlsConn = tls.Client(conn, tlsConfig)
// Otherwise perform handshake, but don't verify the domain
//
// The following is lightly-modified from the doFullHandshake
// method in https://golang.org/src/crypto/tls/handshake_client.go
if err = tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, err
}
// We can do this in `tls.Config.VerifyConnection`, which effective for
// other TLS protocols such as WebSocket. See `route.go:parseChainNode`
/*
// If crypto/tls is doing verification, there's no need to do our own.
if tlsConfig.InsecureSkipVerify == false {
return tlsConn, nil
}
// Similarly if we use host's CA, we can do full handshake
if tlsConfig.RootCAs == nil {
return tlsConn, nil
}
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
CurrentTime: time.Now(),
DNSName: "",
Intermediates: x509.NewCertPool(),
}
certs := tlsConn.ConnectionState().PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
tlsConn.Close()
return nil, err
}
*/
return tlsConn, err
}

102
common/util/udp/conn.go Normal file
View File

@ -0,0 +1,102 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
)
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type Conn struct {
net.PacketConn
localAddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
idle int32 // indicate the connection is idle
closed chan struct{}
closeMutex sync.Mutex
}
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
return &Conn{
PacketConn: c,
localAddr: localAddr,
remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize),
closed: make(chan struct{}),
}
}
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(&bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *Conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *Conn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
}
func (c *Conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *Conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *Conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *Conn) WriteQueue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}

120
common/util/udp/listener.go Normal file
View File

@ -0,0 +1,120 @@
package udp
import (
"net"
"sync"
"time"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type listener struct {
addr net.Addr
conn net.PacketConn
cqueue chan net.Conn
readQueueSize int
readBufferSize int
connPool *ConnPool
mux sync.Mutex
closed chan struct{}
errChan chan error
logger logger.Logger
}
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
ln := &listener{
conn: conn,
addr: addr,
cqueue: make(chan net.Conn, backlog),
connPool: NewConnPool(ttl).WithLogger(logger),
readQueueSize: dataQueueSize,
readBufferSize: dataBufferSize,
closed: make(chan struct{}),
errChan: make(chan error, 1),
logger: logger,
}
go ln.listenLoop()
return ln
}
func (ln *listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-ln.cqueue:
return
case <-ln.closed:
return nil, net.ErrClosed
case err = <-ln.errChan:
if err == nil {
err = net.ErrClosed
}
return
}
}
func (ln *listener) listenLoop() {
for {
select {
case <-ln.closed:
return
default:
}
b := bufpool.Get(ln.readBufferSize)
n, raddr, err := ln.conn.ReadFrom(*b)
if err != nil {
ln.errChan <- err
close(ln.errChan)
return
}
c := ln.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := c.WriteQueue((*b)[:n]); err != nil {
ln.logger.Warn("data discarded: ", err)
}
}
}
func (ln *listener) Addr() net.Addr {
return ln.addr
}
func (ln *listener) Close() error {
select {
case <-ln.closed:
default:
close(ln.closed)
ln.conn.Close()
ln.connPool.Close()
}
return nil
}
func (ln *listener) getConn(raddr net.Addr) *Conn {
ln.mux.Lock()
defer ln.mux.Unlock()
c, ok := ln.connPool.Get(raddr.String())
if ok {
return c
}
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
select {
case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c)
return c
default:
c.Close()
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil
}
}

100
common/util/udp/pool.go Normal file
View File

@ -0,0 +1,100 @@
package udp
import (
"sync"
"time"
"github.com/go-gost/core/logger"
)
type ConnPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func NewConnPool(ttl time.Duration) *ConnPool {
p := &ConnPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
p.logger = logger
return p
}
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*Conn)
}
return
}
func (p *ConnPool) Set(key any, c *Conn) {
p.m.Store(key, c)
}
func (p *ConnPool) Delete(key any) {
p.m.Delete(key)
}
func (p *ConnPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v any) bool {
if c, ok := v.(*Conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *ConnPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size := 0
idles := 0
p.m.Range(func(key, value any) bool {
c, ok := value.(*Conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-p.closed:
return
}
}
}