get real client ip for http forwarding

This commit is contained in:
ginuerzh 2023-10-18 21:22:44 +08:00
parent 54b56df214
commit f2fd6554ad
2 changed files with 89 additions and 6 deletions

View File

@ -10,6 +10,8 @@ import (
"net"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"time"
"github.com/go-gost/core/chain"
@ -91,8 +93,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
network = "udp"
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String()))
var rw io.ReadWriter = conn
var host string
var protocol string
@ -108,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}
if protocol == forward.ProtoHTTP {
h.handleHTTP(ctx, rw, log)
h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
return nil
}
@ -116,6 +116,8 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
host = net.JoinHostPort(host, "0")
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String()))
var target *chain.Node
if host != "" {
target = &chain.Node{
@ -179,7 +181,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw)
var cc net.Conn
@ -198,6 +200,15 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
return err
}
if addr := getRealClientAddr(req, remoteAddr); addr != remoteAddr {
log = log.WithFields(map[string]any{
"src": addr.String(),
})
remoteAddr = addr
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String()))
target := &chain.Node{
Addr: req.Host,
}
@ -326,3 +337,34 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
return true
}
func getRealClientAddr(req *http.Request, raddr net.Addr) net.Addr {
if req == nil {
return nil
}
// cloudflare CDN
sip := req.Header.Get("CF-Connecting-IP")
if sip == "" {
ss := strings.Split(req.Header.Get("X-Forwarded-For"), ",")
if len(ss) > 0 && ss[0] != "" {
sip = ss[0]
}
}
if sip == "" {
sip = req.Header.Get("X-Real-Ip")
}
ip := net.ParseIP(sip)
if ip == nil {
return raddr
}
_, sp, _ := net.SplitHostPort(raddr.String())
port, _ := strconv.Atoi(sp)
return &net.TCPAddr{
IP: ip,
Port: port,
}
}

View File

@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httputil"
"strconv"
"strings"
"time"
"github.com/go-gost/core/chain"
@ -93,8 +94,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
network = "udp"
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String()))
localAddr := convertAddr(conn.LocalAddr())
var rw io.ReadWriter = conn
@ -115,6 +114,8 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String()))
if md, ok := conn.(mdata.Metadatable); ok {
if v := mdutil.GetString(md.Metadata(), "host"); v != "" {
host = v
@ -200,6 +201,15 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
return err
}
if addr := getRealClientAddr(req, remoteAddr); addr != remoteAddr {
log = log.WithFields(map[string]any{
"src": addr.String(),
})
remoteAddr = addr
}
ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String()))
target := &chain.Node{
Addr: req.Host,
}
@ -354,3 +364,34 @@ func convertAddr(addr net.Addr) net.Addr {
}
}
}
func getRealClientAddr(req *http.Request, raddr net.Addr) net.Addr {
if req == nil {
return nil
}
// cloudflare CDN
sip := req.Header.Get("CF-Connecting-IP")
if sip == "" {
ss := strings.Split(req.Header.Get("X-Forwarded-For"), ",")
if len(ss) > 0 && ss[0] != "" {
sip = ss[0]
}
}
if sip == "" {
sip = req.Header.Get("X-Real-Ip")
}
ip := net.ParseIP(sip)
if ip == nil {
return raddr
}
_, sp, _ := net.SplitHostPort(raddr.String())
port, _ := strconv.Atoi(sp)
return &net.TCPAddr{
IP: ip,
Port: port,
}
}