add chain

This commit is contained in:
ginuerzh
2021-10-26 21:07:46 +08:00
parent ce13b2a82a
commit 3351aa5974
78 changed files with 917 additions and 185 deletions

33
pkg/chain/chain.go Normal file
View File

@ -0,0 +1,33 @@
package chain
type Chain struct {
groups []*NodeGroup
}
func (c *Chain) AddNodeGroup(group *NodeGroup) {
c.groups = append(c.groups, group)
}
func (c *Chain) GetRoute() (r *Route) {
if c == nil || len(c.groups) == 0 {
return
}
r = &Route{}
for _, group := range c.groups {
node := group.Next()
if node == nil {
return
}
// TODO: bypass
if node.Transport().IsMultiplex() {
tr := node.Transport().WithRoute(r)
node = node.WithTransport(tr)
r = &Route{}
}
r.AddNode(node)
}
return r
}

59
pkg/chain/node.go Normal file
View File

@ -0,0 +1,59 @@
package chain
type Node struct {
name string
addr string
transport *Transport
}
func NewNode(name, addr string) *Node {
return &Node{
name: name,
addr: addr,
}
}
func (node *Node) Name() string {
return node.name
}
func (node *Node) Addr() string {
return node.addr
}
func (node *Node) Transport() *Transport {
return node.transport
}
func (node *Node) WithTransport(tr *Transport) *Node {
node.transport = tr
return node
}
type NodeGroup struct {
nodes []*Node
selector Selector
}
func NewNodeGroup(nodes ...*Node) *NodeGroup {
return &NodeGroup{
nodes: nodes,
}
}
func (g *NodeGroup) AddNode(node *Node) {
g.nodes = append(g.nodes, node)
}
func (g *NodeGroup) WithSelector(selector Selector) {
g.selector = selector
}
func (g *NodeGroup) Next() *Node {
selector := g.selector
if selector == nil {
// selector = defaultSelector
return g.nodes[0]
}
return selector.Select(g.nodes...)
}

93
pkg/chain/route.go Normal file
View File

@ -0,0 +1,93 @@
package chain
import (
"context"
"errors"
"net"
)
type Route struct {
nodes []*Node
}
func (r *Route) AddNode(node *Node) {
r.nodes = append(r.nodes, node)
}
func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) {
if r.IsEmpty() {
return nil, errors.New("empty route")
}
node := r.nodes[0]
cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr())
if err != nil {
return
}
cn, err := node.Transport().Handshake(ctx, cc)
if err != nil {
cc.Close()
return
}
preNode := node
for _, node := range r.nodes[1:] {
cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr())
if err != nil {
cn.Close()
return
}
cc, err = node.transport.Handshake(ctx, cc)
if err != nil {
cn.Close()
}
cn = cc
preNode = node
}
conn = cn
return
}
func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if r.IsEmpty() {
return r.dialDirect(ctx, network, address)
}
conn, err := r.Connect(ctx)
if err != nil {
return nil, err
}
cc, err := r.Last().Transport().Connect(ctx, conn, network, address)
if err != nil {
conn.Close()
return nil, err
}
return cc, nil
}
func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
if address == "" {
return net.ListenUDP(network, nil)
}
default:
}
d := &net.Dialer{}
return d.DialContext(ctx, network, address)
}
func (r *Route) IsEmpty() bool {
return r == nil || len(r.nodes) == 0
}
func (r Route) Last() *Node {
if r.IsEmpty() {
return nil
}
return r.nodes[len(r.nodes)-1]
}

41
pkg/chain/selector.go Normal file
View File

@ -0,0 +1,41 @@
package chain
var (
defaultSelector Selector = NewSelector(nil)
)
type Filter interface {
Filter(nodes ...*Node) []*Node
String() string
}
type Strategy interface {
Apply(nodes ...*Node) *Node
String() string
}
type Selector interface {
Select(nodes ...*Node) *Node
}
type selector struct {
strategy Strategy
filters []Filter
}
func NewSelector(strategy Strategy, filters ...Filter) Selector {
return &selector{
filters: filters,
strategy: strategy,
}
}
func (s *selector) Select(nodes ...*Node) *Node {
for _, filter := range s.filters {
nodes = filter.Filter(nodes...)
}
if len(nodes) == 0 {
return nil
}
return s.strategy.Apply(nodes...)
}

66
pkg/chain/transport.go Normal file
View File

@ -0,0 +1,66 @@
package chain
import (
"context"
"net"
"github.com/go-gost/gost/pkg/components/connector"
"github.com/go-gost/gost/pkg/components/dialer"
)
type Transport struct {
route *Route
dialer dialer.Dialer
connector connector.Connector
}
func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport {
tr.dialer = dialer
return tr
}
func (tr *Transport) WithConnector(connector connector.Connector) *Transport {
tr.connector = connector
return tr
}
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
return tr.dialer.Dial(ctx, addr, tr.dialOptions()...)
}
func (tr *Transport) dialOptions() []dialer.DialOption {
var opts []dialer.DialOption
if tr.route != nil {
opts = append(opts,
dialer.DialFuncDialOption(
func(ctx context.Context, addr string) (net.Conn, error) {
return tr.route.Dial(ctx, "tcp", addr)
},
),
)
}
return opts
}
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
return hs.Handshake(ctx, conn)
}
return conn, nil
}
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
return tr.connector.Connect(ctx, conn, network, address)
}
func (tr *Transport) IsMultiplex() bool {
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
return mux.IsMultiplex()
}
return false
}
func (tr *Transport) WithRoute(r *Route) *Transport {
tr.route = r
return tr
}

View File

@ -0,0 +1,11 @@
package connector
import (
"context"
"net"
)
// Connector is responsible for connecting to the destination address.
type Connector interface {
Connect(ctx context.Context, conn net.Conn, network, address string, opts ...ConnectOption) (net.Conn, error)
}

View File

@ -0,0 +1,99 @@
package http
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"log"
"net"
"net/http"
"net/url"
"github.com/go-gost/gost/pkg/components/connector"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ connector.Connector = (*Connector)(nil)
)
type Connector struct {
md metadata
logger logger.Logger
}
func NewConnector(opts ...connector.Option) *Connector {
options := &connector.Options{}
for _, opt := range opts {
opt(options)
}
return &Connector{
logger: options.Logger,
}
}
func (c *Connector) Init(md connector.Metadata) (err error) {
c.md, err = c.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: address},
Host: address,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
if c.md.UserAgent != "" {
log.Println(c.md.UserAgent)
req.Header.Set("User-Agent", c.md.UserAgent)
}
req.Header.Set("Proxy-Connection", "keep-alive")
if user := c.md.User; user != nil {
u := user.Username()
p, _ := user.Password()
req.Header.Set("Proxy-Authorization",
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
}
req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s", resp.Status)
}
return conn, nil
}
func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) {
if md == nil {
md = connector.Metadata{}
}
m.UserAgent = md[userAgent]
if m.UserAgent == "" {
m.UserAgent = defaultUserAgent
}
if v, ok := md[username]; ok {
m.User = url.UserPassword(v, md[password])
}
return
}

View File

@ -0,0 +1,18 @@
package http
import "net/url"
const (
userAgent = "userAgent"
username = "username"
password = "password"
)
const (
defaultUserAgent = "Chrome/78.0.3904.106"
)
type metadata struct {
UserAgent string
User *url.Userinfo
}

View File

@ -0,0 +1,3 @@
package connector
type Metadata map[string]string

View File

@ -0,0 +1,22 @@
package connector
import (
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}
type ConnectOptions struct {
}
type ConnectOption func(opts *ConnectOptions)

View File

@ -0,0 +1,54 @@
package ss
import (
"context"
"net"
"github.com/go-gost/gost/pkg/components/connector"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ connector.Connector = (*Connector)(nil)
)
type Connector struct {
md metadata
logger logger.Logger
}
func NewConnector(opts ...connector.Option) *Connector {
options := &connector.Options{}
for _, opt := range opts {
opt(options)
}
return &Connector{
logger: options.Logger,
}
}
func (c *Connector) Init(md connector.Metadata) (err error) {
c.md, err = c.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
return conn, nil
}
func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) {
if md == nil {
md = connector.Metadata{}
}
m.method = md[method]
m.password = md[password]
return
}

View File

@ -0,0 +1,11 @@
package ss
const (
method = "method"
password = "password"
)
type metadata struct {
method string
password string
}

View File

@ -0,0 +1,19 @@
package dialer
import (
"context"
"net"
)
// Transporter is responsible for dialing to the proxy server.
type Dialer interface {
Dial(ctx context.Context, addr string, opts ...DialOption) (net.Conn, error)
}
type Handshaker interface {
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
}
type Multiplexer interface {
IsMultiplex() bool
}

View File

@ -0,0 +1,3 @@
package dialer
type Metadata map[string]string

View File

@ -0,0 +1,32 @@
package dialer
import (
"context"
"net"
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}
type DialOptions struct {
DialFunc func(ctx context.Context, addr string) (net.Conn, error)
}
type DialOption func(opts *DialOptions)
func DialFuncDialOption(dialf func(ctx context.Context, addr string) (net.Conn, error)) DialOption {
return func(opts *DialOptions) {
opts.DialFunc = dialf
}
}

View File

@ -0,0 +1,56 @@
package tcp
import (
"context"
"net"
"github.com/go-gost/gost/pkg/components/dialer"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ dialer.Dialer = (*Dialer)(nil)
)
type Dialer struct {
md metadata
logger logger.Logger
}
func NewDialer(opts ...dialer.Option) *Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &Dialer{
logger: options.Logger,
}
}
func (d *Dialer) Init(md dialer.Metadata) (err error) {
d.md, err = d.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
dial := options.DialFunc
if dial != nil {
return dial(ctx, addr)
}
var netd net.Dialer
return netd.DialContext(ctx, "tcp", addr)
}
func (d *Dialer) parseMetadata(md dialer.Metadata) (m metadata, err error) {
return
}

View File

@ -0,0 +1,15 @@
package tcp
import "time"
const (
dialTimeout = "dialTimeout"
)
const (
defaultDialTimeout = 5 * time.Second
)
type metadata struct {
dialTimeout time.Duration
}

View File

@ -0,0 +1,10 @@
package handler
import (
"context"
"net"
)
type Handler interface {
Handle(context.Context, net.Conn)
}

View File

@ -0,0 +1,225 @@
package http
import (
"bufio"
"context"
"net"
"net/http"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/components/handler"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ handler.Handler = (*Handler)(nil)
)
type Handler struct {
chain *chain.Chain
logger logger.Logger
md metadata
}
func NewHandler(opts ...handler.Option) *Handler {
options := &handler.Options{}
for _, opt := range opts {
opt(options)
}
return &Handler{
chain: options.Chain,
logger: options.Logger,
}
}
func (h *Handler) Init(md handler.Metadata) error {
return nil
}
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
h.logger.WithFields(map[string]interface{}{
"src": conn.RemoteAddr(),
"local": conn.LocalAddr(),
}).Error(err)
return
}
defer req.Body.Close()
h.handleRequest(ctx, conn, req)
}
func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) {
if req == nil {
return
}
/*
// try to get the actual host.
if v := req.Header.Get("Gost-Target"); v != "" {
if h, err := decodeServerName(v); err == nil {
req.Host = h
}
}
*/
host := req.Host
if _, port, _ := net.SplitHostPort(host); port == "" {
host = net.JoinHostPort(host, "80")
}
/*
u, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
if u != "" {
u += "@"
}
log.Logf("[http] %s%s -> %s -> %s",
u, conn.RemoteAddr(), h.options.Node.String(), host)
if Debug {
dump, _ := httputil.DumpRequest(req, false)
log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
req.Header.Del("Gost-Target")
*/
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
}
if h.md.proxyAgent != "" {
resp.Header.Add("Proxy-Agent", h.md.proxyAgent)
}
/*
if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s",
conn.RemoteAddr(), conn.LocalAddr(), host)
resp.StatusCode = http.StatusForbidden
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
resp.Write(conn)
return
}
*/
/*
if h.options.Bypass.Contains(host) {
resp.StatusCode = http.StatusForbidden
log.Logf("[http] %s - %s bypass %s",
conn.RemoteAddr(), conn.LocalAddr(), host)
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
resp.Write(conn)
return
}
*/
/*
if !h.authenticate(conn, req, resp) {
return
}
*/
if req.Method == "PRI" ||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp.StatusCode = http.StatusBadRequest
/*
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s",
conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
*/
resp.Write(conn)
return
}
req.Header.Del("Proxy-Authorization")
cc, err := h.dial(ctx, host)
if err != nil {
resp.StatusCode = http.StatusServiceUnavailable
/*
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
*/
resp.Write(conn)
return
}
defer cc.Close()
if req.Method == http.MethodConnect {
resp.StatusCode = http.StatusOK
resp.Status = "200 Connection established"
resp.Write(conn)
} else {
req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil {
return
}
}
handler.Transport(conn, cc)
}
func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
count := h.md.retryCount + 1
if count <= 0 {
count = 1
}
for i := 0; i < count; i++ {
route := h.chain.GetRoute()
/*
buf := bytes.Buffer{}
fmt.Fprintf(&buf, "%s -> %s -> ",
conn.RemoteAddr(), h.options.Node.String())
for _, nd := range route.route {
fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String())
}
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
*/
/*
// forward http request
lastNode := route.LastNode()
if req.Method != http.MethodConnect && lastNode.Protocol == "http" {
err = h.forwardRequest(conn, req, route)
if err == nil {
return
}
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
continue
}
*/
conn, err = route.Dial(ctx, "tcp", addr)
if err != nil {
continue
}
}
return
}

View File

@ -0,0 +1,7 @@
package http
type metadata struct {
addr string
proxyAgent string
retryCount int
}

View File

@ -0,0 +1,3 @@
package handler
type Metadata map[string]string

View File

@ -0,0 +1,25 @@
package handler
import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Chain *chain.Chain
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}
func ChainOption(chain *chain.Chain) Option {
return func(opts *Options) {
opts.Chain = chain
}
}

View File

@ -0,0 +1,129 @@
package ss
import (
"bytes"
"context"
"net"
"time"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/components/handler"
"github.com/go-gost/gost/pkg/logger"
"github.com/shadowsocks/go-shadowsocks2/core"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
var (
_ handler.Handler = (*Handler)(nil)
)
type Handler struct {
logger logger.Logger
md metadata
}
func NewHandler(opts ...handler.Option) *Handler {
options := &handler.Options{}
for _, opt := range opts {
opt(options)
}
return &Handler{
logger: options.Logger,
}
}
func (h *Handler) Init(md handler.Metadata) (err error) {
h.md, err = h.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
if h.md.cipher != nil {
conn = &shadowConn{
Conn: h.md.cipher.StreamConn(conn),
}
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
addr := &gosocks5.Addr{}
_, err := addr.ReadFrom(conn)
if err != nil {
h.logger.Error(err)
return
}
conn.SetReadDeadline(time.Time{})
host := addr.String()
cc, err := net.Dial("tcp", host)
if err != nil {
return
}
defer cc.Close()
handler.Transport(conn, cc)
}
func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) {
m.cipher, err = h.initCipher(md[method], md[password], md[key])
if err != nil {
return
}
if v, ok := md[readTimeout]; ok {
m.readTimeout, _ = time.ParseDuration(v)
}
return
}
func (h *Handler) initCipher(method, password string, key string) (core.Cipher, error) {
if method == "" && password == "" {
return nil, nil
}
c, _ := ss.NewCipher(method, password)
if c != nil {
return &shadowCipher{cipher: c}, nil
}
return core.PickCipher(method, []byte(key), password)
}
type shadowCipher struct {
cipher *ss.Cipher
}
func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
return ss.NewConn(conn, c.cipher.Copy())
}
func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
return ss.NewSecurePacketConn(conn, c.cipher.Copy())
}
// Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write,
// we wrap around it to make io.Copy happy.
type shadowConn struct {
net.Conn
wbuf bytes.Buffer
}
func (c *shadowConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.Conn.Write(c.wbuf.Bytes())
c.wbuf.Reset()
return
}
_, err = c.Conn.Write(b)
return
}

View File

@ -0,0 +1,19 @@
package ss
import (
"time"
"github.com/shadowsocks/go-shadowsocks2/core"
)
const (
method = "method"
password = "password"
key = "key"
readTimeout = "readTimeout"
)
type metadata struct {
cipher core.Cipher
readTimeout time.Duration
}

View File

@ -0,0 +1,80 @@
package ss
import (
"context"
"net"
"time"
"github.com/go-gost/gost/pkg/components/handler"
"github.com/go-gost/gost/pkg/logger"
"github.com/shadowsocks/go-shadowsocks2/core"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
var (
_ handler.Handler = (*Handler)(nil)
)
type Handler struct {
logger logger.Logger
md metadata
}
func NewHandler(opts ...handler.Option) *Handler {
options := &handler.Options{}
for _, opt := range opts {
opt(options)
}
return &Handler{
logger: options.Logger,
}
}
func (h *Handler) Init(md handler.Metadata) (err error) {
h.md, err = h.parseMetadata(md)
if err != nil {
return
}
return nil
}
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
}
func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) {
m.cipher, err = h.initCipher(md[method], md[password], md[key])
if err != nil {
return
}
if v, ok := md[readTimeout]; ok {
m.readTimeout, _ = time.ParseDuration(v)
}
return
}
func (h *Handler) initCipher(method, password string, key string) (core.Cipher, error) {
if method == "" && password == "" {
return nil, nil
}
c, _ := ss.NewCipher(method, password)
if c != nil {
return &shadowCipher{cipher: c}, nil
}
return core.PickCipher(method, []byte(key), password)
}
type shadowCipher struct {
cipher *ss.Cipher
}
func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
return ss.NewConn(conn, c.cipher.Copy())
}
func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
return ss.NewSecurePacketConn(conn, c.cipher.Copy())
}

View File

@ -0,0 +1,19 @@
package ss
import (
"time"
"github.com/shadowsocks/go-shadowsocks2/core"
)
const (
method = "method"
password = "password"
key = "key"
readTimeout = "readTimeout"
)
type metadata struct {
cipher core.Cipher
readTimeout time.Duration
}

View File

@ -0,0 +1,43 @@
package handler
import (
"io"
"sync"
)
const (
poolBufferSize = 32 * 1024
)
var (
pool = sync.Pool{
New: func() interface{} {
return make([]byte, poolBufferSize)
},
}
)
func Transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1)
go func() {
errc <- copyBuffer(rw1, rw2)
}()
go func() {
errc <- copyBuffer(rw2, rw1)
}()
err := <-errc
if err != nil && err == io.EOF {
err = nil
}
return err
}
func copyBuffer(dst io.Writer, src io.Reader) error {
buf := pool.Get().([]byte)
defer pool.Put(buf)
_, err := io.CopyBuffer(dst, src, buf)
return err
}

View File

@ -0,0 +1,34 @@
package utils
import (
"net"
"github.com/golang/snappy"
)
type kcpCompStreamConn struct {
net.Conn
w *snappy.Writer
r *snappy.Reader
}
func KCPCompStreamConn(conn net.Conn) net.Conn {
return &kcpCompStreamConn{
Conn: conn,
w: snappy.NewBufferedWriter(conn),
r: snappy.NewReader(conn),
}
}
func (c *kcpCompStreamConn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *kcpCompStreamConn) Write(b []byte) (n int, err error) {
n, err = c.w.Write(b)
if err != nil {
return
}
err = c.w.Flush()
return n, err
}

View File

@ -0,0 +1,104 @@
package utils
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"net"
"github.com/lucas-clemente/quic-go"
)
type quicConn struct {
quic.Session
quic.Stream
}
func QUICConn(session quic.Session, stream quic.Stream) net.Conn {
return &quicConn{
Session: session,
Stream: stream,
}
}
type quicCipherConn struct {
net.PacketConn
key []byte
}
func QUICCipherConn(conn net.PacketConn, key []byte) net.PacketConn {
return &quicCipherConn{
PacketConn: conn,
key: key,
}
}
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
n, addr, err = conn.PacketConn.ReadFrom(data)
if err != nil {
return
}
b, err := conn.decrypt(data[:n])
if err != nil {
return
}
copy(data, b)
return len(b), addr, nil
}
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
b, err := conn.encrypt(data)
if err != nil {
return
}
_, err = conn.PacketConn.WriteTo(b, addr)
if err != nil {
return
}
return len(b), nil
}
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, data, nil), nil
}
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}

View File

@ -0,0 +1,32 @@
package utils
import (
"net"
"time"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
// TCPKeepAliveListener is a TCP listener with keep alive enabled.
type TCPKeepAliveListener struct {
KeepAlivePeriod time.Duration
*net.TCPListener
}
func (l *TCPKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := l.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
period := l.KeepAlivePeriod
if period <= 0 {
period = defaultKeepAlivePeriod
}
tc.SetKeepAlivePeriod(period)
return tc, nil
}

View File

@ -0,0 +1,40 @@
package utils
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
)
// LoadTLSConfig loads the certificate from cert & key files and optional client CA file.
func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
if pool, _ := loadCA(caFile); pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}

View File

@ -0,0 +1,41 @@
package utils
import (
"net"
"time"
"github.com/gorilla/websocket"
)
type websocketConn struct {
*websocket.Conn
rb []byte
}
func WebsocketServerConn(conn *websocket.Conn) net.Conn {
return &websocketConn{
Conn: conn,
}
}
func (c *websocketConn) Read(b []byte) (n int, err error) {
if len(c.rb) == 0 {
_, c.rb, err = c.ReadMessage()
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *websocketConn) Write(b []byte) (n int, err error) {
err = c.WriteMessage(websocket.BinaryMessage, b)
n = len(b)
return
}
func (c *websocketConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}

View File

@ -0,0 +1,115 @@
package ftcp
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
net.PacketConn
raddr net.Addr
rc chan []byte // data receive queue
fresh int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
}
case <-c.closed:
return
}
}
}

View File

@ -0,0 +1,162 @@
package ftcp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/tcpraw"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.conn, err = tcpraw.Listen("tcp", addr)
if err != nil {
return
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
}
func (l *Listener) Addr() net.Addr {
return l.conn.LocalAddr()
}
func (l *Listener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
}
l.logger.Debug("recv", n)
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -0,0 +1,23 @@
package ftcp
import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)
const (
addr = "addr"
)
type metadata struct {
addr string
ttl time.Duration
readBufferSize int
readQueueSize int
connQueueSize int
}

View File

@ -0,0 +1,54 @@
package http2
import (
"errors"
"net"
"net/http"
"time"
)
// a dummy HTTP2 server conn used by HTTP2 handler
type conn struct {
r *http.Request
w http.ResponseWriter
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
}
func (c *conn) Write(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
return addr
}
func (c *conn) RemoteAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
return addr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

View File

@ -0,0 +1,89 @@
package h2
import (
"errors"
"io"
"net"
"net/http"
"time"
)
// HTTP2 connection, wrapped up just like a net.Conn
type conn struct {
r io.Reader
w io.Writer
remoteAddr net.Addr
localAddr net.Addr
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
return c.w.Write(b)
}
func (c *conn) Close() (err error) {
select {
case <-c.closed:
return
default:
close(c.closed)
}
if rc, ok := c.r.(io.Closer); ok {
err = rc.Close()
}
if w, ok := c.w.(io.Closer); ok {
err = w.Close()
}
return
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(string); ok {
err = errors.New(s)
// log.Log("[http2]", err)
return
}
err = r.(error)
}
}()
n, err = fw.w.Write(p)
if err != nil {
// log.Log("flush writer:", err)
return
}
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}

View File

@ -0,0 +1,186 @@
package h2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
net.Listener
md metadata
server *http2.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
}
// TODO: tune http2 server config
l.server = &http2.Server{
// MaxConcurrentStreams: 1000,
PermitProhibitedCipherSuites: true,
IdleTimeout: 5 * time.Minute,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
// log.Log("[http2] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.handleLoop(conn)
}
}
func (l *Listener) handleLoop(conn net.Conn) {
if l.md.tlsConfig != nil {
tlsConn := tls.Server(conn, l.md.tlsConfig)
// NOTE: HTTP2 server will check the TLS version,
// so we must ensure that the TLS connection is handshake completed.
if err := tlsConn.Handshake(); err != nil {
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
conn = tlsConn
}
opt := http2.ServeConnOpts{
Handler: http.HandlerFunc(l.handleFunc),
}
l.server.ServeConn(conn, &opt)
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
/*
log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump))
}
*/
// w.Header().Set("Proxy-Agent", "gost/"+Version)
conn, err := l.upgrade(w, r)
if err != nil {
// log.Logf("[http2] %s - %s %s %s %s: %s",
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
return
}
select {
case l.connChan <- conn:
default:
conn.Close()
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
if l.md.path == "" && r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed)
return nil, errors.New("method not allowed")
}
if l.md.path != "" && r.RequestURI != l.md.path {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New("bad request")
}
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush() // write header to client
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if remoteAddr == nil {
remoteAddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
return &conn{
r: r.Body,
w: flushWriter{w},
localAddr: l.Listener.Addr(),
remoteAddr: remoteAddr,
closed: make(chan struct{}),
}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package h2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,140 @@
package http2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
server *http.Server
addr net.Addr
connChan chan *conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.server = &http.Server{
Addr: l.md.addr,
Handler: http.HandlerFunc(l.handleFunc),
TLSConfig: l.md.tlsConfig,
}
if err := http2.ConfigureServer(l.server, nil); err != nil {
return err
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
l.addr = ln.Addr()
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan *conn, queueSize)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
// log.Log("[http2]", err)
}
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
conn := &conn{
r: r,
w: w,
closed: make(chan struct{}),
}
select {
case l.connChan <- conn:
default:
// log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr)
return
}
<-conn.closed
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package http2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,115 @@
package kcp
import (
"crypto/sha1"
"github.com/xtaci/kcp-go/v5"
"golang.org/x/crypto/pbkdf2"
)
var (
// Salt is the default salt for KCP cipher.
Salt = "kcp-go"
)
var (
// DefaultKCPConfig is the default KCP config.
DefaultConfig = &Config{
Key: "it's a secrect",
Crypt: "aes",
Mode: "fast",
MTU: 1350,
SndWnd: 1024,
RcvWnd: 1024,
DataShard: 10,
ParityShard: 3,
DSCP: 0,
NoComp: false,
AckNodelay: false,
NoDelay: 0,
Interval: 50,
Resend: 0,
NoCongestion: 0,
SockBuf: 4194304,
KeepAlive: 10,
SnmpLog: "",
SnmpPeriod: 60,
Signal: false,
TCP: false,
}
)
// KCPConfig describes the config for KCP.
type Config struct {
Key string `json:"key"`
Crypt string `json:"crypt"`
Mode string `json:"mode"`
MTU int `json:"mtu"`
SndWnd int `json:"sndwnd"`
RcvWnd int `json:"rcvwnd"`
DataShard int `json:"datashard"`
ParityShard int `json:"parityshard"`
DSCP int `json:"dscp"`
NoComp bool `json:"nocomp"`
AckNodelay bool `json:"acknodelay"`
NoDelay int `json:"nodelay"`
Interval int `json:"interval"`
Resend int `json:"resend"`
NoCongestion int `json:"nc"`
SockBuf int `json:"sockbuf"`
KeepAlive int `json:"keepalive"`
SnmpLog string `json:"snmplog"`
SnmpPeriod int `json:"snmpperiod"`
Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature.
TCP bool `json:"tcp"`
}
// Init initializes the KCP config.
func (c *Config) Init() {
switch c.Mode {
case "normal":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1
case "fast":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 30, 2, 1
case "fast2":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1
case "fast3":
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 10, 2, 1
}
}
func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) {
pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New)
switch crypt {
case "sm4":
block, _ = kcp.NewSM4BlockCrypt(pass[:16])
case "tea":
block, _ = kcp.NewTEABlockCrypt(pass[:16])
case "xor":
block, _ = kcp.NewSimpleXORBlockCrypt(pass)
case "none":
block, _ = kcp.NewNoneBlockCrypt(pass)
case "aes-128":
block, _ = kcp.NewAESBlockCrypt(pass[:16])
case "aes-192":
block, _ = kcp.NewAESBlockCrypt(pass[:24])
case "blowfish":
block, _ = kcp.NewBlowfishBlockCrypt(pass)
case "twofish":
block, _ = kcp.NewTwofishBlockCrypt(pass)
case "cast5":
block, _ = kcp.NewCast5BlockCrypt(pass[:16])
case "3des":
block, _ = kcp.NewTripleDESBlockCrypt(pass[:24])
case "xtea":
block, _ = kcp.NewXTEABlockCrypt(pass[:16])
case "salsa20":
block, _ = kcp.NewSalsa20BlockCrypt(pass)
case "aes":
fallthrough
default: // aes
block, _ = kcp.NewAESBlockCrypt(pass)
}
return
}

View File

@ -0,0 +1,179 @@
package kcp
import (
"errors"
"net"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"github.com/xtaci/tcpraw"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
ln *kcp.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
config := l.md.config
if config == nil {
config = DefaultConfig
}
config.Init()
var ln *kcp.Listener
if config.TCP {
var conn net.PacketConn
conn, err = tcpraw.Listen("tcp", addr)
if err != nil {
return
}
ln, err = kcp.ServeConn(
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard, conn)
} else {
ln, err = kcp.ListenWithOptions(addr,
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard)
}
if err != nil {
return
}
if config.DSCP > 0 {
if err = ln.SetDSCP(config.DSCP); err != nil {
l.logger.Warn(err)
}
}
if err = ln.SetReadBuffer(config.SockBuf); err != nil {
l.logger.Warn(err)
}
if err = ln.SetWriteBuffer(config.SockBuf); err != nil {
l.logger.Warn(err)
}
l.ln = ln
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.ln.Close()
}
func (l *Listener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *Listener) listenLoop() {
for {
conn, err := l.ln.AcceptKCP()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn.SetStreamMode(true)
conn.SetWriteDelay(false)
conn.SetNoDelay(
l.md.config.NoDelay,
l.md.config.Interval,
l.md.config.Resend,
l.md.config.NoCongestion,
)
conn.SetMtu(l.md.config.MTU)
conn.SetWindowSize(l.md.config.SndWnd, l.md.config.RcvWnd)
conn.SetACKNoDelay(l.md.config.AckNodelay)
go l.mux(conn)
}
}
func (l *Listener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.MaxReceiveBuffer = l.md.config.SockBuf
smuxConfig.KeepAliveInterval = time.Duration(l.md.config.KeepAlive) * time.Second
if !l.md.config.NoComp {
conn = utils.KCPCompStreamConn(conn)
}
mux, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer mux.Close()
for {
stream, err := mux.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}

View File

@ -0,0 +1,18 @@
package kcp
const (
addr = "addr"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
config *Config
connQueueSize int
}

View File

@ -0,0 +1,20 @@
package listener
import (
"errors"
"net"
)
var (
ErrClosed = errors.New("accpet on closed listener")
)
// Listener is a server listener, just like a net.Listener.
type Listener interface {
net.Listener
}
// Accepter represents a network endpoint that can accept connection from peer.
type Accepter interface {
Accept() (net.Conn, error)
}

View File

@ -0,0 +1,3 @@
package listener
type Metadata map[string]string

View File

@ -0,0 +1,140 @@
package http
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
)
type conn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
}
func (c *conn) Handshake() (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshaked {
return nil
}
if err = c.handshake(); err != nil {
return
}
c.handshaked = true
return nil
}
func (c *conn) handshake() (err error) {
br := bufio.NewReader(c.Conn)
r, err := http.ReadRequest(br)
if err != nil {
return
}
/*
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump))
}
*/
if r.ContentLength > 0 {
_, err = io.Copy(&c.rbuf, r.Body)
} else {
var b []byte
b, err = br.Peek(br.Buffered())
if len(b) > 0 {
_, err = c.rbuf.Write(b)
}
}
if err != nil {
// log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err)
return
}
b := bytes.Buffer{}
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
b.WriteString("HTTP/1.1 503 Service Unavailable\r\n")
b.WriteString("Content-Length: 0\r\n")
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
b.WriteString("\r\n")
/*
if Debug {
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
}
*/
b.WriteTo(c.Conn)
return errors.New("bad request")
}
b.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
b.WriteString("Server: nginx/1.10.0\r\n")
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
b.WriteString("Connection: Upgrade\r\n")
b.WriteString("Upgrade: websocket\r\n")
b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))))
b.WriteString("\r\n")
/*
if Debug {
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
}
*/
if c.rbuf.Len() > 0 {
c.wbuf = b // cache the response header if there are extra data in the request body.
return
}
_, err = b.WriteTo(c.Conn)
return
}
func (c *conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
return c.Conn.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
n = len(b) // exclude the header length
return
}
return c.Conn.Write(b)
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View File

@ -0,0 +1,88 @@
package http
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &conn{Conn: c}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package http
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,306 @@
package tls
import (
"bytes"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"sync"
"time"
dissector "github.com/ginuerzh/tls-dissector"
)
const (
maxTLSDataLen = 16384
)
var (
cipherSuites = []uint16{
0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f,
0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a,
0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d,
0x003c, 0x0035, 0x002f, 0x00ff,
}
compressionMethods = []uint8{0x00}
algorithms = []uint16{
0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402,
0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203,
}
tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17}
tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03}
ErrBadType = errors.New("bad type")
ErrBadMajorVersion = errors.New("bad major version")
ErrBadMinorVersion = errors.New("bad minor version")
ErrMaxDataLen = errors.New("bad tls data len")
)
const (
tlsRecordStateType = iota
tlsRecordStateVersion0
tlsRecordStateVersion1
tlsRecordStateLength0
tlsRecordStateLength1
tlsRecordStateData
)
type obfsTLSParser struct {
step uint8
state uint8
length uint16
}
func (r *obfsTLSParser) Parse(b []byte) (int, error) {
i := 0
last := 0
length := len(b)
for i < length {
ch := b[i]
switch r.state {
case tlsRecordStateType:
if tlsRecordTypes[r.step] != ch {
return 0, ErrBadType
}
r.state = tlsRecordStateVersion0
i++
case tlsRecordStateVersion0:
if ch != 0x03 {
return 0, ErrBadMajorVersion
}
r.state = tlsRecordStateVersion1
i++
case tlsRecordStateVersion1:
if ch != tlsVersionMinors[r.step] {
return 0, ErrBadMinorVersion
}
r.state = tlsRecordStateLength0
i++
case tlsRecordStateLength0:
r.length = uint16(ch) << 8
r.state = tlsRecordStateLength1
i++
case tlsRecordStateLength1:
r.length |= uint16(ch)
if r.step == 0 {
r.length = 91
} else if r.step == 1 {
r.length = 1
} else if r.length > maxTLSDataLen {
return 0, ErrMaxDataLen
}
if r.length > 0 {
r.state = tlsRecordStateData
} else {
r.state = tlsRecordStateType
r.step++
}
i++
case tlsRecordStateData:
left := uint16(length - i)
if left > r.length {
left = r.length
}
if r.step >= 2 {
skip := i - last
copy(b[last:], b[i:length])
length -= int(skip)
last += int(left)
i = last
} else {
i += int(left)
}
r.length -= left
if r.length == 0 {
if r.step < 3 {
r.step++
}
r.state = tlsRecordStateType
}
}
}
if last == 0 {
return 0, nil
} else if last < length {
length -= last
}
return length, nil
}
type conn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
host string
handshaked chan struct{}
parser *obfsTLSParser
handshakeMutex sync.Mutex
}
// newConn creates a connection for obfs-tls server.
func newConn(c net.Conn, host string) net.Conn {
return &conn{
Conn: c,
host: host,
handshaked: make(chan struct{}),
}
}
func (c *conn) Handshaked() bool {
select {
case <-c.handshaked:
return true
default:
return false
}
}
func (c *conn) Handshake(payload []byte) (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.Handshaked() {
return
}
if err = c.handshake(); err != nil {
return
}
close(c.handshaked)
return nil
}
func (c *conn) handshake() error {
record := &dissector.Record{}
if _, err := record.ReadFrom(c.Conn); err != nil {
// log.Log(err)
return err
}
if record.Type != dissector.Handshake {
return dissector.ErrBadType
}
clientMsg := &dissector.ClientHelloMsg{}
if err := clientMsg.Decode(record.Opaque); err != nil {
// log.Log(err)
return err
}
for _, ext := range clientMsg.Extensions {
if ext.Type() == dissector.ExtSessionTicket {
b, err := ext.Encode()
if err != nil {
// log.Log(err)
return err
}
c.rbuf.Write(b)
break
}
}
serverMsg := &dissector.ServerHelloMsg{
Version: tls.VersionTLS12,
SessionID: clientMsg.SessionID,
CipherSuite: 0xcca8,
CompressionMethod: 0x00,
Extensions: []dissector.Extension{
&dissector.RenegotiationInfoExtension{},
&dissector.ExtendedMasterSecretExtension{},
&dissector.ECPointFormatsExtension{
Formats: []uint8{0x00},
},
},
}
serverMsg.Random.Time = uint32(time.Now().Unix())
rand.Read(serverMsg.Random.Opaque[:])
b, err := serverMsg.Encode()
if err != nil {
return err
}
record = &dissector.Record{
Type: dissector.Handshake,
Version: tls.VersionTLS10,
Opaque: b,
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
record = &dissector.Record{
Type: dissector.ChangeCipherSpec,
Version: tls.VersionTLS12,
Opaque: []byte{0x01},
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
return nil
}
func (c *conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(nil); err != nil {
return
}
select {
case <-c.handshaked:
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
record := &dissector.Record{}
if _, err = record.ReadFrom(c.Conn); err != nil {
return
}
n = copy(b, record.Opaque)
_, err = c.rbuf.Write(record.Opaque[n:])
return
}
func (c *conn) Write(b []byte) (n int, err error) {
n = len(b)
if !c.Handshaked() {
if err = c.Handshake(b); err != nil {
return
}
}
for len(b) > 0 {
data := b
if len(b) > maxTLSDataLen {
data = b[:maxTLSDataLen]
b = b[maxTLSDataLen:]
} else {
b = b[:0]
}
record := &dissector.Record{
Type: dissector.AppData,
Version: tls.VersionTLS12,
Opaque: data,
}
if c.wbuf.Len() > 0 {
record.Type = dissector.Handshake
record.WriteTo(&c.wbuf)
_, err = c.wbuf.WriteTo(c.Conn)
return
}
if _, err = record.WriteTo(c.Conn); err != nil {
return
}
}
return
}

View File

@ -0,0 +1,88 @@
package tls
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &conn{Conn: c}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package tls
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,17 @@
package listener
import (
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Logger logger.Logger
}
type Option func(opts *Options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger
}
}

View File

@ -0,0 +1,142 @@
package quic
import (
"context"
"errors"
"net"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/lucas-clemente/quic-go"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
ln quic.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
if err != nil {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
if l.md.cipherKey != nil {
conn = utils.QUICCipherConn(conn, l.md.cipherKey)
}
config := &quic.Config{
KeepAlive: l.md.keepAlive,
HandshakeIdleTimeout: l.md.HandshakeTimeout,
MaxIdleTimeout: l.md.MaxIdleTimeout,
}
ln, err := quic.Listen(conn, l.md.tlsConfig, config)
if err != nil {
return
}
l.ln = ln
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.ln.Close()
}
func (l *Listener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *Listener) listenLoop() {
for {
ctx := context.Background()
session, err := l.ln.Accept(ctx)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.mux(ctx, session)
}
}
func (l *Listener) mux(ctx context.Context, session quic.Session) {
defer session.CloseWithError(0, "")
for {
stream, err := session.AcceptStream(ctx)
if err != nil {
l.logger.Error("accept stream:", err)
return
}
conn := utils.QUICConn(session, stream)
select {
case l.connChan <- conn:
case <-stream.Context().Done():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}

View File

@ -0,0 +1,32 @@
package quic
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
tlsConfig *tls.Config
keepAlive bool
HandshakeTimeout time.Duration
MaxIdleTimeout time.Duration
cipherKey []byte
connQueueSize int
}

View File

@ -0,0 +1,79 @@
package tcp
import (
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
if err != nil {
return
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return
}
if l.md.keepAlive {
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln,
KeepAlivePeriod: l.md.keepAlivePeriod,
}
return
}
l.Listener = ln
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.keepAlive = true
if val, ok := md[keepAlive]; ok {
m.keepAlive, _ = strconv.ParseBool(val)
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,19 @@
package tcp
import "time"
const (
addr = "addr"
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
const (
defaultKeepAlivePeriod = 180 * time.Second
)
type metadata struct {
addr string
keepAlive bool
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,75 @@
package tls
import (
"crypto/tls"
"errors"
"net"
"time"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
l.Listener = ln
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
if val, ok := md[keepAlivePeriod]; ok {
m.keepAlivePeriod, _ = time.ParseDuration(val)
}
return
}

View File

@ -0,0 +1,20 @@
package tls
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
keepAlivePeriod = "keepAlivePeriod"
)
type metadata struct {
addr string
tlsConfig *tls.Config
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,141 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.errChan <- err
close(l.errChan)
return
}
go l.mux(conn)
}
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package mux
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
tlsConfig *tls.Config
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}

View File

@ -0,0 +1,115 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
net.PacketConn
raddr net.Addr
rc chan []byte // data receive queue
fresh int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
}
case <-c.closed:
return
}
}
}

View File

@ -0,0 +1,168 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
if err != nil {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
l.conn = conn
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
}
func (l *Listener) Addr() net.Addr {
return l.conn.LocalAddr()
}
func (l *Listener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
}
l.logger.Debug("recv", n)
}
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -0,0 +1,23 @@
package udp
import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)
const (
addr = "addr"
)
type metadata struct {
addr string
ttl time.Duration
readBufferSize int
readQueueSize int
connQueueSize int
}

View File

@ -0,0 +1,143 @@
package ws
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/gorilla/websocket"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.connChan <- utils.WebsocketServerConn(conn):
default:
conn.Close()
l.logger.Warn("connection queue is full")
}
}

View File

@ -0,0 +1,40 @@
package ws
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
}

View File

@ -0,0 +1,178 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/pkg/components/internal/utils"
"github.com/go-gost/gost/pkg/components/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
l.mux(utils.WebsocketServerConn(conn))
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}

View File

@ -0,0 +1,54 @@
package mux
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}

96
pkg/logger/gost_logger.go Normal file
View File

@ -0,0 +1,96 @@
package logger
import (
"os"
"github.com/sirupsen/logrus"
)
var (
_ Logger = (*logger)(nil)
)
type logger struct {
logger *logrus.Entry
}
func newLogger(name string) *logger {
l := logrus.New()
l.SetOutput(os.Stdout)
gl := &logger{
logger: l.WithFields(logrus.Fields{
logFieldScope: name,
}),
}
return gl
}
// EnableJSONOutput enables JSON formatted output log.
func (l *logger) EnableJSONOutput(enabled bool) {
}
// SetOutputLevel sets log output level
func (l *logger) SetLevel(level LogLevel) {
lvl, _ := logrus.ParseLevel(string(level))
l.logger.Logger.SetLevel(lvl)
}
// WithFields adds new fields to log.
func (l *logger) WithFields(fields map[string]interface{}) Logger {
return &logger{
logger: l.logger.WithFields(logrus.Fields(fields)),
}
}
// Info logs a message at level Info.
func (l *logger) Info(args ...interface{}) {
l.logger.Log(logrus.InfoLevel, args...)
}
// Infof logs a message at level Info.
func (l *logger) Infof(format string, args ...interface{}) {
l.logger.Logf(logrus.InfoLevel, format, args...)
}
// Debug logs a message at level Debug.
func (l *logger) Debug(args ...interface{}) {
l.logger.Log(logrus.DebugLevel, args...)
}
// Debugf logs a message at level Debug.
func (l *logger) Debugf(format string, args ...interface{}) {
l.logger.Logf(logrus.DebugLevel, format, args...)
}
// Warn logs a message at level Warn.
func (l *logger) Warn(args ...interface{}) {
l.logger.Log(logrus.WarnLevel, args...)
}
// Warnf logs a message at level Warn.
func (l *logger) Warnf(format string, args ...interface{}) {
l.logger.Logf(logrus.WarnLevel, format, args...)
}
// Error logs a message at level Error.
func (l *logger) Error(args ...interface{}) {
l.logger.Log(logrus.ErrorLevel, args...)
}
// Errorf logs a message at level Error.
func (l *logger) Errorf(format string, args ...interface{}) {
l.logger.Logf(logrus.ErrorLevel, format, args...)
}
// Fatal logs a message at level Fatal then the process will exit with status set to 1.
func (l *logger) Fatal(args ...interface{}) {
l.logger.Fatal(args...)
}
// Fatalf logs a message at level Fatal then the process will exit with status set to 1.
func (l *logger) Fatalf(format string, args ...interface{}) {
l.logger.Fatalf(format, args...)
}

57
pkg/logger/logger.go Normal file
View File

@ -0,0 +1,57 @@
package logger
import "sync"
const (
logFieldScope = "scope"
)
// LogLevel is Logger Level type
type LogLevel string
const (
// DebugLevel has verbose message
DebugLevel LogLevel = "debug"
// InfoLevel is default log level
InfoLevel LogLevel = "info"
// WarnLevel is for logging messages about possible issues
WarnLevel LogLevel = "warn"
// ErrorLevel is for logging errors
ErrorLevel LogLevel = "error"
// FatalLevel is for logging fatal messages. The system shuts down after logging the message.
FatalLevel LogLevel = "fatal"
)
var (
globalLoggers = make(map[string]Logger)
globalLoggersLock sync.RWMutex
)
type Logger interface {
EnableJSONOutput(enabled bool)
SetLevel(level LogLevel)
WithFields(map[string]interface{}) Logger
Debug(args ...interface{})
Debugf(format string, args ...interface{})
Info(args ...interface{})
Infof(format string, args ...interface{})
Warn(args ...interface{})
Warnf(format string, args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
}
func NewLogger(name string) Logger {
globalLoggersLock.Lock()
defer globalLoggersLock.Unlock()
logger, ok := globalLoggers[name]
if !ok {
logger = newLogger(name)
globalLoggers[name] = logger
}
return logger
}

63
pkg/service/service.go Normal file
View File

@ -0,0 +1,63 @@
package service
import (
"context"
"net"
"time"
"github.com/go-gost/gost/pkg/components/handler"
"github.com/go-gost/gost/pkg/components/listener"
)
type Service struct {
listener listener.Listener
handler handler.Handler
}
func (s *Service) WithListener(ln listener.Listener) *Service {
s.listener = ln
return s
}
func (s *Service) WithHandler(h handler.Handler) *Service {
s.handler = h
return s
}
func (s *Service) Addr() net.Addr {
return s.listener.Addr()
}
func (s *Service) Run() error {
return s.serve()
}
func (s *Service) Close() error {
return s.listener.Close()
}
func (s *Service) serve() error {
var tempDelay time.Duration
for {
conn, e := s.listener.Accept()
if e != nil {
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
// log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
return e
}
tempDelay = 0
go s.handler.Handle(context.Background(), conn)
}
}