initial commit
This commit is contained in:
113
common/bufpool/pool.go
Normal file
113
common/bufpool/pool.go
Normal file
@ -0,0 +1,113 @@
|
||||
package bufpool
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
pools = []struct {
|
||||
size int
|
||||
pool sync.Pool
|
||||
}{
|
||||
{
|
||||
size: 128,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 128)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 512,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 512)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 1024,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 1024)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 4096,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 4096)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 8192,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 8192)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 16 * 1024,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 16*1024)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 32 * 1024,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 32*1024)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 64 * 1024,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 64*1024)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
size: 65 * 1024,
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 65*1024)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Get returns a buffer of specified size.
|
||||
func Get(size int) *[]byte {
|
||||
for i := range pools {
|
||||
if size <= pools[i].size {
|
||||
b := pools[i].pool.Get().(*[]byte)
|
||||
*b = (*b)[:size]
|
||||
return b
|
||||
}
|
||||
}
|
||||
b := make([]byte, size)
|
||||
return &b
|
||||
}
|
||||
|
||||
func Put(b *[]byte) {
|
||||
for i := range pools {
|
||||
if cap(*b) == pools[i].size {
|
||||
pools[i].pool.Put(b)
|
||||
}
|
||||
}
|
||||
}
|
99
common/matcher/matcher.go
Normal file
99
common/matcher/matcher.go
Normal file
@ -0,0 +1,99 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
)
|
||||
|
||||
// Matcher is a generic pattern matcher,
|
||||
// it gives the match result of the given pattern for specific v.
|
||||
type Matcher interface {
|
||||
Match(v string) bool
|
||||
}
|
||||
|
||||
// NewMatcher creates a Matcher for the given pattern.
|
||||
// The acutal Matcher depends on the pattern:
|
||||
// IP Matcher if pattern is a valid IP address.
|
||||
// CIDR Matcher if pattern is a valid CIDR address.
|
||||
// Domain Matcher if both of the above are not.
|
||||
func NewMatcher(pattern string) Matcher {
|
||||
if pattern == "" {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(pattern); ip != nil {
|
||||
return IPMatcher(ip)
|
||||
}
|
||||
if _, inet, err := net.ParseCIDR(pattern); err == nil {
|
||||
return CIDRMatcher(inet)
|
||||
}
|
||||
return DomainMatcher(pattern)
|
||||
}
|
||||
|
||||
type ipMatcher struct {
|
||||
ip net.IP
|
||||
}
|
||||
|
||||
// IPMatcher creates a Matcher for a specific IP address.
|
||||
func IPMatcher(ip net.IP) Matcher {
|
||||
return &ipMatcher{
|
||||
ip: ip,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ipMatcher) Match(ip string) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
return m.ip.Equal(net.ParseIP(ip))
|
||||
}
|
||||
|
||||
type cidrMatcher struct {
|
||||
ipNet *net.IPNet
|
||||
}
|
||||
|
||||
// CIDRMatcher creates a Matcher for a specific CIDR notation IP address.
|
||||
func CIDRMatcher(inet *net.IPNet) Matcher {
|
||||
return &cidrMatcher{
|
||||
ipNet: inet,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cidrMatcher) Match(ip string) bool {
|
||||
if m == nil || m.ipNet == nil {
|
||||
return false
|
||||
}
|
||||
return m.ipNet.Contains(net.ParseIP(ip))
|
||||
}
|
||||
|
||||
type domainMatcher struct {
|
||||
pattern string
|
||||
glob glob.Glob
|
||||
}
|
||||
|
||||
// DomainMatcher creates a Matcher for a specific domain pattern,
|
||||
// the pattern can be a plain domain such as 'example.com',
|
||||
// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'.
|
||||
func DomainMatcher(pattern string) Matcher {
|
||||
p := pattern
|
||||
if strings.HasPrefix(pattern, ".") {
|
||||
p = pattern[1:] // trim the prefix '.'
|
||||
pattern = "*" + p
|
||||
}
|
||||
return &domainMatcher{
|
||||
pattern: p,
|
||||
glob: glob.MustCompile(pattern),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *domainMatcher) Match(domain string) bool {
|
||||
if m == nil || m.glob == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if domain == m.pattern {
|
||||
return true
|
||||
}
|
||||
return m.glob.Match(domain)
|
||||
}
|
144
common/net/dialer/dialer.go
Normal file
144
common/net/dialer/dialer.go
Normal file
@ -0,0 +1,144 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultNetDialer = &NetDialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
type NetDialer struct {
|
||||
Interface string
|
||||
Timeout time.Duration
|
||||
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if d == nil {
|
||||
d = DefaultNetDialer
|
||||
}
|
||||
log := d.Logger
|
||||
if log == nil {
|
||||
log = logger.Default()
|
||||
}
|
||||
|
||||
ifceName, ifAddr, err := parseInterfaceAddr(d.Interface, network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.DialFunc != nil {
|
||||
return d.DialFunc(ctx, network, addr)
|
||||
}
|
||||
logger.Default().Infof("interface: %s %v/%s", ifceName, ifAddr, network)
|
||||
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
if addr == "" {
|
||||
var laddr *net.UDPAddr
|
||||
if ifAddr != nil {
|
||||
laddr, _ = ifAddr.(*net.UDPAddr)
|
||||
}
|
||||
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
return nil, fmt.Errorf("dial: unsupported network %s", network)
|
||||
}
|
||||
netd := net.Dialer{
|
||||
Timeout: d.Timeout,
|
||||
LocalAddr: ifAddr,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var cerr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
cerr = bindDevice(fd, ifceName)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cerr != nil {
|
||||
return cerr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return netd.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
func parseInterfaceAddr(ifceName, network string) (ifce string, addr net.Addr, err error) {
|
||||
if ifceName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(ifceName)
|
||||
if ip == nil {
|
||||
var ife *net.Interface
|
||||
ife, err = net.InterfaceByName(ifceName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var addrs []net.Addr
|
||||
addrs, err = ife.Addrs()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
err = fmt.Errorf("addr not found for interface %s", ifceName)
|
||||
return
|
||||
}
|
||||
ip = addrs[0].(*net.IPNet).IP
|
||||
ifce = ifceName
|
||||
} else {
|
||||
ifce, err = findInterfaceByIP(ip)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
port := 0
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
addr = &net.TCPAddr{IP: ip, Port: port}
|
||||
return
|
||||
case "udp", "udp4", "udp6":
|
||||
addr = &net.UDPAddr{IP: ip, Port: port}
|
||||
return
|
||||
default:
|
||||
addr = &net.IPAddr{IP: ip}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func findInterfaceByIP(ip net.IP) (string, error) {
|
||||
ifces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, ifce := range ifces {
|
||||
addrs, _ := ifce.Addrs()
|
||||
if len(addrs) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ipAddr, _ := addr.(*net.IPNet)
|
||||
if ipAddr == nil {
|
||||
continue
|
||||
}
|
||||
// logger.Default().Infof("%s-%s", ipAddr, ip)
|
||||
if ipAddr.IP.Equal(ip) {
|
||||
return ifce.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
14
common/net/dialer/dialer_linux.go
Normal file
14
common/net/dialer/dialer_linux.go
Normal file
@ -0,0 +1,14 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func bindDevice(fd uintptr, ifceName string) error {
|
||||
// unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
|
||||
// unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
||||
if ifceName == "" {
|
||||
return nil
|
||||
}
|
||||
return unix.BindToDevice(int(fd), ifceName)
|
||||
}
|
7
common/net/dialer/dialer_other.go
Normal file
7
common/net/dialer/dialer_other.go
Normal file
@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package dialer
|
||||
|
||||
func bindDevice(fd uintptr, ifceName string) error {
|
||||
return nil
|
||||
}
|
126
common/net/relay/relay.go
Normal file
126
common/net/relay/relay.go
Normal file
@ -0,0 +1,126 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/bypass"
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
type UDPRelay struct {
|
||||
pc1 net.PacketConn
|
||||
pc2 net.PacketConn
|
||||
|
||||
bypass bypass.Bypass
|
||||
bufferSize int
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewUDPRelay(pc1, pc2 net.PacketConn) *UDPRelay {
|
||||
return &UDPRelay{
|
||||
pc1: pc1,
|
||||
pc2: pc2,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UDPRelay) WithBypass(bp bypass.Bypass) *UDPRelay {
|
||||
r.bypass = bp
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *UDPRelay) WithLogger(logger logger.Logger) *UDPRelay {
|
||||
r.logger = logger
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *UDPRelay) SetBufferSize(n int) {
|
||||
r.bufferSize = n
|
||||
}
|
||||
|
||||
func (r *UDPRelay) Run() (err error) {
|
||||
bufSize := r.bufferSize
|
||||
if bufSize <= 0 {
|
||||
bufSize = 1024
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := r.pc1.ReadFrom(*b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := r.pc2.WriteTo((*b)[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.logger != nil {
|
||||
r.logger.Debugf("%s >>> %s data: %d",
|
||||
r.pc2.LocalAddr(), raddr, n)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := r.pc2.ReadFrom(*b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := r.pc1.WriteTo((*b)[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.logger != nil {
|
||||
r.logger.Debugf("%s <<< %s data: %d",
|
||||
r.pc2.LocalAddr(), raddr, n)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
}
|
50
common/net/transport.go
Normal file
50
common/net/transport.go
Normal file
@ -0,0 +1,50 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
)
|
||||
|
||||
func Transport(rw1, rw2 io.ReadWriter) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- copyBuffer(rw1, rw2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errc <- copyBuffer(rw2, rw1)
|
||||
}()
|
||||
|
||||
err := <-errc
|
||||
if err != nil && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func copyBuffer(dst io.Writer, src io.Reader) error {
|
||||
buf := bufpool.Get(16 * 1024)
|
||||
defer bufpool.Put(buf)
|
||||
|
||||
_, err := io.CopyBuffer(dst, src, *buf)
|
||||
return err
|
||||
}
|
||||
|
||||
type bufferReaderConn struct {
|
||||
net.Conn
|
||||
br *bufio.Reader
|
||||
}
|
||||
|
||||
func NewBufferReaderConn(conn net.Conn, br *bufio.Reader) net.Conn {
|
||||
return &bufferReaderConn{
|
||||
Conn: conn,
|
||||
br: br,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *bufferReaderConn) Read(b []byte) (int, error) {
|
||||
return c.br.Read(b)
|
||||
}
|
41
common/net/udp.go
Normal file
41
common/net/udp.go
Normal file
@ -0,0 +1,41 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type UDPConn interface {
|
||||
net.PacketConn
|
||||
io.Reader
|
||||
io.Writer
|
||||
readUDP
|
||||
writeUDP
|
||||
setBuffer
|
||||
syscallConn
|
||||
remoteAddr
|
||||
}
|
||||
|
||||
type setBuffer interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
SetWriteBuffer(bytes int) error
|
||||
}
|
||||
|
||||
type readUDP interface {
|
||||
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
|
||||
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
|
||||
}
|
||||
|
||||
type writeUDP interface {
|
||||
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
|
||||
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
|
||||
}
|
||||
|
||||
type syscallConn interface {
|
||||
SyscallConn() (syscall.RawConn, error)
|
||||
}
|
||||
|
||||
type remoteAddr interface {
|
||||
RemoteAddr() net.Addr
|
||||
}
|
88
common/util/resolver/cache.go
Normal file
88
common/util/resolver/cache.go
Normal file
@ -0,0 +1,88 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type CacheKey string
|
||||
|
||||
// NewCacheKey generates resolver cache key from question of dns query.
|
||||
func NewCacheKey(q *dns.Question) CacheKey {
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
||||
return CacheKey(key)
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
msg *dns.Msg
|
||||
ts time.Time
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
m sync.Map
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewCache() *Cache {
|
||||
return &Cache{}
|
||||
}
|
||||
|
||||
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
|
||||
c.logger = logger
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) Load(key CacheKey) *dns.Msg {
|
||||
v, ok := c.m.Load(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
item, ok := v.(*cacheItem)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if time.Since(item.ts) > item.ttl {
|
||||
c.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debugf("hit resolver cache: %s", key)
|
||||
|
||||
return item.msg.Copy()
|
||||
}
|
||||
|
||||
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||
if key == "" || mr == nil || ttl < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if ttl == 0 {
|
||||
for _, answer := range mr.Answer {
|
||||
v := time.Duration(answer.Header().Ttl) * time.Second
|
||||
if ttl == 0 || ttl > v {
|
||||
ttl = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
|
||||
c.m.Store(key, &cacheItem{
|
||||
msg: mr.Copy(),
|
||||
ts: time.Now(),
|
||||
ttl: ttl,
|
||||
})
|
||||
|
||||
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
|
||||
}
|
30
common/util/resolver/resolver.go
Normal file
30
common/util/resolver/resolver.go
Normal file
@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
|
||||
if m == nil || ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
e := new(dns.EDNS0_SUBNET)
|
||||
e.Code = dns.EDNS0SUBNET
|
||||
if ip := ip.To4(); ip != nil {
|
||||
e.Family = 1
|
||||
e.SourceNetmask = 24
|
||||
e.Address = ip
|
||||
} else {
|
||||
e.Family = 2
|
||||
e.SourceNetmask = 128
|
||||
e.Address = ip.To16()
|
||||
}
|
||||
opt.Option = append(opt.Option, e)
|
||||
m.Extra = append(m.Extra, opt)
|
||||
}
|
178
common/util/tls/tls.go
Normal file
178
common/util/tls/tls.go
Normal file
@ -0,0 +1,178 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfig is a default TLS config for global use.
|
||||
DefaultConfig *tls.Config
|
||||
)
|
||||
|
||||
// LoadServerConfig loads the certificate from cert & key files and optional client CA file.
|
||||
func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
|
||||
if certFile == "" && keyFile == "" {
|
||||
return DefaultConfig.Clone(), nil
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
|
||||
pool, err := loadCA(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
cfg.ClientCAs = pool
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadClientConfig loads the certificate from cert & key files and optional CA file.
|
||||
func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) {
|
||||
var cfg *tls.Config
|
||||
|
||||
if certFile == "" && keyFile == "" {
|
||||
cfg = &tls.Config{}
|
||||
} else {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
rootCAs, err := loadCA(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.RootCAs = rootCAs
|
||||
cfg.ServerName = serverName
|
||||
cfg.InsecureSkipVerify = !verify
|
||||
|
||||
// If the root ca is given, but skip verify, we verify the certificate manually.
|
||||
if cfg.RootCAs != nil && !verify {
|
||||
cfg.VerifyConnection = func(state tls.ConnectionState) error {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: cfg.RootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
DNSName: "",
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
certs := state.PeerCertificates
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
_, err := certs[0].Verify(opts)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func loadCA(caFile string) (cp *x509.CertPool, err error) {
|
||||
if caFile == "" {
|
||||
return
|
||||
}
|
||||
cp = x509.NewCertPool()
|
||||
data, err := ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !cp.AppendCertsFromPEM(data) {
|
||||
return nil, errors.New("AppendCertsFromPEM failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap a net.Conn into a client tls connection, performing any
|
||||
// additional verification as needed.
|
||||
//
|
||||
// As of go 1.3, crypto/tls only supports either doing no certificate
|
||||
// verification, or doing full verification including of the peer's
|
||||
// DNS name. For consul, we want to validate that the certificate is
|
||||
// signed by a known CA, but because consul doesn't use DNS names for
|
||||
// node names, we don't verify the certificate DNS names. Since go 1.3
|
||||
// no longer supports this mode of operation, we have to do it
|
||||
// manually.
|
||||
//
|
||||
// This code is taken from consul:
|
||||
// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go
|
||||
func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
|
||||
var err error
|
||||
var tlsConn *tls.Conn
|
||||
|
||||
if timeout > 0 {
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
}
|
||||
|
||||
tlsConn = tls.Client(conn, tlsConfig)
|
||||
|
||||
// Otherwise perform handshake, but don't verify the domain
|
||||
//
|
||||
// The following is lightly-modified from the doFullHandshake
|
||||
// method in https://golang.org/src/crypto/tls/handshake_client.go
|
||||
if err = tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We can do this in `tls.Config.VerifyConnection`, which effective for
|
||||
// other TLS protocols such as WebSocket. See `route.go:parseChainNode`
|
||||
/*
|
||||
// If crypto/tls is doing verification, there's no need to do our own.
|
||||
if tlsConfig.InsecureSkipVerify == false {
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// Similarly if we use host's CA, we can do full handshake
|
||||
if tlsConfig.RootCAs == nil {
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
DNSName: "",
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
certs := tlsConn.ConnectionState().PeerCertificates
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
_, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
|
||||
return tlsConn, err
|
||||
}
|
102
common/util/udp/conn.go
Normal file
102
common/util/udp/conn.go
Normal file
@ -0,0 +1,102 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
)
|
||||
|
||||
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type Conn struct {
|
||||
net.PacketConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
idle int32 // indicate the connection is idle
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
|
||||
return &Conn{
|
||||
PacketConn: c,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
rc: make(chan []byte, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
c.SetIdle(false)
|
||||
bufpool.Put(&bb)
|
||||
|
||||
case <-c.closed:
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.remoteAddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) IsIdle() bool {
|
||||
return atomic.LoadInt32(&c.idle) > 0
|
||||
}
|
||||
|
||||
func (c *Conn) SetIdle(idle bool) {
|
||||
v := int32(0)
|
||||
if idle {
|
||||
v = 1
|
||||
}
|
||||
atomic.StoreInt32(&c.idle, v)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteQueue(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
|
||||
case <-c.closed:
|
||||
return net.ErrClosed
|
||||
|
||||
default:
|
||||
return errors.New("recv queue is full")
|
||||
}
|
||||
}
|
120
common/util/udp/listener.go
Normal file
120
common/util/udp/listener.go
Normal file
@ -0,0 +1,120 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/common/bufpool"
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
addr net.Addr
|
||||
conn net.PacketConn
|
||||
cqueue chan net.Conn
|
||||
readQueueSize int
|
||||
readBufferSize int
|
||||
connPool *ConnPool
|
||||
mux sync.Mutex
|
||||
closed chan struct{}
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
|
||||
ln := &listener{
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
cqueue: make(chan net.Conn, backlog),
|
||||
connPool: NewConnPool(ttl).WithLogger(logger),
|
||||
readQueueSize: dataQueueSize,
|
||||
readBufferSize: dataBufferSize,
|
||||
closed: make(chan struct{}),
|
||||
errChan: make(chan error, 1),
|
||||
logger: logger,
|
||||
}
|
||||
go ln.listenLoop()
|
||||
|
||||
return ln
|
||||
}
|
||||
|
||||
func (ln *listener) Accept() (conn net.Conn, err error) {
|
||||
select {
|
||||
case conn = <-ln.cqueue:
|
||||
return
|
||||
case <-ln.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err = <-ln.errChan:
|
||||
if err == nil {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) listenLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
b := bufpool.Get(ln.readBufferSize)
|
||||
|
||||
n, raddr, err := ln.conn.ReadFrom(*b)
|
||||
if err != nil {
|
||||
ln.errChan <- err
|
||||
close(ln.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
c := ln.getConn(raddr)
|
||||
if c == nil {
|
||||
bufpool.Put(b)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.WriteQueue((*b)[:n]); err != nil {
|
||||
ln.logger.Warn("data discarded: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *listener) Close() error {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
default:
|
||||
close(ln.closed)
|
||||
ln.conn.Close()
|
||||
ln.connPool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *listener) getConn(raddr net.Addr) *Conn {
|
||||
ln.mux.Lock()
|
||||
defer ln.mux.Unlock()
|
||||
|
||||
c, ok := ln.connPool.Get(raddr.String())
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
|
||||
select {
|
||||
case ln.cqueue <- c:
|
||||
ln.connPool.Set(raddr.String(), c)
|
||||
return c
|
||||
default:
|
||||
c.Close()
|
||||
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
|
||||
return nil
|
||||
}
|
||||
}
|
100
common/util/udp/pool.go
Normal file
100
common/util/udp/pool.go
Normal file
@ -0,0 +1,100 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/logger"
|
||||
)
|
||||
|
||||
type ConnPool struct {
|
||||
m sync.Map
|
||||
ttl time.Duration
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewConnPool(ttl time.Duration) *ConnPool {
|
||||
p := &ConnPool{
|
||||
ttl: ttl,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go p.idleCheck()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
|
||||
p.logger = logger
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
c, ok = v.(*Conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ConnPool) Set(key any, c *Conn) {
|
||||
p.m.Store(key, c)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Delete(key any) {
|
||||
p.m.Delete(key)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Close() {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
close(p.closed)
|
||||
|
||||
p.m.Range(func(k, v any) bool {
|
||||
if c, ok := v.(*Conn); ok && c != nil {
|
||||
c.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ConnPool) idleCheck() {
|
||||
ticker := time.NewTicker(p.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
size := 0
|
||||
idles := 0
|
||||
p.m.Range(func(key, value any) bool {
|
||||
c, ok := value.(*Conn)
|
||||
if !ok || c == nil {
|
||||
p.Delete(key)
|
||||
return true
|
||||
}
|
||||
size++
|
||||
|
||||
if c.IsIdle() {
|
||||
idles++
|
||||
p.Delete(key)
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
c.SetIdle(true)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if idles > 0 {
|
||||
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
|
||||
}
|
||||
case <-p.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user