diff --git a/.golangci.yml b/.golangci.yml index 66f3fd9d0..2fabe806c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -195,8 +195,8 @@ linters: text: dnsTimeout is a global variable linters: - gochecknoglobals - - path: challenge/dns01/nameserver_test.go - text: findXByFqdnTestCases is a global variable + - path: challenge/dns01/precheck.go + text: defaultNameserverPort is a global variable linters: - gochecknoglobals - path: challenge/http01/domain_matcher.go diff --git a/challenge/dns01/dns_challenge_manual_test.go b/challenge/dns01/dns_challenge_manual_test.go index 26a508d1c..e0a2dc93a 100644 --- a/challenge/dns01/dns_challenge_manual_test.go +++ b/challenge/dns01/dns_challenge_manual_test.go @@ -5,10 +5,18 @@ import ( "os" "testing" + "github.com/go-acme/lego/v4/platform/tester/dnsmock" + "github.com/miekg/dns" "github.com/stretchr/testify/require" ) func TestDNSProviderManual(t *testing.T) { + useAsNameserver(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.Noop). + Query("_acme-challenge.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t)) + backupStdin := os.Stdin defer func() { os.Stdin = backupStdin }() diff --git a/challenge/dns01/dns_challenge_test.go b/challenge/dns01/dns_challenge_test.go index 48bd9986c..7c723497c 100644 --- a/challenge/dns01/dns_challenge_test.go +++ b/challenge/dns01/dns_challenge_test.go @@ -11,6 +11,8 @@ import ( "github.com/go-acme/lego/v4/acme/api" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/platform/tester/dnsmock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -113,6 +115,10 @@ func TestChallenge_PreSolve(t *testing.T) { } func TestChallenge_Solve(t *testing.T) { + useAsNameserver(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.Noop). + Build(t)) + server := tester.MockACMEServer().BuildHTTPS(t) privateKey, err := rsa.GenerateKey(rand.Reader, 1024) @@ -280,3 +286,55 @@ func TestChallenge_CleanUp(t *testing.T) { }) } } + +func TestGetChallengeInfo(t *testing.T) { + useAsNameserver(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.Noop). + Build(t)) + + info := GetChallengeInfo("example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "_acme-challenge.example.com.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} + +func TestGetChallengeInfo_CNAME(t *testing.T) { + useAsNameserver(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")). + Query("example.org. CNAME", dnsmock.Noop). + Build(t)) + + info := GetChallengeInfo("example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "example.org.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} + +func TestGetChallengeInfo_CNAME_disabled(t *testing.T) { + useAsNameserver(t, dnsmock.NewServer(). + // Never called when the env var works. + Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")). + Build(t)) + + t.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true") + + info := GetChallengeInfo("example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "_acme-challenge.example.com.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} diff --git a/challenge/dns01/fixtures/resolv.conf.1 b/challenge/dns01/fixtures/resolv.conf.1 index 3098f99b5..bc2a3c1ac 100644 --- a/challenge/dns01/fixtures/resolv.conf.1 +++ b/challenge/dns01/fixtures/resolv.conf.1 @@ -1,4 +1,4 @@ -domain company.com +domain example.com nameserver 10.200.3.249 nameserver 10.200.3.250:5353 nameserver 2001:4860:4860::8844 diff --git a/challenge/dns01/mock_test.go b/challenge/dns01/mock_test.go new file mode 100644 index 000000000..535d79cda --- /dev/null +++ b/challenge/dns01/mock_test.go @@ -0,0 +1,78 @@ +package dns01 + +import ( + "context" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func fakeNS(name, ns string) *dns.NS { + return &dns.NS{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800}, + Ns: ns, + } +} + +func fakeA(name, ip string) *dns.A { + return &dns.A{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10}, + A: net.ParseIP(ip), + } +} + +func fakeTXT(name, value string) *dns.TXT { + return &dns.TXT{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 10}, + Txt: []string{value}, + } +} + +// mockResolver modifies the default DNS resolver to use a custom network address during the test execution. +// IMPORTANT: it modifying global variables. +func mockResolver(t *testing.T, addr net.Addr) { + t.Helper() + + _, port, err := net.SplitHostPort(addr.String()) + require.NoError(t, err) + + originalDefaultNameserverPort := defaultNameserverPort + t.Cleanup(func() { + defaultNameserverPort = originalDefaultNameserverPort + }) + + defaultNameserverPort = port + + originalResolver := net.DefaultResolver + t.Cleanup(func() { + net.DefaultResolver = originalResolver + }) + + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{Timeout: 1 * time.Second} + + return d.DialContext(ctx, network, addr.String()) + }, + } +} + +func useAsNameserver(t *testing.T, addr net.Addr) { + t.Helper() + + ClearFqdnCache() + t.Cleanup(func() { + ClearFqdnCache() + }) + + originalRecursiveNameservers := recursiveNameservers + t.Cleanup(func() { + recursiveNameservers = originalRecursiveNameservers + }) + + recursiveNameservers = ParseNameservers([]string{addr.String()}) +} diff --git a/challenge/dns01/nameserver_test.go b/challenge/dns01/nameserver_test.go index 4eb7a5f15..dd4d66dcb 100644 --- a/challenge/dns01/nameserver_test.go +++ b/challenge/dns01/nameserver_test.go @@ -5,138 +5,237 @@ import ( "sort" "testing" + "github.com/go-acme/lego/v4/platform/tester/dnsmock" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestLookupNameserversOK(t *testing.T) { +func Test_lookupNameserversOK(t *testing.T) { testCases := []struct { - fqdn string - nss []string + desc string + fakeDNSServer *dnsmock.Builder + fqdn string + expected []string }{ { - fqdn: "en.wikipedia.org.", - nss: []string{"ns0.wikimedia.org.", "ns1.wikimedia.org.", "ns2.wikimedia.org."}, + fqdn: "en.wikipedia.org.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("en.wikipedia.org.localhost SOA", dnsmock.CNAME("dyna.wikimedia.org.localhost")). + Query("wikipedia.org.localhost SOA", dnsmock.SOA("")). + Query("wikipedia.org.localhost NS", + dnsmock.Answer( + fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."), + fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."), + fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."), + ), + ), + expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."}, }, { - fqdn: "www.google.com.", - nss: []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + fqdn: "www.google.com.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("www.google.com.localhost. SOA", dnsmock.Noop). + Query("google.com.localhost. SOA", dnsmock.SOA("")). + Query("google.com.localhost. NS", + dnsmock.Answer( + fakeNS("google.com.localhost.", "ns1.google.com.localhost."), + fakeNS("google.com.localhost.", "ns2.google.com.localhost."), + fakeNS("google.com.localhost.", "ns3.google.com.localhost."), + fakeNS("google.com.localhost.", "ns4.google.com.localhost."), + ), + ), + expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."}, }, { - fqdn: "mail.proton.me.", - nss: []string{"ns1.proton.me.", "ns2.proton.me.", "ns3.proton.me."}, + fqdn: "mail.proton.me.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("mail.proton.me.localhost. SOA", dnsmock.Noop). + Query("proton.me.localhost. SOA", dnsmock.SOA("")). + Query("proton.me.localhost. NS", + dnsmock.Answer( + fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."), + fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."), + fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."), + ), + ), + expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."}, }, } for _, test := range testCases { t.Run(test.fqdn, func(t *testing.T) { - t.Parallel() + useAsNameserver(t, test.fakeDNSServer.Build(t)) nss, err := lookupNameservers(test.fqdn) require.NoError(t, err) sort.Strings(nss) - sort.Strings(test.nss) + sort.Strings(test.expected) - assert.Equal(t, test.nss, nss) + assert.Equal(t, test.expected, nss) }) } } -func TestLookupNameserversErr(t *testing.T) { +func Test_lookupNameserversErr(t *testing.T) { testCases := []struct { - desc string - fqdn string - error string + desc string + fqdn string + fakeDNSServer *dnsmock.Builder + error string }{ { - desc: "invalid tld", - fqdn: "example.invalid.", - error: "could not find zone", + desc: "NXDOMAIN", + fqdn: "example.invalid.", + fakeDNSServer: dnsmock.NewServer(). + Query(". SOA", dnsmock.Error(dns.RcodeNameError)), + error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]", + }, + { + desc: "NS error", + fqdn: "example.com.", + fakeDNSServer: dnsmock.NewServer(). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.com. NS", dnsmock.Error(dns.RcodeServerFailure)), + error: "[zone=example.com.] could not determine authoritative nameservers", + }, + { + desc: "empty NS", + fqdn: "example.com.", + fakeDNSServer: dnsmock.NewServer(). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.me NS", dnsmock.Noop), + error: "[zone=example.com.] could not determine authoritative nameservers", }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - t.Parallel() + useAsNameserver(t, test.fakeDNSServer.Build(t)) _, err := lookupNameservers(test.fqdn) require.Error(t, err) - assert.Contains(t, err.Error(), test.error) + assert.EqualError(t, err, test.error) }) } } -var findXByFqdnTestCases = []struct { +type lookupSoaByFqdnTestCase struct { desc string fqdn string zone string primaryNs string nameservers []string expectedError string -}{ - { - desc: "domain is a CNAME", - fqdn: "mail.google.com.", - zone: "google.com.", - primaryNs: "ns1.google.com.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a non-existent subdomain", - fqdn: "foo.google.com.", - zone: "google.com.", - primaryNs: "ns1.google.com.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a eTLD", - fqdn: "example.com.ac.", - zone: "ac.", - primaryNs: "a0.nic.ac.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a cross-zone CNAME", - fqdn: "cross-zone-example.assets.sh.", - zone: "assets.sh.", - primaryNs: "gina.ns.cloudflare.com.", - nameservers: recursiveNameservers, - }, - { - desc: "NXDOMAIN", - fqdn: "test.lego.invalid.", - zone: "lego.invalid.", - nameservers: []string{"8.8.8.8:53"}, - expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`, - }, - { - desc: "several non existent nameservers", - fqdn: "mail.google.com.", - zone: "google.com.", - primaryNs: "ns1.google.com.", - nameservers: []string{":7053", ":8053", "8.8.8.8:53"}, - }, - { - desc: "only non-existent nameservers", - fqdn: "mail.google.com.", - zone: "google.com.", - nameservers: []string{":7053", ":8053", ":9053"}, - // use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053. - expectedError: "[fqdn=mail.google.com.] could not find the start of authority for 'mail.google.com.': DNS call error: read udp ", - }, - { - desc: "no nameservers", - fqdn: "test.example.com.", - zone: "example.com.", - nameservers: []string{}, - expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers", - }, +} + +func lookupSoaByFqdnTestCases(t *testing.T) []lookupSoaByFqdnTestCase { + t.Helper() + + return []lookupSoaByFqdnTestCase{ + { + desc: "domain is a CNAME", + fqdn: "mail.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a non-existent subdomain", + fqdn: "foo.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("foo.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a eTLD", + fqdn: "example.com.ac.", + zone: "ac.", + primaryNs: "ns1.nic.ac.", + nameservers: []string{ + dnsmock.NewServer(). + Query("example.com.ac. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("com.ac. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("ac. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a cross-zone CNAME", + fqdn: "cross-zone-example.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("cross-zone-example.example.com. SOA", dnsmock.CNAME("example.org.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "NXDOMAIN", + fqdn: "test.lego.invalid.", + zone: "lego.invalid.", + nameservers: []string{ + dnsmock.NewServer(). + Query("test.lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Build(t). + String(), + }, + expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`, + }, + { + desc: "several non existent nameservers", + fqdn: "mail.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + ":7053", + ":8053", + dnsmock.NewServer(). + Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "only non-existent nameservers", + fqdn: "mail.example.com.", + zone: "example.com.", + nameservers: []string{":7053", ":8053", ":9053"}, + // use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053. + expectedError: "[fqdn=mail.example.com.] could not find the start of authority for 'mail.example.com.': DNS call error: read udp ", + }, + { + desc: "no nameservers", + fqdn: "test.example.com.", + zone: "example.com.", + nameservers: []string{}, + expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers", + }, + } } func TestFindZoneByFqdnCustom(t *testing.T) { - for _, test := range findXByFqdnTestCases { + for _, test := range lookupSoaByFqdnTestCases(t) { t.Run(test.desc, func(t *testing.T) { ClearFqdnCache() @@ -153,7 +252,7 @@ func TestFindZoneByFqdnCustom(t *testing.T) { } func TestFindPrimaryNsByFqdnCustom(t *testing.T) { - for _, test := range findXByFqdnTestCases { + for _, test := range lookupSoaByFqdnTestCases(t) { t.Run(test.desc, func(t *testing.T) { ClearFqdnCache() @@ -169,7 +268,7 @@ func TestFindPrimaryNsByFqdnCustom(t *testing.T) { } } -func TestResolveConfServers(t *testing.T) { +func Test_getNameservers_ResolveConfServers(t *testing.T) { testCases := []struct { fixture string expected []string diff --git a/challenge/dns01/precheck.go b/challenge/dns01/precheck.go index 706e8dbec..e10efa33e 100644 --- a/challenge/dns01/precheck.go +++ b/challenge/dns01/precheck.go @@ -9,6 +9,10 @@ import ( "github.com/miekg/dns" ) +// defaultNameserverPort used by authoritative NS. +// This is for tests only. +var defaultNameserverPort = "53" + // PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready. type PreCheckFunc func(fqdn, value string) (bool, error) @@ -121,7 +125,7 @@ func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) { func checkNameserversPropagation(fqdn, value string, nameservers []string, addPort bool) (bool, error) { for _, ns := range nameservers { if addPort { - ns = net.JoinHostPort(ns, "53") + ns = net.JoinHostPort(ns, defaultNameserverPort) } r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, false) diff --git a/challenge/dns01/precheck_test.go b/challenge/dns01/precheck_test.go index 1f3ecbf7e..bda8c781e 100644 --- a/challenge/dns01/precheck_test.go +++ b/challenge/dns01/precheck_test.go @@ -3,40 +3,73 @@ package dns01 import ( "testing" + "github.com/go-acme/lego/v4/platform/tester/dnsmock" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestCheckDNSPropagation(t *testing.T) { +func Test_preCheck_checkDNSPropagation(t *testing.T) { + mockResolver(t, + dnsmock.NewServer(). + Query("ns0.lego.localhost. A", + dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))). + Query("ns1.lego.localhost. A", + dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))). + Query("example.com. TXT", + dnsmock.Answer( + fakeTXT("example.com.", "one"), + fakeTXT("example.com.", "two"), + fakeTXT("example.com.", "three"), + fakeTXT("example.com.", "four"), + fakeTXT("example.com.", "five"), + ), + ). + Build(t), + ) + + useAsNameserver(t, + dnsmock.NewServer(). + Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.com. NS", + dnsmock.Answer( + fakeNS("example.com.", "ns0.lego.localhost."), + fakeNS("example.com.", "ns1.lego.localhost."), + ), + ). + Build(t), + ) + testCases := []struct { - desc string - fqdn string - value string - expectError bool + desc string + fqdn string + value string + expectedError string }{ { desc: "success", - fqdn: "postman-echo.com.", - value: "postman-domain-verification=c85de626cb79d941310696e06558e2e790223802f3697dfbdcaf65510152d52c", + fqdn: "example.com.", + value: "four", }, { - desc: "no TXT record", - fqdn: "acme-staging.api.letsencrypt.org.", - value: "fe01=", - expectError: true, + desc: "no matching TXT record", + fqdn: "acme-staging.api.example.com.", + value: "fe01=", + expectedError: "did not return the expected TXT record [fqdn: acme-staging.api.example.com., value: fe01=]: one ,two ,three ,four ,five", }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - t.Parallel() ClearFqdnCache() check := newPreCheck() ok, err := check.checkDNSPropagation(test.fqdn, test.value) - if test.expectError { - assert.Errorf(t, err, "PreCheckDNS must fail for %s", test.fqdn) + if test.expectedError != "" { + assert.ErrorContainsf(t, err, test.expectedError, "PreCheckDNS must fail for %s", test.fqdn) assert.False(t, ok, "PreCheckDNS must fail for %s", test.fqdn) } else { assert.NoErrorf(t, err, "PreCheckDNS failed for %s", test.fqdn) @@ -46,69 +79,67 @@ func TestCheckDNSPropagation(t *testing.T) { } } -func TestCheckAuthoritativeNss(t *testing.T) { +func Test_checkNameserversPropagation_authoritativeNss(t *testing.T) { testCases := []struct { - desc string - fqdn, value string - ns []string - expected bool + desc string + fqdn, value string + fakeDNSServer *dnsmock.Builder + expectedError string }{ { - desc: "TXT RR w/ expected value", - fqdn: "8.8.8.8.asn.routeviews.org.", - value: "151698.8.8.024", - ns: []string{"asnums.routeviews.org."}, - expected: true, + desc: "TXT RR w/ expected value", + // NS: asnums.routeviews.org. + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "151698.8.8.024", + fakeDNSServer: dnsmock.NewServer(). + Query("8.8.8.8.asn.routeviews.org. TXT", + dnsmock.Answer( + fakeTXT("8.8.8.8.asn.routeviews.org.", "151698.8.8.024"), + ), + ), + }, + { + desc: "TXT RR w/ unexpected value", + // NS: asnums.routeviews.org. + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "fe01=", + fakeDNSServer: dnsmock.NewServer(). + Query("8.8.8.8.asn.routeviews.org. TXT", + dnsmock.Answer( + fakeTXT("8.8.8.8.asn.routeviews.org.", "15169"), + fakeTXT("8.8.8.8.asn.routeviews.org.", "8.8.8.0"), + fakeTXT("8.8.8.8.asn.routeviews.org.", "24"), + ), + ), + expectedError: "did not return the expected TXT record [fqdn: 8.8.8.8.asn.routeviews.org., value: fe01=]: 15169 ,8.8.8.0 ,24", }, { desc: "No TXT RR", - fqdn: "ns1.google.com.", - ns: []string{"ns2.google.com."}, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - ClearFqdnCache() - - ok, _ := checkNameserversPropagation(test.fqdn, test.value, test.ns, true) - assert.Equal(t, test.expected, ok, test.fqdn) - }) - } -} - -func TestCheckAuthoritativeNssErr(t *testing.T) { - testCases := []struct { - desc string - fqdn, value string - ns []string - error string - }{ - { - desc: "TXT RR /w unexpected value", - fqdn: "8.8.8.8.asn.routeviews.org.", - value: "fe01=", - ns: []string{"asnums.routeviews.org."}, - error: "did not return the expected TXT record", - }, - { - desc: "No TXT RR", + // NS: ns2.google.com. fqdn: "ns1.google.com.", value: "fe01=", - ns: []string{"ns2.google.com."}, - error: "did not return the expected TXT record", + fakeDNSServer: dnsmock.NewServer(). + Query("ns1.google.com.", dnsmock.Noop), + expectedError: "did not return the expected TXT record [fqdn: ns1.google.com., value: fe01=]: ", }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - t.Parallel() ClearFqdnCache() - _, err := checkNameserversPropagation(test.fqdn, test.value, test.ns, true) - require.Error(t, err) - assert.Contains(t, err.Error(), test.error) + addr := test.fakeDNSServer.Build(t) + + ok, err := checkNameserversPropagation(test.fqdn, test.value, []string{addr.String()}, false) + + if test.expectedError == "" { + require.NoError(t, err) + assert.True(t, ok) + } else { + require.Error(t, err) + require.ErrorContains(t, err, test.expectedError) + assert.False(t, ok) + } }) } } diff --git a/platform/tester/dnsmock/dnsmock.go b/platform/tester/dnsmock/dnsmock.go new file mode 100644 index 000000000..6cb4f45b8 --- /dev/null +++ b/platform/tester/dnsmock/dnsmock.go @@ -0,0 +1,191 @@ +package dnsmock + +import ( + "fmt" + "math" + "net" + "strings" + "sync" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +const noType uint16 = math.MaxUint16 + +type Option func(*dns.Server) error + +type Builder struct { + // domain -> op -> type + routes map[string]map[int]map[uint16]dns.Handler + + stringToType map[string]uint16 +} + +func NewServer() *Builder { + stringToType := make(map[string]uint16) + for typ, str := range dns.TypeToString { + stringToType[str] = typ + } + + return &Builder{ + routes: make(map[string]map[int]map[uint16]dns.Handler), + stringToType: stringToType, + } +} + +func (b *Builder) Query(pattern string, handler dns.HandlerFunc) *Builder { + route, err := b.route(pattern, dns.OpcodeQuery, handler) + if err != nil { + panic(err.Error()) + } + + return route +} + +func (b *Builder) Update(pattern string, handler dns.HandlerFunc) *Builder { + route, err := b.route(pattern, dns.OpcodeUpdate, handler) + if err != nil { + panic(err.Error()) + } + + return route +} + +func (b *Builder) route(pattern string, op int, handler dns.HandlerFunc) (*Builder, error) { + parts := strings.Fields(pattern) + + domain := parts[0] + + _, ok := dns.IsDomainName(domain) + if !ok { + return nil, fmt.Errorf("%s: invalid domain: %s", dns.OpcodeToString[op], domain) + } + + if _, ok := b.routes[domain]; !ok { + b.routes[domain] = make(map[int]map[uint16]dns.Handler) + } + + if _, ok := b.routes[domain][op]; !ok { + b.routes[domain][op] = make(map[uint16]dns.Handler) + } + + if _, ok := b.routes[domain][op][noType]; ok { + return nil, fmt.Errorf("%s: a global route already exists for the domain: %s", dns.OpcodeToString[op], domain) + } + + switch len(parts) { + case 1: + if len(b.routes[domain][op]) > 0 { + return nil, fmt.Errorf("%s: global route and specific routes cannot be mixed for the same domain: %s", dns.OpcodeToString[op], domain) + } + + b.routes[domain][op][noType] = handler + + return b, nil + + case 2: + raw := parts[1] + + qType, ok := b.stringToType[raw] + if !ok { + return nil, fmt.Errorf("%s: unknown type: %s", dns.OpcodeToString[op], raw) + } + + if _, ok := b.routes[domain][op][qType]; ok { + return nil, fmt.Errorf("%s: duplicate route: %s", dns.OpcodeToString[op], pattern) + } + + b.routes[domain][op][qType] = handler + + return b, nil + + default: + return nil, fmt.Errorf("%s: invalid pattern: %s", dns.OpcodeToString[op], pattern) + } +} + +func (b *Builder) Build(t *testing.T, options ...Option) net.Addr { + t.Helper() + + mux := dns.NewServeMux() + + server := &dns.Server{ + Addr: "127.0.0.1:0", + Net: "udp", + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, + Handler: mux, + MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction { + // bypass defaultMsgAcceptFunc to allow dynamic update (https://github.com/miekg/dns/pull/830) + return dns.MsgAccept + }, + } + + for _, option := range options { + require.NoError(t, option(server)) + } + + for pattern, ops := range b.routes { + mux.HandleFunc(pattern, func(w dns.ResponseWriter, req *dns.Msg) { + mTypes, ok := ops[req.Opcode] + if !ok { + _ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeNotImplemented)) + + return + } + + if h, found := mTypes[noType]; found { + h.ServeDNS(w, req) + + return + } + + // For safety but it doesn't happen. + if len(req.Question) == 0 { + _ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)) + + return + } + + // For safety but it doesn't happen. + if req.Question[0].Qclass != dns.ClassINET { + _ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)) + + return + } + + // Works only for [Query]. + h, ok := mTypes[req.Question[0].Qtype] + if !ok { + _ = w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeNotImplemented)) + + return + } + + h.ServeDNS(w, req) + }) + } + + t.Cleanup(func() { + _ = server.Shutdown() + }) + + waitLock := sync.Mutex{} + waitLock.Lock() + + server.NotifyStartedFunc = waitLock.Unlock + + go func() { + err := server.ListenAndServe() + if err != nil { + t.Log(err) + } + }() + + waitLock.Lock() + + return server.PacketConn.LocalAddr() +} diff --git a/platform/tester/dnsmock/dnsmock_test.go b/platform/tester/dnsmock/dnsmock_test.go new file mode 100644 index 000000000..77a67a402 --- /dev/null +++ b/platform/tester/dnsmock/dnsmock_test.go @@ -0,0 +1,240 @@ +package dnsmock + +import ( + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServer_Query_matchType(t *testing.T) { + addr := NewServer(). + Query("example.com. SOA", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Query_noType(t *testing.T) { + addr := NewServer(). + Query("example.com.", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Query_noMatch_domain(t *testing.T) { + addr := NewServer(). + Query("example.com. SOA", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.org.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeRefused, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeRefused], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Query_noMatch_type(t *testing.T) { + addr := NewServer(). + Query("example.com. SOA", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeNotImplemented, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Query_noMatch_opType(t *testing.T) { + addr := NewServer(). + Query("example.com.", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetUpdate("example.com.") + m.Insert([]dns.RR{ + &dns.TXT{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}, + Txt: []string{"foo"}, + }, + }) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeNotImplemented, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Query_unknownType(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: unknown type: ABC", func() { + NewServer(). + Query("example.com. ABC", Noop). + Build(t) + }) +} + +func TestServer_Query_duplicate(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: duplicate route: example.com. SOA", func() { + NewServer(). + Query("example.com. SOA", Noop). + Query("example.com. SOA", Noop). + Build(t) + }) +} + +func TestServer_Query_duplicateGlobal(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: a global route already exists for the domain: example.com.", func() { + NewServer(). + Query("example.com.", Noop). + Query("example.com.", Noop). + Build(t) + }) +} + +func TestServer_Query_mixed(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: global route and specific routes cannot be mixed for the same domain: example.com.", func() { + NewServer(). + Query("example.com. SOA", Noop). + Query("example.com.", Noop). + Build(t) + }) +} + +func TestServer_Query_invalidDomain(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: invalid domain: .example.com.", func() { + NewServer(). + Query(".example.com. SOA", Noop). + Build(t) + }) +} + +func TestServer_Query_invalidPattern(t *testing.T) { + assert.PanicsWithValue(t, "QUERY: invalid pattern: example.com. SOA 13", func() { + NewServer(). + Query("example.com. SOA 13", Noop). + Build(t) + }) +} + +func TestServer_Update(t *testing.T) { + addr := NewServer(). + Update("example.com.", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetUpdate("example.com.") + m.Insert([]dns.RR{ + &dns.TXT{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}, + Txt: []string{"foo"}, + }, + }) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Update_noMatch_domain(t *testing.T) { + addr := NewServer(). + Update("example.com.", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetUpdate("example.org.") + m.Insert([]dns.RR{ + &dns.TXT{ + Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}, + Txt: []string{"foo"}, + }, + }) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeRefused, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeRefused], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Update_noMatch_opType(t *testing.T) { + addr := NewServer(). + Update("example.com.", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeNotImplemented, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeNotImplemented], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestServer_Update_duplicate(t *testing.T) { + assert.PanicsWithValue(t, "UPDATE: a global route already exists for the domain: example.com.", func() { + NewServer(). + Update("example.com.", Noop). + Update("example.com.", Noop). + Build(t) + }) +} + +func TestServer_Update_invalidDomain(t *testing.T) { + assert.PanicsWithValue(t, "UPDATE: invalid domain: .example.com.", func() { + NewServer(). + Update(".example.com.", Noop). + Build(t) + }) +} + +func TestServer_Update_invalidPattern(t *testing.T) { + assert.PanicsWithValue(t, "UPDATE: invalid pattern: example.com. SOA 13", func() { + NewServer(). + Update("example.com. SOA 13", Noop). + Build(t) + }) +} diff --git a/platform/tester/dnsmock/handlers.go b/platform/tester/dnsmock/handlers.go new file mode 100644 index 000000000..e1b047318 --- /dev/null +++ b/platform/tester/dnsmock/handlers.go @@ -0,0 +1,76 @@ +package dnsmock + +import ( + "fmt" + + "github.com/miekg/dns" +) + +func DumpRequest() dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + fmt.Println(req) + + Noop(w, req) + } +} + +func SOA(name string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + if name == "" { + name = req.Question[0].Name + } + + // Handle TLD + base := name + if dns.CountLabel(req.Question[0].Name) == 1 { + base = "nic." + req.Question[0].Name + } + + answer := &dns.SOA{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120}, + Ns: "ns1." + base, + Mbox: "admin." + base, + Serial: 2016022801, + Refresh: 28800, + Retry: 7200, + Expire: 2419200, + Minttl: 1200, + } + + Answer(answer)(w, req) + } +} + +func CNAME(target string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + answer := &dns.CNAME{ + Hdr: dns.RR_Header{Name: req.Question[0].Name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 1}, + Target: dns.Fqdn(target), + } + + Answer(answer)(w, req) + } +} + +func Noop(w dns.ResponseWriter, req *dns.Msg) { + _ = w.WriteMsg(new(dns.Msg).SetReply(req)) +} + +func Error(rcode int) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + _ = w.WriteMsg(new(dns.Msg).SetRcode(req, rcode)) + } +} + +func Answer(answer ...dns.RR) func(w dns.ResponseWriter, req *dns.Msg) { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg).SetReply(req) + + m.Answer = answer + + err := w.WriteMsg(m) + if err != nil { + panic(err.Error()) + } + } +} diff --git a/platform/tester/dnsmock/handlers_test.go b/platform/tester/dnsmock/handlers_test.go new file mode 100644 index 000000000..13cdc0e2d --- /dev/null +++ b/platform/tester/dnsmock/handlers_test.go @@ -0,0 +1,156 @@ +package dnsmock + +import ( + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSOA_self(t *testing.T) { + addr := NewServer(). + Query("example.com. SOA", SOA("")). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + expectedSOA := []dns.RR{&dns.SOA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 56}, + Ns: "ns1.example.com.", + Mbox: "admin.example.com.", + Serial: 2016022801, + Refresh: 28800, + Retry: 7200, + Expire: 2419200, + Minttl: 1200, + }} + + require.Equal(t, dns.RcodeSuccess, r.Rcode) + assert.Equal(t, expectedSOA, r.Answer) + assert.Equal(t, m.Question, r.Question) +} + +func TestSOA_differentDomain(t *testing.T) { + addr := NewServer(). + Query("example.com. SOA", SOA("example.org.")). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + + expectedSOA := []dns.RR{&dns.SOA{ + Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 56}, + Ns: "ns1.example.org.", + Mbox: "admin.example.org.", + Serial: 2016022801, + Refresh: 28800, + Retry: 7200, + Expire: 2419200, + Minttl: 1200, + }} + + assert.Equal(t, expectedSOA, r.Answer) + assert.Equal(t, m.Question, r.Question) +} + +func TestSOA_tld(t *testing.T) { + addr := NewServer(). + Query("com. SOA", SOA("")). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("com.", dns.TypeSOA) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + + expectedSOA := []dns.RR{&dns.SOA{ + Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 120, Rdlength: 48}, + Ns: "ns1.nic.com.", + Mbox: "admin.nic.com.", + Serial: 2016022801, + Refresh: 28800, + Retry: 7200, + Expire: 2419200, + Minttl: 1200, + }} + + assert.Equal(t, expectedSOA, r.Answer) + assert.Equal(t, m.Question, r.Question) +} + +func TestCNAME(t *testing.T) { + addr := NewServer(). + Query("example.com. CNAME", CNAME("example.org.")). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + + expectedCNAME := []dns.RR{&dns.CNAME{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 1, Rdlength: 13}, + Target: "example.org.", + }} + + assert.Equal(t, expectedCNAME, r.Answer) + assert.Equal(t, m.Question, r.Question) +} + +func TestNoop(t *testing.T) { + addr := NewServer(). + Query("example.com. CNAME", Noop). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeSuccess, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeSuccess], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} + +func TestError(t *testing.T) { + addr := NewServer(). + Query("example.com. CNAME", Error(dns.RcodeNameError)). + Build(t) + + client := &dns.Client{Timeout: 1 * time.Second} + + m := new(dns.Msg).SetQuestion("example.com.", dns.TypeCNAME) + + r, _, err := client.Exchange(m, addr.String()) + require.NoError(t, err) + + require.Equalf(t, dns.RcodeNameError, r.Rcode, + "expected %s, got %s", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[r.Rcode]) + assert.Equal(t, m.Question, r.Question) +} diff --git a/providers/dns/rfc2136/rfc2136.go b/providers/dns/rfc2136/rfc2136.go index 6b5c47072..84655b450 100644 --- a/providers/dns/rfc2136/rfc2136.go +++ b/providers/dns/rfc2136/rfc2136.go @@ -131,7 +131,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { config.TSIGSecret = "" } else { // zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) - config.TSIGKey = strings.ToLower(dns.Fqdn(config.TSIGKey)) + config.TSIGKey = dns.CanonicalName(config.TSIGKey) } if config.TSIGAlgorithm == "" { @@ -193,14 +193,14 @@ func (d *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { } // Create RR - rr := new(dns.TXT) - rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)} - rr.Txt = []string{value} - rrs := []dns.RR{rr} + rrs := []dns.RR{&dns.TXT{ + Hdr: dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)}, + Txt: []string{value}, + }} // Create dynamic update packet - m := new(dns.Msg) - m.SetUpdate(zone) + m := new(dns.Msg).SetUpdate(zone) + switch action { case "INSERT": // Always remove old challenge left over from who knows what. diff --git a/providers/dns/rfc2136/rfc2136_test.go b/providers/dns/rfc2136/rfc2136_test.go index 31414a4d4..1dc7270d2 100644 --- a/providers/dns/rfc2136/rfc2136_test.go +++ b/providers/dns/rfc2136/rfc2136_test.go @@ -2,23 +2,21 @@ package rfc2136 import ( "bytes" - "fmt" "strings" - "sync" "testing" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/platform/tester/dnsmock" "github.com/miekg/dns" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( fakeDomain = "123456789.www.example.com" fakeKeyAuth = "123d==" - fakeValue = "Now36o-3BmlB623-0c1qCIUmgWVVmDJb88KGl24pqpo" + fakeValue = "ADw2sEd82DUgXcQ9hNBZThJs7zVJkR5v9JeSbAb9mZY" fakeFqdn = "_acme-challenge.123456789.www.example.com." fakeZone = "example.com." fakeTTL = 120 @@ -162,33 +160,16 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func TestCanaryLocalTestServer(t *testing.T) { +func TestDNSProvider_Present_success(t *testing.T) { dns01.ClearFqdnCache() - mux, addr := runLocalDNSTestServer(t, false) - mux.HandleFunc("example.com.", serverHandlerHello) - - c := new(dns.Client) - m := new(dns.Msg) - - m.SetQuestion("example.com.", dns.TypeTXT) - - r, _, err := c.Exchange(m, addr) - require.NoError(t, err, "Failed to communicate with test server") - assert.Len(t, r.Extra, 1, "Failed to communicate with test server") - - txt := r.Extra[0].(*dns.TXT).Txt[0] - assert.Equal(t, "Hello world", txt) -} - -func TestServerSuccess(t *testing.T) { - dns01.ClearFqdnCache() - - mux, addr := runLocalDNSTestServer(t, false) - mux.HandleFunc(fakeZone, serverHandlerReturnSuccess) + addr := dnsmock.NewServer(). + Query(fakeZone+" SOA", dnsmock.SOA("")). + Update(fakeZone+" SOA", dnsmock.Noop). + Build(t) config := NewDefaultConfig() - config.Nameserver = addr + config.Nameserver = addr.String() provider, err := NewDNSProviderConfig(config) require.NoError(t, err) @@ -197,14 +178,72 @@ func TestServerSuccess(t *testing.T) { require.NoError(t, err) } -func TestServerError(t *testing.T) { +func TestDNSProvider_Present_success_updatePacket(t *testing.T) { dns01.ClearFqdnCache() - mux, addr := runLocalDNSTestServer(t, false) - mux.HandleFunc(fakeZone, serverHandlerReturnErr) + reqChan := make(chan *dns.Msg, 1) + + addr := dnsmock.NewServer(). + Query("_acme-challenge.123456789.www.example.com. SOA", dnsmock.SOA(fakeZone)). + Update(fakeZone+" SOA", func(w dns.ResponseWriter, req *dns.Msg) { + dnsmock.Noop(w, req) + + // Only talk back when it is not the SOA RR. + reqChan <- req + }). + Build(t) config := NewDefaultConfig() - config.Nameserver = addr + config.Nameserver = addr.String() + + provider, err := NewDNSProviderConfig(config) + require.NoError(t, err) + + err = provider.Present(fakeDomain, "", fakeKeyAuth) + require.NoError(t, err) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for request") + + case rcvMsg := <-reqChan: + txtRR := &dns.TXT{ + Hdr: dns.RR_Header{Name: fakeFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: fakeTTL}, + Txt: []string{fakeValue}, + } + + m := new(dns.Msg).SetUpdate(fakeZone) + + m.RemoveRRset([]dns.RR{txtRR}) + m.Insert([]dns.RR{txtRR}) + + expected, err := m.Pack() + require.NoError(t, err, "error packing") + + rcvMsg.Id = m.Id + + actual, err := rcvMsg.Pack() + require.NoError(t, err, "error packing") + + if !bytes.Equal(actual, expected) { + tmp := new(dns.Msg) + require.NoError(t, tmp.Unpack(actual)) + + t.Errorf("Expected msg:\n%s", m) + t.Errorf("Actual msg:\n%s", tmp) + } + } +} + +func TestDNSProvider_Present_error(t *testing.T) { + dns01.ClearFqdnCache() + + addr := dnsmock.NewServer(). + Query(fakeZone+" SOA", dnsmock.Error(dns.RcodeNotZone)). + Build(t) + + config := NewDefaultConfig() + config.Nameserver = addr.String() provider, err := NewDNSProviderConfig(config) require.NoError(t, err) @@ -216,14 +255,20 @@ func TestServerError(t *testing.T) { } } -func TestTsigClient(t *testing.T) { +func TestDNSProvider_Present_tsig_success(t *testing.T) { dns01.ClearFqdnCache() - mux, addr := runLocalDNSTestServer(t, true) - mux.HandleFunc(fakeZone, serverHandlerReturnSuccess) + addr := dnsmock.NewServer(). + Query(fakeZone+" SOA", dnsmock.SOA("")). + Update(fakeZone+" SOA", handleTSIG). + Build(t, func(server *dns.Server) error { + server.TsigSecret = map[string]string{fakeTsigKey: fakeTsigSecret} + + return nil + }) config := NewDefaultConfig() - config.Nameserver = addr + config.Nameserver = addr.String() config.TSIGKey = fakeTsigKey config.TSIGSecret = fakeTsigSecret @@ -234,167 +279,50 @@ func TestTsigClient(t *testing.T) { require.NoError(t, err) } -func TestValidUpdatePacket(t *testing.T) { - reqChan := make(chan *dns.Msg, 10) - +func TestDNSProvider_Present_tsig_error(t *testing.T) { dns01.ClearFqdnCache() - mux, addr := runLocalDNSTestServer(t, false) - mux.HandleFunc(fakeZone, serverHandlerPassBackRequest(reqChan)) + addr := dnsmock.NewServer(). + Query(fakeZone+" SOA", dnsmock.SOA("")). + Update(fakeZone+" SOA", handleTSIG). + Build(t, func(server *dns.Server) error { + server.TsigSecret = map[string]string{"example.org": fakeTsigSecret} - txtRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN TXT %s", fakeFqdn, fakeTTL, fakeValue)) - - m := new(dns.Msg) - m.SetUpdate(fakeZone) - m.RemoveRRset([]dns.RR{txtRR}) - m.Insert([]dns.RR{txtRR}) - - expectStr := m.String() - - expect, err := m.Pack() - require.NoError(t, err, "error packing") + return nil + }) config := NewDefaultConfig() - config.Nameserver = addr + config.Nameserver = addr.String() + config.TSIGKey = fakeTsigKey + config.TSIGSecret = fakeTsigSecret provider, err := NewDNSProviderConfig(config) require.NoError(t, err) - err = provider.Present(fakeDomain, "", "1234d==") - require.NoError(t, err) - - rcvMsg := <-reqChan - rcvMsg.Id = m.Id - - actual, err := rcvMsg.Pack() - require.NoError(t, err, "error packing") - - if !bytes.Equal(actual, expect) { - tmp := new(dns.Msg) - if err := tmp.Unpack(actual); err != nil { - t.Fatalf("Error unpacking actual msg: %v", err) - } - t.Errorf("Expected msg:\n%s", expectStr) - t.Errorf("Actual msg:\n%v", tmp) - } + err = provider.Present(fakeDomain, "", fakeKeyAuth) + require.Error(t, err) + require.EqualError(t, err, "rfc2136: failed to insert: DNS update failed: server replied: NOTZONE") } -func runLocalDNSTestServer(t *testing.T, tsig bool) (*dns.ServeMux, string) { - t.Helper() - - mux := dns.NewServeMux() - - server := &dns.Server{ - Addr: "127.0.0.1:0", - Net: "udp", - ReadTimeout: time.Hour, - WriteTimeout: time.Hour, - MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction { - // bypass defaultMsgAcceptFunc to allow dynamic update (https://github.com/miekg/dns/pull/830) - return dns.MsgAccept - }, - Handler: mux, - } - - t.Cleanup(func() { - _ = server.Shutdown() - }) - - if tsig { - server.TsigSecret = map[string]string{fakeTsigKey: fakeTsigSecret} - } - - waitLock := sync.Mutex{} - waitLock.Lock() - - server.NotifyStartedFunc = waitLock.Unlock - - go func() { - err := server.ListenAndServe() - if err != nil { - t.Log(err) - } - }() - - waitLock.Lock() - - return mux, server.PacketConn.LocalAddr().String() -} - -func serverHandlerHello(w dns.ResponseWriter, req *dns.Msg) { +func handleTSIG(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) - m.SetReply(req) - m.Extra = make([]dns.RR, 1) - m.Extra[0] = &dns.TXT{ - Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}, - Txt: []string{"Hello world"}, + tsig := req.IsTsig() + if tsig == nil { + _ = w.WriteMsg(m.SetRcode(req, dns.RcodeRefused)) + return } - _ = w.WriteMsg(m) -} + err := w.TsigStatus() + if err != nil { + _ = w.WriteMsg(m.SetRcode(req, dns.RcodeNotZone)) -func serverHandlerReturnSuccess(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET { - // Return SOA to appease findZoneByFqdn() - m.Answer = []dns.RR{fakeSOAAnswer()} + return } - if t := req.IsTsig(); t != nil { - if w.TsigStatus() == nil { - // Validated - m.SetTsig(fakeZone, dns.HmacSHA1, 300, time.Now().Unix()) - } - } - - _ = w.WriteMsg(m) -} - -func serverHandlerReturnErr(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNotZone) - - _ = w.WriteMsg(m) -} - -func serverHandlerPassBackRequest(reqChan chan *dns.Msg) func(w dns.ResponseWriter, req *dns.Msg) { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET { - // Return SOA to appease findZoneByFqdn() - m.Answer = []dns.RR{fakeSOAAnswer()} - } - - if t := req.IsTsig(); t != nil { - if w.TsigStatus() == nil { - // Validated - m.SetTsig(fakeZone, dns.HmacSHA1, 300, time.Now().Unix()) - } - } - - _ = w.WriteMsg(m) - - if req.Opcode != dns.OpcodeQuery || req.Question[0].Qtype != dns.TypeSOA || req.Question[0].Qclass != dns.ClassINET { - // Only talk back when it is not the SOA RR. - reqChan <- req - } - } -} - -func fakeSOAAnswer() *dns.SOA { - return &dns.SOA{ - Hdr: dns.RR_Header{Name: fakeZone, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: fakeTTL}, - Ns: "ns1." + fakeZone, - Mbox: "admin." + fakeZone, - Serial: 2016022801, - Refresh: 28800, - Retry: 7200, - Expire: 2419200, - Minttl: 1200, - } + // Validated + _ = w.WriteMsg(m. + SetReply(req). + SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix()), + ) }