tests: new DNS router/server/mock (#2613)

This commit is contained in:
Ludovic Fernandez 2025-08-08 18:28:50 +02:00 committed by GitHub
commit 0012e20e52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1195 additions and 326 deletions

View file

@ -195,8 +195,8 @@ linters:
text: dnsTimeout is a global variable
linters:
- gochecknoglobals
- path: challenge/dns01/nameserver_test.go
text: findXByFqdnTestCases is a global variable
- path: challenge/dns01/precheck.go
text: defaultNameserverPort is a global variable
linters:
- gochecknoglobals
- path: challenge/http01/domain_matcher.go

View file

@ -5,10 +5,18 @@ import (
"os"
"testing"
"github.com/go-acme/lego/v4/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestDNSProviderManual(t *testing.T) {
useAsNameserver(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
Query("_acme-challenge.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t))
backupStdin := os.Stdin
defer func() { os.Stdin = backupStdin }()

View file

@ -11,6 +11,8 @@ import (
"github.com/go-acme/lego/v4/acme/api"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/platform/tester"
"github.com/go-acme/lego/v4/platform/tester/dnsmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -113,6 +115,10 @@ func TestChallenge_PreSolve(t *testing.T) {
}
func TestChallenge_Solve(t *testing.T) {
useAsNameserver(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
Build(t))
server := tester.MockACMEServer().BuildHTTPS(t)
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
@ -280,3 +286,55 @@ func TestChallenge_CleanUp(t *testing.T) {
})
}
}
func TestGetChallengeInfo(t *testing.T) {
useAsNameserver(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
Build(t))
info := GetChallengeInfo("example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "_acme-challenge.example.com.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}
func TestGetChallengeInfo_CNAME(t *testing.T) {
useAsNameserver(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
Query("example.org. CNAME", dnsmock.Noop).
Build(t))
info := GetChallengeInfo("example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "example.org.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}
func TestGetChallengeInfo_CNAME_disabled(t *testing.T) {
useAsNameserver(t, dnsmock.NewServer().
// Never called when the env var works.
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
Build(t))
t.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true")
info := GetChallengeInfo("example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "_acme-challenge.example.com.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}

View file

@ -1,4 +1,4 @@
domain company.com
domain example.com
nameserver 10.200.3.249
nameserver 10.200.3.250:5353
nameserver 2001:4860:4860::8844

View file

@ -0,0 +1,78 @@
package dns01
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func fakeNS(name, ns string) *dns.NS {
return &dns.NS{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
Ns: ns,
}
}
func fakeA(name, ip string) *dns.A {
return &dns.A{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10},
A: net.ParseIP(ip),
}
}
func fakeTXT(name, value string) *dns.TXT {
return &dns.TXT{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 10},
Txt: []string{value},
}
}
// mockResolver modifies the default DNS resolver to use a custom network address during the test execution.
// IMPORTANT: it modifying global variables.
func mockResolver(t *testing.T, addr net.Addr) {
t.Helper()
_, port, err := net.SplitHostPort(addr.String())
require.NoError(t, err)
originalDefaultNameserverPort := defaultNameserverPort
t.Cleanup(func() {
defaultNameserverPort = originalDefaultNameserverPort
})
defaultNameserverPort = port
originalResolver := net.DefaultResolver
t.Cleanup(func() {
net.DefaultResolver = originalResolver
})
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: 1 * time.Second}
return d.DialContext(ctx, network, addr.String())
},
}
}
func useAsNameserver(t *testing.T, addr net.Addr) {
t.Helper()
ClearFqdnCache()
t.Cleanup(func() {
ClearFqdnCache()
})
originalRecursiveNameservers := recursiveNameservers
t.Cleanup(func() {
recursiveNameservers = originalRecursiveNameservers
})
recursiveNameservers = ParseNameservers([]string{addr.String()})
}

View file

@ -5,138 +5,237 @@ import (
"sort"
"testing"
"github.com/go-acme/lego/v4/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLookupNameserversOK(t *testing.T) {
func Test_lookupNameserversOK(t *testing.T) {
testCases := []struct {
fqdn string
nss []string
desc string
fakeDNSServer *dnsmock.Builder
fqdn string
expected []string
}{
{
fqdn: "en.wikipedia.org.",
nss: []string{"ns0.wikimedia.org.", "ns1.wikimedia.org.", "ns2.wikimedia.org."},
fqdn: "en.wikipedia.org.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("en.wikipedia.org.localhost SOA", dnsmock.CNAME("dyna.wikimedia.org.localhost")).
Query("wikipedia.org.localhost SOA", dnsmock.SOA("")).
Query("wikipedia.org.localhost NS",
dnsmock.Answer(
fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."),
fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."),
fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."),
),
),
expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."},
},
{
fqdn: "www.google.com.",
nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
fqdn: "www.google.com.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("www.google.com.localhost. SOA", dnsmock.Noop).
Query("google.com.localhost. SOA", dnsmock.SOA("")).
Query("google.com.localhost. NS",
dnsmock.Answer(
fakeNS("google.com.localhost.", "ns1.google.com.localhost."),
fakeNS("google.com.localhost.", "ns2.google.com.localhost."),
fakeNS("google.com.localhost.", "ns3.google.com.localhost."),
fakeNS("google.com.localhost.", "ns4.google.com.localhost."),
),
),
expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."},
},
{
fqdn: "mail.proton.me.",
nss: []string{"ns1.proton.me.", "ns2.proton.me.", "ns3.proton.me."},
fqdn: "mail.proton.me.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("mail.proton.me.localhost. SOA", dnsmock.Noop).
Query("proton.me.localhost. SOA", dnsmock.SOA("")).
Query("proton.me.localhost. NS",
dnsmock.Answer(
fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."),
fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."),
fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."),
),
),
expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."},
},
}
for _, test := range testCases {
t.Run(test.fqdn, func(t *testing.T) {
t.Parallel()
useAsNameserver(t, test.fakeDNSServer.Build(t))
nss, err := lookupNameservers(test.fqdn)
require.NoError(t, err)
sort.Strings(nss)
sort.Strings(test.nss)
sort.Strings(test.expected)
assert.Equal(t, test.nss, nss)
assert.Equal(t, test.expected, nss)
})
}
}
func TestLookupNameserversErr(t *testing.T) {
func Test_lookupNameserversErr(t *testing.T) {
testCases := []struct {
desc string
fqdn string
error string
desc string
fqdn string
fakeDNSServer *dnsmock.Builder
error string
}{
{
desc: "invalid tld",
fqdn: "example.invalid.",
error: "could not find zone",
desc: "NXDOMAIN",
fqdn: "example.invalid.",
fakeDNSServer: dnsmock.NewServer().
Query(". SOA", dnsmock.Error(dns.RcodeNameError)),
error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]",
},
{
desc: "NS error",
fqdn: "example.com.",
fakeDNSServer: dnsmock.NewServer().
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.com. NS", dnsmock.Error(dns.RcodeServerFailure)),
error: "[zone=example.com.] could not determine authoritative nameservers",
},
{
desc: "empty NS",
fqdn: "example.com.",
fakeDNSServer: dnsmock.NewServer().
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.me NS", dnsmock.Noop),
error: "[zone=example.com.] could not determine authoritative nameservers",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
useAsNameserver(t, test.fakeDNSServer.Build(t))
_, err := lookupNameservers(test.fqdn)
require.Error(t, err)
assert.Contains(t, err.Error(), test.error)
assert.EqualError(t, err, test.error)
})
}
}
var findXByFqdnTestCases = []struct {
type lookupSoaByFqdnTestCase struct {
desc string
fqdn string
zone string
primaryNs string
nameservers []string
expectedError string
}{
{
desc: "domain is a CNAME",
fqdn: "mail.google.com.",
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: recursiveNameservers,
},
{
desc: "domain is a non-existent subdomain",
fqdn: "foo.google.com.",
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: recursiveNameservers,
},
{
desc: "domain is a eTLD",
fqdn: "example.com.ac.",
zone: "ac.",
primaryNs: "a0.nic.ac.",
nameservers: recursiveNameservers,
},
{
desc: "domain is a cross-zone CNAME",
fqdn: "cross-zone-example.assets.sh.",
zone: "assets.sh.",
primaryNs: "gina.ns.cloudflare.com.",
nameservers: recursiveNameservers,
},
{
desc: "NXDOMAIN",
fqdn: "test.lego.invalid.",
zone: "lego.invalid.",
nameservers: []string{"8.8.8.8:53"},
expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`,
},
{
desc: "several non existent nameservers",
fqdn: "mail.google.com.",
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: []string{":7053", ":8053", "8.8.8.8:53"},
},
{
desc: "only non-existent nameservers",
fqdn: "mail.google.com.",
zone: "google.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
expectedError: "[fqdn=mail.google.com.] could not find the start of authority for 'mail.google.com.': DNS call error: read udp ",
},
{
desc: "no nameservers",
fqdn: "test.example.com.",
zone: "example.com.",
nameservers: []string{},
expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers",
},
}
func lookupSoaByFqdnTestCases(t *testing.T) []lookupSoaByFqdnTestCase {
t.Helper()
return []lookupSoaByFqdnTestCase{
{
desc: "domain is a CNAME",
fqdn: "mail.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a non-existent subdomain",
fqdn: "foo.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("foo.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a eTLD",
fqdn: "example.com.ac.",
zone: "ac.",
primaryNs: "ns1.nic.ac.",
nameservers: []string{
dnsmock.NewServer().
Query("example.com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("ac. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a cross-zone CNAME",
fqdn: "cross-zone-example.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("cross-zone-example.example.com. SOA", dnsmock.CNAME("example.org.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "NXDOMAIN",
fqdn: "test.lego.invalid.",
zone: "lego.invalid.",
nameservers: []string{
dnsmock.NewServer().
Query("test.lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Build(t).
String(),
},
expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`,
},
{
desc: "several non existent nameservers",
fqdn: "mail.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
":7053",
":8053",
dnsmock.NewServer().
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "only non-existent nameservers",
fqdn: "mail.example.com.",
zone: "example.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
expectedError: "[fqdn=mail.example.com.] could not find the start of authority for 'mail.example.com.': DNS call error: read udp ",
},
{
desc: "no nameservers",
fqdn: "test.example.com.",
zone: "example.com.",
nameservers: []string{},
expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers",
},
}
}
func TestFindZoneByFqdnCustom(t *testing.T) {
for _, test := range findXByFqdnTestCases {
for _, test := range lookupSoaByFqdnTestCases(t) {
t.Run(test.desc, func(t *testing.T) {
ClearFqdnCache()
@ -153,7 +252,7 @@ func TestFindZoneByFqdnCustom(t *testing.T) {
}
func TestFindPrimaryNsByFqdnCustom(t *testing.T) {
for _, test := range findXByFqdnTestCases {
for _, test := range lookupSoaByFqdnTestCases(t) {
t.Run(test.desc, func(t *testing.T) {
ClearFqdnCache()
@ -169,7 +268,7 @@ func TestFindPrimaryNsByFqdnCustom(t *testing.T) {
}
}
func TestResolveConfServers(t *testing.T) {
func Test_getNameservers_ResolveConfServers(t *testing.T) {
testCases := []struct {
fixture string
expected []string

View file

@ -9,6 +9,10 @@ import (
"github.com/miekg/dns"
)
// defaultNameserverPort used by authoritative NS.
// This is for tests only.
var defaultNameserverPort = "53"
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
type PreCheckFunc func(fqdn, value string) (bool, error)
@ -121,7 +125,7 @@ func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) {
func checkNameserversPropagation(fqdn, value string, nameservers []string, addPort bool) (bool, error) {
for _, ns := range nameservers {
if addPort {
ns = net.JoinHostPort(ns, "53")
ns = net.JoinHostPort(ns, defaultNameserverPort)
}
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, false)

View file

@ -3,40 +3,73 @@ package dns01
import (
"testing"
"github.com/go-acme/lego/v4/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCheckDNSPropagation(t *testing.T) {
func Test_preCheck_checkDNSPropagation(t *testing.T) {
mockResolver(t,
dnsmock.NewServer().
Query("ns0.lego.localhost. A",
dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))).
Query("ns1.lego.localhost. A",
dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))).
Query("example.com. TXT",
dnsmock.Answer(
fakeTXT("example.com.", "one"),
fakeTXT("example.com.", "two"),
fakeTXT("example.com.", "three"),
fakeTXT("example.com.", "four"),
fakeTXT("example.com.", "five"),
),
).
Build(t),
)
useAsNameserver(t,
dnsmock.NewServer().
Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.com. NS",
dnsmock.Answer(
fakeNS("example.com.", "ns0.lego.localhost."),
fakeNS("example.com.", "ns1.lego.localhost."),
),
).
Build(t),
)
testCases := []struct {
desc string
fqdn string
value string
expectError bool
desc string
fqdn string
value string
expectedError string
}{
{
desc: "success",
fqdn: "postman-echo.com.",
value: "postman-domain-verification=c85de626cb79d941310696e06558e2e790223802f3697dfbdcaf65510152d52c",
fqdn: "example.com.",
value: "four",
},
{
desc: "no TXT record",
fqdn: "acme-staging.api.letsencrypt.org.",
value: "fe01=",
expectError: true,
desc: "no matching TXT record",
fqdn: "acme-staging.api.example.com.",
value: "fe01=",
expectedError: "did not return the expected TXT record [fqdn: acme-staging.api.example.com., value: fe01=]: one ,two ,three ,four ,five",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ClearFqdnCache()
check := newPreCheck()
ok, err := check.checkDNSPropagation(test.fqdn, test.value)
if test.expectError {
assert.Errorf(t, err, "PreCheckDNS must fail for %s", test.fqdn)
if test.expectedError != "" {
assert.ErrorContainsf(t, err, test.expectedError, "PreCheckDNS must fail for %s", test.fqdn)
assert.False(t, ok, "PreCheckDNS must fail for %s", test.fqdn)
} else {
assert.NoErrorf(t, err, "PreCheckDNS failed for %s", test.fqdn)
@ -46,69 +79,67 @@ func TestCheckDNSPropagation(t *testing.T) {
}
}
func TestCheckAuthoritativeNss(t *testing.T) {
func Test_checkNameserversPropagation_authoritativeNss(t *testing.T) {
testCases := []struct {
desc string
fqdn, value string
ns []string
expected bool
desc string
fqdn, value string
fakeDNSServer *dnsmock.Builder
expectedError string
}{
{
desc: "TXT RR w/ expected value",
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "151698.8.8.024",
ns: []string{"asnums.routeviews.org."},
expected: true,
desc: "TXT RR w/ expected value",
// NS: asnums.routeviews.org.
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "151698.8.8.024",
fakeDNSServer: dnsmock.NewServer().
Query("8.8.8.8.asn.routeviews.org. TXT",
dnsmock.Answer(
fakeTXT("8.8.8.8.asn.routeviews.org.", "151698.8.8.024"),
),
),
},
{
desc: "TXT RR w/ unexpected value",
// NS: asnums.routeviews.org.
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "fe01=",
fakeDNSServer: dnsmock.NewServer().
Query("8.8.8.8.asn.routeviews.org. TXT",
dnsmock.Answer(
fakeTXT("8.8.8.8.asn.routeviews.org.", "15169"),
fakeTXT("8.8.8.8.asn.routeviews.org.", "8.8.8.0"),
fakeTXT("8.8.8.8.asn.routeviews.org.", "24"),
),
),
expectedError: "did not return the expected TXT record [fqdn: 8.8.8.8.asn.routeviews.org., value: fe01=]: 15169 ,8.8.8.0 ,24",
},
{
desc: "No TXT RR",
fqdn: "ns1.google.com.",
ns: []string{"ns2.google.com."},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ClearFqdnCache()
ok, _ := checkNameserversPropagation(test.fqdn, test.value, test.ns, true)
assert.Equal(t, test.expected, ok, test.fqdn)
})
}
}
func TestCheckAuthoritativeNssErr(t *testing.T) {
testCases := []struct {
desc string
fqdn, value string
ns []string
error string
}{
{
desc: "TXT RR /w unexpected value",
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "fe01=",
ns: []string{"asnums.routeviews.org."},
error: "did not return the expected TXT record",
},
{
desc: "No TXT RR",
// NS: ns2.google.com.
fqdn: "ns1.google.com.",
value: "fe01=",
ns: []string{"ns2.google.com."},
error: "did not return the expected TXT record",
fakeDNSServer: dnsmock.NewServer().
Query("ns1.google.com.", dnsmock.Noop),
expectedError: "did not return the expected TXT record [fqdn: ns1.google.com., value: fe01=]: ",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ClearFqdnCache()
_, err := checkNameserversPropagation(test.fqdn, test.value, test.ns, true)
require.Error(t, err)
assert.Contains(t, err.Error(), test.error)
addr := test.fakeDNSServer.Build(t)
ok, err := checkNameserversPropagation(test.fqdn, test.value, []string{addr.String()}, false)
if test.expectedError == "" {
require.NoError(t, err)
assert.True(t, ok)
} else {
require.Error(t, err)
require.ErrorContains(t, err, test.expectedError)
assert.False(t, ok)
}
})
}
}

View file

@ -0,0 +1,191 @@
package dnsmock
import (
"fmt"
"math"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
const noType uint16 = math.MaxUint16
type Option func(*dns.Server) error
type Builder struct {
// domain -> op -> type
routes map[string]map[int]map[uint16]dns.Handler
stringToType map[string]uint16
}
func NewServer() *Builder {
stringToType := make(map[string]uint16)
for typ, str := range dns.TypeToString {
stringToType[str] = typ
}
return &Builder{
routes: make(map[string]map[int]map[uint16]dns.Handler),
stringToType: stringToType,
}
}
func (b *Builder) Query(pattern string, handler dns.HandlerFunc) *Builder {
route, err := b.route(pattern, dns.OpcodeQuery, handler)
if err != nil {
panic(err.Error())
}
return route
}
func (b *Builder) Update(pattern string, handler dns.HandlerFunc) *Builder {
route, err := b.route(pattern, dns.OpcodeUpdate, handler)
if err != nil {
panic(err.Error())
}
return route
}
func (b *Builder) route(pattern string, op int, handler dns.HandlerFunc) (*Builder, error) {
parts := strings.Fields(pattern)
domain := parts[0]
_, ok := dns.IsDomainName(domain)
if !ok {
return nil, fmt.Errorf("%s: invalid domain: %s", dns.OpcodeToString[op], domain)
}
if _, ok := b.routes[domain]; !ok {
b.routes[domain] = make(map[int]map[uint16]dns.Handler)
}
if _, ok := b.routes[domain][op]; !ok {
b.routes[domain][op] = make(map[uint16]dns.Handler)
}
if _, ok := b.routes[domain][op][noType]; ok {
return nil, fmt.Errorf("%s: a global route already exists for the domain: %s", dns.OpcodeToString[op], domain)
}
switch len(parts) {
case 1:
if len(b.routes[domain][op]) > 0 {
return nil, fmt.Errorf("%s: global route and specific routes cannot be mixed for the same domain: %s", dns.OpcodeToString[op], domain)
}
b.routes[domain][op][noType] = handler
return b, nil
case 2:
raw := parts[1]
qType, ok := b.stringToType[raw]
if !ok {
return nil, fmt.Errorf("%s: unknown type: %s", dns.OpcodeToString[op], raw)
}
if _, ok := b.routes[domain][op][qType]; ok {
return nil, fmt.Errorf("%s: duplicate route: %s", dns.OpcodeToString[op], pattern)
}
b.routes[domain][op][qType] = handler
return b, nil
default:
return nil, fmt.Errorf("%s: invalid pattern: %s", dns.OpcodeToString[op], pattern)
}
}
func (b *Builder) Build(t *testing.T, options ...Option) net.Addr {
t.Helper()
mux := dns.NewServeMux()
server := &dns.Server{
Addr: "127.0.0.1:0",
Net: "udp",
ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
Handler: mux,
MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction {
// bypass defaultMsgAcceptFunc to allow dynamic update (https://github.com/miekg/dns/pull/830)
return dns.MsgAccept
},
}
for _, option := range options {
require.NoError(t, option(server))
}
for pattern, ops := range b.routes {
mux.HandleFunc(pattern, func(w dns.ResponseWriter, req *dns.Msg) {
mTypes, ok := ops[req.Opcode]
if !ok {
_ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeNotImplemented))
return
}
if h, found := mTypes[noType]; found {
h.ServeDNS(w, req)
return
}
// For safety but it doesn't happen.
if len(req.Question) == 0 {
_ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused))
return
}
// For safety but it doesn't happen.
if req.Question[0].Qclass != dns.ClassINET {
_ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused))
return
}
// Works only for [Query].
h, ok := mTypes[req.Question[0].Qtype]
if !ok {
_ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeNotImplemented))
return
}
h.ServeDNS(w, req)
})
}
t.Cleanup(func() {
_ = server.Shutdown()
})
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
err := server.ListenAndServe()
if err != nil {
t.Log(err)
}
}()
waitLock.Lock()
return server.PacketConn.LocalAddr()
}

View file

@ -0,0 +1,240 @@
package dnsmock
import (
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_Query_matchType(t *testing.T) {
addr := NewServer().
Query("example.com. SOA", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Query_noType(t *testing.T) {
addr := NewServer().
Query("example.com.", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Query_noMatch_domain(t *testing.T) {
addr := NewServer().
Query("example.com. SOA", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.org.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeRefused, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeRefused], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Query_noMatch_type(t *testing.T) {
addr := NewServer().
Query("example.com. SOA", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeNotImplemented, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Query_noMatch_opType(t *testing.T) {
addr := NewServer().
Query("example.com.", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetUpdate("example.com.")
m.Insert([]dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1},
Txt: []string{"foo"},
},
})
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeNotImplemented, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Query_unknownType(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: unknown type: ABC", func() {
NewServer().
Query("example.com. ABC", Noop).
Build(t)
})
}
func TestServer_Query_duplicate(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: duplicate route: example.com. SOA", func() {
NewServer().
Query("example.com. SOA", Noop).
Query("example.com. SOA", Noop).
Build(t)
})
}
func TestServer_Query_duplicateGlobal(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: a global route already exists for the domain: example.com.", func() {
NewServer().
Query("example.com.", Noop).
Query("example.com.", Noop).
Build(t)
})
}
func TestServer_Query_mixed(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: global route and specific routes cannot be mixed for the same domain: example.com.", func() {
NewServer().
Query("example.com. SOA", Noop).
Query("example.com.", Noop).
Build(t)
})
}
func TestServer_Query_invalidDomain(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: invalid domain: .example.com.", func() {
NewServer().
Query(".example.com. SOA", Noop).
Build(t)
})
}
func TestServer_Query_invalidPattern(t *testing.T) {
assert.PanicsWithValue(t, "QUERY: invalid pattern: example.com. SOA 13", func() {
NewServer().
Query("example.com. SOA 13", Noop).
Build(t)
})
}
func TestServer_Update(t *testing.T) {
addr := NewServer().
Update("example.com.", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetUpdate("example.com.")
m.Insert([]dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1},
Txt: []string{"foo"},
},
})
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Update_noMatch_domain(t *testing.T) {
addr := NewServer().
Update("example.com.", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetUpdate("example.org.")
m.Insert([]dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1},
Txt: []string{"foo"},
},
})
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeRefused, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeRefused], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Update_noMatch_opType(t *testing.T) {
addr := NewServer().
Update("example.com.", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeNotImplemented, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestServer_Update_duplicate(t *testing.T) {
assert.PanicsWithValue(t, "UPDATE: a global route already exists for the domain: example.com.", func() {
NewServer().
Update("example.com.", Noop).
Update("example.com.", Noop).
Build(t)
})
}
func TestServer_Update_invalidDomain(t *testing.T) {
assert.PanicsWithValue(t, "UPDATE: invalid domain: .example.com.", func() {
NewServer().
Update(".example.com.", Noop).
Build(t)
})
}
func TestServer_Update_invalidPattern(t *testing.T) {
assert.PanicsWithValue(t, "UPDATE: invalid pattern: example.com. SOA 13", func() {
NewServer().
Update("example.com. SOA 13", Noop).
Build(t)
})
}

View file

@ -0,0 +1,76 @@
package dnsmock
import (
"fmt"
"github.com/miekg/dns"
)
func DumpRequest() dns.HandlerFunc {
return func(w dns.ResponseWriter, req *dns.Msg) {
fmt.Println(req)
Noop(w, req)
}
}
func SOA(name string) dns.HandlerFunc {
return func(w dns.ResponseWriter, req *dns.Msg) {
if name == "" {
name = req.Question[0].Name
}
// Handle TLD
base := name
if dns.CountLabel(req.Question[0].Name) == 1 {
base = "nic." + req.Question[0].Name
}
answer := &dns.SOA{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120},
Ns: "ns1." + base,
Mbox: "admin." + base,
Serial: 2016022801,
Refresh: 28800,
Retry: 7200,
Expire: 2419200,
Minttl: 1200,
}
Answer(answer)(w, req)
}
}
func CNAME(target string) dns.HandlerFunc {
return func(w dns.ResponseWriter, req *dns.Msg) {
answer := &dns.CNAME{
Hdr: dns.RR_Header{Name: req.Question[0].Name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 1},
Target: dns.Fqdn(target),
}
Answer(answer)(w, req)
}
}
func Noop(w dns.ResponseWriter, req *dns.Msg) {
_ = w.WriteMsg(new(dns.Msg).SetReply(req))
}
func Error(rcode int) dns.HandlerFunc {
return func(w dns.ResponseWriter, req *dns.Msg) {
_ = w.WriteMsg(new(dns.Msg).SetRcode(req, rcode))
}
}
func Answer(answer ...dns.RR) func(w dns.ResponseWriter, req *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg).SetReply(req)
m.Answer = answer
err := w.WriteMsg(m)
if err != nil {
panic(err.Error())
}
}
}

View file

@ -0,0 +1,156 @@
package dnsmock
import (
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSOA_self(t *testing.T) {
addr := NewServer().
Query("example.com. SOA", SOA("")).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
expectedSOA := []dns.RR{&dns.SOA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 56},
Ns: "ns1.example.com.",
Mbox: "admin.example.com.",
Serial: 2016022801,
Refresh: 28800,
Retry: 7200,
Expire: 2419200,
Minttl: 1200,
}}
require.Equal(t, dns.RcodeSuccess, r.Rcode)
assert.Equal(t, expectedSOA, r.Answer)
assert.Equal(t, m.Question, r.Question)
}
func TestSOA_differentDomain(t *testing.T) {
addr := NewServer().
Query("example.com. SOA", SOA("example.org.")).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
expectedSOA := []dns.RR{&dns.SOA{
Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 56},
Ns: "ns1.example.org.",
Mbox: "admin.example.org.",
Serial: 2016022801,
Refresh: 28800,
Retry: 7200,
Expire: 2419200,
Minttl: 1200,
}}
assert.Equal(t, expectedSOA, r.Answer)
assert.Equal(t, m.Question, r.Question)
}
func TestSOA_tld(t *testing.T) {
addr := NewServer().
Query("com. SOA", SOA("")).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("com.", dns.TypeSOA)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
expectedSOA := []dns.RR{&dns.SOA{
Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 48},
Ns: "ns1.nic.com.",
Mbox: "admin.nic.com.",
Serial: 2016022801,
Refresh: 28800,
Retry: 7200,
Expire: 2419200,
Minttl: 1200,
}}
assert.Equal(t, expectedSOA, r.Answer)
assert.Equal(t, m.Question, r.Question)
}
func TestCNAME(t *testing.T) {
addr := NewServer().
Query("example.com. CNAME", CNAME("example.org.")).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
expectedCNAME := []dns.RR{&dns.CNAME{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 1, Rdlength: 13},
Target: "example.org.",
}}
assert.Equal(t, expectedCNAME, r.Answer)
assert.Equal(t, m.Question, r.Question)
}
func TestNoop(t *testing.T) {
addr := NewServer().
Query("example.com. CNAME", Noop).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeSuccess, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}
func TestError(t *testing.T) {
addr := NewServer().
Query("example.com. CNAME", Error(dns.RcodeNameError)).
Build(t)
client := &dns.Client{Timeout: 1 * time.Second}
m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME)
r, _, err := client.Exchange(m, addr.String())
require.NoError(t, err)
require.Equalf(t, dns.RcodeNameError, r.Rcode,
"expected %s, got %s", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[r.Rcode])
assert.Equal(t, m.Question, r.Question)
}

View file

@ -131,7 +131,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
config.TSIGSecret = ""
} else {
// zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
config.TSIGKey = strings.ToLower(dns.Fqdn(config.TSIGKey))
config.TSIGKey = dns.CanonicalName(config.TSIGKey)
}
if config.TSIGAlgorithm == "" {
@ -193,14 +193,14 @@ func (d *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error {
}
// Create RR
rr := new(dns.TXT)
rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)}
rr.Txt = []string{value}
rrs := []dns.RR{rr}
rrs := []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)},
Txt: []string{value},
}}
// Create dynamic update packet
m := new(dns.Msg)
m.SetUpdate(zone)
m := new(dns.Msg).SetUpdate(zone)
switch action {
case "INSERT":
// Always remove old challenge left over from who knows what.

View file

@ -2,23 +2,21 @@ package rfc2136
import (
"bytes"
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/tester"
"github.com/go-acme/lego/v4/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
fakeDomain = "123456789.www.example.com"
fakeKeyAuth = "123d=="
fakeValue = "Now36o-3BmlB623-0c1qCIUmgWVVmDJb88KGl24pqpo"
fakeValue = "ADw2sEd82DUgXcQ9hNBZThJs7zVJkR5v9JeSbAb9mZY"
fakeFqdn = "_acme-challenge.123456789.www.example.com."
fakeZone = "example.com."
fakeTTL = 120
@ -162,33 +160,16 @@ func TestNewDNSProviderConfig(t *testing.T) {
}
}
func TestCanaryLocalTestServer(t *testing.T) {
func TestDNSProvider_Present_success(t *testing.T) {
dns01.ClearFqdnCache()
mux, addr := runLocalDNSTestServer(t, false)
mux.HandleFunc("example.com.", serverHandlerHello)
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeTXT)
r, _, err := c.Exchange(m, addr)
require.NoError(t, err, "Failed to communicate with test server")
assert.Len(t, r.Extra, 1, "Failed to communicate with test server")
txt := r.Extra[0].(*dns.TXT).Txt[0]
assert.Equal(t, "Hello world", txt)
}
func TestServerSuccess(t *testing.T) {
dns01.ClearFqdnCache()
mux, addr := runLocalDNSTestServer(t, false)
mux.HandleFunc(fakeZone, serverHandlerReturnSuccess)
addr := dnsmock.NewServer().
Query(fakeZone+" SOA", dnsmock.SOA("")).
Update(fakeZone+" SOA", dnsmock.Noop).
Build(t)
config := NewDefaultConfig()
config.Nameserver = addr
config.Nameserver = addr.String()
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
@ -197,14 +178,72 @@ func TestServerSuccess(t *testing.T) {
require.NoError(t, err)
}
func TestServerError(t *testing.T) {
func TestDNSProvider_Present_success_updatePacket(t *testing.T) {
dns01.ClearFqdnCache()
mux, addr := runLocalDNSTestServer(t, false)
mux.HandleFunc(fakeZone, serverHandlerReturnErr)
reqChan := make(chan *dns.Msg, 1)
addr := dnsmock.NewServer().
Query("_acme-challenge.123456789.www.example.com. SOA", dnsmock.SOA(fakeZone)).
Update(fakeZone+" SOA", func(w dns.ResponseWriter, req *dns.Msg) {
dnsmock.Noop(w, req)
// Only talk back when it is not the SOA RR.
reqChan <- req
}).
Build(t)
config := NewDefaultConfig()
config.Nameserver = addr
config.Nameserver = addr.String()
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
err = provider.Present(fakeDomain, "", fakeKeyAuth)
require.NoError(t, err)
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for request")
case rcvMsg := <-reqChan:
txtRR := &dns.TXT{
Hdr: dns.RR_Header{Name: fakeFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: fakeTTL},
Txt: []string{fakeValue},
}
m := new(dns.Msg).SetUpdate(fakeZone)
m.RemoveRRset([]dns.RR{txtRR})
m.Insert([]dns.RR{txtRR})
expected, err := m.Pack()
require.NoError(t, err, "error packing")
rcvMsg.Id = m.Id
actual, err := rcvMsg.Pack()
require.NoError(t, err, "error packing")
if !bytes.Equal(actual, expected) {
tmp := new(dns.Msg)
require.NoError(t, tmp.Unpack(actual))
t.Errorf("Expected msg:\n%s", m)
t.Errorf("Actual msg:\n%s", tmp)
}
}
}
func TestDNSProvider_Present_error(t *testing.T) {
dns01.ClearFqdnCache()
addr := dnsmock.NewServer().
Query(fakeZone+" SOA", dnsmock.Error(dns.RcodeNotZone)).
Build(t)
config := NewDefaultConfig()
config.Nameserver = addr.String()
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
@ -216,14 +255,20 @@ func TestServerError(t *testing.T) {
}
}
func TestTsigClient(t *testing.T) {
func TestDNSProvider_Present_tsig_success(t *testing.T) {
dns01.ClearFqdnCache()
mux, addr := runLocalDNSTestServer(t, true)
mux.HandleFunc(fakeZone, serverHandlerReturnSuccess)
addr := dnsmock.NewServer().
Query(fakeZone+" SOA", dnsmock.SOA("")).
Update(fakeZone+" SOA", handleTSIG).
Build(t, func(server *dns.Server) error {
server.TsigSecret = map[string]string{fakeTsigKey: fakeTsigSecret}
return nil
})
config := NewDefaultConfig()
config.Nameserver = addr
config.Nameserver = addr.String()
config.TSIGKey = fakeTsigKey
config.TSIGSecret = fakeTsigSecret
@ -234,167 +279,50 @@ func TestTsigClient(t *testing.T) {
require.NoError(t, err)
}
func TestValidUpdatePacket(t *testing.T) {
reqChan := make(chan *dns.Msg, 10)
func TestDNSProvider_Present_tsig_error(t *testing.T) {
dns01.ClearFqdnCache()
mux, addr := runLocalDNSTestServer(t, false)
mux.HandleFunc(fakeZone, serverHandlerPassBackRequest(reqChan))
addr := dnsmock.NewServer().
Query(fakeZone+" SOA", dnsmock.SOA("")).
Update(fakeZone+" SOA", handleTSIG).
Build(t, func(server *dns.Server) error {
server.TsigSecret = map[string]string{"example.org": fakeTsigSecret}
txtRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN TXT %s", fakeFqdn, fakeTTL, fakeValue))
m := new(dns.Msg)
m.SetUpdate(fakeZone)
m.RemoveRRset([]dns.RR{txtRR})
m.Insert([]dns.RR{txtRR})
expectStr := m.String()
expect, err := m.Pack()
require.NoError(t, err, "error packing")
return nil
})
config := NewDefaultConfig()
config.Nameserver = addr
config.Nameserver = addr.String()
config.TSIGKey = fakeTsigKey
config.TSIGSecret = fakeTsigSecret
provider, err := NewDNSProviderConfig(config)
require.NoError(t, err)
err = provider.Present(fakeDomain, "", "1234d==")
require.NoError(t, err)
rcvMsg := <-reqChan
rcvMsg.Id = m.Id
actual, err := rcvMsg.Pack()
require.NoError(t, err, "error packing")
if !bytes.Equal(actual, expect) {
tmp := new(dns.Msg)
if err := tmp.Unpack(actual); err != nil {
t.Fatalf("Error unpacking actual msg: %v", err)
}
t.Errorf("Expected msg:\n%s", expectStr)
t.Errorf("Actual msg:\n%v", tmp)
}
err = provider.Present(fakeDomain, "", fakeKeyAuth)
require.Error(t, err)
require.EqualError(t, err, "rfc2136: failed to insert: DNS update failed: server replied: NOTZONE")
}
func runLocalDNSTestServer(t *testing.T, tsig bool) (*dns.ServeMux, string) {
t.Helper()
mux := dns.NewServeMux()
server := &dns.Server{
Addr: "127.0.0.1:0",
Net: "udp",
ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction {
// bypass defaultMsgAcceptFunc to allow dynamic update (https://github.com/miekg/dns/pull/830)
return dns.MsgAccept
},
Handler: mux,
}
t.Cleanup(func() {
_ = server.Shutdown()
})
if tsig {
server.TsigSecret = map[string]string{fakeTsigKey: fakeTsigSecret}
}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
err := server.ListenAndServe()
if err != nil {
t.Log(err)
}
}()
waitLock.Lock()
return mux, server.PacketConn.LocalAddr().String()
}
func serverHandlerHello(w dns.ResponseWriter, req *dns.Msg) {
func handleTSIG(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
m.Extra = make([]dns.RR, 1)
m.Extra[0] = &dns.TXT{
Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{"Hello world"},
tsig := req.IsTsig()
if tsig == nil {
_ = w.WriteMsg(m.SetRcode(req, dns.RcodeRefused))
return
}
_ = w.WriteMsg(m)
}
err := w.TsigStatus()
if err != nil {
_ = w.WriteMsg(m.SetRcode(req, dns.RcodeNotZone))
func serverHandlerReturnSuccess(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET {
// Return SOA to appease findZoneByFqdn()
m.Answer = []dns.RR{fakeSOAAnswer()}
return
}
if t := req.IsTsig(); t != nil {
if w.TsigStatus() == nil {
// Validated
m.SetTsig(fakeZone, dns.HmacSHA1, 300, time.Now().Unix())
}
}
_ = w.WriteMsg(m)
}
func serverHandlerReturnErr(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeNotZone)
_ = w.WriteMsg(m)
}
func serverHandlerPassBackRequest(reqChan chan *dns.Msg) func(w dns.ResponseWriter, req *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET {
// Return SOA to appease findZoneByFqdn()
m.Answer = []dns.RR{fakeSOAAnswer()}
}
if t := req.IsTsig(); t != nil {
if w.TsigStatus() == nil {
// Validated
m.SetTsig(fakeZone, dns.HmacSHA1, 300, time.Now().Unix())
}
}
_ = w.WriteMsg(m)
if req.Opcode != dns.OpcodeQuery || req.Question[0].Qtype != dns.TypeSOA || req.Question[0].Qclass != dns.ClassINET {
// Only talk back when it is not the SOA RR.
reqChan <- req
}
}
}
func fakeSOAAnswer() *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{Name: fakeZone, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: fakeTTL},
Ns: "ns1." + fakeZone,
Mbox: "admin." + fakeZone,
Serial: 2016022801,
Refresh: 28800,
Retry: 7200,
Expire: 2419200,
Minttl: 1200,
}
// Validated
_ = w.WriteMsg(m.
SetReply(req).
SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix()),
)
}