initial commit
This commit is contained in:
210
listener/dns/listener.go
Normal file
210
listener/dns/listener.go
Normal file
@ -0,0 +1,210 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("dns", NewListener)
|
||||
}
|
||||
|
||||
type dnsListener struct {
|
||||
addr net.Addr
|
||||
server Server
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &dnsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *dnsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch strings.ToLower(l.md.mode) {
|
||||
case "tcp":
|
||||
l.server = &dns.Server{
|
||||
Net: "tcp",
|
||||
Addr: l.options.Addr,
|
||||
Handler: l,
|
||||
ReadTimeout: l.md.readTimeout,
|
||||
WriteTimeout: l.md.writeTimeout,
|
||||
}
|
||||
case "tls":
|
||||
l.server = &dns.Server{
|
||||
Net: "tcp-tls",
|
||||
Addr: l.options.Addr,
|
||||
Handler: l,
|
||||
TLSConfig: l.options.TLSConfig,
|
||||
ReadTimeout: l.md.readTimeout,
|
||||
WriteTimeout: l.md.writeTimeout,
|
||||
}
|
||||
case "https":
|
||||
l.server = &dohServer{
|
||||
addr: l.options.Addr,
|
||||
tlsConfig: l.options.TLSConfig,
|
||||
server: &http.Server{
|
||||
Handler: l,
|
||||
ReadTimeout: l.md.readTimeout,
|
||||
WriteTimeout: l.md.writeTimeout,
|
||||
},
|
||||
}
|
||||
default:
|
||||
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
|
||||
l.server = &dns.Server{
|
||||
Net: "udp",
|
||||
Addr: l.options.Addr,
|
||||
Handler: l,
|
||||
UDPSize: l.md.readBufferSize,
|
||||
ReadTimeout: l.md.readTimeout,
|
||||
WriteTimeout: l.md.writeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := l.server.ListenAndServe()
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *dnsListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
conn = metrics.WrapConn(l.options.Service, conn)
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *dnsListener) Close() error {
|
||||
return l.server.Shutdown()
|
||||
}
|
||||
|
||||
func (l *dnsListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
|
||||
b, err := m.Pack()
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
if err := l.serve(w, b); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Based on https://github.com/semihalev/sdns
|
||||
func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var buf []byte
|
||||
var err error
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns"))
|
||||
if len(buf) == 0 || err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
case http.MethodPost:
|
||||
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" {
|
||||
l.logger.Errorf("unsupported media type: %s", ct)
|
||||
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
|
||||
buf, err = ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
default:
|
||||
l.logger.Errorf("method not allowd: %s", r.Method)
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
mq := &dns.Msg{}
|
||||
if err := mq.Unpack(buf); err != nil {
|
||||
l.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Server", "SDNS")
|
||||
w.Header().Set("Content-Type", "application/dns-message")
|
||||
|
||||
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if raddr == nil {
|
||||
raddr = &net.TCPAddr{}
|
||||
}
|
||||
if err := l.serve(&dohResponseWriter{raddr: raddr, ResponseWriter: w}, buf); err != nil {
|
||||
l.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *dnsListener) serve(w ResponseWriter, msg []byte) (err error) {
|
||||
conn := &serverConn{
|
||||
r: bytes.NewReader(msg),
|
||||
w: w,
|
||||
laddr: l.addr,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
select {
|
||||
case l.cqueue <- conn:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", w.RemoteAddr())
|
||||
return errors.New("connection queue is full")
|
||||
}
|
||||
|
||||
return conn.Wait()
|
||||
}
|
41
listener/dns/metadata.go
Normal file
41
listener/dns/metadata.go
Normal file
@ -0,0 +1,41 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
mode string
|
||||
readBufferSize int
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
mode = "mode"
|
||||
readBufferSize = "readBufferSize"
|
||||
readTimeout = "readTimeout"
|
||||
writeTimeout = "writeTimeout"
|
||||
)
|
||||
|
||||
l.md.mode = mdata.GetString(md, mode)
|
||||
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
|
||||
l.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
l.md.writeTimeout = mdata.GetDuration(md, writeTimeout)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
110
listener/dns/server.go
Normal file
110
listener/dns/server.go
Normal file
@ -0,0 +1,110 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
ListenAndServe() error
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
type dohServer struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func (s *dohServer) ListenAndServe() error {
|
||||
ln, err := net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln = tls.NewListener(ln, s.tlsConfig)
|
||||
return s.server.Serve(ln)
|
||||
}
|
||||
|
||||
func (s *dohServer) Shutdown() error {
|
||||
return s.server.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
type ResponseWriter interface {
|
||||
io.Writer
|
||||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
type dohResponseWriter struct {
|
||||
raddr net.Addr
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *dohResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.raddr
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
r io.Reader
|
||||
w ResponseWriter
|
||||
laddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
return c.r.Read(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
return c.w.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) Wait() error {
|
||||
<-c.closed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) LocalAddr() net.Addr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.w.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *serverConn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *serverConn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *serverConn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
124
listener/ftcp/conn.go
Normal file
124
listener/ftcp/conn.go
Normal file
@ -0,0 +1,124 @@
|
||||
package ftcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type serverConn struct {
|
||||
pc net.PacketConn
|
||||
raddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
fresh int32
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
config *serverConnConfig
|
||||
}
|
||||
|
||||
type serverConnConfig struct {
|
||||
ttl time.Duration
|
||||
qsize int
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
|
||||
if conn == nil || raddr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &serverConnConfig{}
|
||||
}
|
||||
c := &serverConn{
|
||||
pc: conn,
|
||||
raddr: raddr,
|
||||
rc: make(chan []byte, cfg.qsize),
|
||||
closed: make(chan struct{}),
|
||||
config: cfg,
|
||||
}
|
||||
go c.ttlWait()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *serverConn) send(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
atomic.StoreInt32(&c.fresh, 1)
|
||||
case <-c.closed:
|
||||
err = errors.New("read from closed connection")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
return c.pc.WriteTo(b, c.raddr)
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
return errors.New("connection is closed")
|
||||
default:
|
||||
if c.config.onClose != nil {
|
||||
c.config.onClose()
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) LocalAddr() net.Addr {
|
||||
return c.pc.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *serverConn) SetDeadline(t time.Time) error {
|
||||
return c.pc.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *serverConn) SetReadDeadline(t time.Time) error {
|
||||
return c.pc.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *serverConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.pc.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *serverConn) ttlWait() {
|
||||
ticker := time.NewTicker(c.config.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
159
listener/ftcp/listener.go
Normal file
159
listener/ftcp/listener.go
Normal file
@ -0,0 +1,159 @@
|
||||
package ftcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/xtaci/tcpraw"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("ftcp", NewListener)
|
||||
}
|
||||
|
||||
type ftcpListener struct {
|
||||
conn net.PacketConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
connPool connPool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &ftcpListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ftcpListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.conn, err = tcpraw.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *ftcpListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
conn = metrics.WrapConn(l.options.Service, conn)
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *ftcpListener) Close() error {
|
||||
err := l.conn.Close()
|
||||
l.connPool.Range(func(k any, v *serverConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *ftcpListener) Addr() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *ftcpListener) listenLoop() {
|
||||
for {
|
||||
b := make([]byte, l.md.readBufferSize)
|
||||
|
||||
n, raddr, err := l.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := l.connPool.Get(raddr.String())
|
||||
if !ok {
|
||||
conn = newServerConn(l.conn, raddr,
|
||||
&serverConnConfig{
|
||||
ttl: l.md.ttl,
|
||||
qsize: l.md.readQueueSize,
|
||||
onClose: func() {
|
||||
l.connPool.Delete(raddr.String())
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connPool.Set(raddr.String(), conn)
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.send(b[:n]); err != nil {
|
||||
l.logger.Warn("data discarded:", err)
|
||||
}
|
||||
l.logger.Debug("recv", n)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ftcpListener) parseMetadata(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
size int64
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (p *connPool) Get(key any) (conn *serverConn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
conn, ok = v.(*serverConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *connPool) Set(key any, conn *serverConn) {
|
||||
p.m.Store(key, conn)
|
||||
atomic.AddInt64(&p.size, 1)
|
||||
}
|
||||
|
||||
func (p *connPool) Delete(key any) {
|
||||
p.m.Delete(key)
|
||||
atomic.AddInt64(&p.size, -1)
|
||||
}
|
||||
|
||||
func (p *connPool) Range(f func(key any, value *serverConn) bool) {
|
||||
p.m.Range(func(k, v any) bool {
|
||||
return f(k, v.(*serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (p *connPool) Size() int64 {
|
||||
return atomic.LoadInt64(&p.size)
|
||||
}
|
22
listener/ftcp/metadata.go
Normal file
22
listener/ftcp/metadata.go
Normal file
@ -0,0 +1,22 @@
|
||||
package ftcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultTTL = 60 * time.Second
|
||||
defaultReadBufferSize = 1024
|
||||
defaultReadQueueSize = 128
|
||||
defaultConnQueueSize = 128
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
ttl time.Duration
|
||||
|
||||
readBufferSize int
|
||||
readQueueSize int
|
||||
connQueueSize int
|
||||
}
|
100
listener/grpc/listener.go
Normal file
100
listener/grpc/listener.go
Normal file
@ -0,0 +1,100 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
pb "github.com/go-gost/x/internal/util/grpc/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("grpc", NewListener)
|
||||
}
|
||||
|
||||
type grpcListener struct {
|
||||
addr net.Addr
|
||||
server *grpc.Server
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
md metadata
|
||||
logger logger.Logger
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &grpcListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *grpcListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
if !l.md.insecure {
|
||||
opts = append(opts, grpc.Creds(credentials.NewTLS(l.options.TLSConfig)))
|
||||
}
|
||||
|
||||
l.server = grpc.NewServer(opts...)
|
||||
l.addr = ln.Addr()
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
pb.RegisterGostTunelServer(l.server, &server{
|
||||
cqueue: l.cqueue,
|
||||
localAddr: l.addr,
|
||||
logger: l.options.Logger,
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := l.server.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *grpcListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *grpcListener) Close() error {
|
||||
l.server.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *grpcListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
29
listener/grpc/metadata.go
Normal file
29
listener/grpc/metadata.go
Normal file
@ -0,0 +1,29 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
backlog int
|
||||
insecure bool
|
||||
}
|
||||
|
||||
func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
insecure = "grpcInsecure"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.insecure = mdata.GetBool(md, insecure)
|
||||
return
|
||||
}
|
124
listener/grpc/server.go
Normal file
124
listener/grpc/server.go
Normal file
@ -0,0 +1,124 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
pb "github.com/go-gost/x/internal/util/grpc/proto"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
cqueue chan net.Conn
|
||||
localAddr net.Addr
|
||||
pb.UnimplementedGostTunelServer
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (s *server) Tunnel(srv pb.GostTunel_TunnelServer) error {
|
||||
c := &conn{
|
||||
s: srv,
|
||||
localAddr: s.localAddr,
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
if p, ok := peer.FromContext(srv.Context()); ok {
|
||||
c.remoteAddr = p.Addr
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cqueue <- c:
|
||||
default:
|
||||
c.Close()
|
||||
s.logger.Warnf("connection queue is full, client discarded")
|
||||
}
|
||||
|
||||
<-c.closed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
s pb.GostTunel_TunnelServer
|
||||
rb []byte
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.s.Context().Done():
|
||||
err = c.s.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if len(c.rb) == 0 {
|
||||
chunk, err := c.s.Recv()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.rb = chunk.Data
|
||||
}
|
||||
|
||||
n = copy(b, c.rb)
|
||||
c.rb = c.rb[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.s.Context().Done():
|
||||
err = c.s.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err = c.s.Send(&pb.Chunk{
|
||||
Data: b,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
54
listener/http2/conn.go
Normal file
54
listener/http2/conn.go
Normal file
@ -0,0 +1,54 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// a dummy HTTP2 server conn used by HTTP2 handler
|
||||
type conn struct {
|
||||
r *http.Request
|
||||
w http.ResponseWriter
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
89
listener/http2/h2/conn.go
Normal file
89
listener/http2/h2/conn.go
Normal file
@ -0,0 +1,89 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP2 connection, wrapped up just like a net.Conn
|
||||
type conn struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *conn) Close() (err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
if rc, ok := c.r.(io.Closer); ok {
|
||||
err = rc.Close()
|
||||
}
|
||||
if w, ok := c.w.(io.Closer); ok {
|
||||
err = w.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (fw flushWriter) Write(p []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(string); ok {
|
||||
err = errors.New(s)
|
||||
// log.Log("[http2]", err)
|
||||
return
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
n, err = fw.w.Write(p)
|
||||
if err != nil {
|
||||
// log.Log("flush writer:", err)
|
||||
return
|
||||
}
|
||||
if f, ok := fw.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
178
listener/http2/h2/listener.go
Normal file
178
listener/http2/h2/listener.go
Normal file
@ -0,0 +1,178 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("h2c", NewListener)
|
||||
registry.ListenerRegistry().Register("h2", NewTLSListener)
|
||||
}
|
||||
|
||||
type h2Listener struct {
|
||||
server *http.Server
|
||||
addr net.Addr
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
h2c bool
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &h2Listener{
|
||||
h2c: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &h2Listener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *h2Listener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.options.Addr,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.addr = ln.Addr()
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
if l.h2c {
|
||||
l.server.Handler = h2c.NewHandler(
|
||||
http.HandlerFunc(l.handleFunc), &http2.Server{})
|
||||
} else {
|
||||
l.server.Handler = http.HandlerFunc(l.handleFunc)
|
||||
l.server.TLSConfig = l.options.TLSConfig
|
||||
if err := http2.ConfigureServer(l.server, nil); err != nil {
|
||||
ln.Close()
|
||||
return err
|
||||
}
|
||||
ln = tls.NewListener(ln, l.options.TLSConfig)
|
||||
}
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if err := l.server.Serve(ln); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *h2Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *h2Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *h2Listener) Close() (err error) {
|
||||
select {
|
||||
case <-l.errChan:
|
||||
default:
|
||||
err = l.server.Close()
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
if l.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
l.logger.Debug(string(dump))
|
||||
}
|
||||
conn, err := l.upgrade(w, r)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case l.cqueue <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
|
||||
}
|
||||
|
||||
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
|
||||
}
|
||||
|
||||
func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
|
||||
if l.md.path == "" && r.Method != http.MethodConnect {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return nil, errors.New("method not allowed")
|
||||
}
|
||||
|
||||
if l.md.path != "" && r.RequestURI != l.md.path {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil, errors.New("bad request")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush() // write header to client
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if remoteAddr == nil {
|
||||
remoteAddr = &net.TCPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
return &conn{
|
||||
r: r.Body,
|
||||
w: flushWriter{w},
|
||||
localAddr: l.addr,
|
||||
remoteAddr: remoteAddr,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
29
listener/http2/h2/metadata.go
Normal file
29
listener/http2/h2/metadata.go
Normal file
@ -0,0 +1,29 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
path string
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
path = "path"
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.path = mdata.GetString(md, path)
|
||||
return
|
||||
}
|
120
listener/http2/listener.go
Normal file
120
listener/http2/listener.go
Normal file
@ -0,0 +1,120 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
http2_util "github.com/go-gost/x/internal/util/http2"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("http2", NewListener)
|
||||
}
|
||||
|
||||
type http2Listener struct {
|
||||
server *http.Server
|
||||
addr net.Addr
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &http2Listener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *http2Listener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.options.Addr,
|
||||
Handler: http.HandlerFunc(l.handleFunc),
|
||||
TLSConfig: l.options.TLSConfig,
|
||||
}
|
||||
if err := http2.ConfigureServer(l.server, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.addr = ln.Addr()
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
ln = tls.NewListener(
|
||||
ln,
|
||||
l.options.TLSConfig,
|
||||
)
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if err := l.server.Serve(ln); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *http2Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *http2Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *http2Listener) Close() (err error) {
|
||||
select {
|
||||
case <-l.errChan:
|
||||
default:
|
||||
err = l.server.Close()
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
conn := http2_util.NewServerConn(w, r, l.addr, raddr)
|
||||
select {
|
||||
case l.cqueue <- conn:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
<-conn.Done()
|
||||
}
|
25
listener/http2/metadata.go
Normal file
25
listener/http2/metadata.go
Normal file
@ -0,0 +1,25 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
return
|
||||
}
|
81
listener/http3/listener.go
Normal file
81
listener/http3/listener.go
Normal file
@ -0,0 +1,81 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
pht_util "github.com/go-gost/x/internal/util/pht"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("http3", NewListener)
|
||||
registry.ListenerRegistry().Register("h3", NewListener)
|
||||
}
|
||||
|
||||
type http3Listener struct {
|
||||
addr net.Addr
|
||||
server *pht_util.Server
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &http3Listener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *http3Listener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = pht_util.NewHTTP3Server(
|
||||
l.options.Addr,
|
||||
&quic.Config{},
|
||||
pht_util.TLSConfigServerOption(l.options.TLSConfig),
|
||||
pht_util.BacklogServerOption(l.md.backlog),
|
||||
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
|
||||
pht_util.LoggerServerOption(l.options.Logger),
|
||||
)
|
||||
|
||||
go func() {
|
||||
if err := l.server.ListenAndServe(); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *http3Listener) Accept() (conn net.Conn, err error) {
|
||||
conn, err = l.server.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return metrics.WrapConn(l.options.Service, conn), nil
|
||||
}
|
||||
|
||||
func (l *http3Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *http3Listener) Close() (err error) {
|
||||
return l.server.Close()
|
||||
}
|
51
listener/http3/metadata.go
Normal file
51
listener/http3/metadata.go
Normal file
@ -0,0 +1,51 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAuthorizePath = "/authorize"
|
||||
defaultPushPath = "/push"
|
||||
defaultPullPath = "/pull"
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
authorizePath string
|
||||
pushPath string
|
||||
pullPath string
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *http3Listener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
authorizePath = "authorizePath"
|
||||
pushPath = "pushPath"
|
||||
pullPath = "pullPath"
|
||||
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
l.md.authorizePath = mdata.GetString(md, authorizePath)
|
||||
if !strings.HasPrefix(l.md.authorizePath, "/") {
|
||||
l.md.authorizePath = defaultAuthorizePath
|
||||
}
|
||||
l.md.pushPath = mdata.GetString(md, pushPath)
|
||||
if !strings.HasPrefix(l.md.pushPath, "/") {
|
||||
l.md.pushPath = defaultPushPath
|
||||
}
|
||||
l.md.pullPath = mdata.GetString(md, pullPath)
|
||||
if !strings.HasPrefix(l.md.pullPath, "/") {
|
||||
l.md.pullPath = defaultPullPath
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
21
listener/icmp/conn.go
Normal file
21
listener/icmp/conn.go
Normal file
@ -0,0 +1,21 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type quicConn struct {
|
||||
quic.Stream
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
}
|
||||
|
||||
func (c *quicConn) LocalAddr() net.Addr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
func (c *quicConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
147
listener/icmp/listener.go
Normal file
147
listener/icmp/listener.go
Normal file
@ -0,0 +1,147 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
icmp_pkg "github.com/go-gost/x/internal/util/icmp"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("icmp", NewListener)
|
||||
}
|
||||
|
||||
type icmpListener struct {
|
||||
ln quic.Listener
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &icmpListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *icmpListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
addr := l.options.Addr
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil {
|
||||
addr = host
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
conn, err = icmp.ListenPacket("ip4:icmp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn = icmp_pkg.ServerConn(conn)
|
||||
conn = metrics.WrapPacketConn(l.options.Service, conn)
|
||||
conn = admission.WrapPacketConn(l.options.Admission, conn)
|
||||
|
||||
config := &quic.Config{
|
||||
KeepAlive: l.md.keepAlive,
|
||||
HandshakeIdleTimeout: l.md.handshakeTimeout,
|
||||
MaxIdleTimeout: l.md.maxIdleTimeout,
|
||||
Versions: []quic.VersionNumber{
|
||||
quic.Version1,
|
||||
quic.VersionDraft29,
|
||||
},
|
||||
}
|
||||
|
||||
tlsCfg := l.options.TLSConfig
|
||||
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
|
||||
|
||||
ln, err := quic.Listen(conn, tlsCfg, config)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *icmpListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *icmpListener) Close() error {
|
||||
return l.ln.Close()
|
||||
}
|
||||
|
||||
func (l *icmpListener) Addr() net.Addr {
|
||||
return l.ln.Addr()
|
||||
}
|
||||
|
||||
func (l *icmpListener) listenLoop() {
|
||||
for {
|
||||
ctx := context.Background()
|
||||
session, err := l.ln.Accept(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept: ", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
l.logger.Infof("new client session: %v", session.RemoteAddr())
|
||||
go l.mux(ctx, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *icmpListener) mux(ctx context.Context, session quic.Session) {
|
||||
defer session.CloseWithError(0, "closed")
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &quicConn{
|
||||
Stream: stream,
|
||||
laddr: session.LocalAddr(),
|
||||
raddr: session.RemoteAddr(),
|
||||
}
|
||||
select {
|
||||
case l.cqueue <- conn:
|
||||
case <-stream.Context().Done():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
41
listener/icmp/metadata.go
Normal file
41
listener/icmp/metadata.go
Normal file
@ -0,0 +1,41 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
keepAlive bool
|
||||
handshakeTimeout time.Duration
|
||||
maxIdleTimeout time.Duration
|
||||
|
||||
cipherKey []byte
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *icmpListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
keepAlive = "keepAlive"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
maxIdleTimeout = "maxIdleTimeout"
|
||||
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.keepAlive = mdata.GetBool(md, keepAlive)
|
||||
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
|
||||
l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)
|
||||
|
||||
return
|
||||
}
|
179
listener/kcp/listener.go
Normal file
179
listener/kcp/listener.go
Normal file
@ -0,0 +1,179 @@
|
||||
package kcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
kcp_util "github.com/go-gost/x/internal/util/kcp"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
"github.com/xtaci/tcpraw"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("kcp", NewListener)
|
||||
}
|
||||
|
||||
type kcpListener struct {
|
||||
conn net.PacketConn
|
||||
ln *kcp.Listener
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &kcpListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *kcpListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := l.md.config
|
||||
config.Init()
|
||||
|
||||
var conn net.PacketConn
|
||||
if config.TCP {
|
||||
conn, err = tcpraw.Listen("tcp", l.options.Addr)
|
||||
} else {
|
||||
var udpAddr *net.UDPAddr
|
||||
udpAddr, err = net.ResolveUDPAddr("udp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn, err = net.ListenUDP("udp", udpAddr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn = metrics.WrapUDPConn(l.options.Service, conn)
|
||||
conn = admission.WrapUDPConn(l.options.Admission, conn)
|
||||
|
||||
ln, err := kcp.ServeConn(
|
||||
kcp_util.BlockCrypt(config.Key, config.Crypt, kcp_util.DefaultSalt),
|
||||
config.DataShard, config.ParityShard, conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if config.DSCP > 0 {
|
||||
if er := ln.SetDSCP(config.DSCP); er != nil {
|
||||
l.logger.Warn(er)
|
||||
}
|
||||
}
|
||||
if er := ln.SetReadBuffer(config.SockBuf); er != nil {
|
||||
l.logger.Warn(er)
|
||||
}
|
||||
if er := ln.SetWriteBuffer(config.SockBuf); er != nil {
|
||||
l.logger.Warn(er)
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
l.conn = conn
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *kcpListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *kcpListener) Close() error {
|
||||
l.conn.Close()
|
||||
return l.ln.Close()
|
||||
}
|
||||
|
||||
func (l *kcpListener) Addr() net.Addr {
|
||||
return l.ln.Addr()
|
||||
}
|
||||
|
||||
func (l *kcpListener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.ln.AcceptKCP()
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(false)
|
||||
conn.SetNoDelay(
|
||||
l.md.config.NoDelay,
|
||||
l.md.config.Interval,
|
||||
l.md.config.Resend,
|
||||
l.md.config.NoCongestion,
|
||||
)
|
||||
conn.SetMtu(l.md.config.MTU)
|
||||
conn.SetWindowSize(l.md.config.SndWnd, l.md.config.RcvWnd)
|
||||
conn.SetACKNoDelay(l.md.config.AckNodelay)
|
||||
go l.mux(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *kcpListener) mux(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.MaxReceiveBuffer = l.md.config.SockBuf
|
||||
smuxConfig.KeepAliveInterval = time.Duration(l.md.config.KeepAlive) * time.Second
|
||||
|
||||
if !l.md.config.NoComp {
|
||||
conn = kcp_util.CompStreamConn(conn)
|
||||
}
|
||||
|
||||
mux, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer mux.Close()
|
||||
|
||||
for {
|
||||
stream, err := mux.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.cqueue <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
47
listener/kcp/metadata.go
Normal file
47
listener/kcp/metadata.go
Normal file
@ -0,0 +1,47 @@
|
||||
package kcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
kcp_util "github.com/go-gost/x/internal/util/kcp"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
config *kcp_util.Config
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *kcpListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
config = "config"
|
||||
)
|
||||
|
||||
if m := mdata.GetStringMap(md, config); len(m) > 0 {
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg := &kcp_util.Config{}
|
||||
if err := json.Unmarshal(b, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
l.md.config = cfg
|
||||
}
|
||||
|
||||
if l.md.config == nil {
|
||||
l.md.config = kcp_util.DefaultConfig
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
129
listener/mtls/listener.go
Normal file
129
listener/mtls/listener.go
Normal file
@ -0,0 +1,129 @@
|
||||
package mtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("mtls", NewListener)
|
||||
}
|
||||
|
||||
type mtlsListener struct {
|
||||
net.Listener
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &mtlsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *mtlsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
l.Listener = tls.NewListener(ln, l.options.TLSConfig)
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *mtlsListener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.mux(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *mtlsListener) mux(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
|
||||
if l.md.muxKeepAliveInterval > 0 {
|
||||
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
|
||||
}
|
||||
if l.md.muxKeepAliveTimeout > 0 {
|
||||
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
|
||||
}
|
||||
if l.md.muxMaxFrameSize > 0 {
|
||||
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
|
||||
}
|
||||
if l.md.muxMaxReceiveBuffer > 0 {
|
||||
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
|
||||
}
|
||||
if l.md.muxMaxStreamBuffer > 0 {
|
||||
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
|
||||
}
|
||||
session, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.cqueue <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
49
listener/mtls/metadata.go
Normal file
49
listener/mtls/metadata.go
Normal file
@ -0,0 +1,49 @@
|
||||
package mtls
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
muxKeepAliveDisabled bool
|
||||
muxKeepAliveInterval time.Duration
|
||||
muxKeepAliveTimeout time.Duration
|
||||
muxMaxFrameSize int
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAliveInterval = "muxKeepAliveInterval"
|
||||
muxKeepAliveTimeout = "muxKeepAliveTimeout"
|
||||
muxMaxFrameSize = "muxMaxFrameSize"
|
||||
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
|
||||
muxMaxStreamBuffer = "muxMaxStreamBuffer"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
|
||||
l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
|
||||
l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)
|
||||
l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize)
|
||||
l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer)
|
||||
l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer)
|
||||
|
||||
return
|
||||
}
|
194
listener/mws/listener.go
Normal file
194
listener/mws/listener.go
Normal file
@ -0,0 +1,194 @@
|
||||
package mws
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
ws_util "github.com/go-gost/x/internal/util/ws"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("mws", NewListener)
|
||||
registry.ListenerRegistry().Register("mwss", NewTLSListener)
|
||||
}
|
||||
|
||||
type mwsListener struct {
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
tlsEnabled bool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &mwsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &mwsListener{
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *mwsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.upgrader = &websocket.Upgrader{
|
||||
HandshakeTimeout: l.md.handshakeTimeout,
|
||||
ReadBufferSize: l.md.readBufferSize,
|
||||
WriteBufferSize: l.md.writeBufferSize,
|
||||
EnableCompression: l.md.enableCompression,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
path := l.md.path
|
||||
if path == "" {
|
||||
path = defaultPath
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(path, http.HandlerFunc(l.upgrade))
|
||||
l.srv = &http.Server{
|
||||
Addr: l.options.Addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: l.md.readHeaderTimeout,
|
||||
}
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
if l.tlsEnabled {
|
||||
ln = tls.NewListener(ln, l.options.TLSConfig)
|
||||
}
|
||||
|
||||
l.addr = ln.Addr()
|
||||
|
||||
go func() {
|
||||
err := l.srv.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *mwsListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *mwsListener) Close() error {
|
||||
return l.srv.Close()
|
||||
}
|
||||
|
||||
func (l *mwsListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
if l.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
log := l.logger.WithFields(map[string]any{
|
||||
"local": l.addr.String(),
|
||||
"remote": r.RemoteAddr,
|
||||
})
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
l.mux(ws_util.Conn(conn))
|
||||
}
|
||||
|
||||
func (l *mwsListener) mux(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
|
||||
if l.md.muxKeepAliveInterval > 0 {
|
||||
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
|
||||
}
|
||||
if l.md.muxKeepAliveTimeout > 0 {
|
||||
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
|
||||
}
|
||||
if l.md.muxMaxFrameSize > 0 {
|
||||
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
|
||||
}
|
||||
if l.md.muxMaxReceiveBuffer > 0 {
|
||||
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
|
||||
}
|
||||
if l.md.muxMaxStreamBuffer > 0 {
|
||||
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
|
||||
}
|
||||
session, err := smux.Server(conn, smuxConfig)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.cqueue <- stream:
|
||||
case <-stream.GetDieCh():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
85
listener/mws/metadata.go
Normal file
85
listener/mws/metadata.go
Normal file
@ -0,0 +1,85 @@
|
||||
package mws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPath = "/ws"
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
path string
|
||||
backlog int
|
||||
header http.Header
|
||||
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
|
||||
muxKeepAliveDisabled bool
|
||||
muxKeepAliveInterval time.Duration
|
||||
muxKeepAliveTimeout time.Duration
|
||||
muxMaxFrameSize int
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
}
|
||||
|
||||
func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
path = "path"
|
||||
backlog = "backlog"
|
||||
header = "header"
|
||||
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAliveInterval = "muxKeepAliveInterval"
|
||||
muxKeepAliveTimeout = "muxKeepAliveTimeout"
|
||||
muxMaxFrameSize = "muxMaxFrameSize"
|
||||
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
|
||||
muxMaxStreamBuffer = "muxMaxStreamBuffer"
|
||||
)
|
||||
|
||||
l.md.path = mdata.GetString(md, path)
|
||||
if l.md.path == "" {
|
||||
l.md.path = defaultPath
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
|
||||
l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
|
||||
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
|
||||
l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize)
|
||||
l.md.enableCompression = mdata.GetBool(md, enableCompression)
|
||||
|
||||
l.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
|
||||
l.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
|
||||
l.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)
|
||||
l.md.muxMaxFrameSize = mdata.GetInt(md, muxMaxFrameSize)
|
||||
l.md.muxMaxReceiveBuffer = mdata.GetInt(md, muxMaxReceiveBuffer)
|
||||
l.md.muxMaxStreamBuffer = mdata.GetInt(md, muxMaxStreamBuffer)
|
||||
|
||||
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
|
||||
hd := http.Header{}
|
||||
for k, v := range mm {
|
||||
hd.Add(k, v)
|
||||
}
|
||||
l.md.header = hd
|
||||
}
|
||||
return
|
||||
}
|
145
listener/obfs/http/conn.go
Normal file
145
listener/obfs/http/conn.go
Normal file
@ -0,0 +1,145 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
)
|
||||
|
||||
type obfsHTTPConn struct {
|
||||
net.Conn
|
||||
rbuf bytes.Buffer
|
||||
wbuf bytes.Buffer
|
||||
handshaked bool
|
||||
handshakeMutex sync.Mutex
|
||||
header http.Header
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (c *obfsHTTPConn) Handshake() (err error) {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
if c.handshaked {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = c.handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.handshaked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *obfsHTTPConn) handshake() (err error) {
|
||||
br := bufio.NewReader(c.Conn)
|
||||
r, err := http.ReadRequest(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if r.ContentLength > 0 {
|
||||
_, err = io.Copy(&c.rbuf, r.Body)
|
||||
} else {
|
||||
var b []byte
|
||||
b, err = br.Peek(br.Buffered())
|
||||
if len(b) > 0 {
|
||||
_, err = c.rbuf.Write(b)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: c.header,
|
||||
}
|
||||
if resp.Header == nil {
|
||||
resp.Header = http.Header{}
|
||||
}
|
||||
resp.Header.Set("Date", time.Now().Format(time.RFC1123))
|
||||
|
||||
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
|
||||
resp.StatusCode = http.StatusBadRequest
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(&resp, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
resp.Write(c.Conn)
|
||||
return errors.New("bad request")
|
||||
}
|
||||
|
||||
resp.StatusCode = http.StatusSwitchingProtocols
|
||||
resp.Header.Set("Connection", "Upgrade")
|
||||
resp.Header.Set("Upgrade", "websocket")
|
||||
resp.Header.Set("Sec-WebSocket-Accept", c.computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(&resp, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
// cache the response header if there are extra data in the request body.
|
||||
resp.Write(&c.wbuf)
|
||||
return
|
||||
}
|
||||
|
||||
err = resp.Write(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *obfsHTTPConn) Read(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
return c.rbuf.Read(b)
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *obfsHTTPConn) Write(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
if c.wbuf.Len() > 0 {
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
n = len(b) // exclude the header length
|
||||
return
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
func (c *obfsHTTPConn) computeAcceptKey(challengeKey string) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(challengeKey))
|
||||
h.Write(keyGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
63
listener/obfs/http/listener.go
Normal file
63
listener/obfs/http/listener.go
Normal file
@ -0,0 +1,63 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("ohttp", NewListener)
|
||||
}
|
||||
|
||||
type obfsListener struct {
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &obfsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *obfsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *obfsListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &obfsHTTPConn{
|
||||
Conn: c,
|
||||
header: l.md.header,
|
||||
logger: l.logger,
|
||||
}, nil
|
||||
}
|
26
listener/obfs/http/metadata.go
Normal file
26
listener/obfs/http/metadata.go
Normal file
@ -0,0 +1,26 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (l *obfsListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
header = "header"
|
||||
)
|
||||
|
||||
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
|
||||
hd := http.Header{}
|
||||
for k, v := range mm {
|
||||
hd.Add(k, v)
|
||||
}
|
||||
l.md.header = hd
|
||||
}
|
||||
return
|
||||
}
|
161
listener/obfs/tls/conn.go
Normal file
161
listener/obfs/tls/conn.go
Normal file
@ -0,0 +1,161 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dissector "github.com/go-gost/tls-dissector"
|
||||
)
|
||||
|
||||
const (
|
||||
maxTLSDataLen = 16384
|
||||
)
|
||||
|
||||
type obfsTLSConn struct {
|
||||
net.Conn
|
||||
rbuf bytes.Buffer
|
||||
wbuf bytes.Buffer
|
||||
handshaked bool
|
||||
handshakeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *obfsTLSConn) Handshake() (err error) {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
if c.handshaked {
|
||||
return
|
||||
}
|
||||
|
||||
if err = c.handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.handshaked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *obfsTLSConn) handshake() error {
|
||||
record := &dissector.Record{}
|
||||
if _, err := record.ReadFrom(c.Conn); err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
if record.Type != dissector.Handshake {
|
||||
return dissector.ErrBadType
|
||||
}
|
||||
|
||||
clientMsg := &dissector.ClientHelloMsg{}
|
||||
if err := clientMsg.Decode(record.Opaque); err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ext := range clientMsg.Extensions {
|
||||
if ext.Type() == dissector.ExtSessionTicket {
|
||||
b, err := ext.Encode()
|
||||
if err != nil {
|
||||
// log.Log(err)
|
||||
return err
|
||||
}
|
||||
c.rbuf.Write(b)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
serverMsg := &dissector.ServerHelloMsg{
|
||||
Version: tls.VersionTLS12,
|
||||
SessionID: clientMsg.SessionID,
|
||||
CipherSuite: 0xcca8,
|
||||
CompressionMethod: 0x00,
|
||||
Extensions: []dissector.Extension{
|
||||
&dissector.RenegotiationInfoExtension{},
|
||||
&dissector.ExtendedMasterSecretExtension{},
|
||||
&dissector.ECPointFormatsExtension{
|
||||
Formats: []uint8{0x00},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
serverMsg.Random.Time = uint32(time.Now().Unix())
|
||||
rand.Read(serverMsg.Random.Opaque[:])
|
||||
b, err := serverMsg.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record = &dissector.Record{
|
||||
Type: dissector.Handshake,
|
||||
Version: tls.VersionTLS10,
|
||||
Opaque: b,
|
||||
}
|
||||
|
||||
if _, err := record.WriteTo(&c.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record = &dissector.Record{
|
||||
Type: dissector.ChangeCipherSpec,
|
||||
Version: tls.VersionTLS12,
|
||||
Opaque: []byte{0x01},
|
||||
}
|
||||
if _, err := record.WriteTo(&c.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *obfsTLSConn) Read(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
return c.rbuf.Read(b)
|
||||
}
|
||||
record := &dissector.Record{}
|
||||
if _, err = record.ReadFrom(c.Conn); err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(b, record.Opaque)
|
||||
_, err = c.rbuf.Write(record.Opaque[n:])
|
||||
return
|
||||
}
|
||||
|
||||
func (c *obfsTLSConn) Write(b []byte) (n int, err error) {
|
||||
if err = c.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
n = len(b)
|
||||
|
||||
for len(b) > 0 {
|
||||
data := b
|
||||
if len(b) > maxTLSDataLen {
|
||||
data = b[:maxTLSDataLen]
|
||||
b = b[maxTLSDataLen:]
|
||||
} else {
|
||||
b = b[:0]
|
||||
}
|
||||
record := &dissector.Record{
|
||||
Type: dissector.AppData,
|
||||
Version: tls.VersionTLS12,
|
||||
Opaque: data,
|
||||
}
|
||||
|
||||
if c.wbuf.Len() > 0 {
|
||||
record.Type = dissector.Handshake
|
||||
record.WriteTo(&c.wbuf)
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = record.WriteTo(c.Conn); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
61
listener/obfs/tls/listener.go
Normal file
61
listener/obfs/tls/listener.go
Normal file
@ -0,0 +1,61 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("otls", NewListener)
|
||||
}
|
||||
|
||||
type obfsListener struct {
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &obfsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *obfsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
l.Listener = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *obfsListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &obfsTLSConn{
|
||||
Conn: c,
|
||||
}, nil
|
||||
}
|
12
listener/obfs/tls/metadata.go
Normal file
12
listener/obfs/tls/metadata.go
Normal file
@ -0,0 +1,12 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
}
|
||||
|
||||
func (l *obfsListener) parseMetadata(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
96
listener/pht/listener.go
Normal file
96
listener/pht/listener.go
Normal file
@ -0,0 +1,96 @@
|
||||
// plain http tunnel
|
||||
|
||||
package pht
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
pht_util "github.com/go-gost/x/internal/util/pht"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("pht", NewListener)
|
||||
registry.ListenerRegistry().Register("phts", NewTLSListener)
|
||||
}
|
||||
|
||||
type phtListener struct {
|
||||
addr net.Addr
|
||||
tlsEnabled bool
|
||||
server *pht_util.Server
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &phtListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &phtListener{
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *phtListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = pht_util.NewServer(
|
||||
l.options.Addr,
|
||||
pht_util.TLSConfigServerOption(l.options.TLSConfig),
|
||||
pht_util.EnableTLSServerOption(l.tlsEnabled),
|
||||
pht_util.BacklogServerOption(l.md.backlog),
|
||||
pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath),
|
||||
pht_util.LoggerServerOption(l.options.Logger),
|
||||
)
|
||||
|
||||
go func() {
|
||||
if err := l.server.ListenAndServe(); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *phtListener) Accept() (conn net.Conn, err error) {
|
||||
conn, err = l.server.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn = metrics.WrapConn(l.options.Service, conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *phtListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *phtListener) Close() (err error) {
|
||||
return l.server.Close()
|
||||
}
|
51
listener/pht/metadata.go
Normal file
51
listener/pht/metadata.go
Normal file
@ -0,0 +1,51 @@
|
||||
package pht
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAuthorizePath = "/authorize"
|
||||
defaultPushPath = "/push"
|
||||
defaultPullPath = "/pull"
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
authorizePath string
|
||||
pushPath string
|
||||
pullPath string
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
authorizePath = "authorizePath"
|
||||
pushPath = "pushPath"
|
||||
pullPath = "pullPath"
|
||||
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
l.md.authorizePath = mdata.GetString(md, authorizePath)
|
||||
if !strings.HasPrefix(l.md.authorizePath, "/") {
|
||||
l.md.authorizePath = defaultAuthorizePath
|
||||
}
|
||||
l.md.pushPath = mdata.GetString(md, pushPath)
|
||||
if !strings.HasPrefix(l.md.pushPath, "/") {
|
||||
l.md.pushPath = defaultPushPath
|
||||
}
|
||||
l.md.pullPath = mdata.GetString(md, pullPath)
|
||||
if !strings.HasPrefix(l.md.pullPath, "/") {
|
||||
l.md.pullPath = defaultPullPath
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
21
listener/quic/conn.go
Normal file
21
listener/quic/conn.go
Normal file
@ -0,0 +1,21 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type quicConn struct {
|
||||
quic.Stream
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
}
|
||||
|
||||
func (c *quicConn) LocalAddr() net.Addr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
func (c *quicConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
151
listener/quic/listener.go
Normal file
151
listener/quic/listener.go
Normal file
@ -0,0 +1,151 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
quic_util "github.com/go-gost/x/internal/util/quic"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("quic", NewListener)
|
||||
}
|
||||
|
||||
type quicListener struct {
|
||||
ln quic.Listener
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &quicListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *quicListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
addr := l.options.Addr
|
||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||
addr = net.JoinHostPort(addr, "0")
|
||||
}
|
||||
|
||||
var laddr *net.UDPAddr
|
||||
laddr, err = net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
conn, err = net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if l.md.cipherKey != nil {
|
||||
conn = quic_util.CipherPacketConn(conn, l.md.cipherKey)
|
||||
}
|
||||
|
||||
config := &quic.Config{
|
||||
KeepAlive: l.md.keepAlive,
|
||||
HandshakeIdleTimeout: l.md.handshakeTimeout,
|
||||
MaxIdleTimeout: l.md.maxIdleTimeout,
|
||||
Versions: []quic.VersionNumber{
|
||||
quic.Version1,
|
||||
quic.VersionDraft29,
|
||||
},
|
||||
}
|
||||
|
||||
tlsCfg := l.options.TLSConfig
|
||||
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
|
||||
|
||||
ln, err := quic.Listen(conn, tlsCfg, config)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *quicListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
conn = metrics.WrapConn(l.options.Service, conn)
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *quicListener) Close() error {
|
||||
return l.ln.Close()
|
||||
}
|
||||
|
||||
func (l *quicListener) Addr() net.Addr {
|
||||
return l.ln.Addr()
|
||||
}
|
||||
|
||||
func (l *quicListener) listenLoop() {
|
||||
for {
|
||||
ctx := context.Background()
|
||||
session, err := l.ln.Accept(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.mux(ctx, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *quicListener) mux(ctx context.Context, session quic.Session) {
|
||||
defer session.CloseWithError(0, "closed")
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream:", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &quicConn{
|
||||
Stream: stream,
|
||||
laddr: session.LocalAddr(),
|
||||
raddr: session.RemoteAddr(),
|
||||
}
|
||||
select {
|
||||
case l.cqueue <- conn:
|
||||
case <-stream.Context().Done():
|
||||
stream.Close()
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", session.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
46
listener/quic/metadata.go
Normal file
46
listener/quic/metadata.go
Normal file
@ -0,0 +1,46 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
keepAlive bool
|
||||
handshakeTimeout time.Duration
|
||||
maxIdleTimeout time.Duration
|
||||
|
||||
cipherKey []byte
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
keepAlive = "keepAlive"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
maxIdleTimeout = "maxIdleTimeout"
|
||||
|
||||
backlog = "backlog"
|
||||
cipherKey = "cipherKey"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
if key := mdata.GetString(md, cipherKey); key != "" {
|
||||
l.md.cipherKey = []byte(key)
|
||||
}
|
||||
|
||||
l.md.keepAlive = mdata.GetBool(md, keepAlive)
|
||||
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
|
||||
l.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)
|
||||
|
||||
return
|
||||
}
|
42
listener/redirect/udp/conn.go
Normal file
42
listener/redirect/udp/conn.go
Normal file
@ -0,0 +1,42 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
)
|
||||
|
||||
type redirConn struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
ttl time.Duration
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (c *redirConn) Read(b []byte) (n int, err error) {
|
||||
if c.ttl > 0 {
|
||||
c.SetReadDeadline(time.Now().Add(c.ttl))
|
||||
defer c.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
c.once.Do(func() {
|
||||
n = copy(b, c.buf)
|
||||
bufpool.Put(&c.buf)
|
||||
c.buf = nil
|
||||
})
|
||||
|
||||
if n == 0 {
|
||||
n, err = c.Conn.Read(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *redirConn) Write(b []byte) (n int, err error) {
|
||||
if c.ttl > 0 {
|
||||
c.SetWriteDeadline(time.Now().Add(c.ttl))
|
||||
defer c.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
68
listener/redirect/udp/listener.go
Normal file
68
listener/redirect/udp/listener.go
Normal file
@ -0,0 +1,68 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("redu", NewListener)
|
||||
}
|
||||
|
||||
type redirectListener struct {
|
||||
ln *net.UDPConn
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &redirectListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *redirectListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveUDPAddr("udp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := l.listenUDP(laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.ln = ln
|
||||
return
|
||||
}
|
||||
|
||||
func (l *redirectListener) Accept() (conn net.Conn, err error) {
|
||||
conn, err = l.accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// conn = metrics.WrapConn(l.options.Service, conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *redirectListener) Addr() net.Addr {
|
||||
return l.ln.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *redirectListener) Close() error {
|
||||
return l.ln.Close()
|
||||
}
|
37
listener/redirect/udp/listener_linux.go
Normal file
37
listener/redirect/udp/listener_linux.go
Normal file
@ -0,0 +1,37 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/LiamHaworth/go-tproxy"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
)
|
||||
|
||||
func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return tproxy.ListenUDP("udp", addr)
|
||||
}
|
||||
|
||||
func (l *redirectListener) accept() (conn net.Conn, err error) {
|
||||
b := bufpool.Get(l.md.readBufferSize)
|
||||
|
||||
n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.ln, *b)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String())
|
||||
|
||||
c, err := tproxy.DialUDP("udp", dstAddr, raddr)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conn = &redirConn{
|
||||
Conn: c,
|
||||
buf: (*b)[:n],
|
||||
ttl: l.md.ttl,
|
||||
}
|
||||
return
|
||||
}
|
16
listener/redirect/udp/listener_other.go
Normal file
16
listener/redirect/udp/listener_other.go
Normal file
@ -0,0 +1,16 @@
|
||||
//go:build !linux
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return nil, errors.New("UDP redirect is not available on non-linux platform")
|
||||
}
|
||||
|
||||
func (l *redirectListener) accept() (conn net.Conn, err error) {
|
||||
return nil, errors.New("UDP redirect is not available on non-linux platform")
|
||||
}
|
36
listener/redirect/udp/metadata.go
Normal file
36
listener/redirect/udp/metadata.go
Normal file
@ -0,0 +1,36 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTTL = 60 * time.Second
|
||||
defaultReadBufferSize = 1024
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
ttl time.Duration
|
||||
readBufferSize int
|
||||
}
|
||||
|
||||
func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
ttl = "ttl"
|
||||
readBufferSize = "readBufferSize"
|
||||
)
|
||||
|
||||
l.md.ttl = mdata.GetDuration(md, ttl)
|
||||
if l.md.ttl <= 0 {
|
||||
l.md.ttl = defaultTTL
|
||||
}
|
||||
|
||||
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
|
||||
if l.md.readBufferSize <= 0 {
|
||||
l.md.readBufferSize = defaultReadBufferSize
|
||||
}
|
||||
|
||||
return
|
||||
}
|
148
listener/ssh/listener.go
Normal file
148
listener/ssh/listener.go
Normal file
@ -0,0 +1,148 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
ssh_util "github.com/go-gost/x/internal/util/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("ssh", NewListener)
|
||||
}
|
||||
|
||||
type sshListener struct {
|
||||
net.Listener
|
||||
config *ssh.ServerConfig
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &sshListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
l.Listener = ln
|
||||
|
||||
config := &ssh.ServerConfig{
|
||||
PasswordCallback: ssh_util.PasswordCallback(l.options.Auther),
|
||||
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
|
||||
}
|
||||
config.AddHostKey(l.md.signer)
|
||||
if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 {
|
||||
config.NoClientAuth = true
|
||||
}
|
||||
|
||||
l.config = config
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshListener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshListener) serveConn(conn net.Conn) {
|
||||
start := time.Now()
|
||||
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
l.logger.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go func() {
|
||||
for newChannel := range chans {
|
||||
// Check the type of channel
|
||||
t := newChannel.ChannelType()
|
||||
switch t {
|
||||
case ssh_util.GostSSHTunnelRequest:
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
l.logger.Warnf("could not accept channel: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
cc := ssh_util.NewConn(conn, channel)
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
default:
|
||||
l.logger.Warnf("unsupported channel type: %s", t)
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
sc.Wait()
|
||||
}
|
68
listener/ssh/metadata.go
Normal file
68
listener/ssh/metadata.go
Normal file
@ -0,0 +1,68 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
tls_util "github.com/go-gost/gost/v3/pkg/common/util/tls"
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
ssh_util "github.com/go-gost/x/internal/util/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
signer ssh.Signer
|
||||
authorizedKeys map[string]bool
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
authorizedKeys = "authorizedKeys"
|
||||
privateKeyFile = "privateKeyFile"
|
||||
passphrase = "passphrase"
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
if key := mdata.GetString(md, privateKeyFile); key != "" {
|
||||
data, err := ioutil.ReadFile(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pp := mdata.GetString(md, passphrase)
|
||||
if pp == "" {
|
||||
l.md.signer, err = ssh.ParsePrivateKey(data)
|
||||
} else {
|
||||
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if l.md.signer == nil {
|
||||
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.md.signer = signer
|
||||
}
|
||||
|
||||
if name := mdata.GetString(md, authorizedKeys); name != "" {
|
||||
m, err := ssh_util.ParseAuthorizedKeysFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.md.authorizedKeys = m
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
199
listener/sshd/listener.go
Normal file
199
listener/sshd/listener.go
Normal file
@ -0,0 +1,199 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
ssh_util "github.com/go-gost/x/internal/util/ssh"
|
||||
sshd_util "github.com/go-gost/x/internal/util/sshd"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
|
||||
const (
|
||||
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
|
||||
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("sshd", NewListener)
|
||||
}
|
||||
|
||||
type sshdListener struct {
|
||||
net.Listener
|
||||
config *ssh.ServerConfig
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &sshdListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshdListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
l.Listener = ln
|
||||
|
||||
config := &ssh.ServerConfig{
|
||||
PasswordCallback: ssh_util.PasswordCallback(l.options.Auther),
|
||||
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
|
||||
}
|
||||
config.AddHostKey(l.md.signer)
|
||||
if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 {
|
||||
config.NoClientAuth = true
|
||||
}
|
||||
|
||||
l.config = config
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshdListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshdListener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshdListener) serveConn(conn net.Conn) {
|
||||
start := time.Now()
|
||||
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
l.logger.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
|
||||
go func() {
|
||||
for newChannel := range chans {
|
||||
// Check the type of channel
|
||||
t := newChannel.ChannelType()
|
||||
switch t {
|
||||
case DirectForwardRequest:
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
l.logger.Warnf("could not accept channel: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
p := directForward{}
|
||||
ssh.Unmarshal(newChannel.ExtraData(), &p)
|
||||
|
||||
l.logger.Debug(p.String())
|
||||
|
||||
if p.Host1 == "<nil>" {
|
||||
p.Host1 = ""
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
cc := sshd_util.NewDirectForwardConn(sc, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
|
||||
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
default:
|
||||
l.logger.Warnf("unsupported channel type: %s", t)
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
switch req.Type {
|
||||
case RemoteForwardRequest:
|
||||
cc := sshd_util.NewRemoteForwardConn(ctx, sc, req)
|
||||
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
req.Reply(false, []byte("connection queue is full"))
|
||||
cc.Close()
|
||||
}
|
||||
default:
|
||||
l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
sc.Wait()
|
||||
}
|
||||
|
||||
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
|
||||
type directForward struct {
|
||||
Host1 string
|
||||
Port1 uint32
|
||||
Host2 string
|
||||
Port2 uint32
|
||||
}
|
||||
|
||||
func (p directForward) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
|
||||
}
|
68
listener/sshd/metadata.go
Normal file
68
listener/sshd/metadata.go
Normal file
@ -0,0 +1,68 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
tls_util "github.com/go-gost/gost/v3/pkg/common/util/tls"
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
ssh_util "github.com/go-gost/x/internal/util/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
signer ssh.Signer
|
||||
authorizedKeys map[string]bool
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
authorizedKeys = "authorizedKeys"
|
||||
privateKeyFile = "privateKeyFile"
|
||||
passphrase = "passphrase"
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
if key := mdata.GetString(md, privateKeyFile); key != "" {
|
||||
data, err := ioutil.ReadFile(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pp := mdata.GetString(md, passphrase)
|
||||
if pp == "" {
|
||||
l.md.signer, err = ssh.ParsePrivateKey(data)
|
||||
} else {
|
||||
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if l.md.signer == nil {
|
||||
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.md.signer = signer
|
||||
}
|
||||
|
||||
if name := mdata.GetString(md, authorizedKeys); name != "" {
|
||||
m, err := ssh_util.ParseAuthorizedKeysFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.md.authorizedKeys = m
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
||||
}
|
96
listener/tap/listener.go
Normal file
96
listener/tap/listener.go
Normal file
@ -0,0 +1,96 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
tap_util "github.com/go-gost/x/internal/util/tap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("tap", NewListener)
|
||||
}
|
||||
|
||||
type tapListener struct {
|
||||
saddr string
|
||||
addr net.Addr
|
||||
cqueue chan net.Conn
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &tapListener{
|
||||
saddr: options.Addr,
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *tapListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.addr, err = net.ResolveUDPAddr("udp", l.saddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, ip, err := l.createTap()
|
||||
if err != nil {
|
||||
if ifce != nil {
|
||||
ifce.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
itf, err := net.InterfaceByName(ifce.Name())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
addrs, _ := itf.Addrs()
|
||||
l.logger.Infof("name: %s, mac: %s, mtu: %d, addrs: %s",
|
||||
itf.Name, itf.HardwareAddr, itf.MTU, addrs)
|
||||
|
||||
l.cqueue = make(chan net.Conn, 1)
|
||||
l.closed = make(chan struct{})
|
||||
|
||||
conn := tap_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip})
|
||||
|
||||
l.cqueue <- conn
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tapListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.cqueue:
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
}
|
||||
|
||||
return nil, listener.ErrClosed
|
||||
}
|
||||
|
||||
func (l *tapListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *tapListener) Close() error {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
close(l.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
44
listener/tap/metadata.go
Normal file
44
listener/tap/metadata.go
Normal file
@ -0,0 +1,44 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
tap_util "github.com/go-gost/x/internal/util/tap"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMTU = 1350
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
config *tap_util.Config
|
||||
}
|
||||
|
||||
func (l *tapListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
name = "name"
|
||||
netKey = "net"
|
||||
mtu = "mtu"
|
||||
routes = "routes"
|
||||
gateway = "gw"
|
||||
)
|
||||
|
||||
config := &tap_util.Config{
|
||||
Name: mdata.GetString(md, name),
|
||||
Net: mdata.GetString(md, netKey),
|
||||
MTU: mdata.GetInt(md, mtu),
|
||||
Gateway: mdata.GetString(md, gateway),
|
||||
}
|
||||
if config.MTU <= 0 {
|
||||
config.MTU = DefaultMTU
|
||||
}
|
||||
|
||||
for _, s := range mdata.GetStrings(md, routes) {
|
||||
if s != "" {
|
||||
config.Routes = append(config.Routes, s)
|
||||
}
|
||||
}
|
||||
|
||||
l.md.config = config
|
||||
|
||||
return
|
||||
}
|
13
listener/tap/tap_darwin.go
Normal file
13
listener/tap/tap_darwin.go
Normal file
@ -0,0 +1,13 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
|
||||
err = errors.New("tap is not supported on darwin")
|
||||
return
|
||||
}
|
69
listener/tap/tap_linux.go
Normal file
69
listener/tap/tap_linux.go
Normal file
@ -0,0 +1,69 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/docker/libcontainer/netlink"
|
||||
"github.com/milosgajdos/tenus"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
|
||||
var ipNet *net.IPNet
|
||||
if l.md.config.Net != "" {
|
||||
ip, ipNet, err = net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TAP,
|
||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
||||
Name: l.md.config.Name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
link, err := tenus.NewLinkFrom(ifce.Name())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU)
|
||||
|
||||
if err = link.SetLinkMTU(l.md.config.MTU); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if l.md.config.Net != "" {
|
||||
l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name())
|
||||
|
||||
if err = link.SetLinkIp(ip, ipNet); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Debugf("ip link set dev %s up", ifce.Name())
|
||||
if err = link.SetLinkUp(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
|
||||
for _, route := range routes {
|
||||
l.logger.Debugf("ip route add %s via %s dev %s", route, gw, ifName)
|
||||
if err := netlink.AddRoute(route, "", gw, ifName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
61
listener/tap/tap_unix.go
Normal file
61
listener/tap/tap_unix.go
Normal file
@ -0,0 +1,61 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package tap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, _, _ = net.ParseCIDR(l.md.config.Net)
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TAP,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cmd string
|
||||
if l.md.config.Net != "" {
|
||||
cmd = fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), l.md.config.Net, l.md.config.MTU)
|
||||
} else {
|
||||
cmd = fmt.Sprintf("ifconfig %s mtu %d up", ifce.Name(), l.md.config.MTU)
|
||||
}
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
|
||||
for _, route := range routes {
|
||||
if route == "" {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf("route add -net %s dev %s", route, ifName)
|
||||
if gw != "" {
|
||||
cmd += " gw " + gw
|
||||
}
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
return fmt.Errorf("%s: %v", cmd, er)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
75
listener/tap/tap_windows.go
Normal file
75
listener/tap/tap_windows.go
Normal file
@ -0,0 +1,75 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tapListener) createTap() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, ipNet, _ := net.ParseCIDR(l.md.config.Net)
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TAP,
|
||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
||||
ComponentID: "tap0901",
|
||||
InterfaceName: l.md.config.Name,
|
||||
Network: l.md.config.Net,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ip != nil && ipNet != nil {
|
||||
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
|
||||
"source=static addr=%s mask=%s gateway=none",
|
||||
ifce.Name(), ip.String(), ipMask(ipNet.Mask))
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tapListener) addRoutes(ifName string, gw string, routes ...string) error {
|
||||
for _, route := range routes {
|
||||
l.deleteRoute(ifName, route)
|
||||
|
||||
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
|
||||
route, ifName)
|
||||
if gw != "" {
|
||||
cmd += " nexthop=" + gw
|
||||
}
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
return fmt.Errorf("%s: %v", cmd, er)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *tapListener) deleteRoute(ifName string, route string) error {
|
||||
cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active",
|
||||
route, ifName)
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
return exec.Command(args[0], args[1:]...).Run()
|
||||
}
|
||||
|
||||
func ipMask(mask net.IPMask) string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3])
|
||||
}
|
96
listener/tun/listener.go
Normal file
96
listener/tun/listener.go
Normal file
@ -0,0 +1,96 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("tun", NewListener)
|
||||
}
|
||||
|
||||
type tunListener struct {
|
||||
addr net.Addr
|
||||
cqueue chan net.Conn
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &tunListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *tunListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, ip, err := l.createTun()
|
||||
if err != nil {
|
||||
if ifce != nil {
|
||||
ifce.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
itf, err := net.InterfaceByName(ifce.Name())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
addrs, _ := itf.Addrs()
|
||||
l.logger.Infof("name: %s, net: %s, mtu: %d, addrs: %s",
|
||||
itf.Name, ip, itf.MTU, addrs)
|
||||
|
||||
l.cqueue = make(chan net.Conn, 1)
|
||||
l.closed = make(chan struct{})
|
||||
|
||||
conn := tun_util.NewConn(l.md.config, ifce, l.addr, &net.IPAddr{IP: ip})
|
||||
|
||||
l.cqueue <- conn
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.cqueue:
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
}
|
||||
|
||||
return nil, listener.ErrClosed
|
||||
}
|
||||
|
||||
func (l *tunListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *tunListener) Close() error {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
close(l.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
63
listener/tun/metadata.go
Normal file
63
listener/tun/metadata.go
Normal file
@ -0,0 +1,63 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMTU = 1350
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
config *tun_util.Config
|
||||
}
|
||||
|
||||
func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
name = "name"
|
||||
netKey = "net"
|
||||
peer = "peer"
|
||||
mtu = "mtu"
|
||||
routes = "routes"
|
||||
gateway = "gw"
|
||||
)
|
||||
|
||||
config := &tun_util.Config{
|
||||
Name: mdata.GetString(md, name),
|
||||
Net: mdata.GetString(md, netKey),
|
||||
Peer: mdata.GetString(md, peer),
|
||||
MTU: mdata.GetInt(md, mtu),
|
||||
Gateway: mdata.GetString(md, gateway),
|
||||
}
|
||||
if config.MTU <= 0 {
|
||||
config.MTU = DefaultMTU
|
||||
}
|
||||
|
||||
gw := net.ParseIP(config.Gateway)
|
||||
|
||||
for _, s := range mdata.GetStrings(md, routes) {
|
||||
ss := strings.SplitN(s, " ", 2)
|
||||
if len(ss) == 2 {
|
||||
var route tun_util.Route
|
||||
_, ipNet, _ := net.ParseCIDR(strings.TrimSpace(ss[0]))
|
||||
if ipNet == nil {
|
||||
continue
|
||||
}
|
||||
route.Net = *ipNet
|
||||
route.Gateway = net.ParseIP(ss[1])
|
||||
if route.Gateway == nil {
|
||||
route.Gateway = gw
|
||||
}
|
||||
|
||||
config.Routes = append(config.Routes, route)
|
||||
}
|
||||
}
|
||||
|
||||
l.md.config = config
|
||||
|
||||
return
|
||||
}
|
57
listener/tun/tun_darwin.go
Normal file
57
listener/tun/tun_darwin.go
Normal file
@ -0,0 +1,57 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, _, err = net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
peer := l.md.config.Peer
|
||||
if peer == "" {
|
||||
peer = ip.String()
|
||||
}
|
||||
|
||||
cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up",
|
||||
ifce.Name(), l.md.config.Net, l.md.config.Peer, l.md.config.MTU)
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if err = exec.Command(args[0], args[1:]...).Run(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
|
||||
for _, route := range routes {
|
||||
cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName)
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
if err := exec.Command(args[0], args[1:]...).Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
67
listener/tun/tun_linux.go
Normal file
67
listener/tun/tun_linux.go
Normal file
@ -0,0 +1,67 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"github.com/docker/libcontainer/netlink"
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
"github.com/milosgajdos/tenus"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
||||
Name: l.md.config.Name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
link, err := tenus.NewLinkFrom(ifce.Name())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Debugf("ip link set dev %s mtu %d", ifce.Name(), l.md.config.MTU)
|
||||
|
||||
if err = link.SetLinkMTU(l.md.config.MTU); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Debugf("ip address add %s dev %s", l.md.config.Net, ifce.Name())
|
||||
|
||||
if err = link.SetLinkIp(ip, ipNet); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Debugf("ip link set dev %s up", ifce.Name())
|
||||
if err = link.SetLinkUp(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
|
||||
for _, route := range routes {
|
||||
l.logger.Debugf("ip route add %s dev %s", route.Net.String(), ifName)
|
||||
if err := netlink.AddRoute(route.Net.String(), "", "", ifName); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
55
listener/tun/tun_unix.go
Normal file
55
listener/tun/tun_unix.go
Normal file
@ -0,0 +1,55 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, _, err = net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up",
|
||||
ifce.Name(), l.md.config.Net, l.md.config.MTU)
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error {
|
||||
for _, route := range routes {
|
||||
cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName)
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
return fmt.Errorf("%s: %v", cmd, er)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
77
listener/tun/tun_windows.go
Normal file
77
listener/tun/tun_windows.go
Normal file
@ -0,0 +1,77 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
func (l *tunListener) createTun() (ifce *water.Interface, ip net.IP, err error) {
|
||||
ip, ipNet, err := net.ParseCIDR(l.md.config.Net)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ifce, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
||||
ComponentID: "tap0901",
|
||||
InterfaceName: l.md.config.Name,
|
||||
Network: l.md.config.Net,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := fmt.Sprintf("netsh interface ip set address name=%s "+
|
||||
"source=static addr=%s mask=%s gateway=none",
|
||||
ifce.Name(), ip.String(), ipMask(ipNet.Mask))
|
||||
l.logger.Debug(cmd)
|
||||
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
err = fmt.Errorf("%s: %v", cmd, er)
|
||||
return
|
||||
}
|
||||
|
||||
if err = l.addRoutes(ifce.Name(), l.md.config.Gateway, l.md.config.Routes...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *tunListener) addRoutes(ifName string, gw string, routes ...tun_util.Route) error {
|
||||
for _, route := range routes {
|
||||
l.deleteRoute(ifName, route.Net.String())
|
||||
|
||||
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active",
|
||||
route.Net.String(), ifName)
|
||||
if gw != "" {
|
||||
cmd += " nexthop=" + gw
|
||||
}
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
|
||||
return fmt.Errorf("%s: %v", cmd, er)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *tunListener) deleteRoute(ifName string, route string) error {
|
||||
cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active",
|
||||
route, ifName)
|
||||
l.logger.Debug(cmd)
|
||||
args := strings.Split(cmd, " ")
|
||||
return exec.Command(args[0], args[1:]...).Run()
|
||||
}
|
||||
|
||||
func ipMask(mask net.IPMask) string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3])
|
||||
}
|
149
listener/ws/listener.go
Normal file
149
listener/ws/listener.go
Normal file
@ -0,0 +1,149 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/admission"
|
||||
"github.com/go-gost/gost/v3/pkg/common/metrics"
|
||||
"github.com/go-gost/gost/v3/pkg/listener"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
ws_util "github.com/go-gost/x/internal/util/ws"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListenerRegistry().Register("ws", NewListener)
|
||||
registry.ListenerRegistry().Register("wss", NewTLSListener)
|
||||
}
|
||||
|
||||
type wsListener struct {
|
||||
addr net.Addr
|
||||
upgrader *websocket.Upgrader
|
||||
srv *http.Server
|
||||
tlsEnabled bool
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &wsListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTLSListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &wsListener{
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *wsListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.upgrader = &websocket.Upgrader{
|
||||
HandshakeTimeout: l.md.handshakeTimeout,
|
||||
ReadBufferSize: l.md.readBufferSize,
|
||||
WriteBufferSize: l.md.writeBufferSize,
|
||||
EnableCompression: l.md.enableCompression,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
|
||||
l.srv = &http.Server{
|
||||
Addr: l.options.Addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: l.md.readHeaderTimeout,
|
||||
}
|
||||
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
ln, err := net.Listen("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = admission.WrapListener(l.options.Admission, ln)
|
||||
|
||||
if l.tlsEnabled {
|
||||
ln = tls.NewListener(ln, l.options.TLSConfig)
|
||||
}
|
||||
|
||||
l.addr = ln.Addr()
|
||||
|
||||
go func() {
|
||||
err := l.srv.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *wsListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *wsListener) Close() error {
|
||||
return l.srv.Close()
|
||||
}
|
||||
|
||||
func (l *wsListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
if l.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
log := l.logger.WithFields(map[string]any{
|
||||
"local": l.addr.String(),
|
||||
"remote": r.RemoteAddr,
|
||||
})
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.cqueue <- ws_util.Conn(conn):
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
}
|
||||
}
|
66
listener/ws/metadata.go
Normal file
66
listener/ws/metadata.go
Normal file
@ -0,0 +1,66 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPath = "/ws"
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
path string
|
||||
backlog int
|
||||
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
path = "path"
|
||||
backlog = "backlog"
|
||||
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
|
||||
header = "header"
|
||||
)
|
||||
|
||||
l.md.path = mdata.GetString(md, path)
|
||||
if l.md.path == "" {
|
||||
l.md.path = defaultPath
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
|
||||
l.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
|
||||
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
|
||||
l.md.writeBufferSize = mdata.GetInt(md, writeBufferSize)
|
||||
l.md.enableCompression = mdata.GetBool(md, enableCompression)
|
||||
|
||||
if mm := mdata.GetStringMapString(md, header); len(mm) > 0 {
|
||||
hd := http.Header{}
|
||||
for k, v := range mm {
|
||||
hd.Add(k, v)
|
||||
}
|
||||
l.md.header = hd
|
||||
}
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user