add chain
This commit is contained in:
33
pkg/chain/chain.go
Normal file
33
pkg/chain/chain.go
Normal file
@ -0,0 +1,33 @@
|
||||
package chain
|
||||
|
||||
type Chain struct {
|
||||
groups []*NodeGroup
|
||||
}
|
||||
|
||||
func (c *Chain) AddNodeGroup(group *NodeGroup) {
|
||||
c.groups = append(c.groups, group)
|
||||
}
|
||||
|
||||
func (c *Chain) GetRoute() (r *Route) {
|
||||
if c == nil || len(c.groups) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r = &Route{}
|
||||
for _, group := range c.groups {
|
||||
node := group.Next()
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
// TODO: bypass
|
||||
|
||||
if node.Transport().IsMultiplex() {
|
||||
tr := node.Transport().WithRoute(r)
|
||||
node = node.WithTransport(tr)
|
||||
r = &Route{}
|
||||
}
|
||||
|
||||
r.AddNode(node)
|
||||
}
|
||||
return r
|
||||
}
|
59
pkg/chain/node.go
Normal file
59
pkg/chain/node.go
Normal file
@ -0,0 +1,59 @@
|
||||
package chain
|
||||
|
||||
type Node struct {
|
||||
name string
|
||||
addr string
|
||||
transport *Transport
|
||||
}
|
||||
|
||||
func NewNode(name, addr string) *Node {
|
||||
return &Node{
|
||||
name: name,
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (node *Node) Name() string {
|
||||
return node.name
|
||||
}
|
||||
|
||||
func (node *Node) Addr() string {
|
||||
return node.addr
|
||||
}
|
||||
|
||||
func (node *Node) Transport() *Transport {
|
||||
return node.transport
|
||||
}
|
||||
|
||||
func (node *Node) WithTransport(tr *Transport) *Node {
|
||||
node.transport = tr
|
||||
return node
|
||||
}
|
||||
|
||||
type NodeGroup struct {
|
||||
nodes []*Node
|
||||
selector Selector
|
||||
}
|
||||
|
||||
func NewNodeGroup(nodes ...*Node) *NodeGroup {
|
||||
return &NodeGroup{
|
||||
nodes: nodes,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *NodeGroup) AddNode(node *Node) {
|
||||
g.nodes = append(g.nodes, node)
|
||||
}
|
||||
|
||||
func (g *NodeGroup) WithSelector(selector Selector) {
|
||||
g.selector = selector
|
||||
}
|
||||
|
||||
func (g *NodeGroup) Next() *Node {
|
||||
selector := g.selector
|
||||
if selector == nil {
|
||||
// selector = defaultSelector
|
||||
return g.nodes[0]
|
||||
}
|
||||
return selector.Select(g.nodes...)
|
||||
}
|
93
pkg/chain/route.go
Normal file
93
pkg/chain/route.go
Normal file
@ -0,0 +1,93 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
nodes []*Node
|
||||
}
|
||||
|
||||
func (r *Route) AddNode(node *Node) {
|
||||
r.nodes = append(r.nodes, node)
|
||||
}
|
||||
|
||||
func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
if r.IsEmpty() {
|
||||
return nil, errors.New("empty route")
|
||||
}
|
||||
|
||||
node := r.nodes[0]
|
||||
cc, err := node.Transport().Dial(ctx, r.nodes[0].Addr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cn, err := node.Transport().Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
return
|
||||
}
|
||||
|
||||
preNode := node
|
||||
for _, node := range r.nodes[1:] {
|
||||
cc, err = preNode.Transport().Connect(ctx, cn, "tcp", node.Addr())
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
return
|
||||
}
|
||||
cc, err = node.transport.Handshake(ctx, cc)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
}
|
||||
cn = cc
|
||||
preNode = node
|
||||
}
|
||||
|
||||
conn = cn
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if r.IsEmpty() {
|
||||
return r.dialDirect(ctx, network, address)
|
||||
}
|
||||
|
||||
conn, err := r.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc, err := r.Last().Transport().Connect(ctx, conn, network, address)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
if address == "" {
|
||||
return net.ListenUDP(network, nil)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
d := &net.Dialer{}
|
||||
return d.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (r *Route) IsEmpty() bool {
|
||||
return r == nil || len(r.nodes) == 0
|
||||
}
|
||||
|
||||
func (r Route) Last() *Node {
|
||||
if r.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return r.nodes[len(r.nodes)-1]
|
||||
}
|
41
pkg/chain/selector.go
Normal file
41
pkg/chain/selector.go
Normal file
@ -0,0 +1,41 @@
|
||||
package chain
|
||||
|
||||
var (
|
||||
defaultSelector Selector = NewSelector(nil)
|
||||
)
|
||||
|
||||
type Filter interface {
|
||||
Filter(nodes ...*Node) []*Node
|
||||
String() string
|
||||
}
|
||||
|
||||
type Strategy interface {
|
||||
Apply(nodes ...*Node) *Node
|
||||
String() string
|
||||
}
|
||||
|
||||
type Selector interface {
|
||||
Select(nodes ...*Node) *Node
|
||||
}
|
||||
|
||||
type selector struct {
|
||||
strategy Strategy
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
func NewSelector(strategy Strategy, filters ...Filter) Selector {
|
||||
return &selector{
|
||||
filters: filters,
|
||||
strategy: strategy,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *selector) Select(nodes ...*Node) *Node {
|
||||
for _, filter := range s.filters {
|
||||
nodes = filter.Filter(nodes...)
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.strategy.Apply(nodes...)
|
||||
}
|
66
pkg/chain/transport.go
Normal file
66
pkg/chain/transport.go
Normal file
@ -0,0 +1,66 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/connector"
|
||||
"github.com/go-gost/gost/pkg/components/dialer"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
route *Route
|
||||
dialer dialer.Dialer
|
||||
connector connector.Connector
|
||||
}
|
||||
|
||||
func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport {
|
||||
tr.dialer = dialer
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) WithConnector(connector connector.Connector) *Transport {
|
||||
tr.connector = connector
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return tr.dialer.Dial(ctx, addr, tr.dialOptions()...)
|
||||
}
|
||||
|
||||
func (tr *Transport) dialOptions() []dialer.DialOption {
|
||||
var opts []dialer.DialOption
|
||||
if tr.route != nil {
|
||||
opts = append(opts,
|
||||
dialer.DialFuncDialOption(
|
||||
func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return tr.route.Dial(ctx, "tcp", addr)
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) {
|
||||
if hs, ok := tr.dialer.(dialer.Handshaker); ok {
|
||||
return hs.Handshake(ctx, conn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
|
||||
return tr.connector.Connect(ctx, conn, network, address)
|
||||
}
|
||||
|
||||
func (tr *Transport) IsMultiplex() bool {
|
||||
if mux, ok := tr.dialer.(dialer.Multiplexer); ok {
|
||||
return mux.IsMultiplex()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (tr *Transport) WithRoute(r *Route) *Transport {
|
||||
tr.route = r
|
||||
return tr
|
||||
}
|
11
pkg/components/connector/connector.go
Normal file
11
pkg/components/connector/connector.go
Normal file
@ -0,0 +1,11 @@
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Connector is responsible for connecting to the destination address.
|
||||
type Connector interface {
|
||||
Connect(ctx context.Context, conn net.Conn, network, address string, opts ...ConnectOption) (net.Conn, error)
|
||||
}
|
99
pkg/components/connector/http/connector.go
Normal file
99
pkg/components/connector/http/connector.go
Normal file
@ -0,0 +1,99 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/connector"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ connector.Connector = (*Connector)(nil)
|
||||
)
|
||||
|
||||
type Connector struct {
|
||||
md metadata
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewConnector(opts ...connector.Option) *Connector {
|
||||
options := &connector.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) Init(md connector.Metadata) (err error) {
|
||||
c.md, err = c.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: address},
|
||||
Host: address,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
if c.md.UserAgent != "" {
|
||||
log.Println(c.md.UserAgent)
|
||||
req.Header.Set("User-Agent", c.md.UserAgent)
|
||||
}
|
||||
req.Header.Set("Proxy-Connection", "keep-alive")
|
||||
|
||||
if user := c.md.User; user != nil {
|
||||
u := user.Username()
|
||||
p, _ := user.Password()
|
||||
req.Header.Set("Proxy-Authorization",
|
||||
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
if err := req.Write(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s", resp.Status)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) {
|
||||
if md == nil {
|
||||
md = connector.Metadata{}
|
||||
}
|
||||
m.UserAgent = md[userAgent]
|
||||
if m.UserAgent == "" {
|
||||
m.UserAgent = defaultUserAgent
|
||||
}
|
||||
|
||||
if v, ok := md[username]; ok {
|
||||
m.User = url.UserPassword(v, md[password])
|
||||
}
|
||||
|
||||
return
|
||||
}
|
18
pkg/components/connector/http/metadata.go
Normal file
18
pkg/components/connector/http/metadata.go
Normal file
@ -0,0 +1,18 @@
|
||||
package http
|
||||
|
||||
import "net/url"
|
||||
|
||||
const (
|
||||
userAgent = "userAgent"
|
||||
username = "username"
|
||||
password = "password"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUserAgent = "Chrome/78.0.3904.106"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
UserAgent string
|
||||
User *url.Userinfo
|
||||
}
|
3
pkg/components/connector/metadata.go
Normal file
3
pkg/components/connector/metadata.go
Normal file
@ -0,0 +1,3 @@
|
||||
package connector
|
||||
|
||||
type Metadata map[string]string
|
22
pkg/components/connector/option.go
Normal file
22
pkg/components/connector/option.go
Normal file
@ -0,0 +1,22 @@
|
||||
package connector
|
||||
|
||||
import (
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type Option func(opts *Options)
|
||||
|
||||
func LoggerOption(logger logger.Logger) Option {
|
||||
return func(opts *Options) {
|
||||
opts.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectOptions struct {
|
||||
}
|
||||
|
||||
type ConnectOption func(opts *ConnectOptions)
|
54
pkg/components/connector/ss/connector.go
Normal file
54
pkg/components/connector/ss/connector.go
Normal file
@ -0,0 +1,54 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/connector"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ connector.Connector = (*Connector)(nil)
|
||||
)
|
||||
|
||||
type Connector struct {
|
||||
md metadata
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewConnector(opts ...connector.Option) *Connector {
|
||||
options := &connector.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) Init(md connector.Metadata) (err error) {
|
||||
c.md, err = c.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) {
|
||||
if md == nil {
|
||||
md = connector.Metadata{}
|
||||
}
|
||||
|
||||
m.method = md[method]
|
||||
m.password = md[password]
|
||||
|
||||
return
|
||||
}
|
11
pkg/components/connector/ss/metadata.go
Normal file
11
pkg/components/connector/ss/metadata.go
Normal file
@ -0,0 +1,11 @@
|
||||
package ss
|
||||
|
||||
const (
|
||||
method = "method"
|
||||
password = "password"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
method string
|
||||
password string
|
||||
}
|
19
pkg/components/dialer/dialer.go
Normal file
19
pkg/components/dialer/dialer.go
Normal file
@ -0,0 +1,19 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Transporter is responsible for dialing to the proxy server.
|
||||
type Dialer interface {
|
||||
Dial(ctx context.Context, addr string, opts ...DialOption) (net.Conn, error)
|
||||
}
|
||||
|
||||
type Handshaker interface {
|
||||
Handshake(ctx context.Context, conn net.Conn) (net.Conn, error)
|
||||
}
|
||||
|
||||
type Multiplexer interface {
|
||||
IsMultiplex() bool
|
||||
}
|
3
pkg/components/dialer/metadata.go
Normal file
3
pkg/components/dialer/metadata.go
Normal file
@ -0,0 +1,3 @@
|
||||
package dialer
|
||||
|
||||
type Metadata map[string]string
|
32
pkg/components/dialer/option.go
Normal file
32
pkg/components/dialer/option.go
Normal file
@ -0,0 +1,32 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type Option func(opts *Options)
|
||||
|
||||
func LoggerOption(logger logger.Logger) Option {
|
||||
return func(opts *Options) {
|
||||
opts.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type DialOptions struct {
|
||||
DialFunc func(ctx context.Context, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type DialOption func(opts *DialOptions)
|
||||
|
||||
func DialFuncDialOption(dialf func(ctx context.Context, addr string) (net.Conn, error)) DialOption {
|
||||
return func(opts *DialOptions) {
|
||||
opts.DialFunc = dialf
|
||||
}
|
||||
}
|
56
pkg/components/dialer/tcp/dialer.go
Normal file
56
pkg/components/dialer/tcp/dialer.go
Normal file
@ -0,0 +1,56 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/dialer"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ dialer.Dialer = (*Dialer)(nil)
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
md metadata
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewDialer(opts ...dialer.Option) *Dialer {
|
||||
options := &dialer.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Dialer{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) Init(md dialer.Metadata) (err error) {
|
||||
d.md, err = d.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
|
||||
var options dialer.DialOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
dial := options.DialFunc
|
||||
if dial != nil {
|
||||
return dial(ctx, addr)
|
||||
}
|
||||
|
||||
var netd net.Dialer
|
||||
return netd.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
|
||||
func (d *Dialer) parseMetadata(md dialer.Metadata) (m metadata, err error) {
|
||||
return
|
||||
}
|
15
pkg/components/dialer/tcp/metadata.go
Normal file
15
pkg/components/dialer/tcp/metadata.go
Normal file
@ -0,0 +1,15 @@
|
||||
package tcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
dialTimeout = "dialTimeout"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
dialTimeout time.Duration
|
||||
}
|
10
pkg/components/handler/handler.go
Normal file
10
pkg/components/handler/handler.go
Normal file
@ -0,0 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
Handle(context.Context, net.Conn)
|
||||
}
|
225
pkg/components/handler/http/handler.go
Normal file
225
pkg/components/handler/http/handler.go
Normal file
@ -0,0 +1,225 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/components/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ handler.Handler = (*Handler)(nil)
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
chain *chain.Chain
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) *Handler {
|
||||
options := &handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
chain: options.Chain,
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Init(md handler.Metadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
if err != nil {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
"src": conn.RemoteAddr(),
|
||||
"local": conn.LocalAddr(),
|
||||
}).Error(err)
|
||||
return
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
h.handleRequest(ctx, conn, req)
|
||||
}
|
||||
|
||||
func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
// try to get the actual host.
|
||||
if v := req.Header.Get("Gost-Target"); v != "" {
|
||||
if h, err := decodeServerName(v); err == nil {
|
||||
req.Host = h
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
host := req.Host
|
||||
if _, port, _ := net.SplitHostPort(host); port == "" {
|
||||
host = net.JoinHostPort(host, "80")
|
||||
}
|
||||
|
||||
/*
|
||||
u, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
|
||||
if u != "" {
|
||||
u += "@"
|
||||
}
|
||||
log.Logf("[http] %s%s -> %s -> %s",
|
||||
u, conn.RemoteAddr(), h.options.Node.String(), host)
|
||||
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||
}
|
||||
|
||||
req.Header.Del("Gost-Target")
|
||||
*/
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
if h.md.proxyAgent != "" {
|
||||
resp.Header.Add("Proxy-Agent", h.md.proxyAgent)
|
||||
}
|
||||
|
||||
/*
|
||||
if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) {
|
||||
log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), host)
|
||||
resp.StatusCode = http.StatusForbidden
|
||||
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||
}
|
||||
|
||||
resp.Write(conn)
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
if h.options.Bypass.Contains(host) {
|
||||
resp.StatusCode = http.StatusForbidden
|
||||
|
||||
log.Logf("[http] %s - %s bypass %s",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), host)
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||
}
|
||||
|
||||
resp.Write(conn)
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
if !h.authenticate(conn, req, resp) {
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
if req.Method == "PRI" ||
|
||||
(req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
||||
resp.StatusCode = http.StatusBadRequest
|
||||
/*
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Logf("[http] %s <- %s\n%s",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||
}
|
||||
*/
|
||||
|
||||
resp.Write(conn)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Del("Proxy-Authorization")
|
||||
|
||||
cc, err := h.dial(ctx, host)
|
||||
if err != nil {
|
||||
resp.StatusCode = http.StatusServiceUnavailable
|
||||
|
||||
/*
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||
}
|
||||
*/
|
||||
resp.Write(conn)
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
resp.StatusCode = http.StatusOK
|
||||
resp.Status = "200 Connection established"
|
||||
resp.Write(conn)
|
||||
} else {
|
||||
req.Header.Del("Proxy-Connection")
|
||||
|
||||
if err = req.Write(cc); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler.Transport(conn, cc)
|
||||
}
|
||||
|
||||
func (h *Handler) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
|
||||
count := h.md.retryCount + 1
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
route := h.chain.GetRoute()
|
||||
|
||||
/*
|
||||
buf := bytes.Buffer{}
|
||||
fmt.Fprintf(&buf, "%s -> %s -> ",
|
||||
conn.RemoteAddr(), h.options.Node.String())
|
||||
for _, nd := range route.route {
|
||||
fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String())
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", host)
|
||||
log.Log("[route]", buf.String())
|
||||
*/
|
||||
|
||||
/*
|
||||
// forward http request
|
||||
lastNode := route.LastNode()
|
||||
if req.Method != http.MethodConnect && lastNode.Protocol == "http" {
|
||||
err = h.forwardRequest(conn, req, route)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
conn, err = route.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
7
pkg/components/handler/http/metadata.go
Normal file
7
pkg/components/handler/http/metadata.go
Normal file
@ -0,0 +1,7 @@
|
||||
package http
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
proxyAgent string
|
||||
retryCount int
|
||||
}
|
3
pkg/components/handler/metadata.go
Normal file
3
pkg/components/handler/metadata.go
Normal file
@ -0,0 +1,3 @@
|
||||
package handler
|
||||
|
||||
type Metadata map[string]string
|
25
pkg/components/handler/option.go
Normal file
25
pkg/components/handler/option.go
Normal file
@ -0,0 +1,25 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Chain *chain.Chain
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type Option func(opts *Options)
|
||||
|
||||
func LoggerOption(logger logger.Logger) Option {
|
||||
return func(opts *Options) {
|
||||
opts.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
func ChainOption(chain *chain.Chain) Option {
|
||||
return func(opts *Options) {
|
||||
opts.Chain = chain
|
||||
}
|
||||
}
|
129
pkg/components/handler/ss/handler.go
Normal file
129
pkg/components/handler/ss/handler.go
Normal file
@ -0,0 +1,129 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/components/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
|
||||
)
|
||||
|
||||
var (
|
||||
_ handler.Handler = (*Handler)(nil)
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) *Handler {
|
||||
options := &handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Init(md handler.Metadata) (err error) {
|
||||
h.md, err = h.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if h.md.cipher != nil {
|
||||
conn = &shadowConn{
|
||||
Conn: h.md.cipher.StreamConn(conn),
|
||||
}
|
||||
}
|
||||
|
||||
if h.md.readTimeout > 0 {
|
||||
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
|
||||
}
|
||||
|
||||
addr := &gosocks5.Addr{}
|
||||
_, err := addr.ReadFrom(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
host := addr.String()
|
||||
cc, err := net.Dial("tcp", host)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
handler.Transport(conn, cc)
|
||||
}
|
||||
|
||||
func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) {
|
||||
m.cipher, err = h.initCipher(md[method], md[password], md[key])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v, ok := md[readTimeout]; ok {
|
||||
m.readTimeout, _ = time.ParseDuration(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) initCipher(method, password string, key string) (core.Cipher, error) {
|
||||
if method == "" && password == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c, _ := ss.NewCipher(method, password)
|
||||
if c != nil {
|
||||
return &shadowCipher{cipher: c}, nil
|
||||
}
|
||||
|
||||
return core.PickCipher(method, []byte(key), password)
|
||||
}
|
||||
|
||||
type shadowCipher struct {
|
||||
cipher *ss.Cipher
|
||||
}
|
||||
|
||||
func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
|
||||
return ss.NewConn(conn, c.cipher.Copy())
|
||||
}
|
||||
|
||||
func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
|
||||
return ss.NewSecurePacketConn(conn, c.cipher.Copy())
|
||||
}
|
||||
|
||||
// Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write,
|
||||
// we wrap around it to make io.Copy happy.
|
||||
type shadowConn struct {
|
||||
net.Conn
|
||||
wbuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *shadowConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b) // force byte length consistent
|
||||
if c.wbuf.Len() > 0 {
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
_, err = c.Conn.Write(c.wbuf.Bytes())
|
||||
c.wbuf.Reset()
|
||||
return
|
||||
}
|
||||
_, err = c.Conn.Write(b)
|
||||
return
|
||||
}
|
19
pkg/components/handler/ss/metadata.go
Normal file
19
pkg/components/handler/ss/metadata.go
Normal file
@ -0,0 +1,19 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
)
|
||||
|
||||
const (
|
||||
method = "method"
|
||||
password = "password"
|
||||
key = "key"
|
||||
readTimeout = "readTimeout"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
cipher core.Cipher
|
||||
readTimeout time.Duration
|
||||
}
|
80
pkg/components/handler/ssu/handler.go
Normal file
80
pkg/components/handler/ssu/handler.go
Normal file
@ -0,0 +1,80 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
|
||||
)
|
||||
|
||||
var (
|
||||
_ handler.Handler = (*Handler)(nil)
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) *Handler {
|
||||
options := &handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Init(md handler.Metadata) (err error) {
|
||||
h.md, err = h.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) {
|
||||
m.cipher, err = h.initCipher(md[method], md[password], md[key])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v, ok := md[readTimeout]; ok {
|
||||
m.readTimeout, _ = time.ParseDuration(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) initCipher(method, password string, key string) (core.Cipher, error) {
|
||||
if method == "" && password == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c, _ := ss.NewCipher(method, password)
|
||||
if c != nil {
|
||||
return &shadowCipher{cipher: c}, nil
|
||||
}
|
||||
|
||||
return core.PickCipher(method, []byte(key), password)
|
||||
}
|
||||
|
||||
type shadowCipher struct {
|
||||
cipher *ss.Cipher
|
||||
}
|
||||
|
||||
func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
|
||||
return ss.NewConn(conn, c.cipher.Copy())
|
||||
}
|
||||
|
||||
func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
|
||||
return ss.NewSecurePacketConn(conn, c.cipher.Copy())
|
||||
}
|
19
pkg/components/handler/ssu/metadata.go
Normal file
19
pkg/components/handler/ssu/metadata.go
Normal file
@ -0,0 +1,19 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
)
|
||||
|
||||
const (
|
||||
method = "method"
|
||||
password = "password"
|
||||
key = "key"
|
||||
readTimeout = "readTimeout"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
cipher core.Cipher
|
||||
readTimeout time.Duration
|
||||
}
|
43
pkg/components/handler/transport.go
Normal file
43
pkg/components/handler/transport.go
Normal file
@ -0,0 +1,43 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
poolBufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
var (
|
||||
pool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolBufferSize)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func Transport(rw1, rw2 io.ReadWriter) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- copyBuffer(rw1, rw2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errc <- copyBuffer(rw2, rw1)
|
||||
}()
|
||||
|
||||
err := <-errc
|
||||
if err != nil && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func copyBuffer(dst io.Writer, src io.Reader) error {
|
||||
buf := pool.Get().([]byte)
|
||||
defer pool.Put(buf)
|
||||
|
||||
_, err := io.CopyBuffer(dst, src, buf)
|
||||
return err
|
||||
}
|
34
pkg/components/internal/utils/kcp.go
Normal file
34
pkg/components/internal/utils/kcp.go
Normal file
@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/golang/snappy"
|
||||
)
|
||||
|
||||
type kcpCompStreamConn struct {
|
||||
net.Conn
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
}
|
||||
|
||||
func KCPCompStreamConn(conn net.Conn) net.Conn {
|
||||
return &kcpCompStreamConn{
|
||||
Conn: conn,
|
||||
w: snappy.NewBufferedWriter(conn),
|
||||
r: snappy.NewReader(conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kcpCompStreamConn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *kcpCompStreamConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.w.Write(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = c.w.Flush()
|
||||
return n, err
|
||||
}
|
104
pkg/components/internal/utils/quic.go
Normal file
104
pkg/components/internal/utils/quic.go
Normal file
@ -0,0 +1,104 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type quicConn struct {
|
||||
quic.Session
|
||||
quic.Stream
|
||||
}
|
||||
|
||||
func QUICConn(session quic.Session, stream quic.Stream) net.Conn {
|
||||
return &quicConn{
|
||||
Session: session,
|
||||
Stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
type quicCipherConn struct {
|
||||
net.PacketConn
|
||||
key []byte
|
||||
}
|
||||
|
||||
func QUICCipherConn(conn net.PacketConn, key []byte) net.PacketConn {
|
||||
return &quicCipherConn{
|
||||
PacketConn: conn,
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = conn.PacketConn.ReadFrom(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b, err := conn.decrypt(data[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
copy(data, b)
|
||||
|
||||
return len(b), addr, nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
|
||||
b, err := conn.encrypt(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.PacketConn.WriteTo(b, addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
32
pkg/components/internal/utils/tcp.go
Normal file
32
pkg/components/internal/utils/tcp.go
Normal file
@ -0,0 +1,32 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
// TCPKeepAliveListener is a TCP listener with keep alive enabled.
|
||||
type TCPKeepAliveListener struct {
|
||||
KeepAlivePeriod time.Duration
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (l *TCPKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := l.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tc.SetKeepAlive(true)
|
||||
period := l.KeepAlivePeriod
|
||||
if period <= 0 {
|
||||
period = defaultKeepAlivePeriod
|
||||
}
|
||||
tc.SetKeepAlivePeriod(period)
|
||||
|
||||
return tc, nil
|
||||
}
|
40
pkg/components/internal/utils/tls.go
Normal file
40
pkg/components/internal/utils/tls.go
Normal file
@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// LoadTLSConfig loads the certificate from cert & key files and optional client CA file.
|
||||
func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
|
||||
if pool, _ := loadCA(caFile); pool != nil {
|
||||
cfg.ClientCAs = pool
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func loadCA(caFile string) (cp *x509.CertPool, err error) {
|
||||
if caFile == "" {
|
||||
return
|
||||
}
|
||||
cp = x509.NewCertPool()
|
||||
data, err := ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !cp.AppendCertsFromPEM(data) {
|
||||
return nil, errors.New("AppendCertsFromPEM failed")
|
||||
}
|
||||
return
|
||||
}
|
41
pkg/components/internal/utils/ws.go
Normal file
41
pkg/components/internal/utils/ws.go
Normal file
@ -0,0 +1,41 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type websocketConn struct {
|
||||
*websocket.Conn
|
||||
rb []byte
|
||||
}
|
||||
|
||||
func WebsocketServerConn(conn *websocket.Conn) net.Conn {
|
||||
return &websocketConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketConn) Read(b []byte) (n int, err error) {
|
||||
if len(c.rb) == 0 {
|
||||
_, c.rb, err = c.ReadMessage()
|
||||
}
|
||||
n = copy(b, c.rb)
|
||||
c.rb = c.rb[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *websocketConn) Write(b []byte) (n int, err error) {
|
||||
err = c.WriteMessage(websocket.BinaryMessage, b)
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||
if err := c.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SetWriteDeadline(t)
|
||||
}
|
115
pkg/components/listener/ftcp/conn.go
Normal file
115
pkg/components/listener/ftcp/conn.go
Normal file
@ -0,0 +1,115 @@
|
||||
package ftcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type serverConn struct {
|
||||
net.PacketConn
|
||||
raddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
fresh int32
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
config *serverConnConfig
|
||||
}
|
||||
|
||||
type serverConnConfig struct {
|
||||
ttl time.Duration
|
||||
qsize int
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
|
||||
if conn == nil || raddr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &serverConnConfig{}
|
||||
}
|
||||
c := &serverConn{
|
||||
PacketConn: conn,
|
||||
raddr: raddr,
|
||||
rc: make(chan []byte, cfg.qsize),
|
||||
closed: make(chan struct{}),
|
||||
config: cfg,
|
||||
}
|
||||
go c.ttlWait()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *serverConn) send(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
atomic.StoreInt32(&c.fresh, 1)
|
||||
case <-c.closed:
|
||||
err = errors.New("read from closed connection")
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.raddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.raddr)
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
return errors.New("connection is closed")
|
||||
default:
|
||||
if c.config.onClose != nil {
|
||||
c.config.onClose()
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *serverConn) ttlWait() {
|
||||
ticker := time.NewTicker(c.config.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
162
pkg/components/listener/ftcp/listener.go
Normal file
162
pkg/components/listener/ftcp/listener.go
Normal file
@ -0,0 +1,162 @@
|
||||
package ftcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/xtaci/tcpraw"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
conn net.PacketConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
connPool connPool
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.conn, err = tcpraw.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
err := l.conn.Close()
|
||||
l.connPool.Range(func(k interface{}, v *serverConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
b := make([]byte, l.md.readBufferSize)
|
||||
|
||||
n, raddr, err := l.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := l.connPool.Get(raddr.String())
|
||||
if !ok {
|
||||
conn = newServerConn(l.conn, raddr,
|
||||
&serverConnConfig{
|
||||
ttl: l.md.ttl,
|
||||
qsize: l.md.readQueueSize,
|
||||
onClose: func() {
|
||||
l.connPool.Delete(raddr.String())
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connPool.Set(raddr.String(), conn)
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.send(b[:n]); err != nil {
|
||||
l.logger.Warn("data discarded:", err)
|
||||
}
|
||||
l.logger.Debug("recv", n)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
size int64
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
conn, ok = v.(*serverConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *connPool) Set(key interface{}, conn *serverConn) {
|
||||
p.m.Store(key, conn)
|
||||
atomic.AddInt64(&p.size, 1)
|
||||
}
|
||||
|
||||
func (p *connPool) Delete(key interface{}) {
|
||||
p.m.Delete(key)
|
||||
atomic.AddInt64(&p.size, -1)
|
||||
}
|
||||
|
||||
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
|
||||
p.m.Range(func(k, v interface{}) bool {
|
||||
return f(k, v.(*serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (p *connPool) Size() int64 {
|
||||
return atomic.LoadInt64(&p.size)
|
||||
}
|
23
pkg/components/listener/ftcp/metadata.go
Normal file
23
pkg/components/listener/ftcp/metadata.go
Normal file
@ -0,0 +1,23 @@
|
||||
package ftcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultTTL = 60 * time.Second
|
||||
defaultReadBufferSize = 1024
|
||||
defaultReadQueueSize = 128
|
||||
defaultConnQueueSize = 128
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
ttl time.Duration
|
||||
|
||||
readBufferSize int
|
||||
readQueueSize int
|
||||
connQueueSize int
|
||||
}
|
54
pkg/components/listener/http2/conn.go
Normal file
54
pkg/components/listener/http2/conn.go
Normal file
@ -0,0 +1,54 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// a dummy HTTP2 server conn used by HTTP2 handler
|
||||
type conn struct {
|
||||
r *http.Request
|
||||
w http.ResponseWriter
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
89
pkg/components/listener/http2/h2/conn.go
Normal file
89
pkg/components/listener/http2/h2/conn.go
Normal file
@ -0,0 +1,89 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP2 connection, wrapped up just like a net.Conn
|
||||
type conn struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *conn) Close() (err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
if rc, ok := c.r.(io.Closer); ok {
|
||||
err = rc.Close()
|
||||
}
|
||||
if w, ok := c.w.(io.Closer); ok {
|
||||
err = w.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (fw flushWriter) Write(p []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(string); ok {
|
||||
err = errors.New(s)
|
||||
// log.Log("[http2]", err)
|
||||
return
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
n, err = fw.w.Write(p)
|
||||
if err != nil {
|
||||
// log.Log("flush writer:", err)
|
||||
return
|
||||
}
|
||||
if f, ok := fw.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
186
pkg/components/listener/http2/h2/listener.go
Normal file
186
pkg/components/listener/http2/h2/listener.go
Normal file
@ -0,0 +1,186 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
md metadata
|
||||
server *http2.Server
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l.Listener = &utils.TCPKeepAliveListener{
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
}
|
||||
// TODO: tune http2 server config
|
||||
l.server = &http2.Server{
|
||||
// MaxConcurrentStreams: 1000,
|
||||
PermitProhibitedCipherSuites: true,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan net.Conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
// log.Log("[http2] accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.handleLoop(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) handleLoop(conn net.Conn) {
|
||||
if l.md.tlsConfig != nil {
|
||||
tlsConn := tls.Server(conn, l.md.tlsConfig)
|
||||
// NOTE: HTTP2 server will check the TLS version,
|
||||
// so we must ensure that the TLS connection is handshake completed.
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
opt := http2.ServeConnOpts{
|
||||
Handler: http.HandlerFunc(l.handleFunc),
|
||||
}
|
||||
l.server.ServeConn(conn, &opt)
|
||||
}
|
||||
|
||||
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
/*
|
||||
log.Logf("[http2] %s -> %s %s %s %s",
|
||||
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Log("[http2]", string(dump))
|
||||
}
|
||||
*/
|
||||
// w.Header().Set("Proxy-Agent", "gost/"+Version)
|
||||
conn, err := l.upgrade(w, r)
|
||||
if err != nil {
|
||||
// log.Logf("[http2] %s - %s %s %s %s: %s",
|
||||
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
|
||||
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
|
||||
}
|
||||
|
||||
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
|
||||
if l.md.path == "" && r.Method != http.MethodConnect {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return nil, errors.New("method not allowed")
|
||||
}
|
||||
|
||||
if l.md.path != "" && r.RequestURI != l.md.path {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil, errors.New("bad request")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush() // write header to client
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if remoteAddr == nil {
|
||||
remoteAddr = &net.TCPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
return &conn{
|
||||
r: r.Body,
|
||||
w: flushWriter{w},
|
||||
localAddr: l.Listener.Addr(),
|
||||
remoteAddr: remoteAddr,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
pkg/components/listener/http2/h2/metadata.go
Normal file
38
pkg/components/listener/http2/h2/metadata.go
Normal file
@ -0,0 +1,38 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
140
pkg/components/listener/http2/listener.go
Normal file
140
pkg/components/listener/http2/listener.go
Normal file
@ -0,0 +1,140 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
server *http.Server
|
||||
addr net.Addr
|
||||
connChan chan *conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.md.addr,
|
||||
Handler: http.HandlerFunc(l.handleFunc),
|
||||
TLSConfig: l.md.tlsConfig,
|
||||
}
|
||||
if err := http2.ConfigureServer(l.server, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.addr = ln.Addr()
|
||||
|
||||
ln = tls.NewListener(
|
||||
&utils.TCPKeepAliveListener{
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
},
|
||||
l.md.tlsConfig,
|
||||
)
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan *conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if err := l.server.Serve(ln); err != nil {
|
||||
// log.Log("[http2]", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *Listener) Close() (err error) {
|
||||
select {
|
||||
case <-l.errChan:
|
||||
default:
|
||||
err = l.server.Close()
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
conn := &conn{
|
||||
r: r,
|
||||
w: w,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
default:
|
||||
// log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
<-conn.closed
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
pkg/components/listener/http2/metadata.go
Normal file
38
pkg/components/listener/http2/metadata.go
Normal file
@ -0,0 +1,38 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
115
pkg/components/listener/kcp/config.go
Normal file
115
pkg/components/listener/kcp/config.go
Normal file
@ -0,0 +1,115 @@
|
||||
package kcp
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
var (
|
||||
// Salt is the default salt for KCP cipher.
|
||||
Salt = "kcp-go"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultKCPConfig is the default KCP config.
|
||||
DefaultConfig = &Config{
|
||||
Key: "it's a secrect",
|
||||
Crypt: "aes",
|
||||
Mode: "fast",
|
||||
MTU: 1350,
|
||||
SndWnd: 1024,
|
||||
RcvWnd: 1024,
|
||||
DataShard: 10,
|
||||
ParityShard: 3,
|
||||
DSCP: 0,
|
||||
NoComp: false,
|
||||
AckNodelay: false,
|
||||
NoDelay: 0,
|
||||
Interval: 50,
|
||||
Resend: 0,
|
||||
NoCongestion: 0,
|
||||
SockBuf: 4194304,
|
||||
KeepAlive: 10,
|
||||
SnmpLog: "",
|
||||
SnmpPeriod: 60,
|
||||
Signal: false,
|
||||
TCP: false,
|
||||
}
|
||||
)
|
||||
|
||||
// KCPConfig describes the config for KCP.
|
||||
type Config struct {
|
||||
Key string `json:"key"`
|
||||
Crypt string `json:"crypt"`
|
||||
Mode string `json:"mode"`
|
||||
MTU int `json:"mtu"`
|
||||
SndWnd int `json:"sndwnd"`
|
||||
RcvWnd int `json:"rcvwnd"`
|
||||
DataShard int `json:"datashard"`
|
||||
ParityShard int `json:"parityshard"`
|
||||
DSCP int `json:"dscp"`
|
||||
NoComp bool `json:"nocomp"`
|
||||
AckNodelay bool `json:"acknodelay"`
|
||||
NoDelay int `json:"nodelay"`
|
||||
Interval int `json:"interval"`
|
||||
Resend int `json:"resend"`
|
||||
NoCongestion int `json:"nc"`
|
||||
SockBuf int `json:"sockbuf"`
|
||||
KeepAlive int `json:"keepalive"`
|
||||
SnmpLog string `json:"snmplog"`
|
||||
SnmpPeriod int `json:"snmpperiod"`
|
||||
Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature.
|
||||
TCP bool `json:"tcp"`
|
||||
}
|
||||
|
||||
// Init initializes the KCP config.
|
||||
func (c *Config) Init() {
|
||||
switch c.Mode {
|
||||
case "normal":
|
||||
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1
|
||||
case "fast":
|
||||
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 30, 2, 1
|
||||
case "fast2":
|
||||
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1
|
||||
case "fast3":
|
||||
c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 10, 2, 1
|
||||
}
|
||||
}
|
||||
|
||||
func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) {
|
||||
pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New)
|
||||
|
||||
switch crypt {
|
||||
case "sm4":
|
||||
block, _ = kcp.NewSM4BlockCrypt(pass[:16])
|
||||
case "tea":
|
||||
block, _ = kcp.NewTEABlockCrypt(pass[:16])
|
||||
case "xor":
|
||||
block, _ = kcp.NewSimpleXORBlockCrypt(pass)
|
||||
case "none":
|
||||
block, _ = kcp.NewNoneBlockCrypt(pass)
|
||||
case "aes-128":
|
||||
block, _ = kcp.NewAESBlockCrypt(pass[:16])
|
||||
case "aes-192":
|
||||
block, _ = kcp.NewAESBlockCrypt(pass[:24])
|
||||
case "blowfish":
|
||||
block, _ = kcp.NewBlowfishBlockCrypt(pass)
|
||||
case "twofish":
|
||||
block, _ = kcp.NewTwofishBlockCrypt(pass)
|
||||
case "cast5":
|
||||
block, _ = kcp.NewCast5BlockCrypt(pass[:16])
|
||||
case "3des":
|
||||
block, _ = kcp.NewTripleDESBlockCrypt(pass[:24])
|
||||
case "xtea":
|
||||
block, _ = kcp.NewXTEABlockCrypt(pass[:16])
|
||||
case "salsa20":
|
||||
block, _ = kcp.NewSalsa20BlockCrypt(pass)
|
||||
case "aes":
|
||||
fallthrough
|
||||
default: // aes
|
||||
block, _ = kcp.NewAESBlockCrypt(pass)
|
||||
}
|
||||
return
|
||||
}
|
179
pkg/components/listener/kcp/listener.go
Normal file
179
pkg/components/listener/kcp/listener.go
Normal file
@ -0,0 +1,179 @@
|
||||
package kcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
"github.com/xtaci/tcpraw"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
ln *kcp.Listener
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := l.md.config
|
||||
if config == nil {
|
||||
config = DefaultConfig
|
||||
}
|
||||
config.Init()
|
||||
|
||||
var ln *kcp.Listener
|
||||
|
||||
if config.TCP {
|
||||
var conn net.PacketConn
|
||||
conn, err = tcpraw.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln, err = kcp.ServeConn(
|
||||
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard, conn)
|
||||
} else {
|
||||
ln, err = kcp.ListenWithOptions(addr,
|
||||
blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if config.DSCP > 0 {
|
||||
if err = ln.SetDSCP(config.DSCP); err != nil {
|
||||
l.logger.Warn(err)
|
||||
}
|
||||
}
|
||||
if err = ln.SetReadBuffer(config.SockBuf); err != nil {
|
||||
l.logger.Warn(err)
|
||||
}
|
||||
if err = ln.SetWriteBuffer(config.SockBuf); err != nil {
|
||||
l.logger.Warn(err)
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.ln.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.ln.Addr()
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.ln.AcceptKCP()
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(false)
|
||||
conn.SetNoDelay(
|
||||
l.md.config.NoDelay,
|
||||
l.md.config.Interval,
|
||||
l.md.config.Resend,
|
||||
l.md.config.NoCongestion,
|
||||
)
|
||||
conn.SetMtu(l.md.config.MTU)
|
||||
conn.SetWindowSize(l.md.config.SndWnd, l.md.config.RcvWnd)
|
||||
conn.SetACKNoDelay(l.md.config.AckNodelay)
|
||||
go l.mux(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) mux(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.MaxReceiveBuffer = l.md.config.SockBuf
|
||||
smuxConfig.KeepAliveInterval = time.Duration(l.md.config.KeepAlive) * time.Second
|
||||
|
||||
if !l.md.config.NoComp {
|
||||
conn = utils.KCPCompStreamConn(conn)
|
||||
}
|
||||
|
||||
mux, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer mux.Close()
|
||||
|
||||
for {
|
||||
stream, err := mux.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream:", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
18
pkg/components/listener/kcp/metadata.go
Normal file
18
pkg/components/listener/kcp/metadata.go
Normal file
@ -0,0 +1,18 @@
|
||||
package kcp
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
config *Config
|
||||
|
||||
connQueueSize int
|
||||
}
|
20
pkg/components/listener/listener.go
Normal file
20
pkg/components/listener/listener.go
Normal file
@ -0,0 +1,20 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("accpet on closed listener")
|
||||
)
|
||||
|
||||
// Listener is a server listener, just like a net.Listener.
|
||||
type Listener interface {
|
||||
net.Listener
|
||||
}
|
||||
|
||||
// Accepter represents a network endpoint that can accept connection from peer.
|
||||
type Accepter interface {
|
||||
Accept() (net.Conn, error)
|
||||
}
|
3
pkg/components/listener/metadata.go
Normal file
3
pkg/components/listener/metadata.go
Normal file
@ -0,0 +1,3 @@
|
||||
package listener
|
||||
|
||||
type Metadata map[string]string
|
140
pkg/components/listener/obfs/http/conn.go
Normal file
140
pkg/components/listener/obfs/http/conn.go
Normal file
@ -0,0 +1,140 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
net.Conn
|
||||
rbuf bytes.Buffer
|
||||
wbuf bytes.Buffer
|
||||
handshaked bool
|
||||
handshakeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *conn) Handshake() (err error) {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
if c.handshaked {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = c.handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.handshaked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) handshake() (err error) {
|
||||
br := bufio.NewReader(c.Conn)
|
||||
r, err := http.ReadRequest(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
/*
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump))
|
||||
}
|
||||
*/
|
||||
|
||||
if r.ContentLength > 0 {
|
||||
_, err = io.Copy(&c.rbuf, r.Body)
|
||||
} else {
|
||||
var b []byte
|
||||
b, err = br.Peek(br.Buffered())
|
||||
if len(b) > 0 {
|
||||
_, err = c.rbuf.Write(b)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
b := bytes.Buffer{}
|
||||
|
||||
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
|
||||
b.WriteString("HTTP/1.1 503 Service Unavailable\r\n")
|
||||
b.WriteString("Content-Length: 0\r\n")
|
||||
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
|
||||
b.WriteString("\r\n")
|
||||
|
||||
/*
|
||||
if Debug {
|
||||
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
|
||||
}
|
||||
*/
|
||||
|
||||
b.WriteTo(c.Conn)
|
||||
return errors.New("bad request")
|
||||
}
|
||||
|
||||
b.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
|
||||
b.WriteString("Server: nginx/1.10.0\r\n")
|
||||
b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
|
||||
b.WriteString("Connection: Upgrade\r\n")
|
||||
b.WriteString("Upgrade: websocket\r\n")
|
||||
b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key"))))
|
||||
b.WriteString("\r\n")
|
||||
|
||||
/*
|
||||
if Debug {
|
||||
log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String())
|
||||
}
|
||||
*/
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
c.wbuf = b // cache the response header if there are extra data in the request body.
|
||||
return
|
||||
}
|
||||
|
||||
_, err = b.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
return c.rbuf.Read(b)
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
if c.wbuf.Len() > 0 {
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
n = len(b) // exclude the header length
|
||||
return
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
func computeAcceptKey(challengeKey string) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(challengeKey))
|
||||
h.Write(keyGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
88
pkg/components/listener/obfs/http/listener.go
Normal file
88
pkg/components/listener/obfs/http/listener.go
Normal file
@ -0,0 +1,88 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.keepAlive {
|
||||
l.Listener = &utils.TCPKeepAliveListener{
|
||||
TCPListener: ln,
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conn{Conn: c}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.keepAlive = true
|
||||
if val, ok := md[keepAlive]; ok {
|
||||
m.keepAlive, _ = strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
if val, ok := md[keepAlivePeriod]; ok {
|
||||
m.keepAlivePeriod, _ = time.ParseDuration(val)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
19
pkg/components/listener/obfs/http/metadata.go
Normal file
19
pkg/components/listener/obfs/http/metadata.go
Normal file
@ -0,0 +1,19 @@
|
||||
package http
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
keepAlive = "keepAlive"
|
||||
keepAlivePeriod = "keepAlivePeriod"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
keepAlive bool
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
306
pkg/components/listener/obfs/tls/conn.go
Normal file
306
pkg/components/listener/obfs/tls/conn.go
Normal file
@ -0,0 +1,306 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dissector "github.com/ginuerzh/tls-dissector"
|
||||
)
|
||||
|
||||
const (
|
||||
maxTLSDataLen = 16384
|
||||
)
|
||||
|
||||
var (
|
||||
cipherSuites = []uint16{
|
||||
0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f,
|
||||
0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a,
|
||||
0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d,
|
||||
0x003c, 0x0035, 0x002f, 0x00ff,
|
||||
}
|
||||
|
||||
compressionMethods = []uint8{0x00}
|
||||
|
||||
algorithms = []uint16{
|
||||
0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402,
|
||||
0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203,
|
||||
}
|
||||
|
||||
tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17}
|
||||
tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03}
|
||||
|
||||
ErrBadType = errors.New("bad type")
|
||||
ErrBadMajorVersion = errors.New("bad major version")
|
||||
ErrBadMinorVersion = errors.New("bad minor version")
|
||||
ErrMaxDataLen = errors.New("bad tls data len")
|
||||
)
|
||||
|
||||
const (
|
||||
tlsRecordStateType = iota
|
||||
tlsRecordStateVersion0
|
||||
tlsRecordStateVersion1
|
||||
tlsRecordStateLength0
|
||||
tlsRecordStateLength1
|
||||
tlsRecordStateData
|
||||
)
|
||||
|
||||
type obfsTLSParser struct {
|
||||
step uint8
|
||||
state uint8
|
||||
length uint16
|
||||
}
|
||||
|
||||
func (r *obfsTLSParser) Parse(b []byte) (int, error) {
|
||||
i := 0
|
||||
last := 0
|
||||
length := len(b)
|
||||
|
||||
for i < length {
|
||||
ch := b[i]
|
||||
switch r.state {
|
||||
case tlsRecordStateType:
|
||||
if tlsRecordTypes[r.step] != ch {
|
||||
return 0, ErrBadType
|
||||
}
|
||||
r.state = tlsRecordStateVersion0
|
||||
i++
|
||||
case tlsRecordStateVersion0:
|
||||
if ch != 0x03 {
|
||||
return 0, ErrBadMajorVersion
|
||||
}
|
||||
r.state = tlsRecordStateVersion1
|
||||
i++
|
||||
case tlsRecordStateVersion1:
|
||||
if ch != tlsVersionMinors[r.step] {
|
||||
return 0, ErrBadMinorVersion
|
||||
}
|
||||
r.state = tlsRecordStateLength0
|
||||
i++
|
||||
case tlsRecordStateLength0:
|
||||
r.length = uint16(ch) << 8
|
||||
r.state = tlsRecordStateLength1
|
||||
i++
|
||||
case tlsRecordStateLength1:
|
||||
r.length |= uint16(ch)
|
||||
if r.step == 0 {
|
||||
r.length = 91
|
||||
} else if r.step == 1 {
|
||||
r.length = 1
|
||||
} else if r.length > maxTLSDataLen {
|
||||
return 0, ErrMaxDataLen
|
||||
}
|
||||
if r.length > 0 {
|
||||
r.state = tlsRecordStateData
|
||||
} else {
|
||||
r.state = tlsRecordStateType
|
||||
r.step++
|
||||
}
|
||||
i++
|
||||
case tlsRecordStateData:
|
||||
left := uint16(length - i)
|
||||
if left > r.length {
|
||||
left = r.length
|
||||
}
|
||||
if r.step >= 2 {
|
||||
skip := i - last
|
||||
copy(b[last:], b[i:length])
|
||||
length -= int(skip)
|
||||
last += int(left)
|
||||
i = last
|
||||
} else {
|
||||
i += int(left)
|
||||
}
|
||||
r.length -= left
|
||||
if r.length == 0 {
|
||||
if r.step < 3 {
|
||||
r.step++
|
||||
}
|
||||
r.state = tlsRecordStateType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if last == 0 {
|
||||
return 0, nil
|
||||
} else if last < length {
|
||||
length -= last
|
||||
}
|
||||
|
||||
return length, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
net.Conn
|
||||
rbuf bytes.Buffer
|
||||
wbuf bytes.Buffer
|
||||
host string
|
||||
handshaked chan struct{}
|
||||
parser *obfsTLSParser
|
||||
handshakeMutex sync.Mutex
|
||||
}
|
||||
|
||||
// newConn creates a connection for obfs-tls server.
|
||||
func newConn(c net.Conn, host string) net.Conn {
|
||||
return &conn{
|
||||
Conn: c,
|
||||
host: host,
|
||||
handshaked: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Handshaked() bool {
|
||||
select {
|
||||
case <-c.handshaked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Handshake(payload []byte) (err error) {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
if c.Handshaked() {
|
||||
return
|
||||
}
|
||||
|
||||
if err = c.handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
close(c.handshaked)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) handshake() error {
|
||||
record := &dissector.Record{}
|
||||
if _, err := record.ReadFrom(c.Conn); err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
if record.Type != dissector.Handshake {
|
||||
return dissector.ErrBadType
|
||||
}
|
||||
|
||||
clientMsg := &dissector.ClientHelloMsg{}
|
||||
if err := clientMsg.Decode(record.Opaque); err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ext := range clientMsg.Extensions {
|
||||
if ext.Type() == dissector.ExtSessionTicket {
|
||||
b, err := ext.Encode()
|
||||
if err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
c.rbuf.Write(b)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
serverMsg := &dissector.ServerHelloMsg{
|
||||
Version: tls.VersionTLS12,
|
||||
SessionID: clientMsg.SessionID,
|
||||
CipherSuite: 0xcca8,
|
||||
CompressionMethod: 0x00,
|
||||
Extensions: []dissector.Extension{
|
||||
&dissector.RenegotiationInfoExtension{},
|
||||
&dissector.ExtendedMasterSecretExtension{},
|
||||
&dissector.ECPointFormatsExtension{
|
||||
Formats: []uint8{0x00},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
serverMsg.Random.Time = uint32(time.Now().Unix())
|
||||
rand.Read(serverMsg.Random.Opaque[:])
|
||||
b, err := serverMsg.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record = &dissector.Record{
|
||||
Type: dissector.Handshake,
|
||||
Version: tls.VersionTLS10,
|
||||
Opaque: b,
|
||||
}
|
||||
|
||||
if _, err := record.WriteTo(&c.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record = &dissector.Record{
|
||||
Type: dissector.ChangeCipherSpec,
|
||||
Version: tls.VersionTLS12,
|
||||
Opaque: []byte{0x01},
|
||||
}
|
||||
if _, err := record.WriteTo(&c.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.handshaked:
|
||||
}
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
return c.rbuf.Read(b)
|
||||
}
|
||||
record := &dissector.Record{}
|
||||
if _, err = record.ReadFrom(c.Conn); err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(b, record.Opaque)
|
||||
_, err = c.rbuf.Write(record.Opaque[n:])
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if !c.Handshaked() {
|
||||
if err = c.Handshake(b); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for len(b) > 0 {
|
||||
data := b
|
||||
if len(b) > maxTLSDataLen {
|
||||
data = b[:maxTLSDataLen]
|
||||
b = b[maxTLSDataLen:]
|
||||
} else {
|
||||
b = b[:0]
|
||||
}
|
||||
record := &dissector.Record{
|
||||
Type: dissector.AppData,
|
||||
Version: tls.VersionTLS12,
|
||||
Opaque: data,
|
||||
}
|
||||
|
||||
if c.wbuf.Len() > 0 {
|
||||
record.Type = dissector.Handshake
|
||||
record.WriteTo(&c.wbuf)
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = record.WriteTo(c.Conn); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
88
pkg/components/listener/obfs/tls/listener.go
Normal file
88
pkg/components/listener/obfs/tls/listener.go
Normal file
@ -0,0 +1,88 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.keepAlive {
|
||||
l.Listener = &utils.TCPKeepAliveListener{
|
||||
TCPListener: ln,
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conn{Conn: c}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.keepAlive = true
|
||||
if val, ok := md[keepAlive]; ok {
|
||||
m.keepAlive, _ = strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
if val, ok := md[keepAlivePeriod]; ok {
|
||||
m.keepAlivePeriod, _ = time.ParseDuration(val)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
19
pkg/components/listener/obfs/tls/metadata.go
Normal file
19
pkg/components/listener/obfs/tls/metadata.go
Normal file
@ -0,0 +1,19 @@
|
||||
package tls
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
keepAlive = "keepAlive"
|
||||
keepAlivePeriod = "keepAlivePeriod"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
keepAlive bool
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
17
pkg/components/listener/option.go
Normal file
17
pkg/components/listener/option.go
Normal file
@ -0,0 +1,17 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type Option func(opts *Options)
|
||||
|
||||
func LoggerOption(logger logger.Logger) Option {
|
||||
return func(opts *Options) {
|
||||
opts.Logger = logger
|
||||
}
|
||||
}
|
142
pkg/components/listener/quic/listener.go
Normal file
142
pkg/components/listener/quic/listener.go
Normal file
@ -0,0 +1,142 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
ln quic.Listener
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
conn, err = net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.cipherKey != nil {
|
||||
conn = utils.QUICCipherConn(conn, l.md.cipherKey)
|
||||
}
|
||||
|
||||
config := &quic.Config{
|
||||
KeepAlive: l.md.keepAlive,
|
||||
HandshakeIdleTimeout: l.md.HandshakeTimeout,
|
||||
MaxIdleTimeout: l.md.MaxIdleTimeout,
|
||||
}
|
||||
|
||||
ln, err := quic.Listen(conn, l.md.tlsConfig, config)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.ln.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.ln.Addr()
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
ctx := context.Background()
|
||||
session, err := l.ln.Accept(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.mux(ctx, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) mux(ctx context.Context, session quic.Session) {
|
||||
defer session.CloseWithError(0, "")
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream:", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn := utils.QUICConn(session, stream)
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
case <-stream.Context().Done():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
32
pkg/components/listener/quic/metadata.go
Normal file
32
pkg/components/listener/quic/metadata.go
Normal file
@ -0,0 +1,32 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
|
||||
keepAlive = "keepAlive"
|
||||
keepAlivePeriod = "keepAlivePeriod"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
keepAlive bool
|
||||
HandshakeTimeout time.Duration
|
||||
MaxIdleTimeout time.Duration
|
||||
|
||||
cipherKey []byte
|
||||
connQueueSize int
|
||||
}
|
79
pkg/components/listener/tcp/listener.go
Normal file
79
pkg/components/listener/tcp/listener.go
Normal file
@ -0,0 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.keepAlive {
|
||||
l.Listener = &utils.TCPKeepAliveListener{
|
||||
TCPListener: ln,
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.keepAlive = true
|
||||
if val, ok := md[keepAlive]; ok {
|
||||
m.keepAlive, _ = strconv.ParseBool(val)
|
||||
}
|
||||
|
||||
if val, ok := md[keepAlivePeriod]; ok {
|
||||
m.keepAlivePeriod, _ = time.ParseDuration(val)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
19
pkg/components/listener/tcp/metadata.go
Normal file
19
pkg/components/listener/tcp/metadata.go
Normal file
@ -0,0 +1,19 @@
|
||||
package tcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
keepAlive = "keepAlive"
|
||||
keepAlivePeriod = "keepAlivePeriod"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
keepAlive bool
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
75
pkg/components/listener/tls/listener.go
Normal file
75
pkg/components/listener/tls/listener.go
Normal file
@ -0,0 +1,75 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln = tls.NewListener(
|
||||
&utils.TCPKeepAliveListener{
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
},
|
||||
l.md.tlsConfig,
|
||||
)
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if val, ok := md[keepAlivePeriod]; ok {
|
||||
m.keepAlivePeriod, _ = time.ParseDuration(val)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
20
pkg/components/listener/tls/metadata.go
Normal file
20
pkg/components/listener/tls/metadata.go
Normal file
@ -0,0 +1,20 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
keepAlivePeriod = "keepAlivePeriod"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
141
pkg/components/listener/tls/mux/listener.go
Normal file
141
pkg/components/listener/tls/mux/listener.go
Normal file
@ -0,0 +1,141 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
net.Listener
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan net.Conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.mux(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) mux(conn net.Conn) {
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
|
||||
if l.md.muxKeepAlivePeriod > 0 {
|
||||
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
|
||||
}
|
||||
if l.md.muxKeepAliveTimeout > 0 {
|
||||
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
|
||||
}
|
||||
if l.md.muxMaxFrameSize > 0 {
|
||||
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
|
||||
}
|
||||
if l.md.muxMaxReceiveBuffer > 0 {
|
||||
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
|
||||
}
|
||||
if l.md.muxMaxStreamBuffer > 0 {
|
||||
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
|
||||
}
|
||||
session, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream:", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
pkg/components/listener/tls/mux/metadata.go
Normal file
38
pkg/components/listener/tls/mux/metadata.go
Normal file
@ -0,0 +1,38 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAlivePeriod = "muxKeepAlivePeriod"
|
||||
muxKeepAliveTimeout = "muxKeepAliveTimeout"
|
||||
muxMaxFrameSize = "muxMaxFrameSize"
|
||||
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
|
||||
muxMaxStreamBuffer = "muxMaxStreamBuffer"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
muxKeepAliveDisabled bool
|
||||
muxKeepAlivePeriod time.Duration
|
||||
muxKeepAliveTimeout time.Duration
|
||||
muxMaxFrameSize int
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
|
||||
connQueueSize int
|
||||
}
|
115
pkg/components/listener/udp/conn.go
Normal file
115
pkg/components/listener/udp/conn.go
Normal file
@ -0,0 +1,115 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type serverConn struct {
|
||||
net.PacketConn
|
||||
raddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
fresh int32
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
config *serverConnConfig
|
||||
}
|
||||
|
||||
type serverConnConfig struct {
|
||||
ttl time.Duration
|
||||
qsize int
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
|
||||
if conn == nil || raddr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &serverConnConfig{}
|
||||
}
|
||||
c := &serverConn{
|
||||
PacketConn: conn,
|
||||
raddr: raddr,
|
||||
rc: make(chan []byte, cfg.qsize),
|
||||
closed: make(chan struct{}),
|
||||
config: cfg,
|
||||
}
|
||||
go c.ttlWait()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *serverConn) send(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
atomic.StoreInt32(&c.fresh, 1)
|
||||
case <-c.closed:
|
||||
err = errors.New("read from closed connection")
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.raddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.raddr)
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
return errors.New("connection is closed")
|
||||
default:
|
||||
if c.config.onClose != nil {
|
||||
c.config.onClose()
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *serverConn) ttlWait() {
|
||||
ticker := time.NewTicker(c.config.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
168
pkg/components/listener/udp/listener.go
Normal file
168
pkg/components/listener/udp/listener.go
Normal file
@ -0,0 +1,168 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
conn net.PacketConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
connPool connPool
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveUDPAddr("udp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
conn, err = net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.conn = conn
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
err := l.conn.Close()
|
||||
l.connPool.Range(func(k interface{}, v *serverConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
b := make([]byte, l.md.readBufferSize)
|
||||
|
||||
n, raddr, err := l.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := l.connPool.Get(raddr.String())
|
||||
if !ok {
|
||||
conn = newServerConn(l.conn, raddr,
|
||||
&serverConnConfig{
|
||||
ttl: l.md.ttl,
|
||||
qsize: l.md.readQueueSize,
|
||||
onClose: func() {
|
||||
l.connPool.Delete(raddr.String())
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connPool.Set(raddr.String(), conn)
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.send(b[:n]); err != nil {
|
||||
l.logger.Warn("data discarded:", err)
|
||||
}
|
||||
l.logger.Debug("recv", n)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
size int64
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
conn, ok = v.(*serverConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *connPool) Set(key interface{}, conn *serverConn) {
|
||||
p.m.Store(key, conn)
|
||||
atomic.AddInt64(&p.size, 1)
|
||||
}
|
||||
|
||||
func (p *connPool) Delete(key interface{}) {
|
||||
p.m.Delete(key)
|
||||
atomic.AddInt64(&p.size, -1)
|
||||
}
|
||||
|
||||
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
|
||||
p.m.Range(func(k, v interface{}) bool {
|
||||
return f(k, v.(*serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (p *connPool) Size() int64 {
|
||||
return atomic.LoadInt64(&p.size)
|
||||
}
|
23
pkg/components/listener/udp/metadata.go
Normal file
23
pkg/components/listener/udp/metadata.go
Normal file
@ -0,0 +1,23 @@
|
||||
package udp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultTTL = 60 * time.Second
|
||||
defaultReadBufferSize = 1024
|
||||
defaultReadQueueSize = 128
|
||||
defaultConnQueueSize = 128
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
ttl time.Duration
|
||||
|
||||
readBufferSize int
|
||||
readQueueSize int
|
||||
connQueueSize int
|
||||
}
|
143
pkg/components/listener/ws/listener.go
Normal file
143
pkg/components/listener/ws/listener.go
Normal file
@ -0,0 +1,143 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.upgrader = &websocket.Upgrader{
|
||||
HandshakeTimeout: l.md.handshakeTimeout,
|
||||
ReadBufferSize: l.md.readBufferSize,
|
||||
WriteBufferSize: l.md.writeBufferSize,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
EnableCompression: l.md.enableCompression,
|
||||
}
|
||||
|
||||
path := l.md.path
|
||||
if path == "" {
|
||||
path = defaultPath
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(path, http.HandlerFunc(l.upgrade))
|
||||
l.srv = &http.Server{
|
||||
Addr: l.md.addr,
|
||||
TLSConfig: l.md.tlsConfig,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: l.md.readHeaderTimeout,
|
||||
}
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan net.Conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if l.md.tlsConfig != nil {
|
||||
ln = tls.NewListener(ln, l.md.tlsConfig)
|
||||
}
|
||||
|
||||
l.addr = ln.Addr()
|
||||
|
||||
go func() {
|
||||
err := l.srv.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.srv.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- utils.WebsocketServerConn(conn):
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Warn("connection queue is full")
|
||||
}
|
||||
}
|
40
pkg/components/listener/ws/metadata.go
Normal file
40
pkg/components/listener/ws/metadata.go
Normal file
@ -0,0 +1,40 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
responseHeader = "responseHeader"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPath = "/ws"
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
}
|
178
pkg/components/listener/ws/mux/listener.go
Normal file
178
pkg/components/listener/ws/mux/listener.go
Normal file
@ -0,0 +1,178 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/internal/utils"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.upgrader = &websocket.Upgrader{
|
||||
HandshakeTimeout: l.md.handshakeTimeout,
|
||||
ReadBufferSize: l.md.readBufferSize,
|
||||
WriteBufferSize: l.md.writeBufferSize,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
EnableCompression: l.md.enableCompression,
|
||||
}
|
||||
|
||||
path := l.md.path
|
||||
if path == "" {
|
||||
path = defaultPath
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(path, http.HandlerFunc(l.upgrade))
|
||||
l.srv = &http.Server{
|
||||
Addr: l.md.addr,
|
||||
TLSConfig: l.md.tlsConfig,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: l.md.readHeaderTimeout,
|
||||
}
|
||||
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if l.md.tlsConfig != nil {
|
||||
ln = tls.NewListener(ln, l.md.tlsConfig)
|
||||
}
|
||||
|
||||
l.addr = ln.Addr()
|
||||
|
||||
go func() {
|
||||
err := l.srv.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.srv.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
l.mux(utils.WebsocketServerConn(conn))
|
||||
}
|
||||
|
||||
func (l *Listener) mux(conn net.Conn) {
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
|
||||
if l.md.muxKeepAlivePeriod > 0 {
|
||||
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
|
||||
}
|
||||
if l.md.muxKeepAliveTimeout > 0 {
|
||||
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
|
||||
}
|
||||
if l.md.muxMaxFrameSize > 0 {
|
||||
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
|
||||
}
|
||||
if l.md.muxMaxReceiveBuffer > 0 {
|
||||
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
|
||||
}
|
||||
if l.md.muxMaxStreamBuffer > 0 {
|
||||
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
|
||||
}
|
||||
session, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream:", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
}
|
54
pkg/components/listener/ws/mux/metadata.go
Normal file
54
pkg/components/listener/ws/mux/metadata.go
Normal file
@ -0,0 +1,54 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
responseHeader = "responseHeader"
|
||||
connQueueSize = "connQueueSize"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAlivePeriod = "muxKeepAlivePeriod"
|
||||
muxKeepAliveTimeout = "muxKeepAliveTimeout"
|
||||
muxMaxFrameSize = "muxMaxFrameSize"
|
||||
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
|
||||
muxMaxStreamBuffer = "muxMaxStreamBuffer"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPath = "/ws"
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
|
||||
muxKeepAliveDisabled bool
|
||||
muxKeepAlivePeriod time.Duration
|
||||
muxKeepAliveTimeout time.Duration
|
||||
muxMaxFrameSize int
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
connQueueSize int
|
||||
}
|
96
pkg/logger/gost_logger.go
Normal file
96
pkg/logger/gost_logger.go
Normal file
@ -0,0 +1,96 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Logger = (*logger)(nil)
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func newLogger(name string) *logger {
|
||||
l := logrus.New()
|
||||
l.SetOutput(os.Stdout)
|
||||
|
||||
gl := &logger{
|
||||
logger: l.WithFields(logrus.Fields{
|
||||
logFieldScope: name,
|
||||
}),
|
||||
}
|
||||
|
||||
return gl
|
||||
}
|
||||
|
||||
// EnableJSONOutput enables JSON formatted output log.
|
||||
func (l *logger) EnableJSONOutput(enabled bool) {
|
||||
|
||||
}
|
||||
|
||||
// SetOutputLevel sets log output level
|
||||
func (l *logger) SetLevel(level LogLevel) {
|
||||
lvl, _ := logrus.ParseLevel(string(level))
|
||||
l.logger.Logger.SetLevel(lvl)
|
||||
}
|
||||
|
||||
// WithFields adds new fields to log.
|
||||
func (l *logger) WithFields(fields map[string]interface{}) Logger {
|
||||
return &logger{
|
||||
logger: l.logger.WithFields(logrus.Fields(fields)),
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs a message at level Info.
|
||||
func (l *logger) Info(args ...interface{}) {
|
||||
l.logger.Log(logrus.InfoLevel, args...)
|
||||
}
|
||||
|
||||
// Infof logs a message at level Info.
|
||||
func (l *logger) Infof(format string, args ...interface{}) {
|
||||
l.logger.Logf(logrus.InfoLevel, format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a message at level Debug.
|
||||
func (l *logger) Debug(args ...interface{}) {
|
||||
l.logger.Log(logrus.DebugLevel, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a message at level Debug.
|
||||
func (l *logger) Debugf(format string, args ...interface{}) {
|
||||
l.logger.Logf(logrus.DebugLevel, format, args...)
|
||||
}
|
||||
|
||||
// Warn logs a message at level Warn.
|
||||
func (l *logger) Warn(args ...interface{}) {
|
||||
l.logger.Log(logrus.WarnLevel, args...)
|
||||
}
|
||||
|
||||
// Warnf logs a message at level Warn.
|
||||
func (l *logger) Warnf(format string, args ...interface{}) {
|
||||
l.logger.Logf(logrus.WarnLevel, format, args...)
|
||||
}
|
||||
|
||||
// Error logs a message at level Error.
|
||||
func (l *logger) Error(args ...interface{}) {
|
||||
l.logger.Log(logrus.ErrorLevel, args...)
|
||||
}
|
||||
|
||||
// Errorf logs a message at level Error.
|
||||
func (l *logger) Errorf(format string, args ...interface{}) {
|
||||
l.logger.Logf(logrus.ErrorLevel, format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a message at level Fatal then the process will exit with status set to 1.
|
||||
func (l *logger) Fatal(args ...interface{}) {
|
||||
l.logger.Fatal(args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a message at level Fatal then the process will exit with status set to 1.
|
||||
func (l *logger) Fatalf(format string, args ...interface{}) {
|
||||
l.logger.Fatalf(format, args...)
|
||||
}
|
57
pkg/logger/logger.go
Normal file
57
pkg/logger/logger.go
Normal file
@ -0,0 +1,57 @@
|
||||
package logger
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
logFieldScope = "scope"
|
||||
)
|
||||
|
||||
// LogLevel is Logger Level type
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
// DebugLevel has verbose message
|
||||
DebugLevel LogLevel = "debug"
|
||||
// InfoLevel is default log level
|
||||
InfoLevel LogLevel = "info"
|
||||
// WarnLevel is for logging messages about possible issues
|
||||
WarnLevel LogLevel = "warn"
|
||||
// ErrorLevel is for logging errors
|
||||
ErrorLevel LogLevel = "error"
|
||||
// FatalLevel is for logging fatal messages. The system shuts down after logging the message.
|
||||
FatalLevel LogLevel = "fatal"
|
||||
)
|
||||
|
||||
var (
|
||||
globalLoggers = make(map[string]Logger)
|
||||
globalLoggersLock sync.RWMutex
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
EnableJSONOutput(enabled bool)
|
||||
SetLevel(level LogLevel)
|
||||
WithFields(map[string]interface{}) Logger
|
||||
Debug(args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warn(args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Error(args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Fatal(args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
func NewLogger(name string) Logger {
|
||||
globalLoggersLock.Lock()
|
||||
defer globalLoggersLock.Unlock()
|
||||
|
||||
logger, ok := globalLoggers[name]
|
||||
if !ok {
|
||||
logger = newLogger(name)
|
||||
globalLoggers[name] = logger
|
||||
}
|
||||
|
||||
return logger
|
||||
}
|
63
pkg/service/service.go
Normal file
63
pkg/service/service.go
Normal file
@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/components/handler"
|
||||
"github.com/go-gost/gost/pkg/components/listener"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
listener listener.Listener
|
||||
handler handler.Handler
|
||||
}
|
||||
|
||||
func (s *Service) WithListener(ln listener.Listener) *Service {
|
||||
s.listener = ln
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Service) WithHandler(h handler.Handler) *Service {
|
||||
s.handler = h
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Service) Addr() net.Addr {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *Service) Run() error {
|
||||
return s.serve()
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *Service) serve() error {
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
conn, e := s.listener.Accept()
|
||||
if e != nil {
|
||||
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 5 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if max := 1 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
// log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay)
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
}
|
||||
return e
|
||||
}
|
||||
tempDelay = 0
|
||||
|
||||
go s.handler.Handle(context.Background(), conn)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user