mirror of
https://github.com/go-acme/lego
synced 2026-03-14 14:35:48 +01:00
feat: new package dnsnew
This commit is contained in:
parent
2f2f587b03
commit
a4499b729c
23 changed files with 2157 additions and 0 deletions
155
challenge/dnsnew/client.go
Normal file
155
challenge/dnsnew/client.go
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const defaultResolvConf = "/etc/resolv.conf"
|
||||
|
||||
var defaultClient atomic.Pointer[Client]
|
||||
|
||||
func init() {
|
||||
defaultClient.Store(NewClient(nil))
|
||||
}
|
||||
|
||||
func DefaultClient() *Client { return defaultClient.Load() }
|
||||
|
||||
func SetDefaultClient(c *Client) {
|
||||
defaultClient.Store(c)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RecursiveNameservers []string
|
||||
Timeout time.Duration
|
||||
TCPOnly bool
|
||||
NetworkStack challenge.NetworkStack
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
recursiveNameservers []string
|
||||
|
||||
// authoritativeNSPort used by authoritative NS.
|
||||
// For testing purposes only.
|
||||
authoritativeNSPort string
|
||||
|
||||
tcpClient *dns.Client
|
||||
udpClient *dns.Client
|
||||
tcpOnly bool
|
||||
|
||||
fqdnSoaCache map[string]*soaCacheEntry
|
||||
muFqdnSoaCache sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(opts *Options) *Client {
|
||||
if opts == nil {
|
||||
tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY"))
|
||||
opts = &Options{TCPOnly: tcpOnly}
|
||||
}
|
||||
|
||||
if len(opts.RecursiveNameservers) == 0 {
|
||||
defaultNameservers := []string{
|
||||
"google-public-dns-a.google.com:53",
|
||||
"google-public-dns-b.google.com:53",
|
||||
}
|
||||
|
||||
opts.RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
|
||||
}
|
||||
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = dnsTimeout
|
||||
}
|
||||
|
||||
return &Client{
|
||||
recursiveNameservers: opts.RecursiveNameservers,
|
||||
authoritativeNSPort: "53",
|
||||
tcpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("tcp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
udpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("udp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
tcpOnly: opts.TCPOnly,
|
||||
fqdnSoaCache: map[string]*soaCacheEntry{},
|
||||
muFqdnSoaCache: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) {
|
||||
return c.sendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive)
|
||||
}
|
||||
|
||||
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
|
||||
m := createDNSMsg(fqdn, rtype, recursive)
|
||||
|
||||
if len(nameservers) == 0 {
|
||||
return nil, &DNSError{Message: "empty list of nameservers"}
|
||||
}
|
||||
|
||||
var (
|
||||
r *dns.Msg
|
||||
err error
|
||||
errAll error
|
||||
)
|
||||
|
||||
for _, ns := range nameservers {
|
||||
r, err = c.exchange(ctx, m, ns)
|
||||
if err == nil && len(r.Answer) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
errAll = errors.Join(errAll, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, errAll
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
|
||||
if c.tcpOnly {
|
||||
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
if err != nil {
|
||||
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
r, _, err := c.udpClient.ExchangeContext(ctx, m, ns)
|
||||
|
||||
if r != nil && r.Truncated {
|
||||
// If the TCP request succeeds, the "err" will reset to nil
|
||||
r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(fqdn, rtype)
|
||||
m.SetEdns0(4096, false)
|
||||
|
||||
if !recursive {
|
||||
m.RecursionDesired = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
34
challenge/dnsnew/client_cache.go
Normal file
34
challenge/dnsnew/client_cache.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// soaCacheEntry holds a cached SOA record (only selected fields).
|
||||
type soaCacheEntry struct {
|
||||
zone string // zone apex (a domain name)
|
||||
primaryNs string // primary nameserver for the zone apex
|
||||
expires time.Time // time when this cache entry should be evicted
|
||||
}
|
||||
|
||||
func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
|
||||
return &soaCacheEntry{
|
||||
zone: soa.Hdr.Name,
|
||||
primaryNs: soa.Ns,
|
||||
expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
// isExpired checks whether a cache entry should be considered expired.
|
||||
func (cache *soaCacheEntry) isExpired() bool {
|
||||
return time.Now().After(cache.expires)
|
||||
}
|
||||
|
||||
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
|
||||
func (c *Client) ClearFqdnCache() {
|
||||
c.muFqdnSoaCache.Lock()
|
||||
c.fqdnSoaCache = map[string]*soaCacheEntry{}
|
||||
c.muFqdnSoaCache.Unlock()
|
||||
}
|
||||
57
challenge/dnsnew/client_cname.go
Normal file
57
challenge/dnsnew/client_cname.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-acme/lego/v5/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (c *Client) lookupCNAME(ctx context.Context, fqdn string) string {
|
||||
// recursion counter so it doesn't spin out of control
|
||||
for range 50 {
|
||||
// Keep following CNAMEs
|
||||
r, err := c.sendQuery(ctx, fqdn, dns.TypeCNAME, true)
|
||||
|
||||
if err != nil || r.Rcode != dns.RcodeSuccess {
|
||||
// TODO(ldez): logs the error in v5
|
||||
// No more CNAME records to follow, exit
|
||||
break
|
||||
}
|
||||
|
||||
// Check if the domain has CNAME then use that
|
||||
cname := updateDomainWithCName(r, fqdn)
|
||||
if cname == fqdn {
|
||||
break
|
||||
}
|
||||
|
||||
log.Info("Found CNAME entry.", "fqdn", fqdn, "cname", cname)
|
||||
|
||||
fqdn = cname
|
||||
}
|
||||
|
||||
return fqdn
|
||||
}
|
||||
|
||||
// Update FQDN with CNAME if any.
|
||||
func updateDomainWithCName(r *dns.Msg, fqdn string) string {
|
||||
for _, rr := range r.Answer {
|
||||
if cn, ok := rr.(*dns.CNAME); ok {
|
||||
if strings.EqualFold(cn.Hdr.Name, fqdn) {
|
||||
return cn.Target
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fqdn
|
||||
}
|
||||
|
||||
// dnsMsgContainsCNAME checks for a CNAME answer in msg.
|
||||
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
|
||||
return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
|
||||
_, ok := rr.(*dns.CNAME)
|
||||
return ok
|
||||
})
|
||||
}
|
||||
35
challenge/dnsnew/client_cname_test.go
Normal file
35
challenge/dnsnew/client_cname_test.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_updateDomainWithCName_caseInsensitive(t *testing.T) {
|
||||
qname := "_acme-challenge.uppercase-test.example.com."
|
||||
cnameTarget := "_acme-challenge.uppercase-test.cname-target.example.com."
|
||||
|
||||
msg := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Authoritative: true,
|
||||
},
|
||||
Answer: []dns.RR{
|
||||
&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: strings.ToUpper(qname), // CNAME names are case-insensitive
|
||||
Rrtype: dns.TypeCNAME,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
},
|
||||
Target: cnameTarget,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fqdn := updateDomainWithCName(msg, qname)
|
||||
|
||||
assert.Equal(t, cnameTarget, fqdn)
|
||||
}
|
||||
64
challenge/dnsnew/client_error.go
Normal file
64
challenge/dnsnew/client_error.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// DNSError error related to DNS calls.
|
||||
type DNSError struct {
|
||||
Message string
|
||||
NS string
|
||||
MsgIn *dns.Msg
|
||||
MsgOut *dns.Msg
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d *DNSError) Error() string {
|
||||
var details []string
|
||||
if d.NS != "" {
|
||||
details = append(details, "ns="+d.NS)
|
||||
}
|
||||
|
||||
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
|
||||
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
|
||||
}
|
||||
|
||||
if d.MsgOut != nil {
|
||||
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
|
||||
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
|
||||
}
|
||||
|
||||
details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
|
||||
}
|
||||
|
||||
msg := "DNS error"
|
||||
if d.Message != "" {
|
||||
msg = d.Message
|
||||
}
|
||||
|
||||
if d.Err != nil {
|
||||
msg += ": " + d.Err.Error()
|
||||
}
|
||||
|
||||
if len(details) > 0 {
|
||||
msg += " [" + strings.Join(details, ", ") + "]"
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (d *DNSError) Unwrap() error {
|
||||
return d.Err
|
||||
}
|
||||
|
||||
func formatQuestions(questions []dns.Question) string {
|
||||
var parts []string
|
||||
for _, question := range questions {
|
||||
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
75
challenge/dnsnew/client_error_test.go
Normal file
75
challenge/dnsnew/client_error_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDNSError_Error(t *testing.T) {
|
||||
msgIn := createDNSMsg("example.com.", dns.TypeTXT, true)
|
||||
|
||||
msgOut := createDNSMsg("example.org.", dns.TypeSOA, true)
|
||||
msgOut.Rcode = dns.RcodeNameError
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
err *DNSError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "empty error",
|
||||
err: &DNSError{},
|
||||
expected: "DNS error",
|
||||
},
|
||||
{
|
||||
desc: "all fields",
|
||||
err: &DNSError{
|
||||
Message: "Oops",
|
||||
NS: "example.com.",
|
||||
MsgIn: msgIn,
|
||||
MsgOut: msgOut,
|
||||
Err: errors.New("I did it again"),
|
||||
},
|
||||
expected: "Oops: I did it again [ns=example.com., question='example.com. IN TXT', code=NXDOMAIN]",
|
||||
},
|
||||
{
|
||||
desc: "only NS",
|
||||
err: &DNSError{
|
||||
NS: "example.com.",
|
||||
},
|
||||
expected: "DNS error [ns=example.com.]",
|
||||
},
|
||||
{
|
||||
desc: "only MsgIn",
|
||||
err: &DNSError{
|
||||
MsgIn: msgIn,
|
||||
},
|
||||
expected: "DNS error [question='example.com. IN TXT']",
|
||||
},
|
||||
{
|
||||
desc: "only MsgOut",
|
||||
err: &DNSError{
|
||||
MsgOut: msgOut,
|
||||
},
|
||||
expected: "DNS error [question='example.org. IN SOA', code=NXDOMAIN]",
|
||||
},
|
||||
{
|
||||
desc: "only Err",
|
||||
err: &DNSError{
|
||||
Err: errors.New("I did it again"),
|
||||
},
|
||||
expected: "DNS error: I did it again",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.EqualError(t, test.err, test.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
107
challenge/dnsnew/client_nameservers.go
Normal file
107
challenge/dnsnew/client_nameservers.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// checkNameserversPropagation queries each of the recursive nameservers for the expected TXT record.
|
||||
func (c *Client) checkNameserversPropagation(ctx context.Context, fqdn, value string, addPort bool) (bool, error) {
|
||||
return c.checkNameserversPropagationCustom(ctx, fqdn, value, c.recursiveNameservers, addPort)
|
||||
}
|
||||
|
||||
// checkNameserversPropagationCustom queries each of the given nameservers for the expected TXT record.
|
||||
func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn, value string, nameservers []string, addPort bool) (bool, error) {
|
||||
for _, ns := range nameservers {
|
||||
if addPort {
|
||||
ns = net.JoinHostPort(ns, c.authoritativeNSPort)
|
||||
}
|
||||
|
||||
r, err := c.sendQueryCustom(ctx, fqdn, dns.TypeTXT, []string{ns}, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
|
||||
}
|
||||
|
||||
var records []string
|
||||
|
||||
var found bool
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if txt, ok := rr.(*dns.TXT); ok {
|
||||
record := strings.Join(txt.Txt, "")
|
||||
|
||||
records = append(records, record)
|
||||
if record == value {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,"))
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// lookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn.
|
||||
func (c *Client) lookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
zone, err := c.FindZoneByFqdn(ctx, fqdn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find zone: %w", err)
|
||||
}
|
||||
|
||||
r, err := c.sendQuery(ctx, zone, dns.TypeNS, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NS call failed: %w", err)
|
||||
}
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||
}
|
||||
}
|
||||
|
||||
if len(authoritativeNss) > 0 {
|
||||
return authoritativeNss, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
|
||||
}
|
||||
|
||||
// getNameservers attempts to get systems nameservers before falling back to the defaults.
|
||||
func getNameservers(path string, defaults []string) []string {
|
||||
config, err := dns.ClientConfigFromFile(path)
|
||||
if err != nil || len(config.Servers) == 0 {
|
||||
return defaults
|
||||
}
|
||||
|
||||
return parseNameservers(config.Servers)
|
||||
}
|
||||
|
||||
func parseNameservers(servers []string) []string {
|
||||
var resolvers []string
|
||||
|
||||
for _, resolver := range servers {
|
||||
// ensure all servers have a port number
|
||||
if _, _, err := net.SplitHostPort(resolver); err != nil {
|
||||
resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
|
||||
} else {
|
||||
resolvers = append(resolvers, resolver)
|
||||
}
|
||||
}
|
||||
|
||||
return resolvers
|
||||
}
|
||||
216
challenge/dnsnew/client_nameservers_test.go
Normal file
216
challenge/dnsnew/client_nameservers_test.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClient_checkNameserversPropagationCustom_authoritativeNss(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn, value string
|
||||
fakeDNSServer *dnsmock.Builder
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
desc: "TXT RR w/ expected value",
|
||||
// NS: asnums.routeviews.org.
|
||||
fqdn: "8.8.8.8.asn.routeviews.org.",
|
||||
value: "151698.8.8.024",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("8.8.8.8.asn.routeviews.org. TXT",
|
||||
dnsmock.Answer(
|
||||
fakeTXT("8.8.8.8.asn.routeviews.org.", "151698.8.8.024"),
|
||||
),
|
||||
),
|
||||
},
|
||||
{
|
||||
desc: "TXT RR w/ unexpected value",
|
||||
// NS: asnums.routeviews.org.
|
||||
fqdn: "8.8.8.8.asn.routeviews.org.",
|
||||
value: "fe01=",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("8.8.8.8.asn.routeviews.org. TXT",
|
||||
dnsmock.Answer(
|
||||
fakeTXT("8.8.8.8.asn.routeviews.org.", "15169"),
|
||||
fakeTXT("8.8.8.8.asn.routeviews.org.", "8.8.8.0"),
|
||||
fakeTXT("8.8.8.8.asn.routeviews.org.", "24"),
|
||||
),
|
||||
),
|
||||
expectedError: "did not return the expected TXT record [fqdn: 8.8.8.8.asn.routeviews.org., value: fe01=]: 15169 ,8.8.8.0 ,24",
|
||||
},
|
||||
{
|
||||
desc: "No TXT RR",
|
||||
// NS: ns2.google.com.
|
||||
fqdn: "ns1.google.com.",
|
||||
value: "fe01=",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("ns1.google.com.", dnsmock.Noop),
|
||||
expectedError: "did not return the expected TXT record [fqdn: ns1.google.com., value: fe01=]: ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(nil)
|
||||
|
||||
addr := test.fakeDNSServer.Build(t)
|
||||
|
||||
ok, err := client.checkNameserversPropagationCustom(t.Context(), test.fqdn, test.value, []string{addr.String()}, false)
|
||||
|
||||
if test.expectedError == "" {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, test.expectedError)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_lookupAuthoritativeNameservers_OK(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fakeDNSServer *dnsmock.Builder
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
fqdn: "en.wikipedia.org.localhost.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("en.wikipedia.org.localhost SOA", dnsmock.CNAME("dyna.wikimedia.org.localhost")).
|
||||
Query("wikipedia.org.localhost SOA", dnsmock.SOA("")).
|
||||
Query("wikipedia.org.localhost NS",
|
||||
dnsmock.Answer(
|
||||
fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "www.google.com.localhost.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("www.google.com.localhost. SOA", dnsmock.Noop).
|
||||
Query("google.com.localhost. SOA", dnsmock.SOA("")).
|
||||
Query("google.com.localhost. NS",
|
||||
dnsmock.Answer(
|
||||
fakeNS("google.com.localhost.", "ns1.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns2.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns3.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns4.google.com.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "mail.proton.me.localhost.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("mail.proton.me.localhost. SOA", dnsmock.Noop).
|
||||
Query("proton.me.localhost. SOA", dnsmock.SOA("")).
|
||||
Query("proton.me.localhost. NS",
|
||||
dnsmock.Answer(
|
||||
fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.fqdn, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
nss, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.NoError(t, err)
|
||||
|
||||
sort.Strings(nss)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, nss)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_lookupAuthoritativeNameservers_error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
fakeDNSServer *dnsmock.Builder
|
||||
error string
|
||||
}{
|
||||
{
|
||||
desc: "NXDOMAIN",
|
||||
fqdn: "example.invalid.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query(". SOA", dnsmock.Error(dns.RcodeNameError)),
|
||||
error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]",
|
||||
},
|
||||
{
|
||||
desc: "NS error",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Query("example.com. NS", dnsmock.Error(dns.RcodeServerFailure)),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
{
|
||||
desc: "empty NS",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock.NewServer().
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Query("example.me NS", dnsmock.Noop),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
_, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, test.error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getNameservers_ResolveConfServers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
fixture string
|
||||
expected []string
|
||||
defaults []string
|
||||
}{
|
||||
{
|
||||
fixture: "fixtures/resolv.conf.1",
|
||||
defaults: []string{"127.0.0.1:53"},
|
||||
expected: []string{"10.200.3.249:53", "10.200.3.250:5353", "[2001:4860:4860::8844]:53", "[10.0.0.1]:5353"},
|
||||
},
|
||||
{
|
||||
fixture: "fixtures/resolv.conf.nonexistant",
|
||||
defaults: []string{"127.0.0.1:53"},
|
||||
expected: []string{"127.0.0.1:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.fixture, func(t *testing.T) {
|
||||
result := getNameservers(test.fixture, test.defaults)
|
||||
|
||||
sort.Strings(result)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
8
challenge/dnsnew/client_timeout_unix.go
Normal file
8
challenge/dnsnew/client_timeout_unix.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
//go:build !windows
|
||||
|
||||
package dnsnew
|
||||
|
||||
import "time"
|
||||
|
||||
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
|
||||
const dnsTimeout = 10 * time.Second
|
||||
8
challenge/dnsnew/client_timeout_windows.go
Normal file
8
challenge/dnsnew/client_timeout_windows.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
//go:build windows
|
||||
|
||||
package dnsnew
|
||||
|
||||
import "time"
|
||||
|
||||
// dnsTimeout is used to override the default DNS timeout of 20 seconds.
|
||||
const dnsTimeout = 20 * time.Second
|
||||
89
challenge/dnsnew/client_zone.go
Normal file
89
challenge/dnsnew/client_zone.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// FindZoneByFqdn determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
func (c *Client) FindZoneByFqdn(ctx context.Context, fqdn string) (string, error) {
|
||||
return c.FindZoneByFqdnCustom(ctx, fqdn, c.recursiveNameservers)
|
||||
}
|
||||
|
||||
// FindZoneByFqdnCustom determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
func (c *Client) FindZoneByFqdnCustom(ctx context.Context, fqdn string, nameservers []string) (string, error) {
|
||||
soa, err := c.lookupSoaByFqdn(ctx, fqdn, nameservers)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
|
||||
}
|
||||
|
||||
return soa.zone, nil
|
||||
}
|
||||
|
||||
func (c *Client) lookupSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) {
|
||||
c.muFqdnSoaCache.Lock()
|
||||
defer c.muFqdnSoaCache.Unlock()
|
||||
|
||||
// Do we have it cached and is it still fresh?
|
||||
if ent := c.fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() {
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
ent, err := c.fetchSoaByFqdn(ctx, fqdn, nameservers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.fqdnSoaCache[fqdn] = ent
|
||||
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (*soaCacheEntry, error) {
|
||||
var (
|
||||
err error
|
||||
r *dns.Msg
|
||||
)
|
||||
|
||||
for domain := range DomainsSeq(fqdn) {
|
||||
r, err = c.sendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch r.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
// Check if we got a SOA RR in the answer section
|
||||
if len(r.Answer) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// CNAME records cannot/should not exist at the root of a zone.
|
||||
// So we skip a domain when a CNAME is found.
|
||||
if dnsMsgContainsCNAME(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ans := range r.Answer {
|
||||
if soa, ok := ans.(*dns.SOA); ok {
|
||||
return newSoaCacheEntry(soa), nil
|
||||
}
|
||||
}
|
||||
case dns.RcodeNameError:
|
||||
// NXDOMAIN
|
||||
default:
|
||||
// Any response code other than NOERROR and NXDOMAIN is treated as error
|
||||
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
|
||||
}
|
||||
158
challenge/dnsnew/client_zone_test.go
Normal file
158
challenge/dnsnew/client_zone_test.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type lookupSoaByFqdnTestCase struct {
|
||||
desc string
|
||||
fqdn string
|
||||
zone string
|
||||
primaryNs string
|
||||
nameservers []string
|
||||
expectedError string
|
||||
}
|
||||
|
||||
func lookupSoaByFqdnTestCases(t *testing.T) []lookupSoaByFqdnTestCase {
|
||||
t.Helper()
|
||||
|
||||
return []lookupSoaByFqdnTestCase{
|
||||
{
|
||||
desc: "domain is a CNAME",
|
||||
fqdn: "mail.example.com.",
|
||||
zone: "example.com.",
|
||||
primaryNs: "ns1.example.com.",
|
||||
nameservers: []string{
|
||||
dnsmock.NewServer().
|
||||
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "domain is a non-existent subdomain",
|
||||
fqdn: "foo.example.com.",
|
||||
zone: "example.com.",
|
||||
primaryNs: "ns1.example.com.",
|
||||
nameservers: []string{
|
||||
dnsmock.NewServer().
|
||||
Query("foo.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "domain is a eTLD",
|
||||
fqdn: "example.com.ac.",
|
||||
zone: "ac.",
|
||||
primaryNs: "ns1.nic.ac.",
|
||||
nameservers: []string{
|
||||
dnsmock.NewServer().
|
||||
Query("example.com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("com.ac. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("ac. SOA", dnsmock.SOA("")).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "domain is a cross-zone CNAME",
|
||||
fqdn: "cross-zone-example.example.com.",
|
||||
zone: "example.com.",
|
||||
primaryNs: "ns1.example.com.",
|
||||
nameservers: []string{
|
||||
dnsmock.NewServer().
|
||||
Query("cross-zone-example.example.com. SOA", dnsmock.CNAME("example.org.")).
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "NXDOMAIN",
|
||||
fqdn: "test.lego.invalid.",
|
||||
zone: "lego.invalid.",
|
||||
nameservers: []string{
|
||||
dnsmock.NewServer().
|
||||
Query("test.lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("lego.invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("invalid. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
expectedError: `[fqdn=test.lego.invalid.] could not find the start of authority for 'test.lego.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]`,
|
||||
},
|
||||
{
|
||||
desc: "several non existent nameservers",
|
||||
fqdn: "mail.example.com.",
|
||||
zone: "example.com.",
|
||||
primaryNs: "ns1.example.com.",
|
||||
nameservers: []string{
|
||||
":7053",
|
||||
":8053",
|
||||
dnsmock.NewServer().
|
||||
Query("mail.example.com. SOA", dnsmock.CNAME("example.com.")).
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Build(t).
|
||||
String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "only non-existent nameservers",
|
||||
fqdn: "mail.example.com.",
|
||||
zone: "example.com.",
|
||||
nameservers: []string{":7053", ":8053", ":9053"},
|
||||
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
|
||||
expectedError: "[fqdn=mail.example.com.] could not find the start of authority for 'mail.example.com.': DNS call error: read udp ",
|
||||
},
|
||||
{
|
||||
desc: "no nameservers",
|
||||
fqdn: "test.example.com.",
|
||||
zone: "example.com.",
|
||||
nameservers: []string{},
|
||||
expectedError: "[fqdn=test.example.com.] could not find the start of authority for 'test.example.com.': empty list of nameservers",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_FindZoneByFqdnCustom(t *testing.T) {
|
||||
for _, test := range lookupSoaByFqdnTestCases(t) {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(nil)
|
||||
|
||||
zone, err := client.FindZoneByFqdnCustom(t.Context(), test.fqdn, test.nameservers)
|
||||
if test.expectedError != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, test.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.zone, zone)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_FindZoneByFqdn(t *testing.T) {
|
||||
for _, test := range lookupSoaByFqdnTestCases(t) {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(nil)
|
||||
client.recursiveNameservers = test.nameservers
|
||||
|
||||
zone, err := client.FindZoneByFqdn(t.Context(), test.fqdn)
|
||||
if test.expectedError != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, test.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.zone, zone)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
200
challenge/dnsnew/dns_challenge.go
Normal file
200
challenge/dnsnew/dns_challenge.go
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/acme"
|
||||
"github.com/go-acme/lego/v5/acme/api"
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/go-acme/lego/v5/log"
|
||||
"github.com/go-acme/lego/v5/platform/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPropagationTimeout default propagation timeout.
|
||||
DefaultPropagationTimeout = 60 * time.Second
|
||||
|
||||
// DefaultPollingInterval default polling interval.
|
||||
DefaultPollingInterval = 2 * time.Second
|
||||
|
||||
// DefaultTTL default TTL.
|
||||
DefaultTTL = 120
|
||||
)
|
||||
|
||||
type ValidateFunc func(ctx context.Context, core *api.Core, domain string, chlng acme.Challenge) error
|
||||
|
||||
// Challenge implements the dns-01 challenge.
|
||||
type Challenge struct {
|
||||
core *api.Core
|
||||
validate ValidateFunc
|
||||
provider challenge.Provider
|
||||
preCheck preCheck
|
||||
}
|
||||
|
||||
func NewChallenge(core *api.Core, validate ValidateFunc, provider challenge.Provider, opts ...ChallengeOption) *Challenge {
|
||||
chlg := &Challenge{
|
||||
core: core,
|
||||
validate: validate,
|
||||
provider: provider,
|
||||
preCheck: newPreCheck(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(chlg)
|
||||
if err != nil {
|
||||
log.Warn("Challenge option skipped.", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return chlg
|
||||
}
|
||||
|
||||
// PreSolve just submits the txt record to the dns provider.
|
||||
// It does not validate record propagation or do anything at all with the ACME server.
|
||||
func (c *Challenge) PreSolve(ctx context.Context, authz acme.Authorization) error {
|
||||
domain := challenge.GetTargetedDomain(authz)
|
||||
log.Info("acme: Preparing to solve DNS-01.", "domain", domain)
|
||||
|
||||
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.provider == nil {
|
||||
return fmt.Errorf("[%s] acme: no DNS Provider configured", domain)
|
||||
}
|
||||
|
||||
// Generate the Key Authorization for the challenge
|
||||
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.provider.Present(authz.Identifier.Value, chlng.Token, keyAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error {
|
||||
domain := challenge.GetTargetedDomain(authz)
|
||||
log.Info("acme: Trying to solve DNS-01.", "domain", domain)
|
||||
|
||||
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate the Key Authorization for the challenge
|
||||
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info := GetChallengeInfo(ctx, authz.Identifier.Value, keyAuth)
|
||||
|
||||
var timeout, interval time.Duration
|
||||
|
||||
switch provider := c.provider.(type) {
|
||||
case challenge.ProviderTimeout:
|
||||
timeout, interval = provider.Timeout()
|
||||
default:
|
||||
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
|
||||
}
|
||||
|
||||
log.Info("acme: Checking DNS record propagation.",
|
||||
"domain", domain, "nameservers", strings.Join(DefaultClient().recursiveNameservers, ","))
|
||||
|
||||
time.Sleep(interval)
|
||||
|
||||
err = wait.For("propagation", timeout, interval, func() (bool, error) {
|
||||
stop, errP := c.preCheck.call(ctx, domain, info.EffectiveFQDN, info.Value)
|
||||
if !stop || errP != nil {
|
||||
log.Info("acme: Waiting for DNS record propagation.", "domain", domain)
|
||||
}
|
||||
|
||||
return stop, errP
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chlng.KeyAuthorization = keyAuth
|
||||
|
||||
return c.validate(ctx, c.core, domain, chlng)
|
||||
}
|
||||
|
||||
// CleanUp cleans the challenge.
|
||||
func (c *Challenge) CleanUp(authz acme.Authorization) error {
|
||||
log.Info("acme: Cleaning DNS-01 challenge.", "domain", challenge.GetTargetedDomain(authz))
|
||||
|
||||
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth)
|
||||
}
|
||||
|
||||
func (c *Challenge) Sequential() (bool, time.Duration) {
|
||||
if p, ok := c.provider.(sequential); ok {
|
||||
return ok, p.Sequential()
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
type sequential interface {
|
||||
Sequential() time.Duration
|
||||
}
|
||||
|
||||
// ChallengeInfo contains the information use to create the TXT record.
|
||||
type ChallengeInfo struct {
|
||||
// FQDN is the full-qualified challenge domain (i.e. `_acme-challenge.[domain].`)
|
||||
FQDN string
|
||||
|
||||
// EffectiveFQDN contains the resulting FQDN after the CNAMEs resolutions.
|
||||
EffectiveFQDN string
|
||||
|
||||
// Value contains the value for the TXT record.
|
||||
Value string
|
||||
}
|
||||
|
||||
// GetChallengeInfo returns information used to create a DNS record which will fulfill the `dns-01` challenge.
|
||||
func GetChallengeInfo(ctx context.Context, domain, keyAuth string) ChallengeInfo {
|
||||
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
|
||||
// base64URL encoding without padding
|
||||
value := base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
|
||||
|
||||
ok, _ := strconv.ParseBool(os.Getenv("LEGO_DISABLE_CNAME_SUPPORT"))
|
||||
|
||||
return ChallengeInfo{
|
||||
Value: value,
|
||||
FQDN: getChallengeFQDN(ctx, domain, false),
|
||||
EffectiveFQDN: getChallengeFQDN(ctx, domain, !ok),
|
||||
}
|
||||
}
|
||||
|
||||
func getChallengeFQDN(ctx context.Context, domain string, followCNAME bool) string {
|
||||
fqdn := fmt.Sprintf("_acme-challenge.%s.", domain)
|
||||
|
||||
if !followCNAME {
|
||||
return fqdn
|
||||
}
|
||||
|
||||
return DefaultClient().lookupCNAME(ctx, fqdn)
|
||||
}
|
||||
46
challenge/dnsnew/dns_challenge_options.go
Normal file
46
challenge/dnsnew/dns_challenge_options.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChallengeOption func(*Challenge) error
|
||||
|
||||
// CondOption Conditional challenge option.
|
||||
func CondOption(condition bool, opt ChallengeOption) ChallengeOption {
|
||||
if !condition {
|
||||
// NoOp options
|
||||
return func(*Challenge) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
func DisableAuthoritativeNssPropagationRequirement() ChallengeOption {
|
||||
return func(chlg *Challenge) error {
|
||||
chlg.preCheck.requireAuthoritativeNssPropagation = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func RecursiveNSsPropagationRequirement() ChallengeOption {
|
||||
return func(chlg *Challenge) error {
|
||||
chlg.preCheck.requireRecursiveNssPropagation = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption {
|
||||
return WrapPreCheck(func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error) {
|
||||
time.Sleep(wait)
|
||||
|
||||
if skipCheck {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return check(ctx, fqdn, value)
|
||||
})
|
||||
}
|
||||
86
challenge/dnsnew/dns_challenge_precheck.go
Normal file
86
challenge/dnsnew/dns_challenge_precheck.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
|
||||
type PreCheckFunc func(ctx context.Context, fqdn, value string) (bool, error)
|
||||
|
||||
// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
|
||||
// the main check, put it in a loop, etc.
|
||||
type WrapPreCheckFunc func(ctx context.Context, domain, fqdn, value string, check PreCheckFunc) (bool, error)
|
||||
|
||||
// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
|
||||
func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
|
||||
return func(chlg *Challenge) error {
|
||||
chlg.preCheck.checkFunc = wrap
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type preCheck struct {
|
||||
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
|
||||
checkFunc WrapPreCheckFunc
|
||||
|
||||
// require the TXT record to be propagated to all authoritative name servers
|
||||
requireAuthoritativeNssPropagation bool
|
||||
|
||||
// require the TXT record to be propagated to all recursive name servers
|
||||
requireRecursiveNssPropagation bool
|
||||
}
|
||||
|
||||
func newPreCheck() preCheck {
|
||||
return preCheck{
|
||||
requireAuthoritativeNssPropagation: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (p preCheck) call(ctx context.Context, domain, fqdn, value string) (bool, error) {
|
||||
if p.checkFunc == nil {
|
||||
return p.checkDNSPropagation(ctx, fqdn, value)
|
||||
}
|
||||
|
||||
return p.checkFunc(ctx, domain, fqdn, value, p.checkDNSPropagation)
|
||||
}
|
||||
|
||||
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
|
||||
func (p preCheck) checkDNSPropagation(ctx context.Context, fqdn, value string) (bool, error) {
|
||||
client := DefaultClient()
|
||||
|
||||
// Initial attempt to resolve at the recursive NS (require getting CNAME)
|
||||
r, err := client.sendQuery(ctx, fqdn, dns.TypeTXT, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("initial recursive nameserver: %w", err)
|
||||
}
|
||||
|
||||
if r.Rcode == dns.RcodeSuccess {
|
||||
fqdn = updateDomainWithCName(r, fqdn)
|
||||
}
|
||||
|
||||
if p.requireRecursiveNssPropagation {
|
||||
_, err = client.checkNameserversPropagation(ctx, fqdn, value, false)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("recursive nameservers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !p.requireAuthoritativeNssPropagation {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
authoritativeNss, err := client.lookupAuthoritativeNameservers(ctx, fqdn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
found, err := client.checkNameserversPropagationCustom(ctx, fqdn, value, authoritativeNss, true)
|
||||
if err != nil {
|
||||
return found, fmt.Errorf("authoritative nameservers: %w", err)
|
||||
}
|
||||
|
||||
return found, nil
|
||||
}
|
||||
78
challenge/dnsnew/dns_challenge_precheck_test.go
Normal file
78
challenge/dnsnew/dns_challenge_precheck_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_preCheck_checkDNSPropagation(t *testing.T) {
|
||||
mockDefault(t,
|
||||
dnsmock.NewServer().
|
||||
Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)).
|
||||
Query("example.com. SOA", dnsmock.SOA("")).
|
||||
Query("example.com. NS",
|
||||
dnsmock.Answer(
|
||||
fakeNS("example.com.", "ns0.lego.localhost."),
|
||||
fakeNS("example.com.", "ns1.lego.localhost."),
|
||||
),
|
||||
).
|
||||
Build(t),
|
||||
mockResolver(
|
||||
dnsmock.NewServer().
|
||||
Query("ns0.lego.localhost. A",
|
||||
dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))).
|
||||
Query("ns1.lego.localhost. A",
|
||||
dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))).
|
||||
Query("example.com. TXT",
|
||||
dnsmock.Answer(
|
||||
fakeTXT("example.com.", "one"),
|
||||
fakeTXT("example.com.", "two"),
|
||||
fakeTXT("example.com.", "three"),
|
||||
fakeTXT("example.com.", "four"),
|
||||
fakeTXT("example.com.", "five"),
|
||||
),
|
||||
).
|
||||
Build(t),
|
||||
),
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
value string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
fqdn: "example.com.",
|
||||
value: "four",
|
||||
},
|
||||
{
|
||||
desc: "no matching TXT record",
|
||||
fqdn: "acme-staging.api.example.com.",
|
||||
value: "fe01=",
|
||||
expectedError: "did not return the expected TXT record [fqdn: acme-staging.api.example.com., value: fe01=]: one ,two ,three ,four ,five",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
DefaultClient().ClearFqdnCache()
|
||||
|
||||
check := newPreCheck()
|
||||
|
||||
ok, err := check.checkDNSPropagation(t.Context(), test.fqdn, test.value)
|
||||
if test.expectedError != "" {
|
||||
assert.ErrorContainsf(t, err, test.expectedError, "PreCheckDNS must fail for %s", test.fqdn)
|
||||
assert.False(t, ok, "PreCheckDNS must fail for %s", test.fqdn)
|
||||
} else {
|
||||
assert.NoErrorf(t, err, "PreCheckDNS failed for %s", test.fqdn)
|
||||
assert.True(t, ok, "PreCheckDNS failed for %s", test.fqdn)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
348
challenge/dnsnew/dns_challenge_test.go
Normal file
348
challenge/dnsnew/dns_challenge_test.go
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/acme"
|
||||
"github.com/go-acme/lego/v5/acme/api"
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/go-acme/lego/v5/platform/tester"
|
||||
"github.com/go-acme/lego/v5/platform/tester/dnsmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type providerMock struct {
|
||||
present, cleanUp error
|
||||
}
|
||||
|
||||
func (p *providerMock) Present(domain, token, keyAuth string) error { return p.present }
|
||||
func (p *providerMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
|
||||
|
||||
type providerTimeoutMock struct {
|
||||
present, cleanUp error
|
||||
timeout, interval time.Duration
|
||||
}
|
||||
|
||||
func (p *providerTimeoutMock) Present(domain, token, keyAuth string) error { return p.present }
|
||||
func (p *providerTimeoutMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
|
||||
func (p *providerTimeoutMock) Timeout() (time.Duration, time.Duration) { return p.timeout, p.interval }
|
||||
|
||||
func TestChallenge_PreSolve(t *testing.T) {
|
||||
server := tester.MockACMEServer().BuildHTTPS(t)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
|
||||
return false, errors.New("OOPS")
|
||||
},
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
Value: "example.com",
|
||||
},
|
||||
Challenges: []acme.Challenge{
|
||||
{Type: challenge.DNS01.String()},
|
||||
},
|
||||
}
|
||||
|
||||
err = chlg.PreSolve(t.Context(), authz)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallenge_Solve(t *testing.T) {
|
||||
mockDefault(t, dnsmock.NewServer().
|
||||
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
|
||||
Build(t))
|
||||
|
||||
server := tester.MockACMEServer().BuildHTTPS(t)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
|
||||
return false, errors.New("OOPS")
|
||||
},
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
var options []ChallengeOption
|
||||
if test.preCheck != nil {
|
||||
options = append(options, WrapPreCheck(test.preCheck))
|
||||
}
|
||||
|
||||
chlg := NewChallenge(core, test.validate, test.provider, options...)
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
Value: "example.com",
|
||||
},
|
||||
Challenges: []acme.Challenge{
|
||||
{Type: challenge.DNS01.String()},
|
||||
},
|
||||
}
|
||||
|
||||
err = chlg.Solve(t.Context(), authz)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChallenge_CleanUp(t *testing.T) {
|
||||
server := tester.MockACMEServer().BuildHTTPS(t)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
core, err := api.New(server.Client(), "lego-test", server.URL+"/dir", "", privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) {
|
||||
return false, errors.New("OOPS")
|
||||
},
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ context.Context, _ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_ context.Context, _, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
Value: "example.com",
|
||||
},
|
||||
Challenges: []acme.Challenge{
|
||||
{Type: challenge.DNS01.String()},
|
||||
},
|
||||
}
|
||||
|
||||
err = chlg.CleanUp(authz)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChallengeInfo(t *testing.T) {
|
||||
mockDefault(t, dnsmock.NewServer().
|
||||
Query("_acme-challenge.example.com. CNAME", dnsmock.Noop).
|
||||
Build(t))
|
||||
|
||||
info := GetChallengeInfo(t.Context(), "example.com", "123")
|
||||
|
||||
expected := ChallengeInfo{
|
||||
FQDN: "_acme-challenge.example.com.",
|
||||
EffectiveFQDN: "_acme-challenge.example.com.",
|
||||
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, info)
|
||||
}
|
||||
|
||||
func TestGetChallengeInfo_CNAME(t *testing.T) {
|
||||
mockDefault(t, dnsmock.NewServer().
|
||||
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
|
||||
Query("example.org. CNAME", dnsmock.Noop).
|
||||
Build(t))
|
||||
|
||||
info := GetChallengeInfo(t.Context(), "example.com", "123")
|
||||
|
||||
expected := ChallengeInfo{
|
||||
FQDN: "_acme-challenge.example.com.",
|
||||
EffectiveFQDN: "example.org.",
|
||||
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, info)
|
||||
}
|
||||
|
||||
func TestGetChallengeInfo_CNAME_disabled(t *testing.T) {
|
||||
mockDefault(t, dnsmock.NewServer().
|
||||
// Never called when the env var works.
|
||||
Query("_acme-challenge.example.com. CNAME", dnsmock.CNAME("example.org.")).
|
||||
Build(t))
|
||||
|
||||
t.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true")
|
||||
|
||||
info := GetChallengeInfo(t.Context(), "example.com", "123")
|
||||
|
||||
expected := ChallengeInfo{
|
||||
FQDN: "_acme-challenge.example.com.",
|
||||
EffectiveFQDN: "_acme-challenge.example.com.",
|
||||
Value: "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM",
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, info)
|
||||
}
|
||||
24
challenge/dnsnew/domain.go
Normal file
24
challenge/dnsnew/domain.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ExtractSubDomain extracts the subdomain part from a domain and a zone.
|
||||
func ExtractSubDomain(domain, zone string) (string, error) {
|
||||
canonDomain := dns.Fqdn(domain)
|
||||
canonZone := dns.Fqdn(zone)
|
||||
|
||||
if canonDomain == canonZone {
|
||||
return "", fmt.Errorf("no subdomain because the domain and the zone are identical: %s", canonDomain)
|
||||
}
|
||||
|
||||
if !dns.IsSubDomain(canonZone, canonDomain) {
|
||||
return "", fmt.Errorf("%s is not a subdomain of %s", canonDomain, canonZone)
|
||||
}
|
||||
|
||||
return strings.TrimSuffix(canonDomain, "."+canonZone), nil
|
||||
}
|
||||
102
challenge/dnsnew/domain_test.go
Normal file
102
challenge/dnsnew/domain_test.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractSubDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domain string
|
||||
zone string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "no FQDN",
|
||||
domain: "_acme-challenge.example.com",
|
||||
zone: "example.com",
|
||||
expected: "_acme-challenge",
|
||||
},
|
||||
{
|
||||
desc: "no FQDN zone",
|
||||
domain: "_acme-challenge.example.com.",
|
||||
zone: "example.com",
|
||||
expected: "_acme-challenge",
|
||||
},
|
||||
{
|
||||
desc: "no FQDN domain",
|
||||
domain: "_acme-challenge.example.com",
|
||||
zone: "example.com.",
|
||||
expected: "_acme-challenge",
|
||||
},
|
||||
{
|
||||
desc: "FQDN",
|
||||
domain: "_acme-challenge.example.com.",
|
||||
zone: "example.com.",
|
||||
expected: "_acme-challenge",
|
||||
},
|
||||
{
|
||||
desc: "multi-level subdomain",
|
||||
domain: "_acme-challenge.one.example.com.",
|
||||
zone: "example.com.",
|
||||
expected: "_acme-challenge.one",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subDomain, err := ExtractSubDomain(test.domain, test.zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected, subDomain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSubDomain_errors(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domain string
|
||||
zone string
|
||||
}{
|
||||
{
|
||||
desc: "same domain",
|
||||
domain: "example.com",
|
||||
zone: "example.com",
|
||||
},
|
||||
{
|
||||
desc: "same domain, no FQDN zone",
|
||||
domain: "example.com.",
|
||||
zone: "example.com",
|
||||
},
|
||||
{
|
||||
desc: "same domain, no FQDN domain",
|
||||
domain: "example.com",
|
||||
zone: "example.com.",
|
||||
},
|
||||
{
|
||||
desc: "same domain, FQDN",
|
||||
domain: "example.com.",
|
||||
zone: "example.com.",
|
||||
},
|
||||
{
|
||||
desc: "zone and domain are unrelated",
|
||||
domain: "_acme-challenge.example.com",
|
||||
zone: "example.org",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := ExtractSubDomain(test.domain, test.zone)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
challenge/dnsnew/fixtures/resolv.conf.1
Normal file
5
challenge/dnsnew/fixtures/resolv.conf.1
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
domain example.com
|
||||
nameserver 10.200.3.249
|
||||
nameserver 10.200.3.250:5353
|
||||
nameserver 2001:4860:4860::8844
|
||||
nameserver [10.0.0.1]:5353
|
||||
47
challenge/dnsnew/fqdn.go
Normal file
47
challenge/dnsnew/fqdn.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// UnFqdn converts the fqdn into a name removing the trailing dot.
|
||||
func UnFqdn(name string) string {
|
||||
n := len(name)
|
||||
if n != 0 && name[n-1] == '.' {
|
||||
return name[:n-1]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// UnFqdnDomainsSeq generates a sequence of "unFQDNed" domain names derived from a domain (FQDN or not) in descending order.
|
||||
func UnFqdnDomainsSeq(fqdn string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
if fqdn == "" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, index := range dns.Split(fqdn) {
|
||||
if !yield(UnFqdn(fqdn[index:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DomainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order.
|
||||
func DomainsSeq(fqdn string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
if fqdn == "" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, index := range dns.Split(fqdn) {
|
||||
if !yield(fqdn[index:]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
137
challenge/dnsnew/fqdn_test.go
Normal file
137
challenge/dnsnew/fqdn_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnFqdn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "simple",
|
||||
fqdn: "foo.example.",
|
||||
expected: "foo.example",
|
||||
},
|
||||
{
|
||||
desc: "already domain",
|
||||
fqdn: "foo.example",
|
||||
expected: "foo.example",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
domain := UnFqdn(test.fqdn)
|
||||
|
||||
assert.Equal(t, test.expected, domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnFqdnDomainsSeq(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "empty",
|
||||
fqdn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "TLD",
|
||||
fqdn: "com",
|
||||
expected: []string{"com"},
|
||||
},
|
||||
{
|
||||
desc: "2 levels",
|
||||
fqdn: "example.com",
|
||||
expected: []string{"example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "3 levels",
|
||||
fqdn: "foo.example.com",
|
||||
expected: []string{"foo.example.com", "example.com", "com"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
for name, suffix := range map[string]string{"": "", " FQDN": "."} { //nolint:gocritic
|
||||
t.Run(test.desc+name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := slices.Collect(UnFqdnDomainsSeq(test.fqdn + suffix))
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainsSeq(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "empty",
|
||||
fqdn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty FQDN",
|
||||
fqdn: ".",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "TLD FQDN",
|
||||
fqdn: "com",
|
||||
expected: []string{"com"},
|
||||
},
|
||||
{
|
||||
desc: "TLD",
|
||||
fqdn: "com.",
|
||||
expected: []string{"com."},
|
||||
},
|
||||
{
|
||||
desc: "2 levels",
|
||||
fqdn: "example.com",
|
||||
expected: []string{"example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "2 levels FQDN",
|
||||
fqdn: "example.com.",
|
||||
expected: []string{"example.com.", "com."},
|
||||
},
|
||||
{
|
||||
desc: "3 levels",
|
||||
fqdn: "foo.example.com",
|
||||
expected: []string{"foo.example.com", "example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "3 levels FQDN",
|
||||
fqdn: "foo.example.com.",
|
||||
expected: []string{"foo.example.com.", "example.com.", "com."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := slices.Collect(DomainsSeq(test.fqdn))
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
78
challenge/dnsnew/mock_test.go
Normal file
78
challenge/dnsnew/mock_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package dnsnew
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func fakeNS(name, ns string) *dns.NS {
|
||||
return &dns.NS{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
|
||||
Ns: ns,
|
||||
}
|
||||
}
|
||||
|
||||
func fakeA(name, ip string) *dns.A {
|
||||
return &dns.A{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10},
|
||||
A: net.ParseIP(ip),
|
||||
}
|
||||
}
|
||||
|
||||
func fakeTXT(name, value string) *dns.TXT {
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 10},
|
||||
Txt: []string{value},
|
||||
}
|
||||
}
|
||||
|
||||
// mockResolver modifies the default DNS resolver to use a custom network address during the test execution.
|
||||
// IMPORTANT: it modifying std global variables.
|
||||
func mockResolver(authoritativeNS net.Addr) func(t *testing.T, client *Client) {
|
||||
return func(t *testing.T, client *Client) {
|
||||
t.Helper()
|
||||
|
||||
_, port, err := net.SplitHostPort(authoritativeNS.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
client.authoritativeNSPort = port
|
||||
|
||||
originalResolver := net.DefaultResolver
|
||||
|
||||
t.Cleanup(func() {
|
||||
net.DefaultResolver = originalResolver
|
||||
})
|
||||
|
||||
net.DefaultResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{Timeout: 1 * time.Second}
|
||||
|
||||
return d.DialContext(ctx, network, authoritativeNS.String())
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mockDefault(t *testing.T, recursiveNS net.Addr, opts ...func(t *testing.T, client *Client)) {
|
||||
t.Helper()
|
||||
|
||||
backup := DefaultClient()
|
||||
|
||||
t.Cleanup(func() {
|
||||
SetDefaultClient(backup)
|
||||
})
|
||||
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{recursiveNS.String()}})
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t, client)
|
||||
}
|
||||
|
||||
SetDefaultClient(client)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue