package resolver import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net" "net/http" "github.com/go-gost/core/logger" "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) type grpcPlugin struct { conn grpc.ClientConnInterface client proto.ResolverClient log logger.Logger } // NewGRPCPlugin creates a Resolver plugin based on gRPC. func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) (resolver.Resolver, error) { var options plugin.Options for _, opt := range opts { opt(&options) } log := logger.Default().WithFields(map[string]any{ "kind": "resolver", "resolover": name, }) conn, err := plugin.NewGRPCConn(addr, &options) if err != nil { log.Error(err) } p := &grpcPlugin{ conn: conn, log: log, } if conn != nil { p.client = proto.NewResolverClient(conn) } return p, nil } func (p *grpcPlugin) Resolve(ctx context.Context, network, host string, opts ...resolver.Option) (ips []net.IP, err error) { p.log.Debugf("resolve %s/%s", host, network) if p.client == nil { return } r, err := p.client.Resolve(ctx, &proto.ResolveRequest{ Network: network, Host: host, Client: string(ctxvalue.ClientIDFromContext(ctx)), }) if err != nil { p.log.Error(err) return } for _, s := range r.Ips { if ip := net.ParseIP(s); ip != nil { ips = append(ips, ip) } } return } func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } type httpPluginRequest struct { Network string `json:"network"` Host string `json:"host"` Client string `json:"client"` } type httpPluginResponse struct { IPs []string `json:"ips"` OK bool `json:"ok"` } type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } // NewHTTPPlugin creates an Resolver plugin based on HTTP. func NewHTTPPlugin(name string, url string, opts ...plugin.Option) resolver.Resolver { var options plugin.Options for _, opt := range opts { opt(&options) } return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, log: logger.Default().WithFields(map[string]any{ "kind": "resolver", "resolver": name, }), } } func (p *httpPlugin) Resolve(ctx context.Context, network, host string, opts ...resolver.Option) (ips []net.IP, err error) { p.log.Debugf("resolve %s/%s", host, network) if p.client == nil { return } rb := httpPluginRequest{ Network: network, Host: host, Client: string(ctxvalue.ClientIDFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { return } req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } if p.header != nil { req.Header = p.header.Clone() } req.Header.Set("Content-Type", "application/json") resp, err := p.client.Do(req) if err != nil { return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { err = fmt.Errorf("%s", resp.Status) return } res := httpPluginResponse{} if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { return } if !res.OK { return nil, errors.New("resolve failed") } for _, s := range res.IPs { if ip := net.ParseIP(s); ip != nil { ips = append(ips, ip) } } return ips, nil }