mv non-core components to extended repo
This commit is contained in:
@ -1,285 +0,0 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
readBufferSize = 1500
|
||||
writeBufferSize = 1500
|
||||
magicNumber = 0x474F5354
|
||||
)
|
||||
|
||||
const (
|
||||
messageHeaderLen = 10
|
||||
)
|
||||
|
||||
const (
|
||||
FlagAck = 1
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidPacket = errors.New("icmp: invalid packet")
|
||||
ErrInvalidType = errors.New("icmp: invalid type")
|
||||
ErrShortBuffer = errors.New("icmp: short buffer")
|
||||
)
|
||||
|
||||
type message struct {
|
||||
// magic uint32 // magic number
|
||||
flags uint16 // flags
|
||||
// rsv uint16 // reserved field
|
||||
// len uint16 // length of data
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (m *message) Encode(b []byte) (n int, err error) {
|
||||
if len(b) < messageHeaderLen+len(m.data) {
|
||||
err = ErrShortBuffer
|
||||
return
|
||||
}
|
||||
binary.BigEndian.PutUint32(b[:4], magicNumber) // magic number
|
||||
binary.BigEndian.PutUint16(b[4:6], m.flags) // flags
|
||||
binary.BigEndian.PutUint16(b[6:8], 0) // reserved
|
||||
binary.BigEndian.PutUint16(b[8:10], uint16(len(m.data)))
|
||||
copy(b[messageHeaderLen:], m.data)
|
||||
|
||||
n = messageHeaderLen + len(m.data)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *message) Decode(b []byte) (n int, err error) {
|
||||
if len(b) < messageHeaderLen {
|
||||
err = ErrShortBuffer
|
||||
return
|
||||
}
|
||||
if binary.BigEndian.Uint32(b[:4]) != magicNumber {
|
||||
err = ErrInvalidPacket
|
||||
return
|
||||
}
|
||||
m.flags = binary.BigEndian.Uint16(b[4:6])
|
||||
length := binary.BigEndian.Uint16(b[8:10])
|
||||
if len(b[messageHeaderLen:]) < int(length) {
|
||||
err = ErrShortBuffer
|
||||
return
|
||||
}
|
||||
m.data = b[messageHeaderLen : messageHeaderLen+length]
|
||||
|
||||
n = messageHeaderLen + int(length)
|
||||
return
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
net.PacketConn
|
||||
id int
|
||||
seq uint32
|
||||
}
|
||||
|
||||
func ClientConn(conn net.PacketConn, id int) net.PacketConn {
|
||||
return &clientConn{
|
||||
PacketConn: conn,
|
||||
id: id,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
buf := bufpool.Get(readBufferSize)
|
||||
defer bufpool.Put(buf)
|
||||
|
||||
for {
|
||||
n, addr, err = c.PacketConn.ReadFrom(*buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := icmp.ParseMessage(1, (*buf)[:n])
|
||||
if err != nil {
|
||||
// logger.Default().Error("icmp: parse message %v", err)
|
||||
return 0, addr, err
|
||||
}
|
||||
echo, ok := m.Body.(*icmp.Echo)
|
||||
if !ok || m.Type != ipv4.ICMPTypeEchoReply {
|
||||
// logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type)
|
||||
continue // discard
|
||||
}
|
||||
|
||||
if echo.ID != c.id {
|
||||
// logger.Default().Warnf("icmp: id mismatch got %d, should be %d (discarded)", echo.ID, c.id)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := message{}
|
||||
if _, err := msg.Decode(echo.Data); err != nil {
|
||||
logger.Default().Warn(err)
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.flags&FlagAck == 0 {
|
||||
// logger.Default().Warn("icmp: invalid message (discarded)")
|
||||
continue
|
||||
}
|
||||
n = copy(b, msg.data)
|
||||
break
|
||||
}
|
||||
|
||||
if v, ok := addr.(*net.IPAddr); ok {
|
||||
addr = &net.UDPAddr{
|
||||
IP: v.IP,
|
||||
Port: c.id,
|
||||
}
|
||||
}
|
||||
// logger.Default().Infof("icmp: read from: %v %d", addr, n)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
// logger.Default().Infof("icmp: write to: %v %d", addr, len(b))
|
||||
switch v := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
addr = &net.IPAddr{IP: v.IP}
|
||||
}
|
||||
|
||||
buf := bufpool.Get(writeBufferSize)
|
||||
defer bufpool.Put(buf)
|
||||
|
||||
msg := message{
|
||||
data: b,
|
||||
}
|
||||
nn, err := msg.Encode(*buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
echo := icmp.Echo{
|
||||
ID: c.id,
|
||||
Seq: int(atomic.AddUint32(&c.seq, 1)),
|
||||
Data: (*buf)[:nn],
|
||||
}
|
||||
m := icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &echo,
|
||||
}
|
||||
wb, err := m.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(wb, addr)
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
net.PacketConn
|
||||
seqs [65535]uint32
|
||||
}
|
||||
|
||||
func ServerConn(conn net.PacketConn) net.PacketConn {
|
||||
return &serverConn{
|
||||
PacketConn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
buf := bufpool.Get(readBufferSize)
|
||||
defer bufpool.Put(buf)
|
||||
|
||||
for {
|
||||
n, addr, err = c.PacketConn.ReadFrom(*buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := icmp.ParseMessage(1, (*buf)[:n])
|
||||
if err != nil {
|
||||
// logger.Default().Error("icmp: parse message %v", err)
|
||||
return 0, addr, err
|
||||
}
|
||||
|
||||
echo, ok := m.Body.(*icmp.Echo)
|
||||
if !ok || m.Type != ipv4.ICMPTypeEcho || echo.ID <= 0 {
|
||||
// logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type)
|
||||
continue
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.seqs[uint16(echo.ID-1)], uint32(echo.Seq))
|
||||
|
||||
msg := message{}
|
||||
if _, err := msg.Decode(echo.Data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.flags&FlagAck > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
n = copy(b, msg.data)
|
||||
|
||||
if v, ok := addr.(*net.IPAddr); ok {
|
||||
addr = &net.UDPAddr{
|
||||
IP: v.IP,
|
||||
Port: echo.ID,
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// logger.Default().Infof("icmp: read from: %v %d", addr, n)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
// logger.Default().Infof("icmp: write to: %v %d", addr, len(b))
|
||||
var id int
|
||||
switch v := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
addr = &net.IPAddr{IP: v.IP}
|
||||
id = v.Port
|
||||
}
|
||||
|
||||
if id <= 0 || id > math.MaxUint16 {
|
||||
err = fmt.Errorf("icmp: invalid message id %v", addr)
|
||||
return
|
||||
}
|
||||
|
||||
buf := bufpool.Get(writeBufferSize)
|
||||
defer bufpool.Put(buf)
|
||||
|
||||
msg := message{
|
||||
flags: FlagAck,
|
||||
data: b,
|
||||
}
|
||||
nn, err := msg.Encode(*buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
echo := icmp.Echo{
|
||||
ID: id,
|
||||
Seq: int(atomic.LoadUint32(&c.seqs[id-1])),
|
||||
Data: (*buf)[:nn],
|
||||
}
|
||||
m := icmp.Message{
|
||||
Type: ipv4.ICMPTypeEchoReply,
|
||||
Code: 0,
|
||||
Body: &echo,
|
||||
}
|
||||
wb, err := m.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(wb, addr)
|
||||
n = len(b)
|
||||
return
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
package pht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Host string
|
||||
Client *http.Client
|
||||
AuthorizePath string
|
||||
PushPath string
|
||||
PullPath string
|
||||
TLSEnabled bool
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
func (c *Client) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||
raddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
c.Logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.Host != "" {
|
||||
addr = net.JoinHostPort(c.Host, strconv.Itoa(raddr.Port))
|
||||
}
|
||||
|
||||
token, err := c.authorize(ctx, addr)
|
||||
if err != nil {
|
||||
c.Logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn := &clientConn{
|
||||
client: c.Client,
|
||||
rxc: make(chan []byte, 128),
|
||||
closed: make(chan struct{}),
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: raddr,
|
||||
logger: c.Logger,
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if c.TLSEnabled {
|
||||
scheme = "https"
|
||||
}
|
||||
cn.pushURL = fmt.Sprintf("%s://%s%s?token=%s", scheme, addr, c.PushPath, token)
|
||||
cn.pullURL = fmt.Sprintf("%s://%s%s?token=%s", scheme, addr, c.PullPath, token)
|
||||
|
||||
go cn.readLoop()
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *Client) authorize(ctx context.Context, addr string) (token string, err error) {
|
||||
var url string
|
||||
if c.TLSEnabled {
|
||||
url = fmt.Sprintf("https://%s%s", addr, c.AuthorizePath)
|
||||
} else {
|
||||
url = fmt.Sprintf("http://%s%s", addr, c.AuthorizePath)
|
||||
}
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
c.Logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
resp, err := c.Client.Do(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if c.Logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
c.Logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(data), "token=") {
|
||||
token = strings.TrimPrefix(string(data), "token=")
|
||||
}
|
||||
if token == "" {
|
||||
err = errors.New("authorize failed")
|
||||
}
|
||||
return
|
||||
}
|
@ -1,176 +0,0 @@
|
||||
package pht
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type clientConn struct {
|
||||
client *http.Client
|
||||
pushURL string
|
||||
pullURL string
|
||||
buf []byte
|
||||
rxc chan []byte
|
||||
closed chan struct{}
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(b []byte) (n int, err error) {
|
||||
if len(c.buf) == 0 {
|
||||
select {
|
||||
case c.buf = <-c.rxc:
|
||||
case <-c.closed:
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n = copy(b, c.buf)
|
||||
c.buf = c.buf[n:]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(b []byte) (n int, err error) {
|
||||
if len(b) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b))
|
||||
buf.WriteByte('\n')
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, c.pushURL, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err = errors.New(resp.Status)
|
||||
return
|
||||
}
|
||||
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientConn) readLoop() {
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
err := func() error {
|
||||
r, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if c.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
c.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
b, err := base64.StdEncoding.DecodeString(scanner.Text())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case c.rxc <- b:
|
||||
case <-c.closed:
|
||||
return net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
c.logger.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientConn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *clientConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *clientConn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *serverConn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
@ -1,344 +0,0 @@
|
||||
package pht
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/http3"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type serverOptions struct {
|
||||
authorizePath string
|
||||
pushPath string
|
||||
pullPath string
|
||||
backlog int
|
||||
tlsEnabled bool
|
||||
tlsConfig *tls.Config
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type ServerOption func(opts *serverOptions)
|
||||
|
||||
func PathServerOption(authorizePath, pushPath, pullPath string) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.authorizePath = authorizePath
|
||||
opts.pullPath = pullPath
|
||||
opts.pushPath = pushPath
|
||||
}
|
||||
}
|
||||
|
||||
func BacklogServerOption(backlog int) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.backlog = backlog
|
||||
}
|
||||
}
|
||||
|
||||
func TLSConfigServerOption(tlsConfig *tls.Config) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.tlsConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
func EnableTLSServerOption(enable bool) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.tlsEnabled = enable
|
||||
}
|
||||
}
|
||||
|
||||
func LoggerServerOption(logger logger.Logger) ServerOption {
|
||||
return func(opts *serverOptions) {
|
||||
opts.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove stale clients from conns
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
httpServer *http.Server
|
||||
http3Server *http3.Server
|
||||
cqueue chan net.Conn
|
||||
conns sync.Map
|
||||
closed chan struct{}
|
||||
|
||||
options serverOptions
|
||||
}
|
||||
|
||||
func NewServer(addr string, opts ...ServerOption) *Server {
|
||||
var options serverOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.backlog <= 0 {
|
||||
options.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
httpServer: &http.Server{
|
||||
Addr: addr,
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
},
|
||||
cqueue: make(chan net.Conn, options.backlog),
|
||||
closed: make(chan struct{}),
|
||||
options: options,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(options.authorizePath, s.handleAuthorize)
|
||||
mux.HandleFunc(options.pushPath, s.handlePush)
|
||||
mux.HandleFunc(options.pullPath, s.handlePull)
|
||||
s.httpServer.Handler = mux
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func NewHTTP3Server(addr string, quicConfig *quic.Config, opts ...ServerOption) *Server {
|
||||
var options serverOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.backlog <= 0 {
|
||||
options.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
http3Server: &http3.Server{
|
||||
Server: &http.Server{
|
||||
Addr: addr,
|
||||
TLSConfig: options.tlsConfig,
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
},
|
||||
QuicConfig: quicConfig,
|
||||
},
|
||||
cqueue: make(chan net.Conn, options.backlog),
|
||||
closed: make(chan struct{}),
|
||||
options: options,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(options.authorizePath, s.handleAuthorize)
|
||||
mux.HandleFunc(options.pushPath, s.handlePush)
|
||||
mux.HandleFunc(options.pullPath, s.handlePull)
|
||||
s.http3Server.Handler = mux
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe() error {
|
||||
if s.http3Server != nil {
|
||||
addr, err := net.ResolveUDPAddr("udp", s.http3Server.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.addr = addr
|
||||
return s.http3Server.ListenAndServe()
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", s.httpServer.Addr)
|
||||
if err != nil {
|
||||
s.options.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.addr = ln.Addr()
|
||||
if s.options.tlsEnabled {
|
||||
s.httpServer.TLSConfig = s.options.tlsConfig
|
||||
ln = tls.NewListener(ln, s.options.tlsConfig)
|
||||
}
|
||||
|
||||
return s.httpServer.Serve(ln)
|
||||
}
|
||||
|
||||
func (s *Server) Accept() (conn net.Conn, err error) {
|
||||
select {
|
||||
case conn = <-s.cqueue:
|
||||
case <-s.closed:
|
||||
err = http.ErrServerClosed
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return http.ErrServerClosed
|
||||
default:
|
||||
close(s.closed)
|
||||
|
||||
if s.http3Server != nil {
|
||||
return s.http3Server.Close()
|
||||
}
|
||||
return s.httpServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if raddr == nil {
|
||||
raddr = &net.TCPAddr{}
|
||||
}
|
||||
|
||||
// connection id
|
||||
cid := xid.New().String()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
c := &serverConn{
|
||||
Conn: c1,
|
||||
localAddr: s.addr,
|
||||
remoteAddr: raddr,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cqueue <- c:
|
||||
default:
|
||||
c.Close()
|
||||
s.options.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
w.Write([]byte(fmt.Sprintf("token=%s", cid)))
|
||||
s.conns.Store(cid, c2)
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cid := r.Form.Get("token")
|
||||
v, ok := s.conns.Load(cid)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
conn := v.(net.Conn)
|
||||
|
||||
br := bufio.NewReader(r.Body)
|
||||
data, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
s.options.logger.Error(err)
|
||||
conn.Close()
|
||||
s.conns.Delete(cid)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
data = strings.TrimSuffix(data, "\n")
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
s.options.logger.Error(err)
|
||||
s.conns.Delete(cid)
|
||||
conn.Close()
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
if _, err := conn.Write(b); err != nil {
|
||||
s.options.logger.Error(err)
|
||||
s.conns.Delete(cid)
|
||||
conn.Close()
|
||||
w.WriteHeader(http.StatusGone)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) {
|
||||
if s.options.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
s.options.logger.Debug(string(dump))
|
||||
}
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cid := r.Form.Get("token")
|
||||
v, ok := s.conns.Load(cid)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
conn := v.(net.Conn)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
n, err := conn.Read(*b)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.options.logger.Error(err)
|
||||
s.conns.Delete(cid)
|
||||
conn.Close()
|
||||
} else {
|
||||
(*b)[0] = '\n'
|
||||
w.Write((*b)[:1])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
bw := bufio.NewWriter(w)
|
||||
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
|
||||
bw.WriteString("\n")
|
||||
if err := bw.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type cipherConn struct {
|
||||
net.PacketConn
|
||||
key []byte
|
||||
}
|
||||
|
||||
func CipherPacketConn(conn net.PacketConn, key []byte) net.PacketConn {
|
||||
return &cipherConn{
|
||||
PacketConn: conn,
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *cipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = conn.PacketConn.ReadFrom(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b, err := conn.decrypt(data[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
copy(data, b)
|
||||
|
||||
return len(b), addr, nil
|
||||
}
|
||||
|
||||
func (conn *cipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
|
||||
b, err := conn.encrypt(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.PacketConn.WriteTo(b, addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (conn *cipherConn) encrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (conn *cipherConn) decrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
@ -5,7 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
172
pkg/internal/util/socks/conn.go
Normal file
172
pkg/internal/util/socks/conn.go
Normal file
@ -0,0 +1,172 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
)
|
||||
|
||||
type udpTunConn struct {
|
||||
net.Conn
|
||||
taddr net.Addr
|
||||
}
|
||||
|
||||
func UDPTunClientConn(c net.Conn, targetAddr net.Addr) net.Conn {
|
||||
return &udpTunConn{
|
||||
Conn: c,
|
||||
taddr: targetAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func UDPTunClientPacketConn(c net.Conn) net.PacketConn {
|
||||
return &udpTunConn{
|
||||
Conn: c,
|
||||
}
|
||||
}
|
||||
|
||||
func UDPTunServerConn(c net.Conn) net.PacketConn {
|
||||
return &udpTunConn{
|
||||
Conn: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
socksAddr := gosocks5.Addr{}
|
||||
header := gosocks5.UDPHeader{
|
||||
Addr: &socksAddr,
|
||||
}
|
||||
dgram := gosocks5.UDPDatagram{
|
||||
Header: &header,
|
||||
Data: b,
|
||||
}
|
||||
_, err = dgram.ReadFrom(c.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n = len(dgram.Data)
|
||||
if n > len(b) {
|
||||
n = copy(b, dgram.Data)
|
||||
}
|
||||
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpTunConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
socksAddr := gosocks5.Addr{}
|
||||
if err = socksAddr.ParseFrom(addr.String()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
header := gosocks5.UDPHeader{
|
||||
Addr: &socksAddr,
|
||||
}
|
||||
dgram := gosocks5.UDPDatagram{
|
||||
Header: &header,
|
||||
Data: b,
|
||||
}
|
||||
dgram.Header.Rsv = uint16(len(dgram.Data))
|
||||
dgram.Header.Frag = 0xff // UDP tun relay flag, used by shadowsocks
|
||||
_, err = dgram.WriteTo(c.Conn)
|
||||
n = len(b)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpTunConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.taddr)
|
||||
}
|
||||
|
||||
var (
|
||||
DefaultBufferSize = 4096
|
||||
)
|
||||
|
||||
type udpConn struct {
|
||||
net.PacketConn
|
||||
raddr net.Addr
|
||||
taddr net.Addr
|
||||
bufferSize int
|
||||
}
|
||||
|
||||
func UDPConn(c net.PacketConn, bufferSize int) net.PacketConn {
|
||||
return &udpConn{
|
||||
PacketConn: c,
|
||||
bufferSize: bufferSize,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads an UDP datagram.
|
||||
// NOTE: for server side,
|
||||
// the returned addr is the target address the client want to relay to.
|
||||
func (c *udpConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
rbuf := bufpool.Get(c.bufferSize)
|
||||
defer bufpool.Put(rbuf)
|
||||
|
||||
n, c.raddr, err = c.PacketConn.ReadFrom(*rbuf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
socksAddr := gosocks5.Addr{}
|
||||
header := gosocks5.UDPHeader{
|
||||
Addr: &socksAddr,
|
||||
}
|
||||
hlen, err := header.ReadFrom(bytes.NewReader((*rbuf)[:n]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = copy(b, (*rbuf)[hlen:n])
|
||||
|
||||
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
wbuf := bufpool.Get(c.bufferSize)
|
||||
defer bufpool.Put(wbuf)
|
||||
|
||||
socksAddr := gosocks5.Addr{}
|
||||
if err = socksAddr.ParseFrom(addr.String()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
header := gosocks5.UDPHeader{
|
||||
Addr: &socksAddr,
|
||||
}
|
||||
dgram := gosocks5.UDPDatagram{
|
||||
Header: &header,
|
||||
Data: b,
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer((*wbuf)[:0])
|
||||
_, err = dgram.WriteTo(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
|
||||
n = len(b)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.taddr)
|
||||
}
|
||||
|
||||
func (c *udpConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
18
pkg/internal/util/socks/socks.go
Normal file
18
pkg/internal/util/socks/socks.go
Normal file
@ -0,0 +1,18 @@
|
||||
package socks
|
||||
|
||||
const (
|
||||
// MethodTLS is an extended SOCKS5 method with tls encryption support.
|
||||
MethodTLS uint8 = 0x80
|
||||
// MethodTLSAuth is an extended SOCKS5 method with tls encryption and authentication support.
|
||||
MethodTLSAuth uint8 = 0x82
|
||||
// MethodMux is an extended SOCKS5 method for stream multiplexing.
|
||||
MethodMux = 0x88
|
||||
)
|
||||
|
||||
const (
|
||||
// CmdMuxBind is an extended SOCKS5 request CMD for
|
||||
// multiplexing transport with the binding server.
|
||||
CmdMuxBind uint8 = 0xF2
|
||||
// CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP.
|
||||
CmdUDPTun uint8 = 0xF3
|
||||
)
|
@ -1,48 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// a dummy ssh client conn used by client connector
|
||||
type ClientConn struct {
|
||||
net.Conn
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn, client *ssh.Client) net.Conn {
|
||||
return &ClientConn{
|
||||
Conn: conn,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) Client() *ssh.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
type sshConn struct {
|
||||
channel ssh.Channel
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, channel ssh.Channel) net.Conn {
|
||||
return &sshConn{
|
||||
Conn: conn,
|
||||
channel: channel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *sshConn) Read(b []byte) (n int, err error) {
|
||||
return c.channel.Read(b)
|
||||
}
|
||||
|
||||
func (c *sshConn) Write(b []byte) (n int, err error) {
|
||||
return c.channel.Write(b)
|
||||
}
|
||||
|
||||
func (c *sshConn) Close() error {
|
||||
return c.channel.Close()
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/go-gost/gost/pkg/auth"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionDead = errors.New("session is dead")
|
||||
)
|
||||
|
||||
// PasswordCallbackFunc is a callback function used by SSH server.
|
||||
// It authenticates user using a password.
|
||||
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
||||
|
||||
func PasswordCallback(au auth.Authenticator) PasswordCallbackFunc {
|
||||
if au == nil {
|
||||
return nil
|
||||
}
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if au.Authenticate(conn.User(), string(password)) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("password rejected for %s", conn.User())
|
||||
}
|
||||
}
|
||||
|
||||
// PublicKeyCallbackFunc is a callback function used by SSH server.
|
||||
// It offers a public key for authentication.
|
||||
type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error)
|
||||
|
||||
func PublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if keys[string(pubKey.Marshal())] {
|
||||
return &ssh.Permissions{
|
||||
// Record the public key used for authentication.
|
||||
Extensions: map[string]string{
|
||||
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
||||
}
|
||||
}
|
||||
|
||||
// ParseSSHAuthorizedKeysFile parses ssh authorized keys file.
|
||||
func ParseAuthorizedKeysFile(name string) (map[string]bool, error) {
|
||||
authorizedKeysBytes, err := ioutil.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authorizedKeysMap := make(map[string]bool)
|
||||
for len(authorizedKeysBytes) > 0 {
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = true
|
||||
authorizedKeysBytes = rest
|
||||
}
|
||||
|
||||
return authorizedKeysMap, nil
|
||||
}
|
@ -1,118 +0,0 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type DirectForwardConn struct {
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
dstAddr string
|
||||
}
|
||||
|
||||
func NewDirectForwardConn(conn ssh.Conn, channel ssh.Channel, dstAddr string) net.Conn {
|
||||
return &DirectForwardConn{
|
||||
conn: conn,
|
||||
channel: channel,
|
||||
dstAddr: dstAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Read(b []byte) (n int, err error) {
|
||||
return c.channel.Read(b)
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Write(b []byte) (n int, err error) {
|
||||
return c.channel.Write(b)
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Close() error {
|
||||
return c.channel.Close()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) DstAddr() string {
|
||||
return c.dstAddr
|
||||
}
|
||||
|
||||
type RemoteForwardConn struct {
|
||||
ctx context.Context
|
||||
conn ssh.Conn
|
||||
req *ssh.Request
|
||||
}
|
||||
|
||||
func NewRemoteForwardConn(ctx context.Context, conn ssh.Conn, req *ssh.Request) net.Conn {
|
||||
return &RemoteForwardConn{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
req: req,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Conn() ssh.Conn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Request() *ssh.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Read(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Write(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Close() error {
|
||||
return &net.OpError{Op: "close", Net: "nop", Source: nil, Addr: nil, Err: errors.New("close not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Done() <-chan struct{} {
|
||||
return c.ctx.Done()
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
package tap
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Net string
|
||||
MTU int
|
||||
Routes []string
|
||||
Gateway string
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
config *Config
|
||||
ifce *water.Interface
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
}
|
||||
|
||||
func NewConn(config *Config, ifce *water.Interface, laddr, raddr net.Addr) *Conn {
|
||||
return &Conn{
|
||||
config: config,
|
||||
ifce: ifce,
|
||||
laddr: laddr,
|
||||
raddr: raddr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Config() *Config {
|
||||
return c.config
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
return c.ifce.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return c.ifce.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
return c.ifce.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
32
pkg/internal/util/tcp.go
Normal file
32
pkg/internal/util/tcp.go
Normal file
@ -0,0 +1,32 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepAlivePeriod = 180 * time.Second
|
||||
)
|
||||
|
||||
// TCPKeepAliveListener is a TCP listener with keep alive enabled.
|
||||
type TCPKeepAliveListener struct {
|
||||
KeepAlivePeriod time.Duration
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (l *TCPKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := l.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tc.SetKeepAlive(true)
|
||||
period := l.KeepAlivePeriod
|
||||
if period <= 0 {
|
||||
period = defaultKeepAlivePeriod
|
||||
}
|
||||
tc.SetKeepAlivePeriod(period)
|
||||
|
||||
return tc, nil
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
package tun
|
||||
|
||||
import "net"
|
||||
|
||||
// Route is an IP routing entry
|
||||
type Route struct {
|
||||
Net net.IPNet
|
||||
Gateway net.IP
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Net string
|
||||
// peer addr of point-to-point on MacOS
|
||||
Peer string
|
||||
MTU int
|
||||
Gateway string
|
||||
Routes []Route
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
config *Config
|
||||
ifce *water.Interface
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
}
|
||||
|
||||
func NewConn(config *Config, ifce *water.Interface, laddr, raddr net.Addr) *Conn {
|
||||
return &Conn{
|
||||
config: config,
|
||||
ifce: ifce,
|
||||
laddr: laddr,
|
||||
raddr: raddr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Config() *Config {
|
||||
return c.config
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
return c.ifce.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return c.ifce.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
return c.ifce.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
102
pkg/internal/util/udp/conn.go
Normal file
102
pkg/internal/util/udp/conn.go
Normal file
@ -0,0 +1,102 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
)
|
||||
|
||||
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type Conn struct {
|
||||
net.PacketConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
idle int32 // indicate the connection is idle
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
|
||||
return &Conn{
|
||||
PacketConn: c,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
rc: make(chan []byte, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
c.SetIdle(false)
|
||||
bufpool.Put(&bb)
|
||||
|
||||
case <-c.closed:
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.remoteAddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) IsIdle() bool {
|
||||
return atomic.LoadInt32(&c.idle) > 0
|
||||
}
|
||||
|
||||
func (c *Conn) SetIdle(idle bool) {
|
||||
v := int32(0)
|
||||
if idle {
|
||||
v = 1
|
||||
}
|
||||
atomic.StoreInt32(&c.idle, v)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteQueue(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
|
||||
case <-c.closed:
|
||||
return net.ErrClosed
|
||||
|
||||
default:
|
||||
return errors.New("recv queue is full")
|
||||
}
|
||||
}
|
120
pkg/internal/util/udp/listener.go
Normal file
120
pkg/internal/util/udp/listener.go
Normal file
@ -0,0 +1,120 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
addr net.Addr
|
||||
conn net.PacketConn
|
||||
cqueue chan net.Conn
|
||||
readQueueSize int
|
||||
readBufferSize int
|
||||
connPool *ConnPool
|
||||
mux sync.Mutex
|
||||
closed chan struct{}
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
|
||||
ln := &listener{
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
cqueue: make(chan net.Conn, backlog),
|
||||
connPool: NewConnPool(ttl).WithLogger(logger),
|
||||
readQueueSize: dataQueueSize,
|
||||
readBufferSize: dataBufferSize,
|
||||
closed: make(chan struct{}),
|
||||
errChan: make(chan error, 1),
|
||||
logger: logger,
|
||||
}
|
||||
go ln.listenLoop()
|
||||
|
||||
return ln
|
||||
}
|
||||
|
||||
func (ln *listener) Accept() (conn net.Conn, err error) {
|
||||
select {
|
||||
case conn = <-ln.cqueue:
|
||||
return
|
||||
case <-ln.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err = <-ln.errChan:
|
||||
if err == nil {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) listenLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
b := bufpool.Get(ln.readBufferSize)
|
||||
|
||||
n, raddr, err := ln.conn.ReadFrom(*b)
|
||||
if err != nil {
|
||||
ln.errChan <- err
|
||||
close(ln.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
c := ln.getConn(raddr)
|
||||
if c == nil {
|
||||
bufpool.Put(b)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.WriteQueue((*b)[:n]); err != nil {
|
||||
ln.logger.Warn("data discarded: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
return ln.addr
|
||||
}
|
||||
|
||||
func (ln *listener) Close() error {
|
||||
select {
|
||||
case <-ln.closed:
|
||||
default:
|
||||
close(ln.closed)
|
||||
ln.conn.Close()
|
||||
ln.connPool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *listener) getConn(raddr net.Addr) *Conn {
|
||||
ln.mux.Lock()
|
||||
defer ln.mux.Unlock()
|
||||
|
||||
c, ok := ln.connPool.Get(raddr.String())
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
|
||||
select {
|
||||
case ln.cqueue <- c:
|
||||
ln.connPool.Set(raddr.String(), c)
|
||||
return c
|
||||
default:
|
||||
c.Close()
|
||||
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
|
||||
return nil
|
||||
}
|
||||
}
|
100
pkg/internal/util/udp/pool.go
Normal file
100
pkg/internal/util/udp/pool.go
Normal file
@ -0,0 +1,100 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
)
|
||||
|
||||
type ConnPool struct {
|
||||
m sync.Map
|
||||
ttl time.Duration
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewConnPool(ttl time.Duration) *ConnPool {
|
||||
p := &ConnPool{
|
||||
ttl: ttl,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go p.idleCheck()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
|
||||
p.logger = logger
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
c, ok = v.(*Conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ConnPool) Set(key any, c *Conn) {
|
||||
p.m.Store(key, c)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Delete(key any) {
|
||||
p.m.Delete(key)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Close() {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
close(p.closed)
|
||||
|
||||
p.m.Range(func(k, v any) bool {
|
||||
if c, ok := v.(*Conn); ok && c != nil {
|
||||
c.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ConnPool) idleCheck() {
|
||||
ticker := time.NewTicker(p.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
size := 0
|
||||
idles := 0
|
||||
p.m.Range(func(key, value any) bool {
|
||||
c, ok := value.(*Conn)
|
||||
if !ok || c == nil {
|
||||
p.Delete(key)
|
||||
return true
|
||||
}
|
||||
size++
|
||||
|
||||
if c.IsIdle() {
|
||||
idles++
|
||||
p.Delete(key)
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
c.SetIdle(true)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if idles > 0 {
|
||||
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
|
||||
}
|
||||
case <-p.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WebsocketConn interface {
|
||||
net.Conn
|
||||
WriteMessage(int, []byte) error
|
||||
ReadMessage() (int, []byte, error)
|
||||
}
|
||||
|
||||
type websocketConn struct {
|
||||
*websocket.Conn
|
||||
rb []byte
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func Conn(conn *websocket.Conn) WebsocketConn {
|
||||
return &websocketConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketConn) Read(b []byte) (n int, err error) {
|
||||
if len(c.rb) == 0 {
|
||||
_, c.rb, err = c.Conn.ReadMessage()
|
||||
}
|
||||
n = copy(b, c.rb)
|
||||
c.rb = c.rb[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *websocketConn) Write(b []byte) (n int, err error) {
|
||||
err = c.WriteMessage(websocket.BinaryMessage, b)
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *websocketConn) WriteMessage(messageType int, data []byte) error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
return c.Conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||
if err := c.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SetWriteDeadline(t)
|
||||
}
|
Reference in New Issue
Block a user