diff --git a/shadow/client.go b/shadow/client.go index 72e5aaa..f0d85e7 100644 --- a/shadow/client.go +++ b/shadow/client.go @@ -61,38 +61,37 @@ func handlerClient(conn net.Conn, serverAddress string, fakeAddressSNI string) { p := &PackAppData{Conn: dial.NetConn()} - go io.Copy(conn, p) - go io.Copy(p, conn) + defer p.Close() + defer conn.Close() + exitCh := make(chan int, 1) + + go MyCopy(conn, p, exitCh) + go MyCopy(p, conn, exitCh) + <-exitCh } -func MyCopy(src io.ReadWriteCloser, dst io.ReadWriteCloser) { +func MyCopy(src io.ReadWriteCloser, dst io.ReadWriteCloser, ch chan int) { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - fmt.Printf("err1:\n") - } + if er != nil { + if er == io.EOF { + break + } else { + fmt.Printf("Read err: %v\n", er) + break } + } else { + nw, ew := dst.Write(buf[0:nr]) if ew != nil { - fmt.Printf("err2:%v\n", ew) + fmt.Printf("Write error:%v\n", ew) break } if nr != nw { - fmt.Printf("nr != nw \n") + fmt.Printf("Write less then buffered \n") break } } - if er != nil { - if er != io.EOF { - fmt.Printf("err3:%v\n", er) - } else { - fmt.Println("EOF") - } - break - } } + ch <- 1 } diff --git a/shadow/client_test.go b/shadow/client_test.go index 6dcfef8..d4ac593 100644 --- a/shadow/client_test.go +++ b/shadow/client_test.go @@ -23,6 +23,14 @@ func TestName(t *testing.T) { } func TestName2(t *testing.T) { - v := "12345678" - fmt.Println(v[2:2]) + b := []byte("ABC") + encrypt, err := AesEncrypt(b, []byte("1234567812345678")) + if err != nil { + fmt.Println(err) + } + decrypt, err := AesDecrypt(encrypt, []byte("1234567812345678")) + if err != nil { + fmt.Println(err) + } + fmt.Println(string(decrypt)) } diff --git a/shadow/packer.go b/shadow/packer.go index 1f8edad..f8b7608 100644 --- a/shadow/packer.go +++ b/shadow/packer.go @@ -3,12 +3,14 @@ package shadow import ( "bytes" "encoding/binary" + "errors" "fmt" + "io" "net" ) var ( - AppDataHeader = []byte("GGGGGGGG") + AppDataHeader = []byte{0x17, 0x3, 0x3} HeaderLength = len(AppDataHeader) ) @@ -17,25 +19,48 @@ type PackAppData struct { } func (m PackAppData) Read(p []byte) (n int, err error) { - buf := make([]byte, 32*1024) - read, err := m.Conn.Read(buf[0 : HeaderLength+2]) + + buf := make([]byte, 32*1024+HeaderLength+2) + + headRead, err := io.ReadAtLeast(m.Conn, buf[0:HeaderLength+2], HeaderLength+2) + if err != nil { + fmt.Printf("Read header error: %v\n", err) + return 0, err + } + if headRead < HeaderLength+2 { + return 0, errors.New("Read header failed") + } if bytes.Equal(buf[0:HeaderLength], AppDataHeader) { - u := int(binary.BigEndian.Uint16(buf[HeaderLength : HeaderLength+2])) - r, err := m.Conn.Read(buf[HeaderLength+2 : u+HeaderLength+2]) - copy(p, buf[HeaderLength+2:r+HeaderLength+2]) - return r, err + payLoadLength := int(binary.BigEndian.Uint16(buf[HeaderLength : HeaderLength+2])) + sum := 0 + for sum < payLoadLength { + r, e := m.Conn.Read(buf[HeaderLength+2+sum : HeaderLength+2+payLoadLength]) + if e != nil { + if e == io.EOF { + break + } else { + return 0, e + } + } + copy(p[sum:], buf[HeaderLength+2+sum:HeaderLength+2+sum+r]) + sum += r + } + return sum, err } else { - fmt.Println("Header is not present") - return read, err + fmt.Printf("Invalid header") + return 0, errors.New("invalid header") } } func (m PackAppData) Write(p []byte) (n int, err error) { - t := make([]byte, len(p)+HeaderLength+2) - copy(t[0:], AppDataHeader) - binary.BigEndian.PutUint16(t[HeaderLength:], uint16(len(p))) - copy(t[HeaderLength+2:], p) - write, err := m.Conn.Write(t) + lenNum := make([]byte, 2) + binary.BigEndian.PutUint16(lenNum, uint16(len(p))) + + packetBuf := bytes.NewBuffer(AppDataHeader) + packetBuf.Write(lenNum) + packetBuf.Write(p) + + write, err := m.Conn.Write(packetBuf.Bytes()) write = write - HeaderLength - 2 return write, err } diff --git a/shadow/server.go b/shadow/server.go index 4f7d9f0..f76cb6b 100644 --- a/shadow/server.go +++ b/shadow/server.go @@ -70,8 +70,66 @@ func handler(conn net.Conn, targetAddress string, fakeAddress string) { } p := &PackAppData{Conn: conn} - go io.Copy(realConnection, p) - go io.Copy(p, realConnection) + + defer p.Close() + defer realConnection.Close() + exit := make(chan int, 1) + + go MyCopy(p, realConnection, exit) + go MyCopy(realConnection, p, exit) + <-exit + + //go func() { + // buf := make([]byte, 64*1024) + // for { + // nr, er := realConnection.Read(buf) + // if er != nil { + // if er == io.EOF { + // continue + // } else { + // fmt.Println("read err:", er) + // break + // } + // } else { + // lenNum := make([]byte, 2) + // binary.BigEndian.PutUint16(lenNum, uint16(nr)) + // + // packetBuf := bytes.NewBuffer(AppDataHeader) + // packetBuf.Write(lenNum) + // packetBuf.Write(buf[0:nr]) + // + // _, ew := conn.Write(packetBuf.Bytes()) + // if ew != nil { + // fmt.Printf("err2:%v\n", ew) + // break + // } + // } + // } + //}() + // + //go func() { + // result := bytes.NewBuffer(nil) + // var buf [65542]byte // 由于 标识数据包长度 的只有两个字节 故数据包最大为 2^16+4(魔数)+2(长度标识) + // for { + // n, er := conn.Read(buf[0:]) + // result.Write(buf[0:n]) + // if er != nil { + // if er == io.EOF { + // continue + // } else { + // fmt.Println("read err:", er) + // break + // } + // } else { + // scanner := bufio.NewScanner(result) + // scanner.Split(packetSlitFunc) + // for scanner.Scan() { + // realConnection.Write(scanner.Bytes()[HeaderLength+2:]) + // } + // } + // result.Reset() + // } + //}() } func processHandshake(src net.Conn, dst net.Conn, waitCh chan int) {