initial commit
This commit is contained in:
280
handler/dns/handler.go
Normal file
280
handler/dns/handler.go
Normal file
@ -0,0 +1,280 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
resolver_util "github.com/go-gost/gost/v3/pkg/common/util/resolver"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/hosts"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/gost/v3/pkg/resolver/exchanger"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultNameserver = "udp://127.0.0.1:53"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("dns", NewHandler)
|
||||
}
|
||||
|
||||
type dnsHandler struct {
|
||||
exchangers []exchanger.Exchanger
|
||||
cache *resolver_util.Cache
|
||||
router *chain.Router
|
||||
hosts hosts.HostMapper
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &dnsHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
log := h.options.Logger
|
||||
|
||||
h.cache = resolver_util.NewCache().WithLogger(log)
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(log)
|
||||
}
|
||||
h.hosts = h.router.Hosts()
|
||||
|
||||
for _, server := range h.md.dns {
|
||||
server = strings.TrimSpace(server)
|
||||
if server == "" {
|
||||
continue
|
||||
}
|
||||
ex, err := exchanger.NewExchanger(
|
||||
server,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", server, err)
|
||||
continue
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
}
|
||||
if len(h.exchangers) == 0 {
|
||||
ex, err := exchanger.NewExchanger(
|
||||
defaultNameserver,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
log.Warnf("resolver not found, default to %s", defaultNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, err := conn.Read(*b)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
reply, err := h.exchange(ctx, (*b)[:n], log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer bufpool.Put(&reply)
|
||||
|
||||
if _, err = conn.Write(reply); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
|
||||
mq := dns.Msg{}
|
||||
if err := mq.Unpack(msg); err != nil {
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(mq.Question) == 0 {
|
||||
return nil, errors.New("msg: empty question")
|
||||
}
|
||||
|
||||
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
log.Debug(mq.String())
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
log.Debug(mr.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mr = h.lookupHosts(&mq, log)
|
||||
if mr != nil {
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
}
|
||||
|
||||
// only cache for single question message.
|
||||
if len(mq.Question) == 1 {
|
||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||
mr = h.cache.Load(key)
|
||||
if mr != nil {
|
||||
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
mr.Id = mq.Id
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
h.cache.Store(key, mr, h.md.ttl)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
query, err := mq.PackBuffer(*b)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
for _, ex := range h.exchangers {
|
||||
log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
reply, err = ex.Exchange(ctx, query)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Error(err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mr = &dns.Msg{}
|
||||
if err = mr.Unpack(reply); err != nil {
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// lookup host mapper
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
|
||||
if h.hosts == nil ||
|
||||
r.Question[0].Qclass != dns.ClassINET ||
|
||||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
|
||||
host := strings.TrimSuffix(r.Question[0].Name, ".")
|
||||
|
||||
switch r.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
ips, _ := h.hosts.Lookup("ip4", host)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
|
||||
case dns.TypeAAAA:
|
||||
ips, _ := h.hosts.Lookup("ip6", host)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.WriteString(m.MsgHdr.String() + " ")
|
||||
buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ")
|
||||
buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ")
|
||||
buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ")
|
||||
buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra)))
|
||||
return buf.String()
|
||||
}
|
41
handler/dns/metadata.go
Normal file
41
handler/dns/metadata.go
Normal file
@ -0,0 +1,41 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
readTimeout time.Duration
|
||||
ttl time.Duration
|
||||
timeout time.Duration
|
||||
clientIP net.IP
|
||||
// nameservers
|
||||
dns []string
|
||||
}
|
||||
|
||||
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
readTimeout = "readTimeout"
|
||||
ttl = "ttl"
|
||||
timeout = "timeout"
|
||||
clientIP = "clientIP"
|
||||
dns = "dns"
|
||||
)
|
||||
|
||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
h.md.ttl = mdata.GetDuration(md, ttl)
|
||||
h.md.timeout = mdata.GetDuration(md, timeout)
|
||||
if h.md.timeout <= 0 {
|
||||
h.md.timeout = 5 * time.Second
|
||||
}
|
||||
sip := mdata.GetString(md, clientIP)
|
||||
if sip != "" {
|
||||
h.md.clientIP = net.ParseIP(sip)
|
||||
}
|
||||
h.md.dns = mdata.GetStrings(md, dns)
|
||||
|
||||
return
|
||||
}
|
46
handler/http2/conn.go
Normal file
46
handler/http2/conn.go
Normal file
@ -0,0 +1,46 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type readWriter struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (rw *readWriter) Read(p []byte) (n int, err error) {
|
||||
return rw.r.Read(p)
|
||||
}
|
||||
|
||||
func (rw *readWriter) Write(p []byte) (n int, err error) {
|
||||
return rw.w.Write(p)
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (fw flushWriter) Write(p []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(string); ok {
|
||||
err = errors.New(s)
|
||||
return
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
n, err = fw.w.Write(p)
|
||||
if err != nil {
|
||||
// log.Log("flush writer:", err)
|
||||
return
|
||||
}
|
||||
if f, ok := fw.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
343
handler/http2/handler.go
Normal file
343
handler/http2/handler.go
Normal file
@ -0,0 +1,343 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
http2_util "github.com/go-gost/x/internal/util/http2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("http2", NewHandler)
|
||||
}
|
||||
|
||||
type http2Handler struct {
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &http2Handler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *http2Handler) Init(md md.Metadata) error {
|
||||
if err := h.parseMetadata(md); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
cc, ok := conn.(*http2_util.ServerConn)
|
||||
if !ok {
|
||||
err := errors.New("wrong connection type")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
return h.roundTrip(ctx, cc.Writer(), cc.Request(), log)
|
||||
}
|
||||
|
||||
// NOTE: there is an issue (golang/go#43989) will cause the client hangs
|
||||
// when server returns an non-200 status code,
|
||||
// May be fixed in go1.18.
|
||||
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) error {
|
||||
// Try to get the actual host.
|
||||
// Compatible with GOST 2.x.
|
||||
if v := req.Header.Get("Gost-Target"); v != "" {
|
||||
if h, err := h.decodeServerName(v); err == nil {
|
||||
req.Host = h
|
||||
}
|
||||
}
|
||||
req.Header.Del("Gost-Target")
|
||||
|
||||
if v := req.Header.Get("X-Gost-Target"); v != "" {
|
||||
if h, err := h.decodeServerName(v); err == nil {
|
||||
req.Host = h
|
||||
}
|
||||
}
|
||||
req.Header.Del("X-Gost-Target")
|
||||
|
||||
addr := req.Host
|
||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||
addr = net.JoinHostPort(addr, "80")
|
||||
}
|
||||
|
||||
fields := map[string]any{
|
||||
"dst": addr,
|
||||
}
|
||||
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
|
||||
fields["user"] = u
|
||||
}
|
||||
log = log.WithFields(fields)
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
log.Infof("%s >> %s", req.RemoteAddr, addr)
|
||||
|
||||
for k := range h.md.header {
|
||||
w.Header().Set(k, h.md.header.Get(k))
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
log.Info("bypass: ", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 2,
|
||||
ProtoMinor: 0,
|
||||
Header: http.Header{},
|
||||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
|
||||
if !h.authenticate(w, req, resp, log) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete the proxy related headers.
|
||||
req.Header.Del("Proxy-Authorization")
|
||||
req.Header.Del("Proxy-Connection")
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
|
||||
// compatible with HTTP1.x
|
||||
if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 {
|
||||
// we take over the underly connection
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
log.Infof("%s <-> %s", req.RemoteAddr, addr)
|
||||
netpkg.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >-< %s", req.RemoteAddr, addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: forward request
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *http2Handler) decodeServerName(s string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(b) < 4 {
|
||||
return "", errors.New("invalid name")
|
||||
}
|
||||
v, err := base64.RawURLEncoding.DecodeString(string(b[4:]))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) {
|
||||
return "", errors.New("invalid name")
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
|
||||
if proxyAuth == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(proxyAuth, "Basic ") {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic "))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
|
||||
u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
||||
if h.options.Auther == nil || h.options.Auther.Authenticate(u, p) {
|
||||
return true
|
||||
}
|
||||
|
||||
pr := h.md.probeResistance
|
||||
// probing resistance is enabled, and knocking host is mismatch.
|
||||
if pr != nil && (pr.Knock == "" || !strings.EqualFold(r.URL.Hostname(), pr.Knock)) {
|
||||
resp.StatusCode = http.StatusServiceUnavailable // default status code
|
||||
switch pr.Type {
|
||||
case "code":
|
||||
resp.StatusCode, _ = strconv.Atoi(pr.Value)
|
||||
case "web":
|
||||
url := pr.Value
|
||||
if !strings.HasPrefix(url, "http") {
|
||||
url = "http://" + url
|
||||
}
|
||||
r, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
resp = r
|
||||
defer resp.Body.Close()
|
||||
case "host":
|
||||
cc, err := net.Dial("tcp", pr.Value)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if err := h.forwardRequest(w, r, cc); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
case "file":
|
||||
f, _ := os.Open(pr.Value)
|
||||
if f != nil {
|
||||
defer f.Close()
|
||||
|
||||
resp.StatusCode = http.StatusOK
|
||||
if finfo, _ := f.Stat(); finfo != nil {
|
||||
resp.ContentLength = finfo.Size()
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/html")
|
||||
resp.Body = f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == 0 {
|
||||
resp.StatusCode = http.StatusProxyAuthRequired
|
||||
resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"")
|
||||
if strings.ToLower(r.Header.Get("Proxy-Connection")) == "keep-alive" {
|
||||
// XXX libcurl will keep sending auth request in same conn
|
||||
// which we don't supported yet.
|
||||
resp.Header.Add("Connection", "close")
|
||||
resp.Header.Add("Proxy-Connection", "close")
|
||||
}
|
||||
|
||||
log.Info("proxy authentication required")
|
||||
} else {
|
||||
resp.Header = http.Header{}
|
||||
resp.Header.Set("Server", "nginx/1.20.1")
|
||||
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
resp.Header.Set("Connection", "keep-alive")
|
||||
}
|
||||
}
|
||||
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
h.writeResponse(w, resp)
|
||||
|
||||
return
|
||||
}
|
||||
func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) {
|
||||
if err = r.Write(rw); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(rw), r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return h.writeResponse(w, resp)
|
||||
}
|
||||
|
||||
func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error {
|
||||
for k, v := range resp.Header {
|
||||
for _, vv := range v {
|
||||
w.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, err := io.Copy(flushWriter{w}, resp.Body)
|
||||
return err
|
||||
}
|
47
handler/http2/metadata.go
Normal file
47
handler/http2/metadata.go
Normal file
@ -0,0 +1,47 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
probeResistance *probeResistance
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
|
||||
const (
|
||||
header = "header"
|
||||
probeResistKey = "probeResistance"
|
||||
knock = "knock"
|
||||
)
|
||||
|
||||
if m := mdata.GetStringMapString(md, header); len(m) > 0 {
|
||||
hd := http.Header{}
|
||||
for k, v := range m {
|
||||
hd.Add(k, v)
|
||||
}
|
||||
h.md.header = hd
|
||||
}
|
||||
|
||||
if v := mdata.GetString(md, probeResistKey); v != "" {
|
||||
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 {
|
||||
h.md.probeResistance = &probeResistance{
|
||||
Type: ss[0],
|
||||
Value: ss[1],
|
||||
Knock: mdata.GetString(md, knock),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type probeResistance struct {
|
||||
Type string
|
||||
Value string
|
||||
Knock string
|
||||
}
|
112
handler/redirect/handler.go
Normal file
112
handler/redirect/handler.go
Normal file
@ -0,0 +1,112 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("red", NewHandler)
|
||||
registry.HandlerRegistry().Register("redu", NewHandler)
|
||||
registry.HandlerRegistry().Register("redir", NewHandler)
|
||||
registry.HandlerRegistry().Register("redirect", NewHandler)
|
||||
}
|
||||
|
||||
type redirectHandler struct {
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &redirectHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *redirectHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
network := "tcp"
|
||||
var dstAddr net.Addr
|
||||
var err error
|
||||
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
dstAddr = conn.LocalAddr()
|
||||
}
|
||||
|
||||
if network == "tcp" {
|
||||
dstAddr, conn, err = h.getOriginalDstAddr(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
|
||||
log.Info("bypass: ", dstAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, dstAddr.String())
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
|
||||
|
||||
return nil
|
||||
}
|
40
handler/redirect/handler_linux.go
Normal file
40
handler/redirect/handler_linux.go
Normal file
@ -0,0 +1,40 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
|
||||
defer conn.Close()
|
||||
|
||||
tc, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
err = errors.New("wrong connection type, must be TCP")
|
||||
return
|
||||
}
|
||||
|
||||
fc, err := tc.File()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer fc.Close()
|
||||
|
||||
mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// only ipv4 support
|
||||
ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7])
|
||||
port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3])
|
||||
addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c, err = net.FileConn(fc)
|
||||
return
|
||||
}
|
15
handler/redirect/handler_other.go
Normal file
15
handler/redirect/handler_other.go
Normal file
@ -0,0 +1,15 @@
|
||||
//go:build !linux
|
||||
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c net.Conn, err error) {
|
||||
defer conn.Close()
|
||||
|
||||
err = errors.New("TCP redirect is not available on non-linux platform")
|
||||
return
|
||||
}
|
18
handler/redirect/metadata.go
Normal file
18
handler/redirect/metadata.go
Normal file
@ -0,0 +1,18 @@
|
||||
package redirect
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
retryCount int
|
||||
}
|
||||
|
||||
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
retryCount = "retry"
|
||||
)
|
||||
|
||||
h.md.retryCount = mdata.GetInt(md, retryCount)
|
||||
return
|
||||
}
|
193
handler/relay/bind.go
Normal file
193
handler/relay/bind.go
Normal file
@ -0,0 +1,193 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
net_relay "github.com/go-gost/gost/v3/pkg/common/net/relay"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
"github.com/go-gost/x/internal/util/mux"
|
||||
relay_util "github.com/go-gost/x/internal/util/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "bind",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
|
||||
if !h.md.enableBind {
|
||||
resp.Status = relay.StatusForbidden
|
||||
log.Error("relay: BIND is disabled")
|
||||
_, err := resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
|
||||
if network == "tcp" {
|
||||
return h.bindTCP(ctx, conn, network, address, log)
|
||||
} else {
|
||||
return h.bindUDP(ctx, conn, network, address, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
|
||||
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
|
||||
af := &relay.AddrFeature{}
|
||||
err = af.ParseFrom(ln.Addr().String())
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
|
||||
af.AType = relay.AddrIPv4
|
||||
resp.Features = append(resp.Features, af)
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
log.Error(err)
|
||||
ln.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
return h.serveTCPBind(ctx, conn, ln, log)
|
||||
}
|
||||
|
||||
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
|
||||
bindAddr, _ := net.ResolveUDPAddr(network, address)
|
||||
pc, err := net.ListenUDP(network, bindAddr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
af := &relay.AddrFeature{}
|
||||
err = af.ParseFrom(pc.LocalAddr().String())
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
|
||||
af.AType = relay.AddrIPv4
|
||||
resp.Features = append(resp.Features, af)
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"bind": pc.LocalAddr().String(),
|
||||
})
|
||||
log.Debugf("bind on %s OK", pc.LocalAddr())
|
||||
|
||||
r := net_relay.NewUDPRelay(relay_util.UDPTunServerConn(conn), pc).
|
||||
WithBypass(h.options.Bypass).
|
||||
WithLogger(log)
|
||||
r.SetBufferSize(h.md.udpBufferSize)
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
r.Run()
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error {
|
||||
// Upgrade connection to multiplex stream.
|
||||
session, err := mux.ClientSession(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
go func() {
|
||||
defer ln.Close()
|
||||
for {
|
||||
conn, err := session.Accept()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn.Close() // we do not handle incoming connections.
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
rc, err := ln.Accept()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
log.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"local": ln.Addr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
sc, err := session.GetConn()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
|
||||
af := &relay.AddrFeature{}
|
||||
af.ParseFrom(c.RemoteAddr().String())
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
Features: []relay.Feature{af},
|
||||
}
|
||||
if _, err := resp.WriteTo(sc); err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
|
||||
netpkg.Transport(sc, c)
|
||||
log.WithFields(map[string]any{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
|
||||
}(rc)
|
||||
}
|
||||
}
|
81
handler/relay/conn.go
Normal file
81
handler/relay/conn.go
Normal file
@ -0,0 +1,81 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
)
|
||||
|
||||
type tcpConn struct {
|
||||
net.Conn
|
||||
wbuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *tcpConn) Read(b []byte) (n int, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *tcpConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b) // force byte length consistent
|
||||
if c.wbuf.Len() > 0 {
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
_, err = c.Conn.Write(b)
|
||||
return
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
net.Conn
|
||||
wbuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *udpConn) Read(b []byte) (n int, err error) {
|
||||
var bb [2]byte
|
||||
_, err = io.ReadFull(c.Conn, bb[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dlen := int(binary.BigEndian.Uint16(bb[:]))
|
||||
if len(b) >= dlen {
|
||||
return io.ReadFull(c.Conn, b[:dlen])
|
||||
}
|
||||
buf := make([]byte, dlen)
|
||||
_, err = io.ReadFull(c.Conn, buf)
|
||||
n = copy(b, buf)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpConn) Write(b []byte) (n int, err error) {
|
||||
if len(b) > math.MaxUint16 {
|
||||
err = errors.New("write: data maximum exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
n = len(b)
|
||||
if c.wbuf.Len() > 0 {
|
||||
var bb [2]byte
|
||||
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
|
||||
c.wbuf.Write(bb[:])
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
var bb [2]byte
|
||||
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
|
||||
_, err = c.Conn.Write(bb[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
91
handler/relay/connect.go
Normal file
91
handler/relay/connect.go
Normal file
@ -0,0 +1,91 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "connect",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
|
||||
if address == "" {
|
||||
resp.Status = relay.StatusBadRequest
|
||||
resp.WriteTo(conn)
|
||||
err := errors.New("target not specified")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
|
||||
log.Info("bypass: ", address)
|
||||
resp.Status = relay.StatusForbidden
|
||||
_, err := resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, address)
|
||||
if err != nil {
|
||||
resp.Status = relay.StatusNetworkUnreachable
|
||||
resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if h.md.noDelay {
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
rc := &udpConn{
|
||||
Conn: conn,
|
||||
}
|
||||
if !h.md.noDelay {
|
||||
// cache the header
|
||||
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
conn = rc
|
||||
default:
|
||||
rc := &tcpConn{
|
||||
Conn: conn,
|
||||
}
|
||||
if !h.md.noDelay {
|
||||
// cache the header
|
||||
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
conn = rc
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
|
||||
|
||||
return nil
|
||||
}
|
91
handler/relay/forward.go
Normal file
91
handler/relay/forward.go
Normal file
@ -0,0 +1,91 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
target := h.group.Next()
|
||||
if target == nil {
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
err := errors.New("target not available")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
|
||||
"cmd": "forward",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, target.Addr)
|
||||
if err != nil {
|
||||
// TODO: the router itself may be failed due to the failed node in the router,
|
||||
// the dead marker may be a wrong operation.
|
||||
target.Marker.Mark()
|
||||
|
||||
resp.Status = relay.StatusHostUnreachable
|
||||
resp.WriteTo(conn)
|
||||
log.Error(err)
|
||||
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
target.Marker.Reset()
|
||||
|
||||
if h.md.noDelay {
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
rc := &udpConn{
|
||||
Conn: conn,
|
||||
}
|
||||
if !h.md.noDelay {
|
||||
// cache the header
|
||||
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
conn = rc
|
||||
default:
|
||||
rc := &tcpConn{
|
||||
Conn: conn,
|
||||
}
|
||||
if !h.md.noDelay {
|
||||
// cache the header
|
||||
if _, err := resp.WriteTo(&rc.wbuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
conn = rc
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
|
||||
|
||||
return nil
|
||||
}
|
147
handler/relay/handler.go
Normal file
147
handler/relay/handler.go
Normal file
@ -0,0 +1,147 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadVersion = errors.New("relay: bad version")
|
||||
ErrUnknownCmd = errors.New("relay: unknown command")
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("relay", NewHandler)
|
||||
}
|
||||
|
||||
type relayHandler struct {
|
||||
group *chain.NodeGroup
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &relayHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *relayHandler) Init(md md.Metadata) (err error) {
|
||||
if err := h.parseMetadata(md); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward implements handler.Forwarder.
|
||||
func (h *relayHandler) Forward(group *chain.NodeGroup) {
|
||||
h.group = group
|
||||
}
|
||||
|
||||
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
if h.md.readTimeout > 0 {
|
||||
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
|
||||
}
|
||||
|
||||
req := relay.Request{}
|
||||
if _, err := req.ReadFrom(conn); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if req.Version != relay.Version1 {
|
||||
err := ErrBadVersion
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
var user, pass string
|
||||
var address string
|
||||
for _, f := range req.Features {
|
||||
if f.Type() == relay.FeatureUserAuth {
|
||||
feature := f.(*relay.UserAuthFeature)
|
||||
user, pass = feature.Username, feature.Password
|
||||
}
|
||||
if f.Type() == relay.FeatureAddr {
|
||||
feature := f.(*relay.AddrFeature)
|
||||
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
|
||||
}
|
||||
}
|
||||
|
||||
if user != "" {
|
||||
log = log.WithFields(map[string]any{"user": user})
|
||||
}
|
||||
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) {
|
||||
resp.Status = relay.StatusUnauthorized
|
||||
log.Error("unauthorized")
|
||||
_, err := resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if (req.Flags & relay.FUDP) == relay.FUDP {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
if h.group != nil {
|
||||
if address != "" {
|
||||
resp.Status = relay.StatusForbidden
|
||||
log.Error("forward mode, connect is forbidden")
|
||||
_, err := resp.WriteTo(conn)
|
||||
return err
|
||||
}
|
||||
// forward mode
|
||||
return h.handleForward(ctx, conn, network, log)
|
||||
}
|
||||
|
||||
switch req.Flags & relay.CmdMask {
|
||||
case 0, relay.CONNECT:
|
||||
return h.handleConnect(ctx, conn, network, address, log)
|
||||
case relay.BIND:
|
||||
return h.handleBind(ctx, conn, network, address, log)
|
||||
}
|
||||
return ErrUnknownCmd
|
||||
}
|
35
handler/relay/metadata.go
Normal file
35
handler/relay/metadata.go
Normal file
@ -0,0 +1,35 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
readTimeout time.Duration
|
||||
enableBind bool
|
||||
udpBufferSize int
|
||||
noDelay bool
|
||||
}
|
||||
|
||||
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
readTimeout = "readTimeout"
|
||||
enableBind = "bind"
|
||||
udpBufferSize = "udpBufferSize"
|
||||
noDelay = "nodelay"
|
||||
)
|
||||
|
||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
h.md.enableBind = mdata.GetBool(md, enableBind)
|
||||
h.md.noDelay = mdata.GetBool(md, noDelay)
|
||||
|
||||
if bs := mdata.GetInt(md, udpBufferSize); bs > 0 {
|
||||
h.md.udpBufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
|
||||
} else {
|
||||
h.md.udpBufferSize = 1024
|
||||
}
|
||||
return
|
||||
}
|
20
handler/sni/conn.go
Normal file
20
handler/sni/conn.go
Normal file
@ -0,0 +1,20 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type cacheConn struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (c *cacheConn) Read(b []byte) (n int, err error) {
|
||||
if len(c.buf) > 0 {
|
||||
n = copy(b, c.buf)
|
||||
c.buf = c.buf[n:]
|
||||
return
|
||||
}
|
||||
|
||||
return c.Conn.Read(b)
|
||||
}
|
222
handler/sni/handler.go
Normal file
222
handler/sni/handler.go
Normal file
@ -0,0 +1,222 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
dissector "github.com/go-gost/tls-dissector"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("sni", NewHandler)
|
||||
}
|
||||
|
||||
type sniHandler struct {
|
||||
httpHandler handler.Handler
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
h := &sniHandler{
|
||||
options: options,
|
||||
}
|
||||
|
||||
if f := registry.HandlerRegistry().Get("http"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(h.options.Logger.WithFields(map[string]any{"type": "http"})))
|
||||
h.httpHandler = f(v...)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *sniHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
if h.httpHandler != nil {
|
||||
if md != nil {
|
||||
md.Set("sni", true)
|
||||
}
|
||||
if err = h.httpHandler.Init(md); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sniHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
var hdr [dissector.RecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if hdr[0] != dissector.Handshake {
|
||||
// We assume it is an HTTP request
|
||||
conn = &cacheConn{
|
||||
Conn: conn,
|
||||
buf: hdr[:],
|
||||
}
|
||||
|
||||
if h.httpHandler != nil {
|
||||
return h.httpHandler.Handle(ctx, conn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
length := binary.BigEndian.Uint16(hdr[3:5])
|
||||
|
||||
buf := bufpool.Get(int(length) + dissector.RecordHeaderLen)
|
||||
defer bufpool.Put(buf)
|
||||
if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
copy(*buf, hdr[:])
|
||||
|
||||
opaque, host, err := h.decodeHost(bytes.NewReader(*buf))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
target := net.JoinHostPort(host, "443")
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": target,
|
||||
})
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(target) {
|
||||
log.Info("bypass: ", target)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", target)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if _, err := cc.Write(opaque); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) {
|
||||
record, err := dissector.ReadRecord(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
clientHello := dissector.ClientHelloMsg{}
|
||||
if err = clientHello.Decode(record.Opaque); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var extensions []dissector.Extension
|
||||
for _, ext := range clientHello.Extensions {
|
||||
if ext.Type() == 0xFFFE {
|
||||
b, _ := ext.Encode()
|
||||
if v, err := h.decodeServerName(string(b)); err == nil {
|
||||
host = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
extensions = append(extensions, ext)
|
||||
}
|
||||
clientHello.Extensions = extensions
|
||||
|
||||
for _, ext := range clientHello.Extensions {
|
||||
if ext.Type() == dissector.ExtServerName {
|
||||
snExtension := ext.(*dissector.ServerNameExtension)
|
||||
if host == "" {
|
||||
host = snExtension.Name
|
||||
} else {
|
||||
snExtension.Name = host
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
record.Opaque, err = clientHello.Encode()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
if _, err = record.WriteTo(buf); err != nil {
|
||||
return
|
||||
}
|
||||
opaque = buf.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
func (h *sniHandler) decodeServerName(s string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(b) < 4 {
|
||||
return "", errors.New("invalid name")
|
||||
}
|
||||
v, err := base64.RawURLEncoding.DecodeString(string(b[4:]))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) {
|
||||
return "", errors.New("invalid name")
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
20
handler/sni/metadata.go
Normal file
20
handler/sni/metadata.go
Normal file
@ -0,0 +1,20 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
readTimeout = "readTimeout"
|
||||
)
|
||||
|
||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
return
|
||||
}
|
119
handler/ss/handler.go
Normal file
119
handler/ss/handler.go
Normal file
@ -0,0 +1,119 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/x/internal/util/ss"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("ss", NewHandler)
|
||||
}
|
||||
|
||||
type ssHandler struct {
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &ssHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ssHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
if h.options.Auth != nil {
|
||||
method := h.options.Auth.Username()
|
||||
password, _ := h.options.Auth.Password()
|
||||
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
if h.cipher != nil {
|
||||
conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil)
|
||||
}
|
||||
|
||||
if h.md.readTimeout > 0 {
|
||||
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
|
||||
}
|
||||
|
||||
addr := &gosocks5.Addr{}
|
||||
if _, err := addr.ReadFrom(conn); err != nil {
|
||||
log.Error(err)
|
||||
io.Copy(ioutil.Discard, conn)
|
||||
return err
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": addr.String(),
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
|
||||
log.Info("bypass: ", addr.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
|
||||
return nil
|
||||
}
|
24
handler/ss/metadata.go
Normal file
24
handler/ss/metadata.go
Normal file
@ -0,0 +1,24 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
key string
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
key = "key"
|
||||
readTimeout = "readTimeout"
|
||||
)
|
||||
|
||||
h.md.key = mdata.GetString(md, key)
|
||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
|
||||
return
|
||||
}
|
188
handler/ss/udp/handler.go
Normal file
188
handler/ss/udp/handler.go
Normal file
@ -0,0 +1,188 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/x/internal/util/relay"
|
||||
"github.com/go-gost/x/internal/util/ss"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("ssu", NewHandler)
|
||||
}
|
||||
|
||||
type ssuHandler struct {
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &ssuHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ssuHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h.options.Auth != nil {
|
||||
method := h.options.Auth.Username()
|
||||
password, _ := h.options.Auth.Password()
|
||||
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
pc, ok := conn.(net.PacketConn)
|
||||
if ok {
|
||||
if h.cipher != nil {
|
||||
pc = h.cipher.PacketConn(pc)
|
||||
}
|
||||
// standard UDP relay.
|
||||
pc = ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize)
|
||||
} else {
|
||||
if h.cipher != nil {
|
||||
conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil)
|
||||
}
|
||||
// UDP over TCP
|
||||
pc = relay.UDPTunServerConn(conn)
|
||||
}
|
||||
|
||||
// obtain a udp connection
|
||||
c, err := h.router.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
cc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
err := errors.New("ss: wrong connection type")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.LocalAddr(), cc.LocalAddr())
|
||||
h.relayPacket(pc, cc, log)
|
||||
log.WithFields(map[string]any{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) {
|
||||
bufSize := h.md.bufferSize
|
||||
errc := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, addr, err := pc1.ReadFrom(*b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
|
||||
log.Warn("bypass: ", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = pc2.WriteTo((*b)[:n], addr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("%s >>> %s data: %d",
|
||||
pc2.LocalAddr(), addr, n)
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(bufSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, raddr, err := pc2.ReadFrom(*b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
|
||||
log.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = pc1.WriteTo((*b)[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("%s <<< %s data: %d",
|
||||
pc2.LocalAddr(), raddr, n)
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return <-errc
|
||||
}
|
32
handler/ss/udp/metadata.go
Normal file
32
handler/ss/udp/metadata.go
Normal file
@ -0,0 +1,32 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
key string
|
||||
readTimeout time.Duration
|
||||
bufferSize int
|
||||
}
|
||||
|
||||
func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
key = "key"
|
||||
readTimeout = "readTimeout"
|
||||
bufferSize = "bufferSize"
|
||||
)
|
||||
|
||||
h.md.key = mdata.GetString(md, key)
|
||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||
|
||||
if bs := mdata.GetInt(md, bufferSize); bs > 0 {
|
||||
h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024))
|
||||
} else {
|
||||
h.md.bufferSize = 1024
|
||||
}
|
||||
return
|
||||
}
|
245
handler/sshd/handler.go
Normal file
245
handler/sshd/handler.go
Normal file
@ -0,0 +1,245 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
netpkg "github.com/go-gost/gost/v3/pkg/common/net"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
sshd_util "github.com/go-gost/x/internal/util/sshd"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
|
||||
const (
|
||||
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("sshd", NewHandler)
|
||||
}
|
||||
|
||||
type forwardHandler struct {
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &forwardHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
switch cc := conn.(type) {
|
||||
case *sshd_util.DirectForwardConn:
|
||||
return h.handleDirectForward(ctx, cc, log)
|
||||
case *sshd_util.RemoteForwardConn:
|
||||
return h.handleRemoteForward(ctx, cc, log)
|
||||
default:
|
||||
err := errors.New("sshd: wrong connection type")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) error {
|
||||
targetAddr := conn.DstAddr()
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", targetAddr, "tcp"),
|
||||
"cmd": "connect",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), targetAddr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) {
|
||||
log.Infof("bypass %s", targetAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
|
||||
netpkg.Transport(conn, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) error {
|
||||
req := conn.Request()
|
||||
|
||||
t := tcpipForward{}
|
||||
if err := ssh.Unmarshal(req.Payload, &t); err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", addr, network),
|
||||
"cmd": "bind",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
// tie to the client connection
|
||||
ln, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
req.Reply(false, nil)
|
||||
return err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
err = func() error {
|
||||
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
|
||||
_, port, err := getHostPortFromAddr(ln.Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], uint32(port))
|
||||
t.Port = uint32(port)
|
||||
return req.Reply(true, b[:])
|
||||
}
|
||||
return req.Reply(true, nil)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
sshConn := conn.Conn()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
cc, err := ln.Accept()
|
||||
if err != nil { // Unable to accept new connection - listener is likely closed
|
||||
return
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
log := log.WithFields(map[string]any{
|
||||
"local": conn.LocalAddr().String(),
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
p := directForward{}
|
||||
var err error
|
||||
|
||||
var portnum int
|
||||
p.Host1 = t.Host
|
||||
p.Port1 = t.Port
|
||||
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.Port2 = uint32(portnum)
|
||||
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
|
||||
if err != nil {
|
||||
log.Error("open forwarded channel: ", err)
|
||||
return
|
||||
}
|
||||
defer ch.Close()
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
netpkg.Transport(ch, conn)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
}(cc)
|
||||
}
|
||||
}()
|
||||
|
||||
tm := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
<-conn.Done()
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(tm),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
||||
host, portString, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
port, err = strconv.Atoi(portString)
|
||||
return
|
||||
}
|
||||
|
||||
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
|
||||
type directForward struct {
|
||||
Host1 string
|
||||
Port1 uint32
|
||||
Host2 string
|
||||
Port2 uint32
|
||||
}
|
||||
|
||||
func (p directForward) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
|
||||
}
|
||||
|
||||
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
|
||||
type tcpipForward struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
12
handler/sshd/metadata.go
Normal file
12
handler/sshd/metadata.go
Normal file
@ -0,0 +1,12 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
}
|
||||
|
||||
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
return
|
||||
}
|
341
handler/tap/handler.go
Normal file
341
handler/tap/handler.go
Normal file
@ -0,0 +1,341 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/x/internal/util/ss"
|
||||
tap_util "github.com/go-gost/x/internal/util/tap"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
|
||||
"github.com/songgao/water/waterutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("tap", NewHandler)
|
||||
}
|
||||
|
||||
type tapHandler struct {
|
||||
group *chain.NodeGroup
|
||||
routes sync.Map
|
||||
exit chan struct{}
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &tapHandler{
|
||||
exit: make(chan struct{}, 1),
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tapHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h.options.Auth != nil {
|
||||
method := h.options.Auth.Username()
|
||||
password, _ := h.options.Auth.Password()
|
||||
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Forward implements handler.Forwarder.
|
||||
func (h *tapHandler) Forward(group *chain.NodeGroup) {
|
||||
h.group = group
|
||||
}
|
||||
|
||||
func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer os.Exit(0)
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger
|
||||
cc, ok := conn.(*tap_util.Conn)
|
||||
if !ok || cc.Config() == nil {
|
||||
err := errors.New("tap: wrong connection type")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
log = log.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
network := "udp"
|
||||
var raddr net.Addr
|
||||
var err error
|
||||
|
||||
target := h.group.Next()
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
|
||||
})
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
|
||||
}
|
||||
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) {
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
err := func() error {
|
||||
var err error
|
||||
var pc net.PacketConn
|
||||
|
||||
if addr != nil {
|
||||
cc, err := h.router.Dial(ctx, addr.Network(), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ok bool
|
||||
pc, ok = cc.(net.PacketConn)
|
||||
if !ok {
|
||||
return errors.New("wrong connection type")
|
||||
}
|
||||
} else {
|
||||
laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
|
||||
pc, err = net.ListenUDP("udp", laddr)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.cipher != nil {
|
||||
pc = h.cipher.PacketConn(pc)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
return h.transport(conn, pc, addr, config, log)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-h.exit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 1000 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if max := 6 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
}
|
||||
tempDelay = 0
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr, config *tap_util.Config, log logger.Logger) error {
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, err := tap.Read(*b)
|
||||
if err != nil {
|
||||
select {
|
||||
case h.exit <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
src := waterutil.MACSource((*b)[:n])
|
||||
dst := waterutil.MACDestination((*b)[:n])
|
||||
eType := etherType(waterutil.MACEthertype((*b)[:n]))
|
||||
|
||||
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
|
||||
// client side, deliver frame directly.
|
||||
if raddr != nil {
|
||||
_, err := conn.WriteTo((*b)[:n], raddr)
|
||||
return err
|
||||
}
|
||||
|
||||
// server side, broadcast.
|
||||
if waterutil.IsBroadcast(dst) {
|
||||
go h.routes.Range(func(k, v any) bool {
|
||||
conn.WriteTo((*b)[:n], v.(net.Addr))
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
var addr net.Addr
|
||||
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
|
||||
addr = v.(net.Addr)
|
||||
}
|
||||
if addr == nil {
|
||||
log.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, addr, err := conn.ReadFrom(*b)
|
||||
if err != nil &&
|
||||
err != shadowaead.ErrShortPacket {
|
||||
return err
|
||||
}
|
||||
|
||||
src := waterutil.MACSource((*b)[:n])
|
||||
dst := waterutil.MACDestination((*b)[:n])
|
||||
eType := etherType(waterutil.MACEthertype((*b)[:n]))
|
||||
|
||||
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
|
||||
// client side, deliver frame to tap device.
|
||||
if raddr != nil {
|
||||
_, err := tap.Write((*b)[:n])
|
||||
return err
|
||||
}
|
||||
|
||||
// server side, record route.
|
||||
rkey := hwAddrToTapRouteKey(src)
|
||||
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
|
||||
if actual.(net.Addr).String() != addr.String() {
|
||||
log.Debugf("update route: %s -> %s (old %s)",
|
||||
src, addr, actual.(net.Addr))
|
||||
h.routes.Store(rkey, addr)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("new route: %s -> %s", src, addr)
|
||||
}
|
||||
|
||||
if waterutil.IsBroadcast(dst) {
|
||||
go h.routes.Range(func(k, v any) bool {
|
||||
if k.(tapRouteKey) != rkey {
|
||||
conn.WriteTo((*b)[:n], v.(net.Addr))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
|
||||
log.Debugf("find route: %s -> %s", dst, v)
|
||||
_, err := conn.WriteTo((*b)[:n], v.(net.Addr))
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tap.Write((*b)[:n]); err != nil {
|
||||
select {
|
||||
case h.exit <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-errc
|
||||
if err != nil && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var mEtherTypes = map[waterutil.Ethertype]string{
|
||||
waterutil.IPv4: "ip",
|
||||
waterutil.ARP: "arp",
|
||||
waterutil.RARP: "rarp",
|
||||
waterutil.IPv6: "ip6",
|
||||
}
|
||||
|
||||
func etherType(et waterutil.Ethertype) string {
|
||||
if s, ok := mEtherTypes[et]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("unknown(%v)", et)
|
||||
}
|
||||
|
||||
type tapRouteKey [6]byte
|
||||
|
||||
func hwAddrToTapRouteKey(addr net.HardwareAddr) (key tapRouteKey) {
|
||||
copy(key[:], addr)
|
||||
return
|
||||
}
|
24
handler/tap/metadata.go
Normal file
24
handler/tap/metadata.go
Normal file
@ -0,0 +1,24 @@
|
||||
package tap
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
key string
|
||||
bufferSize int
|
||||
}
|
||||
|
||||
func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
key = "key"
|
||||
bufferSize = "bufferSize"
|
||||
)
|
||||
|
||||
h.md.key = mdata.GetString(md, key)
|
||||
h.md.bufferSize = mdata.GetInt(md, bufferSize)
|
||||
if h.md.bufferSize <= 0 {
|
||||
h.md.bufferSize = 1500
|
||||
}
|
||||
return
|
||||
}
|
391
handler/tun/handler.go
Normal file
391
handler/tun/handler.go
Normal file
@ -0,0 +1,391 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/v3/pkg/chain"
|
||||
"github.com/go-gost/gost/v3/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/v3/pkg/handler"
|
||||
"github.com/go-gost/gost/v3/pkg/logger"
|
||||
md "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
"github.com/go-gost/gost/v3/pkg/registry"
|
||||
"github.com/go-gost/x/internal/util/ss"
|
||||
tun_util "github.com/go-gost/x/internal/util/tun"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
|
||||
"github.com/songgao/water/waterutil"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.HandlerRegistry().Register("tun", NewHandler)
|
||||
}
|
||||
|
||||
type tunHandler struct {
|
||||
group *chain.NodeGroup
|
||||
routes sync.Map
|
||||
exit chan struct{}
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &tunHandler{
|
||||
exit: make(chan struct{}, 1),
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tunHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h.options.Auth != nil {
|
||||
method := h.options.Auth.Username()
|
||||
password, _ := h.options.Auth.Password()
|
||||
h.cipher, err = ss.ShadowCipher(method, password, h.md.key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.router = h.options.Router
|
||||
if h.router == nil {
|
||||
h.router = (&chain.Router{}).WithLogger(h.options.Logger)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Forward implements handler.Forwarder.
|
||||
func (h *tunHandler) Forward(group *chain.NodeGroup) {
|
||||
h.group = group
|
||||
}
|
||||
|
||||
func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer os.Exit(0)
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger
|
||||
|
||||
cc, ok := conn.(*tun_util.Conn)
|
||||
if !ok || cc.Config() == nil {
|
||||
err := errors.New("tun: wrong connection type")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
log = log.WithFields(map[string]any{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
network := "udp"
|
||||
var raddr net.Addr
|
||||
var err error
|
||||
|
||||
target := h.group.Next()
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
|
||||
})
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr)
|
||||
}
|
||||
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) {
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
err := func() error {
|
||||
var err error
|
||||
var pc net.PacketConn
|
||||
if addr != nil {
|
||||
cc, err := h.router.Dial(ctx, addr.Network(), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ok bool
|
||||
pc, ok = cc.(net.PacketConn)
|
||||
if !ok {
|
||||
cc.Close()
|
||||
return errors.New("wrong connection type")
|
||||
}
|
||||
} else {
|
||||
laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
|
||||
pc, err = net.ListenUDP("udp", laddr)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.cipher != nil {
|
||||
pc = h.cipher.PacketConn(pc)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
return h.transport(conn, pc, addr, config, log)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-h.exit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 1000 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if max := 6 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
}
|
||||
tempDelay = 0
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, config *tun_util.Config, log logger.Logger) error {
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, err := tun.Read(*b)
|
||||
if err != nil {
|
||||
select {
|
||||
case h.exit <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var src, dst net.IP
|
||||
if waterutil.IsIPv4((*b)[:n]) {
|
||||
header, err := ipv4.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||
header, err := ipv6.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("%s >> %s %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else {
|
||||
log.Warn("unknown packet, discarded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// client side, deliver packet directly.
|
||||
if raddr != nil {
|
||||
_, err := conn.WriteTo((*b)[:n], raddr)
|
||||
return err
|
||||
}
|
||||
|
||||
addr := h.findRouteFor(dst, config.Routes...)
|
||||
if addr == nil {
|
||||
log.Warnf("no route for %s -> %s", src, dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("find route: %s -> %s", dst, addr)
|
||||
|
||||
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := func() error {
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
n, addr, err := conn.ReadFrom(*b)
|
||||
if err != nil &&
|
||||
err != shadowaead.ErrShortPacket {
|
||||
return err
|
||||
}
|
||||
|
||||
var src, dst net.IP
|
||||
if waterutil.IsIPv4((*b)[:n]) {
|
||||
header, err := ipv4.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||
header, err := ipv6.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("%s > %s %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else {
|
||||
log.Warn("unknown packet, discarded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// client side, deliver packet to tun device.
|
||||
if raddr != nil {
|
||||
_, err := tun.Write((*b)[:n])
|
||||
return err
|
||||
}
|
||||
|
||||
rkey := ipToTunRouteKey(src)
|
||||
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
|
||||
if actual.(net.Addr).String() != addr.String() {
|
||||
log.Debugf("update route: %s -> %s (old %s)",
|
||||
src, addr, actual.(net.Addr))
|
||||
h.routes.Store(rkey, addr)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("no route for %s -> %s", src, addr)
|
||||
}
|
||||
|
||||
if addr := h.findRouteFor(dst, config.Routes...); addr != nil {
|
||||
log.Debugf("find route: %s -> %s", dst, addr)
|
||||
|
||||
_, err := conn.WriteTo((*b)[:n], addr)
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tun.Write((*b)[:n]); err != nil {
|
||||
select {
|
||||
case h.exit <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-errc
|
||||
if err != nil && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *tunHandler) findRouteFor(dst net.IP, routes ...tun_util.Route) net.Addr {
|
||||
if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok {
|
||||
return v.(net.Addr)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if route.Net.Contains(dst) && route.Gateway != nil {
|
||||
if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok {
|
||||
return v.(net.Addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var mIPProts = map[waterutil.IPProtocol]string{
|
||||
waterutil.HOPOPT: "HOPOPT",
|
||||
waterutil.ICMP: "ICMP",
|
||||
waterutil.IGMP: "IGMP",
|
||||
waterutil.GGP: "GGP",
|
||||
waterutil.TCP: "TCP",
|
||||
waterutil.UDP: "UDP",
|
||||
waterutil.IPv6_Route: "IPv6-Route",
|
||||
waterutil.IPv6_Frag: "IPv6-Frag",
|
||||
waterutil.IPv6_ICMP: "IPv6-ICMP",
|
||||
}
|
||||
|
||||
func ipProtocol(p waterutil.IPProtocol) string {
|
||||
if v, ok := mIPProts[p]; ok {
|
||||
return v
|
||||
}
|
||||
return fmt.Sprintf("unknown(%d)", p)
|
||||
}
|
||||
|
||||
type tunRouteKey [16]byte
|
||||
|
||||
func ipToTunRouteKey(ip net.IP) (key tunRouteKey) {
|
||||
copy(key[:], ip.To16())
|
||||
return
|
||||
}
|
24
handler/tun/metadata.go
Normal file
24
handler/tun/metadata.go
Normal file
@ -0,0 +1,24 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/v3/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
key string
|
||||
bufferSize int
|
||||
}
|
||||
|
||||
func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
key = "key"
|
||||
bufferSize = "bufferSize"
|
||||
)
|
||||
|
||||
h.md.key = mdata.GetString(md, key)
|
||||
h.md.bufferSize = mdata.GetInt(md, bufferSize)
|
||||
if h.md.bufferSize <= 0 {
|
||||
h.md.bufferSize = 1500
|
||||
}
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user