add sshd listener
This commit is contained in:
parent
a134026e76
commit
04dfc8c4c3
@ -274,6 +274,13 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
|
||||
Metadata: md,
|
||||
}
|
||||
|
||||
if svc.Handler.Type == "sshd" {
|
||||
svc.Handler.Auths = nil
|
||||
}
|
||||
if svc.Listener.Type == "sshd" {
|
||||
svc.Listener.Auths = auths
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
@ -354,6 +361,13 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
|
||||
Metadata: md,
|
||||
}
|
||||
|
||||
if node.Connector.Type == "sshd" {
|
||||
node.Connector.Auth = nil
|
||||
}
|
||||
if node.Dialer.Type == "sshd" {
|
||||
node.Dialer.Auth = auth
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
|
@ -91,6 +91,7 @@ func buildService(cfg *config.Config) (services []*service.Service) {
|
||||
|
||||
ln := registry.GetListener(svc.Listener.Type)(
|
||||
listener.AddrOption(svc.Addr),
|
||||
listener.ChainOption(chains[svc.Listener.Chain]),
|
||||
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
|
||||
listener.TLSConfigOption(tlsConfig),
|
||||
listener.LoggerOption(listenerLogger),
|
||||
|
@ -37,7 +37,6 @@ import (
|
||||
_ "github.com/go-gost/gost/pkg/handler/dns"
|
||||
_ "github.com/go-gost/gost/pkg/handler/forward/local"
|
||||
_ "github.com/go-gost/gost/pkg/handler/forward/remote"
|
||||
_ "github.com/go-gost/gost/pkg/handler/forward/ssh"
|
||||
_ "github.com/go-gost/gost/pkg/handler/http"
|
||||
_ "github.com/go-gost/gost/pkg/handler/http2"
|
||||
_ "github.com/go-gost/gost/pkg/handler/redirect"
|
||||
@ -47,6 +46,7 @@ import (
|
||||
_ "github.com/go-gost/gost/pkg/handler/socks/v5"
|
||||
_ "github.com/go-gost/gost/pkg/handler/ss"
|
||||
_ "github.com/go-gost/gost/pkg/handler/ss/udp"
|
||||
_ "github.com/go-gost/gost/pkg/handler/sshd"
|
||||
_ "github.com/go-gost/gost/pkg/handler/tap"
|
||||
_ "github.com/go-gost/gost/pkg/handler/tun"
|
||||
|
||||
@ -65,6 +65,7 @@ import (
|
||||
_ "github.com/go-gost/gost/pkg/listener/rtcp"
|
||||
_ "github.com/go-gost/gost/pkg/listener/rudp"
|
||||
_ "github.com/go-gost/gost/pkg/listener/ssh"
|
||||
_ "github.com/go-gost/gost/pkg/listener/sshd"
|
||||
_ "github.com/go-gost/gost/pkg/listener/tap"
|
||||
_ "github.com/go-gost/gost/pkg/listener/tcp"
|
||||
_ "github.com/go-gost/gost/pkg/listener/tls"
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
func init() {
|
||||
registry.RegisterDialer("http3", NewDialer)
|
||||
registry.RegisterDialer("h3", NewDialer)
|
||||
}
|
||||
|
||||
type http3Dialer struct {
|
||||
|
@ -164,7 +164,7 @@ func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer
|
||||
|
||||
func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) {
|
||||
config := ssh.ClientConfig{
|
||||
// Timeout: timeout,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
if d.md.user != nil {
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"github.com/go-gost/gosocks4"
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/go-gost/relay"
|
||||
@ -24,42 +23,37 @@ type autoHandler struct {
|
||||
socks4Handler handler.Handler
|
||||
socks5Handler handler.Handler
|
||||
relayHandler handler.Handler
|
||||
log logger.Logger
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := &handler.Options{}
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
log := options.Logger
|
||||
if log == nil {
|
||||
log = logger.Default()
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
h := &autoHandler{
|
||||
log: log,
|
||||
options: options,
|
||||
}
|
||||
|
||||
if f := registry.GetHandler("http"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"})))
|
||||
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "http"})))
|
||||
h.httpHandler = f(v...)
|
||||
}
|
||||
if f := registry.GetHandler("socks4"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks4"})))
|
||||
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks4"})))
|
||||
h.socks4Handler = f(v...)
|
||||
}
|
||||
if f := registry.GetHandler("socks5"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks5"})))
|
||||
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks5"})))
|
||||
h.socks5Handler = f(v...)
|
||||
}
|
||||
if f := registry.GetHandler("relay"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "relay"})))
|
||||
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "relay"})))
|
||||
h.relayHandler = f(v...)
|
||||
}
|
||||
|
||||
@ -92,15 +86,15 @@ func (h *autoHandler) Init(md md.Metadata) error {
|
||||
}
|
||||
|
||||
func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
h.log = h.log.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
h.log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.log.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -108,7 +102,7 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
br := bufio.NewReader(conn)
|
||||
b, err := br.Peek(1)
|
||||
if err != nil {
|
||||
h.log.Error(err)
|
||||
log.Error(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@ -132,5 +126,4 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
h.httpHandler.Handle(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -33,7 +33,6 @@ type dnsHandler struct {
|
||||
exchangers []exchanger.Exchanger
|
||||
cache *resolver_util.Cache
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -50,19 +49,18 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
h.logger = h.options.Logger
|
||||
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
log := h.options.Logger
|
||||
|
||||
h.cache = resolver_util.NewCache().WithLogger(h.options.Logger)
|
||||
h.cache = resolver_util.NewCache().WithLogger(log)
|
||||
h.router = &chain.Router{
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
Resolver: h.options.Resolver,
|
||||
// Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
Logger: log,
|
||||
}
|
||||
|
||||
for _, server := range h.md.dns {
|
||||
@ -74,10 +72,10 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
server,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(h.logger),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warnf("parse %s: %v", server, err)
|
||||
log.Warnf("parse %s: %v", server, err)
|
||||
continue
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
@ -87,9 +85,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
defaultNameserver,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(h.logger),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
h.logger.Warnf("resolver not found, default to %s", defaultNameserver)
|
||||
log.Warnf("resolver not found, default to %s", defaultNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -103,14 +101,14 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -120,26 +118,25 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
n, err := conn.Read(*b)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Info("read data: ", n)
|
||||
|
||||
reply, err := h.exchange(ctx, (*b)[:n])
|
||||
reply, err := h.exchange(ctx, (*b)[:n], log)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer bufpool.Put(&reply)
|
||||
|
||||
if _, err = conn.Write(reply); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
|
||||
mq := dns.Msg{}
|
||||
if err := mq.Unpack(msg); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -149,23 +146,23 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
h.logger.Debug(mq.String())
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
log.Debug(mq.String())
|
||||
} else {
|
||||
h.logger.Info(h.dumpMsgHeader(&mq))
|
||||
log.Info(h.dumpMsgHeader(&mq))
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
h.logger.Debug(mr.String())
|
||||
log.Debug(mr.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mr = h.lookupHosts(&mq)
|
||||
mr = h.lookupHosts(&mq, log)
|
||||
if mr != nil {
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
@ -176,7 +173,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||
mr = h.cache.Load(key)
|
||||
if mr != nil {
|
||||
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
mr.Id = mq.Id
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
@ -195,18 +192,18 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
query, err := mq.PackBuffer(*b)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
for _, ex := range h.exchangers {
|
||||
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
log.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
reply, err = ex.Exchange(ctx, query)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -214,21 +211,21 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
mr = &dns.Msg{}
|
||||
if err = mr.Unpack(reply); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
h.logger.Debug(mr.String())
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
log.Debug(mr.String())
|
||||
} else {
|
||||
h.logger.Info(h.dumpMsgHeader(mr))
|
||||
log.Info(h.dumpMsgHeader(mr))
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// lookup host mapper
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
|
||||
if h.options.Hosts == nil ||
|
||||
r.Question[0].Qclass != dns.ClassINET ||
|
||||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
|
||||
@ -246,12 +243,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
@ -262,12 +259,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
)
|
||||
@ -22,7 +21,6 @@ func init() {
|
||||
type forwardHandler struct {
|
||||
group *chain.NodeGroup
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -55,7 +53,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -69,21 +66,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
target := h.group.Next()
|
||||
if target == nil {
|
||||
h.logger.Error("no target available")
|
||||
log.Error("no target available")
|
||||
return
|
||||
}
|
||||
|
||||
@ -92,15 +89,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, target.Addr())
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
// TODO: the router itself may be failed due to the failed node in the router,
|
||||
// the dead marker may be a wrong operation.
|
||||
target.Marker().Mark()
|
||||
@ -110,11 +107,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
target.Marker().Reset()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
)
|
||||
@ -21,7 +20,6 @@ func init() {
|
||||
type forwardHandler struct {
|
||||
group *chain.NodeGroup
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -49,7 +47,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -63,21 +60,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
target := h.group.Next()
|
||||
if target == nil {
|
||||
h.logger.Error("no target available")
|
||||
log.Error("no target available")
|
||||
return
|
||||
}
|
||||
|
||||
@ -86,15 +83,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, target.Addr())
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
// TODO: the router itself may be failed due to the failed node in the router,
|
||||
// the dead marker may be a wrong operation.
|
||||
target.Marker().Mark()
|
||||
@ -104,11 +101,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
target.Marker().Reset()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}
|
||||
|
@ -1,300 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
|
||||
const (
|
||||
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
|
||||
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
|
||||
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
|
||||
CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterHandler("sshd", NewHandler)
|
||||
}
|
||||
|
||||
type forwardHandler struct {
|
||||
config *ssh.ServerConfig
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &forwardHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
authenticator := auth_util.AuthFromUsers(h.options.Auths...)
|
||||
|
||||
config := &ssh.ServerConfig{
|
||||
PasswordCallback: ssh_util.PasswordCallback(authenticator),
|
||||
PublicKeyCallback: ssh_util.PublicKeyCallback(h.md.authorizedKeys),
|
||||
}
|
||||
|
||||
config.AddHostKey(h.md.signer)
|
||||
|
||||
if authenticator == nil && len(h.md.authorizedKeys) == 0 {
|
||||
config.NoClientAuth = true
|
||||
}
|
||||
|
||||
h.config = config
|
||||
h.router = &chain.Router{
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
Resolver: h.options.Resolver,
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleForward(ctx, sshConn, chans, reqs)
|
||||
}
|
||||
|
||||
func (h *forwardHandler) handleForward(ctx context.Context, conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
|
||||
quit := make(chan struct{})
|
||||
defer close(quit) // quit signal
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
switch req.Type {
|
||||
case RemoteForwardRequest:
|
||||
go h.tcpipForwardRequest(conn, req, quit)
|
||||
default:
|
||||
h.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
|
||||
if req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for newChannel := range chans {
|
||||
// Check the type of channel
|
||||
t := newChannel.ChannelType()
|
||||
switch t {
|
||||
case DirectForwardRequest:
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
h.logger.Warnf("could not accept channel: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
p := directForward{}
|
||||
ssh.Unmarshal(newChannel.ExtraData(), &p)
|
||||
|
||||
h.logger.Debug(p.String())
|
||||
|
||||
if p.Host1 == "<nil>" {
|
||||
p.Host1 = ""
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
go h.directPortForwardChannel(ctx, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
|
||||
default:
|
||||
h.logger.Warnf("unsupported channel type: %s", t)
|
||||
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("unsupported channel type: %s", t))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn.Wait()
|
||||
}
|
||||
|
||||
func (h *forwardHandler) directPortForwardChannel(ctx context.Context, channel ssh.Channel, raddr string) {
|
||||
defer channel.Close()
|
||||
|
||||
// log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr)
|
||||
|
||||
/*
|
||||
if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) {
|
||||
log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr)
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr) {
|
||||
h.logger.Infof("bypass %s", raddr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := h.router.Dial(ctx, "tcp", raddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
handler.Transport(conn, channel)
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
|
||||
type directForward struct {
|
||||
Host1 string
|
||||
Port1 uint32
|
||||
Host2 string
|
||||
Port2 uint32
|
||||
}
|
||||
|
||||
func (p directForward) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
|
||||
}
|
||||
|
||||
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
||||
host, portString, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
port, err = strconv.Atoi(portString)
|
||||
return
|
||||
}
|
||||
|
||||
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
|
||||
type tcpipForward struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
func (h *forwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) {
|
||||
t := tcpipForward{}
|
||||
ssh.Unmarshal(req.Payload, &t)
|
||||
|
||||
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
|
||||
|
||||
/*
|
||||
if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
|
||||
log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr)
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
// tie to the client connection
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
h.logger.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
err = func() error {
|
||||
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
|
||||
_, port, err := getHostPortFromAddr(ln.Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], uint32(port))
|
||||
t.Port = uint32(port)
|
||||
return req.Reply(true, b[:])
|
||||
}
|
||||
return req.Reply(true, nil)
|
||||
}()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil { // Unable to accept new connection - listener is likely closed
|
||||
return
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
p := directForward{}
|
||||
var err error
|
||||
|
||||
var portnum int
|
||||
p.Host1 = t.Host
|
||||
p.Port1 = t.Port
|
||||
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.Port2 = uint32(portnum)
|
||||
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
|
||||
if err != nil {
|
||||
h.logger.Error("open forwarded channel: ", err)
|
||||
return
|
||||
}
|
||||
defer ch.Close()
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
handler.Transport(ch, conn)
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
<-quit
|
||||
}
|
@ -32,7 +32,6 @@ func init() {
|
||||
type httpHandler struct {
|
||||
router *chain.Router
|
||||
authenticator auth.Authenticator
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -61,7 +60,6 @@ func (h *httpHandler) Init(md md.Metadata) error {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -70,28 +68,28 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
req, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
h.handleRequest(ctx, conn, req)
|
||||
h.handleRequest(ctx, conn, req, log)
|
||||
}
|
||||
|
||||
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) {
|
||||
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
@ -129,16 +127,16 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
fields := map[string]interface{}{
|
||||
"dst": addr,
|
||||
}
|
||||
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
|
||||
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" {
|
||||
fields["user"] = u
|
||||
}
|
||||
h.logger = h.logger.WithFields(fields)
|
||||
log = log.WithFields(fields)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
@ -152,22 +150,22 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
|
||||
resp.StatusCode = http.StatusForbidden
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
h.logger.Info("bypass: ", addr)
|
||||
log.Info("bypass: ", addr)
|
||||
|
||||
resp.Write(conn)
|
||||
return
|
||||
}
|
||||
|
||||
if !h.authenticate(conn, req, resp) {
|
||||
if !h.authenticate(conn, req, resp, log) {
|
||||
return
|
||||
}
|
||||
|
||||
if network == "udp" {
|
||||
h.handleUDP(ctx, conn, network, req.Host)
|
||||
h.handleUDP(ctx, conn, network, req.Host, log)
|
||||
return
|
||||
}
|
||||
|
||||
@ -176,9 +174,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
resp.StatusCode = http.StatusBadRequest
|
||||
resp.Write(conn)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
return
|
||||
@ -191,9 +189,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
resp.StatusCode = http.StatusServiceUnavailable
|
||||
resp.Write(conn)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -203,30 +201,28 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
||||
resp.StatusCode = http.StatusOK
|
||||
resp.Status = "200 Connection established"
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
if err = resp.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
req.Header.Del("Proxy-Connection")
|
||||
if err = req.Write(cc); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}
|
||||
|
||||
func (h *httpHandler) decodeServerName(s string) (string, error) {
|
||||
@ -247,7 +243,7 @@ func (h *httpHandler) decodeServerName(s string) (string, error) {
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
|
||||
func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) {
|
||||
if proxyAuth == "" {
|
||||
return
|
||||
}
|
||||
@ -268,8 +264,8 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password strin
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) {
|
||||
u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"))
|
||||
func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
|
||||
u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log)
|
||||
if h.authenticator == nil || h.authenticator.Authenticate(u, p) {
|
||||
return true
|
||||
}
|
||||
@ -289,7 +285,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
|
||||
}
|
||||
r, err := http.Get(url)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
resp = r
|
||||
@ -297,7 +293,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
|
||||
case "host":
|
||||
cc, err := net.Dial("tcp", pr.Value)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
defer cc.Close()
|
||||
@ -333,7 +329,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
|
||||
resp.Header.Add("Proxy-Connection", "close")
|
||||
}
|
||||
|
||||
h.logger.Info("proxy authentication required")
|
||||
log.Info("proxy authentication required")
|
||||
} else {
|
||||
resp.Header.Set("Server", "nginx/1.20.1")
|
||||
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
|
||||
@ -342,9 +338,9 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
|
||||
}
|
||||
}
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
resp.Write(conn)
|
||||
|
@ -12,8 +12,8 @@ import (
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"cmd": "udp",
|
||||
})
|
||||
|
||||
@ -30,49 +30,47 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
|
||||
resp.StatusCode = http.StatusForbidden
|
||||
resp.Write(conn)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
h.logger.Error("UDP relay is diabled")
|
||||
log.Error("UDP relay is diabled")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.StatusCode = http.StatusOK
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
if err := resp.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// obtain a udp connection
|
||||
c, err := h.router.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
log.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||
WithBypass(h.options.Bypass).
|
||||
WithLogger(h.logger)
|
||||
WithLogger(log)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
relay.Run()
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}
|
||||
|
@ -35,7 +35,6 @@ func init() {
|
||||
type http2Handler struct {
|
||||
router *chain.Router
|
||||
authenticator auth.Authenticator
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -64,7 +63,6 @@ func (h *http2Handler) Init(md md.Metadata) error {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -72,29 +70,29 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
cc, ok := conn.(*http2_util.ServerConn)
|
||||
if !ok {
|
||||
h.logger.Error("wrong connection type")
|
||||
log.Error("wrong connection type")
|
||||
return
|
||||
}
|
||||
h.roundTrip(ctx, cc.Writer(), cc.Request())
|
||||
h.roundTrip(ctx, cc.Writer(), cc.Request(), log)
|
||||
}
|
||||
|
||||
// NOTE: there is an issue (golang/go#43989) will cause the client hangs
|
||||
// when server returns an non-200 status code,
|
||||
// May be fixed in go1.18.
|
||||
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request) {
|
||||
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) {
|
||||
// Try to get the actual host.
|
||||
// Compatible with GOST 2.x.
|
||||
if v := req.Header.Get("Gost-Target"); v != "" {
|
||||
@ -122,21 +120,21 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
|
||||
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
|
||||
fields["user"] = u
|
||||
}
|
||||
h.logger = h.logger.WithFields(fields)
|
||||
log = log.WithFields(fields)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpRequest(req, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
h.logger.Infof("%s >> %s", req.RemoteAddr, addr)
|
||||
log.Infof("%s >> %s", req.RemoteAddr, addr)
|
||||
|
||||
if h.md.proxyAgent != "" {
|
||||
w.Header().Set("Proxy-Agent", h.md.proxyAgent)
|
||||
for k := range h.md.header {
|
||||
w.Header().Set(k, h.md.header.Get(k))
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
h.logger.Info("bypass: ", addr)
|
||||
log.Info("bypass: ", addr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -147,7 +145,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
|
||||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
|
||||
if !h.authenticate(w, req, resp) {
|
||||
if !h.authenticate(w, req, resp, log) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -157,7 +155,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
@ -174,30 +172,28 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
|
||||
// we take over the underly connection
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
h.logger.Infof("%s <-> %s", req.RemoteAddr, addr)
|
||||
log.Infof("%s <-> %s", req.RemoteAddr, addr)
|
||||
handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).
|
||||
Infof("%s >-< %s", req.RemoteAddr, addr)
|
||||
}).Infof("%s >-< %s", req.RemoteAddr, addr)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -241,7 +237,7 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) {
|
||||
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
|
||||
u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
||||
if h.authenticator == nil || h.authenticator.Authenticate(u, p) {
|
||||
return true
|
||||
@ -261,7 +257,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
|
||||
}
|
||||
r, err := http.Get(url)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
resp = r
|
||||
@ -269,13 +265,13 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
|
||||
case "host":
|
||||
cc, err := net.Dial("tcp", pr.Value)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
if err := h.forwardRequest(w, r, cc); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
case "file":
|
||||
@ -303,7 +299,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
|
||||
resp.Header.Add("Proxy-Connection", "close")
|
||||
}
|
||||
|
||||
h.logger.Info("proxy authentication required")
|
||||
log.Info("proxy authentication required")
|
||||
} else {
|
||||
resp.Header = http.Header{}
|
||||
resp.Header.Set("Server", "nginx/1.20.1")
|
||||
@ -313,9 +309,9 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
|
||||
}
|
||||
}
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
dump, _ := httputil.DumpResponse(resp, false)
|
||||
h.logger.Debug(string(dump))
|
||||
log.Debug(string(dump))
|
||||
}
|
||||
|
||||
h.writeResponse(w, resp)
|
||||
|
@ -1,28 +1,31 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
proxyAgent string
|
||||
probeResistance *probeResistance
|
||||
sni bool
|
||||
enableUDP bool
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
|
||||
const (
|
||||
proxyAgent = "proxyAgent"
|
||||
header = "header"
|
||||
probeResistKey = "probeResistance"
|
||||
knock = "knock"
|
||||
sni = "sni"
|
||||
enableUDP = "udp"
|
||||
)
|
||||
|
||||
h.md.proxyAgent = mdata.GetString(md, proxyAgent)
|
||||
if m := mdata.GetStringMapString(md, header); len(m) > 0 {
|
||||
hd := http.Header{}
|
||||
for k, v := range m {
|
||||
hd.Add(k, v)
|
||||
}
|
||||
h.md.header = hd
|
||||
}
|
||||
|
||||
if v := mdata.GetString(md, probeResistKey); v != "" {
|
||||
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 {
|
||||
@ -33,8 +36,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
h.md.sni = mdata.GetBool(md, sni)
|
||||
h.md.enableUDP = mdata.GetBool(md, enableUDP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
)
|
||||
@ -22,7 +21,6 @@ func init() {
|
||||
|
||||
type redirectHandler struct {
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -50,7 +48,6 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -59,14 +56,14 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -83,35 +80,33 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if network == "tcp" {
|
||||
dstAddr, conn, err = h.getOriginalDstAddr(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
|
||||
h.logger.Info("bypass: ", dstAddr)
|
||||
log.Info("bypass: ", dstAddr)
|
||||
return
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, dstAddr.String())
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
|
||||
}
|
||||
|
@ -13,7 +13,6 @@ func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c ne
|
||||
tc, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
err = errors.New("wrong connection type, must be TCP")
|
||||
h.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -10,16 +10,17 @@ import (
|
||||
"github.com/go-gost/gost/pkg/common/util/mux"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "bind",
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
@ -29,18 +30,18 @@ func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, a
|
||||
if !h.md.enableBind {
|
||||
resp.Status = relay.StatusForbidden
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error("BIND is diabled")
|
||||
log.Error("BIND is diabled")
|
||||
return
|
||||
}
|
||||
|
||||
if network == "tcp" {
|
||||
h.bindTCP(ctx, conn, network, address)
|
||||
h.bindTCP(ctx, conn, network, address, log)
|
||||
} else {
|
||||
h.bindUDP(ctx, conn, network, address)
|
||||
h.bindUDP(ctx, conn, network, address, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string) {
|
||||
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
@ -48,7 +49,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
|
||||
|
||||
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
return
|
||||
@ -57,7 +58,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
|
||||
af := &relay.AddrFeature{}
|
||||
err = af.ParseFrom(ln.Addr().String())
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
@ -65,20 +66,20 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
|
||||
af.AType = relay.AddrIPv4
|
||||
resp.Features = append(resp.Features, af)
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
ln.Close()
|
||||
return
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
h.logger.Debugf("bind on %s OK", ln.Addr())
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
h.serveTCPBind(ctx, conn, ln)
|
||||
h.serveTCPBind(ctx, conn, ln, log)
|
||||
}
|
||||
|
||||
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string) {
|
||||
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
@ -87,7 +88,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
|
||||
bindAddr, _ := net.ResolveUDPAddr(network, address)
|
||||
pc, err := net.ListenUDP(network, bindAddr)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer pc.Close()
|
||||
@ -95,7 +96,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
|
||||
af := &relay.AddrFeature{}
|
||||
err = af.ParseFrom(pc.LocalAddr().String())
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
@ -103,33 +104,32 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
|
||||
af.AType = relay.AddrIPv4
|
||||
resp.Features = append(resp.Features, af)
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": pc.LocalAddr().String(),
|
||||
})
|
||||
h.logger.Debugf("bind on %s OK", pc.LocalAddr())
|
||||
log.Debugf("bind on %s OK", pc.LocalAddr())
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
h.tunnelServerUDP(
|
||||
socks.UDPTunServerConn(conn),
|
||||
pc,
|
||||
log,
|
||||
)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}
|
||||
|
||||
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener) {
|
||||
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
|
||||
// Upgrade connection to multiplex stream.
|
||||
session, err := mux.ClientSession(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
@ -139,7 +139,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
|
||||
for {
|
||||
conn, err := session.Accept()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn.Close() // we do not handle incoming connections.
|
||||
@ -149,17 +149,22 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
|
||||
for {
|
||||
rc, err := ln.Accept()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
log.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"local": ln.Addr().String(),
|
||||
"remote": c.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
sc, err := session.GetConn()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
@ -172,21 +177,20 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
|
||||
Features: []relay.Feature{af},
|
||||
}
|
||||
if _, err := resp.WriteTo(sc); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String())
|
||||
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
|
||||
handler.Transport(sc, c)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String())
|
||||
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
|
||||
}(rc)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
||||
func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn, log logger.Logger) (err error) {
|
||||
bufSize := h.md.udpBufferSize
|
||||
errc := make(chan error, 2)
|
||||
|
||||
@ -202,7 +206,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
log.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -210,7 +214,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s >>> %s data: %d",
|
||||
log.Debugf("%s >>> %s data: %d",
|
||||
c.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
@ -235,14 +239,14 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
log.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := tunnel.WriteTo((*b)[:n], raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("%s <<< %s data: %d",
|
||||
log.Debugf("%s <<< %s data: %d",
|
||||
c.LocalAddr(), raddr, n)
|
||||
|
||||
return nil
|
||||
|
@ -7,16 +7,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "connect",
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
@ -26,12 +27,12 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
|
||||
if address == "" {
|
||||
resp.Status = relay.StatusBadRequest
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error("target not specified")
|
||||
log.Error("target not specified")
|
||||
return
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
|
||||
h.logger.Info("bypass: ", address)
|
||||
log.Info("bypass: ", address)
|
||||
resp.Status = relay.StatusForbidden
|
||||
resp.WriteTo(conn)
|
||||
return
|
||||
@ -47,7 +48,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
|
||||
|
||||
if h.md.noDelay {
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -78,11 +79,9 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), address)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
|
||||
}
|
||||
|
@ -7,10 +7,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/go-gost/relay"
|
||||
)
|
||||
|
||||
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) {
|
||||
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) {
|
||||
resp := relay.Response{
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
@ -19,15 +20,16 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
|
||||
if target == nil {
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error("no target available")
|
||||
log.Error("no target available")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
|
||||
"cmd": "forward",
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
|
||||
cc, err := h.router.Dial(ctx, network, target.Addr())
|
||||
if err != nil {
|
||||
@ -37,7 +39,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
|
||||
|
||||
resp.Status = relay.StatusHostUnreachable
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -46,7 +48,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
|
||||
|
||||
if h.md.noDelay {
|
||||
if _, err := resp.WriteTo(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -77,11 +79,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/go-gost/relay"
|
||||
@ -24,7 +23,6 @@ type relayHandler struct {
|
||||
group *chain.NodeGroup
|
||||
router *chain.Router
|
||||
authenticator auth.Authenticator
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -53,7 +51,6 @@ func (h *relayHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -66,14 +63,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -84,14 +81,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
req := relay.Request{}
|
||||
if _, err := req.ReadFrom(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if req.Version != relay.Version1 {
|
||||
h.logger.Error("bad version")
|
||||
log.Error("bad version")
|
||||
return
|
||||
}
|
||||
|
||||
@ -109,7 +106,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
if user != "" {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{"user": user})
|
||||
log = log.WithFields(map[string]interface{}{"user": user})
|
||||
}
|
||||
|
||||
resp := relay.Response{
|
||||
@ -119,7 +116,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if h.authenticator != nil && !h.authenticator.Authenticate(user, pass) {
|
||||
resp.Status = relay.StatusUnauthorized
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error("unauthorized")
|
||||
log.Error("unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
@ -132,18 +129,18 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if address != "" {
|
||||
resp.Status = relay.StatusForbidden
|
||||
resp.WriteTo(conn)
|
||||
h.logger.Error("forward mode, connect is forbidden")
|
||||
log.Error("forward mode, connect is forbidden")
|
||||
return
|
||||
}
|
||||
// forward mode
|
||||
h.handleForward(ctx, conn, network)
|
||||
h.handleForward(ctx, conn, network, log)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Flags & relay.CmdMask {
|
||||
case 0, relay.CONNECT:
|
||||
h.handleConnect(ctx, conn, network, address)
|
||||
h.handleConnect(ctx, conn, network, address, log)
|
||||
case relay.BIND:
|
||||
h.handleBind(ctx, conn, network, address)
|
||||
h.handleBind(ctx, conn, network, address, log)
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
dissector "github.com/go-gost/tls-dissector"
|
||||
@ -27,7 +26,6 @@ func init() {
|
||||
type sniHandler struct {
|
||||
httpHandler handler.Handler
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -38,19 +36,13 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
log := options.Logger
|
||||
if log == nil {
|
||||
log = logger.Default()
|
||||
}
|
||||
|
||||
h := &sniHandler{
|
||||
options: options,
|
||||
logger: log,
|
||||
}
|
||||
|
||||
if f := registry.GetHandler("http"); f != nil {
|
||||
v := append(opts,
|
||||
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"})))
|
||||
handler.LoggerOption(h.options.Logger.WithFields(map[string]interface{}{"type": "http"})))
|
||||
h.httpHandler = f(v...)
|
||||
}
|
||||
|
||||
@ -85,21 +77,21 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
var hdr [dissector.RecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -121,25 +113,25 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
buf := bufpool.Get(int(length) + dissector.RecordHeaderLen)
|
||||
defer bufpool.Put(buf)
|
||||
if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
copy(*buf, hdr[:])
|
||||
|
||||
opaque, host, err := h.decodeHost(bytes.NewReader(*buf))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
target := net.JoinHostPort(host, "443")
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": target,
|
||||
})
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(target) {
|
||||
h.logger.Info("bypass: ", target)
|
||||
log.Info("bypass: ", target)
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,18 +142,16 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer cc.Close()
|
||||
|
||||
if _, err := cc.Write(opaque); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), target)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), target)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), target)
|
||||
}
|
||||
|
||||
func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) {
|
||||
|
@ -23,7 +23,6 @@ func init() {
|
||||
type socks4Handler struct {
|
||||
router *chain.Router
|
||||
authenticator auth.Authenticator
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -52,7 +51,6 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -62,14 +60,14 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
start := time.Now()
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -80,10 +78,10 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
req, err := gosocks4.ReadRequest(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(req)
|
||||
log.Debug(req)
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
@ -91,33 +89,33 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
!h.authenticator.Authenticate(string(req.Userid), "") {
|
||||
resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Cmd {
|
||||
case gosocks4.CmdConnect:
|
||||
h.handleConnect(ctx, conn, req)
|
||||
h.handleConnect(ctx, conn, req, log)
|
||||
case gosocks4.CmdBind:
|
||||
h.handleBind(ctx, conn, req)
|
||||
default:
|
||||
h.logger.Errorf("unknown cmd: %d", req.Cmd)
|
||||
log.Errorf("unknown cmd: %d", req.Cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request) {
|
||||
func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) {
|
||||
addr := req.Addr.String()
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": addr,
|
||||
})
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
|
||||
resp := gosocks4.NewReply(gosocks4.Rejected, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
h.logger.Info("bypass: ", addr)
|
||||
log.Debug(resp)
|
||||
log.Info("bypass: ", addr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -125,7 +123,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
|
||||
if err != nil {
|
||||
resp := gosocks4.NewReply(gosocks4.Failed, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -133,19 +131,17 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
|
||||
|
||||
resp := gosocks4.NewReply(gosocks4.Granted, nil)
|
||||
if err := resp.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}
|
||||
|
||||
func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) {
|
||||
|
@ -8,43 +8,44 @@ import (
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "bind",
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
if !h.md.enableBind {
|
||||
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
|
||||
reply.Write(conn)
|
||||
h.logger.Debug(reply)
|
||||
h.logger.Error("BIND is diabled")
|
||||
log.Debug(reply)
|
||||
log.Error("BIND is diabled")
|
||||
return
|
||||
}
|
||||
|
||||
// BIND does not support chain.
|
||||
h.bindLocal(ctx, conn, network, address)
|
||||
h.bindLocal(ctx, conn, network, address, log)
|
||||
}
|
||||
|
||||
func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string) {
|
||||
func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
return
|
||||
}
|
||||
|
||||
socksAddr := gosocks5.Addr{}
|
||||
if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
@ -52,22 +53,22 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a
|
||||
socksAddr.Type = 0
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
ln.Close()
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
|
||||
h.logger.Debugf("bind on %s OK", ln.Addr())
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
h.serveBind(ctx, conn, ln)
|
||||
h.serveBind(ctx, conn, ln, log)
|
||||
}
|
||||
|
||||
func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener) {
|
||||
func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
|
||||
var rc net.Conn
|
||||
accept := func() <-chan error {
|
||||
errc := make(chan error, 1)
|
||||
@ -105,38 +106,42 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
|
||||
select {
|
||||
case err := <-accept():
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
|
||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||
if err := reply.Write(pc2); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
return
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
log.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"local": rc.LocalAddr().String(),
|
||||
"remote": rc.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
raddr := gosocks5.Addr{}
|
||||
raddr.ParseFrom(rc.RemoteAddr().String())
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr)
|
||||
if err := reply.Write(pc2); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
start := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), raddr.String())
|
||||
log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr())
|
||||
handler.Transport(pc2, rc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(start)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), raddr.String())
|
||||
log.WithFields(map[string]interface{}{"duration": time.Since(start)}).
|
||||
Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr())
|
||||
|
||||
case err := <-pipe():
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
ln.Close()
|
||||
return
|
||||
|
@ -8,20 +8,21 @@ import (
|
||||
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "connect",
|
||||
})
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
|
||||
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
h.logger.Info("bypass: ", address)
|
||||
log.Debug(resp)
|
||||
log.Info("bypass: ", address)
|
||||
return
|
||||
}
|
||||
|
||||
@ -29,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
|
||||
if err != nil {
|
||||
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
return
|
||||
}
|
||||
|
||||
@ -37,17 +38,15 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
|
||||
|
||||
resp := gosocks5.NewReply(gosocks5.Succeeded, nil)
|
||||
if err := resp.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), address)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
)
|
||||
@ -23,7 +22,6 @@ func init() {
|
||||
type socks5Handler struct {
|
||||
selector gosocks5.Selector
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -44,7 +42,6 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
h.logger = h.options.Logger
|
||||
h.router = &chain.Router{
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
@ -56,7 +53,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
|
||||
h.selector = &serverSelector{
|
||||
Authenticator: auth_util.AuthFromUsers(h.options.Auths...),
|
||||
TLSConfig: h.options.TLSConfig,
|
||||
logger: h.logger,
|
||||
logger: h.options.Logger,
|
||||
noTLS: h.md.noTLS,
|
||||
}
|
||||
|
||||
@ -68,14 +65,14 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
start := time.Now()
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -87,30 +84,30 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
conn = gosocks5.ServerConn(conn, h.selector)
|
||||
req, err := gosocks5.ReadRequest(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(req)
|
||||
log.Debug(req)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
address := req.Addr.String()
|
||||
|
||||
switch req.Cmd {
|
||||
case gosocks5.CmdConnect:
|
||||
h.handleConnect(ctx, conn, "tcp", address)
|
||||
h.handleConnect(ctx, conn, "tcp", address, log)
|
||||
case gosocks5.CmdBind:
|
||||
h.handleBind(ctx, conn, "tcp", address)
|
||||
h.handleBind(ctx, conn, "tcp", address, log)
|
||||
case socks.CmdMuxBind:
|
||||
h.handleMuxBind(ctx, conn, "tcp", address)
|
||||
h.handleMuxBind(ctx, conn, "tcp", address, log)
|
||||
case gosocks5.CmdUdp:
|
||||
h.handleUDP(ctx, conn)
|
||||
h.handleUDP(ctx, conn, log)
|
||||
case socks.CmdUDPTun:
|
||||
h.handleUDPTun(ctx, conn, "udp", address)
|
||||
h.handleUDPTun(ctx, conn, "udp", address, log)
|
||||
default:
|
||||
h.logger.Errorf("unknown cmd: %d", req.Cmd)
|
||||
log.Errorf("unknown cmd: %d", req.Cmd)
|
||||
resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil)
|
||||
resp.Write(conn)
|
||||
h.logger.Debug(resp)
|
||||
log.Debug(resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -9,43 +9,44 @@ import (
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/common/util/mux"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", address, network),
|
||||
"cmd": "mbind",
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), address)
|
||||
|
||||
if !h.md.enableBind {
|
||||
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
|
||||
reply.Write(conn)
|
||||
h.logger.Debug(reply)
|
||||
h.logger.Error("BIND is diabled")
|
||||
log.Debug(reply)
|
||||
log.Error("BIND is diabled")
|
||||
return
|
||||
}
|
||||
|
||||
h.muxBindLocal(ctx, conn, network, address)
|
||||
h.muxBindLocal(ctx, conn, network, address, log)
|
||||
}
|
||||
|
||||
func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string) {
|
||||
func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
return
|
||||
}
|
||||
|
||||
socksAddr := gosocks5.Addr{}
|
||||
err = socksAddr.ParseFrom(ln.Addr().String())
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
}
|
||||
|
||||
// Issue: may not reachable when host has multi-interface
|
||||
@ -53,26 +54,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network
|
||||
socksAddr.Type = 0
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
ln.Close()
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
|
||||
h.logger.Debugf("bind on %s OK", ln.Addr())
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
h.serveMuxBind(ctx, conn, ln)
|
||||
h.serveMuxBind(ctx, conn, ln, log)
|
||||
}
|
||||
|
||||
func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener) {
|
||||
func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
|
||||
// Upgrade connection to multiplex stream.
|
||||
session, err := mux.ClientSession(conn)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
@ -82,7 +83,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
|
||||
for {
|
||||
conn, err := session.Accept()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn.Close() // we do not handle incoming connections.
|
||||
@ -92,17 +93,21 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
|
||||
for {
|
||||
rc, err := ln.Accept()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
log.Debugf("peer %s accepted", rc.RemoteAddr())
|
||||
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"local": rc.LocalAddr().String(),
|
||||
"remote": rc.RemoteAddr().String(),
|
||||
})
|
||||
sc, err := session.GetConn()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
@ -113,18 +118,17 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
|
||||
addr.ParseFrom(c.RemoteAddr().String())
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &addr)
|
||||
if err := reply.Write(sc); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String())
|
||||
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
|
||||
handler.Transport(sc, c)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String())
|
||||
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
|
||||
}(rc)
|
||||
}
|
||||
}
|
||||
|
@ -11,27 +11,28 @@ import (
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"cmd": "udp",
|
||||
})
|
||||
|
||||
if !h.md.enableUDP {
|
||||
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
|
||||
reply.Write(conn)
|
||||
h.logger.Debug(reply)
|
||||
h.logger.Error("UDP relay is diabled")
|
||||
log.Debug(reply)
|
||||
log.Error("UDP relay is diabled")
|
||||
return
|
||||
}
|
||||
|
||||
cc, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
reply := gosocks5.NewReply(gosocks5.Failure, nil)
|
||||
reply.Write(conn)
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
@ -42,41 +43,40 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
|
||||
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()),
|
||||
})
|
||||
h.logger.Debugf("bind on %s OK", cc.LocalAddr())
|
||||
log.Debugf("bind on %s OK", cc.LocalAddr())
|
||||
|
||||
// obtain a udp connection
|
||||
c, err := h.router.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
log.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
|
||||
WithBypass(h.options.Bypass).
|
||||
WithLogger(h.logger)
|
||||
WithLogger(log)
|
||||
relay.SetBufferSize(h.md.udpBufferSize)
|
||||
|
||||
go relay.Run()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
io.Copy(ioutil.Discard, conn)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
}
|
||||
|
@ -8,54 +8,53 @@ import (
|
||||
"github.com/go-gost/gosocks5"
|
||||
"github.com/go-gost/gost/pkg/common/util/socks"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"cmd": "udp-tun",
|
||||
})
|
||||
|
||||
if !h.md.enableUDP {
|
||||
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
|
||||
reply.Write(conn)
|
||||
h.logger.Debug(reply)
|
||||
h.logger.Error("UDP relay is diabled")
|
||||
log.Debug(reply)
|
||||
log.Error("UDP relay is diabled")
|
||||
return
|
||||
}
|
||||
|
||||
// dummy bind
|
||||
reply := gosocks5.NewReply(gosocks5.Succeeded, nil)
|
||||
if err := reply.Write(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Debug(reply)
|
||||
log.Debug(reply)
|
||||
|
||||
// obtain a udp connection
|
||||
c, err := h.router.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
pc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
log.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
|
||||
WithBypass(h.options.Bypass).
|
||||
WithLogger(h.logger)
|
||||
WithLogger(log)
|
||||
relay.SetBufferSize(h.md.udpBufferSize)
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
relay.Run()
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/common/util/ss"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/shadowsocks/go-shadowsocks2/core"
|
||||
@ -24,7 +23,6 @@ func init() {
|
||||
type ssHandler struct {
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -60,7 +58,6 @@ func (h *ssHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -69,14 +66,14 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -91,19 +88,19 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
addr := &gosocks5.Addr{}
|
||||
if _, err := addr.ReadFrom(conn); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
io.Copy(ioutil.Discard, conn)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": addr.String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
|
||||
h.logger.Info("bypass: ", addr.String())
|
||||
log.Info("bypass: ", addr.String())
|
||||
return
|
||||
}
|
||||
|
||||
@ -114,11 +111,9 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
handler.Transport(conn, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}
|
||||
|
@ -23,7 +23,6 @@ func init() {
|
||||
type ssuHandler struct {
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -60,7 +59,6 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -69,14 +67,14 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -99,26 +97,25 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
// obtain a udp connection
|
||||
c, err := h.router.Dial(ctx, "udp", "") // UDP association
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
cc, ok := c.(net.PacketConn)
|
||||
if !ok {
|
||||
h.logger.Errorf("wrong connection type")
|
||||
log.Errorf("wrong connection type")
|
||||
return
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
h.relayPacket(pc, cc)
|
||||
h.logger.
|
||||
WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
h.relayPacket(pc, cc, log)
|
||||
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
|
||||
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||
}
|
||||
|
||||
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
|
||||
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) {
|
||||
bufSize := h.md.bufferSize
|
||||
errc := make(chan error, 2)
|
||||
|
||||
@ -134,7 +131,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
|
||||
h.logger.Warn("bypass: ", addr)
|
||||
log.Warn("bypass: ", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -142,7 +139,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s >>> %s data: %d",
|
||||
log.Debugf("%s >>> %s data: %d",
|
||||
pc2.LocalAddr(), addr, n)
|
||||
return nil
|
||||
}()
|
||||
@ -166,7 +163,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
|
||||
h.logger.Warn("bypass: ", raddr)
|
||||
log.Warn("bypass: ", raddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -174,7 +171,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s <<< %s data: %d",
|
||||
log.Debugf("%s <<< %s data: %d",
|
||||
pc2.LocalAddr(), raddr, n)
|
||||
return nil
|
||||
}()
|
||||
|
238
pkg/handler/sshd/handler.go
Normal file
238
pkg/handler/sshd/handler.go
Normal file
@ -0,0 +1,238 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
|
||||
const (
|
||||
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterHandler("sshd", NewHandler)
|
||||
}
|
||||
|
||||
type forwardHandler struct {
|
||||
router *chain.Router
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
|
||||
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
options := handler.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &forwardHandler{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Init(md md.Metadata) (err error) {
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.router = &chain.Router{
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
Resolver: h.options.Resolver,
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
switch cc := conn.(type) {
|
||||
case *sshd_util.DirectForwardConn:
|
||||
h.handleDirectForward(ctx, cc, log)
|
||||
case *sshd_util.RemoteForwardConn:
|
||||
h.handleRemoteForward(ctx, cc, log)
|
||||
default:
|
||||
log.Error("wrong connection type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) {
|
||||
targetAddr := conn.DstAddr()
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", targetAddr, "tcp"),
|
||||
"cmd": "connect",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), targetAddr)
|
||||
|
||||
if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) {
|
||||
log.Infof("bypass %s", targetAddr)
|
||||
return
|
||||
}
|
||||
|
||||
cc, err := h.router.Dial(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
|
||||
handler.Transport(conn, cc)
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
|
||||
}
|
||||
|
||||
func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) {
|
||||
req := conn.Request()
|
||||
|
||||
t := tcpipForward{}
|
||||
ssh.Unmarshal(req.Payload, &t)
|
||||
|
||||
network := "tcp"
|
||||
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", addr, network),
|
||||
"cmd": "bind",
|
||||
})
|
||||
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
|
||||
|
||||
// tie to the client connection
|
||||
ln, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
|
||||
})
|
||||
log.Debugf("bind on %s OK", ln.Addr())
|
||||
|
||||
err = func() error {
|
||||
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
|
||||
_, port, err := getHostPortFromAddr(ln.Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], uint32(port))
|
||||
t.Port = uint32(port)
|
||||
return req.Reply(true, b[:])
|
||||
}
|
||||
return req.Reply(true, nil)
|
||||
}()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshConn := conn.Conn()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
cc, err := ln.Accept()
|
||||
if err != nil { // Unable to accept new connection - listener is likely closed
|
||||
return
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"local": conn.LocalAddr().String(),
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
p := directForward{}
|
||||
var err error
|
||||
|
||||
var portnum int
|
||||
p.Host1 = t.Host
|
||||
p.Port1 = t.Port
|
||||
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.Port2 = uint32(portnum)
|
||||
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
|
||||
if err != nil {
|
||||
log.Error("open forwarded channel: ", err)
|
||||
return
|
||||
}
|
||||
defer ch.Close()
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
t := time.Now()
|
||||
log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
handler.Transport(ch, conn)
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(t),
|
||||
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
|
||||
}(cc)
|
||||
}
|
||||
}()
|
||||
|
||||
tm := time.Now()
|
||||
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
|
||||
<-conn.Done()
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(tm),
|
||||
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
|
||||
}
|
||||
|
||||
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
||||
host, portString, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
port, err = strconv.Atoi(portString)
|
||||
return
|
||||
}
|
||||
|
||||
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
|
||||
type directForward struct {
|
||||
Host1 string
|
||||
Port1 uint32
|
||||
Host2 string
|
||||
Port2 uint32
|
||||
}
|
||||
|
||||
func (p directForward) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
|
||||
}
|
||||
|
||||
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
|
||||
type tcpipForward struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
12
pkg/handler/sshd/metadata.go
Normal file
12
pkg/handler/sshd/metadata.go
Normal file
@ -0,0 +1,12 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
}
|
||||
|
||||
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
return
|
||||
}
|
@ -33,7 +33,6 @@ type tapHandler struct {
|
||||
exit chan struct{}
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -71,7 +70,6 @@ func (h *tapHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -85,21 +83,22 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer os.Exit(0)
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger
|
||||
cc, ok := conn.(*tap_util.Conn)
|
||||
if !ok || cc.Config() == nil {
|
||||
h.logger.Error("invalid connection")
|
||||
log.Error("invalid connection")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -112,19 +111,19 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr())
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
|
||||
})
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
}
|
||||
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config())
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
|
||||
}
|
||||
|
||||
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config) {
|
||||
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) {
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
err := func() error {
|
||||
@ -154,10 +153,10 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
||||
pc = h.cipher.PacketConn(pc)
|
||||
}
|
||||
|
||||
return h.transport(conn, pc, addr)
|
||||
return h.transport(conn, pc, addr, log)
|
||||
}()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
select {
|
||||
@ -183,7 +182,7 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
||||
|
||||
}
|
||||
|
||||
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr) error {
|
||||
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error {
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
@ -205,7 +204,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
dst := waterutil.MACDestination((*b)[:n])
|
||||
eType := etherType(waterutil.MACEthertype((*b)[:n]))
|
||||
|
||||
h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
|
||||
// client side, deliver frame directly.
|
||||
if raddr != nil {
|
||||
@ -227,7 +226,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
addr = v.(net.Addr)
|
||||
}
|
||||
if addr == nil {
|
||||
h.logger.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
|
||||
log.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -261,7 +260,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
dst := waterutil.MACDestination((*b)[:n])
|
||||
eType := etherType(waterutil.MACEthertype((*b)[:n]))
|
||||
|
||||
h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
|
||||
|
||||
// client side, deliver frame to tap device.
|
||||
if raddr != nil {
|
||||
@ -273,12 +272,12 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
rkey := hwAddrToTapRouteKey(src)
|
||||
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
|
||||
if actual.(net.Addr).String() != addr.String() {
|
||||
h.logger.Debugf("update route: %s -> %s (old %s)",
|
||||
log.Debugf("update route: %s -> %s (old %s)",
|
||||
src, addr, actual.(net.Addr))
|
||||
h.routes.Store(rkey, addr)
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("new route: %s -> %s", src, addr)
|
||||
log.Debugf("new route: %s -> %s", src, addr)
|
||||
}
|
||||
|
||||
if waterutil.IsBroadcast(dst) {
|
||||
@ -291,7 +290,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
}
|
||||
|
||||
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
|
||||
h.logger.Debugf("find route: %s -> %s", dst, v)
|
||||
log.Debugf("find route: %s -> %s", dst, v)
|
||||
_, err := conn.WriteTo((*b)[:n], v.(net.Addr))
|
||||
return err
|
||||
}
|
||||
|
@ -35,7 +35,6 @@ type tunHandler struct {
|
||||
exit chan struct{}
|
||||
cipher core.Cipher
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -73,7 +72,6 @@ func (h *tunHandler) Init(md md.Metadata) (err error) {
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
return
|
||||
}
|
||||
@ -87,21 +85,23 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer os.Exit(0)
|
||||
defer conn.Close()
|
||||
|
||||
log := h.options.Logger
|
||||
|
||||
cc, ok := conn.(*tun_util.Conn)
|
||||
if !ok || cc.Config() == nil {
|
||||
h.logger.Error("invalid connection")
|
||||
log.Error("invalid connection")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -114,19 +114,19 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr())
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log = log.WithFields(map[string]interface{}{
|
||||
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
|
||||
})
|
||||
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
|
||||
}
|
||||
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config())
|
||||
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
|
||||
}
|
||||
|
||||
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config) {
|
||||
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) {
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
err := func() error {
|
||||
@ -155,10 +155,10 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
||||
pc = h.cipher.PacketConn(pc)
|
||||
}
|
||||
|
||||
return h.transport(conn, pc, addr)
|
||||
return h.transport(conn, pc, addr, log)
|
||||
}()
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
select {
|
||||
@ -184,7 +184,7 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
||||
|
||||
}
|
||||
|
||||
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr) error {
|
||||
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error {
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
@ -206,10 +206,10 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
if waterutil.IsIPv4((*b)[:n]) {
|
||||
header, err := ipv4.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
|
||||
@ -217,17 +217,17 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||
header, err := ipv6.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("%s >> %s %s %d %d",
|
||||
log.Debugf("%s >> %s %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else {
|
||||
h.logger.Warn("unknown packet, discarded")
|
||||
log.Warn("unknown packet, discarded")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -239,11 +239,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
|
||||
addr := h.findRouteFor(dst)
|
||||
if addr == nil {
|
||||
h.logger.Warnf("no route for %s -> %s", src, dst)
|
||||
log.Warnf("no route for %s -> %s", src, dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
h.logger.Debugf("find route: %s -> %s", dst, addr)
|
||||
log.Debugf("find route: %s -> %s", dst, addr)
|
||||
|
||||
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
|
||||
return err
|
||||
@ -274,11 +274,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
if waterutil.IsIPv4((*b)[:n]) {
|
||||
header, err := ipv4.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
|
||||
@ -286,18 +286,18 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||
header, err := ipv6.ParseHeader((*b)[:n])
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
log.Warn(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
h.logger.Debugf("%s > %s %s %d %d",
|
||||
log.Debugf("%s > %s %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
|
||||
src, dst = header.Src, header.Dst
|
||||
} else {
|
||||
h.logger.Warn("unknown packet, discarded")
|
||||
log.Warn("unknown packet, discarded")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -310,16 +310,16 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
|
||||
rkey := ipToTunRouteKey(src)
|
||||
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
|
||||
if actual.(net.Addr).String() != addr.String() {
|
||||
h.logger.Debugf("update route: %s -> %s (old %s)",
|
||||
log.Debugf("update route: %s -> %s (old %s)",
|
||||
src, addr, actual.(net.Addr))
|
||||
h.routes.Store(rkey, addr)
|
||||
}
|
||||
} else {
|
||||
h.logger.Warnf("no route for %s -> %s", src, addr)
|
||||
log.Warnf("no route for %s -> %s", src, addr)
|
||||
}
|
||||
|
||||
if addr := h.findRouteFor(dst); addr != nil {
|
||||
h.logger.Debugf("find route: %s -> %s", dst, addr)
|
||||
log.Debugf("find route: %s -> %s", dst, addr)
|
||||
|
||||
_, err := conn.WriteTo((*b)[:n], addr)
|
||||
return err
|
||||
|
118
pkg/internal/util/sshd/conn.go
Normal file
118
pkg/internal/util/sshd/conn.go
Normal file
@ -0,0 +1,118 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type DirectForwardConn struct {
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
dstAddr string
|
||||
}
|
||||
|
||||
func NewDirectForwardConn(conn ssh.Conn, channel ssh.Channel, dstAddr string) net.Conn {
|
||||
return &DirectForwardConn{
|
||||
conn: conn,
|
||||
channel: channel,
|
||||
dstAddr: dstAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Read(b []byte) (n int, err error) {
|
||||
return c.channel.Read(b)
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Write(b []byte) (n int, err error) {
|
||||
return c.channel.Write(b)
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) Close() error {
|
||||
return c.channel.Close()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *DirectForwardConn) DstAddr() string {
|
||||
return c.dstAddr
|
||||
}
|
||||
|
||||
type RemoteForwardConn struct {
|
||||
ctx context.Context
|
||||
conn ssh.Conn
|
||||
req *ssh.Request
|
||||
}
|
||||
|
||||
func NewRemoteForwardConn(ctx context.Context, conn ssh.Conn, req *ssh.Request) net.Conn {
|
||||
return &RemoteForwardConn{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
req: req,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Conn() ssh.Conn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Request() *ssh.Request {
|
||||
return c.req
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Read(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Write(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Close() error {
|
||||
return &net.OpError{Op: "close", Net: "nop", Source: nil, Addr: nil, Err: errors.New("close not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *RemoteForwardConn) Done() <-chan struct{} {
|
||||
return c.ctx.Done()
|
||||
}
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
func init() {
|
||||
registry.RegisterListener("http3", NewListener)
|
||||
registry.RegisterListener("h3", NewListener)
|
||||
}
|
||||
|
||||
type phtListener struct {
|
||||
|
@ -3,6 +3,7 @@ package ssh
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
|
||||
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
|
||||
@ -29,13 +30,14 @@ type sshListener struct {
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := &listener.Options{}
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
opt(&options)
|
||||
}
|
||||
return &sshListener{
|
||||
addr: options.Addr,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
@ -96,6 +98,14 @@ func (l *sshListener) listenLoop() {
|
||||
}
|
||||
|
||||
func (l *sshListener) serveConn(conn net.Conn) {
|
||||
start := time.Now()
|
||||
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
l.logger.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
@ -122,8 +132,9 @@ func (l *sshListener) serveConn(conn net.Conn) {
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
cc.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
default:
|
||||
|
199
pkg/listener/sshd/listener.go
Normal file
199
pkg/listener/sshd/listener.go
Normal file
@ -0,0 +1,199 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
|
||||
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
|
||||
sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd"
|
||||
"github.com/go-gost/gost/pkg/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
|
||||
const (
|
||||
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
|
||||
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterListener("sshd", NewListener)
|
||||
}
|
||||
|
||||
type sshdListener struct {
|
||||
addr string
|
||||
net.Listener
|
||||
config *ssh.ServerConfig
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &sshdListener{
|
||||
addr: options.Addr,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshdListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.Listener = ln
|
||||
|
||||
authenticator := auth_util.AuthFromUsers(l.options.Auths...)
|
||||
config := &ssh.ServerConfig{
|
||||
PasswordCallback: ssh_util.PasswordCallback(authenticator),
|
||||
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
|
||||
}
|
||||
config.AddHostKey(l.md.signer)
|
||||
if authenticator == nil && len(l.md.authorizedKeys) == 0 {
|
||||
config.NoClientAuth = true
|
||||
}
|
||||
|
||||
l.config = config
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshdListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *sshdListener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *sshdListener) serveConn(conn net.Conn) {
|
||||
start := time.Now()
|
||||
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
l.logger.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
defer sc.Close()
|
||||
|
||||
go func() {
|
||||
for newChannel := range chans {
|
||||
// Check the type of channel
|
||||
t := newChannel.ChannelType()
|
||||
switch t {
|
||||
case DirectForwardRequest:
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
l.logger.Warnf("could not accept channel: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
p := directForward{}
|
||||
ssh.Unmarshal(newChannel.ExtraData(), &p)
|
||||
|
||||
l.logger.Debug(p.String())
|
||||
|
||||
if p.Host1 == "<nil>" {
|
||||
p.Host1 = ""
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
cc := sshd_util.NewDirectForwardConn(sc, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
|
||||
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
default:
|
||||
l.logger.Warnf("unsupported channel type: %s", t)
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
switch req.Type {
|
||||
case RemoteForwardRequest:
|
||||
cc := sshd_util.NewRemoteForwardConn(ctx, sc, req)
|
||||
|
||||
select {
|
||||
case l.cqueue <- cc:
|
||||
default:
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
|
||||
req.Reply(false, []byte("connection queue is full"))
|
||||
cc.Close()
|
||||
}
|
||||
default:
|
||||
l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
sc.Wait()
|
||||
}
|
||||
|
||||
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
|
||||
type directForward struct {
|
||||
Host1 string
|
||||
Port1 uint32
|
||||
Host2 string
|
||||
Port2 uint32
|
||||
}
|
||||
|
||||
func (p directForward) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
|
||||
}
|
@ -9,16 +9,22 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
signer ssh.Signer
|
||||
authorizedKeys map[string]bool
|
||||
backlog int
|
||||
}
|
||||
|
||||
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
authorizedKeys = "authorizedKeys"
|
||||
privateKeyFile = "privateKeyFile"
|
||||
passphrase = "passphrase"
|
||||
backlog = "backlog"
|
||||
)
|
||||
|
||||
if key := mdata.GetString(md, privateKeyFile); key != "" {
|
||||
@ -29,20 +35,20 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
|
||||
pp := mdata.GetString(md, passphrase)
|
||||
if pp == "" {
|
||||
h.md.signer, err = ssh.ParsePrivateKey(data)
|
||||
l.md.signer, err = ssh.ParsePrivateKey(data)
|
||||
} else {
|
||||
h.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
|
||||
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if h.md.signer == nil {
|
||||
if l.md.signer == nil {
|
||||
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.md.signer = signer
|
||||
l.md.signer = signer
|
||||
}
|
||||
|
||||
if name := mdata.GetString(md, authorizedKeys); name != "" {
|
||||
@ -50,7 +56,12 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.md.authorizedKeys = m
|
||||
l.md.authorizedKeys = m
|
||||
}
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
return
|
Loading…
Reference in New Issue
Block a user