init
This commit is contained in:
270
resolver_test.go
Normal file
270
resolver_test.go
Normal file
@ -0,0 +1,270 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var dnsTests = []struct {
|
||||
ns NameServer
|
||||
host string
|
||||
pass bool
|
||||
}{
|
||||
{NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true},
|
||||
{NameServer{Addr: "1.1.1.1"}, "github", true},
|
||||
{NameServer{Addr: "1.1.1.1"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:53"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
|
||||
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
|
||||
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:12345"}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false},
|
||||
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false},
|
||||
}
|
||||
|
||||
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
|
||||
ips, err := r.Resolve(host)
|
||||
t.Log(host, ips, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDNSResolver(t *testing.T) {
|
||||
for i, tc := range dnsTests {
|
||||
tc := tc
|
||||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||
ns := tc.ns
|
||||
t.Log(ns)
|
||||
r := NewResolver(0, ns)
|
||||
resolv := r.(*resolver)
|
||||
resolv.domain = "com"
|
||||
if err := r.Init(); err != nil {
|
||||
t.Error("got error:", err)
|
||||
}
|
||||
err := dnsResolverRoundtrip(t, r, tc.host)
|
||||
if err != nil {
|
||||
if tc.pass {
|
||||
t.Error("got error:", err)
|
||||
}
|
||||
} else {
|
||||
if !tc.pass {
|
||||
t.Error("should failed")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var resolverCacheTests = []struct {
|
||||
name string
|
||||
ips []net.IP
|
||||
ttl time.Duration
|
||||
result []net.IP
|
||||
}{
|
||||
{"", nil, 0, nil},
|
||||
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
|
||||
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil},
|
||||
{"example.com", nil, 10 * time.Second, nil},
|
||||
{"example.com", []net.IP{}, 10 * time.Second, nil},
|
||||
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
|
||||
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil},
|
||||
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second,
|
||||
[]net.IP{net.IPv4(192, 168, 1, 1)}},
|
||||
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second,
|
||||
[]net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}},
|
||||
}
|
||||
|
||||
/*
|
||||
func TestResolverCache(t *testing.T) {
|
||||
isEqual := func(a, b []net.IP) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if a == nil || b == nil || len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range a {
|
||||
if !a[i].Equal(b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
for i, tc := range resolverCacheTests {
|
||||
tc := tc
|
||||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||
r := newResolver(tc.ttl)
|
||||
r.cache.storeCache(tc.name, tc.ips, tc.ttl)
|
||||
ips := r.cache.loadCache(tc.name, tc.ttl)
|
||||
|
||||
if !isEqual(tc.result, ips) {
|
||||
t.Error("unexpected cache value:", tc.name, ips, tc.ttl)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
var resolverReloadTests = []struct {
|
||||
r io.Reader
|
||||
|
||||
timeout time.Duration
|
||||
ttl time.Duration
|
||||
domain string
|
||||
period time.Duration
|
||||
ns *NameServer
|
||||
|
||||
stopped bool
|
||||
}{
|
||||
{
|
||||
r: nil,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString(""),
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("reload 10s"),
|
||||
period: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("timeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
ttl: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
ttl: 10 * time.Second,
|
||||
domain: "example.com",
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1",
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("\n# comment\ntimeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"),
|
||||
ns: &NameServer{
|
||||
Protocol: "udp",
|
||||
Addr: "1.1.1.1",
|
||||
},
|
||||
timeout: 10 * time.Second,
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1 tcp"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1",
|
||||
Protocol: "tcp",
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1:853",
|
||||
Protocol: "tls",
|
||||
Hostname: "cloudflare-dns.com",
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1:853 tls"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1:853",
|
||||
Protocol: "tls",
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.0.0.1:53 https"),
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("https://1.0.0.1/dns-query"),
|
||||
ns: &NameServer{
|
||||
Addr: "https://1.0.0.1/dns-query",
|
||||
Protocol: "https",
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
}
|
||||
|
||||
func TestResolverReload(t *testing.T) {
|
||||
for i, tc := range resolverReloadTests {
|
||||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||
r := newResolver(0)
|
||||
if err := r.Reload(tc.r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(r.String())
|
||||
if r.TTL() != tc.ttl {
|
||||
t.Errorf("ttl value should be %v, got %v",
|
||||
tc.ttl, r.TTL())
|
||||
}
|
||||
if r.Period() != tc.period {
|
||||
t.Errorf("period value should be %v, got %v",
|
||||
tc.period, r.period)
|
||||
}
|
||||
if r.domain != tc.domain {
|
||||
t.Errorf("domain value should be %v, got %v",
|
||||
tc.domain, r.domain)
|
||||
}
|
||||
|
||||
var ns *NameServer
|
||||
if len(r.servers) > 0 {
|
||||
ns = &r.servers[0]
|
||||
}
|
||||
|
||||
if !compareNameServer(ns, tc.ns) {
|
||||
t.Errorf("nameserver not equal, should be %v, got %v",
|
||||
tc.ns, r.servers)
|
||||
}
|
||||
|
||||
if tc.stopped {
|
||||
r.Stop()
|
||||
if r.Period() >= 0 {
|
||||
t.Errorf("period of the stopped reloader should be minus value")
|
||||
}
|
||||
}
|
||||
if r.Stopped() != tc.stopped {
|
||||
t.Errorf("stopped value should be %v, got %v",
|
||||
tc.stopped, r.Stopped())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareNameServer(n1, n2 *NameServer) bool {
|
||||
if n1 == n2 {
|
||||
return true
|
||||
}
|
||||
if n1 == nil || n2 == nil {
|
||||
return false
|
||||
}
|
||||
return n1.Addr == n2.Addr &&
|
||||
n1.Hostname == n2.Hostname &&
|
||||
n1.Protocol == n2.Protocol
|
||||
}
|
Reference in New Issue
Block a user