add vhost for port forwarding
This commit is contained in:
84
handler/forward/internal/forward/forward.go
Normal file
84
handler/forward/internal/forward/forward.go
Normal file
@ -0,0 +1,84 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
dissector "github.com/go-gost/tls-dissector"
|
||||
xio "github.com/go-gost/x/internal/io"
|
||||
)
|
||||
|
||||
func SniffHost(ctx context.Context, rdw io.ReadWriter) (rw io.ReadWriter, host string, err error) {
|
||||
rw = rdw
|
||||
|
||||
// try to sniff TLS traffic
|
||||
var hdr [dissector.RecordHeaderLen]byte
|
||||
_, err = io.ReadFull(rw, hdr[:])
|
||||
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw)
|
||||
if err == nil &&
|
||||
hdr[0] == dissector.Handshake &&
|
||||
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
|
||||
return sniffSNI(ctx, rw)
|
||||
}
|
||||
|
||||
// try to sniff HTTP traffic
|
||||
if isHTTP(string(hdr[:])) {
|
||||
buf := new(bytes.Buffer)
|
||||
var r *http.Request
|
||||
r, err = http.ReadRequest(bufio.NewReader(io.TeeReader(rw, buf)))
|
||||
rw = xio.NewReadWriter(io.MultiReader(buf, rw), rw)
|
||||
if err == nil {
|
||||
host = r.Host
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func sniffSNI(ctx context.Context, rw io.ReadWriter) (io.ReadWriter, string, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
host, err := getServerName(ctx, io.TeeReader(rw, buf))
|
||||
rw = xio.NewReadWriter(io.MultiReader(buf, rw), rw)
|
||||
return rw, host, err
|
||||
}
|
||||
|
||||
func getServerName(ctx context.Context, r io.Reader) (host string, err error) {
|
||||
record, err := dissector.ReadRecord(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientHello := dissector.ClientHelloMsg{}
|
||||
if err = clientHello.Decode(record.Opaque); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ext := range clientHello.Extensions {
|
||||
if ext.Type() == dissector.ExtServerName {
|
||||
snExtension := ext.(*dissector.ServerNameExtension)
|
||||
host = snExtension.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func isHTTP(s string) bool {
|
||||
return strings.HasPrefix(http.MethodGet, s[:3]) ||
|
||||
strings.HasPrefix(http.MethodPost, s[:4]) ||
|
||||
strings.HasPrefix(http.MethodPut, s[:3]) ||
|
||||
strings.HasPrefix(http.MethodDelete, s) ||
|
||||
strings.HasPrefix(http.MethodOptions, s) ||
|
||||
strings.HasPrefix(http.MethodPatch, s) ||
|
||||
strings.HasPrefix(http.MethodHead, s[:4]) ||
|
||||
strings.HasPrefix(http.MethodConnect, s) ||
|
||||
strings.HasPrefix(http.MethodTrace, s)
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"github.com/go-gost/core/handler"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
xchain "github.com/go-gost/x/chain"
|
||||
"github.com/go-gost/x/handler/forward/internal/forward"
|
||||
netpkg "github.com/go-gost/x/internal/net"
|
||||
"github.com/go-gost/x/registry"
|
||||
)
|
||||
@ -84,18 +86,29 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
return nil
|
||||
}
|
||||
|
||||
target := h.hop.Select(ctx)
|
||||
network := "tcp"
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
var rw io.ReadWriter
|
||||
var host string
|
||||
if h.md.sniffing {
|
||||
if network == "tcp" {
|
||||
rw, host, _ = forward.SniffHost(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
var target *chain.Node
|
||||
if h.hop != nil {
|
||||
target = h.hop.Select(ctx, chain.HostSelectOption(host))
|
||||
}
|
||||
if target == nil {
|
||||
err := errors.New("target not available")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
|
||||
})
|
||||
@ -119,7 +132,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
|
||||
t := time.Now()
|
||||
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
|
||||
netpkg.Transport(conn, cc)
|
||||
netpkg.Transport(rw, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)
|
||||
|
@ -9,13 +9,16 @@ import (
|
||||
|
||||
type metadata struct {
|
||||
readTimeout time.Duration
|
||||
sniffing bool
|
||||
}
|
||||
|
||||
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
readTimeout = "readTimeout"
|
||||
sniffing = "sniffing"
|
||||
)
|
||||
|
||||
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
|
||||
h.md.sniffing = mdutil.GetBool(md, sniffing)
|
||||
return
|
||||
}
|
||||
|
@ -4,12 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/handler"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
"github.com/go-gost/x/handler/forward/internal/forward"
|
||||
netpkg "github.com/go-gost/x/internal/net"
|
||||
"github.com/go-gost/x/registry"
|
||||
)
|
||||
@ -75,9 +77,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
return nil
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
var rw io.ReadWriter
|
||||
var host string
|
||||
if h.md.sniffing {
|
||||
if network == "tcp" {
|
||||
rw, host, _ = forward.SniffHost(ctx, conn)
|
||||
}
|
||||
}
|
||||
var target *chain.Node
|
||||
if h.hop != nil {
|
||||
target = h.hop.Select(ctx)
|
||||
target = h.hop.Select(ctx, chain.HostSelectOption(host))
|
||||
}
|
||||
if target == nil {
|
||||
err := errors.New("target not available")
|
||||
@ -85,11 +99,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
return err
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if _, ok := conn.(net.PacketConn); ok {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
log = log.WithFields(map[string]any{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
|
||||
})
|
||||
@ -113,7 +122,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
|
||||
t := time.Now()
|
||||
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
|
||||
netpkg.Transport(conn, cc)
|
||||
netpkg.Transport(rw, cc)
|
||||
log.WithFields(map[string]any{
|
||||
"duration": time.Since(t),
|
||||
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)
|
||||
|
@ -9,13 +9,16 @@ import (
|
||||
|
||||
type metadata struct {
|
||||
readTimeout time.Duration
|
||||
sniffing bool
|
||||
}
|
||||
|
||||
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
readTimeout = "readTimeout"
|
||||
sniffing = "sniffing"
|
||||
)
|
||||
|
||||
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
|
||||
h.md.sniffing = mdutil.GetBool(md, sniffing)
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user