initial commit

This commit is contained in:
ginuerzh
2022-03-14 20:27:14 +08:00
commit 9397cb5351
175 changed files with 16196 additions and 0 deletions

210
listener/dns/listener.go Normal file
View File

@ -0,0 +1,210 @@
package dns
import (
"bytes"
"encoding/base64"
"errors"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/miekg/dns"
)
func init() {
registry.ListenerRegistry().Register("dns", NewListener)
}
type dnsListener struct {
addr net.Addr
server Server
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &dnsListener{
logger: options.Logger,
options: options,
}
}
func (l *dnsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return err
}
switch strings.ToLower(l.md.mode) {
case "tcp":
l.server = &dns.Server{
Net: "tcp",
Addr: l.options.Addr,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
case "tls":
l.server = &dns.Server{
Net: "tcp-tls",
Addr: l.options.Addr,
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
case "https":
l.server = &dohServer{
addr: l.options.Addr,
tlsConfig: l.options.TLSConfig,
server: &http.Server{
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
default:
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
l.server = &dns.Server{
Net: "udp",
Addr: l.options.Addr,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
}
if err != nil {
return
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go func() {
err := l.server.ListenAndServe()
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *dnsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
conn = metrics.WrapConn(l.options.Service, conn)
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *dnsListener) Close() error {
return l.server.Shutdown()
}
func (l *dnsListener) Addr() net.Addr {
return l.addr
}
func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
b, err := m.Pack()
if err != nil {
l.logger.Error(err)
return
}
if err := l.serve(w, b); err != nil {
l.logger.Error(err)
}
}
// Based on https://github.com/semihalev/sdns
func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var buf []byte
var err error
switch r.Method {
case http.MethodGet:
buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns"))
if len(buf) == 0 || err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
case http.MethodPost:
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" {
l.logger.Errorf("unsupported media type: %s", ct)
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
return
}
buf, err = ioutil.ReadAll(r.Body)
if err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
default:
l.logger.Errorf("method not allowd: %s", r.Method)
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
mq := &dns.Msg{}
if err := mq.Unpack(buf); err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
w.Header().Set("Server", "SDNS")
w.Header().Set("Content-Type", "application/dns-message")
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
if err := l.serve(&dohResponseWriter{raddr: raddr, ResponseWriter: w}, buf); err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
func (l *dnsListener) serve(w ResponseWriter, msg []byte) (err error) {
conn := &serverConn{
r: bytes.NewReader(msg),
w: w,
laddr: l.addr,
closed: make(chan struct{}),
}
select {
case l.cqueue <- conn:
default:
l.logger.Warnf("connection queue is full, client %s discarded", w.RemoteAddr())
return errors.New("connection queue is full")
}
return conn.Wait()
}

41
listener/dns/metadata.go Normal file
View File

@ -0,0 +1,41 @@
package dns
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
mode string
readBufferSize int
readTimeout time.Duration
writeTimeout time.Duration
backlog int
}
func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
mode = "mode"
readBufferSize = "readBufferSize"
readTimeout = "readTimeout"
writeTimeout = "writeTimeout"
)
l.md.mode = mdata.GetString(md, mode)
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
l.md.readTimeout = mdata.GetDuration(md, readTimeout)
l.md.writeTimeout = mdata.GetDuration(md, writeTimeout)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

110
listener/dns/server.go Normal file
View File

@ -0,0 +1,110 @@
package dns
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"time"
)
type Server interface {
ListenAndServe() error
Shutdown() error
}
type dohServer struct {
addr string
tlsConfig *tls.Config
server *http.Server
}
func (s *dohServer) ListenAndServe() error {
ln, err := net.Listen("tcp", s.addr)
if err != nil {
return err
}
ln = tls.NewListener(ln, s.tlsConfig)
return s.server.Serve(ln)
}
func (s *dohServer) Shutdown() error {
return s.server.Shutdown(context.Background())
}
type ResponseWriter interface {
io.Writer
RemoteAddr() net.Addr
}
type dohResponseWriter struct {
raddr net.Addr
http.ResponseWriter
}
func (w *dohResponseWriter) RemoteAddr() net.Addr {
return w.raddr
}
type serverConn struct {
r io.Reader
w ResponseWriter
laddr net.Addr
closed chan struct{}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
select {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.r.Read(b)
}
}
func (c *serverConn) Write(b []byte) (n int, err error) {
select {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.w.Write(b)
}
}
func (c *serverConn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *serverConn) Wait() error {
<-c.closed
return nil
}
func (c *serverConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.w.RemoteAddr()
}
func (c *serverConn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *serverConn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *serverConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

124
listener/ftcp/conn.go Normal file
View File

@ -0,0 +1,124 @@
package ftcp
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
pc net.PacketConn
raddr net.Addr
rc chan []byte // data receive queue
fresh int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
pc: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
return
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.pc.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) LocalAddr() net.Addr {
return c.pc.LocalAddr()
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *serverConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *serverConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *serverConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
}
case <-c.closed:
return
}
}
}

159
listener/ftcp/listener.go Normal file
View File

@ -0,0 +1,159 @@
package ftcp
import (
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/xtaci/tcpraw"
)
func init() {
registry.ListenerRegistry().Register("ftcp", NewListener)
}
type ftcpListener struct {
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &ftcpListener{
logger: options.Logger,
options: options,
}
}
func (l *ftcpListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.conn, err = tcpraw.Listen("tcp", l.options.Addr)
if err != nil {
return
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *ftcpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
conn = metrics.WrapConn(l.options.Service, conn)
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *ftcpListener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k any, v *serverConn) bool {
v.Close()
return true
})
return err
}
func (l *ftcpListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
func (l *ftcpListener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
}
l.logger.Debug("recv", n)
}
}
func (l *ftcpListener) parseMetadata(md md.Metadata) (err error) {
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key any) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key any, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key any) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key any, value *serverConn) bool) {
p.m.Range(func(k, v any) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

22
listener/ftcp/metadata.go Normal file
View File

@ -0,0 +1,22 @@
package ftcp
import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)
const (
addr = "addr"
)
type metadata struct {
ttl time.Duration
readBufferSize int
readQueueSize int
connQueueSize int
}

100
listener/grpc/listener.go Normal file
View File

@ -0,0 +1,100 @@
package grpc
import (
"net"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
pb "github.com/go-gost/x/internal/util/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
func init() {
registry.ListenerRegistry().Register("grpc", NewListener)
}
type grpcListener struct {
addr net.Addr
server *grpc.Server
cqueue chan net.Conn
errChan chan error
md metadata
logger logger.Logger
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &grpcListener{
logger: options.Logger,
options: options,
}
}
func (l *grpcListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
var opts []grpc.ServerOption
if !l.md.insecure {
opts = append(opts, grpc.Creds(credentials.NewTLS(l.options.TLSConfig)))
}
l.server = grpc.NewServer(opts...)
l.addr = ln.Addr()
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
pb.RegisterGostTunelServer(l.server, &server{
cqueue: l.cqueue,
localAddr: l.addr,
logger: l.options.Logger,
})
go func() {
err := l.server.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *grpcListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *grpcListener) Close() error {
l.server.Stop()
return nil
}
func (l *grpcListener) Addr() net.Addr {
return l.addr
}

29
listener/grpc/metadata.go Normal file
View File

@ -0,0 +1,29 @@
package grpc
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
backlog int
insecure bool
}
func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
insecure = "grpcInsecure"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.insecure = mdata.GetBool(md, insecure)
return
}

124
listener/grpc/server.go Normal file
View File

@ -0,0 +1,124 @@
package grpc
import (
"errors"
"io"
"net"
"time"
"github.com/go-gost/gost/v3/pkg/logger"
pb "github.com/go-gost/x/internal/util/grpc/proto"
"google.golang.org/grpc/peer"
)
type server struct {
cqueue chan net.Conn
localAddr net.Addr
pb.UnimplementedGostTunelServer
logger logger.Logger
}
func (s *server) Tunnel(srv pb.GostTunel_TunnelServer) error {
c := &conn{
s: srv,
localAddr: s.localAddr,
remoteAddr: &net.TCPAddr{},
closed: make(chan struct{}),
}
if p, ok := peer.FromContext(srv.Context()); ok {
c.remoteAddr = p.Addr
}
select {
case s.cqueue <- c:
default:
c.Close()
s.logger.Warnf("connection queue is full, client discarded")
}
<-c.closed
return nil
}
type conn struct {
s pb.GostTunel_TunnelServer
rb []byte
localAddr net.Addr
remoteAddr net.Addr
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
select {
case <-c.s.Context().Done():
err = c.s.Context().Err()
return
case <-c.closed:
err = io.ErrClosedPipe
return
default:
}
if len(c.rb) == 0 {
chunk, err := c.s.Recv()
if err != nil {
return 0, err
}
c.rb = chunk.Data
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *conn) Write(b []byte) (n int, err error) {
select {
case <-c.s.Context().Done():
err = c.s.Context().Err()
return
case <-c.closed:
err = io.ErrClosedPipe
return
default:
}
if err = c.s.Send(&pb.Chunk{
Data: b,
}); err != nil {
return
}
n = len(b)
return
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

54
listener/http2/conn.go Normal file
View File

@ -0,0 +1,54 @@
package http2
import (
"errors"
"net"
"net/http"
"time"
)
// a dummy HTTP2 server conn used by HTTP2 handler
type conn struct {
r *http.Request
w http.ResponseWriter
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
}
func (c *conn) Write(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
return addr
}
func (c *conn) RemoteAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
return addr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

89
listener/http2/h2/conn.go Normal file
View File

@ -0,0 +1,89 @@
package h2
import (
"errors"
"io"
"net"
"net/http"
"time"
)
// HTTP2 connection, wrapped up just like a net.Conn
type conn struct {
r io.Reader
w io.Writer
remoteAddr net.Addr
localAddr net.Addr
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
return c.w.Write(b)
}
func (c *conn) Close() (err error) {
select {
case <-c.closed:
return
default:
close(c.closed)
}
if rc, ok := c.r.(io.Closer); ok {
err = rc.Close()
}
if w, ok := c.w.(io.Closer); ok {
err = w.Close()
}
return
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(string); ok {
err = errors.New(s)
// log.Log("[http2]", err)
return
}
err = r.(error)
}
}()
n, err = fw.w.Write(p)
if err != nil {
// log.Log("flush writer:", err)
return
}
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}

View File

@ -0,0 +1,178 @@
package h2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"net/http/httputil"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func init() {
registry.ListenerRegistry().Register("h2c", NewListener)
registry.ListenerRegistry().Register("h2", NewTLSListener)
}
type h2Listener struct {
server *http.Server
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
h2c bool
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &h2Listener{
h2c: true,
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &h2Listener{
logger: options.Logger,
options: options,
}
}
func (l *h2Listener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.server = &http.Server{
Addr: l.options.Addr,
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
l.addr = ln.Addr()
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
if l.h2c {
l.server.Handler = h2c.NewHandler(
http.HandlerFunc(l.handleFunc), &http2.Server{})
} else {
l.server.Handler = http.HandlerFunc(l.handleFunc)
l.server.TLSConfig = l.options.TLSConfig
if err := http2.ConfigureServer(l.server, nil); err != nil {
ln.Close()
return err
}
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *h2Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *h2Listener) Addr() net.Addr {
return l.addr
}
func (l *h2Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
conn, err := l.upgrade(w, r)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.cqueue <- conn:
default:
conn.Close()
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
}
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
}
func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
if l.md.path == "" && r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed)
return nil, errors.New("method not allowed")
}
if l.md.path != "" && r.RequestURI != l.md.path {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New("bad request")
}
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush() // write header to client
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if remoteAddr == nil {
remoteAddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
return &conn{
r: r.Body,
w: flushWriter{w},
localAddr: l.addr,
remoteAddr: remoteAddr,
closed: make(chan struct{}),
}, nil
}

View File

@ -0,0 +1,29 @@
package h2
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
path string
backlog int
}
func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
backlog = "backlog"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.path = mdata.GetString(md, path)
return
}

120
listener/http2/listener.go Normal file
View File

@ -0,0 +1,120 @@
package http2
import (
"crypto/tls"
"net"
"net/http"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
http2_util "github.com/go-gost/x/internal/util/http2"
"golang.org/x/net/http2"
)
func init() {
registry.ListenerRegistry().Register("http2", NewListener)
}
type http2Listener struct {
server *http.Server
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &http2Listener{
logger: options.Logger,
options: options,
}
}
func (l *http2Listener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.server = &http.Server{
Addr: l.options.Addr,
Handler: http.HandlerFunc(l.handleFunc),
TLSConfig: l.options.TLSConfig,
}
if err := http2.ConfigureServer(l.server, nil); err != nil {
return err
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
l.addr = ln.Addr()
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
ln = tls.NewListener(
ln,
l.options.TLSConfig,
)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *http2Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *http2Listener) Addr() net.Addr {
return l.addr
}
func (l *http2Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
conn := http2_util.NewServerConn(w, r, l.addr, raddr)
select {
case l.cqueue <- conn:
default:
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
return
}
<-conn.Done()
}

View File

@ -0,0 +1,25 @@
package http2
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
backlog int
}
func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

View File

@ -0,0 +1,81 @@
package http3
import (
"net"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
pht_util "github.com/go-gost/x/internal/util/pht"
"github.com/lucas-clemente/quic-go"
)
func init() {
registry.ListenerRegistry().Register("http3", NewListener)
registry.ListenerRegistry().Register("h3", NewListener)
}
type http3Listener struct {
addr net.Addr
server *pht_util.Server
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &http3Listener{
logger: options.Logger,
options: options,
}
}
func (l *http3Listener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
l.server = pht_util.NewHTTP3Server(
l.options.Addr,
&quic.Config{},
pht_util.TLSConfigServerOption(l.options.TLSConfig),
pht_util.BacklogServerOption(l.md.backlog),
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
pht_util.LoggerServerOption(l.options.Logger),
)
go func() {
if err := l.server.ListenAndServe(); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *http3Listener) Accept() (conn net.Conn, err error) {
conn, err = l.server.Accept()
if err != nil {
return
}
return metrics.WrapConn(l.options.Service, conn), nil
}
func (l *http3Listener) Addr() net.Addr {
return l.addr
}
func (l *http3Listener) Close() (err error) {
return l.server.Close()
}

View File

@ -0,0 +1,51 @@
package http3
import (
"strings"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultAuthorizePath = "/authorize"
defaultPushPath = "/push"
defaultPullPath = "/pull"
defaultBacklog = 128
)
type metadata struct {
authorizePath string
pushPath string
pullPath string
backlog int
}
func (l *http3Listener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizePath = "authorizePath"
pushPath = "pushPath"
pullPath = "pullPath"
backlog = "backlog"
)
l.md.authorizePath = mdata.GetString(md, authorizePath)
if !strings.HasPrefix(l.md.authorizePath, "/") {
l.md.authorizePath = defaultAuthorizePath
}
l.md.pushPath = mdata.GetString(md, pushPath)
if !strings.HasPrefix(l.md.pushPath, "/") {
l.md.pushPath = defaultPushPath
}
l.md.pullPath = mdata.GetString(md, pullPath)
if !strings.HasPrefix(l.md.pullPath, "/") {
l.md.pullPath = defaultPullPath
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

21
listener/icmp/conn.go Normal file
View File

@ -0,0 +1,21 @@
package quic
import (
"net"
"github.com/lucas-clemente/quic-go"
)
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}

147
listener/icmp/listener.go Normal file
View File

@ -0,0 +1,147 @@
package quic
import (
"context"
"net"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
icmp_pkg "github.com/go-gost/x/internal/util/icmp"
"github.com/lucas-clemente/quic-go"
"golang.org/x/net/icmp"
)
func init() {
registry.ListenerRegistry().Register("icmp", NewListener)
}
type icmpListener struct {
ln quic.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &icmpListener{
logger: options.Logger,
options: options,
}
}
func (l *icmpListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
addr := l.options.Addr
if host, _, err := net.SplitHostPort(addr); err == nil {
addr = host
}
var conn net.PacketConn
conn, err = icmp.ListenPacket("ip4:icmp", addr)
if err != nil {
return
}
conn = icmp_pkg.ServerConn(conn)
conn = metrics.WrapPacketConn(l.options.Service, conn)
conn = admission.WrapPacketConn(l.options.Admission, conn)
config := &quic.Config{
KeepAlive: l.md.keepAlive,
HandshakeIdleTimeout: l.md.handshakeTimeout,
MaxIdleTimeout: l.md.maxIdleTimeout,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
}
tlsCfg := l.options.TLSConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
ln, err := quic.Listen(conn, tlsCfg, config)
if err != nil {
return
}
l.ln = ln
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *icmpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *icmpListener) Close() error {
return l.ln.Close()
}
func (l *icmpListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *icmpListener) listenLoop() {
for {
ctx := context.Background()
session, err := l.ln.Accept(ctx)
if err != nil {
l.logger.Error("accept: ", err)
l.errChan <- err
close(l.errChan)
return
}
l.logger.Infof("new client session: %v", session.RemoteAddr())
go l.mux(ctx, session)
}
}
func (l *icmpListener) mux(ctx context.Context, session quic.Session) {
defer session.CloseWithError(0, "closed")
for {
stream, err := session.AcceptStream(ctx)
if err != nil {
l.logger.Error("accept stream: ", err)
return
}
conn := &quicConn{
Stream: stream,
laddr: session.LocalAddr(),
raddr: session.RemoteAddr(),
}
select {
case l.cqueue <- conn:
case <-stream.Context().Done():
stream.Close()
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr())
}
}
}

41
listener/icmp/metadata.go Normal file
View File

@ -0,0 +1,41 @@
package quic
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
keepAlive bool
handshakeTimeout time.Duration
maxIdleTimeout time.Duration
cipherKey []byte
backlog int
}
func (l *icmpListener) parseMetadata(md mdata.Metadata) (err error) {
const (
keepAlive = "keepAlive"
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
backlog = "backlog"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.keepAlive = mdata.GetBool(md, keepAlive)
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)
return
}

179
listener/kcp/listener.go Normal file
View File

@ -0,0 +1,179 @@
package kcp
import (
"net"
"time"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
kcp_util "github.com/go-gost/x/internal/util/kcp"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"github.com/xtaci/tcpraw"
)
func init() {
registry.ListenerRegistry().Register("kcp", NewListener)
}
type kcpListener struct {
conn net.PacketConn
ln *kcp.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &kcpListener{
logger: options.Logger,
options: options,
}
}
func (l *kcpListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
config := l.md.config
config.Init()
var conn net.PacketConn
if config.TCP {
conn, err = tcpraw.Listen("tcp", l.options.Addr)
} else {
var udpAddr *net.UDPAddr
udpAddr, err = net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
conn, err = net.ListenUDP("udp", udpAddr)
}
if err != nil {
return
}
conn = metrics.WrapUDPConn(l.options.Service, conn)
conn = admission.WrapUDPConn(l.options.Admission, conn)
ln, err := kcp.ServeConn(
kcp_util.BlockCrypt(config.Key, config.Crypt, kcp_util.DefaultSalt),
config.DataShard, config.ParityShard, conn)
if err != nil {
return
}
if config.DSCP > 0 {
if er := ln.SetDSCP(config.DSCP); er != nil {
l.logger.Warn(er)
}
}
if er := ln.SetReadBuffer(config.SockBuf); er != nil {
l.logger.Warn(er)
}
if er := ln.SetWriteBuffer(config.SockBuf); er != nil {
l.logger.Warn(er)
}
l.ln = ln
l.conn = conn
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *kcpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *kcpListener) Close() error {
l.conn.Close()
return l.ln.Close()
}
func (l *kcpListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *kcpListener) listenLoop() {
for {
conn, err := l.ln.AcceptKCP()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn.SetStreamMode(true)
conn.SetWriteDelay(false)
conn.SetNoDelay(
l.md.config.NoDelay,
l.md.config.Interval,
l.md.config.Resend,
l.md.config.NoCongestion,
)
conn.SetMtu(l.md.config.MTU)
conn.SetWindowSize(l.md.config.SndWnd, l.md.config.RcvWnd)
conn.SetACKNoDelay(l.md.config.AckNodelay)
go l.mux(conn)
}
}
func (l *kcpListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.MaxReceiveBuffer = l.md.config.SockBuf
smuxConfig.KeepAliveInterval = time.Duration(l.md.config.KeepAlive) * time.Second
if !l.md.config.NoComp {
conn = kcp_util.CompStreamConn(conn)
}
mux, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer mux.Close()
for {
stream, err := mux.AcceptStream()
if err != nil {
l.logger.Error("accept stream: ", err)
return
}
select {
case l.cqueue <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}

47
listener/kcp/metadata.go Normal file
View File

@ -0,0 +1,47 @@
package kcp
import (
"encoding/json"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
kcp_util "github.com/go-gost/x/internal/util/kcp"
)
const (
defaultBacklog = 128
)
type metadata struct {
config *kcp_util.Config
backlog int
}
func (l *kcpListener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
config = "config"
)
if m := mdata.GetStringMap(md, config); len(m) > 0 {
b, err := json.Marshal(m)
if err != nil {
return err
}
cfg := &kcp_util.Config{}
if err := json.Unmarshal(b, cfg); err != nil {
return err
}
l.md.config = cfg
}
if l.md.config == nil {
l.md.config = kcp_util.DefaultConfig
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

129
listener/mtls/listener.go Normal file
View File

@ -0,0 +1,129 @@
package mtls
import (
"crypto/tls"
"net"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/xtaci/smux"
)
func init() {
registry.ListenerRegistry().Register("mtls", NewListener)
}
type mtlsListener struct {
net.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &mtlsListener{
logger: options.Logger,
options: options,
}
}
func (l *mtlsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
l.Listener = tls.NewListener(ln, l.options.TLSConfig)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *mtlsListener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.errChan <- err
close(l.errChan)
return
}
go l.mux(conn)
}
}
func (l *mtlsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream: ", err)
return
}
select {
case l.cqueue <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}

49
listener/mtls/metadata.go Normal file
View File

@ -0,0 +1,49 @@
package mtls
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
muxKeepAliveDisabled bool
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
backlog int
}
func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)
l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize)
l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer)
l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer)
return
}

194
listener/mws/listener.go Normal file
View File

@ -0,0 +1,194 @@
package mws
import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
ws_util "github.com/go-gost/x/internal/util/ws"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
func init() {
registry.ListenerRegistry().Register("mws", NewListener)
registry.ListenerRegistry().Register("mwss", NewTLSListener)
}
type mwsListener struct {
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
cqueue chan net.Conn
errChan chan error
tlsEnabled bool
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &mwsListener{
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &mwsListener{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
func (l *mwsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.options.Addr,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
if l.tlsEnabled {
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *mwsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *mwsListener) Close() error {
return l.srv.Close()
}
func (l *mwsListener) Addr() net.Addr {
return l.addr
}
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]any{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
l.mux(ws_util.Conn(conn))
}
func (l *mwsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream: ", err)
return
}
select {
case l.cqueue <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}

85
listener/mws/metadata.go Normal file
View File

@ -0,0 +1,85 @@
package mws
import (
"net/http"
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
backlog int
header http.Header
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
muxKeepAliveDisabled bool
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
}
func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
backlog = "backlog"
header = "header"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
l.md.path = mdata.GetString(md, path)
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize)
l.md.enableCompression = mdata.GetBool(md, enableCompression)
l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)
l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize)
l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer)
l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer)
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
hd := http.Header{}
for k, v := range mm {
hd.Add(k, v)
}
l.md.header = hd
}
return
}

145
listener/obfs/http/conn.go Normal file
View File

@ -0,0 +1,145 @@
package http
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"sync"
"time"
"github.com/go-gost/gost/v3/pkg/logger"
)
type obfsHTTPConn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
header http.Header
logger logger.Logger
}
func (c *obfsHTTPConn) Handshake() (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshaked {
return nil
}
if err = c.handshake(); err != nil {
return
}
c.handshaked = true
return nil
}
func (c *obfsHTTPConn) handshake() (err error) {
br := bufio.NewReader(c.Conn)
r, err := http.ReadRequest(br)
if err != nil {
return
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
c.logger.Debug(string(dump))
}
if r.ContentLength > 0 {
_, err = io.Copy(&c.rbuf, r.Body)
} else {
var b []byte
b, err = br.Peek(br.Buffered())
if len(b) > 0 {
_, err = c.rbuf.Write(b)
}
}
if err != nil {
c.logger.Error(err)
return
}
resp := http.Response{
StatusCode: http.StatusOK,
ProtoMajor: 1,
ProtoMinor: 1,
Header: c.header,
}
if resp.Header == nil {
resp.Header = http.Header{}
}
resp.Header.Set("Date", time.Now().Format(time.RFC1123))
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
resp.StatusCode = http.StatusBadRequest
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(&resp, false)
c.logger.Debug(string(dump))
}
resp.Write(c.Conn)
return errors.New("bad request")
}
resp.StatusCode = http.StatusSwitchingProtocols
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "websocket")
resp.Header.Set("Sec-WebSocket-Accept", c.computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))
if c.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(&resp, false)
c.logger.Debug(string(dump))
}
if c.rbuf.Len() > 0 {
// cache the response header if there are extra data in the request body.
resp.Write(&c.wbuf)
return
}
err = resp.Write(c.Conn)
return
}
func (c *obfsHTTPConn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
return c.Conn.Read(b)
}
func (c *obfsHTTPConn) Write(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
n = len(b) // exclude the header length
return
}
return c.Conn.Write(b)
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func (c *obfsHTTPConn) computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View File

@ -0,0 +1,63 @@
package http
import (
"net"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
)
func init() {
registry.ListenerRegistry().Register("ohttp", NewListener)
}
type obfsListener struct {
net.Listener
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &obfsListener{
logger: options.Logger,
options: options,
}
}
func (l *obfsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
l.Listener = ln
return
}
func (l *obfsListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &obfsHTTPConn{
Conn: c,
header: l.md.header,
logger: l.logger,
}, nil
}

View File

@ -0,0 +1,26 @@
package http
import (
"net/http"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
header http.Header
}
func (l *obfsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
header = "header"
)
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
hd := http.Header{}
for k, v := range mm {
hd.Add(k, v)
}
l.md.header = hd
}
return
}

161
listener/obfs/tls/conn.go Normal file
View File

@ -0,0 +1,161 @@
package tls
import (
"bytes"
"crypto/rand"
"crypto/tls"
"net"
"sync"
"time"
dissector "github.com/go-gost/tls-dissector"
)
const (
maxTLSDataLen = 16384
)
type obfsTLSConn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
}
func (c *obfsTLSConn) Handshake() (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshaked {
return
}
if err = c.handshake(); err != nil {
return
}
c.handshaked = true
return nil
}
func (c *obfsTLSConn) handshake() error {
record := &dissector.Record{}
if _, err := record.ReadFrom(c.Conn); err != nil {
// log.Log(err)
return err
}
if record.Type != dissector.Handshake {
return dissector.ErrBadType
}
clientMsg := &dissector.ClientHelloMsg{}
if err := clientMsg.Decode(record.Opaque); err != nil {
// log.Log(err)
return err
}
for _, ext := range clientMsg.Extensions {
if ext.Type() == dissector.ExtSessionTicket {
b, err := ext.Encode()
if err != nil {
// log.Log(err)
return err
}
c.rbuf.Write(b)
break
}
}
serverMsg := &dissector.ServerHelloMsg{
Version: tls.VersionTLS12,
SessionID: clientMsg.SessionID,
CipherSuite: 0xcca8,
CompressionMethod: 0x00,
Extensions: []dissector.Extension{
&dissector.RenegotiationInfoExtension{},
&dissector.ExtendedMasterSecretExtension{},
&dissector.ECPointFormatsExtension{
Formats: []uint8{0x00},
},
},
}
serverMsg.Random.Time = uint32(time.Now().Unix())
rand.Read(serverMsg.Random.Opaque[:])
b, err := serverMsg.Encode()
if err != nil {
return err
}
record = &dissector.Record{
Type: dissector.Handshake,
Version: tls.VersionTLS10,
Opaque: b,
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
record = &dissector.Record{
Type: dissector.ChangeCipherSpec,
Version: tls.VersionTLS12,
Opaque: []byte{0x01},
}
if _, err := record.WriteTo(&c.wbuf); err != nil {
return err
}
return nil
}
func (c *obfsTLSConn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
record := &dissector.Record{}
if _, err = record.ReadFrom(c.Conn); err != nil {
return
}
n = copy(b, record.Opaque)
_, err = c.rbuf.Write(record.Opaque[n:])
return
}
func (c *obfsTLSConn) Write(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
n = len(b)
for len(b) > 0 {
data := b
if len(b) > maxTLSDataLen {
data = b[:maxTLSDataLen]
b = b[maxTLSDataLen:]
} else {
b = b[:0]
}
record := &dissector.Record{
Type: dissector.AppData,
Version: tls.VersionTLS12,
Opaque: data,
}
if c.wbuf.Len() > 0 {
record.Type = dissector.Handshake
record.WriteTo(&c.wbuf)
_, err = c.wbuf.WriteTo(c.Conn)
return
}
if _, err = record.WriteTo(c.Conn); err != nil {
return
}
}
return
}

View File

@ -0,0 +1,61 @@
package tls
import (
"net"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
)
func init() {
registry.ListenerRegistry().Register("otls", NewListener)
}
type obfsListener struct {
net.Listener
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &obfsListener{
logger: options.Logger,
options: options,
}
}
func (l *obfsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
l.Listener = ln
return
}
func (l *obfsListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &obfsTLSConn{
Conn: c,
}, nil
}

View File

@ -0,0 +1,12 @@
package tls
import (
md "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
}
func (l *obfsListener) parseMetadata(md md.Metadata) (err error) {
return
}

96
listener/pht/listener.go Normal file
View File

@ -0,0 +1,96 @@
// plain http tunnel
package pht
import (
"net"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
pht_util "github.com/go-gost/x/internal/util/pht"
)
func init() {
registry.ListenerRegistry().Register("pht", NewListener)
registry.ListenerRegistry().Register("phts", NewTLSListener)
}
type phtListener struct {
addr net.Addr
tlsEnabled bool
server *pht_util.Server
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtListener{
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtListener{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
func (l *phtListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return
}
l.server = pht_util.NewServer(
l.options.Addr,
pht_util.TLSConfigServerOption(l.options.TLSConfig),
pht_util.EnableTLSServerOption(l.tlsEnabled),
pht_util.BacklogServerOption(l.md.backlog),
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
pht_util.LoggerServerOption(l.options.Logger),
)
go func() {
if err := l.server.ListenAndServe(); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *phtListener) Accept() (conn net.Conn, err error) {
conn, err = l.server.Accept()
if err != nil {
return
}
conn = metrics.WrapConn(l.options.Service, conn)
return
}
func (l *phtListener) Addr() net.Addr {
return l.addr
}
func (l *phtListener) Close() (err error) {
return l.server.Close()
}

51
listener/pht/metadata.go Normal file
View File

@ -0,0 +1,51 @@
package pht
import (
"strings"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultAuthorizePath = "/authorize"
defaultPushPath = "/push"
defaultPullPath = "/pull"
defaultBacklog = 128
)
type metadata struct {
authorizePath string
pushPath string
pullPath string
backlog int
}
func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizePath = "authorizePath"
pushPath = "pushPath"
pullPath = "pullPath"
backlog = "backlog"
)
l.md.authorizePath = mdata.GetString(md, authorizePath)
if !strings.HasPrefix(l.md.authorizePath, "/") {
l.md.authorizePath = defaultAuthorizePath
}
l.md.pushPath = mdata.GetString(md, pushPath)
if !strings.HasPrefix(l.md.pushPath, "/") {
l.md.pushPath = defaultPushPath
}
l.md.pullPath = mdata.GetString(md, pullPath)
if !strings.HasPrefix(l.md.pullPath, "/") {
l.md.pullPath = defaultPullPath
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

21
listener/quic/conn.go Normal file
View File

@ -0,0 +1,21 @@
package quic
import (
"net"
"github.com/lucas-clemente/quic-go"
)
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}

151
listener/quic/listener.go Normal file
View File

@ -0,0 +1,151 @@
package quic
import (
"context"
"net"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
quic_util "github.com/go-gost/x/internal/util/quic"
"github.com/lucas-clemente/quic-go"
)
func init() {
registry.ListenerRegistry().Register("quic", NewListener)
}
type quicListener struct {
ln quic.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &quicListener{
logger: options.Logger,
options: options,
}
}
func (l *quicListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
addr := l.options.Addr
if _, _, err := net.SplitHostPort(addr); err != nil {
addr = net.JoinHostPort(addr, "0")
}
var laddr *net.UDPAddr
laddr, err = net.ResolveUDPAddr("udp", addr)
if err != nil {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
if l.md.cipherKey != nil {
conn = quic_util.CipherPacketConn(conn, l.md.cipherKey)
}
config := &quic.Config{
KeepAlive: l.md.keepAlive,
HandshakeIdleTimeout: l.md.handshakeTimeout,
MaxIdleTimeout: l.md.maxIdleTimeout,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
}
tlsCfg := l.options.TLSConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
ln, err := quic.Listen(conn, tlsCfg, config)
if err != nil {
return
}
l.ln = ln
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *quicListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
conn = metrics.WrapConn(l.options.Service, conn)
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *quicListener) Close() error {
return l.ln.Close()
}
func (l *quicListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *quicListener) listenLoop() {
for {
ctx := context.Background()
session, err := l.ln.Accept(ctx)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.mux(ctx, session)
}
}
func (l *quicListener) mux(ctx context.Context, session quic.Session) {
defer session.CloseWithError(0, "closed")
for {
stream, err := session.AcceptStream(ctx)
if err != nil {
l.logger.Error("accept stream:", err)
return
}
conn := &quicConn{
Stream: stream,
laddr: session.LocalAddr(),
raddr: session.RemoteAddr(),
}
select {
case l.cqueue <- conn:
case <-stream.Context().Done():
stream.Close()
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr())
}
}
}

46
listener/quic/metadata.go Normal file
View File

@ -0,0 +1,46 @@
package quic
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
keepAlive bool
handshakeTimeout time.Duration
maxIdleTimeout time.Duration
cipherKey []byte
backlog int
}
func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) {
const (
keepAlive = "keepAlive"
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
backlog = "backlog"
cipherKey = "cipherKey"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
if key := mdata.GetString(md, cipherKey); key != "" {
l.md.cipherKey = []byte(key)
}
l.md.keepAlive = mdata.GetBool(md, keepAlive)
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)
return
}

View File

@ -0,0 +1,42 @@
package udp
import (
"net"
"sync"
"time"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
)
type redirConn struct {
net.Conn
buf []byte
ttl time.Duration
once sync.Once
}
func (c *redirConn) Read(b []byte) (n int, err error) {
if c.ttl > 0 {
c.SetReadDeadline(time.Now().Add(c.ttl))
defer c.SetReadDeadline(time.Time{})
}
c.once.Do(func() {
n = copy(b, c.buf)
bufpool.Put(&c.buf)
c.buf = nil
})
if n == 0 {
n, err = c.Conn.Read(b)
}
return
}
func (c *redirConn) Write(b []byte) (n int, err error) {
if c.ttl > 0 {
c.SetWriteDeadline(time.Now().Add(c.ttl))
defer c.SetWriteDeadline(time.Time{})
}
return c.Conn.Write(b)
}

View File

@ -0,0 +1,68 @@
package udp
import (
"net"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
)
func init() {
registry.ListenerRegistry().Register("redu", NewListener)
}
type redirectListener struct {
ln *net.UDPConn
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &redirectListener{
logger: options.Logger,
options: options,
}
}
func (l *redirectListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
ln, err := l.listenUDP(laddr)
if err != nil {
return
}
l.ln = ln
return
}
func (l *redirectListener) Accept() (conn net.Conn, err error) {
conn, err = l.accept()
if err != nil {
return
}
// conn = metrics.WrapConn(l.options.Service, conn)
return
}
func (l *redirectListener) Addr() net.Addr {
return l.ln.LocalAddr()
}
func (l *redirectListener) Close() error {
return l.ln.Close()
}

View File

@ -0,0 +1,37 @@
package udp
import (
"net"
"github.com/LiamHaworth/go-tproxy"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
)
func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) {
return tproxy.ListenUDP("udp", addr)
}
func (l *redirectListener) accept() (conn net.Conn, err error) {
b := bufpool.Get(l.md.readBufferSize)
n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.ln, *b)
if err != nil {
l.logger.Error(err)
return
}
l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String())
c, err := tproxy.DialUDP("udp", dstAddr, raddr)
if err != nil {
l.logger.Error(err)
return
}
conn = &redirConn{
Conn: c,
buf: (*b)[:n],
ttl: l.md.ttl,
}
return
}

View File

@ -0,0 +1,16 @@
//go:build !linux
package udp
import (
"errors"
"net"
)
func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) {
return nil, errors.New("UDP redirect is not available on non-linux platform")
}
func (l *redirectListener) accept() (conn net.Conn, err error) {
return nil, errors.New("UDP redirect is not available on non-linux platform")
}

View File

@ -0,0 +1,36 @@
package udp
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
)
type metadata struct {
ttl time.Duration
readBufferSize int
}
func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) {
const (
ttl = "ttl"
readBufferSize = "readBufferSize"
)
l.md.ttl = mdata.GetDuration(md, ttl)
if l.md.ttl <= 0 {
l.md.ttl = defaultTTL
}
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
if l.md.readBufferSize <= 0 {
l.md.readBufferSize = defaultReadBufferSize
}
return
}

148
listener/ssh/listener.go Normal file
View File

@ -0,0 +1,148 @@
package ssh
import (
"fmt"
"net"
"time"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
ssh_util "github.com/go-gost/x/internal/util/ssh"
"golang.org/x/crypto/ssh"
)
func init() {
registry.ListenerRegistry().Register("ssh", NewListener)
}
type sshListener struct {
net.Listener
config *ssh.ServerConfig
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &sshListener{
logger: options.Logger,
options: options,
}
}
func (l *sshListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
l.Listener = ln
config := &ssh.ServerConfig{
PasswordCallback: ssh_util.PasswordCallback(l.options.Auther),
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
}
config.AddHostKey(l.md.signer)
if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 {
config.NoClientAuth = true
}
l.config = config
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *sshListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *sshListener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.serveConn(conn)
}
}
func (l *sshListener) serveConn(conn net.Conn) {
start := time.Now()
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
l.logger.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
if err != nil {
l.logger.Error(err)
conn.Close()
return
}
defer sc.Close()
go ssh.DiscardRequests(reqs)
go func() {
for newChannel := range chans {
// Check the type of channel
t := newChannel.ChannelType()
switch t {
case ssh_util.GostSSHTunnelRequest:
channel, requests, err := newChannel.Accept()
if err != nil {
l.logger.Warnf("could not accept channel: %s", err.Error())
continue
}
go ssh.DiscardRequests(requests)
cc := ssh_util.NewConn(conn, channel)
select {
case l.cqueue <- cc:
default:
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
cc.Close()
}
default:
l.logger.Warnf("unsupported channel type: %s", t)
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
}
}
}()
sc.Wait()
}

68
listener/ssh/metadata.go Normal file
View File

@ -0,0 +1,68 @@
package ssh
import (
"io/ioutil"
tls_util "github.com/go-gost/gost/v3/pkg/common/util/tls"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
ssh_util "github.com/go-gost/x/internal/util/ssh"
"golang.org/x/crypto/ssh"
)
const (
defaultBacklog = 128
)
type metadata struct {
signer ssh.Signer
authorizedKeys map[string]bool
backlog int
}
func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizedKeys = "authorizedKeys"
privateKeyFile = "privateKeyFile"
passphrase = "passphrase"
backlog = "backlog"
)
if key := mdata.GetString(md, privateKeyFile); key != "" {
data, err := ioutil.ReadFile(key)
if err != nil {
return err
}
pp := mdata.GetString(md, passphrase)
if pp == "" {
l.md.signer, err = ssh.ParsePrivateKey(data)
} else {
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
}
if err != nil {
return err
}
}
if l.md.signer == nil {
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
if err != nil {
return err
}
l.md.signer = signer
}
if name := mdata.GetString(md, authorizedKeys); name != "" {
m, err := ssh_util.ParseAuthorizedKeysFile(name)
if err != nil {
return err
}
l.md.authorizedKeys = m
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

199
listener/sshd/listener.go Normal file
View File

@ -0,0 +1,199 @@
package ssh
import (
"context"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
ssh_util "github.com/go-gost/x/internal/util/ssh"
sshd_util "github.com/go-gost/x/internal/util/sshd"
"golang.org/x/crypto/ssh"
)
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
const (
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
)
func init() {
registry.ListenerRegistry().Register("sshd", NewListener)
}
type sshdListener struct {
net.Listener
config *ssh.ServerConfig
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &sshdListener{
logger: options.Logger,
options: options,
}
}
func (l *sshdListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
l.Listener = ln
config := &ssh.ServerConfig{
PasswordCallback: ssh_util.PasswordCallback(l.options.Auther),
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
}
config.AddHostKey(l.md.signer)
if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 {
config.NoClientAuth = true
}
l.config = config
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *sshdListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *sshdListener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.serveConn(conn)
}
}
func (l *sshdListener) serveConn(conn net.Conn) {
start := time.Now()
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
l.logger.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
if err != nil {
l.logger.Error(err)
conn.Close()
return
}
defer sc.Close()
go func() {
for newChannel := range chans {
// Check the type of channel
t := newChannel.ChannelType()
switch t {
case DirectForwardRequest:
channel, requests, err := newChannel.Accept()
if err != nil {
l.logger.Warnf("could not accept channel: %s", err.Error())
continue
}
p := directForward{}
ssh.Unmarshal(newChannel.ExtraData(), &p)
l.logger.Debug(p.String())
if p.Host1 == "<nil>" {
p.Host1 = ""
}
go ssh.DiscardRequests(requests)
cc := sshd_util.NewDirectForwardConn(sc, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
select {
case l.cqueue <- cc:
default:
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
cc.Close()
}
default:
l.logger.Warnf("unsupported channel type: %s", t)
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
}
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for req := range reqs {
switch req.Type {
case RemoteForwardRequest:
cc := sshd_util.NewRemoteForwardConn(ctx, sc, req)
select {
case l.cqueue <- cc:
default:
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
req.Reply(false, []byte("connection queue is full"))
cc.Close()
}
default:
l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
req.Reply(false, nil)
}
}
}()
sc.Wait()
}
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
type directForward struct {
Host1 string
Port1 uint32
Host2 string
Port2 uint32
}
func (p directForward) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
}

68
listener/sshd/metadata.go Normal file
View File

@ -0,0 +1,68 @@
package ssh
import (
"io/ioutil"
tls_util "github.com/go-gost/gost/v3/pkg/common/util/tls"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
ssh_util "github.com/go-gost/x/internal/util/ssh"
"golang.org/x/crypto/ssh"
)
const (
defaultBacklog = 128
)
type metadata struct {
signer ssh.Signer
authorizedKeys map[string]bool
backlog int
}
func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizedKeys = "authorizedKeys"
privateKeyFile = "privateKeyFile"
passphrase = "passphrase"
backlog = "backlog"
)
if key := mdata.GetString(md, privateKeyFile); key != "" {
data, err := ioutil.ReadFile(key)
if err != nil {
return err
}
pp := mdata.GetString(md, passphrase)
if pp == "" {
l.md.signer, err = ssh.ParsePrivateKey(data)
} else {
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
}
if err != nil {
return err
}
}
if l.md.signer == nil {
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
if err != nil {
return err
}
l.md.signer = signer
}
if name := mdata.GetString(md, authorizedKeys); name != "" {
m, err := ssh_util.ParseAuthorizedKeysFile(name)
if err != nil {
return err
}
l.md.authorizedKeys = m
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

96
listener/tap/listener.go Normal file
View File

@ -0,0 +1,96 @@
package tap
import (
"net"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
tap_util "github.com/go-gost/x/internal/util/tap"
)
func init() {
registry.ListenerRegistry().Register("tap", NewListener)
}
type tapListener struct {
saddr string
addr net.Addr
cqueue chan net.Conn
closed chan struct{}
logger logger.Logger
md metadata
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &tapListener{
saddr: options.Addr,
logger: options.Logger,
}
}
func (l *tapListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveUDPAddr("udp", l.saddr)
if err != nil {
return
}
ifce, ip, err := l.createTap()
if err != nil {
if ifce != nil {
ifce.Close()
}
return
}
itf, err := net.InterfaceByName(ifce.Name())
if err != nil {
return
}
addrs, _ := itf.Addrs()
l.logger.Infof("name: %s, mac: %s, mtu: %d, addrs: %s",
itf.Name, itf.HardwareAddr, itf.MTU, addrs)
l.cqueue = make(chan net.Conn, 1)
l.closed = make(chan struct{})
conn := tap_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip})
l.cqueue <- conn
return
}
func (l *tapListener) Accept() (net.Conn, error) {
select {
case conn := <-l.cqueue:
return conn, nil
case <-l.closed:
}
return nil, listener.ErrClosed
}
func (l *tapListener) Addr() net.Addr {
return l.addr
}
func (l *tapListener) Close() error {
select {
case <-l.closed:
return net.ErrClosed
default:
close(l.closed)
}
return nil
}

44
listener/tap/metadata.go Normal file
View File

@ -0,0 +1,44 @@
package tap
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
tap_util "github.com/go-gost/x/internal/util/tap"
)
const (
DefaultMTU = 1350
)
type metadata struct {
config *tap_util.Config
}
func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) {
const (
name = "name"
netKey = "net"
mtu = "mtu"
routes = "routes"
gateway = "gw"
)
config := &tap_util.Config{
Name: mdata.GetString(md, name),
Net: mdata.GetString(md, netKey),
MTU: mdata.GetInt(md, mtu),
Gateway: mdata.GetString(md, gateway),
}
if config.MTU <= 0 {
config.MTU = DefaultMTU
}
for _, s := range mdata.GetStrings(md, routes) {
if s != "" {
config.Routes = append(config.Routes, s)
}
}
l.md.config = config
return
}

View File

@ -0,0 +1,13 @@
package tap
import (
"errors"
"net"
"github.com/songgao/water"
)
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
err = errors.New("tap is not supported on darwin")
return
}

69
listener/tap/tap_linux.go Normal file
View File

@ -0,0 +1,69 @@
package tap
import (
"net"
"github.com/docker/libcontainer/netlink"
"github.com/milosgajdos/tenus"
"github.com/songgao/water"
)
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
var ipNet *net.IPNet
if l.md.config.Net != "" {
ip, ipNet, err = net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
}
ifce, err = water.New(water.Config{
DeviceType: water.TAP,
PlatformSpecificParams: water.PlatformSpecificParams{
Name: l.md.config.Name,
},
})
if err != nil {
return
}
link, err := tenus.NewLinkFrom(ifce.Name())
if err != nil {
return
}
l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU)
if err = link.SetLinkMTU(l.md.config.MTU); err != nil {
return
}
if l.md.config.Net != "" {
l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name())
if err = link.SetLinkIp(ip, ipNet); err != nil {
return
}
}
l.logger.Debugf("ip link set dev %s up", ifce.Name())
if err = link.SetLinkUp(); err != nil {
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
for _, route := range routes {
l.logger.Debugf("ip route add %s via %s dev %s", route, gw, ifName)
if err := netlink.AddRoute(route, "", gw, ifName); err != nil {
return err
}
}
return nil
}

61
listener/tap/tap_unix.go Normal file
View File

@ -0,0 +1,61 @@
//go:build !linux && !windows && !darwin
package tap
import (
"fmt"
"net"
"os/exec"
"strings"
"github.com/songgao/water"
)
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
ip, _, _ = net.ParseCIDR(l.md.config.Net)
ifce, err = water.New(water.Config{
DeviceType: water.TAP,
})
if err != nil {
return
}
var cmd string
if l.md.config.Net != "" {
cmd = fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), l.md.config.Net, l.md.config.MTU)
} else {
cmd = fmt.Sprintf("ifconfig %s mtu %d up", ifce.Name(), l.md.config.MTU)
}
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er)
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
for _, route := range routes {
if route == "" {
continue
}
cmd := fmt.Sprintf("route add -net %s dev %s", route, ifName)
if gw != "" {
cmd += " gw " + gw
}
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
}
return nil
}

View File

@ -0,0 +1,75 @@
package tap
import (
"fmt"
"net"
"os/exec"
"strings"
"github.com/songgao/water"
)
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
ip, ipNet, _ := net.ParseCIDR(l.md.config.Net)
ifce, err = water.New(water.Config{
DeviceType: water.TAP,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
InterfaceName: l.md.config.Name,
Network: l.md.config.Net,
},
})
if err != nil {
return
}
if ip != nil && ipNet != nil {
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
"source=static addr=%s mask=%s gateway=none",
ifce.Name(), ip.String(), ipMask(ipNet.Mask))
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er)
return
}
}
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
for _, route := range routes {
l.deleteRoute(ifName, route)
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
route, ifName)
if gw != "" {
cmd += " nexthop=" + gw
}
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
}
return nil
}
func (l *tapListener) deleteRoute(ifName string, route string) error {
cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active",
route, ifName)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
return exec.Command(args[0], args[1:]...).Run()
}
func ipMask(mask net.IPMask) string {
return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3])
}

96
listener/tun/listener.go Normal file
View File

@ -0,0 +1,96 @@
package tun
import (
"net"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
tun_util "github.com/go-gost/x/internal/util/tun"
)
func init() {
registry.ListenerRegistry().Register("tun", NewListener)
}
type tunListener struct {
addr net.Addr
cqueue chan net.Conn
closed chan struct{}
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &tunListener{
logger: options.Logger,
options: options,
}
}
func (l *tunListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
ifce, ip, err := l.createTun()
if err != nil {
if ifce != nil {
ifce.Close()
}
return
}
itf, err := net.InterfaceByName(ifce.Name())
if err != nil {
return
}
addrs, _ := itf.Addrs()
l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s",
itf.Name, ip, itf.MTU, addrs)
l.cqueue = make(chan net.Conn, 1)
l.closed = make(chan struct{})
conn := tun_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip})
l.cqueue <- conn
return
}
func (l *tunListener) Accept() (net.Conn, error) {
select {
case conn := <-l.cqueue:
return conn, nil
case <-l.closed:
}
return nil, listener.ErrClosed
}
func (l *tunListener) Addr() net.Addr {
return l.addr
}
func (l *tunListener) Close() error {
select {
case <-l.closed:
return net.ErrClosed
default:
close(l.closed)
}
return nil
}

63
listener/tun/metadata.go Normal file
View File

@ -0,0 +1,63 @@
package tun
import (
"net"
"strings"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
tun_util "github.com/go-gost/x/internal/util/tun"
)
const (
DefaultMTU = 1350
)
type metadata struct {
config *tun_util.Config
}
func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
const (
name = "name"
netKey = "net"
peer = "peer"
mtu = "mtu"
routes = "routes"
gateway = "gw"
)
config := &tun_util.Config{
Name: mdata.GetString(md, name),
Net: mdata.GetString(md, netKey),
Peer: mdata.GetString(md, peer),
MTU: mdata.GetInt(md, mtu),
Gateway: mdata.GetString(md, gateway),
}
if config.MTU <= 0 {
config.MTU = DefaultMTU
}
gw := net.ParseIP(config.Gateway)
for _, s := range mdata.GetStrings(md, routes) {
ss := strings.SplitN(s, " ", 2)
if len(ss) == 2 {
var route tun_util.Route
_, ipNet, _ := net.ParseCIDR(strings.TrimSpace(ss[0]))
if ipNet == nil {
continue
}
route.Net = *ipNet
route.Gateway = net.ParseIP(ss[1])
if route.Gateway == nil {
route.Gateway = gw
}
config.Routes = append(config.Routes, route)
}
}
l.md.config = config
return
}

View File

@ -0,0 +1,57 @@
package tun
import (
"fmt"
"net"
"os/exec"
"strings"
tun_util "github.com/go-gost/x/internal/util/tun"
"github.com/songgao/water"
)
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
ip, _, err = net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
ifce, err = water.New(water.Config{
DeviceType: water.TUN,
})
if err != nil {
return
}
peer := l.md.config.Peer
if peer == "" {
peer = ip.String()
}
cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up",
ifce.Name(), l.md.config.Net, l.md.config.Peer, l.md.config.MTU)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if err = exec.Command(args[0], args[1:]...).Run(); err != nil {
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
for _, route := range routes {
cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if err := exec.Command(args[0], args[1:]...).Run(); err != nil {
return err
}
}
return nil
}

67
listener/tun/tun_linux.go Normal file
View File

@ -0,0 +1,67 @@
package tun
import (
"errors"
"net"
"syscall"
"github.com/docker/libcontainer/netlink"
tun_util "github.com/go-gost/x/internal/util/tun"
"github.com/milosgajdos/tenus"
"github.com/songgao/water"
)
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
ifce, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
Name: l.md.config.Name,
},
})
if err != nil {
return
}
link, err := tenus.NewLinkFrom(ifce.Name())
if err != nil {
return
}
l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU)
if err = link.SetLinkMTU(l.md.config.MTU); err != nil {
return
}
l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name())
if err = link.SetLinkIp(ip, ipNet); err != nil {
return
}
l.logger.Debugf("ip link set dev %s up", ifce.Name())
if err = link.SetLinkUp(); err != nil {
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
for _, route := range routes {
l.logger.Debugf("ip route add %s dev %s", route.Net.String(), ifName)
if err := netlink.AddRoute(route.Net.String(), "", "", ifName); err != nil && !errors.Is(err, syscall.EEXIST) {
return err
}
}
return nil
}

55
listener/tun/tun_unix.go Normal file
View File

@ -0,0 +1,55 @@
//go:build !linux && !windows && !darwin
package tun
import (
"fmt"
"net"
"os/exec"
"strings"
tun_util "github.com/go-gost/x/internal/util/tun"
"github.com/songgao/water"
)
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
ip, _, err = net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
ifce, err = water.New(water.Config{
DeviceType: water.TUN,
})
if err != nil {
return
}
cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up",
ifce.Name(), l.md.config.Net, l.md.config.MTU)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er)
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
for _, route := range routes {
cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
}
return nil
}

View File

@ -0,0 +1,77 @@
package tun
import (
"fmt"
"net"
"os/exec"
"strings"
tun_util "github.com/go-gost/x/internal/util/tun"
"github.com/songgao/water"
)
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
if err != nil {
return
}
ifce, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
InterfaceName: l.md.config.Name,
Network: l.md.config.Net,
},
})
if err != nil {
return
}
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
"source=static addr=%s mask=%s gateway=none",
ifce.Name(), ip.String(), ipMask(ipNet.Mask))
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
err = fmt.Errorf("%s: %v", cmd, er)
return
}
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
return
}
return
}
func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error {
for _, route := range routes {
l.deleteRoute(ifName, route.Net.String())
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
route.Net.String(), ifName)
if gw != "" {
cmd += " nexthop=" + gw
}
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
}
return nil
}
func (l *tunListener) deleteRoute(ifName string, route string) error {
cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active",
route, ifName)
l.logger.Debug(cmd)
args := strings.Split(cmd, " ")
return exec.Command(args[0], args[1:]...).Run()
}
func ipMask(mask net.IPMask) string {
return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3])
}

149
listener/ws/listener.go Normal file
View File

@ -0,0 +1,149 @@
package ws
import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"github.com/go-gost/gost/v3/pkg/common/admission"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
ws_util "github.com/go-gost/x/internal/util/ws"
"github.com/gorilla/websocket"
)
func init() {
registry.ListenerRegistry().Register("ws", NewListener)
registry.ListenerRegistry().Register("wss", NewTLSListener)
}
type wsListener struct {
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
tlsEnabled bool
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &wsListener{
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &wsListener{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
func (l *wsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
mux := http.NewServeMux()
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.options.Addr,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
ln = metrics.WrapListener(l.options.Service, ln)
ln = admission.WrapListener(l.options.Admission, ln)
if l.tlsEnabled {
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *wsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *wsListener) Close() error {
return l.srv.Close()
}
func (l *wsListener) Addr() net.Addr {
return l.addr
}
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]any{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.cqueue <- ws_util.Conn(conn):
default:
conn.Close()
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
}
}

66
listener/ws/metadata.go Normal file
View File

@ -0,0 +1,66 @@
package ws
import (
"net/http"
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
backlog int
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
header http.Header
}
func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
backlog = "backlog"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
)
l.md.path = mdata.GetString(md, path)
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize)
l.md.enableCompression = mdata.GetBool(md, enableCompression)
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
hd := http.Header{}
for k, v := range mm {
hd.Add(k, v)
}
l.md.header = hd
}
return
}