add sshd listener

This commit is contained in:
ginuerzh 2022-01-26 15:53:33 +08:00
parent a134026e76
commit 04dfc8c4c3
39 changed files with 1101 additions and 848 deletions

View File

@ -274,6 +274,13 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
Metadata: md,
}
if svc.Handler.Type == "sshd" {
svc.Handler.Auths = nil
}
if svc.Listener.Type == "sshd" {
svc.Listener.Auths = auths
}
return svc, nil
}
@ -354,6 +361,13 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
Metadata: md,
}
if node.Connector.Type == "sshd" {
node.Connector.Auth = nil
}
if node.Dialer.Type == "sshd" {
node.Dialer.Auth = auth
}
return node, nil
}

View File

@ -91,6 +91,7 @@ func buildService(cfg *config.Config) (services []*service.Service) {
ln := registry.GetListener(svc.Listener.Type)(
listener.AddrOption(svc.Addr),
listener.ChainOption(chains[svc.Listener.Chain]),
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
listener.TLSConfigOption(tlsConfig),
listener.LoggerOption(listenerLogger),

View File

@ -37,7 +37,6 @@ import (
_ "github.com/go-gost/gost/pkg/handler/dns"
_ "github.com/go-gost/gost/pkg/handler/forward/local"
_ "github.com/go-gost/gost/pkg/handler/forward/remote"
_ "github.com/go-gost/gost/pkg/handler/forward/ssh"
_ "github.com/go-gost/gost/pkg/handler/http"
_ "github.com/go-gost/gost/pkg/handler/http2"
_ "github.com/go-gost/gost/pkg/handler/redirect"
@ -47,6 +46,7 @@ import (
_ "github.com/go-gost/gost/pkg/handler/socks/v5"
_ "github.com/go-gost/gost/pkg/handler/ss"
_ "github.com/go-gost/gost/pkg/handler/ss/udp"
_ "github.com/go-gost/gost/pkg/handler/sshd"
_ "github.com/go-gost/gost/pkg/handler/tap"
_ "github.com/go-gost/gost/pkg/handler/tun"
@ -65,6 +65,7 @@ import (
_ "github.com/go-gost/gost/pkg/listener/rtcp"
_ "github.com/go-gost/gost/pkg/listener/rudp"
_ "github.com/go-gost/gost/pkg/listener/ssh"
_ "github.com/go-gost/gost/pkg/listener/sshd"
_ "github.com/go-gost/gost/pkg/listener/tap"
_ "github.com/go-gost/gost/pkg/listener/tcp"
_ "github.com/go-gost/gost/pkg/listener/tls"

View File

@ -15,6 +15,7 @@ import (
func init() {
registry.RegisterDialer("http3", NewDialer)
registry.RegisterDialer("h3", NewDialer)
}
type http3Dialer struct {

View File

@ -164,7 +164,7 @@ func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer
func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) {
config := ssh.ClientConfig{
// Timeout: timeout,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
if d.md.user != nil {

View File

@ -9,7 +9,6 @@ import (
"github.com/go-gost/gosocks4"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/relay"
@ -24,42 +23,37 @@ type autoHandler struct {
socks4Handler handler.Handler
socks5Handler handler.Handler
relayHandler handler.Handler
log logger.Logger
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := &handler.Options{}
options := handler.Options{}
for _, opt := range opts {
opt(options)
}
log := options.Logger
if log == nil {
log = logger.Default()
opt(&options)
}
h := &autoHandler{
log: log,
options: options,
}
if f := registry.GetHandler("http"); f != nil {
v := append(opts,
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"})))
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "http"})))
h.httpHandler = f(v...)
}
if f := registry.GetHandler("socks4"); f != nil {
v := append(opts,
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks4"})))
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks4"})))
h.socks4Handler = f(v...)
}
if f := registry.GetHandler("socks5"); f != nil {
v := append(opts,
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks5"})))
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "socks5"})))
h.socks5Handler = f(v...)
}
if f := registry.GetHandler("relay"); f != nil {
v := append(opts,
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "relay"})))
handler.LoggerOption(options.Logger.WithFields(map[string]interface{}{"type": "relay"})))
h.relayHandler = f(v...)
}
@ -92,15 +86,15 @@ func (h *autoHandler) Init(md md.Metadata) error {
}
func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
h.log = h.log.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
start := time.Now()
h.log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.log.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -108,7 +102,7 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
br := bufio.NewReader(conn)
b, err := br.Peek(1)
if err != nil {
h.log.Error(err)
log.Error(err)
conn.Close()
return
}
@ -132,5 +126,4 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) {
h.httpHandler.Handle(ctx, conn)
}
}
}

View File

@ -33,7 +33,6 @@ type dnsHandler struct {
exchangers []exchanger.Exchanger
cache *resolver_util.Cache
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -50,19 +49,18 @@ func NewHandler(opts ...handler.Option) handler.Handler {
}
func (h *dnsHandler) Init(md md.Metadata) (err error) {
h.logger = h.options.Logger
if err = h.parseMetadata(md); err != nil {
return
}
log := h.options.Logger
h.cache = resolver_util.NewCache().WithLogger(h.options.Logger)
h.cache = resolver_util.NewCache().WithLogger(log)
h.router = &chain.Router{
Retries: h.options.Retries,
Chain: h.options.Chain,
Resolver: h.options.Resolver,
// Hosts: h.options.Hosts,
Logger: h.options.Logger,
Logger: log,
}
for _, server := range h.md.dns {
@ -74,10 +72,10 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
server,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
exchanger.LoggerOption(log),
)
if err != nil {
h.logger.Warnf("parse %s: %v", server, err)
log.Warnf("parse %s: %v", server, err)
continue
}
h.exchangers = append(h.exchangers, ex)
@ -87,9 +85,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
defaultNameserver,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
exchanger.LoggerOption(log),
)
h.logger.Warnf("resolver not found, default to %s", defaultNameserver)
log.Warnf("resolver not found, default to %s", defaultNameserver)
if err != nil {
return err
}
@ -103,14 +101,14 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -120,26 +118,25 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
n, err := conn.Read(*b)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Info("read data: ", n)
reply, err := h.exchange(ctx, (*b)[:n])
reply, err := h.exchange(ctx, (*b)[:n], log)
if err != nil {
return
}
defer bufpool.Put(&reply)
if _, err = conn.Write(reply); err != nil {
h.logger.Error(err)
log.Error(err)
}
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
@ -149,23 +146,23 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mq.String())
if log.IsLevelEnabled(logger.DebugLevel) {
log.Debug(mq.String())
} else {
h.logger.Info(h.dumpMsgHeader(&mq))
log.Info(h.dumpMsgHeader(&mq))
}
var mr *dns.Msg
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
defer func() {
if mr != nil {
h.logger.Debug(mr.String())
log.Debug(mr.String())
}
}()
}
mr = h.lookupHosts(&mq)
mr = h.lookupHosts(&mq, log)
if mr != nil {
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
@ -176,7 +173,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
b := bufpool.Get(4096)
@ -195,18 +192,18 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
query, err := mq.PackBuffer(*b)
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
log.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
h.logger.Error(err)
log.Error(err)
}
if err != nil {
return nil, err
@ -214,21 +211,21 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mr.String())
if log.IsLevelEnabled(logger.DebugLevel) {
log.Debug(mr.String())
} else {
h.logger.Info(h.dumpMsgHeader(mr))
log.Info(h.dumpMsgHeader(mr))
}
return reply, nil
}
// lookup host mapper
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
if h.options.Hosts == nil ||
r.Question[0].Qclass != dns.ClassINET ||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
@ -246,12 +243,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
@ -262,12 +259,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
@ -22,7 +21,6 @@ func init() {
type forwardHandler struct {
group *chain.NodeGroup
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -55,7 +53,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -69,21 +66,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
target := h.group.Next()
if target == nil {
h.logger.Error("no target available")
log.Error("no target available")
return
}
@ -92,15 +89,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
network = "udp"
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
cc, err := h.router.Dial(ctx, network, target.Addr())
if err != nil {
h.logger.Error(err)
log.Error(err)
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
target.Marker().Mark()
@ -110,11 +107,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
target.Marker().Reset()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
@ -21,7 +20,6 @@ func init() {
type forwardHandler struct {
group *chain.NodeGroup
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -49,7 +47,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -63,21 +60,21 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
target := h.group.Next()
if target == nil {
h.logger.Error("no target available")
log.Error("no target available")
return
}
@ -86,15 +83,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
network = "udp"
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
cc, err := h.router.Dial(ctx, network, target.Addr())
if err != nil {
h.logger.Error(err)
log.Error(err)
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
target.Marker().Mark()
@ -104,11 +101,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
target.Marker().Reset()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}

View File

@ -1,300 +0,0 @@
package ssh
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/chain"
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
"github.com/go-gost/gost/pkg/handler"
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"golang.org/x/crypto/ssh"
)
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
const (
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1
)
func init() {
registry.RegisterHandler("sshd", NewHandler)
}
type forwardHandler struct {
config *ssh.ServerConfig
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &forwardHandler{
options: options,
}
}
func (h *forwardHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
authenticator := auth_util.AuthFromUsers(h.options.Auths...)
config := &ssh.ServerConfig{
PasswordCallback: ssh_util.PasswordCallback(authenticator),
PublicKeyCallback: ssh_util.PublicKeyCallback(h.md.authorizedKeys),
}
config.AddHostKey(h.md.signer)
if authenticator == nil && len(h.md.authorizedKeys) == 0 {
config.NoClientAuth = true
}
h.config = config
h.router = &chain.Router{
Retries: h.options.Retries,
Chain: h.options.Chain,
Resolver: h.options.Resolver,
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return nil
}
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config)
if err != nil {
h.logger.Error(err)
return
}
h.handleForward(ctx, sshConn, chans, reqs)
}
func (h *forwardHandler) handleForward(ctx context.Context, conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
quit := make(chan struct{})
defer close(quit) // quit signal
go func() {
for req := range reqs {
switch req.Type {
case RemoteForwardRequest:
go h.tcpipForwardRequest(conn, req, quit)
default:
h.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
if req.WantReply {
req.Reply(false, nil)
}
}
}
}()
go func() {
for newChannel := range chans {
// Check the type of channel
t := newChannel.ChannelType()
switch t {
case DirectForwardRequest:
channel, requests, err := newChannel.Accept()
if err != nil {
h.logger.Warnf("could not accept channel: %s", err.Error())
continue
}
p := directForward{}
ssh.Unmarshal(newChannel.ExtraData(), &p)
h.logger.Debug(p.String())
if p.Host1 == "<nil>" {
p.Host1 = ""
}
go ssh.DiscardRequests(requests)
go h.directPortForwardChannel(ctx, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
default:
h.logger.Warnf("unsupported channel type: %s", t)
newChannel.Reject(ssh.Prohibited, fmt.Sprintf("unsupported channel type: %s", t))
}
}
}()
conn.Wait()
}
func (h *forwardHandler) directPortForwardChannel(ctx context.Context, channel ssh.Channel, raddr string) {
defer channel.Close()
// log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr)
/*
if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr)
return
}
*/
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr) {
h.logger.Infof("bypass %s", raddr)
return
}
conn, err := h.router.Dial(ctx, "tcp", raddr)
if err != nil {
return
}
defer conn.Close()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
handler.Transport(conn, channel)
h.logger.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
}
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
type directForward struct {
Host1 string
Port1 uint32
Host2 string
Port2 uint32
}
func (p directForward) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
}
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
host, portString, err := net.SplitHostPort(addr.String())
if err != nil {
return
}
port, err = strconv.Atoi(portString)
return
}
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
type tcpipForward struct {
Host string
Port uint32
}
func (h *forwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) {
t := tcpipForward{}
ssh.Unmarshal(req.Payload, &t)
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
/*
if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr)
req.Reply(false, nil)
return
}
*/
// tie to the client connection
ln, err := net.Listen("tcp", addr)
if err != nil {
h.logger.Error(err)
req.Reply(false, nil)
return
}
defer ln.Close()
h.logger.Debugf("bind on %s OK", ln.Addr())
err = func() error {
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
_, port, err := getHostPortFromAddr(ln.Addr())
if err != nil {
return err
}
var b [4]byte
binary.BigEndian.PutUint32(b[:], uint32(port))
t.Port = uint32(port)
return req.Reply(true, b[:])
}
return req.Reply(true, nil)
}()
if err != nil {
h.logger.Error(err)
return
}
go func() {
for {
conn, err := ln.Accept()
if err != nil { // Unable to accept new connection - listener is likely closed
return
}
go func(conn net.Conn) {
defer conn.Close()
p := directForward{}
var err error
var portnum int
p.Host1 = t.Host
p.Port1 = t.Port
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
if err != nil {
return
}
p.Port2 = uint32(portnum)
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
if err != nil {
h.logger.Error("open forwarded channel: ", err)
return
}
defer ch.Close()
go ssh.DiscardRequests(reqs)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
handler.Transport(ch, conn)
h.logger.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), conn.LocalAddr())
}(conn)
}
}()
<-quit
}

View File

@ -32,7 +32,6 @@ func init() {
type httpHandler struct {
router *chain.Router
authenticator auth.Authenticator
logger logger.Logger
md metadata
options handler.Options
}
@ -61,7 +60,6 @@ func (h *httpHandler) Init(md md.Metadata) error {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return nil
}
@ -70,28 +68,28 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer req.Body.Close()
h.handleRequest(ctx, conn, req)
h.handleRequest(ctx, conn, req, log)
}
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) {
func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) {
if req == nil {
return
}
@ -129,16 +127,16 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
fields := map[string]interface{}{
"dst": addr,
}
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" {
fields["user"] = u
}
h.logger = h.logger.WithFields(fields)
log = log.WithFields(fields)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
resp := &http.Response{
ProtoMajor: 1,
@ -152,22 +150,22 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
resp.StatusCode = http.StatusForbidden
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
h.logger.Info("bypass: ", addr)
log.Info("bypass: ", addr)
resp.Write(conn)
return
}
if !h.authenticate(conn, req, resp) {
if !h.authenticate(conn, req, resp, log) {
return
}
if network == "udp" {
h.handleUDP(ctx, conn, network, req.Host)
h.handleUDP(ctx, conn, network, req.Host, log)
return
}
@ -176,9 +174,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
resp.StatusCode = http.StatusBadRequest
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
return
@ -191,9 +189,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
resp.StatusCode = http.StatusServiceUnavailable
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
return
}
@ -203,30 +201,28 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
resp.StatusCode = http.StatusOK
resp.Status = "200 Connection established"
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
if err = resp.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
} else {
req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
}
start := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).
Infof("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
}
func (h *httpHandler) decodeServerName(s string) (string, error) {
@ -247,7 +243,7 @@ func (h *httpHandler) decodeServerName(s string) (string, error) {
return string(v), nil
}
func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) {
if proxyAuth == "" {
return
}
@ -268,8 +264,8 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password strin
return cs[:s], cs[s+1:], true
}
func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) {
u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"))
func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log)
if h.authenticator == nil || h.authenticator.Authenticate(u, p) {
return true
}
@ -289,7 +285,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
}
r, err := http.Get(url)
if err != nil {
h.logger.Error(err)
log.Error(err)
break
}
resp = r
@ -297,7 +293,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
case "host":
cc, err := net.Dial("tcp", pr.Value)
if err != nil {
h.logger.Error(err)
log.Error(err)
break
}
defer cc.Close()
@ -333,7 +329,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
resp.Header.Add("Proxy-Connection", "close")
}
h.logger.Info("proxy authentication required")
log.Info("proxy authentication required")
} else {
resp.Header.Set("Server", "nginx/1.20.1")
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
@ -342,9 +338,9 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
}
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
resp.Write(conn)

View File

@ -12,8 +12,8 @@ import (
"github.com/go-gost/gost/pkg/logger"
)
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"cmd": "udp",
})
@ -30,49 +30,47 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
resp.StatusCode = http.StatusForbidden
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
h.logger.Error("UDP relay is diabled")
log.Error("UDP relay is diabled")
return
}
resp.StatusCode = http.StatusOK
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
if err := resp.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
// obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer c.Close()
pc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
log.Errorf("wrong connection type")
return
}
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass).
WithLogger(h.logger)
WithLogger(log)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run()
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}

View File

@ -35,7 +35,6 @@ func init() {
type http2Handler struct {
router *chain.Router
authenticator auth.Authenticator
logger logger.Logger
md metadata
options handler.Options
}
@ -64,7 +63,6 @@ func (h *http2Handler) Init(md md.Metadata) error {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return nil
}
@ -72,29 +70,29 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
cc, ok := conn.(*http2_util.ServerConn)
if !ok {
h.logger.Error("wrong connection type")
log.Error("wrong connection type")
return
}
h.roundTrip(ctx, cc.Writer(), cc.Request())
h.roundTrip(ctx, cc.Writer(), cc.Request(), log)
}
// NOTE: there is an issue (golang/go#43989) will cause the client hangs
// when server returns an non-200 status code,
// May be fixed in go1.18.
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request) {
func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request, log logger.Logger) {
// Try to get the actual host.
// Compatible with GOST 2.x.
if v := req.Header.Get("Gost-Target"); v != "" {
@ -122,21 +120,21 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
fields["user"] = u
}
h.logger = h.logger.WithFields(fields)
log = log.WithFields(fields)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
h.logger.Infof("%s >> %s", req.RemoteAddr, addr)
log.Infof("%s >> %s", req.RemoteAddr, addr)
if h.md.proxyAgent != "" {
w.Header().Set("Proxy-Agent", h.md.proxyAgent)
for k := range h.md.header {
w.Header().Set(k, h.md.header.Get(k))
}
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
w.WriteHeader(http.StatusForbidden)
h.logger.Info("bypass: ", addr)
log.Info("bypass: ", addr)
return
}
@ -147,7 +145,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
if !h.authenticate(w, req, resp) {
if !h.authenticate(w, req, resp, log) {
return
}
@ -157,7 +155,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
cc, err := h.router.Dial(ctx, "tcp", addr)
if err != nil {
h.logger.Error(err)
log.Error(err)
w.WriteHeader(http.StatusServiceUnavailable)
return
}
@ -174,30 +172,28 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
// we take over the underly connection
conn, _, err := hj.Hijack()
if err != nil {
h.logger.Error(err)
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer conn.Close()
start := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).
Infof("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return
}
start := time.Now()
h.logger.Infof("%s <-> %s", req.RemoteAddr, addr)
log.Infof("%s <-> %s", req.RemoteAddr, addr)
handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).
Infof("%s >-< %s", req.RemoteAddr, addr)
}).Infof("%s >-< %s", req.RemoteAddr, addr)
return
}
}
@ -241,7 +237,7 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri
return cs[:s], cs[s+1:], true
}
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) {
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) {
u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if h.authenticator == nil || h.authenticator.Authenticate(u, p) {
return true
@ -261,7 +257,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
}
r, err := http.Get(url)
if err != nil {
h.logger.Error(err)
log.Error(err)
break
}
resp = r
@ -269,13 +265,13 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
case "host":
cc, err := net.Dial("tcp", pr.Value)
if err != nil {
h.logger.Error(err)
log.Error(err)
break
}
defer cc.Close()
if err := h.forwardRequest(w, r, cc); err != nil {
h.logger.Error(err)
log.Error(err)
}
return
case "file":
@ -303,7 +299,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
resp.Header.Add("Proxy-Connection", "close")
}
h.logger.Info("proxy authentication required")
log.Info("proxy authentication required")
} else {
resp.Header = http.Header{}
resp.Header.Set("Server", "nginx/1.20.1")
@ -313,9 +309,9 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp
}
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
h.logger.Debug(string(dump))
log.Debug(string(dump))
}
h.writeResponse(w, resp)

View File

@ -1,28 +1,31 @@
package http2
import (
"net/http"
"strings"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
proxyAgent string
probeResistance *probeResistance
sni bool
enableUDP bool
header http.Header
}
func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
const (
proxyAgent = "proxyAgent"
header = "header"
probeResistKey = "probeResistance"
knock = "knock"
sni = "sni"
enableUDP = "udp"
)
h.md.proxyAgent = mdata.GetString(md, proxyAgent)
if m := mdata.GetStringMapString(md, header); len(m) > 0 {
hd := http.Header{}
for k, v := range m {
hd.Add(k, v)
}
h.md.header = hd
}
if v := mdata.GetString(md, probeResistKey); v != "" {
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 {
@ -33,8 +36,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
}
}
}
h.md.sni = mdata.GetBool(md, sni)
h.md.enableUDP = mdata.GetBool(md, enableUDP)
return nil
}

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
@ -22,7 +21,6 @@ func init() {
type redirectHandler struct {
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -50,7 +48,6 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -59,14 +56,14 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -83,35 +80,33 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) {
if network == "tcp" {
dstAddr, conn, err = h.getOriginalDstAddr(conn)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", dstAddr, network),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(dstAddr.String()) {
h.logger.Info("bypass: ", dstAddr)
log.Info("bypass: ", dstAddr)
return
}
cc, err := h.router.Dial(ctx, network, dstAddr.String())
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer cc.Close()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
}

View File

@ -13,7 +13,6 @@ func (h *redirectHandler) getOriginalDstAddr(conn net.Conn) (addr net.Addr, c ne
tc, ok := conn.(*net.TCPConn)
if !ok {
err = errors.New("wrong connection type, must be TCP")
h.logger.Error(err)
return
}

View File

@ -10,16 +10,17 @@ import (
"github.com/go-gost/gost/pkg/common/util/mux"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "bind",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
log.Infof("%s >> %s", conn.RemoteAddr(), address)
resp := relay.Response{
Version: relay.Version1,
@ -29,18 +30,18 @@ func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, a
if !h.md.enableBind {
resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
h.logger.Error("BIND is diabled")
log.Error("BIND is diabled")
return
}
if network == "tcp" {
h.bindTCP(ctx, conn, network, address)
h.bindTCP(ctx, conn, network, address, log)
} else {
h.bindUDP(ctx, conn, network, address)
h.bindUDP(ctx, conn, network, address, log)
}
}
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string) {
func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
@ -48,7 +49,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil {
h.logger.Error(err)
log.Error(err)
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
return
@ -57,7 +58,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
af := &relay.AddrFeature{}
err = af.ParseFrom(ln.Addr().String())
if err != nil {
h.logger.Warn(err)
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
@ -65,20 +66,20 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr
af.AType = relay.AddrIPv4
resp.Features = append(resp.Features, af)
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
log.Error(err)
ln.Close()
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
h.logger.Debugf("bind on %s OK", ln.Addr())
log.Debugf("bind on %s OK", ln.Addr())
h.serveTCPBind(ctx, conn, ln)
h.serveTCPBind(ctx, conn, ln, log)
}
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string) {
func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
@ -87,7 +88,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
bindAddr, _ := net.ResolveUDPAddr(network, address)
pc, err := net.ListenUDP(network, bindAddr)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer pc.Close()
@ -95,7 +96,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
af := &relay.AddrFeature{}
err = af.ParseFrom(pc.LocalAddr().String())
if err != nil {
h.logger.Warn(err)
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
@ -103,33 +104,32 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
af.AType = relay.AddrIPv4
resp.Features = append(resp.Features, af)
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"bind": pc.LocalAddr().String(),
})
h.logger.Debugf("bind on %s OK", pc.LocalAddr())
log.Debugf("bind on %s OK", pc.LocalAddr())
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
h.tunnelServerUDP(
socks.UDPTunServerConn(conn),
pc,
log,
)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener) {
func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
// Upgrade connection to multiplex stream.
session, err := mux.ClientSession(conn)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer session.Close()
@ -139,7 +139,7 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
for {
conn, err := session.Accept()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
conn.Close() // we do not handle incoming connections.
@ -149,17 +149,22 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
for {
rc, err := ln.Accept()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
log.Debugf("peer %s accepted", rc.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
log = log.WithFields(map[string]interface{}{
"local": ln.Addr().String(),
"remote": c.RemoteAddr().String(),
})
sc, err := session.GetConn()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer sc.Close()
@ -172,21 +177,20 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
Features: []relay.Feature{af},
}
if _, err := resp.WriteTo(sc); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String())
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
handler.Transport(sc, c)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String())
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
}(rc)
}
}
func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn, log logger.Logger) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
@ -202,7 +206,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
}
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
log.Warn("bypass: ", raddr)
return nil
}
@ -210,7 +214,7 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
log.Debugf("%s >>> %s data: %d",
c.LocalAddr(), raddr, n)
return nil
@ -235,14 +239,14 @@ func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
}
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
log.Warn("bypass: ", raddr)
return nil
}
if _, err := tunnel.WriteTo((*b)[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
log.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil

View File

@ -7,16 +7,17 @@ import (
"time"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
log.Infof("%s >> %s", conn.RemoteAddr(), address)
resp := relay.Response{
Version: relay.Version1,
@ -26,12 +27,12 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if address == "" {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
h.logger.Error("target not specified")
log.Error("target not specified")
return
}
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
h.logger.Info("bypass: ", address)
log.Info("bypass: ", address)
resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
return
@ -47,7 +48,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
}
@ -78,11 +79,9 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address)
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), address)
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
}

View File

@ -7,10 +7,11 @@ import (
"time"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/relay"
)
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) {
func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
@ -19,15 +20,16 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if target == nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
h.logger.Error("no target available")
log.Error("no target available")
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", target.Addr(), network),
"cmd": "forward",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
cc, err := h.router.Dial(ctx, network, target.Addr())
if err != nil {
@ -37,7 +39,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
resp.Status = relay.StatusHostUnreachable
resp.WriteTo(conn)
h.logger.Error(err)
log.Error(err)
return
}
@ -46,7 +48,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
if h.md.noDelay {
if _, err := resp.WriteTo(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
}
@ -77,11 +79,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr())
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr())
}

View File

@ -10,7 +10,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/relay"
@ -24,7 +23,6 @@ type relayHandler struct {
group *chain.NodeGroup
router *chain.Router
authenticator auth.Authenticator
logger logger.Logger
md metadata
options handler.Options
}
@ -53,7 +51,6 @@ func (h *relayHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return nil
}
@ -66,14 +63,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -84,14 +81,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
conn.SetReadDeadline(time.Time{})
if req.Version != relay.Version1 {
h.logger.Error("bad version")
log.Error("bad version")
return
}
@ -109,7 +106,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
}
if user != "" {
h.logger = h.logger.WithFields(map[string]interface{}{"user": user})
log = log.WithFields(map[string]interface{}{"user": user})
}
resp := relay.Response{
@ -119,7 +116,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
if h.authenticator != nil && !h.authenticator.Authenticate(user, pass) {
resp.Status = relay.StatusUnauthorized
resp.WriteTo(conn)
h.logger.Error("unauthorized")
log.Error("unauthorized")
return
}
@ -132,18 +129,18 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) {
if address != "" {
resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
h.logger.Error("forward mode, connect is forbidden")
log.Error("forward mode, connect is forbidden")
return
}
// forward mode
h.handleForward(ctx, conn, network)
h.handleForward(ctx, conn, network, log)
return
}
switch req.Flags & relay.CmdMask {
case 0, relay.CONNECT:
h.handleConnect(ctx, conn, network, address)
h.handleConnect(ctx, conn, network, address, log)
case relay.BIND:
h.handleBind(ctx, conn, network, address)
h.handleBind(ctx, conn, network, address, log)
}
}

View File

@ -14,7 +14,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
dissector "github.com/go-gost/tls-dissector"
@ -27,7 +26,6 @@ func init() {
type sniHandler struct {
httpHandler handler.Handler
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -38,19 +36,13 @@ func NewHandler(opts ...handler.Option) handler.Handler {
opt(&options)
}
log := options.Logger
if log == nil {
log = logger.Default()
}
h := &sniHandler{
options: options,
logger: log,
}
if f := registry.GetHandler("http"); f != nil {
v := append(opts,
handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"})))
handler.LoggerOption(h.options.Logger.WithFields(map[string]interface{}{"type": "http"})))
h.httpHandler = f(v...)
}
@ -85,21 +77,21 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
var hdr [dissector.RecordHeaderLen]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
@ -121,25 +113,25 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
buf := bufpool.Get(int(length) + dissector.RecordHeaderLen)
defer bufpool.Put(buf)
if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
copy(*buf, hdr[:])
opaque, host, err := h.decodeHost(bytes.NewReader(*buf))
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
target := net.JoinHostPort(host, "443")
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": target,
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target)
log.Infof("%s >> %s", conn.RemoteAddr(), target)
if h.options.Bypass != nil && h.options.Bypass.Contains(target) {
h.logger.Info("bypass: ", target)
log.Info("bypass: ", target)
return
}
@ -150,18 +142,16 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) {
defer cc.Close()
if _, err := cc.Write(opaque); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target)
log.Infof("%s <-> %s", conn.RemoteAddr(), target)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), target)
}).Infof("%s >-< %s", conn.RemoteAddr(), target)
}
func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) {

View File

@ -23,7 +23,6 @@ func init() {
type socks4Handler struct {
router *chain.Router
authenticator auth.Authenticator
logger logger.Logger
md metadata
options handler.Options
}
@ -52,7 +51,6 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return nil
}
@ -62,14 +60,14 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -80,10 +78,10 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
req, err := gosocks4.ReadRequest(conn)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(req)
log.Debug(req)
conn.SetReadDeadline(time.Time{})
@ -91,33 +89,33 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
!h.authenticator.Authenticate(string(req.Userid), "") {
resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil)
resp.Write(conn)
h.logger.Debug(resp)
log.Debug(resp)
return
}
switch req.Cmd {
case gosocks4.CmdConnect:
h.handleConnect(ctx, conn, req)
h.handleConnect(ctx, conn, req, log)
case gosocks4.CmdBind:
h.handleBind(ctx, conn, req)
default:
h.logger.Errorf("unknown cmd: %d", req.Cmd)
log.Errorf("unknown cmd: %d", req.Cmd)
}
}
func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request) {
func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request, log logger.Logger) {
addr := req.Addr.String()
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": addr,
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
if h.options.Bypass != nil && h.options.Bypass.Contains(addr) {
resp := gosocks4.NewReply(gosocks4.Rejected, nil)
resp.Write(conn)
h.logger.Debug(resp)
h.logger.Info("bypass: ", addr)
log.Debug(resp)
log.Info("bypass: ", addr)
return
}
@ -125,7 +123,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if err != nil {
resp := gosocks4.NewReply(gosocks4.Failed, nil)
resp.Write(conn)
h.logger.Debug(resp)
log.Debug(resp)
return
}
@ -133,19 +131,17 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
resp := gosocks4.NewReply(gosocks4.Granted, nil)
if err := resp.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(resp)
log.Debug(resp)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
}
func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) {

View File

@ -8,43 +8,44 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "bind",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
log.Infof("%s >> %s", conn.RemoteAddr(), address)
if !h.md.enableBind {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
h.logger.Debug(reply)
h.logger.Error("BIND is diabled")
log.Debug(reply)
log.Error("BIND is diabled")
return
}
// BIND does not support chain.
h.bindLocal(ctx, conn, network, address)
h.bindLocal(ctx, conn, network, address, log)
}
func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string) {
func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil {
h.logger.Error(err)
log.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
}
h.logger.Debug(reply)
log.Debug(reply)
return
}
socksAddr := gosocks5.Addr{}
if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil {
h.logger.Warn(err)
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
@ -52,22 +53,22 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, a
socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
ln.Close()
return
}
h.logger.Debug(reply)
log.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
h.logger.Debugf("bind on %s OK", ln.Addr())
log.Debugf("bind on %s OK", ln.Addr())
h.serveBind(ctx, conn, ln)
h.serveBind(ctx, conn, ln, log)
}
func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener) {
func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
var rc net.Conn
accept := func() <-chan error {
errc := make(chan error, 1)
@ -105,38 +106,42 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
select {
case err := <-accept():
if err != nil {
h.logger.Error(err)
log.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
if err := reply.Write(pc2); err != nil {
h.logger.Error(err)
log.Error(err)
}
h.logger.Debug(reply)
log.Debug(reply)
return
}
defer rc.Close()
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
log.Debugf("peer %s accepted", rc.RemoteAddr())
log = log.WithFields(map[string]interface{}{
"local": rc.LocalAddr().String(),
"remote": rc.RemoteAddr().String(),
})
raddr := gosocks5.Addr{}
raddr.ParseFrom(rc.RemoteAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr)
if err := reply.Write(pc2); err != nil {
h.logger.Error(err)
log.Error(err)
}
h.logger.Debug(reply)
log.Debug(reply)
start := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), raddr.String())
log.Infof("%s <-> %s", rc.LocalAddr(), rc.RemoteAddr())
handler.Transport(pc2, rc)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(start)}).
Infof("%s >-< %s", conn.RemoteAddr(), raddr.String())
log.WithFields(map[string]interface{}{"duration": time.Since(start)}).
Infof("%s >-< %s", rc.LocalAddr(), rc.RemoteAddr())
case err := <-pipe():
if err != nil {
h.logger.Error(err)
log.Error(err)
}
ln.Close()
return

View File

@ -8,20 +8,21 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "connect",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
log.Infof("%s >> %s", conn.RemoteAddr(), address)
if h.options.Bypass != nil && h.options.Bypass.Contains(address) {
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
resp.Write(conn)
h.logger.Debug(resp)
h.logger.Info("bypass: ", address)
log.Debug(resp)
log.Info("bypass: ", address)
return
}
@ -29,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
if err != nil {
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn)
h.logger.Debug(resp)
log.Debug(resp)
return
}
@ -37,17 +38,15 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
resp := gosocks5.NewReply(gosocks5.Succeeded, nil)
if err := resp.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(resp)
log.Debug(resp)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address)
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), address)
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
}

View File

@ -10,7 +10,6 @@ import (
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
@ -23,7 +22,6 @@ func init() {
type socks5Handler struct {
selector gosocks5.Selector
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -44,7 +42,6 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
return
}
h.logger = h.options.Logger
h.router = &chain.Router{
Retries: h.options.Retries,
Chain: h.options.Chain,
@ -56,7 +53,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
h.selector = &serverSelector{
Authenticator: auth_util.AuthFromUsers(h.options.Auths...),
TLSConfig: h.options.TLSConfig,
logger: h.logger,
logger: h.options.Logger,
noTLS: h.md.noTLS,
}
@ -68,14 +65,14 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -87,30 +84,30 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
conn = gosocks5.ServerConn(conn, h.selector)
req, err := gosocks5.ReadRequest(conn)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(req)
log.Debug(req)
conn.SetReadDeadline(time.Time{})
address := req.Addr.String()
switch req.Cmd {
case gosocks5.CmdConnect:
h.handleConnect(ctx, conn, "tcp", address)
h.handleConnect(ctx, conn, "tcp", address, log)
case gosocks5.CmdBind:
h.handleBind(ctx, conn, "tcp", address)
h.handleBind(ctx, conn, "tcp", address, log)
case socks.CmdMuxBind:
h.handleMuxBind(ctx, conn, "tcp", address)
h.handleMuxBind(ctx, conn, "tcp", address, log)
case gosocks5.CmdUdp:
h.handleUDP(ctx, conn)
h.handleUDP(ctx, conn, log)
case socks.CmdUDPTun:
h.handleUDPTun(ctx, conn, "udp", address)
h.handleUDPTun(ctx, conn, "udp", address, log)
default:
h.logger.Errorf("unknown cmd: %d", req.Cmd)
log.Errorf("unknown cmd: %d", req.Cmd)
resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil)
resp.Write(conn)
h.logger.Debug(resp)
log.Debug(resp)
return
}
}

View File

@ -9,43 +9,44 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/util/mux"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", address, network),
"cmd": "mbind",
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), address)
log.Infof("%s >> %s", conn.RemoteAddr(), address)
if !h.md.enableBind {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
h.logger.Debug(reply)
h.logger.Error("BIND is diabled")
log.Debug(reply)
log.Error("BIND is diabled")
return
}
h.muxBindLocal(ctx, conn, network, address)
h.muxBindLocal(ctx, conn, network, address, log)
}
func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string) {
func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error
if err != nil {
h.logger.Error(err)
log.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
}
h.logger.Debug(reply)
log.Debug(reply)
return
}
socksAddr := gosocks5.Addr{}
err = socksAddr.ParseFrom(ln.Addr().String())
if err != nil {
h.logger.Warn(err)
log.Warn(err)
}
// Issue: may not reachable when host has multi-interface
@ -53,26 +54,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network
socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
ln.Close()
return
}
h.logger.Debug(reply)
log.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
h.logger.Debugf("bind on %s OK", ln.Addr())
log.Debugf("bind on %s OK", ln.Addr())
h.serveMuxBind(ctx, conn, ln)
h.serveMuxBind(ctx, conn, ln, log)
}
func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener) {
func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) {
// Upgrade connection to multiplex stream.
session, err := mux.ClientSession(conn)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer session.Close()
@ -82,7 +83,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
for {
conn, err := session.Accept()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
conn.Close() // we do not handle incoming connections.
@ -92,17 +93,21 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
for {
rc, err := ln.Accept()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debugf("peer %s accepted", rc.RemoteAddr())
log.Debugf("peer %s accepted", rc.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
log = log.WithFields(map[string]interface{}{
"local": rc.LocalAddr().String(),
"remote": rc.RemoteAddr().String(),
})
sc, err := session.GetConn()
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer sc.Close()
@ -113,18 +118,17 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.
addr.ParseFrom(c.RemoteAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, &addr)
if err := reply.Write(sc); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(reply)
log.Debug(reply)
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String())
log.Infof("%s <-> %s", c.LocalAddr(), c.RemoteAddr())
handler.Transport(sc, c)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String())
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", c.LocalAddr(), c.RemoteAddr())
}(rc)
}
}

View File

@ -11,27 +11,28 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"cmd": "udp",
})
if !h.md.enableUDP {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
h.logger.Debug(reply)
h.logger.Error("UDP relay is diabled")
log.Debug(reply)
log.Error("UDP relay is diabled")
return
}
cc, err := net.ListenUDP("udp", nil)
if err != nil {
h.logger.Error(err)
log.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
h.logger.Debug(reply)
log.Debug(reply)
return
}
defer cc.Close()
@ -42,41 +43,40 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) {
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(reply)
log.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", cc.LocalAddr(), cc.LocalAddr().Network()),
})
h.logger.Debugf("bind on %s OK", cc.LocalAddr())
log.Debugf("bind on %s OK", cc.LocalAddr())
// obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer c.Close()
pc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
log.Errorf("wrong connection type")
return
}
relay := handler.NewUDPRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc).
WithBypass(h.options.Bypass).
WithLogger(h.logger)
WithLogger(log)
relay.SetBufferSize(h.md.udpBufferSize)
go relay.Run()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
io.Copy(ioutil.Discard, conn)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
}

View File

@ -8,54 +8,53 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/common/util/socks"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) {
h.logger = h.logger.WithFields(map[string]interface{}{
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) {
log = log.WithFields(map[string]interface{}{
"cmd": "udp-tun",
})
if !h.md.enableUDP {
reply := gosocks5.NewReply(gosocks5.NotAllowed, nil)
reply.Write(conn)
h.logger.Debug(reply)
h.logger.Error("UDP relay is diabled")
log.Debug(reply)
log.Error("UDP relay is diabled")
return
}
// dummy bind
reply := gosocks5.NewReply(gosocks5.Succeeded, nil)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Debug(reply)
log.Debug(reply)
// obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer c.Close()
pc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
log.Errorf("wrong connection type")
return
}
relay := handler.NewUDPRelay(socks.UDPTunServerConn(conn), pc).
WithBypass(h.options.Bypass).
WithLogger(h.logger)
WithLogger(log)
relay.SetBufferSize(h.md.udpBufferSize)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run()
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
}

View File

@ -11,7 +11,6 @@ import (
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/util/ss"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/shadowsocks/go-shadowsocks2/core"
@ -24,7 +23,6 @@ func init() {
type ssHandler struct {
cipher core.Cipher
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -60,7 +58,6 @@ func (h *ssHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -69,14 +66,14 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -91,19 +88,19 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
addr := &gosocks5.Addr{}
if _, err := addr.ReadFrom(conn); err != nil {
h.logger.Error(err)
log.Error(err)
io.Copy(ioutil.Discard, conn)
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": addr.String(),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr)
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
h.logger.Info("bypass: ", addr.String())
log.Info("bypass: ", addr.String())
return
}
@ -114,11 +111,9 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) {
defer cc.Close()
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
handler.Transport(conn, cc)
h.logger.
WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).
Infof("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
}

View File

@ -23,7 +23,6 @@ func init() {
type ssuHandler struct {
cipher core.Cipher
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -60,7 +59,6 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -69,14 +67,14 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -99,26 +97,25 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
// obtain a udp connection
c, err := h.router.Dial(ctx, "udp", "") // UDP association
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
defer c.Close()
cc, ok := c.(net.PacketConn)
if !ok {
h.logger.Errorf("wrong connection type")
log.Errorf("wrong connection type")
return
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.relayPacket(pc, cc)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
log.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.relayPacket(pc, cc, log)
log.WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
}
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (err error) {
bufSize := h.md.bufferSize
errc := make(chan error, 2)
@ -134,7 +131,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
}
if h.options.Bypass != nil && h.options.Bypass.Contains(addr.String()) {
h.logger.Warn("bypass: ", addr)
log.Warn("bypass: ", addr)
return nil
}
@ -142,7 +139,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
log.Debugf("%s >>> %s data: %d",
pc2.LocalAddr(), addr, n)
return nil
}()
@ -166,7 +163,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
}
if h.options.Bypass != nil && h.options.Bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
log.Warn("bypass: ", raddr)
return nil
}
@ -174,7 +171,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
log.Debugf("%s <<< %s data: %d",
pc2.LocalAddr(), raddr, n)
return nil
}()

238
pkg/handler/sshd/handler.go Normal file
View File

@ -0,0 +1,238 @@
package ssh
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"golang.org/x/crypto/ssh"
)
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
const (
ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2
)
func init() {
registry.RegisterHandler("sshd", NewHandler)
}
type forwardHandler struct {
router *chain.Router
md metadata
options handler.Options
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &forwardHandler{
options: options,
}
}
func (h *forwardHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
h.router = &chain.Router{
Retries: h.options.Retries,
Chain: h.options.Chain,
Resolver: h.options.Resolver,
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
return nil
}
func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
switch cc := conn.(type) {
case *sshd_util.DirectForwardConn:
h.handleDirectForward(ctx, cc, log)
case *sshd_util.RemoteForwardConn:
h.handleRemoteForward(ctx, cc, log)
default:
log.Error("wrong connection type")
return
}
}
func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_util.DirectForwardConn, log logger.Logger) {
targetAddr := conn.DstAddr()
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", targetAddr, "tcp"),
"cmd": "connect",
})
log.Infof("%s >> %s", conn.RemoteAddr(), targetAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(targetAddr) {
log.Infof("bypass %s", targetAddr)
return
}
cc, err := h.router.Dial(ctx, "tcp", targetAddr)
if err != nil {
return
}
defer cc.Close()
t := time.Now()
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
handler.Transport(conn, cc)
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
}
func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_util.RemoteForwardConn, log logger.Logger) {
req := conn.Request()
t := tcpipForward{}
ssh.Unmarshal(req.Payload, &t)
network := "tcp"
addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port)))
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", addr, network),
"cmd": "bind",
})
log.Infof("%s >> %s", conn.RemoteAddr(), addr)
// tie to the client connection
ln, err := net.Listen(network, addr)
if err != nil {
log.Error(err)
req.Reply(false, nil)
return
}
defer ln.Close()
log = log.WithFields(map[string]interface{}{
"bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()),
})
log.Debugf("bind on %s OK", ln.Addr())
err = func() error {
if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used
_, port, err := getHostPortFromAddr(ln.Addr())
if err != nil {
return err
}
var b [4]byte
binary.BigEndian.PutUint32(b[:], uint32(port))
t.Port = uint32(port)
return req.Reply(true, b[:])
}
return req.Reply(true, nil)
}()
if err != nil {
log.Error(err)
return
}
sshConn := conn.Conn()
go func() {
for {
cc, err := ln.Accept()
if err != nil { // Unable to accept new connection - listener is likely closed
return
}
go func(conn net.Conn) {
defer conn.Close()
log := log.WithFields(map[string]interface{}{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
})
p := directForward{}
var err error
var portnum int
p.Host1 = t.Host
p.Port1 = t.Port
p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr())
if err != nil {
return
}
p.Port2 = uint32(portnum)
ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p))
if err != nil {
log.Error("open forwarded channel: ", err)
return
}
defer ch.Close()
go ssh.DiscardRequests(reqs)
t := time.Now()
log.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr())
handler.Transport(ch, conn)
log.WithFields(map[string]interface{}{
"duration": time.Since(t),
}).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr())
}(cc)
}
}()
tm := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
<-conn.Done()
log.WithFields(map[string]interface{}{
"duration": time.Since(tm),
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
}
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
host, portString, err := net.SplitHostPort(addr.String())
if err != nil {
return
}
port, err = strconv.Atoi(portString)
return
}
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
type directForward struct {
Host1 string
Port1 uint32
Host2 string
Port2 uint32
}
func (p directForward) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
}
// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request
type tcpipForward struct {
Host string
Port uint32
}

View File

@ -0,0 +1,12 @@
package ssh
import (
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
return
}

View File

@ -33,7 +33,6 @@ type tapHandler struct {
exit chan struct{}
cipher core.Cipher
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -71,7 +70,6 @@ func (h *tapHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -85,21 +83,22 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
defer os.Exit(0)
defer conn.Close()
log := h.options.Logger
cc, ok := conn.(*tap_util.Conn)
if !ok || cc.Config() == nil {
h.logger.Error("invalid connection")
log.Error("invalid connection")
return
}
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -112,19 +111,19 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) {
if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr())
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
}
h.handleLoop(ctx, conn, raddr, cc.Config())
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
}
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config) {
func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tap_util.Config, log logger.Logger) {
var tempDelay time.Duration
for {
err := func() error {
@ -154,10 +153,10 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
pc = h.cipher.PacketConn(pc)
}
return h.transport(conn, pc, addr)
return h.transport(conn, pc, addr, log)
}()
if err != nil {
h.logger.Error(err)
log.Error(err)
}
select {
@ -183,7 +182,7 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
}
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr) error {
func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error {
errc := make(chan error, 1)
go func() {
@ -205,7 +204,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
dst := waterutil.MACDestination((*b)[:n])
eType := etherType(waterutil.MACEthertype((*b)[:n]))
h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n)
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
// client side, deliver frame directly.
if raddr != nil {
@ -227,7 +226,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
addr = v.(net.Addr)
}
if addr == nil {
h.logger.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
log.Warnf("no route for %s -> %s %s %d", src, dst, eType, n)
return nil
}
@ -261,7 +260,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
dst := waterutil.MACDestination((*b)[:n])
eType := etherType(waterutil.MACEthertype((*b)[:n]))
h.logger.Debugf("%s >> %s %s %d", src, dst, eType, n)
log.Debugf("%s >> %s %s %d", src, dst, eType, n)
// client side, deliver frame to tap device.
if raddr != nil {
@ -273,12 +272,12 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
rkey := hwAddrToTapRouteKey(src)
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
if actual.(net.Addr).String() != addr.String() {
h.logger.Debugf("update route: %s -> %s (old %s)",
log.Debugf("update route: %s -> %s (old %s)",
src, addr, actual.(net.Addr))
h.routes.Store(rkey, addr)
}
} else {
h.logger.Debugf("new route: %s -> %s", src, addr)
log.Debugf("new route: %s -> %s", src, addr)
}
if waterutil.IsBroadcast(dst) {
@ -291,7 +290,7 @@ func (h *tapHandler) transport(tap net.Conn, conn net.PacketConn, raddr net.Addr
}
if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok {
h.logger.Debugf("find route: %s -> %s", dst, v)
log.Debugf("find route: %s -> %s", dst, v)
_, err := conn.WriteTo((*b)[:n], v.(net.Addr))
return err
}

View File

@ -35,7 +35,6 @@ type tunHandler struct {
exit chan struct{}
cipher core.Cipher
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -73,7 +72,6 @@ func (h *tunHandler) Init(md md.Metadata) (err error) {
Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
return
}
@ -87,21 +85,23 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
defer os.Exit(0)
defer conn.Close()
log := h.options.Logger
cc, ok := conn.(*tun_util.Conn)
if !ok || cc.Config() == nil {
h.logger.Error("invalid connection")
log.Error("invalid connection")
return
}
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -114,19 +114,19 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) {
if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr())
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger = h.logger.WithFields(map[string]interface{}{
log = log.WithFields(map[string]interface{}{
"dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()),
})
h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr())
}
h.handleLoop(ctx, conn, raddr, cc.Config())
h.handleLoop(ctx, conn, raddr, cc.Config(), log)
}
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config) {
func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) {
var tempDelay time.Duration
for {
err := func() error {
@ -155,10 +155,10 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
pc = h.cipher.PacketConn(pc)
}
return h.transport(conn, pc, addr)
return h.transport(conn, pc, addr, log)
}()
if err != nil {
h.logger.Error(err)
log.Error(err)
}
select {
@ -184,7 +184,7 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
}
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr) error {
func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr, log logger.Logger) error {
errc := make(chan error, 1)
go func() {
@ -206,10 +206,10 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
if waterutil.IsIPv4((*b)[:n]) {
header, err := ipv4.ParseHeader((*b)[:n])
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil
}
h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags)
@ -217,17 +217,17 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
} else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n])
if err != nil {
h.logger.Warn(err)
log.Warn(err)
return nil
}
h.logger.Debugf("%s >> %s %s %d %d",
log.Debugf("%s >> %s %s %d %d",
header.Src, header.Dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass)
src, dst = header.Src, header.Dst
} else {
h.logger.Warn("unknown packet, discarded")
log.Warn("unknown packet, discarded")
return nil
}
@ -239,11 +239,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
addr := h.findRouteFor(dst)
if addr == nil {
h.logger.Warnf("no route for %s -> %s", src, dst)
log.Warnf("no route for %s -> %s", src, dst)
return nil
}
h.logger.Debugf("find route: %s -> %s", dst, addr)
log.Debugf("find route: %s -> %s", dst, addr)
if _, err := conn.WriteTo((*b)[:n], addr); err != nil {
return err
@ -274,11 +274,11 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
if waterutil.IsIPv4((*b)[:n]) {
header, err := ipv4.ParseHeader((*b)[:n])
if err != nil {
h.logger.Warn(err)
log.Warn(err)
return nil
}
h.logger.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
log.Debugf("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags)
@ -286,18 +286,18 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
} else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n])
if err != nil {
h.logger.Warn(err)
log.Warn(err)
return nil
}
h.logger.Debugf("%s > %s %s %d %d",
log.Debugf("%s > %s %s %d %d",
header.Src, header.Dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass)
src, dst = header.Src, header.Dst
} else {
h.logger.Warn("unknown packet, discarded")
log.Warn("unknown packet, discarded")
return nil
}
@ -310,16 +310,16 @@ func (h *tunHandler) transport(tun net.Conn, conn net.PacketConn, raddr net.Addr
rkey := ipToTunRouteKey(src)
if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded {
if actual.(net.Addr).String() != addr.String() {
h.logger.Debugf("update route: %s -> %s (old %s)",
log.Debugf("update route: %s -> %s (old %s)",
src, addr, actual.(net.Addr))
h.routes.Store(rkey, addr)
}
} else {
h.logger.Warnf("no route for %s -> %s", src, addr)
log.Warnf("no route for %s -> %s", src, addr)
}
if addr := h.findRouteFor(dst); addr != nil {
h.logger.Debugf("find route: %s -> %s", dst, addr)
log.Debugf("find route: %s -> %s", dst, addr)
_, err := conn.WriteTo((*b)[:n], addr)
return err

View File

@ -0,0 +1,118 @@
package sshd
import (
"context"
"errors"
"net"
"time"
"golang.org/x/crypto/ssh"
)
type DirectForwardConn struct {
conn ssh.Conn
channel ssh.Channel
dstAddr string
}
func NewDirectForwardConn(conn ssh.Conn, channel ssh.Channel, dstAddr string) net.Conn {
return &DirectForwardConn{
conn: conn,
channel: channel,
dstAddr: dstAddr,
}
}
func (c *DirectForwardConn) Read(b []byte) (n int, err error) {
return c.channel.Read(b)
}
func (c *DirectForwardConn) Write(b []byte) (n int, err error) {
return c.channel.Write(b)
}
func (c *DirectForwardConn) Close() error {
return c.channel.Close()
}
func (c *DirectForwardConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *DirectForwardConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *DirectForwardConn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *DirectForwardConn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *DirectForwardConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *DirectForwardConn) DstAddr() string {
return c.dstAddr
}
type RemoteForwardConn struct {
ctx context.Context
conn ssh.Conn
req *ssh.Request
}
func NewRemoteForwardConn(ctx context.Context, conn ssh.Conn, req *ssh.Request) net.Conn {
return &RemoteForwardConn{
ctx: ctx,
conn: conn,
req: req,
}
}
func (c *RemoteForwardConn) Conn() ssh.Conn {
return c.conn
}
func (c *RemoteForwardConn) Request() *ssh.Request {
return c.req
}
func (c *RemoteForwardConn) Read(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")}
}
func (c *RemoteForwardConn) Write(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")}
}
func (c *RemoteForwardConn) Close() error {
return &net.OpError{Op: "close", Net: "nop", Source: nil, Addr: nil, Err: errors.New("close not supported")}
}
func (c *RemoteForwardConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *RemoteForwardConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *RemoteForwardConn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *RemoteForwardConn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *RemoteForwardConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *RemoteForwardConn) Done() <-chan struct{} {
return c.ctx.Done()
}

View File

@ -15,6 +15,7 @@ import (
func init() {
registry.RegisterListener("http3", NewListener)
registry.RegisterListener("h3", NewListener)
}
type phtListener struct {

View File

@ -3,6 +3,7 @@ package ssh
import (
"fmt"
"net"
"time"
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
@ -29,13 +30,14 @@ type sshListener struct {
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &sshListener{
addr: options.Addr,
logger: options.Logger,
options: options,
}
}
@ -96,6 +98,14 @@ func (l *sshListener) listenLoop() {
}
func (l *sshListener) serveConn(conn net.Conn) {
start := time.Now()
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
l.logger.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
if err != nil {
l.logger.Error(err)
@ -122,8 +132,9 @@ func (l *sshListener) serveConn(conn net.Conn) {
select {
case l.cqueue <- cc:
default:
cc.Close()
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
cc.Close()
}
default:

View File

@ -0,0 +1,199 @@
package ssh
import (
"context"
"fmt"
"net"
"strconv"
"time"
auth_util "github.com/go-gost/gost/pkg/common/util/auth"
ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh"
sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"golang.org/x/crypto/ssh"
)
// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X
const (
DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2
RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1
)
func init() {
registry.RegisterListener("sshd", NewListener)
}
type sshdListener struct {
addr string
net.Listener
config *ssh.ServerConfig
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &sshdListener{
addr: options.Addr,
logger: options.Logger,
options: options,
}
}
func (l *sshdListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
ln, err := net.Listen("tcp", l.addr)
if err != nil {
return err
}
l.Listener = ln
authenticator := auth_util.AuthFromUsers(l.options.Auths...)
config := &ssh.ServerConfig{
PasswordCallback: ssh_util.PasswordCallback(authenticator),
PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys),
}
config.AddHostKey(l.md.signer)
if authenticator == nil && len(l.md.authorizedKeys) == 0 {
config.NoClientAuth = true
}
l.config = config
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *sshdListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *sshdListener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.serveConn(conn)
}
}
func (l *sshdListener) serveConn(conn net.Conn) {
start := time.Now()
l.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
l.logger.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
sc, chans, reqs, err := ssh.NewServerConn(conn, l.config)
if err != nil {
l.logger.Error(err)
conn.Close()
return
}
defer sc.Close()
go func() {
for newChannel := range chans {
// Check the type of channel
t := newChannel.ChannelType()
switch t {
case DirectForwardRequest:
channel, requests, err := newChannel.Accept()
if err != nil {
l.logger.Warnf("could not accept channel: %s", err.Error())
continue
}
p := directForward{}
ssh.Unmarshal(newChannel.ExtraData(), &p)
l.logger.Debug(p.String())
if p.Host1 == "<nil>" {
p.Host1 = ""
}
go ssh.DiscardRequests(requests)
cc := sshd_util.NewDirectForwardConn(sc, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1))))
select {
case l.cqueue <- cc:
default:
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
newChannel.Reject(ssh.ResourceShortage, "connection queue is full")
cc.Close()
}
default:
l.logger.Warnf("unsupported channel type: %s", t)
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t))
}
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for req := range reqs {
switch req.Type {
case RemoteForwardRequest:
cc := sshd_util.NewRemoteForwardConn(ctx, sc, req)
select {
case l.cqueue <- cc:
default:
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
req.Reply(false, []byte("connection queue is full"))
cc.Close()
}
default:
l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply)
req.Reply(false, nil)
}
}
}()
sc.Wait()
}
// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip"
type directForward struct {
Host1 string
Port1 uint32
Host2 string
Port2 uint32
}
func (p directForward) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1)
}

View File

@ -9,16 +9,22 @@ import (
"golang.org/x/crypto/ssh"
)
const (
defaultBacklog = 128
)
type metadata struct {
signer ssh.Signer
authorizedKeys map[string]bool
backlog int
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) {
const (
authorizedKeys = "authorizedKeys"
privateKeyFile = "privateKeyFile"
passphrase = "passphrase"
backlog = "backlog"
)
if key := mdata.GetString(md, privateKeyFile); key != "" {
@ -29,20 +35,20 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
pp := mdata.GetString(md, passphrase)
if pp == "" {
h.md.signer, err = ssh.ParsePrivateKey(data)
l.md.signer, err = ssh.ParsePrivateKey(data)
} else {
h.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp))
}
if err != nil {
return err
}
}
if h.md.signer == nil {
if l.md.signer == nil {
signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey)
if err != nil {
return err
}
h.md.signer = signer
l.md.signer = signer
}
if name := mdata.GetString(md, authorizedKeys); name != "" {
@ -50,7 +56,12 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
if err != nil {
return err
}
h.md.authorizedKeys = m
l.md.authorizedKeys = m
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return