package shadow import ( "fmt" "github.com/xtaci/smux" "io" "net" ) type Server struct { ListenAddress string TargetAddress string FakeAddress string } func NewServer(listenAddress string, targetAddress string, fakeAddress string) *Server { server := &Server{ ListenAddress: listenAddress, TargetAddress: targetAddress, FakeAddress: fakeAddress, } return server } func (s *Server) Start() { listen, err := net.Listen("tcp", s.ListenAddress) if err != nil { fmt.Printf("[Server] Start server error: %v\n", err) return } defer listen.Close() fmt.Printf("[Server] Listening at:%v\n", s.ListenAddress) for { conn, err := listen.Accept() if err != nil { fmt.Printf("[Server] Accept error: %v\n", err) continue } go handler(conn, s.TargetAddress, s.FakeAddress) } } func handler(conn net.Conn, targetAddress string, fakeAddress string) { //Process fake TLS handshake fmt.Println("[Server] Perform handshake") fakeConn, err := net.Dial("tcp", fakeAddress) if err != nil { fmt.Printf("[Server] Dial fake error : %v\n", err) return } waitCh := make(chan int, 2) go processHandshake(conn, fakeConn, waitCh) go processHandshake(fakeConn, conn, waitCh) <-waitCh <-waitCh //Process real tcp connection session, err := smux.Server(conn, nil) if err != nil { fmt.Printf("[Server] smux error: %v\n", err) return } for { stream, err := session.AcceptStream() if err != nil { fmt.Printf("[Server] AcceptStream error: %v\n", err) break } go handlerMux(stream, targetAddress) } } func processHandshake(src net.Conn, dst net.Conn, waitCh chan int) { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) if nr > 0 { header := ParseAndVerifyTLSHeader(buf[0:nr]) nw, ew := dst.Write(buf[0:nr]) if header != nil && header.Type == ChangeCipherSpec { fmt.Println("[Server] handshake complete") dst.Close() break } if nw < 0 || nr < nw { nw = 0 if ew == nil { //fmt.Printf("ERR1 %v \n", ew) } } if ew != nil { //fmt.Printf("ERR2 %v \n", ew) break } if nr != nw { //fmt.Printf("ERR3 %v \n", "shortwrite") break } } if er != nil { if er != io.EOF { //fmt.Printf("ERR4 %v \n", er) } break } } waitCh <- 1 } func handlerMux(conn *smux.Stream, targetAddress string) { realConnection, err := net.Dial("tcp", targetAddress) if err != nil { fmt.Printf("[Server] Dial target error : %v\n", err) return } if err != nil { return } go io.Copy(realConnection, conn) go io.Copy(conn, realConnection) }