add network for bypass

This commit is contained in:
ginuerzh
2023-09-30 17:51:55 +08:00
parent ea585fc25d
commit 836cf6eade
24 changed files with 92 additions and 160 deletions

View File

@ -199,7 +199,7 @@ func (h *dnsHandler) request(ctx context.Context, msg []byte, log logger.Logger)
}
if h.options.Bypass != nil && mq.Question[0].Qclass == dns.ClassINET {
if h.options.Bypass.Contains(context.Background(), strings.Trim(mq.Question[0].Name, ".")) {
if h.options.Bypass.Contains(context.Background(), "udp", strings.Trim(mq.Question[0].Name, ".")) {
log.Debug("bypass: ", mq.Question[0].Name)
mr = (&dns.Msg{}).SetReply(&mq)
b := bufpool.Get(h.md.bufferSize)

View File

@ -154,7 +154,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
}
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, addr) {
resp.StatusCode = http.StatusForbidden
if log.IsLevelEnabled(logger.TraceLevel) {

View File

@ -155,7 +155,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
}
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", addr) {
w.WriteHeader(http.StatusForbidden)
log.Debug("bypass: ", addr)
return nil

View File

@ -10,8 +10,8 @@ import (
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hop"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/hop"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
sx "github.com/go-gost/x/internal/util/selector"
@ -106,7 +106,7 @@ func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
w.Header().Set(k, h.md.header.Get(k))
}
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "udp", addr) {
w.WriteHeader(http.StatusForbidden)
log.Debug("bypass: ", addr)
return nil

View File

@ -122,7 +122,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
log.Debugf("%s >> %s", conn.RemoteAddr(), dstAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, dstAddr.String()) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, dstAddr.Network(), dstAddr.String()) {
log.Debug("bypass: ", dstAddr)
return nil
}
@ -163,7 +163,7 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, host) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}
@ -232,7 +232,7 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, host) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}

View File

@ -75,7 +75,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
log.Debugf("%s >> %s", conn.RemoteAddr(), dstAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, dstAddr.String()) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, dstAddr.Network(), dstAddr.String()) {
log.Debug("bypass: ", dstAddr)
return nil
}

View File

@ -44,7 +44,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
return
}
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, address) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, address) {
log.Debug("bypass: ", address)
resp.Status = relay.StatusForbidden
_, err = resp.WriteTo(conn)
@ -131,7 +131,7 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n
host, sp, _ := net.SplitHostPort(address)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, address) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, address) {
log.Debug("bypass: ", address)
resp.Status = relay.StatusForbidden
_, err := resp.WriteTo(conn)

View File

@ -115,7 +115,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, host) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}
@ -183,7 +183,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne
})
log.Debugf("%s >> %s", raddr, host)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, host) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}

View File

@ -123,7 +123,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
})
log.Debugf("%s >> %s", conn.RemoteAddr(), addr)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", addr) {
resp := gosocks4.NewReply(gosocks4.Rejected, nil)
log.Trace(resp)
log.Debug("bypass: ", addr)

View File

@ -19,7 +19,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
})
log.Debugf("%s >> %s", conn.RemoteAddr(), address)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, address) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, address) {
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
log.Trace(resp)
log.Debug("bypass: ", address)

View File

@ -1,99 +0,0 @@
package v5
import (
"context"
"crypto/tls"
"net"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/logger"
"github.com/go-gost/gosocks5"
auth_util "github.com/go-gost/x/internal/util/auth"
"github.com/go-gost/x/internal/util/socks"
)
type serverSelector struct {
methods []uint8
Authenticator auth.Authenticator
TLSConfig *tls.Config
logger logger.Logger
noTLS bool
}
func (selector *serverSelector) Methods() []uint8 {
return selector.methods
}
func (s *serverSelector) Select(methods ...uint8) (method uint8) {
s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods)
method = gosocks5.MethodNoAuth
for _, m := range methods {
if m == socks.MethodTLS && !s.noTLS {
method = m
break
}
}
// when Authenticator is set, auth is mandatory
if s.Authenticator != nil {
if method == gosocks5.MethodNoAuth {
method = gosocks5.MethodUserPass
}
if method == socks.MethodTLS && !s.noTLS {
method = socks.MethodTLSAuth
}
}
return
}
func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (string, net.Conn, error) {
s.logger.Debugf("%d %d", gosocks5.Ver5, method)
switch method {
case socks.MethodTLS:
conn = tls.Server(conn, s.TLSConfig)
return "", conn, nil
case gosocks5.MethodUserPass, socks.MethodTLSAuth:
if method == socks.MethodTLSAuth {
conn = tls.Server(conn, s.TLSConfig)
}
req, err := gosocks5.ReadUserPassRequest(conn)
if err != nil {
s.logger.Error(err)
return "", nil, err
}
s.logger.Trace(req)
var id string
if s.Authenticator != nil {
var ok bool
ctx := auth_util.ContextWithClientAddr(context.Background(), auth_util.ClientAddr(conn.RemoteAddr().String()))
id, ok = s.Authenticator.Authenticate(ctx, req.Username, req.Password)
if !ok {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil {
s.logger.Error(err)
return "", nil, err
}
s.logger.Info(resp)
return "", nil, gosocks5.ErrAuthFailure
}
}
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
s.logger.Trace(resp)
if err := resp.Write(conn); err != nil {
s.logger.Error(err)
return "", nil, err
}
return id, conn, nil
case gosocks5.MethodNoAcceptable:
return "", nil, gosocks5.ErrBadMethod
default:
return "", nil, gosocks5.ErrBadFormat
}
}

View File

@ -101,7 +101,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H
log.Debugf("%s >> %s", conn.RemoteAddr(), addr)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr.String()) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", addr.String()) {
log.Debug("bypass: ", addr.String())
return nil
}

View File

@ -135,7 +135,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (er
return err
}
if h.options.Bypass != nil && h.options.Bypass.Contains(context.Background(), addr.String()) {
if h.options.Bypass != nil && h.options.Bypass.Contains(context.Background(), addr.Network(), addr.String()) {
log.Warn("bypass: ", addr)
return nil
}
@ -167,7 +167,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (er
return err
}
if h.options.Bypass != nil && h.options.Bypass.Contains(context.Background(), raddr.String()) {
if h.options.Bypass != nil && h.options.Bypass.Contains(context.Background(), raddr.Network(), raddr.String()) {
log.Warn("bypass: ", raddr)
return nil
}

View File

@ -92,7 +92,7 @@ func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_uti
log.Debugf("%s >> %s", conn.RemoteAddr(), targetAddr)
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, targetAddr) {
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", targetAddr) {
log.Debugf("bypass %s", targetAddr)
return nil
}