x/service/service.go
2023-11-18 18:28:09 +08:00

217 lines
4.9 KiB
Go

package service
import (
"context"
"io"
"net"
"os/exec"
"strings"
"time"
"github.com/go-gost/core/admission"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/metrics"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/service"
ctxvalue "github.com/go-gost/x/internal/ctx"
xmetrics "github.com/go-gost/x/metrics"
"github.com/rs/xid"
)
type options struct {
admission admission.Admission
recorders []recorder.RecorderObject
preUp []string
postUp []string
preDown []string
postDown []string
logger logger.Logger
}
type Option func(opts *options)
func AdmissionOption(admission admission.Admission) Option {
return func(opts *options) {
opts.admission = admission
}
}
func RecordersOption(recorders ...recorder.RecorderObject) Option {
return func(opts *options) {
opts.recorders = recorders
}
}
func PreUpOption(cmds []string) Option {
return func(opts *options) {
opts.preUp = cmds
}
}
func PreDownOption(cmds []string) Option {
return func(opts *options) {
opts.preDown = cmds
}
}
func PostUpOption(cmds []string) Option {
return func(opts *options) {
opts.postUp = cmds
}
}
func PostDownOption(cmds []string) Option {
return func(opts *options) {
opts.postDown = cmds
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *options) {
opts.logger = logger
}
}
type defaultService struct {
name string
listener listener.Listener
handler handler.Handler
options options
}
func NewService(name string, ln listener.Listener, h handler.Handler, opts ...Option) service.Service {
var options options
for _, opt := range opts {
opt(&options)
}
s := &defaultService{
name: name,
listener: ln,
handler: h,
options: options,
}
s.execCmds("pre-up", s.options.preUp)
return s
}
func (s *defaultService) Addr() net.Addr {
return s.listener.Addr()
}
func (s *defaultService) Close() error {
s.execCmds("pre-down", s.options.preDown)
defer s.execCmds("post-down", s.options.postDown)
if closer, ok := s.handler.(io.Closer); ok {
closer.Close()
}
return s.listener.Close()
}
func (s *defaultService) Serve() error {
s.execCmds("post-up", s.options.postUp)
if v := xmetrics.GetGauge(
xmetrics.MetricServicesGauge,
metrics.Labels{}); v != nil {
v.Inc()
defer v.Dec()
}
var tempDelay time.Duration
for {
conn, e := s.listener.Accept()
if e != nil {
// TODO: remove Temporary checking
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 1 * time.Second
} else {
tempDelay *= 2
}
if max := 5 * time.Second; tempDelay > max {
tempDelay = max
}
s.options.logger.Warnf("accept: %v, retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
s.options.logger.Errorf("accept: %v", e)
return e
}
tempDelay = 0
clientAddr := conn.RemoteAddr().String()
clientIP := clientAddr
if h, _, _ := net.SplitHostPort(clientAddr); h != "" {
clientIP = h
}
ctx := ctxvalue.ContextWithSid(context.Background(), ctxvalue.Sid(xid.New().String()))
ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(clientAddr))
ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: clientIP})
for _, rec := range s.options.recorders {
if rec.Record == recorder.RecorderServiceClientAddress {
if err := rec.Recorder.Record(ctx, []byte(clientIP)); err != nil {
s.options.logger.Errorf("record %s: %v", rec.Record, err)
}
break
}
}
if s.options.admission != nil &&
!s.options.admission.Admit(ctx, conn.RemoteAddr().String()) {
conn.Close()
s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr())
continue
}
go func() {
if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter,
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
}
if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge,
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
defer v.Dec()
}
start := time.Now()
if v := xmetrics.GetObserver(xmetrics.MetricServiceRequestsDurationObserver,
metrics.Labels{"service": s.name}); v != nil {
defer func() {
v.Observe(float64(time.Since(start).Seconds()))
}()
}
if err := s.handler.Handle(ctx, conn); err != nil {
s.options.logger.Error(err)
if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter,
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
}
}
}()
}
}
func (s *defaultService) execCmds(phase string, cmds []string) {
for _, cmd := range cmds {
cmd := strings.TrimSpace(cmd)
if cmd == "" {
continue
}
s.options.logger.Info(cmd)
if err := exec.Command("/bin/sh", "-c", cmd).Run(); err != nil {
s.options.logger.Warnf("[%s] %s: %v", phase, cmd, err)
}
}
}