initial commit

This commit is contained in:
ginuerzh
2022-03-14 20:27:14 +08:00
commit 9397cb5351
175 changed files with 16196 additions and 0 deletions

210
listener/dns/listener.go Normal file
View File

@ -0,0 +1,210 @@
package dns
import (
"bytes"
"encoding/base64"
"errors"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/go-gost/gost/v3/pkg/common/metrics"
"github.com/go-gost/gost/v3/pkg/listener"
"github.com/go-gost/gost/v3/pkg/logger"
md "github.com/go-gost/gost/v3/pkg/metadata"
"github.com/go-gost/gost/v3/pkg/registry"
"github.com/miekg/dns"
)
func init() {
registry.ListenerRegistry().Register("dns", NewListener)
}
type dnsListener struct {
addr net.Addr
server Server
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &dnsListener{
logger: options.Logger,
options: options,
}
}
func (l *dnsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
}
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return err
}
switch strings.ToLower(l.md.mode) {
case "tcp":
l.server = &dns.Server{
Net: "tcp",
Addr: l.options.Addr,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
case "tls":
l.server = &dns.Server{
Net: "tcp-tls",
Addr: l.options.Addr,
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
case "https":
l.server = &dohServer{
addr: l.options.Addr,
tlsConfig: l.options.TLSConfig,
server: &http.Server{
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
default:
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
l.server = &dns.Server{
Net: "udp",
Addr: l.options.Addr,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
}
if err != nil {
return
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go func() {
err := l.server.ListenAndServe()
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *dnsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
conn = metrics.WrapConn(l.options.Service, conn)
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *dnsListener) Close() error {
return l.server.Shutdown()
}
func (l *dnsListener) Addr() net.Addr {
return l.addr
}
func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
b, err := m.Pack()
if err != nil {
l.logger.Error(err)
return
}
if err := l.serve(w, b); err != nil {
l.logger.Error(err)
}
}
// Based on https://github.com/semihalev/sdns
func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var buf []byte
var err error
switch r.Method {
case http.MethodGet:
buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns"))
if len(buf) == 0 || err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
case http.MethodPost:
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" {
l.logger.Errorf("unsupported media type: %s", ct)
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
return
}
buf, err = ioutil.ReadAll(r.Body)
if err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
default:
l.logger.Errorf("method not allowd: %s", r.Method)
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
mq := &dns.Msg{}
if err := mq.Unpack(buf); err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
w.Header().Set("Server", "SDNS")
w.Header().Set("Content-Type", "application/dns-message")
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if raddr == nil {
raddr = &net.TCPAddr{}
}
if err := l.serve(&dohResponseWriter{raddr: raddr, ResponseWriter: w}, buf); err != nil {
l.logger.Error(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
func (l *dnsListener) serve(w ResponseWriter, msg []byte) (err error) {
conn := &serverConn{
r: bytes.NewReader(msg),
w: w,
laddr: l.addr,
closed: make(chan struct{}),
}
select {
case l.cqueue <- conn:
default:
l.logger.Warnf("connection queue is full, client %s discarded", w.RemoteAddr())
return errors.New("connection queue is full")
}
return conn.Wait()
}

41
listener/dns/metadata.go Normal file
View File

@ -0,0 +1,41 @@
package dns
import (
"time"
mdata "github.com/go-gost/gost/v3/pkg/metadata"
)
const (
defaultBacklog = 128
)
type metadata struct {
mode string
readBufferSize int
readTimeout time.Duration
writeTimeout time.Duration
backlog int
}
func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
backlog = "backlog"
mode = "mode"
readBufferSize = "readBufferSize"
readTimeout = "readTimeout"
writeTimeout = "writeTimeout"
)
l.md.mode = mdata.GetString(md, mode)
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
l.md.readTimeout = mdata.GetDuration(md, readTimeout)
l.md.writeTimeout = mdata.GetDuration(md, writeTimeout)
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
return
}

110
listener/dns/server.go Normal file
View File

@ -0,0 +1,110 @@
package dns
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"time"
)
type Server interface {
ListenAndServe() error
Shutdown() error
}
type dohServer struct {
addr string
tlsConfig *tls.Config
server *http.Server
}
func (s *dohServer) ListenAndServe() error {
ln, err := net.Listen("tcp", s.addr)
if err != nil {
return err
}
ln = tls.NewListener(ln, s.tlsConfig)
return s.server.Serve(ln)
}
func (s *dohServer) Shutdown() error {
return s.server.Shutdown(context.Background())
}
type ResponseWriter interface {
io.Writer
RemoteAddr() net.Addr
}
type dohResponseWriter struct {
raddr net.Addr
http.ResponseWriter
}
func (w *dohResponseWriter) RemoteAddr() net.Addr {
return w.raddr
}
type serverConn struct {
r io.Reader
w ResponseWriter
laddr net.Addr
closed chan struct{}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
select {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.r.Read(b)
}
}
func (c *serverConn) Write(b []byte) (n int, err error) {
select {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.w.Write(b)
}
}
func (c *serverConn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *serverConn) Wait() error {
<-c.closed
return nil
}
func (c *serverConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.w.RemoteAddr()
}
func (c *serverConn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *serverConn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *serverConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}