diff --git a/shadow/client.go b/shadow/client.go index 9299a8d..669081b 100644 --- a/shadow/client.go +++ b/shadow/client.go @@ -41,7 +41,7 @@ func (c *Client) Start() { stream := bridge.GetStream() if stream == nil { conn.Close() - fmt.Printf("[Client] connect to server error: %v\n", err) + fmt.Printf("[Client] connect to server error") continue } fmt.Printf("[Client] New TCP connection: %v <-> %v \n", conn.LocalAddr().String(), conn.RemoteAddr().String()) diff --git a/shadow/client_test.go b/shadow/client_test.go index 28adb62..63dc7e4 100644 --- a/shadow/client_test.go +++ b/shadow/client_test.go @@ -1,128 +1,37 @@ package shadow import ( + "crypto/tls" "fmt" - "github.com/xtaci/smux" - "io" "net" "testing" + "time" ) -func Test(t *testing.T) { - listen, err := net.Listen("tcp", "0.0.0.0:11222") +func TestName(t *testing.T) { + dial, err := tls.DialWithDialer(&net.Dialer{ + Timeout: time.Second * 5, + }, "tcp", "www.baidu.com:443", &tls.Config{ + ServerName: "www.baidu.com", + }) + + err = dial.Handshake() if err != nil { - fmt.Printf("Start server failed : %v\n", err) - return - } - defer listen.Close() - for { - conn, err := listen.Accept() - if err != nil { - fmt.Printf("Accept error: %v\n", err) - continue - } - fmt.Println("A") - go handler(conn) + fmt.Println(err) } + time.Sleep(time.Minute) } -func handler(conn net.Conn) { +func TestName2(t *testing.T) { + dial, err := tls.DialWithDialer(&net.Dialer{ + Timeout: time.Second * 5, + }, "tcp", "evan.run:443", &tls.Config{ + ServerName: "evan.run", + }) - fakeConn, err := net.Dial("tcp", "127.0.0.1:5900") + err = dial.Handshake() if err != nil { - fmt.Printf("Dial fake failed : %v\n", err) - return + fmt.Println(err) } - - if err != nil { - return - } - go io.Copy(fakeConn, conn) - go io.Copy(conn, fakeConn) -} - -func TestSmuxServer(t *testing.T) { - listen, err := net.Listen("tcp", "0.0.0.0:7556") - if err != nil { - fmt.Printf("Start server failed : %v\n", err) - return - } - defer listen.Close() - for { - conn, err := listen.Accept() - if err != nil { - fmt.Printf("Accept error: %v\n", err) - continue - } - session, err := smux.Server(conn, nil) - if err != nil { - fmt.Printf("smux error: %v\n", err) - continue - } - - go func() { - for { - stream, err := session.AcceptStream() - if err != nil { - fmt.Printf("AcceptStream error: %v\n", err) - continue - } - fmt.Println("A") - go handlerMuxTest(stream) - } - }() - } - -} - -func TestSmuxClient(t *testing.T) { - listen, err := net.Listen("tcp", "0.0.0.0:11222") - if err != nil { - fmt.Printf("Start server failed : %v\n", err) - return - } - defer listen.Close() - - smuxConn, err := net.Dial("tcp", "127.0.0.1:7556") - if err != nil { - fmt.Printf("Start smuxConn failed : %v\n", err) - return - } - session, err := smux.Client(smuxConn, nil) - if err != nil { - fmt.Printf("Start smux.Client failed : %v\n", err) - return - } - for { - conn, err := listen.Accept() - if err != nil { - fmt.Printf("Accept error: %v\n", err) - continue - } - fmt.Println("A") - - stream, err := session.OpenStream() - if err != nil { - fmt.Printf("OpenStream error: %v\n", err) - continue - - } - go io.Copy(conn, stream) - go io.Copy(stream, conn) - } -} - -func handlerMuxTest(conn *smux.Stream) { - - fakeConn, err := net.Dial("tcp", "127.0.0.1:5900") - if err != nil { - fmt.Printf("Dial fake failed : %v\n", err) - return - } - fmt.Println("UUUUUUUUUUUUUU") - if err != nil { - return - } - go io.Copy(fakeConn, conn) - go io.Copy(conn, fakeConn) + time.Sleep(time.Minute) } diff --git a/shadow/server.go b/shadow/server.go index 1bd4f40..e9f5e16 100644 --- a/shadow/server.go +++ b/shadow/server.go @@ -49,12 +49,11 @@ func handler(conn net.Conn, targetAddress string, fakeAddress string) { fmt.Printf("[Server] Dial fake error : %v\n", err) return } - waitCh := make(chan int, 2) + waitCh := make(chan int, 1) go processHandshake(conn, fakeConn, waitCh) go processHandshake(fakeConn, conn, waitCh) - <-waitCh <-waitCh //Process real tcp connection @@ -83,8 +82,16 @@ func processHandshake(src net.Conn, dst net.Conn, waitCh chan int) { header := ParseAndVerifyTLSHeader(buf[0:nr]) nw, ew := dst.Write(buf[0:nr]) if header != nil && header.Type == ChangeCipherSpec { + //fmt.Println(header.toString()) fmt.Println("[Server] handshake complete") - dst.Close() + if header.ChangeCipherSpecNext == AppData { + dst.Close() + waitCh <- 1 + } else { + src.Close() + waitCh <- 1 + return + } break } if nw < 0 || nr < nw { diff --git a/shadow/tls_bridge.go b/shadow/tls_bridge.go index 19e49f6..f3876ed 100644 --- a/shadow/tls_bridge.go +++ b/shadow/tls_bridge.go @@ -45,6 +45,8 @@ func (t *TLSBridge) dial() error { if err != nil { return err } + //force openStream to prevent first connection problem + session.OpenStream() t.session = session return nil } diff --git a/shadow/tls_util.go b/shadow/tls_util.go index 6659619..e325372 100644 --- a/shadow/tls_util.go +++ b/shadow/tls_util.go @@ -22,10 +22,11 @@ const ( ) type TLSHeader struct { - Type uint8 - Version uint16 - Length int - HandshakeType uint8 + Type uint8 + Version uint16 + Length int + HandshakeType uint8 + ChangeCipherSpecNext uint8 } func (t *TLSHeader) toString() string { @@ -47,6 +48,14 @@ func (t *TLSHeader) toString() string { break case ChangeCipherSpec: desc += "Type=ChangeCipherSpec;" + switch t.ChangeCipherSpecNext { + case Handshake: + desc += "ChangeCipherSpecNext=Handshake;" + break + case AppData: + desc += "ChangeCipherSpecNext=AppData;" + break + } break case EncryptedAlert: desc += "Type=EncryptedAlert;" @@ -83,5 +92,8 @@ func ParseAndVerifyTLSHeader(data []byte) *TLSHeader { return nil } } + if header.Type == ChangeCipherSpec { + header.ChangeCipherSpecNext = data[6] + } return header }