add admission

This commit is contained in:
ginuerzh 2022-02-17 23:30:13 +08:00
parent 5daefb8e3c
commit 307a90c20e
22 changed files with 668 additions and 160 deletions

View File

@ -173,6 +173,24 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) {
service.Handler.Retries = v
md.Del("retries")
}
if v := metadata.GetString(md, "admission"); v != "" {
admCfg := &config.AdmissionConfig{
Name: fmt.Sprintf("admission-%d", len(cfg.Admissions)),
}
if v[0] == '~' {
admCfg.Reverse = true
v = v[1:]
}
for _, s := range strings.Split(v, ",") {
if s == "" {
continue
}
admCfg.Matchers = append(admCfg.Matchers, s)
}
service.Admission = admCfg.Name
cfg.Admissions = append(cfg.Admissions, admCfg)
md.Del("admission")
}
if v := metadata.GetString(md, "bypass"); v != "" {
bypassCfg := &config.BypassConfig{
Name: fmt.Sprintf("bypass-%d", len(cfg.Bypasses)),

View File

@ -12,7 +12,7 @@ import (
"github.com/go-gost/gost/pkg/service"
)
func buildService(cfg *config.Config) (services []service.Servicer) {
func buildService(cfg *config.Config) (services []service.Service) {
if cfg == nil {
return
}
@ -25,6 +25,14 @@ func buildService(cfg *config.Config) (services []service.Servicer) {
}
}
for _, admissionCfg := range cfg.Admissions {
if adm := parsing.ParseAdmission(admissionCfg); adm != nil {
if err := registry.Admission().Register(admissionCfg.Name, adm); err != nil {
log.Fatal(err)
}
}
}
for _, bypassCfg := range cfg.Bypasses {
if bp := parsing.ParseBypass(bypassCfg); bp != nil {
if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil {

View File

@ -0,0 +1,85 @@
package admission
import (
"net"
"strconv"
"github.com/go-gost/gost/pkg/common/matcher"
"github.com/go-gost/gost/pkg/logger"
)
type Admission interface {
Admit(addr string) bool
}
type options struct {
logger logger.Logger
}
type Option func(opts *options)
func LoggerOption(logger logger.Logger) Option {
return func(opts *options) {
opts.logger = logger
}
}
type admission struct {
matchers []matcher.Matcher
reversed bool
options options
}
// NewAdmission creates and initializes a new Admission using matchers as its match rules.
// The rules will be reversed if the reversed is true.
func NewAdmission(reversed bool, matchers []matcher.Matcher, opts ...Option) Admission {
options := options{}
for _, opt := range opts {
opt(&options)
}
return &admission{
matchers: matchers,
reversed: reversed,
options: options,
}
}
// NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) Admission {
var matchers []matcher.Matcher
for _, pattern := range patterns {
if m := matcher.NewMatcher(pattern); m != nil {
matchers = append(matchers, m)
}
}
return NewAdmission(reversed, matchers, opts...)
}
func (p *admission) Admit(addr string) bool {
if addr == "" || p == nil || len(p.matchers) == 0 {
return false
}
// try to strip the port
if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" {
if p, _ := strconv.Atoi(port); p > 0 { // port is valid
addr = host
}
}
var matched bool
for _, matcher := range p.matchers {
if matcher == nil {
continue
}
if matcher.Match(addr) {
matched = true
break
}
}
b := !p.reversed && matched ||
p.reversed && !matched
return b
}

166
pkg/api/config_admission.go Normal file
View File

@ -0,0 +1,166 @@
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/config/parsing"
"github.com/go-gost/gost/pkg/registry"
)
// swagger:parameters createAdmissionRequest
type createAdmissionRequest struct {
// in: body
Data config.AdmissionConfig `json:"data"`
}
// successful operation.
// swagger:response createAdmissionResponse
type createAdmissionResponse struct {
Data Response
}
func createAdmission(ctx *gin.Context) {
// swagger:route POST /config/admissions ConfigManagement createAdmissionRequest
//
// Create a new admission, the name of admission must be unique in admission list.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: createAdmissionResponse
var req createAdmissionRequest
ctx.ShouldBindJSON(&req.Data)
if req.Data.Name == "" {
writeError(ctx, ErrInvalid)
return
}
v := parsing.ParseAdmission(&req.Data)
if err := registry.Admission().Register(req.Data.Name, v); err != nil {
writeError(ctx, ErrDup)
return
}
cfg := config.Global()
cfg.Admissions = append(cfg.Admissions, &req.Data)
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}
// swagger:parameters updateAdmissionRequest
type updateAdmissionRequest struct {
// in: path
// required: true
Admission string `uri:"admission" json:"admission"`
// in: body
Data config.AdmissionConfig `json:"data"`
}
// successful operation.
// swagger:response updateAdmissionResponse
type updateAdmissionResponse struct {
Data Response
}
func updateAdmission(ctx *gin.Context) {
// swagger:route PUT /config/admissions/{admission} ConfigManagement updateAdmissionRequest
//
// Update admission by name, the admission must already exist.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: updateAdmissionResponse
var req updateAdmissionRequest
ctx.ShouldBindUri(&req)
ctx.ShouldBindJSON(&req.Data)
if !registry.Admission().IsRegistered(req.Admission) {
writeError(ctx, ErrNotFound)
return
}
req.Data.Name = req.Admission
v := parsing.ParseAdmission(&req.Data)
registry.Admission().Unregister(req.Admission)
if err := registry.Admission().Register(req.Admission, v); err != nil {
writeError(ctx, ErrDup)
return
}
cfg := config.Global()
for i := range cfg.Admissions {
if cfg.Admissions[i].Name == req.Admission {
cfg.Admissions[i] = &req.Data
break
}
}
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}
// swagger:parameters deleteAdmissionRequest
type deleteAdmissionRequest struct {
// in: path
// required: true
Admission string `uri:"admission" json:"admission"`
}
// successful operation.
// swagger:response deleteAdmissionResponse
type deleteAdmissionResponse struct {
Data Response
}
func deleteAdmission(ctx *gin.Context) {
// swagger:route DELETE /config/admissions/{admission} ConfigManagement deleteAdmissionRequest
//
// Delete admission by name.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: deleteAdmissionResponse
var req deleteAdmissionRequest
ctx.ShouldBindUri(&req)
if !registry.Admission().IsRegistered(req.Admission) {
writeError(ctx, ErrNotFound)
return
}
registry.Admission().Unregister(req.Admission)
cfg := config.Global()
admissiones := cfg.Admissions
cfg.Admissions = nil
for _, s := range admissiones {
if s.Name == req.Admission {
continue
}
cfg.Admissions = append(cfg.Admissions, s)
}
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}

View File

@ -141,8 +141,7 @@ func deleteAuther(ctx *gin.Context) {
var req deleteAutherRequest
ctx.ShouldBindUri(&req)
svc := registry.Auther().Get(req.Auther)
if svc == nil {
if !registry.Auther().IsRegistered(req.Auther) {
writeError(ctx, ErrNotFound)
return
}

View File

@ -143,8 +143,7 @@ func deleteBypass(ctx *gin.Context) {
var req deleteBypassRequest
ctx.ShouldBindUri(&req)
svc := registry.Bypass().Get(req.Bypass)
if svc == nil {
if !registry.Bypass().IsRegistered(req.Bypass) {
writeError(ctx, ErrNotFound)
return
}

View File

@ -152,8 +152,7 @@ func deleteChain(ctx *gin.Context) {
var req deleteChainRequest
ctx.ShouldBindUri(&req)
svc := registry.Chain().Get(req.Chain)
if svc == nil {
if !registry.Chain().IsRegistered(req.Chain) {
writeError(ctx, ErrNotFound)
return
}

View File

@ -143,8 +143,7 @@ func deleteHosts(ctx *gin.Context) {
var req deleteHostsRequest
ctx.ShouldBindUri(&req)
svc := registry.Hosts().Get(req.Hosts)
if svc == nil {
if !registry.Hosts().IsRegistered(req.Hosts) {
writeError(ctx, ErrNotFound)
return
}

View File

@ -151,8 +151,7 @@ func deleteResolver(ctx *gin.Context) {
var req deleteResolverRequest
ctx.ShouldBindUri(&req)
svc := registry.Resolver().Get(req.Resolver)
if svc == nil {
if !registry.Resolver().IsRegistered(req.Resolver) {
writeError(ctx, ErrNotFound)
return
}

View File

@ -8,10 +8,10 @@ import (
)
var (
ErrInvalid = &Error{statusCode: http.StatusBadRequest, Code: 40001, Msg: "instance invalid"}
ErrDup = &Error{statusCode: http.StatusBadRequest, Code: 40002, Msg: "instance duplicated"}
ErrCreate = &Error{statusCode: http.StatusConflict, Code: 40003, Msg: "instance creation failed"}
ErrNotFound = &Error{statusCode: http.StatusBadRequest, Code: 40004, Msg: "instance not found"}
ErrInvalid = &Error{statusCode: http.StatusBadRequest, Code: 40001, Msg: "object invalid"}
ErrDup = &Error{statusCode: http.StatusBadRequest, Code: 40002, Msg: "object duplicated"}
ErrCreate = &Error{statusCode: http.StatusConflict, Code: 40003, Msg: "object creation failed"}
ErrNotFound = &Error{statusCode: http.StatusBadRequest, Code: 40004, Msg: "object not found"}
ErrSave = &Error{statusCode: http.StatusInternalServerError, Code: 40005, Msg: "save config failed"}
)

View File

@ -113,6 +113,10 @@ func registerConfig(config *gin.RouterGroup) {
config.PUT("/authers/:auther", updateAuther)
config.DELETE("/authers/:auther", deleteAuther)
config.POST("/admissions", createAdmission)
config.PUT("/admissions/:admission", updateAdmission)
config.DELETE("/admissions/:admission", deleteAdmission)
config.POST("/bypasses", createBypass)
config.PUT("/bypasses/:bypass", updateBypass)
config.DELETE("/bypasses/:bypass", deleteBypass)

View File

@ -20,6 +20,25 @@ definitions:
x-go-name: PathPrefix
type: object
x-go-package: github.com/go-gost/gost/pkg/config
AdmissionConfig:
properties:
matchers:
items:
type: string
type: array
x-go-name: Matchers
name:
type: string
x-go-name: Name
reverse:
type: boolean
x-go-name: Reverse
type:
description: inline, file, etc.
type: string
x-go-name: Type
type: object
x-go-package: github.com/go-gost/gost/pkg/config
AuthConfig:
properties:
password:
@ -81,6 +100,11 @@ definitions:
x-go-package: github.com/go-gost/gost/pkg/config
Config:
properties:
admissions:
items:
$ref: '#/definitions/AdmissionConfig'
type: array
x-go-name: Admissions
api:
$ref: '#/definitions/APIConfig'
authers:
@ -403,6 +427,9 @@ definitions:
addr:
type: string
x-go-name: Addr
admission:
type: string
x-go-name: Admission
bypass:
type: string
x-go-name: Bypass
@ -481,6 +508,65 @@ paths:
summary: Save current config to file (gost.yaml or gost.json).
tags:
- ConfigManagement
/config/admissions:
post:
operationId: createAdmissionRequest
parameters:
- in: body
name: data
schema:
$ref: '#/definitions/AdmissionConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/createAdmissionResponse'
security:
- basicAuth:
- '[]'
summary: Create a new admission, the name of admission must be unique in admission
list.
tags:
- ConfigManagement
/config/admissions/{admission}:
delete:
operationId: deleteAdmissionRequest
parameters:
- in: path
name: admission
required: true
type: string
x-go-name: Admission
responses:
"200":
$ref: '#/responses/deleteAdmissionResponse'
security:
- basicAuth:
- '[]'
summary: Delete admission by name.
tags:
- ConfigManagement
put:
operationId: updateAdmissionRequest
parameters:
- in: path
name: admission
required: true
type: string
x-go-name: Admission
- in: body
name: data
schema:
$ref: '#/definitions/AdmissionConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/updateAdmissionResponse'
security:
- basicAuth:
- '[]'
summary: Update admission by name, the admission must already exist.
tags:
- ConfigManagement
/config/authers:
post:
operationId: createAutherRequest
@ -835,6 +921,12 @@ paths:
produces:
- application/json
responses:
createAdmissionResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
createAutherResponse:
description: successful operation.
headers:
@ -871,6 +963,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteAdmissionResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteAutherResponse:
description: successful operation.
headers:
@ -919,6 +1017,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
updateAdmissionResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
updateAutherResponse:
description: successful operation.
headers:

View File

@ -3,131 +3,39 @@ package bypass
import (
"net"
"strconv"
"strings"
"github.com/go-gost/gost/pkg/common/matcher"
"github.com/go-gost/gost/pkg/logger"
glob "github.com/gobwas/glob"
)
// Matcher is a generic pattern matcher,
// it gives the match result of the given pattern for specific v.
type Matcher interface {
Match(v string) bool
}
// NewMatcher creates a Matcher for the given pattern.
// The acutal Matcher depends on the pattern:
// IP Matcher if pattern is a valid IP address.
// CIDR Matcher if pattern is a valid CIDR address.
// Domain Matcher if both of the above are not.
func NewMatcher(pattern string) Matcher {
if pattern == "" {
return nil
}
if ip := net.ParseIP(pattern); ip != nil {
return IPMatcher(ip)
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
return CIDRMatcher(inet)
}
return DomainMatcher(pattern)
}
type ipMatcher struct {
ip net.IP
}
// IPMatcher creates a Matcher for a specific IP address.
func IPMatcher(ip net.IP) Matcher {
return &ipMatcher{
ip: ip,
}
}
func (m *ipMatcher) Match(ip string) bool {
if m == nil {
return false
}
return m.ip.Equal(net.ParseIP(ip))
}
type cidrMatcher struct {
ipNet *net.IPNet
}
// CIDRMatcher creates a Matcher for a specific CIDR notation IP address.
func CIDRMatcher(inet *net.IPNet) Matcher {
return &cidrMatcher{
ipNet: inet,
}
}
func (m *cidrMatcher) Match(ip string) bool {
if m == nil || m.ipNet == nil {
return false
}
return m.ipNet.Contains(net.ParseIP(ip))
}
type domainMatcher struct {
pattern string
glob glob.Glob
}
// DomainMatcher creates a Matcher for a specific domain pattern,
// the pattern can be a plain domain such as 'example.com',
// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'.
func DomainMatcher(pattern string) Matcher {
p := pattern
if strings.HasPrefix(pattern, ".") {
p = pattern[1:] // trim the prefix '.'
pattern = "*" + p
}
return &domainMatcher{
pattern: p,
glob: glob.MustCompile(pattern),
}
}
func (m *domainMatcher) Match(domain string) bool {
if m == nil || m.glob == nil {
return false
}
if domain == m.pattern {
return true
}
return m.glob.Match(domain)
}
// Bypass is a filter of address (IP or domain).
type Bypass interface {
// Contains reports whether the bypass includes addr.
Contains(addr string) bool
}
type bypassOptions struct {
type options struct {
logger logger.Logger
}
type BypassOption func(opts *bypassOptions)
type Option func(opts *options)
func LoggerBypassOption(logger logger.Logger) BypassOption {
return func(opts *bypassOptions) {
func LoggerOption(logger logger.Logger) Option {
return func(opts *options) {
opts.logger = logger
}
}
type bypass struct {
matchers []Matcher
matchers []matcher.Matcher
reversed bool
options bypassOptions
options options
}
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
// The rules will be reversed if the reversed is true.
func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass {
options := bypassOptions{}
func NewBypass(reversed bool, matchers []matcher.Matcher, opts ...Option) Bypass {
options := options{}
for _, opt := range opts {
opt(&options)
}
@ -140,10 +48,10 @@ func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass {
// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewBypassPatterns(reversed bool, patterns []string, opts ...BypassOption) Bypass {
var matchers []Matcher
func NewBypassPatterns(reversed bool, patterns []string, opts ...Option) Bypass {
var matchers []matcher.Matcher
for _, pattern := range patterns {
if m := NewMatcher(pattern); m != nil {
if m := matcher.NewMatcher(pattern); m != nil {
matchers = append(matchers, m)
}
}

View File

@ -0,0 +1,99 @@
package matcher
import (
"net"
"strings"
"github.com/gobwas/glob"
)
// Matcher is a generic pattern matcher,
// it gives the match result of the given pattern for specific v.
type Matcher interface {
Match(v string) bool
}
// NewMatcher creates a Matcher for the given pattern.
// The acutal Matcher depends on the pattern:
// IP Matcher if pattern is a valid IP address.
// CIDR Matcher if pattern is a valid CIDR address.
// Domain Matcher if both of the above are not.
func NewMatcher(pattern string) Matcher {
if pattern == "" {
return nil
}
if ip := net.ParseIP(pattern); ip != nil {
return IPMatcher(ip)
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
return CIDRMatcher(inet)
}
return DomainMatcher(pattern)
}
type ipMatcher struct {
ip net.IP
}
// IPMatcher creates a Matcher for a specific IP address.
func IPMatcher(ip net.IP) Matcher {
return &ipMatcher{
ip: ip,
}
}
func (m *ipMatcher) Match(ip string) bool {
if m == nil {
return false
}
return m.ip.Equal(net.ParseIP(ip))
}
type cidrMatcher struct {
ipNet *net.IPNet
}
// CIDRMatcher creates a Matcher for a specific CIDR notation IP address.
func CIDRMatcher(inet *net.IPNet) Matcher {
return &cidrMatcher{
ipNet: inet,
}
}
func (m *cidrMatcher) Match(ip string) bool {
if m == nil || m.ipNet == nil {
return false
}
return m.ipNet.Contains(net.ParseIP(ip))
}
type domainMatcher struct {
pattern string
glob glob.Glob
}
// DomainMatcher creates a Matcher for a specific domain pattern,
// the pattern can be a plain domain such as 'example.com',
// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'.
func DomainMatcher(pattern string) Matcher {
p := pattern
if strings.HasPrefix(pattern, ".") {
p = pattern[1:] // trim the prefix '.'
pattern = "*" + p
}
return &domainMatcher{
pattern: p,
glob: glob.MustCompile(pattern),
}
}
func (m *domainMatcher) Match(domain string) bool {
if m == nil || m.glob == nil {
return false
}
if domain == m.pattern {
return true
}
return m.glob.Match(domain)
}

View File

@ -94,6 +94,14 @@ type SelectorConfig struct {
FailTimeout time.Duration `yaml:"failTimeout" json:"failTimeout"`
}
type AdmissionConfig struct {
Name string `json:"name"`
// inline, file, etc.
Type string `yaml:",omitempty" json:"type,omitempty"`
Reverse bool `yaml:",omitempty" json:"reverse,omitempty"`
Matchers []string `json:"matchers"`
}
type BypassConfig struct {
Name string `json:"name"`
// inline, file, etc.
@ -173,6 +181,7 @@ type ConnectorConfig struct {
type ServiceConfig struct {
Name string `json:"name"`
Addr string `yaml:",omitempty" json:"addr,omitempty"`
Admission string `yaml:",omitempty" json:"admission,omitempty"`
Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
@ -210,6 +219,7 @@ type Config struct {
Services []*ServiceConfig `json:"services"`
Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"`
Authers []*AutherConfig `yaml:",omitempty" json:"authers,omitempty"`
Admissions []*AdmissionConfig `yaml:",omitempty" json:"admissions,omitempty"`
Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"`
Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"`
Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"`

View File

@ -4,6 +4,7 @@ import (
"net"
"net/url"
"github.com/go-gost/gost/pkg/admission"
"github.com/go-gost/gost/pkg/auth"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
@ -79,6 +80,20 @@ func parseSelector(cfg *config.SelectorConfig) chain.Selector {
)
}
func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission {
if cfg == nil {
return nil
}
return admission.NewAdmissionPatterns(
cfg.Reverse,
cfg.Matchers,
admission.LoggerOption(logger.Default().WithFields(map[string]interface{}{
"kind": "admission",
"admission": cfg.Name,
})),
)
}
func ParseBypass(cfg *config.BypassConfig) bypass.Bypass {
if cfg == nil {
return nil
@ -86,7 +101,7 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass {
return bypass.NewBypassPatterns(
cfg.Reverse,
cfg.Matchers,
bypass.LoggerBypassOption(logger.Default().WithFields(map[string]interface{}{
bypass.LoggerOption(logger.Default().WithFields(map[string]interface{}{
"kind": "bypass",
"bypass": cfg.Name,
})),

View File

@ -14,7 +14,7 @@ import (
"github.com/go-gost/gost/pkg/service"
)
func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) {
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
if cfg.Listener == nil {
cfg.Listener = &config.ListenerConfig{
Type: "tcp",
@ -112,10 +112,10 @@ func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) {
return nil, err
}
s := (&service.Service{}).
WithListener(ln).
WithHandler(h).
WithLogger(serviceLogger)
s := service.NewService(ln, h,
service.AdmissionOption(registry.Admission().Get(cfg.Admission)),
service.LoggerOption(serviceLogger),
)
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
return s, nil

View File

@ -180,6 +180,9 @@ func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/dns-message")
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
if err := l.serve(&dohResponseWriter{raddr: raddr, ResponseWriter: w}, buf); err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

View File

@ -8,6 +8,7 @@ import (
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
"github.com/go-gost/gost/pkg/logger"
"google.golang.org/grpc/peer"
)
type server struct {
@ -24,6 +25,9 @@ func (s *server) Tunnel(srv pb.GostTunel_TunnelServer) error {
remoteAddr: &net.TCPAddr{},
closed: make(chan struct{}),
}
if p, ok := peer.FromContext(srv.Context()); ok {
c.remoteAddr = p.Addr
}
select {
case s.cqueue <- c:

65
pkg/registry/admission.go Normal file
View File

@ -0,0 +1,65 @@
package registry
import (
"sync"
"github.com/go-gost/gost/pkg/admission"
)
var (
admissionReg = &admissionRegistry{}
)
func Admission() *admissionRegistry {
return admissionReg
}
type admissionRegistry struct {
m sync.Map
}
func (r *admissionRegistry) Register(name string, admission admission.Admission) error {
if name == "" || admission == nil {
return nil
}
if _, loaded := r.m.LoadOrStore(name, admission); loaded {
return ErrDup
}
return nil
}
func (r *admissionRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *admissionRegistry) IsRegistered(name string) bool {
_, ok := r.m.Load(name)
return ok
}
func (r *admissionRegistry) Get(name string) admission.Admission {
if name == "" {
return nil
}
return &admissionWrapper{name: name}
}
func (r *admissionRegistry) get(name string) admission.Admission {
if v, ok := r.m.Load(name); ok {
return v.(admission.Admission)
}
return nil
}
type admissionWrapper struct {
name string
}
func (w *admissionWrapper) Admit(addr string) bool {
p := admissionReg.get(w.name)
if p == nil {
return false
}
return p.Admit(addr)
}

View File

@ -18,7 +18,7 @@ type serviceRegistry struct {
m sync.Map
}
func (r *serviceRegistry) Register(name string, svc service.Servicer) error {
func (r *serviceRegistry) Register(name string, svc service.Service) error {
if name == "" || svc == nil {
return nil
}
@ -38,7 +38,7 @@ func (r *serviceRegistry) IsRegistered(name string) bool {
return ok
}
func (r *serviceRegistry) Get(name string) service.Servicer {
func (r *serviceRegistry) Get(name string) service.Service {
if name == "" {
return nil
}
@ -46,5 +46,5 @@ func (r *serviceRegistry) Get(name string) service.Servicer {
if !ok {
return nil
}
return v.(service.Servicer)
return v.(service.Service)
}

View File

@ -5,47 +5,64 @@ import (
"net"
"time"
"github.com/go-gost/gost/pkg/admission"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
)
type Servicer interface {
type options struct {
admission admission.Admission
logger logger.Logger
}
type Option func(opts *options)
func AdmissionOption(admission admission.Admission) Option {
return func(opts *options) {
opts.admission = admission
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *options) {
opts.logger = logger
}
}
type Service interface {
Serve() error
Addr() net.Addr
Close() error
}
type Service struct {
type service struct {
listener listener.Listener
handler handler.Handler
logger logger.Logger
options options
}
func (s *Service) WithListener(ln listener.Listener) *Service {
s.listener = ln
return s
func NewService(ln listener.Listener, h handler.Handler, opts ...Option) Service {
var options options
for _, opt := range opts {
opt(&options)
}
return &service{
listener: ln,
handler: h,
options: options,
}
}
func (s *Service) WithHandler(h handler.Handler) *Service {
s.handler = h
return s
}
func (s *Service) WithLogger(logger logger.Logger) *Service {
s.logger = logger
return s
}
func (s *Service) Addr() net.Addr {
func (s *service) Addr() net.Addr {
return s.listener.Addr()
}
func (s *Service) Close() error {
func (s *service) Close() error {
return s.listener.Close()
}
func (s *Service) Serve() error {
func (s *service) Serve() error {
var tempDelay time.Duration
for {
conn, e := s.listener.Accept()
@ -59,15 +76,22 @@ func (s *Service) Serve() error {
if max := 5 * time.Second; tempDelay > max {
tempDelay = max
}
s.logger.Warnf("accept: %v, retrying in %v", e, tempDelay)
s.options.logger.Warnf("accept: %v, retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
s.logger.Errorf("accept: %v", e)
s.options.logger.Errorf("accept: %v", e)
return e
}
tempDelay = 0
if s.options.admission != nil &&
!s.options.admission.Admit(conn.RemoteAddr().String()) {
s.options.logger.Infof("admission: %s is denied", conn.RemoteAddr())
conn.Close()
continue
}
go s.handler.Handle(context.Background(), conn)
}
}