370 lines
7.7 KiB
Go
370 lines
7.7 KiB
Go
package gost
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-gost/relay"
|
|
"github.com/go-log/log"
|
|
)
|
|
|
|
type relayConnector struct {
|
|
user *url.Userinfo
|
|
remoteAddr string
|
|
}
|
|
|
|
// RelayConnector creates a Connector for TCP/UDP data relay.
|
|
func RelayConnector(user *url.Userinfo) Connector {
|
|
return &relayConnector{
|
|
user: user,
|
|
}
|
|
}
|
|
|
|
func (c *relayConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *relayConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
|
|
opts := &ConnectOptions{}
|
|
for _, option := range options {
|
|
option(opts)
|
|
}
|
|
|
|
timeout := opts.Timeout
|
|
if timeout <= 0 {
|
|
timeout = ConnectTimeout
|
|
}
|
|
|
|
conn.SetDeadline(time.Now().Add(timeout))
|
|
defer conn.SetDeadline(time.Time{})
|
|
|
|
var udp bool
|
|
if network == "udp" || network == "udp4" || network == "udp6" {
|
|
udp = true
|
|
}
|
|
|
|
req := &relay.Request{
|
|
Version: relay.Version1,
|
|
}
|
|
if udp {
|
|
req.Flags |= relay.FUDP
|
|
}
|
|
|
|
if c.user != nil {
|
|
pwd, _ := c.user.Password()
|
|
req.Features = append(req.Features, &relay.UserAuthFeature{
|
|
Username: c.user.Username(),
|
|
Password: pwd,
|
|
})
|
|
}
|
|
if address != "" {
|
|
host, port, _ := net.SplitHostPort(address)
|
|
nport, _ := strconv.ParseUint(port, 10, 16)
|
|
if host == "" {
|
|
host = net.IPv4zero.String()
|
|
}
|
|
|
|
if nport > 0 {
|
|
var atype uint8
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
atype = relay.AddrDomain
|
|
} else if ip.To4() == nil {
|
|
atype = relay.AddrIPv6
|
|
} else {
|
|
atype = relay.AddrIPv4
|
|
}
|
|
|
|
req.Features = append(req.Features, &relay.AddrFeature{
|
|
AType: atype,
|
|
Host: host,
|
|
Port: uint16(nport),
|
|
})
|
|
}
|
|
}
|
|
|
|
rc := &relayConn{
|
|
udp: udp,
|
|
Conn: conn,
|
|
}
|
|
|
|
// write the header at once.
|
|
if opts.NoDelay {
|
|
if _, err := req.WriteTo(rc); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
if _, err := req.WriteTo(&rc.wbuf); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return rc, nil
|
|
}
|
|
|
|
type relayHandler struct {
|
|
*baseForwardHandler
|
|
}
|
|
|
|
// RelayHandler creates a server Handler for TCP/UDP relay server.
|
|
func RelayHandler(raddr string, opts ...HandlerOption) Handler {
|
|
h := &relayHandler{
|
|
baseForwardHandler: &baseForwardHandler{
|
|
raddr: raddr,
|
|
group: NewNodeGroup(),
|
|
options: &HandlerOptions{},
|
|
},
|
|
}
|
|
for _, opt := range opts {
|
|
opt(h.options)
|
|
}
|
|
return h
|
|
}
|
|
|
|
func (h *relayHandler) Init(options ...HandlerOption) {
|
|
h.baseForwardHandler.Init(options...)
|
|
}
|
|
|
|
func (h *relayHandler) Handle(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
req := &relay.Request{}
|
|
if _, err := req.ReadFrom(conn); err != nil {
|
|
log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
|
return
|
|
}
|
|
|
|
if req.Version != relay.Version1 {
|
|
log.Logf("[relay] %s - %s : bad version", conn.RemoteAddr(), conn.LocalAddr())
|
|
return
|
|
}
|
|
|
|
var user, pass string
|
|
var raddr string
|
|
for _, f := range req.Features {
|
|
if f.Type() == relay.FeatureUserAuth {
|
|
feature := f.(*relay.UserAuthFeature)
|
|
user, pass = feature.Username, feature.Password
|
|
}
|
|
if f.Type() == relay.FeatureAddr {
|
|
feature := f.(*relay.AddrFeature)
|
|
raddr = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
|
|
}
|
|
}
|
|
|
|
resp := &relay.Response{
|
|
Version: relay.Version1,
|
|
Status: relay.StatusOK,
|
|
}
|
|
if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) {
|
|
resp.Status = relay.StatusUnauthorized
|
|
resp.WriteTo(conn)
|
|
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
|
|
return
|
|
}
|
|
|
|
if raddr != "" {
|
|
if len(h.group.Nodes()) > 0 {
|
|
resp.Status = relay.StatusForbidden
|
|
resp.WriteTo(conn)
|
|
log.Logf("[relay] %s -> %s : relay to %s is forbidden",
|
|
conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
|
return
|
|
}
|
|
} else {
|
|
if len(h.group.Nodes()) == 0 {
|
|
resp.Status = relay.StatusBadRequest
|
|
resp.WriteTo(conn)
|
|
log.Logf("[relay] %s -> %s : bad request, target addr is needed",
|
|
conn.RemoteAddr(), conn.LocalAddr())
|
|
return
|
|
}
|
|
}
|
|
|
|
udp := (req.Flags & relay.FUDP) == relay.FUDP
|
|
retries := 1
|
|
if h.options.Chain != nil && h.options.Chain.Retries > 0 {
|
|
retries = h.options.Chain.Retries
|
|
}
|
|
if h.options.Retries > 0 {
|
|
retries = h.options.Retries
|
|
}
|
|
|
|
network := "tcp"
|
|
if udp {
|
|
network = "udp"
|
|
}
|
|
if !Can(network, raddr, h.options.Whitelist, h.options.Blacklist) {
|
|
resp.Status = relay.StatusForbidden
|
|
resp.WriteTo(conn)
|
|
log.Logf("[relay] %s -> %s : relay to %s is forbidden",
|
|
conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
|
return
|
|
}
|
|
|
|
ctx := context.TODO()
|
|
var cc net.Conn
|
|
var node Node
|
|
var err error
|
|
for i := 0; i < retries; i++ {
|
|
if len(h.group.Nodes()) > 0 {
|
|
node, err = h.group.Next()
|
|
if err != nil {
|
|
resp.Status = relay.StatusServiceUnavailable
|
|
resp.WriteTo(conn)
|
|
log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
|
return
|
|
}
|
|
raddr = node.Addr
|
|
}
|
|
|
|
log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
|
cc, err = h.options.Chain.DialContext(ctx,
|
|
network, raddr,
|
|
RetryChainOption(h.options.Retries),
|
|
TimeoutChainOption(h.options.Timeout),
|
|
)
|
|
if err != nil {
|
|
log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err)
|
|
node.MarkDead()
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
resp.Status = relay.StatusServiceUnavailable
|
|
resp.WriteTo(conn)
|
|
return
|
|
}
|
|
|
|
node.ResetDead()
|
|
defer cc.Close()
|
|
|
|
sc := &relayConn{
|
|
Conn: conn,
|
|
isServer: true,
|
|
udp: udp,
|
|
}
|
|
resp.WriteTo(&sc.wbuf)
|
|
conn = sc
|
|
|
|
log.Logf("[relay] %s <-> %s", conn.RemoteAddr(), raddr)
|
|
transport(conn, cc)
|
|
log.Logf("[relay] %s >-< %s", conn.RemoteAddr(), raddr)
|
|
}
|
|
|
|
type relayConn struct {
|
|
net.Conn
|
|
isServer bool
|
|
udp bool
|
|
wbuf bytes.Buffer
|
|
once sync.Once
|
|
headerSent bool
|
|
}
|
|
|
|
func (c *relayConn) Read(b []byte) (n int, err error) {
|
|
c.once.Do(func() {
|
|
if c.isServer {
|
|
return
|
|
}
|
|
resp := new(relay.Response)
|
|
_, err = resp.ReadFrom(c.Conn)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if resp.Version != relay.Version1 {
|
|
err = relay.ErrBadVersion
|
|
return
|
|
}
|
|
if resp.Status != relay.StatusOK {
|
|
err = fmt.Errorf("status %d", resp.Status)
|
|
return
|
|
}
|
|
})
|
|
|
|
if err != nil {
|
|
log.Logf("[relay] %s <- %s: %s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), err)
|
|
return
|
|
}
|
|
|
|
if !c.udp {
|
|
return c.Conn.Read(b)
|
|
}
|
|
var bb [2]byte
|
|
_, err = io.ReadFull(c.Conn, bb[:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
dlen := int(binary.BigEndian.Uint16(bb[:]))
|
|
if len(b) >= dlen {
|
|
return io.ReadFull(c.Conn, b[:dlen])
|
|
}
|
|
buf := make([]byte, dlen)
|
|
_, err = io.ReadFull(c.Conn, buf)
|
|
n = copy(b, buf)
|
|
return
|
|
}
|
|
|
|
func (c *relayConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
|
n, err = c.Read(b)
|
|
addr = c.Conn.RemoteAddr()
|
|
return
|
|
}
|
|
|
|
func (c *relayConn) Write(b []byte) (n int, err error) {
|
|
if len(b) > 0xFFFF {
|
|
err = errors.New("write: data maximum exceeded")
|
|
return
|
|
}
|
|
n = len(b) // force byte length consistent
|
|
if c.wbuf.Len() > 0 {
|
|
if c.udp {
|
|
var bb [2]byte
|
|
binary.BigEndian.PutUint16(bb[:2], uint16(len(b)))
|
|
c.wbuf.Write(bb[:])
|
|
c.headerSent = true
|
|
}
|
|
c.wbuf.Write(b) // append the data to the cached header
|
|
// _, err = c.Conn.Write(c.wbuf.Bytes())
|
|
// c.wbuf.Reset()
|
|
_, err = c.wbuf.WriteTo(c.Conn)
|
|
return
|
|
}
|
|
|
|
if !c.udp {
|
|
return c.Conn.Write(b)
|
|
}
|
|
if !c.headerSent {
|
|
c.headerSent = true
|
|
b2 := make([]byte, len(b)+2)
|
|
copy(b2, b)
|
|
_, err = c.Conn.Write(b2)
|
|
return
|
|
}
|
|
nsize := 2 + len(b)
|
|
var buf []byte
|
|
if nsize <= mediumBufferSize {
|
|
buf = mPool.Get().([]byte)
|
|
defer mPool.Put(buf)
|
|
} else {
|
|
buf = make([]byte, nsize)
|
|
}
|
|
binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
|
|
n = copy(buf[2:], b)
|
|
_, err = c.Conn.Write(buf[:nsize])
|
|
return
|
|
}
|
|
|
|
func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
return c.Write(b)
|
|
}
|