118 lines
2.1 KiB
Go
118 lines
2.1 KiB
Go
package relay
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/go-gost/gost/pkg/logger"
|
|
"github.com/go-gost/relay"
|
|
)
|
|
|
|
type conn struct {
|
|
net.Conn
|
|
udp bool
|
|
wbuf bytes.Buffer
|
|
once sync.Once
|
|
headerSent bool
|
|
logger logger.Logger
|
|
}
|
|
|
|
func (c *conn) Read(b []byte) (n int, err error) {
|
|
c.once.Do(func() {
|
|
resp := relay.Response{}
|
|
_, err = resp.ReadFrom(c.Conn)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if resp.Version != relay.Version1 {
|
|
err = relay.ErrBadVersion
|
|
return
|
|
}
|
|
if resp.Status != relay.StatusOK {
|
|
err = fmt.Errorf("status %d", resp.Status)
|
|
return
|
|
}
|
|
})
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if !c.udp {
|
|
return c.Conn.Read(b)
|
|
}
|
|
|
|
var bb [2]byte
|
|
_, err = io.ReadFull(c.Conn, bb[:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
dlen := int(binary.BigEndian.Uint16(bb[:]))
|
|
if len(b) >= dlen {
|
|
return io.ReadFull(c.Conn, b[:dlen])
|
|
}
|
|
buf := make([]byte, dlen)
|
|
_, err = io.ReadFull(c.Conn, buf)
|
|
n = copy(b, buf)
|
|
return
|
|
}
|
|
|
|
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
|
n, err = c.Read(b)
|
|
addr = c.Conn.RemoteAddr()
|
|
return
|
|
}
|
|
|
|
func (c *conn) Write(b []byte) (n int, err error) {
|
|
if len(b) > 0xFFFF {
|
|
err = errors.New("write: data maximum exceeded")
|
|
return
|
|
}
|
|
n = len(b) // force byte length consistent
|
|
if c.wbuf.Len() > 0 {
|
|
if c.udp {
|
|
var bb [2]byte
|
|
binary.BigEndian.PutUint16(bb[:2], uint16(len(b)))
|
|
c.wbuf.Write(bb[:])
|
|
c.headerSent = true
|
|
}
|
|
c.wbuf.Write(b) // append the data to the cached header
|
|
// _, err = c.Conn.Write(c.wbuf.Bytes())
|
|
// c.wbuf.Reset()
|
|
_, err = c.wbuf.WriteTo(c.Conn)
|
|
return
|
|
}
|
|
|
|
if !c.udp {
|
|
return c.Conn.Write(b)
|
|
}
|
|
if !c.headerSent {
|
|
c.headerSent = true
|
|
b2 := make([]byte, len(b)+2)
|
|
copy(b2, b)
|
|
_, err = c.Conn.Write(b2)
|
|
return
|
|
}
|
|
nsize := 2 + len(b)
|
|
var buf []byte
|
|
if nsize <= mediumBufferSize {
|
|
buf = mPool.Get().([]byte)
|
|
defer mPool.Put(buf)
|
|
} else {
|
|
buf = make([]byte, nsize)
|
|
}
|
|
binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
|
|
n = copy(buf[2:], b)
|
|
_, err = c.Conn.Write(buf[:nsize])
|
|
return
|
|
}
|
|
|
|
func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
return c.Write(b)
|
|
}
|