add vhost for port forwarding

This commit is contained in:
ginuerzh
2022-11-11 22:23:36 +08:00
parent 81b6efc9b8
commit 1ff2bab1f0
12 changed files with 195 additions and 23 deletions

View File

@ -0,0 +1,84 @@
package forward
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"io"
"net/http"
"strings"
dissector "github.com/go-gost/tls-dissector"
xio "github.com/go-gost/x/internal/io"
)
func SniffHost(ctx context.Context, rdw io.ReadWriter) (rw io.ReadWriter, host string, err error) {
rw = rdw
// try to sniff TLS traffic
var hdr [dissector.RecordHeaderLen]byte
_, err = io.ReadFull(rw, hdr[:])
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw)
if err == nil &&
hdr[0] == dissector.Handshake &&
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
return sniffSNI(ctx, rw)
}
// try to sniff HTTP traffic
if isHTTP(string(hdr[:])) {
buf := new(bytes.Buffer)
var r *http.Request
r, err = http.ReadRequest(bufio.NewReader(io.TeeReader(rw, buf)))
rw = xio.NewReadWriter(io.MultiReader(buf, rw), rw)
if err == nil {
host = r.Host
return
}
}
return
}
func sniffSNI(ctx context.Context, rw io.ReadWriter) (io.ReadWriter, string, error) {
buf := new(bytes.Buffer)
host, err := getServerName(ctx, io.TeeReader(rw, buf))
rw = xio.NewReadWriter(io.MultiReader(buf, rw), rw)
return rw, host, err
}
func getServerName(ctx context.Context, r io.Reader) (host string, err error) {
record, err := dissector.ReadRecord(r)
if err != nil {
return
}
clientHello := dissector.ClientHelloMsg{}
if err = clientHello.Decode(record.Opaque); err != nil {
return
}
for _, ext := range clientHello.Extensions {
if ext.Type() == dissector.ExtServerName {
snExtension := ext.(*dissector.ServerNameExtension)
host = snExtension.Name
break
}
}
return
}
func isHTTP(s string) bool {
return strings.HasPrefix(http.MethodGet, s[:3]) ||
strings.HasPrefix(http.MethodPost, s[:4]) ||
strings.HasPrefix(http.MethodPut, s[:3]) ||
strings.HasPrefix(http.MethodDelete, s) ||
strings.HasPrefix(http.MethodOptions, s) ||
strings.HasPrefix(http.MethodPatch, s) ||
strings.HasPrefix(http.MethodHead, s[:4]) ||
strings.HasPrefix(http.MethodConnect, s) ||
strings.HasPrefix(http.MethodTrace, s)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
@ -11,6 +12,7 @@ import (
"github.com/go-gost/core/handler"
md "github.com/go-gost/core/metadata"
xchain "github.com/go-gost/x/chain"
"github.com/go-gost/x/handler/forward/internal/forward"
netpkg "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/registry"
)
@ -84,18 +86,29 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
target := h.hop.Select(ctx)
network := "tcp"
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
}
var rw io.ReadWriter
var host string
if h.md.sniffing {
if network == "tcp" {
rw, host, _ = forward.SniffHost(ctx, conn)
}
}
var target *chain.Node
if h.hop != nil {
target = h.hop.Select(ctx, chain.HostSelectOption(host))
}
if target == nil {
err := errors.New("target not available")
log.Error(err)
return err
}
network := "tcp"
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
})
@ -119,7 +132,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
netpkg.Transport(conn, cc)
netpkg.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)

View File

@ -9,13 +9,16 @@ import (
type metadata struct {
readTimeout time.Duration
sniffing bool
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
sniffing = "sniffing"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
return
}

View File

@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/handler"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/x/handler/forward/internal/forward"
netpkg "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/registry"
)
@ -75,9 +77,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
network := "tcp"
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
}
var rw io.ReadWriter
var host string
if h.md.sniffing {
if network == "tcp" {
rw, host, _ = forward.SniffHost(ctx, conn)
}
}
var target *chain.Node
if h.hop != nil {
target = h.hop.Select(ctx)
target = h.hop.Select(ctx, chain.HostSelectOption(host))
}
if target == nil {
err := errors.New("target not available")
@ -85,11 +99,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return err
}
network := "tcp"
if _, ok := conn.(net.PacketConn); ok {
network = "udp"
}
log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
})
@ -113,7 +122,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
netpkg.Transport(conn, cc)
netpkg.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)

View File

@ -9,13 +9,16 @@ import (
type metadata struct {
readTimeout time.Duration
sniffing bool
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
sniffing = "sniffing"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
return
}