initial commit
This commit is contained in:
88
common/util/resolver/cache.go
Normal file
88
common/util/resolver/cache.go
Normal file
@ -0,0 +1,88 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type CacheKey string
|
||||
|
||||
// NewCacheKey generates resolver cache key from question of dns query.
|
||||
func NewCacheKey(q *dns.Question) CacheKey {
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
||||
return CacheKey(key)
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
msg *dns.Msg
|
||||
ts time.Time
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
m sync.Map
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewCache() *Cache {
|
||||
return &Cache{}
|
||||
}
|
||||
|
||||
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
|
||||
c.logger = logger
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) Load(key CacheKey) *dns.Msg {
|
||||
v, ok := c.m.Load(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
item, ok := v.(*cacheItem)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if time.Since(item.ts) > item.ttl {
|
||||
c.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debugf("hit resolver cache: %s", key)
|
||||
|
||||
return item.msg.Copy()
|
||||
}
|
||||
|
||||
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||
if key == "" || mr == nil || ttl < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if ttl == 0 {
|
||||
for _, answer := range mr.Answer {
|
||||
v := time.Duration(answer.Header().Ttl) * time.Second
|
||||
if ttl == 0 || ttl > v {
|
||||
ttl = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
|
||||
c.m.Store(key, &cacheItem{
|
||||
msg: mr.Copy(),
|
||||
ts: time.Now(),
|
||||
ttl: ttl,
|
||||
})
|
||||
|
||||
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
|
||||
}
|
30
common/util/resolver/resolver.go
Normal file
30
common/util/resolver/resolver.go
Normal file
@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
|
||||
if m == nil || ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
e := new(dns.EDNS0_SUBNET)
|
||||
e.Code = dns.EDNS0SUBNET
|
||||
if ip := ip.To4(); ip != nil {
|
||||
e.Family = 1
|
||||
e.SourceNetmask = 24
|
||||
e.Address = ip
|
||||
} else {
|
||||
e.Family = 2
|
||||
e.SourceNetmask = 128
|
||||
e.Address = ip.To16()
|
||||
}
|
||||
opt.Option = append(opt.Option, e)
|
||||
m.Extra = append(m.Extra, opt)
|
||||
}
|
178
common/util/tls/tls.go
Normal file
178
common/util/tls/tls.go
Normal file
@ -0,0 +1,178 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfig is a default TLS config for global use.
|
||||
DefaultConfig *tls.Config
|
||||
)
|
||||
|
||||
// LoadServerConfig loads the certificate from cert & key files and optional client CA file.
|
||||
func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
|
||||
if certFile == "" && keyFile == "" {
|
||||
return DefaultConfig.Clone(), nil
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
|
||||
pool, err := loadCA(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
cfg.ClientCAs = pool
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadClientConfig loads the certificate from cert & key files and optional CA file.
|
||||
func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) {
|
||||
var cfg *tls.Config
|
||||
|
||||
if certFile == "" && keyFile == "" {
|
||||
cfg = &tls.Config{}
|
||||
} else {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
rootCAs, err := loadCA(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.RootCAs = rootCAs
|
||||
cfg.ServerName = serverName
|
||||
cfg.InsecureSkipVerify = !verify
|
||||
|
||||
// If the root ca is given, but skip verify, we verify the certificate manually.
|
||||
if cfg.RootCAs != nil && !verify {
|
||||
cfg.VerifyConnection = func(state tls.ConnectionState) error {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: cfg.RootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
DNSName: "",
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
certs := state.PeerCertificates
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
_, err := certs[0].Verify(opts)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func loadCA(caFile string) (cp *x509.CertPool, err error) {
|
||||
if caFile == "" {
|
||||
return
|
||||
}
|
||||
cp = x509.NewCertPool()
|
||||
data, err := ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !cp.AppendCertsFromPEM(data) {
|
||||
return nil, errors.New("AppendCertsFromPEM failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap a net.Conn into a client tls connection, performing any
|
||||
// additional verification as needed.
|
||||
//
|
||||
// As of go 1.3, crypto/tls only supports either doing no certificate
|
||||
// verification, or doing full verification including of the peer's
|
||||
// DNS name. For consul, we want to validate that the certificate is
|
||||
// signed by a known CA, but because consul doesn't use DNS names for
|
||||
// node names, we don't verify the certificate DNS names. Since go 1.3
|
||||
// no longer supports this mode of operation, we have to do it
|
||||
// manually.
|
||||
//
|
||||
// This code is taken from consul:
|
||||
// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go
|
||||
func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
|
||||
var err error
|
||||
var tlsConn *tls.Conn
|
||||
|
||||
if timeout > 0 {
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
}
|
||||
|
||||
tlsConn = tls.Client(conn, tlsConfig)
|
||||
|
||||
// Otherwise perform handshake, but don't verify the domain
|
||||
//
|
||||
// The following is lightly-modified from the doFullHandshake
|
||||
// method in https://golang.org/src/crypto/tls/handshake_client.go
|
||||
if err = tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We can do this in `tls.Config.VerifyConnection`, which effective for
|
||||
// other TLS protocols such as WebSocket. See `route.go:parseChainNode`
|
||||
/*
|
||||
// If crypto/tls is doing verification, there's no need to do our own.
|
||||
if tlsConfig.InsecureSkipVerify == false {
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// Similarly if we use host's CA, we can do full handshake
|
||||
if tlsConfig.RootCAs == nil {
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
DNSName: "",
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
certs := tlsConn.ConnectionState().PeerCertificates
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
_, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
|
||||
return tlsConn, err
|
||||
}
|
102
common/util/udp/conn.go
Normal file
102
common/util/udp/conn.go
Normal file
@ -0,0 +1,102 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
)
|
||||
|
||||
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type Conn struct {
|
||||
net.PacketConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
idle int32 // indicate the connection is idle
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
|
||||
return &Conn{
|
||||
PacketConn: c,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
rc: make(chan []byte, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
c.SetIdle(false)
|
||||
bufpool.Put(&bb)
|
||||
|
||||
case <-c.closed:
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.remoteAddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) IsIdle() bool {
|
||||
return atomic.LoadInt32(&c.idle) > 0
|
||||
}
|
||||
|
||||
func (c *Conn) SetIdle(idle bool) {
|
||||
v := int32(0)
|
||||
if idle {
|
||||
v = 1
|
||||
}
|
||||
atomic.StoreInt32(&c.idle, v)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteQueue(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
|
||||
case <-c.closed:
|
||||
return net.ErrClosed
|
||||
|
||||
default:
|
||||
return errors.New("recv queue is full")
|
||||
}
|
||||
}
|
120
common/util/udp/listener.go
Normal file
120
common/util/udp/listener.go
Normal file
@ -0,0 +1,120 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
addr net.Addr
|
||||
conn net.PacketConn
|
||||
cqueue chan net.Conn
|
||||
readQueueSize int
|
||||
readBufferSize int
|
||||
connPool *ConnPool
|
||||
mux sync.Mutex
|
||||
closed chan struct{}
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
|
||||
ln := &listener{
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
cqueue: make(chan net.Conn, backlog),
|
||||
connPool: NewConnPool(ttl).WithLogger(logger),
|
||||
readQueueSize: dataQueueSize,
|
||||
readBufferSize: dataBufferSize,
|
||||
closed: make(chan struct{}),
|
||||
errChan: make(chan error, 1),
|
||||
logger: logger,
|
||||
}
|
||||
go ln.listenLoop()
|
||||
|
||||
return ln
|
||||
}
|
||||
|
||||
func (ln *listener) Accept() (conn net.Conn, err error) {
|
||||
select {
|
||||
case conn = <-ln.cqueue:
|
||||
return
|
||||
case <-ln.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err = <-ln.errChan:
|
||||
if err == nil {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) listenLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
b := bufpool.Get(ln.readBufferSize)
|
||||
|
||||
n, raddr, err := ln.conn.ReadFrom(*b)
|
||||
if err != nil {
|
||||
ln.errChan <- err
|
||||
close(ln.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
c := ln.getConn(raddr)
|
||||
if c == nil {
|
||||
bufpool.Put(b)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.WriteQueue((*b)[:n]); err != nil {
|
||||
ln.logger.Warn("data discarded: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *listener) Close() error {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
default:
|
||||
close(ln.closed)
|
||||
ln.conn.Close()
|
||||
ln.connPool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *listener) getConn(raddr net.Addr) *Conn {
|
||||
ln.mux.Lock()
|
||||
defer ln.mux.Unlock()
|
||||
|
||||
c, ok := ln.connPool.Get(raddr.String())
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
|
||||
select {
|
||||
case ln.cqueue <- c:
|
||||
ln.connPool.Set(raddr.String(), c)
|
||||
return c
|
||||
default:
|
||||
c.Close()
|
||||
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
|
||||
return nil
|
||||
}
|
||||
}
|
100
common/util/udp/pool.go
Normal file
100
common/util/udp/pool.go
Normal file
@ -0,0 +1,100 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
type ConnPool struct {
|
||||
m sync.Map
|
||||
ttl time.Duration
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewConnPool(ttl time.Duration) *ConnPool {
|
||||
p := &ConnPool{
|
||||
ttl: ttl,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go p.idleCheck()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
|
||||
p.logger = logger
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
c, ok = v.(*Conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ConnPool) Set(key any, c *Conn) {
|
||||
p.m.Store(key, c)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Delete(key any) {
|
||||
p.m.Delete(key)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Close() {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
close(p.closed)
|
||||
|
||||
p.m.Range(func(k, v any) bool {
|
||||
if c, ok := v.(*Conn); ok && c != nil {
|
||||
c.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ConnPool) idleCheck() {
|
||||
ticker := time.NewTicker(p.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
size := 0
|
||||
idles := 0
|
||||
p.m.Range(func(key, value any) bool {
|
||||
c, ok := value.(*Conn)
|
||||
if !ok || c == nil {
|
||||
p.Delete(key)
|
||||
return true
|
||||
}
|
||||
size++
|
||||
|
||||
if c.IsIdle() {
|
||||
idles++
|
||||
p.Delete(key)
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
c.SetIdle(true)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if idles > 0 {
|
||||
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
|
||||
}
|
||||
case <-p.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user