feat: new package dnsnew

This commit is contained in:
Fernandez Ludovic 2026-01-14 05:17:58 +01:00
commit a4499b729c
23 changed files with 2157 additions and 0 deletions

155
challenge/dnsnew/client.go Normal file
View file

@ -0,0 +1,155 @@
package dnsnew
import (
"context"
"errors"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/go-acme/lego/v5/challenge"
"github.com/miekg/dns"
)
const defaultResolvConf = "/etc/resolv.conf"
var defaultClient atomic.Pointer[Client]
func init() {
defaultClient.Store(NewClient(nil))
}
func DefaultClient() *Client { return defaultClient.Load() }
func SetDefaultClient(c *Client) {
defaultClient.Store(c)
}
type Options struct {
RecursiveNameservers []string
Timeout time.Duration
TCPOnly bool
NetworkStack challenge.NetworkStack
}
type Client struct {
recursiveNameservers []string
// authoritativeNSPort used by authoritative NS.
// For testing purposes only.
authoritativeNSPort string
tcpClient *dns.Client
udpClient *dns.Client
tcpOnly bool
fqdnSoaCache map[string]*soaCacheEntry
muFqdnSoaCache sync.Mutex
}
func NewClient(opts *Options) *Client {
if opts == nil {
tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY"))
opts = &Options{TCPOnly: tcpOnly}
}
if len(opts.RecursiveNameservers) == 0 {
defaultNameservers := []string{
"google-public-dns-a.google.com:53",
"google-public-dns-b.google.com:53",
}
opts.RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
}
if opts.Timeout == 0 {
opts.Timeout = dnsTimeout
}
return &Client{
recursiveNameservers: opts.RecursiveNameservers,
authoritativeNSPort: "53",
tcpClient: &dns.Client{
Net: opts.NetworkStack.Network("tcp"),
Timeout: opts.Timeout,
},
udpClient: &dns.Client{
Net: opts.NetworkStack.Network("udp"),
Timeout: opts.Timeout,
},
tcpOnly: opts.TCPOnly,
fqdnSoaCache: map[string]*soaCacheEntry{},
muFqdnSoaCache: sync.Mutex{},
}
}
func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) {
return c.sendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive)
}
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
if len(nameservers) == 0 {
return nil, &DNSError{Message: "empty list of nameservers"}
}
var (
r *dns.Msg
err error
errAll error
)
for _, ns := range nameservers {
r, err = c.exchange(ctx, m, ns)
if err == nil && len(r.Answer) > 0 {
break
}
errAll = errors.Join(errAll, err)
}
if err != nil {
return r, errAll
}
return r, nil
}
func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
if c.tcpOnly {
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
}
r, _, err := c.udpClient.ExchangeContext(ctx, m, ns)
if r != nil && r.Truncated {
// If the TCP request succeeds, the "err" will reset to nil
r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns)
}
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
}
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
return m
}

View file

@ -0,0 +1,34 @@
package dnsnew
import (
"time"
"github.com/miekg/dns"
)
// soaCacheEntry holds a cached SOA record (only selected fields).
type soaCacheEntry struct {
zone string // zone apex (a domain name)
primaryNs string // primary nameserver for the zone apex
expires time.Time // time when this cache entry should be evicted
}
func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
return &soaCacheEntry{
zone: soa.Hdr.Name,
primaryNs: soa.Ns,
expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second),
}
}
// isExpired checks whether a cache entry should be considered expired.
func (cache *soaCacheEntry) isExpired() bool {
return time.Now().After(cache.expires)
}
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
func (c *Client) ClearFqdnCache() {
c.muFqdnSoaCache.Lock()
c.fqdnSoaCache = map[string]*soaCacheEntry{}
c.muFqdnSoaCache.Unlock()
}

View file

@ -0,0 +1,57 @@
package dnsnew
import (
"context"
"slices"
"strings"
"github.com/go-acme/lego/v5/log"
"github.com/miekg/dns"
)
func (c *Client) lookupCNAME(ctx context.Context, fqdn string) string {
// recursion counter so it doesn't spin out of control
for range 50 {
// Keep following CNAMEs
r, err := c.sendQuery(ctx, fqdn, dns.TypeCNAME, true)
if err != nil || r.Rcode != dns.RcodeSuccess {
// TODO(ldez): logs the error in v5
// No more CNAME records to follow, exit
break
}
// Check if the domain has CNAME then use that
cname := updateDomainWithCName(r, fqdn)
if cname == fqdn {
break
}
log.Info("Found CNAME entry.", "fqdn", fqdn, "cname", cname)
fqdn = cname
}
return fqdn
}
// Update FQDN with CNAME if any.
func updateDomainWithCName(r *dns.Msg, fqdn string) string {
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if strings.EqualFold(cn.Hdr.Name, fqdn) {
return cn.Target
}
}
}
return fqdn
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg.
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
_, ok := rr.(*dns.CNAME)
return ok
})
}

View file

@ -0,0 +1,35 @@
package dnsnew
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func Test_updateDomainWithCName_caseInsensitive(t *testing.T) {
qname := "_acme-challenge.uppercase-test.example.com."
cnameTarget := "_acme-challenge.uppercase-test.cname-target.example.com."
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{
Authoritative: true,
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{
Name: strings.ToUpper(qname), // CNAME names are case-insensitive
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 3600,
},
Target: cnameTarget,
},
},
}
fqdn := updateDomainWithCName(msg, qname)
assert.Equal(t, cnameTarget, fqdn)
}

View file

@ -0,0 +1,64 @@
package dnsnew
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// DNSError error related to DNS calls.
type DNSError struct {
Message string
NS string
MsgIn *dns.Msg
MsgOut *dns.Msg
Err error
}
func (d *DNSError) Error() string {
var details []string
if d.NS != "" {
details = append(details, "ns="+d.NS)
}
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
}
if d.MsgOut != nil {
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
}
details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
}
msg := "DNS error"
if d.Message != "" {
msg = d.Message
}
if d.Err != nil {
msg += ": " + d.Err.Error()
}
if len(details) > 0 {
msg += " [" + strings.Join(details, ", ") + "]"
}
return msg
}
func (d *DNSError) Unwrap() error {
return d.Err
}
func formatQuestions(questions []dns.Question) string {
var parts []string
for _, question := range questions {
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
}
return strings.Join(parts, ";")
}

View file

@ -0,0 +1,75 @@
package dnsnew
import (
"errors"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestDNSError_Error(t *testing.T) {
msgIn := createDNSMsg("example.com.", dns.TypeTXT, true)
msgOut := createDNSMsg("example.org.", dns.TypeSOA, true)
msgOut.Rcode = dns.RcodeNameError
testCases := []struct {
desc string
err *DNSError
expected string
}{
{
desc: "empty error",
err: &DNSError{},
expected: "DNS error",
},
{
desc: "all fields",
err: &DNSError{
Message: "Oops",
NS: "example.com.",
MsgIn: msgIn,
MsgOut: msgOut,
Err: errors.New("I did it again"),
},
expected: "Oops: I did it again [ns=example.com., question='example.com. IN TXT', code=NXDOMAIN]",
},
{
desc: "only NS",
err: &DNSError{
NS: "example.com.",
},
expected: "DNS error [ns=example.com.]",
},
{
desc: "only MsgIn",
err: &DNSError{
MsgIn: msgIn,
},
expected: "DNS error [question='example.com. IN TXT']",
},
{
desc: "only MsgOut",
err: &DNSError{
MsgOut: msgOut,
},
expected: "DNS error [question='example.org. IN SOA', code=NXDOMAIN]",
},
{
desc: "only Err",
err: &DNSError{
Err: errors.New("I did it again"),
},
expected: "DNS error: I did it again",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
assert.EqualError(t, test.err, test.expected)
})
}
}

View file

@ -0,0 +1,107 @@
package dnsnew
import (
"context"
"fmt"
"net"
"strings"
"github.com/miekg/dns"
)
// checkNameserversPropagation queries each of the recursive nameservers for the expected TXT record.
func (c *Client) checkNameserversPropagation(ctx context.Context, fqdn, value string, addPort bool) (bool, error) {
return c.checkNameserversPropagationCustom(ctx, fqdn, value, c.recursiveNameservers, addPort)
}
// checkNameserversPropagationCustom queries each of the given nameservers for the expected TXT record.
func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn, value string, nameservers []string, addPort bool) (bool, error) {
for _, ns := range nameservers {
if addPort {
ns = net.JoinHostPort(ns, c.authoritativeNSPort)
}
r, err := c.sendQueryCustom(ctx, fqdn, dns.TypeTXT, []string{ns}, false)
if err != nil {
return false, err
}
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
var records []string
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
record := strings.Join(txt.Txt, "")
records = append(records, record)
if record == value {
found = true
break
}
}
}
if !found {
return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,"))
}
}
return true, nil
}
// lookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn.
func (c *Client) lookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) {
var authoritativeNss []string
zone, err := c.FindZoneByFqdn(ctx, fqdn)
if err != nil {
return nil, fmt.Errorf("could not find zone: %w", err)
}
r, err := c.sendQuery(ctx, zone, dns.TypeNS, true)
if err != nil {
return nil, fmt.Errorf("NS call failed: %w", err)
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
}
// getNameservers attempts to get systems nameservers before falling back to the defaults.
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
return parseNameservers(config.Servers)
}
func parseNameservers(servers []string) []string {
var resolvers []string
for _, resolver := range servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(resolver); err != nil {
resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
} else {
resolvers = append(resolvers, resolver)
}
}
return resolvers
}

View file

@ -0,0 +1,216 @@
package dnsnew
import (
"sort"
"testing"
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_checkNameserversPropagationCustom_authoritativeNss(t *testing.T) {
testCases := []struct {
desc string
fqdn, value string
fakeDNSServer *dnsmock.Builder
expectedError string
}{
{
desc: "TXT RR w/ expected value",
// NS: asnums.routeviews.org.
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "151698.8.8.024",
fakeDNSServer: dnsmock.NewServer().
Query("8.8.8.8.asn.routeviews.org. TXT",
dnsmock.Answer(
fakeTXT("8.8.8.8.asn.routeviews.org.", "151698.8.8.024"),
),
),
},
{
desc: "TXT RR w/ unexpected value",
// NS: asnums.routeviews.org.
fqdn: "8.8.8.8.asn.routeviews.org.",
value: "fe01=",
fakeDNSServer: dnsmock.NewServer().
Query("8.8.8.8.asn.routeviews.org. TXT",
dnsmock.Answer(
fakeTXT("8.8.8.8.asn.routeviews.org.", "15169"),
fakeTXT("8.8.8.8.asn.routeviews.org.", "8.8.8.0"),
fakeTXT("8.8.8.8.asn.routeviews.org.", "24"),
),
),
expectedError: "did not return the expected TXT record [fqdn: 8.8.8.8.asn.routeviews.org., value: fe01=]: 15169 ,8.8.8.0 ,24",
},
{
desc: "No TXT RR",
// NS: ns2.google.com.
fqdn: "ns1.google.com.",
value: "fe01=",
fakeDNSServer: dnsmock.NewServer().
Query("ns1.google.com.", dnsmock.Noop),
expectedError: "did not return the expected TXT record [fqdn: ns1.google.com., value: fe01=]: ",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client := NewClient(nil)
addr := test.fakeDNSServer.Build(t)
ok, err := client.checkNameserversPropagationCustom(t.Context(), test.fqdn, test.value, []string{addr.String()}, false)
if test.expectedError == "" {
require.NoError(t, err)
assert.True(t, ok)
} else {
require.Error(t, err)
require.ErrorContains(t, err, test.expectedError)
assert.False(t, ok)
}
})
}
}
func TestClient_lookupAuthoritativeNameservers_OK(t *testing.T) {
testCases := []struct {
desc string
fakeDNSServer *dnsmock.Builder
fqdn string
expected []string
}{
{
fqdn: "en.wikipedia.org.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("en.wikipedia.org.localhost SOA", dnsmock.CNAME("dyna.wikimedia.org.localhost")).
Query("wikipedia.org.localhost SOA", dnsmock.SOA("")).
Query("wikipedia.org.localhost NS",
dnsmock.Answer(
fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."),
fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."),
fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."),
),
),
expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."},
},
{
fqdn: "www.google.com.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("www.google.com.localhost. SOA", dnsmock.Noop).
Query("google.com.localhost. SOA", dnsmock.SOA("")).
Query("google.com.localhost. NS",
dnsmock.Answer(
fakeNS("google.com.localhost.", "ns1.google.com.localhost."),
fakeNS("google.com.localhost.", "ns2.google.com.localhost."),
fakeNS("google.com.localhost.", "ns3.google.com.localhost."),
fakeNS("google.com.localhost.", "ns4.google.com.localhost."),
),
),
expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."},
},
{
fqdn: "mail.proton.me.localhost.",
fakeDNSServer: dnsmock.NewServer().
Query("mail.proton.me.localhost. SOA", dnsmock.Noop).
Query("proton.me.localhost. SOA", dnsmock.SOA("")).
Query("proton.me.localhost. NS",
dnsmock.Answer(
fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."),
fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."),
fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."),
),
),
expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."},
},
}
for _, test := range testCases {
t.Run(test.fqdn, func(t *testing.T) {
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
nss, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
require.NoError(t, err)
sort.Strings(nss)
sort.Strings(test.expected)
assert.Equal(t, test.expected, nss)
})
}
}
func TestClient_lookupAuthoritativeNameservers_error(t *testing.T) {
testCases := []struct {
desc string
fqdn string
fakeDNSServer *dnsmock.Builder
error string
}{
{
desc: "NXDOMAIN",
fqdn: "example.invalid.",
fakeDNSServer: dnsmock.NewServer().
Query(". SOA", dnsmock.Error(dns.RcodeNameError)),
error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]",
},
{
desc: "NS error",
fqdn: "example.com.",
fakeDNSServer: dnsmock.NewServer().
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.com. NS", dnsmock.Error(dns.RcodeServerFailure)),
error: "[zone=example.com.] could not determine authoritative nameservers",
},
{
desc: "empty NS",
fqdn: "example.com.",
fakeDNSServer: dnsmock.NewServer().
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.me NS", dnsmock.Noop),
error: "[zone=example.com.] could not determine authoritative nameservers",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
_, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
require.Error(t, err)
assert.EqualError(t, err, test.error)
})
}
}
func Test_getNameservers_ResolveConfServers(t *testing.T) {
testCases := []struct {
fixture string
expected []string
defaults []string
}{
{
fixture: "fixtures/resolv.conf.1",
defaults: []string{"127.0.0.1:53"},
expected: []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"},
},
{
fixture: "fixtures/resolv.conf.nonexistant",
defaults: []string{"127.0.0.1:53"},
expected: []string{"127.0.0.1:53"},
},
}
for _, test := range testCases {
t.Run(test.fixture, func(t *testing.T) {
result := getNameservers(test.fixture, test.defaults)
sort.Strings(result)
sort.Strings(test.expected)
assert.Equal(t, test.expected, result)
})
}
}

View file

@ -0,0 +1,8 @@
//go:build !windows
package dnsnew
import "time"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
const dnsTimeout = 10 * time.Second

View file

@ -0,0 +1,8 @@
//go:build windows
package dnsnew
import "time"
// dnsTimeout is used to override the default DNS timeout of 20 seconds.
const dnsTimeout = 20 * time.Second

View file

@ -0,0 +1,89 @@
package dnsnew
import (
"context"
"fmt"
"github.com/miekg/dns"
)
// FindZoneByFqdn determines the zone apex for the given fqdn
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
func (c *Client) FindZoneByFqdn(ctx context.Context, fqdn string) (string, error) {
return c.FindZoneByFqdnCustom(ctx, fqdn, c.recursiveNameservers)
}
// FindZoneByFqdnCustom determines the zone apex for the given fqdn
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
func (c *Client) FindZoneByFqdnCustom(ctx context.Context, fqdn string, nameservers []string) (string, error) {
soa, err := c.lookupSoaByFqdn(ctx, fqdn, nameservers)
if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
}
return soa.zone, nil
}
func (c *Client) lookupSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) {
c.muFqdnSoaCache.Lock()
defer c.muFqdnSoaCache.Unlock()
// Do we have it cached and is it still fresh?
if ent := c.fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() {
return ent, nil
}
ent, err := c.fetchSoaByFqdn(ctx, fqdn, nameservers)
if err != nil {
return nil, err
}
c.fqdnSoaCache[fqdn] = ent
return ent, nil
}
func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) {
var (
err error
r *dns.Msg
)
for domain := range DomainsSeq(fqdn) {
r, err = c.sendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}
if r == nil {
continue
}
switch r.Rcode {
case dns.RcodeSuccess:
// Check if we got a SOA RR in the answer section
if len(r.Answer) == 0 {
continue
}
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(r) {
continue
}
for _, ans := range r.Answer {
if soa, ok := ans.(*dns.SOA); ok {
return newSoaCacheEntry(soa), nil
}
}
case dns.RcodeNameError:
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
}
}
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
}

View file

@ -0,0 +1,158 @@
package dnsnew
import (
"testing"
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type lookupSoaByFqdnTestCase struct {
desc string
fqdn string
zone string
primaryNs string
nameservers []string
expectedError string
}
func lookupSoaByFqdnTestCases(t *testing.T) []lookupSoaByFqdnTestCase {
t.Helper()
return []lookupSoaByFqdnTestCase{
{
desc: "domain is a CNAME",
fqdn: "mail.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a non-existent subdomain",
fqdn: "foo.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("foo.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a eTLD",
fqdn: "example.com.ac.",
zone: "ac.",
primaryNs: "ns1.nic.ac.",
nameservers: []string{
dnsmock.NewServer().
Query("example.com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("ac. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "domain is a cross-zone CNAME",
fqdn: "cross-zone-example.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
dnsmock.NewServer().
Query("cross-zone-example.example.com. SOA", dnsmock.CNAME("example.org.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "NXDOMAIN",
fqdn: "test.lego.invalid.",
zone: "lego.invalid.",
nameservers: []string{
dnsmock.NewServer().
Query("test.lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
Build(t).
String(),
},
expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`,
},
{
desc: "several non existent nameservers",
fqdn: "mail.example.com.",
zone: "example.com.",
primaryNs: "ns1.example.com.",
nameservers: []string{
":7053",
":8053",
dnsmock.NewServer().
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
Query("example.com. SOA", dnsmock.SOA("")).
Build(t).
String(),
},
},
{
desc: "only non-existent nameservers",
fqdn: "mail.example.com.",
zone: "example.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
expectedError: "[fqdn=mail.example.com.] could not find the start of authority for 'mail.example.com.': DNS call error: read udp ",
},
{
desc: "no nameservers",
fqdn: "test.example.com.",
zone: "example.com.",
nameservers: []string{},
expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers",
},
}
}
func TestClient_FindZoneByFqdnCustom(t *testing.T) {
for _, test := range lookupSoaByFqdnTestCases(t) {
t.Run(test.desc, func(t *testing.T) {
client := NewClient(nil)
zone, err := client.FindZoneByFqdnCustom(t.Context(), test.fqdn, test.nameservers)
if test.expectedError != "" {
require.Error(t, err)
assert.ErrorContains(t, err, test.expectedError)
} else {
require.NoError(t, err)
assert.Equal(t, test.zone, zone)
}
})
}
}
func TestClient_FindZoneByFqdn(t *testing.T) {
for _, test := range lookupSoaByFqdnTestCases(t) {
t.Run(test.desc, func(t *testing.T) {
client := NewClient(nil)
client.recursiveNameservers = test.nameservers
zone, err := client.FindZoneByFqdn(t.Context(), test.fqdn)
if test.expectedError != "" {
require.Error(t, err)
assert.ErrorContains(t, err, test.expectedError)
} else {
require.NoError(t, err)
assert.Equal(t, test.zone, zone)
}
})
}
}

View file

@ -0,0 +1,200 @@
package dnsnew
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v5/acme"
"github.com/go-acme/lego/v5/acme/api"
"github.com/go-acme/lego/v5/challenge"
"github.com/go-acme/lego/v5/log"
"github.com/go-acme/lego/v5/platform/wait"
)
const (
// DefaultPropagationTimeout default propagation timeout.
DefaultPropagationTimeout = 60 * time.Second
// DefaultPollingInterval default polling interval.
DefaultPollingInterval = 2 * time.Second
// DefaultTTL default TTL.
DefaultTTL = 120
)
type ValidateFunc func(ctx context.Context, core *api.Core, domain string, chlng acme.Challenge) error
// Challenge implements the dns-01 challenge.
type Challenge struct {
core *api.Core
validate ValidateFunc
provider challenge.Provider
preCheck preCheck
}
func NewChallenge(core *api.Core, validate ValidateFunc, provider challenge.Provider, opts ...ChallengeOption) *Challenge {
chlg := &Challenge{
core: core,
validate: validate,
provider: provider,
preCheck: newPreCheck(),
}
for _, opt := range opts {
err := opt(chlg)
if err != nil {
log.Warn("Challenge option skipped.", "error", err)
}
}
return chlg
}
// PreSolve just submits the txt record to the dns provider.
// It does not validate record propagation or do anything at all with the ACME server.
func (c *Challenge) PreSolve(ctx context.Context, authz acme.Authorization) error {
domain := challenge.GetTargetedDomain(authz)
log.Info("acme: Preparing to solve DNS-01.", "domain", domain)
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
if c.provider == nil {
return fmt.Errorf("[%s] acme: no DNS Provider configured", domain)
}
// Generate the Key Authorization for the challenge
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
err = c.provider.Present(authz.Identifier.Value, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err)
}
return nil
}
func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error {
domain := challenge.GetTargetedDomain(authz)
log.Info("acme: Trying to solve DNS-01.", "domain", domain)
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
// Generate the Key Authorization for the challenge
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
info := GetChallengeInfo(ctx, authz.Identifier.Value, keyAuth)
var timeout, interval time.Duration
switch provider := c.provider.(type) {
case challenge.ProviderTimeout:
timeout, interval = provider.Timeout()
default:
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
}
log.Info("acme: Checking DNS record propagation.",
"domain", domain, "nameservers", strings.Join(DefaultClient().recursiveNameservers, ","))
time.Sleep(interval)
err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(ctx, domain, info.EffectiveFQDN, info.Value)
if !stop || errP != nil {
log.Info("acme: Waiting for DNS record propagation.", "domain", domain)
}
return stop, errP
})
if err != nil {
return err
}
chlng.KeyAuthorization = keyAuth
return c.validate(ctx, c.core, domain, chlng)
}
// CleanUp cleans the challenge.
func (c *Challenge) CleanUp(authz acme.Authorization) error {
log.Info("acme: Cleaning DNS-01 challenge.", "domain", challenge.GetTargetedDomain(authz))
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
return c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth)
}
func (c *Challenge) Sequential() (bool, time.Duration) {
if p, ok := c.provider.(sequential); ok {
return ok, p.Sequential()
}
return false, 0
}
type sequential interface {
Sequential() time.Duration
}
// ChallengeInfo contains the information use to create the TXT record.
type ChallengeInfo struct {
// FQDN is the full-qualified challenge domain (i.e. `_acme-challenge.[domain].`)
FQDN string
// EffectiveFQDN contains the resulting FQDN after the CNAMEs resolutions.
EffectiveFQDN string
// Value contains the value for the TXT record.
Value string
}
// GetChallengeInfo returns information used to create a DNS record which will fulfill the `dns-01` challenge.
func GetChallengeInfo(ctx context.Context, domain, keyAuth string) ChallengeInfo {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding
value := base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
ok, _ := strconv.ParseBool(os.Getenv("LEGO_DISABLE_CNAME_SUPPORT"))
return ChallengeInfo{
Value: value,
FQDN: getChallengeFQDN(ctx, domain, false),
EffectiveFQDN: getChallengeFQDN(ctx, domain, !ok),
}
}
func getChallengeFQDN(ctx context.Context, domain string, followCNAME bool) string {
fqdn := fmt.Sprintf("_acme-challenge.%s.", domain)
if !followCNAME {
return fqdn
}
return DefaultClient().lookupCNAME(ctx, fqdn)
}

View file

@ -0,0 +1,46 @@
package dnsnew
import (
"context"
"time"
)
type ChallengeOption func(*Challenge) error
// CondOption Conditional challenge option.
func CondOption(condition bool, opt ChallengeOption) ChallengeOption {
if !condition {
// NoOp options
return func(*Challenge) error {
return nil
}
}
return opt
}
func DisableAuthoritativeNssPropagationRequirement() ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.requireAuthoritativeNssPropagation = false
return nil
}
}
func RecursiveNSsPropagationRequirement() ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.requireRecursiveNssPropagation = true
return nil
}
}
func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption {
return WrapPreCheck(func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error) {
time.Sleep(wait)
if skipCheck {
return true, nil
}
return check(ctx, fqdn, value)
})
}

View file

@ -0,0 +1,86 @@
package dnsnew
import (
"context"
"fmt"
"github.com/miekg/dns"
)
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
type PreCheckFunc func(ctx context.Context, fqdn, value string) (bool, error)
// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
// the main check, put it in a loop, etc.
type WrapPreCheckFunc func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error)
// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = wrap
return nil
}
}
type preCheck struct {
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
checkFunc WrapPreCheckFunc
// require the TXT record to be propagated to all authoritative name servers
requireAuthoritativeNssPropagation bool
// require the TXT record to be propagated to all recursive name servers
requireRecursiveNssPropagation bool
}
func newPreCheck() preCheck {
return preCheck{
requireAuthoritativeNssPropagation: true,
}
}
func (p preCheck) call(ctx context.Context, domain, fqdn, value string) (bool, error) {
if p.checkFunc == nil {
return p.checkDNSPropagation(ctx, fqdn, value)
}
return p.checkFunc(ctx, domain, fqdn, value, p.checkDNSPropagation)
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func (p preCheck) checkDNSPropagation(ctx context.Context, fqdn, value string) (bool, error) {
client := DefaultClient()
// Initial attempt to resolve at the recursive NS (require getting CNAME)
r, err := client.sendQuery(ctx, fqdn, dns.TypeTXT, true)
if err != nil {
return false, fmt.Errorf("initial recursive nameserver: %w", err)
}
if r.Rcode == dns.RcodeSuccess {
fqdn = updateDomainWithCName(r, fqdn)
}
if p.requireRecursiveNssPropagation {
_, err = client.checkNameserversPropagation(ctx, fqdn, value, false)
if err != nil {
return false, fmt.Errorf("recursive nameservers: %w", err)
}
}
if !p.requireAuthoritativeNssPropagation {
return true, nil
}
authoritativeNss, err := client.lookupAuthoritativeNameservers(ctx, fqdn)
if err != nil {
return false, err
}
found, err := client.checkNameserversPropagationCustom(ctx, fqdn, value, authoritativeNss, true)
if err != nil {
return found, fmt.Errorf("authoritative nameservers: %w", err)
}
return found, nil
}

View file

@ -0,0 +1,78 @@
package dnsnew
import (
"testing"
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func Test_preCheck_checkDNSPropagation(t *testing.T) {
mockDefault(t,
dnsmock.NewServer().
Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
Query("example.com. SOA", dnsmock.SOA("")).
Query("example.com. NS",
dnsmock.Answer(
fakeNS("example.com.", "ns0.lego.localhost."),
fakeNS("example.com.", "ns1.lego.localhost."),
),
).
Build(t),
mockResolver(
dnsmock.NewServer().
Query("ns0.lego.localhost. A",
dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))).
Query("ns1.lego.localhost. A",
dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))).
Query("example.com. TXT",
dnsmock.Answer(
fakeTXT("example.com.", "one"),
fakeTXT("example.com.", "two"),
fakeTXT("example.com.", "three"),
fakeTXT("example.com.", "four"),
fakeTXT("example.com.", "five"),
),
).
Build(t),
),
)
testCases := []struct {
desc string
fqdn string
value string
expectedError string
}{
{
desc: "success",
fqdn: "example.com.",
value: "four",
},
{
desc: "no matching TXT record",
fqdn: "acme-staging.api.example.com.",
value: "fe01=",
expectedError: "did not return the expected TXT record [fqdn: acme-staging.api.example.com., value: fe01=]: one ,two ,three ,four ,five",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
DefaultClient().ClearFqdnCache()
check := newPreCheck()
ok, err := check.checkDNSPropagation(t.Context(), test.fqdn, test.value)
if test.expectedError != "" {
assert.ErrorContainsf(t, err, test.expectedError, "PreCheckDNS must fail for %s", test.fqdn)
assert.False(t, ok, "PreCheckDNS must fail for %s", test.fqdn)
} else {
assert.NoErrorf(t, err, "PreCheckDNS failed for %s", test.fqdn)
assert.True(t, ok, "PreCheckDNS failed for %s", test.fqdn)
}
})
}
}

View file

@ -0,0 +1,348 @@
package dnsnew
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"testing"
"time"
"github.com/go-acme/lego/v5/acme"
"github.com/go-acme/lego/v5/acme/api"
"github.com/go-acme/lego/v5/challenge"
"github.com/go-acme/lego/v5/platform/tester"
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type providerMock struct {
present, cleanUp error
}
func (p *providerMock) Present(domain, token, keyAuth string) error { return p.present }
func (p *providerMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
type providerTimeoutMock struct {
present, cleanUp error
timeout, interval time.Duration
}
func (p *providerTimeoutMock) Present(domain, token, keyAuth string) error { return p.present }
func (p *providerTimeoutMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
func (p *providerTimeoutMock) Timeout() (time.Duration, time.Duration) { return p.timeout, p.interval }
func TestChallenge_PreSolve(t *testing.T) {
server := tester.MockACMEServer().BuildHTTPS(t)
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
require.NoError(t, err)
testCases := []struct {
desc string
validate ValidateFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
},
},
{
desc: "preCheck fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
return false, errors.New("OOPS")
},
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
},
},
{
desc: "present fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
expectError: true,
},
{
desc: "cleanUp fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{
Identifier: acme.Identifier{
Value: "example.com",
},
Challenges: []acme.Challenge{
{Type: challenge.DNS01.String()},
},
}
err = chlg.PreSolve(t.Context(), authz)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestChallenge_Solve(t *testing.T) {
mockDefault(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
Build(t))
server := tester.MockACMEServer().BuildHTTPS(t)
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
require.NoError(t, err)
testCases := []struct {
desc string
validate ValidateFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
},
expectError: true,
},
{
desc: "preCheck fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
return false, errors.New("OOPS")
},
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
},
expectError: true,
},
{
desc: "present fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
},
{
desc: "cleanUp fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
var options []ChallengeOption
if test.preCheck != nil {
options = append(options, WrapPreCheck(test.preCheck))
}
chlg := NewChallenge(core, test.validate, test.provider, options...)
authz := acme.Authorization{
Identifier: acme.Identifier{
Value: "example.com",
},
Challenges: []acme.Challenge{
{Type: challenge.DNS01.String()},
},
}
err = chlg.Solve(t.Context(), authz)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestChallenge_CleanUp(t *testing.T) {
server := tester.MockACMEServer().BuildHTTPS(t)
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
require.NoError(t, err)
testCases := []struct {
desc string
validate ValidateFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
},
},
{
desc: "preCheck fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
return false, errors.New("OOPS")
},
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
},
},
{
desc: "present fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
},
{
desc: "cleanUp fail",
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
expectError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{
Identifier: acme.Identifier{
Value: "example.com",
},
Challenges: []acme.Challenge{
{Type: challenge.DNS01.String()},
},
}
err = chlg.CleanUp(authz)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetChallengeInfo(t *testing.T) {
mockDefault(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
Build(t))
info := GetChallengeInfo(t.Context(), "example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "_acme-challenge.example.com.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}
func TestGetChallengeInfo_CNAME(t *testing.T) {
mockDefault(t, dnsmock.NewServer().
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
Query("example.org. CNAME", dnsmock.Noop).
Build(t))
info := GetChallengeInfo(t.Context(), "example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "example.org.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}
func TestGetChallengeInfo_CNAME_disabled(t *testing.T) {
mockDefault(t, dnsmock.NewServer().
// Never called when the env var works.
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
Build(t))
t.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true")
info := GetChallengeInfo(t.Context(), "example.com", "123")
expected := ChallengeInfo{
FQDN: "_acme-challenge.example.com.",
EffectiveFQDN: "_acme-challenge.example.com.",
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
}
assert.Equal(t, expected, info)
}

View file

@ -0,0 +1,24 @@
package dnsnew
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// ExtractSubDomain extracts the subdomain part from a domain and a zone.
func ExtractSubDomain(domain, zone string) (string, error) {
canonDomain := dns.Fqdn(domain)
canonZone := dns.Fqdn(zone)
if canonDomain == canonZone {
return "", fmt.Errorf("no subdomain because the domain and the zone are identical: %s", canonDomain)
}
if !dns.IsSubDomain(canonZone, canonDomain) {
return "", fmt.Errorf("%s is not a subdomain of %s", canonDomain, canonZone)
}
return strings.TrimSuffix(canonDomain, "."+canonZone), nil
}

View file

@ -0,0 +1,102 @@
package dnsnew
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSubDomain(t *testing.T) {
testCases := []struct {
desc string
domain string
zone string
expected string
}{
{
desc: "no FQDN",
domain: "_acme-challenge.example.com",
zone: "example.com",
expected: "_acme-challenge",
},
{
desc: "no FQDN zone",
domain: "_acme-challenge.example.com.",
zone: "example.com",
expected: "_acme-challenge",
},
{
desc: "no FQDN domain",
domain: "_acme-challenge.example.com",
zone: "example.com.",
expected: "_acme-challenge",
},
{
desc: "FQDN",
domain: "_acme-challenge.example.com.",
zone: "example.com.",
expected: "_acme-challenge",
},
{
desc: "multi-level subdomain",
domain: "_acme-challenge.one.example.com.",
zone: "example.com.",
expected: "_acme-challenge.one",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
subDomain, err := ExtractSubDomain(test.domain, test.zone)
require.NoError(t, err)
assert.Equal(t, test.expected, subDomain)
})
}
}
func TestExtractSubDomain_errors(t *testing.T) {
testCases := []struct {
desc string
domain string
zone string
}{
{
desc: "same domain",
domain: "example.com",
zone: "example.com",
},
{
desc: "same domain, no FQDN zone",
domain: "example.com.",
zone: "example.com",
},
{
desc: "same domain, no FQDN domain",
domain: "example.com",
zone: "example.com.",
},
{
desc: "same domain, FQDN",
domain: "example.com.",
zone: "example.com.",
},
{
desc: "zone and domain are unrelated",
domain: "_acme-challenge.example.com",
zone: "example.org",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
_, err := ExtractSubDomain(test.domain, test.zone)
require.Error(t, err)
})
}
}

View file

@ -0,0 +1,5 @@
domain example.com
nameserver 10.200.3.249
nameserver 10.200.3.250:5353
nameserver 2001:4860:4860::8844
nameserver [10.0.0.1]:5353

47
challenge/dnsnew/fqdn.go Normal file
View file

@ -0,0 +1,47 @@
package dnsnew
import (
"iter"
"github.com/miekg/dns"
)
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}
// UnFqdnDomainsSeq generates a sequence of "unFQDNed" domain names derived from a domain (FQDN or not) in descending order.
func UnFqdnDomainsSeq(fqdn string) iter.Seq[string] {
return func(yield func(string) bool) {
if fqdn == "" {
return
}
for _, index := range dns.Split(fqdn) {
if !yield(UnFqdn(fqdn[index:])) {
return
}
}
}
}
// DomainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order.
func DomainsSeq(fqdn string) iter.Seq[string] {
return func(yield func(string) bool) {
if fqdn == "" {
return
}
for _, index := range dns.Split(fqdn) {
if !yield(fqdn[index:]) {
return
}
}
}
}

View file

@ -0,0 +1,137 @@
package dnsnew
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUnFqdn(t *testing.T) {
testCases := []struct {
desc string
fqdn string
expected string
}{
{
desc: "simple",
fqdn: "foo.example.",
expected: "foo.example",
},
{
desc: "already domain",
fqdn: "foo.example",
expected: "foo.example",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
domain := UnFqdn(test.fqdn)
assert.Equal(t, test.expected, domain)
})
}
}
func TestUnFqdnDomainsSeq(t *testing.T) {
testCases := []struct {
desc string
fqdn string
expected []string
}{
{
desc: "empty",
fqdn: "",
expected: nil,
},
{
desc: "TLD",
fqdn: "com",
expected: []string{"com"},
},
{
desc: "2 levels",
fqdn: "example.com",
expected: []string{"example.com", "com"},
},
{
desc: "3 levels",
fqdn: "foo.example.com",
expected: []string{"foo.example.com", "example.com", "com"},
},
}
for _, test := range testCases {
for name, suffix := range map[string]string{"": "", " FQDN": "."} { //nolint:gocritic
t.Run(test.desc+name, func(t *testing.T) {
t.Parallel()
actual := slices.Collect(UnFqdnDomainsSeq(test.fqdn + suffix))
assert.Equal(t, test.expected, actual)
})
}
}
}
func TestDomainsSeq(t *testing.T) {
testCases := []struct {
desc string
fqdn string
expected []string
}{
{
desc: "empty",
fqdn: "",
expected: nil,
},
{
desc: "empty FQDN",
fqdn: ".",
expected: nil,
},
{
desc: "TLD FQDN",
fqdn: "com",
expected: []string{"com"},
},
{
desc: "TLD",
fqdn: "com.",
expected: []string{"com."},
},
{
desc: "2 levels",
fqdn: "example.com",
expected: []string{"example.com", "com"},
},
{
desc: "2 levels FQDN",
fqdn: "example.com.",
expected: []string{"example.com.", "com."},
},
{
desc: "3 levels",
fqdn: "foo.example.com",
expected: []string{"foo.example.com", "example.com", "com"},
},
{
desc: "3 levels FQDN",
fqdn: "foo.example.com.",
expected: []string{"foo.example.com.", "example.com.", "com."},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := slices.Collect(DomainsSeq(test.fqdn))
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,78 @@
package dnsnew
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func fakeNS(name, ns string) *dns.NS {
return &dns.NS{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
Ns: ns,
}
}
func fakeA(name, ip string) *dns.A {
return &dns.A{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10},
A: net.ParseIP(ip),
}
}
func fakeTXT(name, value string) *dns.TXT {
return &dns.TXT{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 10},
Txt: []string{value},
}
}
// mockResolver modifies the default DNS resolver to use a custom network address during the test execution.
// IMPORTANT: it modifying std global variables.
func mockResolver(authoritativeNS net.Addr) func(t *testing.T, client *Client) {
return func(t *testing.T, client *Client) {
t.Helper()
_, port, err := net.SplitHostPort(authoritativeNS.String())
require.NoError(t, err)
client.authoritativeNSPort = port
originalResolver := net.DefaultResolver
t.Cleanup(func() {
net.DefaultResolver = originalResolver
})
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: 1 * time.Second}
return d.DialContext(ctx, network, authoritativeNS.String())
},
}
}
}
func mockDefault(t *testing.T, recursiveNS net.Addr, opts ...func(t *testing.T, client *Client)) {
t.Helper()
backup := DefaultClient()
t.Cleanup(func() {
SetDefaultClient(backup)
})
client := NewClient(&Options{RecursiveNameservers: []string{recursiveNS.String()}})
for _, opt := range opts {
opt(t, client)
}
SetDefaultClient(client)
}