99 lines
2.0 KiB
Go
99 lines
2.0 KiB
Go
package net
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"runtime"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/vishvananda/netns"
|
|
)
|
|
|
|
type SetBuffer interface {
|
|
SetReadBuffer(bytes int) error
|
|
SetWriteBuffer(bytes int) error
|
|
}
|
|
|
|
type SyscallConn interface {
|
|
SyscallConn() (syscall.RawConn, error)
|
|
}
|
|
|
|
type RemoteAddr interface {
|
|
RemoteAddr() net.Addr
|
|
}
|
|
|
|
// tcpraw.TCPConn
|
|
type SetDSCP interface {
|
|
SetDSCP(int) error
|
|
}
|
|
|
|
func IsIPv4(address string) bool {
|
|
return address != "" && address[0] != ':' && address[0] != '['
|
|
}
|
|
|
|
type ListenConfig struct {
|
|
Netns string
|
|
net.ListenConfig
|
|
}
|
|
|
|
func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) {
|
|
if lc.Netns != "" {
|
|
runtime.LockOSThread()
|
|
defer runtime.UnlockOSThread()
|
|
|
|
originNs, err := netns.Get()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.Get(): %v", err)
|
|
}
|
|
defer netns.Set(originNs)
|
|
|
|
var ns netns.NsHandle
|
|
if strings.HasPrefix(lc.Netns, "/") {
|
|
ns, err = netns.GetFromPath(lc.Netns)
|
|
} else {
|
|
ns, err = netns.GetFromName(lc.Netns)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.Get(%s): %v", lc.Netns, err)
|
|
}
|
|
defer ns.Close()
|
|
|
|
if err := netns.Set(ns); err != nil {
|
|
return nil, fmt.Errorf("netns.Set(%s): %v", lc.Netns, err)
|
|
}
|
|
}
|
|
|
|
return lc.ListenConfig.Listen(ctx, network, address)
|
|
}
|
|
|
|
func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
|
if lc.Netns != "" {
|
|
runtime.LockOSThread()
|
|
defer runtime.UnlockOSThread()
|
|
|
|
originNs, err := netns.Get()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.Get(): %v", err)
|
|
}
|
|
defer netns.Set(originNs)
|
|
|
|
var ns netns.NsHandle
|
|
if strings.HasPrefix(lc.Netns, "/") {
|
|
ns, err = netns.GetFromPath(lc.Netns)
|
|
} else {
|
|
ns, err = netns.GetFromName(lc.Netns)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.Get(%s): %v", lc.Netns, err)
|
|
}
|
|
defer ns.Close()
|
|
|
|
if err := netns.Set(ns); err != nil {
|
|
return nil, fmt.Errorf("netns.Set(%s): %v", lc.Netns, err)
|
|
}
|
|
}
|
|
return lc.ListenConfig.ListenPacket(ctx, network, address)
|
|
}
|