add pht(s)

This commit is contained in:
ginuerzh
2022-01-15 23:35:12 +08:00
parent f86fcb7eba
commit b7dd9dea3f
9 changed files with 689 additions and 4 deletions

159
pkg/dialer/pht/conn.go Normal file
View File

@ -0,0 +1,159 @@
package pht
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/go-gost/gost/pkg/logger"
)
type conn struct {
cid string
addr string
client *http.Client
tlsEnabled bool
buf []byte
rxc chan []byte
closed chan struct{}
md metadata
logger logger.Logger
}
func (c *conn) Read(b []byte) (n int, err error) {
if len(c.buf) == 0 {
select {
case c.buf = <-c.rxc:
case <-c.closed:
err = net.ErrClosed
return
}
}
n = copy(b, c.buf)
c.buf = c.buf[n:]
return
}
func (c *conn) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return
}
buf := bytes.NewBufferString(base64.StdEncoding.EncodeToString(b))
buf.WriteByte('\n')
var url string
if c.tlsEnabled {
url = fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pushPath, c.cid)
} else {
url = fmt.Sprintf("http://%s%s?token=%s", c.addr, c.md.pushPath, c.cid)
}
r, err := http.NewRequest(http.MethodPost, url, buf)
if err != nil {
return
}
resp, err := c.client.Do(r)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = errors.New(resp.Status)
return
}
n = len(b)
return
}
func (c *conn) readLoop() {
defer c.Close()
var url string
if c.tlsEnabled {
url = fmt.Sprintf("https://%s%s?token=%s", c.addr, c.md.pullPath, c.cid)
} else {
url = fmt.Sprintf("http://%s%s?token=%s", c.addr, c.md.pullPath, c.cid)
}
for {
err := func() error {
r, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := c.client.Do(r)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
b, err := base64.StdEncoding.DecodeString(scanner.Text())
if err != nil {
return err
}
select {
case c.rxc <- b:
case <-c.closed:
return net.ErrClosed
}
}
return scanner.Err()
}()
if err != nil {
c.logger.Error(err)
return
}
}
}
func (c *conn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *conn) RemoteAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.addr)
if addr == nil {
addr = &net.TCPAddr{}
}
return addr
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return nil
}
func (c *conn) SetDeadline(t time.Time) error {
return nil
}

143
pkg/dialer/pht/dialer.go Normal file
View File

@ -0,0 +1,143 @@
package pht
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func init() {
registry.RegisterDialer("pht", NewDialer)
registry.RegisterDialer("phts", NewTLSDialer)
}
type phtDialer struct {
tlsEnabled bool
md metadata
logger logger.Logger
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := dialer.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtDialer{
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := dialer.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
func (d *phtDialer) Init(md md.Metadata) (err error) {
return d.parseMetadata(md)
}
func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
tr := &http.Transport{
// Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
if d.tlsEnabled {
tr.TLSClientConfig = d.options.TLSConfig
}
client := &http.Client{
Timeout: 60 * time.Second,
Transport: tr,
}
token, err := d.authorize(ctx, client, addr)
if err != nil {
d.logger.Error(err)
return nil, err
}
c := &conn{
cid: token,
addr: addr,
client: client,
tlsEnabled: d.tlsEnabled,
rxc: make(chan []byte, 128),
closed: make(chan struct{}),
md: d.md,
logger: d.logger,
}
go c.readLoop()
return c, nil
}
func (d *phtDialer) authorize(ctx context.Context, client *http.Client, addr string) (token string, err error) {
var url string
if d.tlsEnabled {
url = fmt.Sprintf("https://%s%s", addr, d.md.authorizePath)
} else {
url = fmt.Sprintf("http://%s%s", addr, d.md.authorizePath)
}
r, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return
}
if d.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
d.logger.Debug(string(dump))
}
resp, err := client.Do(r)
if err != nil {
return
}
defer resp.Body.Close()
if d.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
d.logger.Debug(string(dump))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if strings.HasPrefix(string(data), "token=") {
token = strings.TrimPrefix(string(data), "token=")
}
if token == "" {
err = errors.New("authorize failed")
}
return
}

View File

@ -0,0 +1,48 @@
package pht
import (
"strings"
"time"
mdata "github.com/go-gost/gost/pkg/metadata"
)
const (
dialTimeout = "dialTimeout"
defaultAuthorizePath = "/authorize"
defaultPushPath = "/push"
defaultPullPath = "/pull"
)
const (
defaultDialTimeout = 5 * time.Second
)
type metadata struct {
dialTimeout time.Duration
authorizePath string
pushPath string
pullPath string
}
func (d *phtDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizePath = "authorizePath"
pushPath = "pushPath"
pullPath = "pullPath"
)
d.md.authorizePath = mdata.GetString(md, authorizePath)
if !strings.HasPrefix(d.md.authorizePath, "/") {
d.md.authorizePath = defaultAuthorizePath
}
d.md.pushPath = mdata.GetString(md, pushPath)
if !strings.HasPrefix(d.md.pushPath, "/") {
d.md.pushPath = defaultPushPath
}
d.md.pullPath = mdata.GetString(md, pullPath)
if !strings.HasPrefix(d.md.pullPath, "/") {
d.md.pullPath = defaultPullPath
}
return
}

20
pkg/listener/pht/conn.go Normal file
View File

@ -0,0 +1,20 @@
package pht
import (
"net"
)
// pht connection, wrapped up just like a net.Conn
type conn struct {
net.Conn
remoteAddr net.Addr
localAddr net.Addr
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View File

@ -0,0 +1,282 @@
// plain http tunnel
package pht
import (
"bufio"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"os"
"strings"
"sync"
"time"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/rs/xid"
)
func init() {
registry.RegisterListener("pht", NewListener)
registry.RegisterListener("phts", NewTLSListener)
}
type phtListener struct {
tlsEnabled bool
server *http.Server
addr net.Addr
conns sync.Map
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtListener{
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &phtListener{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
func (l *phtListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
l.addr = ln.Addr()
mux := http.NewServeMux()
mux.HandleFunc("/authorize", l.handleAuthorize)
mux.HandleFunc("/push", l.handlePush)
mux.HandleFunc("/pull", l.handlePull)
l.server = &http.Server{
Addr: l.options.Addr,
Handler: mux,
}
if l.tlsEnabled {
l.server.TLSConfig = l.options.TLSConfig
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *phtListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *phtListener) Addr() net.Addr {
return l.addr
}
func (l *phtListener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *phtListener) handleAuthorize(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
// connection id
cid := xid.New().String()
c1, c2 := net.Pipe()
c := &conn{
Conn: c1,
localAddr: l.addr,
remoteAddr: raddr,
}
select {
case l.cqueue <- c:
default:
c.Close()
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.Write([]byte(fmt.Sprintf("token=%s", cid)))
l.conns.Store(cid, c2)
}
func (l *phtListener) handlePush(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := l.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
br := bufio.NewReader(r.Body)
data, err := br.ReadString('\n')
if err != nil {
l.logger.Error(err)
conn.Close()
l.conns.Delete(cid)
w.WriteHeader(http.StatusBadRequest)
return
}
data = strings.TrimSuffix(data, "\n")
if len(data) == 0 {
return
}
b, err := base64.StdEncoding.DecodeString(data)
if err != nil {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusBadRequest)
return
}
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
defer conn.SetWriteDeadline(time.Time{})
if _, err := conn.Write(b); err != nil {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
w.WriteHeader(http.StatusGone)
}
}
func (l *phtListener) handlePull(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
cid := r.Form.Get("token")
v, ok := l.conns.Load(cid)
if !ok {
w.WriteHeader(http.StatusForbidden)
return
}
conn := v.(net.Conn)
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
b := bufpool.Get(4096)
defer bufpool.Put(b)
for {
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
n, err := conn.Read(*b)
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
l.logger.Error(err)
l.conns.Delete(cid)
conn.Close()
} else {
(*b)[0] = '\n'
w.Write((*b)[:1])
}
return
}
bw := bufio.NewWriter(w)
bw.WriteString(base64.StdEncoding.EncodeToString((*b)[:n]))
bw.WriteString("\n")
if err := bw.Flush(); err != nil {
return
}
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
}
}

View File

@ -0,0 +1,29 @@
package pht
import (
mdata "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
path string
backlog int
}
func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
backlog = "backlog"
)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.path = mdata.GetString(md, path)
return
}