gost/sni_test.go
2022-08-03 10:04:41 +08:00

149 lines
2.6 KiB
Go

package gost
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
)
func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) (err error) {
conn, err := client.Dial(server.Addr().String())
if err != nil {
return
}
conn, err = client.Handshake(conn, AddrHandshakeOption(server.Addr().String()))
if err != nil {
return
}
defer conn.Close()
u, err := url.Parse(targetURL)
if err != nil {
return
}
conn.SetDeadline(time.Now().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{})
conn, err = client.Connect(conn, u.Host)
if err != nil {
return
}
if u.Scheme == "https" {
conn = tls.Client(conn,
&tls.Config{
InsecureSkipVerify: true,
// ServerName: u.Hostname(),
})
u.Scheme = "http"
}
req, err := http.NewRequest(
http.MethodGet,
u.String(),
bytes.NewReader(data),
)
if err != nil {
return
}
if err = req.WriteProxy(conn); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
recv, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if !bytes.Equal(data, recv) {
return fmt.Errorf("data not equal")
}
return
}
func sniProxyRoundtrip(targetURL string, data []byte, host string) error {
ln, err := TCPListener("")
if err != nil {
return err
}
u, err := url.Parse(targetURL)
if err != nil {
return err
}
client := &Client{
Connector: SNIConnector(host),
Transporter: TCPTransporter(),
}
server := &Server{
Listener: ln,
Handler: SNIHandler(HostHandlerOption(u.Host)),
}
go server.Run()
defer server.Close()
return sniRoundtrip(client, server, targetURL, data)
}
func TestSNIProxy(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
httpsSrv := httptest.NewTLSServer(httpTestHandler)
defer httpsSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
var sniProxyTests = []struct {
targetURL string
host string
pass bool
}{
{httpSrv.URL, "", true},
{httpSrv.URL, "example.com", true},
{httpsSrv.URL, "", true},
{httpsSrv.URL, "example.com", true},
}
for i, tc := range sniProxyTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
err := sniProxyRoundtrip(tc.targetURL, sendData, tc.host)
if err == nil {
if !tc.pass {
t.Errorf("#%d should failed", i)
}
} else {
// t.Logf("#%d %v", i, err)
if tc.pass {
t.Errorf("#%d got error: %v", i, err)
}
}
})
}
}