init
This commit is contained in:
369
relay.go
Normal file
369
relay.go
Normal file
@ -0,0 +1,369 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/relay"
|
||||
"github.com/go-log/log"
|
||||
)
|
||||
|
||||
type relayConnector struct {
|
||||
user *url.Userinfo
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
// RelayConnector creates a Connector for TCP/UDP data relay.
|
||||
func RelayConnector(user *url.Userinfo) Connector {
|
||||
return &relayConnector{
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *relayConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *relayConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
|
||||
opts := &ConnectOptions{}
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
timeout := opts.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = ConnectTimeout
|
||||
}
|
||||
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
|
||||
var udp bool
|
||||
if network == "udp" || network == "udp4" || network == "udp6" {
|
||||
udp = true
|
||||
}
|
||||
|
||||
req := &relay.Request{
|
||||
Version: relay.Version1,
|
||||
}
|
||||
if udp {
|
||||
req.Flags |= relay.FUDP
|
||||
}
|
||||
|
||||
if c.user != nil {
|
||||
pwd, _ := c.user.Password()
|
||||
req.Features = append(req.Features, &relay.UserAuthFeature{
|
||||
Username: c.user.Username(),
|
||||
Password: pwd,
|
||||
})
|
||||
}
|
||||
if address != "" {
|
||||
host, port, _ := net.SplitHostPort(address)
|
||||
nport, _ := strconv.ParseUint(port, 10, 16)
|
||||
if host == "" {
|
||||
host = net.IPv4zero.String()
|
||||
}
|
||||
|
||||
if nport > 0 {
|
||||
var atype uint8
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
atype = relay.AddrDomain
|
||||
} else if ip.To4() == nil {
|
||||
atype = relay.AddrIPv6
|
||||
} else {
|
||||
atype = relay.AddrIPv4
|
||||
}
|
||||
|
||||
req.Features = append(req.Features, &relay.AddrFeature{
|
||||
AType: atype,
|
||||
Host: host,
|
||||
Port: uint16(nport),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
rc := &relayConn{
|
||||
udp: udp,
|
||||
Conn: conn,
|
||||
}
|
||||
|
||||
// write the header at once.
|
||||
if opts.NoDelay {
|
||||
if _, err := req.WriteTo(rc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err := req.WriteTo(&rc.wbuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
type relayHandler struct {
|
||||
*baseForwardHandler
|
||||
}
|
||||
|
||||
// RelayHandler creates a server Handler for TCP/UDP relay server.
|
||||
func RelayHandler(raddr string, opts ...HandlerOption) Handler {
|
||||
h := &relayHandler{
|
||||
baseForwardHandler: &baseForwardHandler{
|
||||
raddr: raddr,
|
||||
group: NewNodeGroup(),
|
||||
options: &HandlerOptions{},
|
||||
},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(h.options)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *relayHandler) Init(options ...HandlerOption) {
|
||||
h.baseForwardHandler.Init(options...)
|
||||
}
|
||||
|
||||
func (h *relayHandler) Handle(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
req := &relay.Request{}
|
||||
if _, err := req.ReadFrom(conn); err != nil {
|
||||
log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Version != relay.Version1 {
|
||||
log.Logf("[relay] %s - %s : bad version", conn.RemoteAddr(), conn.LocalAddr())
|
||||
return
|
||||
}
|
||||
|
||||
var user, pass string
|
||||
var raddr string
|
||||
for _, f := range req.Features {
|
||||
if f.Type() == relay.FeatureUserAuth {
|
||||
feature := f.(*relay.UserAuthFeature)
|
||||
user, pass = feature.Username, feature.Password
|
||||
}
|
||||
if f.Type() == relay.FeatureAddr {
|
||||
feature := f.(*relay.AddrFeature)
|
||||
raddr = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
|
||||
}
|
||||
}
|
||||
|
||||
resp := &relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) {
|
||||
resp.Status = relay.StatusUnauthorized
|
||||
resp.WriteTo(conn)
|
||||
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
|
||||
return
|
||||
}
|
||||
|
||||
if raddr != "" {
|
||||
if len(h.group.Nodes()) > 0 {
|
||||
resp.Status = relay.StatusForbidden
|
||||
resp.WriteTo(conn)
|
||||
log.Logf("[relay] %s -> %s : relay to %s is forbidden",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if len(h.group.Nodes()) == 0 {
|
||||
resp.Status = relay.StatusBadRequest
|
||||
resp.WriteTo(conn)
|
||||
log.Logf("[relay] %s -> %s : bad request, target addr is needed",
|
||||
conn.RemoteAddr(), conn.LocalAddr())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
udp := (req.Flags & relay.FUDP) == relay.FUDP
|
||||
retries := 1
|
||||
if h.options.Chain != nil && h.options.Chain.Retries > 0 {
|
||||
retries = h.options.Chain.Retries
|
||||
}
|
||||
if h.options.Retries > 0 {
|
||||
retries = h.options.Retries
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
if udp {
|
||||
network = "udp"
|
||||
}
|
||||
if !Can(network, raddr, h.options.Whitelist, h.options.Blacklist) {
|
||||
resp.Status = relay.StatusForbidden
|
||||
resp.WriteTo(conn)
|
||||
log.Logf("[relay] %s -> %s : relay to %s is forbidden",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
var cc net.Conn
|
||||
var node Node
|
||||
var err error
|
||||
for i := 0; i < retries; i++ {
|
||||
if len(h.group.Nodes()) > 0 {
|
||||
node, err = h.group.Next()
|
||||
if err != nil {
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
raddr = node.Addr
|
||||
}
|
||||
|
||||
log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr)
|
||||
cc, err = h.options.Chain.DialContext(ctx,
|
||||
network, raddr,
|
||||
RetryChainOption(h.options.Retries),
|
||||
TimeoutChainOption(h.options.Timeout),
|
||||
)
|
||||
if err != nil {
|
||||
log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err)
|
||||
node.MarkDead()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
return
|
||||
}
|
||||
|
||||
node.ResetDead()
|
||||
defer cc.Close()
|
||||
|
||||
sc := &relayConn{
|
||||
Conn: conn,
|
||||
isServer: true,
|
||||
udp: udp,
|
||||
}
|
||||
resp.WriteTo(&sc.wbuf)
|
||||
conn = sc
|
||||
|
||||
log.Logf("[relay] %s <-> %s", conn.RemoteAddr(), raddr)
|
||||
transport(conn, cc)
|
||||
log.Logf("[relay] %s >-< %s", conn.RemoteAddr(), raddr)
|
||||
}
|
||||
|
||||
type relayConn struct {
|
||||
net.Conn
|
||||
isServer bool
|
||||
udp bool
|
||||
wbuf bytes.Buffer
|
||||
once sync.Once
|
||||
headerSent bool
|
||||
}
|
||||
|
||||
func (c *relayConn) Read(b []byte) (n int, err error) {
|
||||
c.once.Do(func() {
|
||||
if c.isServer {
|
||||
return
|
||||
}
|
||||
resp := new(relay.Response)
|
||||
_, err = resp.ReadFrom(c.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.Version != relay.Version1 {
|
||||
err = relay.ErrBadVersion
|
||||
return
|
||||
}
|
||||
if resp.Status != relay.StatusOK {
|
||||
err = fmt.Errorf("status %d", resp.Status)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Logf("[relay] %s <- %s: %s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if !c.udp {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
var bb [2]byte
|
||||
_, err = io.ReadFull(c.Conn, bb[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dlen := int(binary.BigEndian.Uint16(bb[:]))
|
||||
if len(b) >= dlen {
|
||||
return io.ReadFull(c.Conn, b[:dlen])
|
||||
}
|
||||
buf := make([]byte, dlen)
|
||||
_, err = io.ReadFull(c.Conn, buf)
|
||||
n = copy(b, buf)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *relayConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.Read(b)
|
||||
addr = c.Conn.RemoteAddr()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *relayConn) Write(b []byte) (n int, err error) {
|
||||
if len(b) > 0xFFFF {
|
||||
err = errors.New("write: data maximum exceeded")
|
||||
return
|
||||
}
|
||||
n = len(b) // force byte length consistent
|
||||
if c.wbuf.Len() > 0 {
|
||||
if c.udp {
|
||||
var bb [2]byte
|
||||
binary.BigEndian.PutUint16(bb[:2], uint16(len(b)))
|
||||
c.wbuf.Write(bb[:])
|
||||
c.headerSent = true
|
||||
}
|
||||
c.wbuf.Write(b) // append the data to the cached header
|
||||
// _, err = c.Conn.Write(c.wbuf.Bytes())
|
||||
// c.wbuf.Reset()
|
||||
_, err = c.wbuf.WriteTo(c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
if !c.udp {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
if !c.headerSent {
|
||||
c.headerSent = true
|
||||
b2 := make([]byte, len(b)+2)
|
||||
copy(b2, b)
|
||||
_, err = c.Conn.Write(b2)
|
||||
return
|
||||
}
|
||||
nsize := 2 + len(b)
|
||||
var buf []byte
|
||||
if nsize <= mediumBufferSize {
|
||||
buf = mPool.Get().([]byte)
|
||||
defer mPool.Put(buf)
|
||||
} else {
|
||||
buf = make([]byte, nsize)
|
||||
}
|
||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
|
||||
n = copy(buf[2:], b)
|
||||
_, err = c.Conn.Write(buf[:nsize])
|
||||
return
|
||||
}
|
||||
|
||||
func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.Write(b)
|
||||
}
|
Reference in New Issue
Block a user