271 lines
6.2 KiB
Go
271 lines
6.2 KiB
Go
package gost
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var dnsTests = []struct {
|
|
ns NameServer
|
|
host string
|
|
pass bool
|
|
}{
|
|
{NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true},
|
|
{NameServer{Addr: "1.1.1.1"}, "github", true},
|
|
{NameServer{Addr: "1.1.1.1"}, "github.com", true},
|
|
{NameServer{Addr: "1.1.1.1:53"}, "github.com", true},
|
|
{NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true},
|
|
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true},
|
|
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false},
|
|
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
|
|
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
|
|
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
|
|
{NameServer{Addr: "1.1.1.1:12345"}, "github.com", false},
|
|
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false},
|
|
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false},
|
|
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false},
|
|
}
|
|
|
|
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
|
|
ips, err := r.Resolve(host)
|
|
t.Log(host, ips, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestDNSResolver(t *testing.T) {
|
|
for i, tc := range dnsTests {
|
|
tc := tc
|
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
|
ns := tc.ns
|
|
t.Log(ns)
|
|
r := NewResolver(0, ns)
|
|
resolv := r.(*resolver)
|
|
resolv.domain = "com"
|
|
if err := r.Init(); err != nil {
|
|
t.Error("got error:", err)
|
|
}
|
|
err := dnsResolverRoundtrip(t, r, tc.host)
|
|
if err != nil {
|
|
if tc.pass {
|
|
t.Error("got error:", err)
|
|
}
|
|
} else {
|
|
if !tc.pass {
|
|
t.Error("should failed")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
var resolverCacheTests = []struct {
|
|
name string
|
|
ips []net.IP
|
|
ttl time.Duration
|
|
result []net.IP
|
|
}{
|
|
{"", nil, 0, nil},
|
|
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
|
|
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil},
|
|
{"example.com", nil, 10 * time.Second, nil},
|
|
{"example.com", []net.IP{}, 10 * time.Second, nil},
|
|
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
|
|
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil},
|
|
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second,
|
|
[]net.IP{net.IPv4(192, 168, 1, 1)}},
|
|
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second,
|
|
[]net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}},
|
|
}
|
|
|
|
/*
|
|
func TestResolverCache(t *testing.T) {
|
|
isEqual := func(a, b []net.IP) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
|
|
if a == nil || b == nil || len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
for i := range a {
|
|
if !a[i].Equal(b[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
for i, tc := range resolverCacheTests {
|
|
tc := tc
|
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
|
r := newResolver(tc.ttl)
|
|
r.cache.storeCache(tc.name, tc.ips, tc.ttl)
|
|
ips := r.cache.loadCache(tc.name, tc.ttl)
|
|
|
|
if !isEqual(tc.result, ips) {
|
|
t.Error("unexpected cache value:", tc.name, ips, tc.ttl)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
*/
|
|
|
|
var resolverReloadTests = []struct {
|
|
r io.Reader
|
|
|
|
timeout time.Duration
|
|
ttl time.Duration
|
|
domain string
|
|
period time.Duration
|
|
ns *NameServer
|
|
|
|
stopped bool
|
|
}{
|
|
{
|
|
r: nil,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString(""),
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("reload 10s"),
|
|
period: 10 * time.Second,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("timeout 10s\nreload 10s\n"),
|
|
timeout: 10 * time.Second,
|
|
period: 10 * time.Second,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"),
|
|
timeout: 10 * time.Second,
|
|
period: 10 * time.Second,
|
|
ttl: 10 * time.Second,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"),
|
|
timeout: 10 * time.Second,
|
|
period: 10 * time.Second,
|
|
ttl: 10 * time.Second,
|
|
domain: "example.com",
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("1.1.1.1"),
|
|
ns: &NameServer{
|
|
Addr: "1.1.1.1",
|
|
},
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("\n# comment\ntimeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"),
|
|
ns: &NameServer{
|
|
Protocol: "udp",
|
|
Addr: "1.1.1.1",
|
|
},
|
|
timeout: 10 * time.Second,
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("1.1.1.1 tcp"),
|
|
ns: &NameServer{
|
|
Addr: "1.1.1.1",
|
|
Protocol: "tcp",
|
|
},
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"),
|
|
ns: &NameServer{
|
|
Addr: "1.1.1.1:853",
|
|
Protocol: "tls",
|
|
Hostname: "cloudflare-dns.com",
|
|
},
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("1.1.1.1:853 tls"),
|
|
ns: &NameServer{
|
|
Addr: "1.1.1.1:853",
|
|
Protocol: "tls",
|
|
},
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("1.0.0.1:53 https"),
|
|
stopped: true,
|
|
},
|
|
{
|
|
r: bytes.NewBufferString("https://1.0.0.1/dns-query"),
|
|
ns: &NameServer{
|
|
Addr: "https://1.0.0.1/dns-query",
|
|
Protocol: "https",
|
|
},
|
|
stopped: true,
|
|
},
|
|
}
|
|
|
|
func TestResolverReload(t *testing.T) {
|
|
for i, tc := range resolverReloadTests {
|
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
|
r := newResolver(0)
|
|
if err := r.Reload(tc.r); err != nil {
|
|
t.Error(err)
|
|
}
|
|
t.Log(r.String())
|
|
if r.TTL() != tc.ttl {
|
|
t.Errorf("ttl value should be %v, got %v",
|
|
tc.ttl, r.TTL())
|
|
}
|
|
if r.Period() != tc.period {
|
|
t.Errorf("period value should be %v, got %v",
|
|
tc.period, r.period)
|
|
}
|
|
if r.domain != tc.domain {
|
|
t.Errorf("domain value should be %v, got %v",
|
|
tc.domain, r.domain)
|
|
}
|
|
|
|
var ns *NameServer
|
|
if len(r.servers) > 0 {
|
|
ns = &r.servers[0]
|
|
}
|
|
|
|
if !compareNameServer(ns, tc.ns) {
|
|
t.Errorf("nameserver not equal, should be %v, got %v",
|
|
tc.ns, r.servers)
|
|
}
|
|
|
|
if tc.stopped {
|
|
r.Stop()
|
|
if r.Period() >= 0 {
|
|
t.Errorf("period of the stopped reloader should be minus value")
|
|
}
|
|
}
|
|
if r.Stopped() != tc.stopped {
|
|
t.Errorf("stopped value should be %v, got %v",
|
|
tc.stopped, r.Stopped())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func compareNameServer(n1, n2 *NameServer) bool {
|
|
if n1 == n2 {
|
|
return true
|
|
}
|
|
if n1 == nil || n2 == nil {
|
|
return false
|
|
}
|
|
return n1.Addr == n2.Addr &&
|
|
n1.Hostname == n2.Hostname &&
|
|
n1.Protocol == n2.Protocol
|
|
}
|