initial commit

This commit is contained in:
ginuerzh
2022-03-16 19:40:29 +08:00
commit 7db81fcfeb
109 changed files with 8782 additions and 0 deletions

View File

@ -0,0 +1,88 @@
package resolver
import (
"fmt"
"sync"
"time"
"github.com/go-gost/core/logger"
"github.com/miekg/dns"
)
type CacheKey string
// NewCacheKey generates resolver cache key from question of dns query.
func NewCacheKey(q *dns.Question) CacheKey {
if q == nil {
return ""
}
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
return CacheKey(key)
}
type cacheItem struct {
msg *dns.Msg
ts time.Time
ttl time.Duration
}
type Cache struct {
m sync.Map
logger logger.Logger
}
func NewCache() *Cache {
return &Cache{}
}
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
c.logger = logger
return c
}
func (c *Cache) Load(key CacheKey) *dns.Msg {
v, ok := c.m.Load(key)
if !ok {
return nil
}
item, ok := v.(*cacheItem)
if !ok {
return nil
}
if time.Since(item.ts) > item.ttl {
c.m.Delete(key)
return nil
}
c.logger.Debugf("hit resolver cache: %s", key)
return item.msg.Copy()
}
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
if key == "" || mr == nil || ttl < 0 {
return
}
if ttl == 0 {
for _, answer := range mr.Answer {
v := time.Duration(answer.Header().Ttl) * time.Second
if ttl == 0 || ttl > v {
ttl = v
}
}
}
if ttl == 0 {
ttl = 30 * time.Second
}
c.m.Store(key, &cacheItem{
msg: mr.Copy(),
ts: time.Now(),
ttl: ttl,
})
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
}

View File

@ -0,0 +1,30 @@
package resolver
import (
"net"
"github.com/miekg/dns"
)
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
if m == nil || ip == nil {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip := ip.To4(); ip != nil {
e.Family = 1
e.SourceNetmask = 24
e.Address = ip
} else {
e.Family = 2
e.SourceNetmask = 128
e.Address = ip.To16()
}
opt.Option = append(opt.Option, e)
m.Extra = append(m.Extra, opt)
}

178
common/util/tls/tls.go Normal file
View File

@ -0,0 +1,178 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"net"
"time"
)
var (
// DefaultConfig is a default TLS config for global use.
DefaultConfig *tls.Config
)
// LoadServerConfig loads the certificate from cert & key files and optional client CA file.
func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" && keyFile == "" {
return DefaultConfig.Clone(), nil
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
pool, err := loadCA(caFile)
if err != nil {
return nil, err
}
if pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
}
// LoadClientConfig loads the certificate from cert & key files and optional CA file.
func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) {
var cfg *tls.Config
if certFile == "" && keyFile == "" {
cfg = &tls.Config{}
} else {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
rootCAs, err := loadCA(caFile)
if err != nil {
return nil, err
}
cfg.RootCAs = rootCAs
cfg.ServerName = serverName
cfg.InsecureSkipVerify = !verify
// If the root ca is given, but skip verify, we verify the certificate manually.
if cfg.RootCAs != nil && !verify {
cfg.VerifyConnection = func(state tls.ConnectionState) error {
opts := x509.VerifyOptions{
Roots: cfg.RootCAs,
CurrentTime: time.Now(),
DNSName: "",
Intermediates: x509.NewCertPool(),
}
certs := state.PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
}
return cfg, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}
// Wrap a net.Conn into a client tls connection, performing any
// additional verification as needed.
//
// As of go 1.3, crypto/tls only supports either doing no certificate
// verification, or doing full verification including of the peer's
// DNS name. For consul, we want to validate that the certificate is
// signed by a known CA, but because consul doesn't use DNS names for
// node names, we don't verify the certificate DNS names. Since go 1.3
// no longer supports this mode of operation, we have to do it
// manually.
//
// This code is taken from consul:
// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go
func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
var err error
var tlsConn *tls.Conn
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
}
tlsConn = tls.Client(conn, tlsConfig)
// Otherwise perform handshake, but don't verify the domain
//
// The following is lightly-modified from the doFullHandshake
// method in https://golang.org/src/crypto/tls/handshake_client.go
if err = tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, err
}
// We can do this in `tls.Config.VerifyConnection`, which effective for
// other TLS protocols such as WebSocket. See `route.go:parseChainNode`
/*
// If crypto/tls is doing verification, there's no need to do our own.
if tlsConfig.InsecureSkipVerify == false {
return tlsConn, nil
}
// Similarly if we use host's CA, we can do full handshake
if tlsConfig.RootCAs == nil {
return tlsConn, nil
}
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
CurrentTime: time.Now(),
DNSName: "",
Intermediates: x509.NewCertPool(),
}
certs := tlsConn.ConnectionState().PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
tlsConn.Close()
return nil, err
}
*/
return tlsConn, err
}

102
common/util/udp/conn.go Normal file
View File

@ -0,0 +1,102 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
)
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type Conn struct {
net.PacketConn
localAddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
idle int32 // indicate the connection is idle
closed chan struct{}
closeMutex sync.Mutex
}
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
return &Conn{
PacketConn: c,
localAddr: localAddr,
remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize),
closed: make(chan struct{}),
}
}
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(&bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *Conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *Conn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
}
func (c *Conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *Conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *Conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *Conn) WriteQueue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}

120
common/util/udp/listener.go Normal file
View File

@ -0,0 +1,120 @@
package udp
import (
"net"
"sync"
"time"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type listener struct {
addr net.Addr
conn net.PacketConn
cqueue chan net.Conn
readQueueSize int
readBufferSize int
connPool *ConnPool
mux sync.Mutex
closed chan struct{}
errChan chan error
logger logger.Logger
}
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
ln := &listener{
conn: conn,
addr: addr,
cqueue: make(chan net.Conn, backlog),
connPool: NewConnPool(ttl).WithLogger(logger),
readQueueSize: dataQueueSize,
readBufferSize: dataBufferSize,
closed: make(chan struct{}),
errChan: make(chan error, 1),
logger: logger,
}
go ln.listenLoop()
return ln
}
func (ln *listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-ln.cqueue:
return
case <-ln.closed:
return nil, net.ErrClosed
case err = <-ln.errChan:
if err == nil {
err = net.ErrClosed
}
return
}
}
func (ln *listener) listenLoop() {
for {
select {
case <-ln.closed:
return
default:
}
b := bufpool.Get(ln.readBufferSize)
n, raddr, err := ln.conn.ReadFrom(*b)
if err != nil {
ln.errChan <- err
close(ln.errChan)
return
}
c := ln.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := c.WriteQueue((*b)[:n]); err != nil {
ln.logger.Warn("data discarded: ", err)
}
}
}
func (ln *listener) Addr() net.Addr {
return ln.addr
}
func (ln *listener) Close() error {
select {
case <-ln.closed:
default:
close(ln.closed)
ln.conn.Close()
ln.connPool.Close()
}
return nil
}
func (ln *listener) getConn(raddr net.Addr) *Conn {
ln.mux.Lock()
defer ln.mux.Unlock()
c, ok := ln.connPool.Get(raddr.String())
if ok {
return c
}
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
select {
case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c)
return c
default:
c.Close()
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil
}
}

100
common/util/udp/pool.go Normal file
View File

@ -0,0 +1,100 @@
package udp
import (
"sync"
"time"
"github.com/go-gost/core/logger"
)
type ConnPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func NewConnPool(ttl time.Duration) *ConnPool {
p := &ConnPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
p.logger = logger
return p
}
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*Conn)
}
return
}
func (p *ConnPool) Set(key any, c *Conn) {
p.m.Store(key, c)
}
func (p *ConnPool) Delete(key any) {
p.m.Delete(key)
}
func (p *ConnPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v any) bool {
if c, ok := v.(*Conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *ConnPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size := 0
idles := 0
p.m.Range(func(key, value any) bool {
c, ok := value.(*Conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-p.closed:
return
}
}
}