From a4499b729c238df940bbd7106c0fc57839531d56 Mon Sep 17 00:00:00 2001 From: Fernandez Ludovic Date: Wed, 14 Jan 2026 05:17:58 +0100 Subject: [PATCH] feat: new package dnsnew --- challenge/dnsnew/client.go | 155 ++++++++ challenge/dnsnew/client_cache.go | 34 ++ challenge/dnsnew/client_cname.go | 57 +++ challenge/dnsnew/client_cname_test.go | 35 ++ challenge/dnsnew/client_error.go | 64 ++++ challenge/dnsnew/client_error_test.go | 75 ++++ challenge/dnsnew/client_nameservers.go | 107 ++++++ challenge/dnsnew/client_nameservers_test.go | 216 +++++++++++ challenge/dnsnew/client_timeout_unix.go | 8 + challenge/dnsnew/client_timeout_windows.go | 8 + challenge/dnsnew/client_zone.go | 89 +++++ challenge/dnsnew/client_zone_test.go | 158 ++++++++ challenge/dnsnew/dns_challenge.go | 200 ++++++++++ challenge/dnsnew/dns_challenge_options.go | 46 +++ challenge/dnsnew/dns_challenge_precheck.go | 86 +++++ .../dnsnew/dns_challenge_precheck_test.go | 78 ++++ challenge/dnsnew/dns_challenge_test.go | 348 ++++++++++++++++++ challenge/dnsnew/domain.go | 24 ++ challenge/dnsnew/domain_test.go | 102 +++++ challenge/dnsnew/fixtures/resolv.conf.1 | 5 + challenge/dnsnew/fqdn.go | 47 +++ challenge/dnsnew/fqdn_test.go | 137 +++++++ challenge/dnsnew/mock_test.go | 78 ++++ 23 files changed, 2157 insertions(+) create mode 100644 challenge/dnsnew/client.go create mode 100644 challenge/dnsnew/client_cache.go create mode 100644 challenge/dnsnew/client_cname.go create mode 100644 challenge/dnsnew/client_cname_test.go create mode 100644 challenge/dnsnew/client_error.go create mode 100644 challenge/dnsnew/client_error_test.go create mode 100644 challenge/dnsnew/client_nameservers.go create mode 100644 challenge/dnsnew/client_nameservers_test.go create mode 100644 challenge/dnsnew/client_timeout_unix.go create mode 100644 challenge/dnsnew/client_timeout_windows.go create mode 100644 challenge/dnsnew/client_zone.go create mode 100644 challenge/dnsnew/client_zone_test.go create mode 100644 challenge/dnsnew/dns_challenge.go create mode 100644 challenge/dnsnew/dns_challenge_options.go create mode 100644 challenge/dnsnew/dns_challenge_precheck.go create mode 100644 challenge/dnsnew/dns_challenge_precheck_test.go create mode 100644 challenge/dnsnew/dns_challenge_test.go create mode 100644 challenge/dnsnew/domain.go create mode 100644 challenge/dnsnew/domain_test.go create mode 100644 challenge/dnsnew/fixtures/resolv.conf.1 create mode 100644 challenge/dnsnew/fqdn.go create mode 100644 challenge/dnsnew/fqdn_test.go create mode 100644 challenge/dnsnew/mock_test.go diff --git a/challenge/dnsnew/client.go b/challenge/dnsnew/client.go new file mode 100644 index 000000000..8f153c582 --- /dev/null +++ b/challenge/dnsnew/client.go @@ -0,0 +1,155 @@ +package dnsnew + +import ( + "context" + "errors" + "os" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/go-acme/lego/v5/challenge" + "github.com/miekg/dns" +) + +const defaultResolvConf = "/etc/resolv.conf" + +var defaultClient atomic.Pointer[Client] + +func init() { + defaultClient.Store(NewClient(nil)) +} + +func DefaultClient() *Client { return defaultClient.Load() } + +func SetDefaultClient(c *Client) { + defaultClient.Store(c) +} + +type Options struct { + RecursiveNameservers []string + Timeout time.Duration + TCPOnly bool + NetworkStack challenge.NetworkStack +} + +type Client struct { + recursiveNameservers []string + + // authoritativeNSPort used by authoritative NS. + // For testing purposes only. + authoritativeNSPort string + + tcpClient *dns.Client + udpClient *dns.Client + tcpOnly bool + + fqdnSoaCache map[string]*soaCacheEntry + muFqdnSoaCache sync.Mutex +} + +func NewClient(opts *Options) *Client { + if opts == nil { + tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")) + opts = &Options{TCPOnly: tcpOnly} + } + + if len(opts.RecursiveNameservers) == 0 { + defaultNameservers := []string{ + "google-public-dns-a.google.com:53", + "google-public-dns-b.google.com:53", + } + + opts.RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) + } + + if opts.Timeout == 0 { + opts.Timeout = dnsTimeout + } + + return &Client{ + recursiveNameservers: opts.RecursiveNameservers, + authoritativeNSPort: "53", + tcpClient: &dns.Client{ + Net: opts.NetworkStack.Network("tcp"), + Timeout: opts.Timeout, + }, + udpClient: &dns.Client{ + Net: opts.NetworkStack.Network("udp"), + Timeout: opts.Timeout, + }, + tcpOnly: opts.TCPOnly, + fqdnSoaCache: map[string]*soaCacheEntry{}, + muFqdnSoaCache: sync.Mutex{}, + } +} + +func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) { + return c.sendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive) +} + +func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { + m := createDNSMsg(fqdn, rtype, recursive) + + if len(nameservers) == 0 { + return nil, &DNSError{Message: "empty list of nameservers"} + } + + var ( + r *dns.Msg + err error + errAll error + ) + + for _, ns := range nameservers { + r, err = c.exchange(ctx, m, ns) + if err == nil && len(r.Answer) > 0 { + break + } + + errAll = errors.Join(errAll, err) + } + + if err != nil { + return r, errAll + } + + return r, nil +} + +func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) { + if c.tcpOnly { + r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns) + if err != nil { + return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} + } + + return r, nil + } + + r, _, err := c.udpClient.ExchangeContext(ctx, m, ns) + + if r != nil && r.Truncated { + // If the TCP request succeeds, the "err" will reset to nil + r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns) + } + + if err != nil { + return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} + } + + return r, nil +} + +func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + + if !recursive { + m.RecursionDesired = false + } + + return m +} diff --git a/challenge/dnsnew/client_cache.go b/challenge/dnsnew/client_cache.go new file mode 100644 index 000000000..4b960cd6c --- /dev/null +++ b/challenge/dnsnew/client_cache.go @@ -0,0 +1,34 @@ +package dnsnew + +import ( + "time" + + "github.com/miekg/dns" +) + +// soaCacheEntry holds a cached SOA record (only selected fields). +type soaCacheEntry struct { + zone string // zone apex (a domain name) + primaryNs string // primary nameserver for the zone apex + expires time.Time // time when this cache entry should be evicted +} + +func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry { + return &soaCacheEntry{ + zone: soa.Hdr.Name, + primaryNs: soa.Ns, + expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second), + } +} + +// isExpired checks whether a cache entry should be considered expired. +func (cache *soaCacheEntry) isExpired() bool { + return time.Now().After(cache.expires) +} + +// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing. +func (c *Client) ClearFqdnCache() { + c.muFqdnSoaCache.Lock() + c.fqdnSoaCache = map[string]*soaCacheEntry{} + c.muFqdnSoaCache.Unlock() +} diff --git a/challenge/dnsnew/client_cname.go b/challenge/dnsnew/client_cname.go new file mode 100644 index 000000000..4a03b84dc --- /dev/null +++ b/challenge/dnsnew/client_cname.go @@ -0,0 +1,57 @@ +package dnsnew + +import ( + "context" + "slices" + "strings" + + "github.com/go-acme/lego/v5/log" + "github.com/miekg/dns" +) + +func (c *Client) lookupCNAME(ctx context.Context, fqdn string) string { + // recursion counter so it doesn't spin out of control + for range 50 { + // Keep following CNAMEs + r, err := c.sendQuery(ctx, fqdn, dns.TypeCNAME, true) + + if err != nil || r.Rcode != dns.RcodeSuccess { + // TODO(ldez): logs the error in v5 + // No more CNAME records to follow, exit + break + } + + // Check if the domain has CNAME then use that + cname := updateDomainWithCName(r, fqdn) + if cname == fqdn { + break + } + + log.Info("Found CNAME entry.", "fqdn", fqdn, "cname", cname) + + fqdn = cname + } + + return fqdn +} + +// Update FQDN with CNAME if any. +func updateDomainWithCName(r *dns.Msg, fqdn string) string { + for _, rr := range r.Answer { + if cn, ok := rr.(*dns.CNAME); ok { + if strings.EqualFold(cn.Hdr.Name, fqdn) { + return cn.Target + } + } + } + + return fqdn +} + +// dnsMsgContainsCNAME checks for a CNAME answer in msg. +func dnsMsgContainsCNAME(msg *dns.Msg) bool { + return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool { + _, ok := rr.(*dns.CNAME) + return ok + }) +} diff --git a/challenge/dnsnew/client_cname_test.go b/challenge/dnsnew/client_cname_test.go new file mode 100644 index 000000000..838adf526 --- /dev/null +++ b/challenge/dnsnew/client_cname_test.go @@ -0,0 +1,35 @@ +package dnsnew + +import ( + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func Test_updateDomainWithCName_caseInsensitive(t *testing.T) { + qname := "_acme-challenge.uppercase-test.example.com." + cnameTarget := "_acme-challenge.uppercase-test.cname-target.example.com." + + msg := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Authoritative: true, + }, + Answer: []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: strings.ToUpper(qname), // CNAME names are case-insensitive + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 3600, + }, + Target: cnameTarget, + }, + }, + } + + fqdn := updateDomainWithCName(msg, qname) + + assert.Equal(t, cnameTarget, fqdn) +} diff --git a/challenge/dnsnew/client_error.go b/challenge/dnsnew/client_error.go new file mode 100644 index 000000000..3ab8d1e62 --- /dev/null +++ b/challenge/dnsnew/client_error.go @@ -0,0 +1,64 @@ +package dnsnew + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" +) + +// DNSError error related to DNS calls. +type DNSError struct { + Message string + NS string + MsgIn *dns.Msg + MsgOut *dns.Msg + Err error +} + +func (d *DNSError) Error() string { + var details []string + if d.NS != "" { + details = append(details, "ns="+d.NS) + } + + if d.MsgIn != nil && len(d.MsgIn.Question) > 0 { + details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question))) + } + + if d.MsgOut != nil { + if d.MsgIn == nil || len(d.MsgIn.Question) == 0 { + details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question))) + } + + details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode]) + } + + msg := "DNS error" + if d.Message != "" { + msg = d.Message + } + + if d.Err != nil { + msg += ": " + d.Err.Error() + } + + if len(details) > 0 { + msg += " [" + strings.Join(details, ", ") + "]" + } + + return msg +} + +func (d *DNSError) Unwrap() error { + return d.Err +} + +func formatQuestions(questions []dns.Question) string { + var parts []string + for _, question := range questions { + parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " ")) + } + + return strings.Join(parts, ";") +} diff --git a/challenge/dnsnew/client_error_test.go b/challenge/dnsnew/client_error_test.go new file mode 100644 index 000000000..f30f59ab8 --- /dev/null +++ b/challenge/dnsnew/client_error_test.go @@ -0,0 +1,75 @@ +package dnsnew + +import ( + "errors" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestDNSError_Error(t *testing.T) { + msgIn := createDNSMsg("example.com.", dns.TypeTXT, true) + + msgOut := createDNSMsg("example.org.", dns.TypeSOA, true) + msgOut.Rcode = dns.RcodeNameError + + testCases := []struct { + desc string + err *DNSError + expected string + }{ + { + desc: "empty error", + err: &DNSError{}, + expected: "DNS error", + }, + { + desc: "all fields", + err: &DNSError{ + Message: "Oops", + NS: "example.com.", + MsgIn: msgIn, + MsgOut: msgOut, + Err: errors.New("I did it again"), + }, + expected: "Oops: I did it again [ns=example.com., question='example.com. IN TXT', code=NXDOMAIN]", + }, + { + desc: "only NS", + err: &DNSError{ + NS: "example.com.", + }, + expected: "DNS error [ns=example.com.]", + }, + { + desc: "only MsgIn", + err: &DNSError{ + MsgIn: msgIn, + }, + expected: "DNS error [question='example.com. IN TXT']", + }, + { + desc: "only MsgOut", + err: &DNSError{ + MsgOut: msgOut, + }, + expected: "DNS error [question='example.org. IN SOA', code=NXDOMAIN]", + }, + { + desc: "only Err", + err: &DNSError{ + Err: errors.New("I did it again"), + }, + expected: "DNS error: I did it again", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + assert.EqualError(t, test.err, test.expected) + }) + } +} diff --git a/challenge/dnsnew/client_nameservers.go b/challenge/dnsnew/client_nameservers.go new file mode 100644 index 000000000..8867b9143 --- /dev/null +++ b/challenge/dnsnew/client_nameservers.go @@ -0,0 +1,107 @@ +package dnsnew + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/miekg/dns" +) + +// checkNameserversPropagation queries each of the recursive nameservers for the expected TXT record. +func (c *Client) checkNameserversPropagation(ctx context.Context, fqdn, value string, addPort bool) (bool, error) { + return c.checkNameserversPropagationCustom(ctx, fqdn, value, c.recursiveNameservers, addPort) +} + +// checkNameserversPropagationCustom queries each of the given nameservers for the expected TXT record. +func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn, value string, nameservers []string, addPort bool) (bool, error) { + for _, ns := range nameservers { + if addPort { + ns = net.JoinHostPort(ns, c.authoritativeNSPort) + } + + r, err := c.sendQueryCustom(ctx, fqdn, dns.TypeTXT, []string{ns}, false) + if err != nil { + return false, err + } + + if r.Rcode != dns.RcodeSuccess { + return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn) + } + + var records []string + + var found bool + + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + record := strings.Join(txt.Txt, "") + + records = append(records, record) + if record == value { + found = true + break + } + } + } + + if !found { + return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,")) + } + } + + return true, nil +} + +// lookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn. +func (c *Client) lookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) { + var authoritativeNss []string + + zone, err := c.FindZoneByFqdn(ctx, fqdn) + if err != nil { + return nil, fmt.Errorf("could not find zone: %w", err) + } + + r, err := c.sendQuery(ctx, zone, dns.TypeNS, true) + if err != nil { + return nil, fmt.Errorf("NS call failed: %w", err) + } + + for _, rr := range r.Answer { + if ns, ok := rr.(*dns.NS); ok { + authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) + } + } + + if len(authoritativeNss) > 0 { + return authoritativeNss, nil + } + + return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone) +} + +// getNameservers attempts to get systems nameservers before falling back to the defaults. +func getNameservers(path string, defaults []string) []string { + config, err := dns.ClientConfigFromFile(path) + if err != nil || len(config.Servers) == 0 { + return defaults + } + + return parseNameservers(config.Servers) +} + +func parseNameservers(servers []string) []string { + var resolvers []string + + for _, resolver := range servers { + // ensure all servers have a port number + if _, _, err := net.SplitHostPort(resolver); err != nil { + resolvers = append(resolvers, net.JoinHostPort(resolver, "53")) + } else { + resolvers = append(resolvers, resolver) + } + } + + return resolvers +} diff --git a/challenge/dnsnew/client_nameservers_test.go b/challenge/dnsnew/client_nameservers_test.go new file mode 100644 index 000000000..5ca7e927c --- /dev/null +++ b/challenge/dnsnew/client_nameservers_test.go @@ -0,0 +1,216 @@ +package dnsnew + +import ( + "sort" + "testing" + + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClient_checkNameserversPropagationCustom_authoritativeNss(t *testing.T) { + testCases := []struct { + desc string + fqdn, value string + fakeDNSServer *dnsmock.Builder + expectedError string + }{ + { + desc: "TXT RR w/ expected value", + // NS: asnums.routeviews.org. + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "151698.8.8.024", + fakeDNSServer: dnsmock.NewServer(). + Query("8.8.8.8.asn.routeviews.org. TXT", + dnsmock.Answer( + fakeTXT("8.8.8.8.asn.routeviews.org.", "151698.8.8.024"), + ), + ), + }, + { + desc: "TXT RR w/ unexpected value", + // NS: asnums.routeviews.org. + fqdn: "8.8.8.8.asn.routeviews.org.", + value: "fe01=", + fakeDNSServer: dnsmock.NewServer(). + Query("8.8.8.8.asn.routeviews.org. TXT", + dnsmock.Answer( + fakeTXT("8.8.8.8.asn.routeviews.org.", "15169"), + fakeTXT("8.8.8.8.asn.routeviews.org.", "8.8.8.0"), + fakeTXT("8.8.8.8.asn.routeviews.org.", "24"), + ), + ), + expectedError: "did not return the expected TXT record [fqdn: 8.8.8.8.asn.routeviews.org., value: fe01=]: 15169 ,8.8.8.0 ,24", + }, + { + desc: "No TXT RR", + // NS: ns2.google.com. + fqdn: "ns1.google.com.", + value: "fe01=", + fakeDNSServer: dnsmock.NewServer(). + Query("ns1.google.com.", dnsmock.Noop), + expectedError: "did not return the expected TXT record [fqdn: ns1.google.com., value: fe01=]: ", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(nil) + + addr := test.fakeDNSServer.Build(t) + + ok, err := client.checkNameserversPropagationCustom(t.Context(), test.fqdn, test.value, []string{addr.String()}, false) + + if test.expectedError == "" { + require.NoError(t, err) + assert.True(t, ok) + } else { + require.Error(t, err) + require.ErrorContains(t, err, test.expectedError) + assert.False(t, ok) + } + }) + } +} + +func TestClient_lookupAuthoritativeNameservers_OK(t *testing.T) { + testCases := []struct { + desc string + fakeDNSServer *dnsmock.Builder + fqdn string + expected []string + }{ + { + fqdn: "en.wikipedia.org.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("en.wikipedia.org.localhost SOA", dnsmock.CNAME("dyna.wikimedia.org.localhost")). + Query("wikipedia.org.localhost SOA", dnsmock.SOA("")). + Query("wikipedia.org.localhost NS", + dnsmock.Answer( + fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."), + fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."), + fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."), + ), + ), + expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."}, + }, + { + fqdn: "www.google.com.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("www.google.com.localhost. SOA", dnsmock.Noop). + Query("google.com.localhost. SOA", dnsmock.SOA("")). + Query("google.com.localhost. NS", + dnsmock.Answer( + fakeNS("google.com.localhost.", "ns1.google.com.localhost."), + fakeNS("google.com.localhost.", "ns2.google.com.localhost."), + fakeNS("google.com.localhost.", "ns3.google.com.localhost."), + fakeNS("google.com.localhost.", "ns4.google.com.localhost."), + ), + ), + expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."}, + }, + { + fqdn: "mail.proton.me.localhost.", + fakeDNSServer: dnsmock.NewServer(). + Query("mail.proton.me.localhost. SOA", dnsmock.Noop). + Query("proton.me.localhost. SOA", dnsmock.SOA("")). + Query("proton.me.localhost. NS", + dnsmock.Answer( + fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."), + fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."), + fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."), + ), + ), + expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."}, + }, + } + + for _, test := range testCases { + t.Run(test.fqdn, func(t *testing.T) { + client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}}) + + nss, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn) + require.NoError(t, err) + + sort.Strings(nss) + sort.Strings(test.expected) + + assert.Equal(t, test.expected, nss) + }) + } +} + +func TestClient_lookupAuthoritativeNameservers_error(t *testing.T) { + testCases := []struct { + desc string + fqdn string + fakeDNSServer *dnsmock.Builder + error string + }{ + { + desc: "NXDOMAIN", + fqdn: "example.invalid.", + fakeDNSServer: dnsmock.NewServer(). + Query(". SOA", dnsmock.Error(dns.RcodeNameError)), + error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]", + }, + { + desc: "NS error", + fqdn: "example.com.", + fakeDNSServer: dnsmock.NewServer(). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.com. NS", dnsmock.Error(dns.RcodeServerFailure)), + error: "[zone=example.com.] could not determine authoritative nameservers", + }, + { + desc: "empty NS", + fqdn: "example.com.", + fakeDNSServer: dnsmock.NewServer(). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.me NS", dnsmock.Noop), + error: "[zone=example.com.] could not determine authoritative nameservers", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}}) + + _, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn) + require.Error(t, err) + assert.EqualError(t, err, test.error) + }) + } +} + +func Test_getNameservers_ResolveConfServers(t *testing.T) { + testCases := []struct { + fixture string + expected []string + defaults []string + }{ + { + fixture: "fixtures/resolv.conf.1", + defaults: []string{"127.0.0.1:53"}, + expected: []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"}, + }, + { + fixture: "fixtures/resolv.conf.nonexistant", + defaults: []string{"127.0.0.1:53"}, + expected: []string{"127.0.0.1:53"}, + }, + } + + for _, test := range testCases { + t.Run(test.fixture, func(t *testing.T) { + result := getNameservers(test.fixture, test.defaults) + + sort.Strings(result) + sort.Strings(test.expected) + + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/challenge/dnsnew/client_timeout_unix.go b/challenge/dnsnew/client_timeout_unix.go new file mode 100644 index 000000000..49dded01c --- /dev/null +++ b/challenge/dnsnew/client_timeout_unix.go @@ -0,0 +1,8 @@ +//go:build !windows + +package dnsnew + +import "time" + +// dnsTimeout is used to override the default DNS timeout of 10 seconds. +const dnsTimeout = 10 * time.Second diff --git a/challenge/dnsnew/client_timeout_windows.go b/challenge/dnsnew/client_timeout_windows.go new file mode 100644 index 000000000..5dc94d056 --- /dev/null +++ b/challenge/dnsnew/client_timeout_windows.go @@ -0,0 +1,8 @@ +//go:build windows + +package dnsnew + +import "time" + +// dnsTimeout is used to override the default DNS timeout of 20 seconds. +const dnsTimeout = 20 * time.Second diff --git a/challenge/dnsnew/client_zone.go b/challenge/dnsnew/client_zone.go new file mode 100644 index 000000000..899318e5c --- /dev/null +++ b/challenge/dnsnew/client_zone.go @@ -0,0 +1,89 @@ +package dnsnew + +import ( + "context" + "fmt" + + "github.com/miekg/dns" +) + +// FindZoneByFqdn determines the zone apex for the given fqdn +// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. +func (c *Client) FindZoneByFqdn(ctx context.Context, fqdn string) (string, error) { + return c.FindZoneByFqdnCustom(ctx, fqdn, c.recursiveNameservers) +} + +// FindZoneByFqdnCustom determines the zone apex for the given fqdn +// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. +func (c *Client) FindZoneByFqdnCustom(ctx context.Context, fqdn string, nameservers []string) (string, error) { + soa, err := c.lookupSoaByFqdn(ctx, fqdn, nameservers) + if err != nil { + return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err) + } + + return soa.zone, nil +} + +func (c *Client) lookupSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) { + c.muFqdnSoaCache.Lock() + defer c.muFqdnSoaCache.Unlock() + + // Do we have it cached and is it still fresh? + if ent := c.fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() { + return ent, nil + } + + ent, err := c.fetchSoaByFqdn(ctx, fqdn, nameservers) + if err != nil { + return nil, err + } + + c.fqdnSoaCache[fqdn] = ent + + return ent, nil +} + +func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) { + var ( + err error + r *dns.Msg + ) + + for domain := range DomainsSeq(fqdn) { + r, err = c.sendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true) + if err != nil { + continue + } + + if r == nil { + continue + } + + switch r.Rcode { + case dns.RcodeSuccess: + // Check if we got a SOA RR in the answer section + if len(r.Answer) == 0 { + continue + } + + // CNAME records cannot/should not exist at the root of a zone. + // So we skip a domain when a CNAME is found. + if dnsMsgContainsCNAME(r) { + continue + } + + for _, ans := range r.Answer { + if soa, ok := ans.(*dns.SOA); ok { + return newSoaCacheEntry(soa), nil + } + } + case dns.RcodeNameError: + // NXDOMAIN + default: + // Any response code other than NOERROR and NXDOMAIN is treated as error + return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r} + } + } + + return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err} +} diff --git a/challenge/dnsnew/client_zone_test.go b/challenge/dnsnew/client_zone_test.go new file mode 100644 index 000000000..8ea10ec50 --- /dev/null +++ b/challenge/dnsnew/client_zone_test.go @@ -0,0 +1,158 @@ +package dnsnew + +import ( + "testing" + + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type lookupSoaByFqdnTestCase struct { + desc string + fqdn string + zone string + primaryNs string + nameservers []string + expectedError string +} + +func lookupSoaByFqdnTestCases(t *testing.T) []lookupSoaByFqdnTestCase { + t.Helper() + + return []lookupSoaByFqdnTestCase{ + { + desc: "domain is a CNAME", + fqdn: "mail.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a non-existent subdomain", + fqdn: "foo.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("foo.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a eTLD", + fqdn: "example.com.ac.", + zone: "ac.", + primaryNs: "ns1.nic.ac.", + nameservers: []string{ + dnsmock.NewServer(). + Query("example.com.ac. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("com.ac. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("ac. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "domain is a cross-zone CNAME", + fqdn: "cross-zone-example.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + dnsmock.NewServer(). + Query("cross-zone-example.example.com. SOA", dnsmock.CNAME("example.org.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "NXDOMAIN", + fqdn: "test.lego.invalid.", + zone: "lego.invalid.", + nameservers: []string{ + dnsmock.NewServer(). + Query("test.lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("invalid. SOA", dnsmock.Error(dns.RcodeNameError)). + Build(t). + String(), + }, + expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`, + }, + { + desc: "several non existent nameservers", + fqdn: "mail.example.com.", + zone: "example.com.", + primaryNs: "ns1.example.com.", + nameservers: []string{ + ":7053", + ":8053", + dnsmock.NewServer(). + Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")). + Query("example.com. SOA", dnsmock.SOA("")). + Build(t). + String(), + }, + }, + { + desc: "only non-existent nameservers", + fqdn: "mail.example.com.", + zone: "example.com.", + nameservers: []string{":7053", ":8053", ":9053"}, + // use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053. + expectedError: "[fqdn=mail.example.com.] could not find the start of authority for 'mail.example.com.': DNS call error: read udp ", + }, + { + desc: "no nameservers", + fqdn: "test.example.com.", + zone: "example.com.", + nameservers: []string{}, + expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers", + }, + } +} + +func TestClient_FindZoneByFqdnCustom(t *testing.T) { + for _, test := range lookupSoaByFqdnTestCases(t) { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(nil) + + zone, err := client.FindZoneByFqdnCustom(t.Context(), test.fqdn, test.nameservers) + if test.expectedError != "" { + require.Error(t, err) + assert.ErrorContains(t, err, test.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, test.zone, zone) + } + }) + } +} + +func TestClient_FindZoneByFqdn(t *testing.T) { + for _, test := range lookupSoaByFqdnTestCases(t) { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(nil) + client.recursiveNameservers = test.nameservers + + zone, err := client.FindZoneByFqdn(t.Context(), test.fqdn) + if test.expectedError != "" { + require.Error(t, err) + assert.ErrorContains(t, err, test.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, test.zone, zone) + } + }) + } +} diff --git a/challenge/dnsnew/dns_challenge.go b/challenge/dnsnew/dns_challenge.go new file mode 100644 index 000000000..0efcdd2a8 --- /dev/null +++ b/challenge/dnsnew/dns_challenge.go @@ -0,0 +1,200 @@ +package dnsnew + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/go-acme/lego/v5/acme" + "github.com/go-acme/lego/v5/acme/api" + "github.com/go-acme/lego/v5/challenge" + "github.com/go-acme/lego/v5/log" + "github.com/go-acme/lego/v5/platform/wait" +) + +const ( + // DefaultPropagationTimeout default propagation timeout. + DefaultPropagationTimeout = 60 * time.Second + + // DefaultPollingInterval default polling interval. + DefaultPollingInterval = 2 * time.Second + + // DefaultTTL default TTL. + DefaultTTL = 120 +) + +type ValidateFunc func(ctx context.Context, core *api.Core, domain string, chlng acme.Challenge) error + +// Challenge implements the dns-01 challenge. +type Challenge struct { + core *api.Core + validate ValidateFunc + provider challenge.Provider + preCheck preCheck +} + +func NewChallenge(core *api.Core, validate ValidateFunc, provider challenge.Provider, opts ...ChallengeOption) *Challenge { + chlg := &Challenge{ + core: core, + validate: validate, + provider: provider, + preCheck: newPreCheck(), + } + + for _, opt := range opts { + err := opt(chlg) + if err != nil { + log.Warn("Challenge option skipped.", "error", err) + } + } + + return chlg +} + +// PreSolve just submits the txt record to the dns provider. +// It does not validate record propagation or do anything at all with the ACME server. +func (c *Challenge) PreSolve(ctx context.Context, authz acme.Authorization) error { + domain := challenge.GetTargetedDomain(authz) + log.Info("acme: Preparing to solve DNS-01.", "domain", domain) + + chlng, err := challenge.FindChallenge(challenge.DNS01, authz) + if err != nil { + return err + } + + if c.provider == nil { + return fmt.Errorf("[%s] acme: no DNS Provider configured", domain) + } + + // Generate the Key Authorization for the challenge + keyAuth, err := c.core.GetKeyAuthorization(chlng.Token) + if err != nil { + return err + } + + err = c.provider.Present(authz.Identifier.Value, chlng.Token, keyAuth) + if err != nil { + return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err) + } + + return nil +} + +func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error { + domain := challenge.GetTargetedDomain(authz) + log.Info("acme: Trying to solve DNS-01.", "domain", domain) + + chlng, err := challenge.FindChallenge(challenge.DNS01, authz) + if err != nil { + return err + } + + // Generate the Key Authorization for the challenge + keyAuth, err := c.core.GetKeyAuthorization(chlng.Token) + if err != nil { + return err + } + + info := GetChallengeInfo(ctx, authz.Identifier.Value, keyAuth) + + var timeout, interval time.Duration + + switch provider := c.provider.(type) { + case challenge.ProviderTimeout: + timeout, interval = provider.Timeout() + default: + timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval + } + + log.Info("acme: Checking DNS record propagation.", + "domain", domain, "nameservers", strings.Join(DefaultClient().recursiveNameservers, ",")) + + time.Sleep(interval) + + err = wait.For("propagation", timeout, interval, func() (bool, error) { + stop, errP := c.preCheck.call(ctx, domain, info.EffectiveFQDN, info.Value) + if !stop || errP != nil { + log.Info("acme: Waiting for DNS record propagation.", "domain", domain) + } + + return stop, errP + }) + if err != nil { + return err + } + + chlng.KeyAuthorization = keyAuth + + return c.validate(ctx, c.core, domain, chlng) +} + +// CleanUp cleans the challenge. +func (c *Challenge) CleanUp(authz acme.Authorization) error { + log.Info("acme: Cleaning DNS-01 challenge.", "domain", challenge.GetTargetedDomain(authz)) + + chlng, err := challenge.FindChallenge(challenge.DNS01, authz) + if err != nil { + return err + } + + keyAuth, err := c.core.GetKeyAuthorization(chlng.Token) + if err != nil { + return err + } + + return c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth) +} + +func (c *Challenge) Sequential() (bool, time.Duration) { + if p, ok := c.provider.(sequential); ok { + return ok, p.Sequential() + } + + return false, 0 +} + +type sequential interface { + Sequential() time.Duration +} + +// ChallengeInfo contains the information use to create the TXT record. +type ChallengeInfo struct { + // FQDN is the full-qualified challenge domain (i.e. `_acme-challenge.[domain].`) + FQDN string + + // EffectiveFQDN contains the resulting FQDN after the CNAMEs resolutions. + EffectiveFQDN string + + // Value contains the value for the TXT record. + Value string +} + +// GetChallengeInfo returns information used to create a DNS record which will fulfill the `dns-01` challenge. +func GetChallengeInfo(ctx context.Context, domain, keyAuth string) ChallengeInfo { + keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) + // base64URL encoding without padding + value := base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size]) + + ok, _ := strconv.ParseBool(os.Getenv("LEGO_DISABLE_CNAME_SUPPORT")) + + return ChallengeInfo{ + Value: value, + FQDN: getChallengeFQDN(ctx, domain, false), + EffectiveFQDN: getChallengeFQDN(ctx, domain, !ok), + } +} + +func getChallengeFQDN(ctx context.Context, domain string, followCNAME bool) string { + fqdn := fmt.Sprintf("_acme-challenge.%s.", domain) + + if !followCNAME { + return fqdn + } + + return DefaultClient().lookupCNAME(ctx, fqdn) +} diff --git a/challenge/dnsnew/dns_challenge_options.go b/challenge/dnsnew/dns_challenge_options.go new file mode 100644 index 000000000..4d5554b14 --- /dev/null +++ b/challenge/dnsnew/dns_challenge_options.go @@ -0,0 +1,46 @@ +package dnsnew + +import ( + "context" + "time" +) + +type ChallengeOption func(*Challenge) error + +// CondOption Conditional challenge option. +func CondOption(condition bool, opt ChallengeOption) ChallengeOption { + if !condition { + // NoOp options + return func(*Challenge) error { + return nil + } + } + + return opt +} + +func DisableAuthoritativeNssPropagationRequirement() ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.requireAuthoritativeNssPropagation = false + return nil + } +} + +func RecursiveNSsPropagationRequirement() ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.requireRecursiveNssPropagation = true + return nil + } +} + +func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption { + return WrapPreCheck(func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error) { + time.Sleep(wait) + + if skipCheck { + return true, nil + } + + return check(ctx, fqdn, value) + }) +} diff --git a/challenge/dnsnew/dns_challenge_precheck.go b/challenge/dnsnew/dns_challenge_precheck.go new file mode 100644 index 000000000..8833a0fc1 --- /dev/null +++ b/challenge/dnsnew/dns_challenge_precheck.go @@ -0,0 +1,86 @@ +package dnsnew + +import ( + "context" + "fmt" + + "github.com/miekg/dns" +) + +// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready. +type PreCheckFunc func(ctx context.Context, fqdn, value string) (bool, error) + +// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after +// the main check, put it in a loop, etc. +type WrapPreCheckFunc func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error) + +// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready. +func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.checkFunc = wrap + return nil + } +} + +type preCheck struct { + // checks DNS propagation before notifying ACME that the DNS challenge is ready. + checkFunc WrapPreCheckFunc + + // require the TXT record to be propagated to all authoritative name servers + requireAuthoritativeNssPropagation bool + + // require the TXT record to be propagated to all recursive name servers + requireRecursiveNssPropagation bool +} + +func newPreCheck() preCheck { + return preCheck{ + requireAuthoritativeNssPropagation: true, + } +} + +func (p preCheck) call(ctx context.Context, domain, fqdn, value string) (bool, error) { + if p.checkFunc == nil { + return p.checkDNSPropagation(ctx, fqdn, value) + } + + return p.checkFunc(ctx, domain, fqdn, value, p.checkDNSPropagation) +} + +// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. +func (p preCheck) checkDNSPropagation(ctx context.Context, fqdn, value string) (bool, error) { + client := DefaultClient() + + // Initial attempt to resolve at the recursive NS (require getting CNAME) + r, err := client.sendQuery(ctx, fqdn, dns.TypeTXT, true) + if err != nil { + return false, fmt.Errorf("initial recursive nameserver: %w", err) + } + + if r.Rcode == dns.RcodeSuccess { + fqdn = updateDomainWithCName(r, fqdn) + } + + if p.requireRecursiveNssPropagation { + _, err = client.checkNameserversPropagation(ctx, fqdn, value, false) + if err != nil { + return false, fmt.Errorf("recursive nameservers: %w", err) + } + } + + if !p.requireAuthoritativeNssPropagation { + return true, nil + } + + authoritativeNss, err := client.lookupAuthoritativeNameservers(ctx, fqdn) + if err != nil { + return false, err + } + + found, err := client.checkNameserversPropagationCustom(ctx, fqdn, value, authoritativeNss, true) + if err != nil { + return found, fmt.Errorf("authoritative nameservers: %w", err) + } + + return found, nil +} diff --git a/challenge/dnsnew/dns_challenge_precheck_test.go b/challenge/dnsnew/dns_challenge_precheck_test.go new file mode 100644 index 000000000..48b98111e --- /dev/null +++ b/challenge/dnsnew/dns_challenge_precheck_test.go @@ -0,0 +1,78 @@ +package dnsnew + +import ( + "testing" + + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func Test_preCheck_checkDNSPropagation(t *testing.T) { + mockDefault(t, + dnsmock.NewServer(). + Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.com. NS", + dnsmock.Answer( + fakeNS("example.com.", "ns0.lego.localhost."), + fakeNS("example.com.", "ns1.lego.localhost."), + ), + ). + Build(t), + mockResolver( + dnsmock.NewServer(). + Query("ns0.lego.localhost. A", + dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))). + Query("ns1.lego.localhost. A", + dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))). + Query("example.com. TXT", + dnsmock.Answer( + fakeTXT("example.com.", "one"), + fakeTXT("example.com.", "two"), + fakeTXT("example.com.", "three"), + fakeTXT("example.com.", "four"), + fakeTXT("example.com.", "five"), + ), + ). + Build(t), + ), + ) + + testCases := []struct { + desc string + fqdn string + value string + expectedError string + }{ + { + desc: "success", + fqdn: "example.com.", + value: "four", + }, + { + desc: "no matching TXT record", + fqdn: "acme-staging.api.example.com.", + value: "fe01=", + expectedError: "did not return the expected TXT record [fqdn: acme-staging.api.example.com., value: fe01=]: one ,two ,three ,four ,five", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + DefaultClient().ClearFqdnCache() + + check := newPreCheck() + + ok, err := check.checkDNSPropagation(t.Context(), test.fqdn, test.value) + if test.expectedError != "" { + assert.ErrorContainsf(t, err, test.expectedError, "PreCheckDNS must fail for %s", test.fqdn) + assert.False(t, ok, "PreCheckDNS must fail for %s", test.fqdn) + } else { + assert.NoErrorf(t, err, "PreCheckDNS failed for %s", test.fqdn) + assert.True(t, ok, "PreCheckDNS failed for %s", test.fqdn) + } + }) + } +} diff --git a/challenge/dnsnew/dns_challenge_test.go b/challenge/dnsnew/dns_challenge_test.go new file mode 100644 index 000000000..44f4b6097 --- /dev/null +++ b/challenge/dnsnew/dns_challenge_test.go @@ -0,0 +1,348 @@ +package dnsnew + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + "time" + + "github.com/go-acme/lego/v5/acme" + "github.com/go-acme/lego/v5/acme/api" + "github.com/go-acme/lego/v5/challenge" + "github.com/go-acme/lego/v5/platform/tester" + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type providerMock struct { + present, cleanUp error +} + +func (p *providerMock) Present(domain, token, keyAuth string) error { return p.present } +func (p *providerMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp } + +type providerTimeoutMock struct { + present, cleanUp error + timeout, interval time.Duration +} + +func (p *providerTimeoutMock) Present(domain, token, keyAuth string) error { return p.present } +func (p *providerTimeoutMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp } +func (p *providerTimeoutMock) Timeout() (time.Duration, time.Duration) { return p.timeout, p.interval } + +func TestChallenge_PreSolve(t *testing.T) { + server := tester.MockACMEServer().BuildHTTPS(t) + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey) + require.NoError(t, err) + + testCases := []struct { + desc string + validate ValidateFunc + preCheck WrapPreCheckFunc + provider challenge.Provider + expectError bool + }{ + { + desc: "success", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{}, + }, + { + desc: "validate fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: nil, + cleanUp: nil, + }, + }, + { + desc: "preCheck fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { + return false, errors.New("OOPS") + }, + provider: &providerTimeoutMock{ + timeout: 2 * time.Second, + interval: 500 * time.Millisecond, + }, + }, + { + desc: "present fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: errors.New("OOPS"), + }, + expectError: true, + }, + { + desc: "cleanUp fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + cleanUp: errors.New("OOPS"), + }, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck)) + + authz := acme.Authorization{ + Identifier: acme.Identifier{ + Value: "example.com", + }, + Challenges: []acme.Challenge{ + {Type: challenge.DNS01.String()}, + }, + } + + err = chlg.PreSolve(t.Context(), authz) + if test.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestChallenge_Solve(t *testing.T) { + mockDefault(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.Noop). + Build(t)) + + server := tester.MockACMEServer().BuildHTTPS(t) + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey) + require.NoError(t, err) + + testCases := []struct { + desc string + validate ValidateFunc + preCheck WrapPreCheckFunc + provider challenge.Provider + expectError bool + }{ + { + desc: "success", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{}, + }, + { + desc: "validate fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: nil, + cleanUp: nil, + }, + expectError: true, + }, + { + desc: "preCheck fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { + return false, errors.New("OOPS") + }, + provider: &providerTimeoutMock{ + timeout: 2 * time.Second, + interval: 500 * time.Millisecond, + }, + expectError: true, + }, + { + desc: "present fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: errors.New("OOPS"), + }, + }, + { + desc: "cleanUp fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + cleanUp: errors.New("OOPS"), + }, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + var options []ChallengeOption + if test.preCheck != nil { + options = append(options, WrapPreCheck(test.preCheck)) + } + + chlg := NewChallenge(core, test.validate, test.provider, options...) + + authz := acme.Authorization{ + Identifier: acme.Identifier{ + Value: "example.com", + }, + Challenges: []acme.Challenge{ + {Type: challenge.DNS01.String()}, + }, + } + + err = chlg.Solve(t.Context(), authz) + if test.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestChallenge_CleanUp(t *testing.T) { + server := tester.MockACMEServer().BuildHTTPS(t) + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey) + require.NoError(t, err) + + testCases := []struct { + desc string + validate ValidateFunc + preCheck WrapPreCheckFunc + provider challenge.Provider + expectError bool + }{ + { + desc: "success", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{}, + }, + { + desc: "validate fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: nil, + cleanUp: nil, + }, + }, + { + desc: "preCheck fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { + return false, errors.New("OOPS") + }, + provider: &providerTimeoutMock{ + timeout: 2 * time.Second, + interval: 500 * time.Millisecond, + }, + }, + { + desc: "present fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + present: errors.New("OOPS"), + }, + }, + { + desc: "cleanUp fail", + validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil }, + preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, + provider: &providerMock{ + cleanUp: errors.New("OOPS"), + }, + expectError: true, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck)) + + authz := acme.Authorization{ + Identifier: acme.Identifier{ + Value: "example.com", + }, + Challenges: []acme.Challenge{ + {Type: challenge.DNS01.String()}, + }, + } + + err = chlg.CleanUp(authz) + if test.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetChallengeInfo(t *testing.T) { + mockDefault(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.Noop). + Build(t)) + + info := GetChallengeInfo(t.Context(), "example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "_acme-challenge.example.com.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} + +func TestGetChallengeInfo_CNAME(t *testing.T) { + mockDefault(t, dnsmock.NewServer(). + Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")). + Query("example.org. CNAME", dnsmock.Noop). + Build(t)) + + info := GetChallengeInfo(t.Context(), "example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "example.org.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} + +func TestGetChallengeInfo_CNAME_disabled(t *testing.T) { + mockDefault(t, dnsmock.NewServer(). + // Never called when the env var works. + Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")). + Build(t)) + + t.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true") + + info := GetChallengeInfo(t.Context(), "example.com", "123") + + expected := ChallengeInfo{ + FQDN: "_acme-challenge.example.com.", + EffectiveFQDN: "_acme-challenge.example.com.", + Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM", + } + + assert.Equal(t, expected, info) +} diff --git a/challenge/dnsnew/domain.go b/challenge/dnsnew/domain.go new file mode 100644 index 000000000..822128aab --- /dev/null +++ b/challenge/dnsnew/domain.go @@ -0,0 +1,24 @@ +package dnsnew + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" +) + +// ExtractSubDomain extracts the subdomain part from a domain and a zone. +func ExtractSubDomain(domain, zone string) (string, error) { + canonDomain := dns.Fqdn(domain) + canonZone := dns.Fqdn(zone) + + if canonDomain == canonZone { + return "", fmt.Errorf("no subdomain because the domain and the zone are identical: %s", canonDomain) + } + + if !dns.IsSubDomain(canonZone, canonDomain) { + return "", fmt.Errorf("%s is not a subdomain of %s", canonDomain, canonZone) + } + + return strings.TrimSuffix(canonDomain, "."+canonZone), nil +} diff --git a/challenge/dnsnew/domain_test.go b/challenge/dnsnew/domain_test.go new file mode 100644 index 000000000..9453559ec --- /dev/null +++ b/challenge/dnsnew/domain_test.go @@ -0,0 +1,102 @@ +package dnsnew + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractSubDomain(t *testing.T) { + testCases := []struct { + desc string + domain string + zone string + expected string + }{ + { + desc: "no FQDN", + domain: "_acme-challenge.example.com", + zone: "example.com", + expected: "_acme-challenge", + }, + { + desc: "no FQDN zone", + domain: "_acme-challenge.example.com.", + zone: "example.com", + expected: "_acme-challenge", + }, + { + desc: "no FQDN domain", + domain: "_acme-challenge.example.com", + zone: "example.com.", + expected: "_acme-challenge", + }, + { + desc: "FQDN", + domain: "_acme-challenge.example.com.", + zone: "example.com.", + expected: "_acme-challenge", + }, + { + desc: "multi-level subdomain", + domain: "_acme-challenge.one.example.com.", + zone: "example.com.", + expected: "_acme-challenge.one", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + subDomain, err := ExtractSubDomain(test.domain, test.zone) + require.NoError(t, err) + + assert.Equal(t, test.expected, subDomain) + }) + } +} + +func TestExtractSubDomain_errors(t *testing.T) { + testCases := []struct { + desc string + domain string + zone string + }{ + { + desc: "same domain", + domain: "example.com", + zone: "example.com", + }, + { + desc: "same domain, no FQDN zone", + domain: "example.com.", + zone: "example.com", + }, + { + desc: "same domain, no FQDN domain", + domain: "example.com", + zone: "example.com.", + }, + { + desc: "same domain, FQDN", + domain: "example.com.", + zone: "example.com.", + }, + { + desc: "zone and domain are unrelated", + domain: "_acme-challenge.example.com", + zone: "example.org", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + _, err := ExtractSubDomain(test.domain, test.zone) + require.Error(t, err) + }) + } +} diff --git a/challenge/dnsnew/fixtures/resolv.conf.1 b/challenge/dnsnew/fixtures/resolv.conf.1 new file mode 100644 index 000000000..bc2a3c1ac --- /dev/null +++ b/challenge/dnsnew/fixtures/resolv.conf.1 @@ -0,0 +1,5 @@ +domain example.com +nameserver 10.200.3.249 +nameserver 10.200.3.250:5353 +nameserver 2001:4860:4860::8844 +nameserver [10.0.0.1]:5353 diff --git a/challenge/dnsnew/fqdn.go b/challenge/dnsnew/fqdn.go new file mode 100644 index 000000000..350c1e7e9 --- /dev/null +++ b/challenge/dnsnew/fqdn.go @@ -0,0 +1,47 @@ +package dnsnew + +import ( + "iter" + + "github.com/miekg/dns" +) + +// UnFqdn converts the fqdn into a name removing the trailing dot. +func UnFqdn(name string) string { + n := len(name) + if n != 0 && name[n-1] == '.' { + return name[:n-1] + } + + return name +} + +// UnFqdnDomainsSeq generates a sequence of "unFQDNed" domain names derived from a domain (FQDN or not) in descending order. +func UnFqdnDomainsSeq(fqdn string) iter.Seq[string] { + return func(yield func(string) bool) { + if fqdn == "" { + return + } + + for _, index := range dns.Split(fqdn) { + if !yield(UnFqdn(fqdn[index:])) { + return + } + } + } +} + +// DomainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order. +func DomainsSeq(fqdn string) iter.Seq[string] { + return func(yield func(string) bool) { + if fqdn == "" { + return + } + + for _, index := range dns.Split(fqdn) { + if !yield(fqdn[index:]) { + return + } + } + } +} diff --git a/challenge/dnsnew/fqdn_test.go b/challenge/dnsnew/fqdn_test.go new file mode 100644 index 000000000..e83724672 --- /dev/null +++ b/challenge/dnsnew/fqdn_test.go @@ -0,0 +1,137 @@ +package dnsnew + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnFqdn(t *testing.T) { + testCases := []struct { + desc string + fqdn string + expected string + }{ + { + desc: "simple", + fqdn: "foo.example.", + expected: "foo.example", + }, + { + desc: "already domain", + fqdn: "foo.example", + expected: "foo.example", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + domain := UnFqdn(test.fqdn) + + assert.Equal(t, test.expected, domain) + }) + } +} + +func TestUnFqdnDomainsSeq(t *testing.T) { + testCases := []struct { + desc string + fqdn string + expected []string + }{ + { + desc: "empty", + fqdn: "", + expected: nil, + }, + { + desc: "TLD", + fqdn: "com", + expected: []string{"com"}, + }, + { + desc: "2 levels", + fqdn: "example.com", + expected: []string{"example.com", "com"}, + }, + { + desc: "3 levels", + fqdn: "foo.example.com", + expected: []string{"foo.example.com", "example.com", "com"}, + }, + } + + for _, test := range testCases { + for name, suffix := range map[string]string{"": "", " FQDN": "."} { //nolint:gocritic + t.Run(test.desc+name, func(t *testing.T) { + t.Parallel() + + actual := slices.Collect(UnFqdnDomainsSeq(test.fqdn + suffix)) + + assert.Equal(t, test.expected, actual) + }) + } + } +} + +func TestDomainsSeq(t *testing.T) { + testCases := []struct { + desc string + fqdn string + expected []string + }{ + { + desc: "empty", + fqdn: "", + expected: nil, + }, + { + desc: "empty FQDN", + fqdn: ".", + expected: nil, + }, + { + desc: "TLD FQDN", + fqdn: "com", + expected: []string{"com"}, + }, + { + desc: "TLD", + fqdn: "com.", + expected: []string{"com."}, + }, + { + desc: "2 levels", + fqdn: "example.com", + expected: []string{"example.com", "com"}, + }, + { + desc: "2 levels FQDN", + fqdn: "example.com.", + expected: []string{"example.com.", "com."}, + }, + { + desc: "3 levels", + fqdn: "foo.example.com", + expected: []string{"foo.example.com", "example.com", "com"}, + }, + { + desc: "3 levels FQDN", + fqdn: "foo.example.com.", + expected: []string{"foo.example.com.", "example.com.", "com."}, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + actual := slices.Collect(DomainsSeq(test.fqdn)) + + assert.Equal(t, test.expected, actual) + }) + } +} diff --git a/challenge/dnsnew/mock_test.go b/challenge/dnsnew/mock_test.go new file mode 100644 index 000000000..9aebb93a1 --- /dev/null +++ b/challenge/dnsnew/mock_test.go @@ -0,0 +1,78 @@ +package dnsnew + +import ( + "context" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func fakeNS(name, ns string) *dns.NS { + return &dns.NS{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800}, + Ns: ns, + } +} + +func fakeA(name, ip string) *dns.A { + return &dns.A{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10}, + A: net.ParseIP(ip), + } +} + +func fakeTXT(name, value string) *dns.TXT { + return &dns.TXT{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 10}, + Txt: []string{value}, + } +} + +// mockResolver modifies the default DNS resolver to use a custom network address during the test execution. +// IMPORTANT: it modifying std global variables. +func mockResolver(authoritativeNS net.Addr) func(t *testing.T, client *Client) { + return func(t *testing.T, client *Client) { + t.Helper() + + _, port, err := net.SplitHostPort(authoritativeNS.String()) + require.NoError(t, err) + + client.authoritativeNSPort = port + + originalResolver := net.DefaultResolver + + t.Cleanup(func() { + net.DefaultResolver = originalResolver + }) + + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{Timeout: 1 * time.Second} + + return d.DialContext(ctx, network, authoritativeNS.String()) + }, + } + } +} + +func mockDefault(t *testing.T, recursiveNS net.Addr, opts ...func(t *testing.T, client *Client)) { + t.Helper() + + backup := DefaultClient() + + t.Cleanup(func() { + SetDefaultClient(backup) + }) + + client := NewClient(&Options{RecursiveNameservers: []string{recursiveNS.String()}}) + + for _, opt := range opts { + opt(t, client) + } + + SetDefaultClient(client) +}