diff --git a/shadow/client.go b/shadow/client.go index 6e0406d..f194037 100644 --- a/shadow/client.go +++ b/shadow/client.go @@ -1,9 +1,11 @@ package shadow import ( + "crypto/tls" "fmt" "io" "net" + "time" ) type Client struct { @@ -22,8 +24,6 @@ func NewClient(listenAddress string, serverAddress string, fakeAddressSNI string } func (c *Client) Start() { - bridge := NewTLSBridge(c.ServerAddress, c.FakeAddressSNI) - listen, err := net.Listen("tcp", c.ListenAddress) if err != nil { fmt.Printf("[Client] Start client error: %v\n", err) @@ -37,15 +37,28 @@ func (c *Client) Start() { fmt.Printf("[Client] accept error: %v\n", err) continue } - - stream := bridge.GetStream() - if stream == nil { - conn.Close() - fmt.Println("[Client] connect to server error") - continue - } - fmt.Printf("[Client] New TCP connection: %v <-> %v \n", conn.LocalAddr().String(), conn.RemoteAddr().String()) - go io.Copy(conn, stream) - go io.Copy(stream, conn) + go handlerClient(conn, c.ServerAddress, c.FakeAddressSNI) } } + +func handlerClient(conn net.Conn, serverAddress string, fakeAddressSNI string) { + dial, err := tls.DialWithDialer(&net.Dialer{ + Timeout: time.Second * 5, + }, "tcp", serverAddress, &tls.Config{ + ServerName: fakeAddressSNI, + }) + if err != nil { + fmt.Printf("[Client] Dial server error: %v\n", err) + return + } + err = dial.Handshake() + if err != nil { + fmt.Printf("[Client] Handshake error: %v\n", err) + return + } + dial.NetConn().SetDeadline(time.Now()) + dial.NetConn().SetDeadline(time.Time{}) + + go io.Copy(conn, dial.NetConn()) + go io.Copy(dial.NetConn(), conn) +} diff --git a/shadow/server.go b/shadow/server.go index b77c994..e7aefe9 100644 --- a/shadow/server.go +++ b/shadow/server.go @@ -2,7 +2,6 @@ package shadow import ( "fmt" - "github.com/xtaci/smux" "io" "net" "time" @@ -61,21 +60,16 @@ func handler(conn net.Conn, targetAddress string, fakeAddress string) { conn.SetDeadline(time.Now()) conn.SetDeadline(time.Time{}) - //Process real tcp connection - session, err := smux.Server(conn, nil) + realConnection, err := net.Dial("tcp", targetAddress) if err != nil { - fmt.Printf("[Server] smux error: %v\n", err) + fmt.Printf("[Server] Dial target error : %v\n", err) return } - for { - stream, err := session.AcceptStream() - if err != nil { - fmt.Printf("[Server] AcceptStream error: %v\n", err) - break - } - go handlerMux(stream, targetAddress) + if err != nil { + return } - + go io.Copy(realConnection, conn) + go io.Copy(conn, realConnection) } func processHandshake(src net.Conn, dst net.Conn, waitCh chan int) { @@ -123,17 +117,3 @@ func processHandshake(src net.Conn, dst net.Conn, waitCh chan int) { } waitCh <- 1 } - -func handlerMux(conn *smux.Stream, targetAddress string) { - - realConnection, err := net.Dial("tcp", targetAddress) - if err != nil { - fmt.Printf("[Server] Dial target error : %v\n", err) - return - } - if err != nil { - return - } - go io.Copy(realConnection, conn) - go io.Copy(conn, realConnection) -} diff --git a/shadow/tls_bridge.go b/shadow/tls_bridge.go deleted file mode 100644 index 951e8b7..0000000 --- a/shadow/tls_bridge.go +++ /dev/null @@ -1,75 +0,0 @@ -package shadow - -import ( - "crypto/tls" - "github.com/xtaci/smux" - "net" - "sync" - "time" -) - -type TLSBridge struct { - session *smux.Session - locker sync.Mutex - serverAddress string - fakeAddressSNI string -} - -func NewTLSBridge(serverAddress string, fakeAddressSNI string) *TLSBridge { - t := &TLSBridge{ - session: nil, - locker: sync.Mutex{}, - serverAddress: serverAddress, - fakeAddressSNI: fakeAddressSNI, - } - return t -} - -func (t *TLSBridge) dial() error { - if t.session != nil { - t.session.Close() - } - dial, err := tls.DialWithDialer(&net.Dialer{ - Timeout: time.Second * 5, - }, "tcp", t.serverAddress, &tls.Config{ - ServerName: t.fakeAddressSNI, - }) - if err != nil { - return err - } - err = dial.Handshake() - if err != nil { - return err - } - dial.NetConn().SetDeadline(time.Now()) - dial.NetConn().SetDeadline(time.Time{}) - time.Sleep(time.Millisecond * 100) - session, err := smux.Client(dial.NetConn(), nil) - if err != nil { - return err - } - //force openStream to prevent first connection problem - session.OpenStream() - t.session = session - return nil -} - -func (t *TLSBridge) GetStream() *smux.Stream { - t.locker.Lock() - defer t.locker.Unlock() - - if t.session == nil { - err := t.dial() - if err != nil { - return nil - } - } - - openStream, err := t.session.OpenStream() - if err != nil { - t.session.Close() - t.session = nil - } - return openStream - -}