initial commit

This commit is contained in:
ginuerzh
2022-03-14 20:27:14 +08:00
commit 9397cb5351
175 changed files with 16196 additions and 0 deletions

280
handler/dns/handler.go Normal file
View File

@ -0,0 +1,280 @@
package dns
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
resolver_util "github.com/go-gost/gost/v3/pkg/common/util/resolver"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/hosts"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/gost/v3/pkg/resolver/exchanger"
"github.com/miekg/dns"
)
const (
defaultNameserver = "udp://127.0.0.1:53"
)
func init() {
registry.HandlerRegistry().Register("dns", NewHandler)
}
type dnsHandler struct {
exchangers []exchanger.Exchanger
cache *resolver_util.Cache
router *chain.Router
hosts hosts.HostMapper
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &dnsHandler{
options: options,
}
}
func (h *dnsHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
log := h.options.Logger
h.cache = resolver_util.NewCache().WithLogger(log)
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(log)
}
h.hosts = h.router.Hosts()
for _, server := range h.md.dns {
server = strings.TrimSpace(server)
if server == "" {
continue
}
ex, err := exchanger.NewExchanger(
server,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(log),
)
if err != nil {
log.Warnf("parse %s: %v", server, err)
continue
}
h.exchangers = append(h.exchangers, ex)
}
if len(h.exchangers) == 0 {
ex, err := exchanger.NewExchanger(
defaultNameserver,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(log),
)
log.Warnf("resolver not found, default to %s", defaultNameserver)
if err != nil {
return err
}
h.exchangers = append(h.exchangers, ex)
}
return
}
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
b := bufpool.Get(4096)
defer bufpool.Put(b)
n, err := conn.Read(*b)
if err != nil {
log.Error(err)
return err
}
reply, err := h.exchange(ctx, (*b)[:n], log)
if err != nil {
return err
}
defer bufpool.Put(&reply)
if _, err = conn.Write(reply); err != nil {
log.Error(err)
return err
}
return nil
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
log.Error(err)
return nil, err
}
if len(mq.Question) == 0 {
return nil, errors.New("msg: empty question")
}
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
if log.IsLevelEnabled(logger.DebugLevel) {
log.Debug(mq.String())
}
var mr *dns.Msg
if log.IsLevelEnabled(logger.DebugLevel) {
defer func() {
if mr != nil {
log.Debug(mr.String())
}
}()
}
mr = h.lookupHosts(&mq, log)
if mr != nil {
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
}
// only cache for single question message.
if len(mq.Question) == 1 {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
}
defer func() {
if mr != nil {
h.cache.Store(key, mr, h.md.ttl)
}
}()
}
b := bufpool.Get(4096)
defer bufpool.Put(b)
query, err := mq.PackBuffer(*b)
if err != nil {
log.Error(err)
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
log.Error(err)
}
if err != nil {
return nil, err
}
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
log.Error(err)
return nil, err
}
return reply, nil
}
// lookup host mapper
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
if h.hosts == nil ||
r.Question[0].Qclass != dns.ClassINET ||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
return nil
}
m = &dns.Msg{}
m.SetReply(r)
host := strings.TrimSuffix(r.Question[0].Name, ".")
switch r.Question[0].Qtype {
case dns.TypeA:
ips, _ := h.hosts.Lookup("ip4", host)
if len(ips) == 0 {
return nil
}
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
if err != nil {
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
}
case dns.TypeAAAA:
ips, _ := h.hosts.Lookup("ip6", host)
if len(ips) == 0 {
return nil
}
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
if err != nil {
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
}
}
return
}
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
buf := new(bytes.Buffer)
buf.WriteString(m.MsgHdr.String() + " ")
buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ")
buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ")
buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ")
buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra)))
return buf.String()
}

41
handler/dns/metadata.go Normal file
View File

@ -0,0 +1,41 @@
package dns
import (
"net"
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
readTimeout time.Duration
ttl time.Duration
timeout time.Duration
clientIP net.IP
// nameservers
dns []string
}
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
ttl = "ttl"
timeout = "timeout"
clientIP = "clientIP"
dns = "dns"
)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
h.md.ttl = mdata.GetDuration(md, ttl)
h.md.timeout = mdata.GetDuration(md, timeout)
if h.md.timeout <= 0 {
h.md.timeout = 5 * time.Second
}
sip := mdata.GetString(md, clientIP)
if sip != "" {
h.md.clientIP = net.ParseIP(sip)
}
h.md.dns = mdata.GetStrings(md, dns)
return
}

46
handler/http2/conn.go Normal file
View File

@ -0,0 +1,46 @@
package http2
import (
"errors"
"io"
"net/http"
)
type readWriter struct {
r io.Reader
w io.Writer
}
func (rw *readWriter) Read(p []byte) (n int, err error) {
return rw.r.Read(p)
}
func (rw *readWriter) Write(p []byte) (n int, err error) {
return rw.w.Write(p)
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(string); ok {
err = errors.New(s)
return
}
err = r.(error)
}
}()
n, err = fw.w.Write(p)
if err != nil {
// log.Log("flush writer:", err)
return
}
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}

343
handler/http2/handler.go Normal file
View File

@ -0,0 +1,343 @@
package http2
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"os"
"strconv"
"strings"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
http2_util "github.com/go-gost/x/internal/util/http2"
)
func init() {
registry.HandlerRegistry().Register("http2", NewHandler)
}
type http2Handler struct {
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &http2Handler{
options: options,
}
}
func (h *http2Handler) Init(md md.Metadata) error {
if err := h.parseMetadata(md); err != nil {
return err
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return nil
}
func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
cc, ok := conn.(*http2_util.ServerConn)
if !ok {
err := errors.New("wrong connection type")
log.Error(err)
return err
}
return h.roundTrip(ctx, cc.Writer(), cc.Request(), log)
}
// NOTE: there is an issue (golang/go#43989) will cause the client hangs
// when server returns an non-200 status code,
// May be fixed in go1.18.
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error {
// Try to get the actual host.
// Compatible with GOST 2.x.
if v := req.Header.Get("Gost-Target"); v != "" {
if h, err := h.decodeServerName(v); err == nil {
req.Host = h
}
}
req.Header.Del("Gost-Target")
if v := req.Header.Get("X-Gost-Target"); v != "" {
if h, err := h.decodeServerName(v); err == nil {
req.Host = h
}
}
req.Header.Del("X-Gost-Target")
addr := req.Host
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "80")
}
fields := map[string]any{
"dst": addr,
}
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
fields["user"] = u
}
log = log.WithFields(fields)
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Debug(string(dump))
}
log.Infof("%s >> %s", req.RemoteAddr, addr)
for k := range h.md.header {
w.Header().Set(k, h.md.header.Get(k))
}
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
w.WriteHeader(http.StatusForbidden)
log.Info("bypass: ", addr)
return nil
}
resp := &http.Response{
ProtoMajor: 2,
ProtoMinor: 0,
Header: http.Header{},
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
if !h.authenticate(w, req, resp, log) {
return nil
}
// delete the proxy related headers.
req.Header.Del("Proxy-Authorization")
req.Header.Del("Proxy-Connection")
cc, err := h.router.Dial(ctx, "tcp", addr)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusServiceUnavailable)
return err
}
defer cc.Close()
if req.Method == http.MethodConnect {
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
// compatible with HTTP1.x
if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 {
// we take over the underly connection
conn, _, err := hj.Hijack()
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return err
}
defer conn.Close()
start := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}
start := time.Now()
log.Infof("%s <-> %s", req.RemoteAddr, addr)
netpkg.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc)
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >-< %s", req.RemoteAddr, addr)
return nil
}
// TODO: forward request
return nil
}
func (h *http2Handler) decodeServerName(s string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return "", err
}
if len(b) < 4 {
return "", errors.New("invalid name")
}
v, err := base64.RawURLEncoding.DecodeString(string(b[4:]))
if err != nil {
return "", err
}
if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) {
return "", errors.New("invalid name")
}
return string(v), nil
}
func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
if proxyAuth == "" {
return
}
if !strings.HasPrefix(proxyAuth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if h.options.Auther == nil || h.options.Auther.Authenticate(u, p) {
return true
}
pr := h.md.probeResistance
// probing resistance is enabled, and knocking host is mismatch.
if pr != nil && (pr.Knock == "" || !strings.EqualFold(r.URL.Hostname(), pr.Knock)) {
resp.StatusCode = http.StatusServiceUnavailable // default status code
switch pr.Type {
case "code":
resp.StatusCode, _ = strconv.Atoi(pr.Value)
case "web":
url := pr.Value
if !strings.HasPrefix(url, "http") {
url = "http://" + url
}
r, err := http.Get(url)
if err != nil {
log.Error(err)
break
}
resp = r
defer resp.Body.Close()
case "host":
cc, err := net.Dial("tcp", pr.Value)
if err != nil {
log.Error(err)
break
}
defer cc.Close()
if err := h.forwardRequest(w, r, cc); err != nil {
log.Error(err)
}
return
case "file":
f, _ := os.Open(pr.Value)
if f != nil {
defer f.Close()
resp.StatusCode = http.StatusOK
if finfo, _ := f.Stat(); finfo != nil {
resp.ContentLength = finfo.Size()
}
resp.Header.Set("Content-Type", "text/html")
resp.Body = f
}
}
}
if resp.StatusCode == 0 {
resp.StatusCode = http.StatusProxyAuthRequired
resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"")
if strings.ToLower(r.Header.Get("Proxy-Connection")) == "keep-alive" {
// XXX libcurl will keep sending auth request in same conn
// which we don't supported yet.
resp.Header.Add("Connection", "close")
resp.Header.Add("Proxy-Connection", "close")
}
log.Info("proxy authentication required")
} else {
resp.Header = http.Header{}
resp.Header.Set("Server", "nginx/1.20.1")
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
if resp.StatusCode == http.StatusOK {
resp.Header.Set("Connection", "keep-alive")
}
}
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
log.Debug(string(dump))
}
h.writeResponse(w, resp)
return
}
func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) {
if err = r.Write(rw); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(rw), r)
if err != nil {
return
}
defer resp.Body.Close()
return h.writeResponse(w, resp)
}
func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error {
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)
_, err := io.Copy(flushWriter{w}, resp.Body)
return err
}

47
handler/http2/metadata.go Normal file
View File

@ -0,0 +1,47 @@
package http2
import (
"net/http"
"strings"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
probeResistance *probeResistance
header http.Header
}
func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
const (
header = "header"
probeResistKey = "probeResistance"
knock = "knock"
)
if m := mdata.GetStringMapString(md, header); len(m) > 0 {
hd := http.Header{}
for k, v := range m {
hd.Add(k, v)
}
h.md.header = hd
}
if v := mdata.GetString(md, probeResistKey); v != "" {
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 {
h.md.probeResistance = &probeResistance{
Type: ss[0],
Value: ss[1],
Knock: mdata.GetString(md, knock),
}
}
}
return nil
}
type probeResistance struct {
Type string
Value string
Knock string
}

112
handler/redirect/handler.go Normal file
View File

@ -0,0 +1,112 @@
package redirect
import (
"context"
"fmt"
"net"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/handler"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
)
func init() {
registry.HandlerRegistry().Register("red", NewHandler)
registry.HandlerRegistry().Register("redu", NewHandler)
registry.HandlerRegistry().Register("redir", NewHandler)
registry.HandlerRegistry().Register("redirect", NewHandler)
}
type redirectHandler struct {
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &redirectHandler{
options: options,
}
}
func (h *redirectHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
network := "tcp"
var dstAddr net.Addr
var err error
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
dstAddr = conn.LocalAddr()
}
if network == "tcp" {
dstAddr, conn, err = h.getOriginalDstAddr(conn)
if err != nil {
log.Error(err)
return err
}
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
})
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
log.Info("bypass: ", dstAddr)
return nil
}
cc, err := h.router.Dial(ctx, network, dstAddr.String())
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
}

View File

@ -0,0 +1,40 @@
package redirect
import (
"errors"
"fmt"
"net"
"syscall"
)
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
defer conn.Close()
tc, ok := conn.(*net.TCPConn)
if !ok {
err = errors.New("wrong connection type, must be TCP")
return
}
fc, err := tc.File()
if err != nil {
return
}
defer fc.Close()
mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80)
if err != nil {
return
}
// only ipv4 support
ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7])
port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3])
addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port))
if err != nil {
return
}
c, err = net.FileConn(fc)
return
}

View File

@ -0,0 +1,15 @@
//go:build !linux
package redirect
import (
"errors"
"net"
)
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
defer conn.Close()
err = errors.New("TCP redirect is not available on non-linux platform")
return
}

View File

@ -0,0 +1,18 @@
package redirect
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
retryCount int
}
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
retryCount = "retry"
)
h.md.retryCount = mdata.GetInt(md, retryCount)
return
}

193
handler/relay/bind.go Normal file
View File

@ -0,0 +1,193 @@
package relay
import (
"context"
"fmt"
"net"
"time"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
net_relay "github.com/go-gost/gost/v3/pkg/common/net/relay"
"github.com/go-gost/gost/v3/pkg/logger"
"github.com/go-gost/relay"
"github.com/go-gost/x/internal/util/mux"
relay_util "github.com/go-gost/x/internal/util/relay"
)
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "bind",
})
log.Infof("%s >> %s", conn.RemoteAddr(), address)
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if !h.md.enableBind {
resp.Status = relay.StatusForbidden
log.Error("relay: BIND is disabled")
_, err := resp.WriteTo(conn)
return err
}
if network == "tcp" {
return h.bindTCP(ctx, conn, network, address, log)
} else {
return h.bindUDP(ctx, conn, network, address, log)
}
}
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil {
log.Error(err)
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
return err
}
af := &relay.AddrFeature{}
err = af.ParseFrom(ln.Addr().String())
if err != nil {
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
af.AType = relay.AddrIPv4
resp.Features = append(resp.Features, af)
if _, err := resp.WriteTo(conn); err != nil {
log.Error(err)
ln.Close()
return err
}
log = log.WithFields(map[string]any{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
log.Debugf("bind on %s OK", ln.Addr())
return h.serveTCPBind(ctx, conn, ln, log)
}
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
bindAddr, _ := net.ResolveUDPAddr(network, address)
pc, err := net.ListenUDP(network, bindAddr)
if err != nil {
log.Error(err)
return err
}
defer pc.Close()
af := &relay.AddrFeature{}
err = af.ParseFrom(pc.LocalAddr().String())
if err != nil {
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
af.AType = relay.AddrIPv4
resp.Features = append(resp.Features, af)
if _, err := resp.WriteTo(conn); err != nil {
log.Error(err)
return err
}
log = log.WithFields(map[string]any{
"bind": pc.LocalAddr().String(),
})
log.Debugf("bind on %s OK", pc.LocalAddr())
r := net_relay.NewUDPRelay(relay_util.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass).
WithLogger(log)
r.SetBufferSize(h.md.udpBufferSize)
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
r.Run()
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
return nil
}
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error {
// Upgrade connection to multiplex stream.
session, err := mux.ClientSession(conn)
if err != nil {
log.Error(err)
return err
}
defer session.Close()
go func() {
defer ln.Close()
for {
conn, err := session.Accept()
if err != nil {
log.Error(err)
return
}
conn.Close() // we do not handle incoming connections.
}
}()
for {
rc, err := ln.Accept()
if err != nil {
log.Error(err)
return err
}
log.Debugf("peer %s accepted", rc.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
log = log.WithFields(map[string]any{
"local": ln.Addr().String(),
"remote": c.RemoteAddr().String(),
})
sc, err := session.GetConn()
if err != nil {
log.Error(err)
return
}
defer sc.Close()
af := &relay.AddrFeature{}
af.ParseFrom(c.RemoteAddr().String())
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
Features: []relay.Feature{af},
}
if _, err := resp.WriteTo(sc); err != nil {
log.Error(err)
return
}
t := time.Now()
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
netpkg.Transport(sc, c)
log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
}(rc)
}
}

81
handler/relay/conn.go Normal file
View File

@ -0,0 +1,81 @@
package relay
import (
"bytes"
"encoding/binary"
"errors"
"io"
"math"
"net"
)
type tcpConn struct {
net.Conn
wbuf bytes.Buffer
}
func (c *tcpConn) Read(b []byte) (n int, err error) {
if err != nil {
return
}
return c.Conn.Read(b)
}
func (c *tcpConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
return
}
_, err = c.Conn.Write(b)
return
}
type udpConn struct {
net.Conn
wbuf bytes.Buffer
}
func (c *udpConn) Read(b []byte) (n int, err error) {
var bb [2]byte
_, err = io.ReadFull(c.Conn, bb[:])
if err != nil {
return
}
dlen := int(binary.BigEndian.Uint16(bb[:]))
if len(b) >= dlen {
return io.ReadFull(c.Conn, b[:dlen])
}
buf := make([]byte, dlen)
_, err = io.ReadFull(c.Conn, buf)
n = copy(b, buf)
return
}
func (c *udpConn) Write(b []byte) (n int, err error) {
if len(b) > math.MaxUint16 {
err = errors.New("write: data maximum exceeded")
return
}
n = len(b)
if c.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
c.wbuf.Write(bb[:])
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
return
}
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
_, err = c.Conn.Write(bb[:])
if err != nil {
return
}
return c.Conn.Write(b)
}

91
handler/relay/connect.go Normal file
View File

@ -0,0 +1,91 @@
package relay
import (
"context"
"errors"
"fmt"
"net"
"time"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/logger"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect",
})
log.Infof("%s >> %s", conn.RemoteAddr(), address)
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if address == "" {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
err := errors.New("target not specified")
log.Error(err)
return err
}
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
log.Info("bypass: ", address)
resp.Status = relay.StatusForbidden
_, err := resp.WriteTo(conn)
return err
}
cc, err := h.router.Dial(ctx, network, address)
if err != nil {
resp.Status = relay.StatusNetworkUnreachable
resp.WriteTo(conn)
return err
}
defer cc.Close()
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
log.Error(err)
return err
}
}
switch network {
case "udp", "udp4", "udp6":
rc := &udpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return err
}
}
conn = rc
default:
rc := &tcpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return err
}
}
conn = rc
}
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
return nil
}

91
handler/relay/forward.go Normal file
View File

@ -0,0 +1,91 @@
package relay
import (
"context"
"errors"
"fmt"
"net"
"time"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/logger"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
target := h.group.Next()
if target == nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
err := errors.New("target not available")
log.Error(err)
return err
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
"cmd": "forward",
})
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
cc, err := h.router.Dial(ctx, network, target.Addr)
if err != nil {
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
target.Marker.Mark()
resp.Status = relay.StatusHostUnreachable
resp.WriteTo(conn)
log.Error(err)
return err
}
defer cc.Close()
target.Marker.Reset()
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
log.Error(err)
return err
}
}
switch network {
case "udp", "udp4", "udp6":
rc := &udpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return err
}
}
conn = rc
default:
rc := &tcpConn{
Conn: conn,
}
if !h.md.noDelay {
// cache the header
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
return err
}
}
conn = rc
}
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
}

147
handler/relay/handler.go Normal file
View File

@ -0,0 +1,147 @@
package relay
import (
"context"
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/handler"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/relay"
)
var (
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
)
func init() {
registry.HandlerRegistry().Register("relay", NewHandler)
}
type relayHandler struct {
group *chain.NodeGroup
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &relayHandler{
options: options,
}
}
func (h *relayHandler) Init(md md.Metadata) (err error) {
if err := h.parseMetadata(md); err != nil {
return err
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return nil
}
// Forward implements handler.Forwarder.
func (h *relayHandler) Forward(group *chain.NodeGroup) {
h.group = group
}
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil {
log.Error(err)
return err
}
conn.SetReadDeadline(time.Time{})
if req.Version != relay.Version1 {
err := ErrBadVersion
log.Error(err)
return err
}
var user, pass string
var address string
for _, f := range req.Features {
if f.Type() == relay.FeatureUserAuth {
feature := f.(*relay.UserAuthFeature)
user, pass = feature.Username, feature.Password
}
if f.Type() == relay.FeatureAddr {
feature := f.(*relay.AddrFeature)
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
}
}
if user != "" {
log = log.WithFields(map[string]any{"user": user})
}
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) {
resp.Status = relay.StatusUnauthorized
log.Error("unauthorized")
_, err := resp.WriteTo(conn)
return err
}
network := "tcp"
if (req.Flags & relay.FUDP) == relay.FUDP {
network = "udp"
}
if h.group != nil {
if address != "" {
resp.Status = relay.StatusForbidden
log.Error("forward mode, connect is forbidden")
_, err := resp.WriteTo(conn)
return err
}
// forward mode
return h.handleForward(ctx, conn, network, log)
}
switch req.Flags & relay.CmdMask {
case 0, relay.CONNECT:
return h.handleConnect(ctx, conn, network, address, log)
case relay.BIND:
return h.handleBind(ctx, conn, network, address, log)
}
return ErrUnknownCmd
}

35
handler/relay/metadata.go Normal file
View File

@ -0,0 +1,35 @@
package relay
import (
"math"
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
readTimeout time.Duration
enableBind bool
udpBufferSize int
noDelay bool
}
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
enableBind = "bind"
udpBufferSize = "udpBufferSize"
noDelay = "nodelay"
)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
h.md.enableBind = mdata.GetBool(md, enableBind)
h.md.noDelay = mdata.GetBool(md, noDelay)
if bs := mdata.GetInt(md, udpBufferSize); bs > 0 {
h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
} else {
h.md.udpBufferSize = 1024
}
return
}

20
handler/sni/conn.go Normal file
View File

@ -0,0 +1,20 @@
package sni
import (
"net"
)
type cacheConn struct {
net.Conn
buf []byte
}
func (c *cacheConn) Read(b []byte) (n int, err error) {
if len(c.buf) > 0 {
n = copy(b, c.buf)
c.buf = c.buf[n:]
return
}
return c.Conn.Read(b)
}

222
handler/sni/handler.go Normal file
View File

@ -0,0 +1,222 @@
package sni
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"net"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/handler"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
dissector "github.com/go-gost/tls-dissector"
)
func init() {
registry.HandlerRegistry().Register("sni", NewHandler)
}
type sniHandler struct {
httpHandler handler.Handler
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
h := &sniHandler{
options: options,
}
if f := registry.HandlerRegistry().Get("http"); f != nil {
v := append(opts,
handler.LoggerOption(h.options.Logger.WithFields(map[string]any{"type": "http"})))
h.httpHandler = f(v...)
}
return h
}
func (h *sniHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
if h.httpHandler != nil {
if md != nil {
md.Set("sni", true)
}
if err = h.httpHandler.Init(md); err != nil {
return
}
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return nil
}
func (h *sniHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
var hdr [dissector.RecordHeaderLen]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
log.Error(err)
return err
}
if hdr[0] != dissector.Handshake {
// We assume it is an HTTP request
conn = &cacheConn{
Conn: conn,
buf: hdr[:],
}
if h.httpHandler != nil {
return h.httpHandler.Handle(ctx, conn)
}
return nil
}
length := binary.BigEndian.Uint16(hdr[3:5])
buf := bufpool.Get(int(length) + dissector.RecordHeaderLen)
defer bufpool.Put(buf)
if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil {
log.Error(err)
return err
}
copy(*buf, hdr[:])
opaque, host, err := h.decodeHost(bytes.NewReader(*buf))
if err != nil {
log.Error(err)
return err
}
target := net.JoinHostPort(host, "443")
log = log.WithFields(map[string]any{
"dst": target,
})
log.Infof("%s >> %s", conn.RemoteAddr(), target)
if h.options.Bypass != nil && h.options.Bypass.Contains(target) {
log.Info("bypass: ", target)
return nil
}
cc, err := h.router.Dial(ctx, "tcp", target)
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
if _, err := cc.Write(opaque); err != nil {
log.Error(err)
return err
}
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), target)
return nil
}
func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) {
record, err := dissector.ReadRecord(r)
if err != nil {
return
}
clientHello := dissector.ClientHelloMsg{}
if err = clientHello.Decode(record.Opaque); err != nil {
return
}
var extensions []dissector.Extension
for _, ext := range clientHello.Extensions {
if ext.Type() == 0xFFFE {
b, _ := ext.Encode()
if v, err := h.decodeServerName(string(b)); err == nil {
host = v
}
continue
}
extensions = append(extensions, ext)
}
clientHello.Extensions = extensions
for _, ext := range clientHello.Extensions {
if ext.Type() == dissector.ExtServerName {
snExtension := ext.(*dissector.ServerNameExtension)
if host == "" {
host = snExtension.Name
} else {
snExtension.Name = host
}
break
}
}
record.Opaque, err = clientHello.Encode()
if err != nil {
return
}
buf := &bytes.Buffer{}
if _, err = record.WriteTo(buf); err != nil {
return
}
opaque = buf.Bytes()
return
}
func (h *sniHandler) decodeServerName(s string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return "", err
}
if len(b) < 4 {
return "", errors.New("invalid name")
}
v, err := base64.RawURLEncoding.DecodeString(string(b[4:]))
if err != nil {
return "", err
}
if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) {
return "", errors.New("invalid name")
}
return string(v), nil
}

20
handler/sni/metadata.go Normal file
View File

@ -0,0 +1,20 @@
package sni
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
readTimeout time.Duration
}
func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
return
}

119
handler/ss/handler.go Normal file
View File

@ -0,0 +1,119 @@
package ss
import (
"context"
"io"
"io/ioutil"
"net"
"time"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/v3/pkg/chain"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/handler"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/x/internal/util/ss"
"github.com/shadowsocks/go-shadowsocks2/core"
)
func init() {
registry.HandlerRegistry().Register("ss", NewHandler)
}
type ssHandler struct {
cipher core.Cipher
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &ssHandler{
options: options,
}
}
func (h *ssHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
if h.options.Auth != nil {
method := h.options.Auth.Username()
password, _ := h.options.Auth.Password()
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
if err != nil {
return
}
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if h.cipher != nil {
conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil)
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
addr := &gosocks5.Addr{}
if _, err := addr.ReadFrom(conn); err != nil {
log.Error(err)
io.Copy(ioutil.Discard, conn)
return err
}
log = log.WithFields(map[string]any{
"dst": addr.String(),
})
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
log.Info("bypass: ", addr.String())
return nil
}
cc, err := h.router.Dial(ctx, "tcp", addr.String())
if err != nil {
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}

24
handler/ss/metadata.go Normal file
View File

@ -0,0 +1,24 @@
package ss
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
key string
readTimeout time.Duration
}
func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
key = "key"
readTimeout = "readTimeout"
)
h.md.key = mdata.GetString(md, key)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
return
}

188
handler/ss/udp/handler.go Normal file
View File

@ -0,0 +1,188 @@
package ss
import (
"context"
"errors"
"net"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/x/internal/util/relay"
"github.com/go-gost/x/internal/util/ss"
"github.com/shadowsocks/go-shadowsocks2/core"
)
func init() {
registry.HandlerRegistry().Register("ssu", NewHandler)
}
type ssuHandler struct {
cipher core.Cipher
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &ssuHandler{
options: options,
}
}
func (h *ssuHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
if h.options.Auth != nil {
method := h.options.Auth.Username()
password, _ := h.options.Auth.Password()
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
if err != nil {
return
}
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
pc, ok := conn.(net.PacketConn)
if ok {
if h.cipher != nil {
pc = h.cipher.PacketConn(pc)
}
// standard UDP relay.
pc = ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize)
} else {
if h.cipher != nil {
conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil)
}
// UDP over TCP
pc = relay.UDPTunServerConn(conn)
}
// obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil {
log.Error(err)
return err
}
defer c.Close()
cc, ok := c.(net.PacketConn)
if !ok {
err := errors.New("ss: wrong connection type")
log.Error(err)
return err
}
t := time.Now()
log.Infof("%s <-> %s", conn.LocalAddr(), cc.LocalAddr())
h.relayPacket(pc, cc, log)
log.WithFields(map[string]any{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr())
return nil
}
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) {
bufSize := h.md.bufferSize
errc := make(chan error, 2)
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, addr, err := pc1.ReadFrom(*b)
if err != nil {
return err
}
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
log.Warn("bypass: ", addr)
return nil
}
if _, err = pc2.WriteTo((*b)[:n], addr); err != nil {
return err
}
log.Debugf("%s >>> %s data: %d",
pc2.LocalAddr(), addr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := pc2.ReadFrom(*b)
if err != nil {
return err
}
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
log.Warn("bypass: ", raddr)
return nil
}
if _, err = pc1.WriteTo((*b)[:n], raddr); err != nil {
return err
}
log.Debugf("%s <<< %s data: %d",
pc2.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
return <-errc
}

View File

@ -0,0 +1,32 @@
package ss
import (
"math"
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
key string
readTimeout time.Duration
bufferSize int
}
func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
key = "key"
readTimeout = "readTimeout"
bufferSize = "bufferSize"
)
h.md.key = mdata.GetString(md, key)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
if bs := mdata.GetInt(md, bufferSize); bs > 0 {
h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
} else {
h.md.bufferSize = 1024
}
return
}

245
handler/sshd/handler.go Normal file
View File

@ -0,0 +1,245 @@
package ssh
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
sshd_util "github.com/go-gost/x/internal/util/sshd"
"golang.org/x/crypto/ssh"
)
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
const (
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
)
func init() {
registry.HandlerRegistry().Register("sshd", NewHandler)
}
type forwardHandler struct {
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &forwardHandler{
options: options,
}
}
func (h *forwardHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return nil
}
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
switch cc := conn.(type) {
case *sshd_util.DirectForwardConn:
return h.handleDirectForward(ctx, cc, log)
case *sshd_util.RemoteForwardConn:
return h.handleRemoteForward(ctx, cc, log)
default:
err := errors.New("sshd: wrong connection type")
log.Error(err)
return err
}
}
func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) error {
targetAddr := conn.DstAddr()
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", targetAddr, "tcp"),
"cmd": "connect",
})
log.Infof("%s >> %s", conn.RemoteAddr(), targetAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) {
log.Infof("bypass %s", targetAddr)
return nil
}
cc, err := h.router.Dial(ctx, "tcp", targetAddr)
if err != nil {
return err
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
return nil
}
func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) error {
req := conn.Request()
t := tcpipForward{}
if err := ssh.Unmarshal(req.Payload, &t); err != nil {
log.Error(err)
return err
}
network := "tcp"
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", addr, network),
"cmd": "bind",
})
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
// tie to the client connection
ln, err := net.Listen(network, addr)
if err != nil {
log.Error(err)
req.Reply(false, nil)
return err
}
defer ln.Close()
log = log.WithFields(map[string]any{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
log.Debugf("bind on %s OK", ln.Addr())
err = func() error {
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
_, port, err := getHostPortFromAddr(ln.Addr())
if err != nil {
return err
}
var b [4]byte
binary.BigEndian.PutUint32(b[:], uint32(port))
t.Port = uint32(port)
return req.Reply(true, b[:])
}
return req.Reply(true, nil)
}()
if err != nil {
log.Error(err)
return err
}
sshConn := conn.Conn()
go func() {
for {
cc, err := ln.Accept()
if err != nil { // Unable to accept new connection - listener is likely closed
return
}
go func(conn net.Conn) {
defer conn.Close()
log := log.WithFields(map[string]any{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
})
p := directForward{}
var err error
var portnum int
p.Host1 = t.Host
p.Port1 = t.Port
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
if err != nil {
return
}
p.Port2 = uint32(portnum)
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
if err != nil {
log.Error("open forwarded channel: ", err)
return
}
defer ch.Close()
go ssh.DiscardRequests(reqs)
t := time.Now()
log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
netpkg.Transport(ch, conn)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
}(cc)
}
}()
tm := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
<-conn.Done()
log.WithFields(map[string]any{
"duration": time.Since(tm),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
host, portString, err := net.SplitHostPort(addr.String())
if err != nil {
return
}
port, err = strconv.Atoi(portString)
return
}
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
type directForward struct {
Host1 string
Port1 uint32
Host2 string
Port2 uint32
}
func (p directForward) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
}
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
type tcpipForward struct {
Host string
Port uint32
}

12
handler/sshd/metadata.go Normal file
View File

@ -0,0 +1,12 @@
package ssh
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
return
}

341
handler/tap/handler.go Normal file
View File

@ -0,0 +1,341 @@
package tap
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/x/internal/util/ss"
tap_util "github.com/go-gost/x/internal/util/tap"
"github.com/shadowsocks/go-shadowsocks2/core"
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
"github.com/songgao/water/waterutil"
)
func init() {
registry.HandlerRegistry().Register("tap", NewHandler)
}
type tapHandler struct {
group *chain.NodeGroup
routes sync.Map
exit chan struct{}
cipher core.Cipher
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &tapHandler{
exit: make(chan struct{}, 1),
options: options,
}
}
func (h *tapHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
if h.options.Auth != nil {
method := h.options.Auth.Username()
password, _ := h.options.Auth.Password()
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
if err != nil {
return
}
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
// Forward implements handler.Forwarder.
func (h *tapHandler) Forward(group *chain.NodeGroup) {
h.group = group
}
func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer os.Exit(0)
defer conn.Close()
log := h.options.Logger
cc, ok := conn.(*tap_util.Conn)
if !ok || cc.Config() == nil {
err := errors.New("tap: wrong connection type")
log.Error(err)
return err
}
start := time.Now()
log = log.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
network := "udp"
var raddr net.Addr
var err error
target := h.group.Next()
if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil {
log.Error(err)
return err
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
})
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
}
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
return nil
}
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) {
var tempDelay time.Duration
for {
err := func() error {
var err error
var pc net.PacketConn
if addr != nil {
cc, err := h.router.Dial(ctx, addr.Network(), "")
if err != nil {
return err
}
var ok bool
pc, ok = cc.(net.PacketConn)
if !ok {
return errors.New("wrong connection type")
}
} else {
laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
pc, err = net.ListenUDP("udp", laddr)
}
if err != nil {
return err
}
if h.cipher != nil {
pc = h.cipher.PacketConn(pc)
}
defer pc.Close()
return h.transport(conn, pc, addr, config, log)
}()
if err != nil {
log.Error(err)
}
select {
case <-h.exit:
return
default:
}
if err != nil {
if tempDelay == 0 {
tempDelay = 1000 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 6 * time.Second; tempDelay > max {
tempDelay = max
}
time.Sleep(tempDelay)
continue
}
tempDelay = 0
}
}
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr, config *tap_util.Config, log logger.Logger) error {
errc := make(chan error, 1)
go func() {
for {
err := func() error {
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
n, err := tap.Read(*b)
if err != nil {
select {
case h.exit <- struct{}{}:
default:
}
return err
}
src := waterutil.MACSource((*b)[:n])
dst := waterutil.MACDestination((*b)[:n])
eType := etherType(waterutil.MACEthertype((*b)[:n]))
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
// client side, deliver frame directly.
if raddr != nil {
_, err := conn.WriteTo((*b)[:n], raddr)
return err
}
// server side, broadcast.
if waterutil.IsBroadcast(dst) {
go h.routes.Range(func(k, v any) bool {
conn.WriteTo((*b)[:n], v.(net.Addr))
return true
})
return nil
}
var addr net.Addr
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
addr = v.(net.Addr)
}
if addr == nil {
log.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
return nil
}
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
return err
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
n, addr, err := conn.ReadFrom(*b)
if err != nil &&
err != shadowaead.ErrShortPacket {
return err
}
src := waterutil.MACSource((*b)[:n])
dst := waterutil.MACDestination((*b)[:n])
eType := etherType(waterutil.MACEthertype((*b)[:n]))
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
// client side, deliver frame to tap device.
if raddr != nil {
_, err := tap.Write((*b)[:n])
return err
}
// server side, record route.
rkey := hwAddrToTapRouteKey(src)
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
if actual.(net.Addr).String() != addr.String() {
log.Debugf("update route: %s -> %s (old %s)",
src, addr, actual.(net.Addr))
h.routes.Store(rkey, addr)
}
} else {
log.Debugf("new route: %s -> %s", src, addr)
}
if waterutil.IsBroadcast(dst) {
go h.routes.Range(func(k, v any) bool {
if k.(tapRouteKey) != rkey {
conn.WriteTo((*b)[:n], v.(net.Addr))
}
return true
})
}
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
log.Debugf("find route: %s -> %s", dst, v)
_, err := conn.WriteTo((*b)[:n], v.(net.Addr))
return err
}
if _, err := tap.Write((*b)[:n]); err != nil {
select {
case h.exit <- struct{}{}:
default:
}
return err
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
err := <-errc
if err != nil && err == io.EOF {
err = nil
}
return err
}
var mEtherTypes = map[waterutil.Ethertype]string{
waterutil.IPv4: "ip",
waterutil.ARP: "arp",
waterutil.RARP: "rarp",
waterutil.IPv6: "ip6",
}
func etherType(et waterutil.Ethertype) string {
if s, ok := mEtherTypes[et]; ok {
return s
}
return fmt.Sprintf("unknown(%v)", et)
}
type tapRouteKey [6]byte
func hwAddrToTapRouteKey(addr net.HardwareAddr) (key tapRouteKey) {
copy(key[:], addr)
return
}

24
handler/tap/metadata.go Normal file
View File

@ -0,0 +1,24 @@
package tap
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
key string
bufferSize int
}
func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
key = "key"
bufferSize = "bufferSize"
)
h.md.key = mdata.GetString(md, key)
h.md.bufferSize = mdata.GetInt(md, bufferSize)
if h.md.bufferSize <= 0 {
h.md.bufferSize = 1500
}
return
}

391
handler/tun/handler.go Normal file
View File

@ -0,0 +1,391 @@
package tun
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"github.com/go-gost/gost/v3/pkg/chain"
"github.com/go-gost/gost/v3/pkg/common/bufpool"
"github.com/go-gost/gost/v3/pkg/handler"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/go-gost/x/internal/util/ss"
tun_util "github.com/go-gost/x/internal/util/tun"
"github.com/shadowsocks/go-shadowsocks2/core"
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
"github.com/songgao/water/waterutil"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func init() {
registry.HandlerRegistry().Register("tun", NewHandler)
}
type tunHandler struct {
group *chain.NodeGroup
routes sync.Map
exit chan struct{}
cipher core.Cipher
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &tunHandler{
exit: make(chan struct{}, 1),
options: options,
}
}
func (h *tunHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
if h.options.Auth != nil {
method := h.options.Auth.Username()
password, _ := h.options.Auth.Password()
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
if err != nil {
return
}
}
h.router = h.options.Router
if h.router == nil {
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
}
return
}
// Forward implements handler.Forwarder.
func (h *tunHandler) Forward(group *chain.NodeGroup) {
h.group = group
}
func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer os.Exit(0)
defer conn.Close()
log := h.options.Logger
cc, ok := conn.(*tun_util.Conn)
if !ok || cc.Config() == nil {
err := errors.New("tun: wrong connection type")
log.Error(err)
return err
}
start := time.Now()
log = log.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
network := "udp"
var raddr net.Addr
var err error
target := h.group.Next()
if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil {
log.Error(err)
return err
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
})
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
}
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
return nil
}
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) {
var tempDelay time.Duration
for {
err := func() error {
var err error
var pc net.PacketConn
if addr != nil {
cc, err := h.router.Dial(ctx, addr.Network(), "")
if err != nil {
return err
}
var ok bool
pc, ok = cc.(net.PacketConn)
if !ok {
cc.Close()
return errors.New("wrong connection type")
}
} else {
laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
pc, err = net.ListenUDP("udp", laddr)
}
if err != nil {
return err
}
if h.cipher != nil {
pc = h.cipher.PacketConn(pc)
}
defer pc.Close()
return h.transport(conn, pc, addr, config, log)
}()
if err != nil {
log.Error(err)
}
select {
case <-h.exit:
return
default:
}
if err != nil {
if tempDelay == 0 {
tempDelay = 1000 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 6 * time.Second; tempDelay > max {
tempDelay = max
}
time.Sleep(tempDelay)
continue
}
tempDelay = 0
}
}
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, config *tun_util.Config, log logger.Logger) error {
errc := make(chan error, 1)
go func() {
for {
err := func() error {
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
n, err := tun.Read(*b)
if err != nil {
select {
case h.exit <- struct{}{}:
default:
}
return err
}
var src, dst net.IP
if waterutil.IsIPv4((*b)[:n]) {
header, err := ipv4.ParseHeader((*b)[:n])
if err != nil {
log.Error(err)
return nil
}
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags)
src, dst = header.Src, header.Dst
} else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n])
if err != nil {
log.Warn(err)
return nil
}
log.Debugf("%s >> %s %s %d %d",
header.Src, header.Dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass)
src, dst = header.Src, header.Dst
} else {
log.Warn("unknown packet, discarded")
return nil
}
// client side, deliver packet directly.
if raddr != nil {
_, err := conn.WriteTo((*b)[:n], raddr)
return err
}
addr := h.findRouteFor(dst, config.Routes...)
if addr == nil {
log.Warnf("no route for %s -> %s", src, dst)
return nil
}
log.Debugf("find route: %s -> %s", dst, addr)
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
return err
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
n, addr, err := conn.ReadFrom(*b)
if err != nil &&
err != shadowaead.ErrShortPacket {
return err
}
var src, dst net.IP
if waterutil.IsIPv4((*b)[:n]) {
header, err := ipv4.ParseHeader((*b)[:n])
if err != nil {
log.Warn(err)
return nil
}
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags)
src, dst = header.Src, header.Dst
} else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n])
if err != nil {
log.Warn(err)
return nil
}
log.Debugf("%s > %s %s %d %d",
header.Src, header.Dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass)
src, dst = header.Src, header.Dst
} else {
log.Warn("unknown packet, discarded")
return nil
}
// client side, deliver packet to tun device.
if raddr != nil {
_, err := tun.Write((*b)[:n])
return err
}
rkey := ipToTunRouteKey(src)
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
if actual.(net.Addr).String() != addr.String() {
log.Debugf("update route: %s -> %s (old %s)",
src, addr, actual.(net.Addr))
h.routes.Store(rkey, addr)
}
} else {
log.Warnf("no route for %s -> %s", src, addr)
}
if addr := h.findRouteFor(dst, config.Routes...); addr != nil {
log.Debugf("find route: %s -> %s", dst, addr)
_, err := conn.WriteTo((*b)[:n], addr)
return err
}
if _, err := tun.Write((*b)[:n]); err != nil {
select {
case h.exit <- struct{}{}:
default:
}
return err
}
return nil
}()
if err != nil {
errc <- err
return
}
}
}()
err := <-errc
if err != nil && err == io.EOF {
err = nil
}
return err
}
func (h *tunHandler) findRouteFor(dst net.IP, routes ...tun_util.Route) net.Addr {
if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok {
return v.(net.Addr)
}
for _, route := range routes {
if route.Net.Contains(dst) && route.Gateway != nil {
if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok {
return v.(net.Addr)
}
}
}
return nil
}
var mIPProts = map[waterutil.IPProtocol]string{
waterutil.HOPOPT: "HOPOPT",
waterutil.ICMP: "ICMP",
waterutil.IGMP: "IGMP",
waterutil.GGP: "GGP",
waterutil.TCP: "TCP",
waterutil.UDP: "UDP",
waterutil.IPv6_Route: "IPv6-Route",
waterutil.IPv6_Frag: "IPv6-Frag",
waterutil.IPv6_ICMP: "IPv6-ICMP",
}
func ipProtocol(p waterutil.IPProtocol) string {
if v, ok := mIPProts[p]; ok {
return v
}
return fmt.Sprintf("unknown(%d)", p)
}
type tunRouteKey [16]byte
func ipToTunRouteKey(ip net.IP) (key tunRouteKey) {
copy(key[:], ip.To16())
return
}

24
handler/tun/metadata.go Normal file
View File

@ -0,0 +1,24 @@
package tun
import (
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
type metadata struct {
key string
bufferSize int
}
func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
key = "key"
bufferSize = "bufferSize"
)
h.md.key = mdata.GetString(md, key)
h.md.bufferSize = mdata.GetInt(md, bufferSize)
if h.md.bufferSize <= 0 {
h.md.bufferSize = 1500
}
return
}