add chain

This commit is contained in:
ginuerzh
2021-10-26 21:07:46 +08:00
parent ce13b2a82a
commit 3351aa5974
78 changed files with 917 additions and 185 deletions

View File

@ -0,0 +1,115 @@
package ftcp
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
net.PacketConn
raddr net.Addr
rc chan []byte // data receive queue
fresh int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
}
case <-c.closed:
return
}
}
}

View File

@ -0,0 +1,162 @@
package ftcp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/tcpraw"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.conn, err = tcpraw.Listen("tcp", addr)
if err != nil {
return
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
}
func (l *Listener) Addr() net.Addr {
return l.conn.LocalAddr()
}
func (l *Listener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
}
l.logger.Debug("recv", n)
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -0,0 +1,23 @@
package ftcp
import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)
const (
addr = "addr"
)
type metadata struct {
addr string
ttl time.Duration
readBufferSize int
readQueueSize int
connQueueSize int
}

View File

@ -0,0 +1,54 @@
package http2
import (
"errors"
"net"
"net/http"
"time"
)
// a dummy HTTP2 server conn used by HTTP2 handler
type conn struct {
r *http.Request
w http.ResponseWriter
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
}
func (c *conn) Write(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
return addr
}
func (c *conn) RemoteAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
return addr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

View File

@ -0,0 +1,89 @@
package h2
import (
"errors"
"io"
"net"
"net/http"
"time"
)
// HTTP2 connection, wrapped up just like a net.Conn
type conn struct {
r io.Reader
w io.Writer
remoteAddr net.Addr
localAddr net.Addr
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
return c.w.Write(b)
}
func (c *conn) Close() (err error) {
select {
case <-c.closed:
return
default:
close(c.closed)
}
if rc, ok := c.r.(io.Closer); ok {
err = rc.Close()
}
if w, ok := c.w.(io.Closer); ok {
err = w.Close()
}
return
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(string); ok {
err = errors.New(s)
// log.Log("[http2]", err)
return
}
err = r.(error)
}
}()
n, err = fw.w.Write(p)
if err != nil {
// log.Log("flush writer:", err)
return
}
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}

View File

@ -0,0 +1,186 @@
package h2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
net.Listener
md metadata
server *http2.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
}
// TODO: tune http2 server config
l.server = &http2.Server{
// MaxConcurrentStreams: 1000,
PermitProhibitedCipherSuites: true,
IdleTimeout: 5 * time.Minute,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
// log.Log("[http2] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.handleLoop(conn)
}
}
func (l *Listener) handleLoop(conn net.Conn) {
if l.md.tlsConfig != nil {
tlsConn := tls.Server(conn, l.md.tlsConfig)
// NOTE: HTTP2 server will check the TLS version,
// so we must ensure that the TLS connection is handshake completed.
if err := tlsConn.Handshake(); err != nil {
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
conn = tlsConn
}
opt := http2.ServeConnOpts{
Handler: http.HandlerFunc(l.handleFunc),
}
l.server.ServeConn(conn, &opt)
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
/*
log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump))
}
*/
// w.Header().Set("Proxy-Agent", "gost/"+Version)
conn, err := l.upgrade(w, r)
if err != nil {
// log.Logf("[http2] %s - %s %s %s %s: %s",
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
return
}
select {
case l.connChan <- conn:
default:
conn.Close()
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
if l.md.path == "" && r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed)
return nil, errors.New("method not allowed")
}
if l.md.path != "" && r.RequestURI != l.md.path {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New("bad request")
}
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush() // write header to client
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if remoteAddr == nil {
remoteAddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
return &conn{
r: r.Body,
w: flushWriter{w},
localAddr: l.Listener.Addr(),
remoteAddr: remoteAddr,
closed: make(chan struct{}),
}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package h2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,140 @@
package http2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
server *http.Server
addr net.Addr
connChan chan *conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.server = &http.Server{
Addr: l.md.addr,
Handler: http.HandlerFunc(l.handleFunc),
TLSConfig: l.md.tlsConfig,
}
if err := http2.ConfigureServer(l.server, nil); err != nil {
return err
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
l.addr = ln.Addr()
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan *conn, queueSize)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
// log.Log("[http2]", err)
}
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
conn := &conn{
r: r,
w: w,
closed: make(chan struct{}),
}
select {
case l.connChan <- conn:
default:
// log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr)
return
}
<-conn.closed
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package http2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,115 @@
package kcp
import (
"crypto/sha1"
"github.com/xtaci/kcp-go/v5"
"golang.org/x/crypto/pbkdf2"
)
var (
// Salt is the default salt for KCP cipher.
Salt = "kcp-go"
)
var (
// DefaultKCPConfig is the default KCP config.
DefaultConfig = &Config{
Key: "it's a secrect",
Crypt: "aes",
Mode: "fast",
MTU: 1350,
SndWnd: 1024,
RcvWnd: 1024,
DataShard: 10,
ParityShard: 3,
DSCP: 0,
NoComp: false,
AckNodelay: false,
NoDelay: 0,
Interval: 50,
Resend: 0,
NoCongestion: 0,
SockBuf: 4194304,
KeepAlive: 10,
SnmpLog: "",
SnmpPeriod: 60,
Signal: false,
TCP: false,
}
)
// KCPConfig describes the config for KCP.
type Config struct {
Key string `json:"key"`
Crypt string `json:"crypt"`
Mode string `json:"mode"`
MTU int `json:"mtu"`
SndWnd int `json:"sndwnd"`
RcvWnd int `json:"rcvwnd"`
DataShard int `json:"datashard"`
ParityShard int `json:"parityshard"`
DSCP int `json:"dscp"`
NoComp bool `json:"nocomp"`
AckNodelay bool `json:"acknodelay"`
NoDelay int `json:"nodelay"`
Interval int `json:"interval"`
Resend int `json:"resend"`
NoCongestion int `json:"nc"`
SockBuf int `json:"sockbuf"`
KeepAlive int `json:"keepalive"`
SnmpLog string `json:"snmplog"`
SnmpPeriod int `json:"snmpperiod"`
Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature.
TCP bool `json:"tcp"`
}
// Init initializes the KCP config.
func (c *Config) Init() {
switch c.Mode {
case "normal":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1
case "fast":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 30, 2, 1
case "fast2":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1
case "fast3":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 10, 2, 1
}
}
func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) {
pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New)
switch crypt {
case "sm4":
block, _ = kcp.NewSM4BlockCrypt(pass[:16])
case "tea":
block, _ = kcp.NewTEABlockCrypt(pass[:16])
case "xor":
block, _ = kcp.NewSimpleXORBlockCrypt(pass)
case "none":
block, _ = kcp.NewNoneBlockCrypt(pass)
case "aes-128":
block, _ = kcp.NewAESBlockCrypt(pass[:16])
case "aes-192":
block, _ = kcp.NewAESBlockCrypt(pass[:24])
case "blowfish":
block, _ = kcp.NewBlowfishBlockCrypt(pass)
case "twofish":
block, _ = kcp.NewTwofishBlockCrypt(pass)
case "cast5":
block, _ = kcp.NewCast5BlockCrypt(pass[:16])
case "3des":
block, _ = kcp.NewTripleDESBlockCrypt(pass[:24])
case "xtea":
block, _ = kcp.NewXTEABlockCrypt(pass[:16])
case "salsa20":
block, _ = kcp.NewSalsa20BlockCrypt(pass)
case "aes":
fallthrough
default: // aes
block, _ = kcp.NewAESBlockCrypt(pass)
}
return
}

View File

@ -0,0 +1,179 @@
package kcp
import (
"errors"
"net"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"github.com/xtaci/tcpraw"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
ln *kcp.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
config := l.md.config
if config == nil {
config = DefaultConfig
}
config.Init()
var ln *kcp.Listener
if config.TCP {
var conn net.PacketConn
conn, err = tcpraw.Listen("tcp", addr)
if err != nil {
return
}
ln, err = kcp.ServeConn(
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard, conn)
} else {
ln, err = kcp.ListenWithOptions(addr,
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard)
}
if err != nil {
return
}
if config.DSCP > 0 {
if err = ln.SetDSCP(config.DSCP); err != nil {
l.logger.Warn(err)
}
}
if err = ln.SetReadBuffer(config.SockBuf); err != nil {
l.logger.Warn(err)
}
if err = ln.SetWriteBuffer(config.SockBuf); err != nil {
l.logger.Warn(err)
}
l.ln = ln
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.ln.Close()
}
func (l *Listener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *Listener) listenLoop() {
for {
conn, err := l.ln.AcceptKCP()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn.SetStreamMode(true)
conn.SetWriteDelay(false)
conn.SetNoDelay(
l.md.config.NoDelay,
l.md.config.Interval,
l.md.config.Resend,
l.md.config.NoCongestion,
)
conn.SetMtu(l.md.config.MTU)
conn.SetWindowSize(l.md.config.SndWnd, l.md.config.RcvWnd)
conn.SetACKNoDelay(l.md.config.AckNodelay)
go l.mux(conn)
}
}
func (l *Listener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.MaxReceiveBuffer = l.md.config.SockBuf
smuxConfig.KeepAliveInterval = time.Duration(l.md.config.KeepAlive) * time.Second
if !l.md.config.NoComp {
conn = utils.KCPCompStreamConn(conn)
}
mux, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer mux.Close()
for {
stream, err := mux.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}

View File

@ -0,0 +1,18 @@
package kcp
const (
addr = "addr"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
config *Config
connQueueSize int
}

View File

@ -0,0 +1,20 @@
package listener
import (
"errors"
"net"
)
var (
ErrClosed = errors.New("accpet on closed listener")
)
// Listener is a server listener, just like a net.Listener.
type Listener interface {
net.Listener
}
// Accepter represents a network endpoint that can accept connection from peer.
type Accepter interface {
Accept() (net.Conn, error)
}

View File

@ -0,0 +1,3 @@
package listener
type Metadata map[string]string

View File

@ -0,0 +1,140 @@
package http
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
)
type conn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
}
func (c *conn) Handshake() (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshaked {
return nil
}
if err = c.handshake(); err != nil {
return
}
c.handshaked = true
return nil
}
func (c *conn) handshake() (err error) {
br := bufio.NewReader(c.Conn)
r, err := http.ReadRequest(br)
if err != nil {
return
}
/*
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump))
}
*/
if r.ContentLength > 0 {
_, err = io.Copy(&c.rbuf, r.Body)
} else {
var b []byte
b, err = br.Peek(br.Buffered())
if len(b) > 0 {
_, err = c.rbuf.Write(b)
}
}
if err != nil {
// log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err)
return
}
b := bytes.Buffer{}
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
b.WriteString("HTTP/1.1 503 Service Unavailable\r\n")
b.WriteString("Content-Length: 0\r\n")
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
b.WriteString("\r\n")
/*
if Debug {
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
}
*/
b.WriteTo(c.Conn)
return errors.New("bad request")
}
b.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
b.WriteString("Server: nginx/1.10.0\r\n")
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
b.WriteString("Connection: Upgrade\r\n")
b.WriteString("Upgrade: websocket\r\n")
b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))))
b.WriteString("\r\n")
/*
if Debug {
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
}
*/
if c.rbuf.Len() > 0 {
c.wbuf = b // cache the response header if there are extra data in the request body.
return
}
_, err = b.WriteTo(c.Conn)
return
}
func (c *conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
return c.Conn.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
n = len(b) // exclude the header length
return
}
return c.Conn.Write(b)
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View File

@ -0,0 +1,88 @@
package http
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &conn{Conn: c}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package http
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,306 @@
package tls
import (
"bytes"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"sync"
"time"
dissector "github.com/ginuerzh/tls-dissector"
)
const (
maxTLSDataLen = 16384
)
var (
cipherSuites = []uint16{
0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f,
0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a,
0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d,
0x003c, 0x0035, 0x002f, 0x00ff,
}
compressionMethods = []uint8{0x00}
algorithms = []uint16{
0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402,
0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203,
}
tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17}
tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03}
ErrBadType = errors.New("bad type")
ErrBadMajorVersion = errors.New("bad major version")
ErrBadMinorVersion = errors.New("bad minor version")
ErrMaxDataLen = errors.New("bad tls data len")
)
const (
tlsRecordStateType = iota
tlsRecordStateVersion0
tlsRecordStateVersion1
tlsRecordStateLength0
tlsRecordStateLength1
tlsRecordStateData
)
type obfsTLSParser struct {
step uint8
state uint8
length uint16
}
func (r *obfsTLSParser) Parse(b []byte) (int, error) {
i := 0
last := 0
length := len(b)
for i < length {
ch := b[i]
switch r.state {
case tlsRecordStateType:
if tlsRecordTypes[r.step] != ch {
return 0, ErrBadType
}
r.state = tlsRecordStateVersion0
i++
case tlsRecordStateVersion0:
if ch != 0x03 {
return 0, ErrBadMajorVersion
}
r.state = tlsRecordStateVersion1
i++
case tlsRecordStateVersion1:
if ch != tlsVersionMinors[r.step] {
return 0, ErrBadMinorVersion
}
r.state = tlsRecordStateLength0
i++
case tlsRecordStateLength0:
r.length = uint16(ch) << 8
r.state = tlsRecordStateLength1
i++
case tlsRecordStateLength1:
r.length |= uint16(ch)
if r.step == 0 {
r.length = 91
} else if r.step == 1 {
r.length = 1
} else if r.length > maxTLSDataLen {
return 0, ErrMaxDataLen
}
if r.length > 0 {
r.state = tlsRecordStateData
} else {
r.state = tlsRecordStateType
r.step++
}
i++
case tlsRecordStateData:
left := uint16(length - i)
if left > r.length {
left = r.length
}
if r.step >= 2 {
skip := i - last
copy(b[last:], b[i:length])
length -= int(skip)
last += int(left)
i = last
} else {
i += int(left)
}
r.length -= left
if r.length == 0 {
if r.step < 3 {
r.step++
}
r.state = tlsRecordStateType
}
}
}
if last == 0 {
return 0, nil
} else if last < length {
length -= last
}
return length, nil
}
type conn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
host string
handshaked chan struct{}
parser *obfsTLSParser
handshakeMutex sync.Mutex
}
// newConn creates a connection for obfs-tls server.
func newConn(c net.Conn, host string) net.Conn {
return &conn{
Conn: c,
host: host,
handshaked: make(chan struct{}),
}
}
func (c *conn) Handshaked() bool {
select {
case <-c.handshaked:
return true
default:
return false
}
}
func (c *conn) Handshake(payload []byte) (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.Handshaked() {
return
}
if err = c.handshake(); err != nil {
return
}
close(c.handshaked)
return nil
}
func (c *conn) handshake() error {
record := &dissector.Record{}
if _, err := record.ReadFrom(c.Conn); err != nil {
// log.Log(err)
return err
}
if record.Type != dissector.Handshake {
return dissector.ErrBadType
}
clientMsg := &dissector.ClientHelloMsg{}
if err := clientMsg.Decode(record.Opaque); err != nil {
// log.Log(err)
return err
}
for _, ext := range clientMsg.Extensions {
if ext.Type() == dissector.ExtSessionTicket {
b, err := ext.Encode()
if err != nil {
// log.Log(err)
return err
}
c.rbuf.Write(b)
break
}
}
serverMsg := &dissector.ServerHelloMsg{
Version: tls.VersionTLS12,
SessionID: clientMsg.SessionID,
CipherSuite: 0xcca8,
CompressionMethod: 0x00,
Extensions: []dissector.Extension{
&dissector.RenegotiationInfoExtension{},
&dissector.ExtendedMasterSecretExtension{},
&dissector.ECPointFormatsExtension{
Formats: []uint8{0x00},
},
},
}
serverMsg.Random.Time = uint32(time.Now().Unix())
rand.Read(serverMsg.Random.Opaque[:])
b, err := serverMsg.Encode()
if err != nil {
return err
}
record = &dissector.Record{
Type: dissector.Handshake,
Version: tls.VersionTLS10,
Opaque: b,
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
record = &dissector.Record{
Type: dissector.ChangeCipherSpec,
Version: tls.VersionTLS12,
Opaque: []byte{0x01},
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
return nil
}
func (c *conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(nil); err != nil {
return
}
select {
case <-c.handshaked:
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
record := &dissector.Record{}
if _, err = record.ReadFrom(c.Conn); err != nil {
return
}
n = copy(b, record.Opaque)
_, err = c.rbuf.Write(record.Opaque[n:])
return
}
func (c *conn) Write(b []byte) (n int, err error) {
n = len(b)
if !c.Handshaked() {
if err = c.Handshake(b); err != nil {
return
}
}
for len(b) > 0 {
data := b
if len(b) > maxTLSDataLen {
data = b[:maxTLSDataLen]
b = b[maxTLSDataLen:]
} else {
b = b[:0]
}
record := &dissector.Record{
Type: dissector.AppData,
Version: tls.VersionTLS12,
Opaque: data,
}
if c.wbuf.Len() > 0 {
record.Type = dissector.Handshake
record.WriteTo(&c.wbuf)
_, err = c.wbuf.WriteTo(c.Conn)
return
}
if _, err = record.WriteTo(c.Conn); err != nil {
return
}
}
return
}

View File

@ -0,0 +1,88 @@
package tls
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &conn{Conn: c}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package tls
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,17 @@
package listener
import (
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}

View File

@ -0,0 +1,142 @@
package quic
import (
"context"
"errors"
"net"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/lucas-clemente/quic-go"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
ln quic.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
if err != nil {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
if l.md.cipherKey != nil {
conn = utils.QUICCipherConn(conn, l.md.cipherKey)
}
config := &quic.Config{
KeepAlive: l.md.keepAlive,
HandshakeIdleTimeout: l.md.HandshakeTimeout,
MaxIdleTimeout: l.md.MaxIdleTimeout,
}
ln, err := quic.Listen(conn, l.md.tlsConfig, config)
if err != nil {
return
}
l.ln = ln
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.ln.Close()
}
func (l *Listener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *Listener) listenLoop() {
for {
ctx := context.Background()
session, err := l.ln.Accept(ctx)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.mux(ctx, session)
}
}
func (l *Listener) mux(ctx context.Context, session quic.Session) {
defer session.CloseWithError(0, "")
for {
stream, err := session.AcceptStream(ctx)
if err != nil {
l.logger.Error("accept stream:", err)
return
}
conn := utils.QUICConn(session, stream)
select {
case l.connChan <- conn:
case <-stream.Context().Done():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}

View File

@ -0,0 +1,32 @@
package quic
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
tlsConfig *tls.Config
keepAlive bool
HandshakeTimeout time.Duration
MaxIdleTimeout time.Duration
cipherKey []byte
connQueueSize int
}

View File

@ -0,0 +1,79 @@
package tcp
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package tcp
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,75 @@
package tls
import (
"crypto/tls"
"errors"
"net"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
l.Listener = ln
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,20 @@
package tls
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
keepAlivePeriod = "keepAlivePeriod"
)
type metadata struct {
addr string
tlsConfig *tls.Config
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,141 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.errChan <- err
close(l.errChan)
return
}
go l.mux(conn)
}
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package mux
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
tlsConfig *tls.Config
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}

View File

@ -0,0 +1,115 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
net.PacketConn
raddr net.Addr
rc chan []byte // data receive queue
fresh int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
}
case <-c.closed:
return
}
}
}

View File

@ -0,0 +1,168 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
if err != nil {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
l.conn = conn
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
}
func (l *Listener) Addr() net.Addr {
return l.conn.LocalAddr()
}
func (l *Listener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
}
l.logger.Debug("recv", n)
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -0,0 +1,23 @@
package udp
import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)
const (
addr = "addr"
)
type metadata struct {
addr string
ttl time.Duration
readBufferSize int
readQueueSize int
connQueueSize int
}

View File

@ -0,0 +1,143 @@
package ws
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/gorilla/websocket"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.connChan <- utils.WebsocketServerConn(conn):
default:
conn.Close()
l.logger.Warn("connection queue is full")
}
}

View File

@ -0,0 +1,40 @@
package ws
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
}

View File

@ -0,0 +1,178 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
l.mux(utils.WebsocketServerConn(conn))
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}

View File

@ -0,0 +1,54 @@
package mux
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}