x/internal/util/icmp/conn.go
2024-07-15 20:34:59 +08:00

314 lines
6.0 KiB
Go

package icmp
import (
"encoding/binary"
"errors"
"fmt"
"math"
"net"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
ICMPv4 = 1
ICMPv6 = 58
)
const (
readBufferSize = 4096
writeBufferSize = 4096
magicNumber = 0x474F5354
)
const (
messageHeaderLen = 10
)
const (
FlagAck = 1
)
var (
ErrInvalidPacket = errors.New("icmp: invalid packet")
ErrInvalidType = errors.New("icmp: invalid type")
ErrShortBuffer = errors.New("icmp: short buffer")
)
type message struct {
// magic uint32 // magic number
flags uint16 // flags
// rsv uint16 // reserved field
// len uint16 // length of data
data []byte
}
func (m *message) Encode(b []byte) (n int, err error) {
if len(b) < messageHeaderLen+len(m.data) {
err = ErrShortBuffer
return
}
binary.BigEndian.PutUint32(b[:4], magicNumber) // magic number
binary.BigEndian.PutUint16(b[4:6], m.flags) // flags
binary.BigEndian.PutUint16(b[6:8], 0) // reserved
binary.BigEndian.PutUint16(b[8:10], uint16(len(m.data)))
copy(b[messageHeaderLen:], m.data)
n = messageHeaderLen + len(m.data)
return
}
func (m *message) Decode(b []byte) (n int, err error) {
if len(b) < messageHeaderLen {
err = ErrShortBuffer
return
}
if binary.BigEndian.Uint32(b[:4]) != magicNumber {
err = ErrInvalidPacket
return
}
m.flags = binary.BigEndian.Uint16(b[4:6])
length := binary.BigEndian.Uint16(b[8:10])
if len(b[messageHeaderLen:]) < int(length) {
err = ErrShortBuffer
return
}
m.data = b[messageHeaderLen : messageHeaderLen+length]
n = messageHeaderLen + int(length)
return
}
type clientConn struct {
ip6 bool
net.PacketConn
id int
seq uint32
}
func ClientConn(ip6 bool, conn net.PacketConn, id int) net.PacketConn {
return &clientConn{
ip6: ip6,
PacketConn: conn,
id: id,
}
}
func (c *clientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
buf := bufpool.Get(readBufferSize)
defer bufpool.Put(buf)
for {
n, addr, err = c.PacketConn.ReadFrom(buf)
if err != nil {
return
}
proto := ICMPv4
if c.ip6 {
proto = ICMPv6
}
m, err := icmp.ParseMessage(proto, buf[:n])
if err != nil {
// logger.Default().Error("icmp: parse message %v", err)
return 0, addr, err
}
echo, ok := m.Body.(*icmp.Echo)
if !ok || m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
// logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type)
continue // discard
}
if echo.ID != c.id {
// logger.Default().Warnf("icmp: id mismatch got %d, should be %d (discarded)", echo.ID, c.id)
continue
}
msg := message{}
if _, err := msg.Decode(echo.Data); err != nil {
logger.Default().Warn(err)
continue
}
if msg.flags&FlagAck == 0 {
// logger.Default().Warn("icmp: invalid message (discarded)")
continue
}
n = copy(b, msg.data)
break
}
if v, ok := addr.(*net.IPAddr); ok {
addr = &net.UDPAddr{
IP: v.IP,
Port: c.id,
Zone: v.Zone,
}
}
// logger.Default().Infof("icmp: read from: %v %d", addr, n)
return
}
func (c *clientConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
// logger.Default().Infof("icmp: write to: %v %d", addr, len(b))
switch v := addr.(type) {
case *net.UDPAddr:
addr = &net.IPAddr{IP: v.IP, Zone: v.Zone}
}
buf := bufpool.Get(writeBufferSize)
defer bufpool.Put(buf)
msg := message{
data: b,
}
nn, err := msg.Encode(buf)
if err != nil {
return
}
echo := icmp.Echo{
ID: c.id,
Seq: int(atomic.AddUint32(&c.seq, 1)),
Data: buf[:nn],
}
m := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &echo,
}
if c.ip6 {
m.Type = ipv6.ICMPTypeEchoRequest
}
wb, err := m.Marshal(nil)
if err != nil {
return 0, err
}
_, err = c.PacketConn.WriteTo(wb, addr)
n = len(b)
return
}
type serverConn struct {
ip6 bool
net.PacketConn
seqs [65535]uint32
}
func ServerConn(ip6 bool, conn net.PacketConn) net.PacketConn {
return &serverConn{
ip6: ip6,
PacketConn: conn,
}
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
buf := bufpool.Get(readBufferSize)
defer bufpool.Put(buf)
for {
n, addr, err = c.PacketConn.ReadFrom(buf)
if err != nil {
return
}
proto := ICMPv4
if c.ip6 {
proto = ICMPv6
}
m, err := icmp.ParseMessage(proto, buf[:n])
if err != nil {
// logger.Default().Error("icmp: parse message %v", err)
return 0, addr, err
}
echo, ok := m.Body.(*icmp.Echo)
if !ok || echo.ID <= 0 || m.Type != ipv4.ICMPTypeEcho && m.Type != ipv6.ICMPTypeEchoRequest {
// logger.Default().Warnf("icmp: invalid type %s (discarded)", m.Type)
continue
}
atomic.StoreUint32(&c.seqs[uint16(echo.ID-1)], uint32(echo.Seq))
msg := message{}
if _, err := msg.Decode(echo.Data); err != nil {
continue
}
if msg.flags&FlagAck > 0 {
continue
}
n = copy(b, msg.data)
if v, ok := addr.(*net.IPAddr); ok {
addr = &net.UDPAddr{
IP: v.IP,
Port: echo.ID,
Zone: v.Zone,
}
}
break
}
// logger.Default().Infof("icmp: read from: %v %d", addr, n)
return
}
func (c *serverConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
// logger.Default().Infof("icmp: write to: %v %d", addr, len(b))
var id int
switch v := addr.(type) {
case *net.UDPAddr:
addr = &net.IPAddr{IP: v.IP, Zone: v.Zone}
id = v.Port
}
if id <= 0 || id > math.MaxUint16 {
err = fmt.Errorf("icmp: invalid message id %v", addr)
return
}
buf := bufpool.Get(writeBufferSize)
defer bufpool.Put(buf)
msg := message{
flags: FlagAck,
data: b,
}
nn, err := msg.Encode(buf)
if err != nil {
return
}
echo := icmp.Echo{
ID: id,
Seq: int(atomic.LoadUint32(&c.seqs[id-1])),
Data: buf[:nn],
}
m := icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Code: 0,
Body: &echo,
}
if c.ip6 {
m.Type = ipv6.ICMPTypeEchoReply
}
wb, err := m.Marshal(nil)
if err != nil {
return 0, err
}
_, err = c.PacketConn.WriteTo(wb, addr)
n = len(b)
return
}