mirror of
https://github.com/go-acme/lego
synced 2026-03-14 14:35:48 +01:00
refactor: create a core DNS client
This commit is contained in:
parent
d8f2938799
commit
0bf4a55f18
32 changed files with 571 additions and 745 deletions
|
|
@ -2,16 +2,9 @@ package dns01
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/go-acme/lego/v5/challenge/internal"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var defaultClient atomic.Pointer[Client]
|
||||
|
|
@ -26,112 +19,37 @@ func SetDefaultClient(c *Client) {
|
|||
defaultClient.Store(c)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RecursiveNameservers []string
|
||||
Timeout time.Duration
|
||||
TCPOnly bool
|
||||
NetworkStack challenge.NetworkStack
|
||||
}
|
||||
type Options = internal.Options
|
||||
|
||||
type Client struct {
|
||||
recursiveNameservers []string
|
||||
core *internal.Client
|
||||
|
||||
// authoritativeNSPort used by authoritative NS.
|
||||
// For testing purposes only.
|
||||
authoritativeNSPort string
|
||||
|
||||
tcpClient *dns.Client
|
||||
udpClient *dns.Client
|
||||
tcpOnly bool
|
||||
|
||||
fqdnSoaCache map[string]*soaCacheEntry
|
||||
muFqdnSoaCache sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(opts *Options) *Client {
|
||||
if opts == nil {
|
||||
tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY"))
|
||||
opts = &Options{TCPOnly: tcpOnly}
|
||||
}
|
||||
|
||||
if len(opts.RecursiveNameservers) == 0 {
|
||||
opts.RecursiveNameservers = internal.GetNameservers(internal.DefaultResolvConf, opts.NetworkStack)
|
||||
}
|
||||
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = internal.DNSTimeout
|
||||
}
|
||||
|
||||
return &Client{
|
||||
recursiveNameservers: internal.ParseNameservers(opts.RecursiveNameservers),
|
||||
authoritativeNSPort: "53",
|
||||
tcpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("tcp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
udpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("udp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
tcpOnly: opts.TCPOnly,
|
||||
fqdnSoaCache: map[string]*soaCacheEntry{},
|
||||
muFqdnSoaCache: sync.Mutex{},
|
||||
core: internal.NewClient(opts),
|
||||
|
||||
authoritativeNSPort: "53",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) {
|
||||
return c.sendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive)
|
||||
// FindZoneByFqdn determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
func (c *Client) FindZoneByFqdn(ctx context.Context, fqdn string) (string, error) {
|
||||
return c.core.FindZoneByFqdn(ctx, fqdn)
|
||||
}
|
||||
|
||||
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
|
||||
m := internal.CreateDNSMsg(fqdn, rtype, recursive)
|
||||
|
||||
if len(nameservers) == 0 {
|
||||
return nil, &internal.DNSError{Message: "empty list of nameservers"}
|
||||
}
|
||||
|
||||
var (
|
||||
r *dns.Msg
|
||||
err error
|
||||
errAll error
|
||||
)
|
||||
|
||||
for _, ns := range nameservers {
|
||||
r, err = c.exchange(ctx, m, ns)
|
||||
if err == nil && len(r.Answer) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
errAll = errors.Join(errAll, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, errAll
|
||||
}
|
||||
|
||||
return r, nil
|
||||
// FindZoneByFqdnCustom determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
func (c *Client) FindZoneByFqdnCustom(ctx context.Context, fqdn string, nameservers []string) (string, error) {
|
||||
return c.core.FindZoneByFqdnCustom(ctx, fqdn, nameservers)
|
||||
}
|
||||
|
||||
func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
|
||||
if c.tcpOnly {
|
||||
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
if err != nil {
|
||||
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
r, _, err := c.udpClient.ExchangeContext(ctx, m, ns)
|
||||
|
||||
if r != nil && r.Truncated {
|
||||
// If the TCP request succeeds, the "err" will reset to nil
|
||||
r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
|
||||
func (c *Client) ClearFqdnCache() {
|
||||
c.core.ClearFqdnCache()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func (c *Client) resolveCNAME(ctx context.Context, fqdn string) (string, error) {
|
||||
r, err := c.sendQuery(ctx, fqdn, dns.TypeTXT, true)
|
||||
r, err := c.core.SendQuery(ctx, fqdn, dns.TypeTXT, true)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("initial recursive nameserver: %w", err)
|
||||
}
|
||||
|
|
@ -27,7 +27,7 @@ func (c *Client) lookupCNAME(ctx context.Context, fqdn string) string {
|
|||
// recursion counter so it doesn't spin out of control
|
||||
for range 50 {
|
||||
// Keep following CNAMEs
|
||||
r, err := c.sendQuery(ctx, fqdn, dns.TypeCNAME, true)
|
||||
r, err := c.core.SendQuery(ctx, fqdn, dns.TypeCNAME, true)
|
||||
if err != nil {
|
||||
log.Debug("Lookup CNAME.",
|
||||
slog.String("fqdn", fqdn),
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ import (
|
|||
|
||||
// checkRecursiveNameserversPropagation queries each of the recursive nameservers for the expected TXT record.
|
||||
func (c *Client) checkRecursiveNameserversPropagation(ctx context.Context, fqdn, value string) (bool, error) {
|
||||
return c.checkNameserversPropagationCustom(ctx, fqdn, value, c.recursiveNameservers, false)
|
||||
return c.checkNameserversPropagationCustom(ctx, fqdn, value, c.core.GetRecursiveNameservers(), false)
|
||||
}
|
||||
|
||||
// checkRecursiveNameserversPropagation queries each of the authoritative nameservers for the expected TXT record.
|
||||
func (c *Client) checkAuthoritativeNameserversPropagation(ctx context.Context, fqdn, value string) (bool, error) {
|
||||
authoritativeNss, err := c.lookupAuthoritativeNameservers(ctx, fqdn)
|
||||
authoritativeNss, err := c.core.LookupAuthoritativeNameservers(ctx, fqdn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -31,7 +31,7 @@ func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn, va
|
|||
ns = net.JoinHostPort(ns, c.authoritativeNSPort)
|
||||
}
|
||||
|
||||
r, err := c.sendQueryCustom(ctx, fqdn, dns.TypeTXT, []string{ns}, false)
|
||||
r, err := c.core.SendQueryCustom(ctx, fqdn, dns.TypeTXT, []string{ns}, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -63,30 +63,3 @@ func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn, va
|
|||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// lookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn.
|
||||
func (c *Client) lookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
zone, err := c.FindZoneByFqdn(ctx, fqdn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find zone: %w", err)
|
||||
}
|
||||
|
||||
r, err := c.sendQuery(ctx, zone, dns.TypeNS, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NS call failed: %w", err)
|
||||
}
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||
}
|
||||
}
|
||||
|
||||
if len(authoritativeNss) > 0 {
|
||||
return authoritativeNss, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
package dns01
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
dnsmock2 "github.com/go-acme/lego/v5/internal/tester/dnsmock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -74,113 +72,3 @@ func TestClient_checkNameserversPropagationCustom_authoritativeNss(t *testing.T)
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_lookupAuthoritativeNameservers_OK(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fakeDNSServer *dnsmock2.Builder
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
fqdn: "en.wikipedia.org.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("en.wikipedia.org.localhost SOA", dnsmock2.CNAME("dyna.wikimedia.org.localhost")).
|
||||
Query("wikipedia.org.localhost SOA", dnsmock2.SOA("")).
|
||||
Query("wikipedia.org.localhost NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "www.google.com.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("www.google.com.localhost. SOA", dnsmock2.Noop).
|
||||
Query("google.com.localhost. SOA", dnsmock2.SOA("")).
|
||||
Query("google.com.localhost. NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("google.com.localhost.", "ns1.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns2.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns3.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns4.google.com.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "mail.proton.me.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("mail.proton.me.localhost. SOA", dnsmock2.Noop).
|
||||
Query("proton.me.localhost. SOA", dnsmock2.SOA("")).
|
||||
Query("proton.me.localhost. NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.fqdn, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
nss, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.NoError(t, err)
|
||||
|
||||
sort.Strings(nss)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, nss)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_lookupAuthoritativeNameservers_error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
fakeDNSServer *dnsmock2.Builder
|
||||
error string
|
||||
}{
|
||||
{
|
||||
desc: "NXDOMAIN",
|
||||
fqdn: "example.invalid.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query(". SOA", dnsmock2.Error(dns.RcodeNameError)),
|
||||
error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]",
|
||||
},
|
||||
{
|
||||
desc: "NS error",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("example.com. SOA", dnsmock2.SOA("")).
|
||||
Query("example.com. NS", dnsmock2.Error(dns.RcodeServerFailure)),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
{
|
||||
desc: "empty NS",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("example.com. SOA", dnsmock2.SOA("")).
|
||||
Query("example.me NS", dnsmock2.Noop),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
_, err := client.lookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, test.error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package dns01
|
|||
import (
|
||||
"iter"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge/internal"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
|
@ -33,15 +34,5 @@ func UnFqdnDomainsSeq(fqdn string) iter.Seq[string] {
|
|||
|
||||
// DomainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order.
|
||||
func DomainsSeq(fqdn string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
if fqdn == "" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, index := range dns.Split(fqdn) {
|
||||
if !yield(fqdn[index:]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return internal.DomainsSeq(fqdn)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,62 +76,3 @@ func TestUnFqdnDomainsSeq(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainsSeq(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "empty",
|
||||
fqdn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty FQDN",
|
||||
fqdn: ".",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "TLD FQDN",
|
||||
fqdn: "com",
|
||||
expected: []string{"com"},
|
||||
},
|
||||
{
|
||||
desc: "TLD",
|
||||
fqdn: "com.",
|
||||
expected: []string{"com."},
|
||||
},
|
||||
{
|
||||
desc: "2 levels",
|
||||
fqdn: "example.com",
|
||||
expected: []string{"example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "2 levels FQDN",
|
||||
fqdn: "example.com.",
|
||||
expected: []string{"example.com.", "com."},
|
||||
},
|
||||
{
|
||||
desc: "3 levels",
|
||||
fqdn: "foo.example.com",
|
||||
expected: []string{"foo.example.com", "example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "3 levels FQDN",
|
||||
fqdn: "foo.example.com.",
|
||||
expected: []string{"foo.example.com.", "example.com.", "com."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := slices.Collect(DomainsSeq(test.fqdn))
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//nolint:unparam // Keep the name for test readability.
|
||||
func fakeNS(name, ns string) *dns.NS {
|
||||
return &dns.NS{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
package dnspersist01
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/go-acme/lego/v5/challenge/internal"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var defaultClient atomic.Pointer[Client]
|
||||
|
|
@ -25,119 +18,25 @@ func SetDefaultClient(c *Client) {
|
|||
defaultClient.Store(c)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RecursiveNameservers []string
|
||||
Timeout time.Duration
|
||||
TCPOnly bool
|
||||
NetworkStack challenge.NetworkStack
|
||||
}
|
||||
type Options = internal.Options
|
||||
|
||||
type Client struct {
|
||||
recursiveNameservers []string
|
||||
core *internal.Client
|
||||
|
||||
// authoritativeNSPort used by authoritative NS.
|
||||
// For testing purposes only.
|
||||
authoritativeNSPort string
|
||||
|
||||
tcpClient *dns.Client
|
||||
udpClient *dns.Client
|
||||
tcpOnly bool
|
||||
}
|
||||
|
||||
func NewClient(opts *Options) *Client {
|
||||
if opts == nil {
|
||||
tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY"))
|
||||
opts = &Options{TCPOnly: tcpOnly}
|
||||
}
|
||||
|
||||
if len(opts.RecursiveNameservers) == 0 {
|
||||
opts.RecursiveNameservers = internal.GetNameservers(internal.DefaultResolvConf, opts.NetworkStack)
|
||||
}
|
||||
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = internal.DNSTimeout
|
||||
}
|
||||
|
||||
return &Client{
|
||||
recursiveNameservers: internal.ParseNameservers(opts.RecursiveNameservers),
|
||||
authoritativeNSPort: "53",
|
||||
tcpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("tcp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
udpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("udp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
tcpOnly: opts.TCPOnly,
|
||||
core: internal.NewClient(opts),
|
||||
|
||||
authoritativeNSPort: "53",
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a duplication of `Client.sendQuery()` from `dns01/client.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) {
|
||||
return c.sendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive)
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a duplication of `Client.sendQueryCustom()` from `dns01/client.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
|
||||
m := internal.CreateDNSMsg(fqdn, rtype, recursive)
|
||||
|
||||
if len(nameservers) == 0 {
|
||||
return nil, &internal.DNSError{Message: "empty list of nameservers"}
|
||||
}
|
||||
|
||||
var (
|
||||
r *dns.Msg
|
||||
err error
|
||||
errAll error
|
||||
)
|
||||
|
||||
for _, ns := range nameservers {
|
||||
r, err = c.exchange(ctx, m, ns)
|
||||
if err == nil && len(r.Answer) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
errAll = errors.Join(errAll, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, errAll
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a duplication of `Client.exchange()` from `dns01/client.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
|
||||
if c.tcpOnly {
|
||||
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
if err != nil {
|
||||
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
r, _, err := c.udpClient.ExchangeContext(ctx, m, ns)
|
||||
|
||||
if r != nil && r.Truncated {
|
||||
// If the TCP request succeeds, the "err" will reset to nil
|
||||
r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
|
||||
func (c *Client) ClearFqdnCache() {
|
||||
c.core.ClearFqdnCache()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,19 +4,16 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// checkRecursiveNameserversPropagation queries each of the recursive nameservers for the expected TXT record.
|
||||
func (c *Client) checkRecursiveNameserversPropagation(ctx context.Context, fqdn string, matcher RecordMatcher) (bool, error) {
|
||||
return c.checkNameserversPropagationCustom(ctx, fqdn, c.recursiveNameservers, matcher, false, true)
|
||||
return c.checkNameserversPropagationCustom(ctx, fqdn, c.core.GetRecursiveNameservers(), matcher, false, true)
|
||||
}
|
||||
|
||||
// checkRecursiveNameserversPropagation queries each of the authoritative nameservers for the expected TXT record.
|
||||
func (c *Client) checkAuthoritativeNameserversPropagation(ctx context.Context, fqdn string, matcher RecordMatcher) (bool, error) {
|
||||
authoritativeNss, err := c.lookupAuthoritativeNameservers(ctx, fqdn)
|
||||
authoritativeNss, err := c.core.LookupAuthoritativeNameservers(ctx, fqdn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -42,34 +39,3 @@ func (c *Client) checkNameserversPropagationCustom(ctx context.Context, fqdn str
|
|||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// lookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn.
|
||||
/*
|
||||
* NOTE(ldez): This function is a duplication of `lookupAuthoritativeNameservers()` from `dns01/client_nameservers.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) lookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
zone, err := c.FindZoneByFqdn(ctx, fqdn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find zone: %w", err)
|
||||
}
|
||||
|
||||
r, err := c.sendQuery(ctx, zone, dns.TypeNS, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NS call failed: %w", err)
|
||||
}
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||
}
|
||||
}
|
||||
|
||||
if len(authoritativeNss) > 0 {
|
||||
return authoritativeNss, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ func (r TXTResult) String() string {
|
|||
// LookupTXT resolves TXT records at fqdn.
|
||||
// If CNAMEs are returned, they are followed up to 50 times to resolve TXT records.
|
||||
func (c *Client) LookupTXT(ctx context.Context, fqdn string) (TXTResult, error) {
|
||||
return c.lookupTXT(ctx, fqdn, c.recursiveNameservers, true)
|
||||
return c.lookupTXT(ctx, fqdn, c.core.GetRecursiveNameservers(), true)
|
||||
}
|
||||
|
||||
func (c *Client) lookupTXT(ctx context.Context, fqdn string, nameservers []string, recursive bool) (TXTResult, error) {
|
||||
|
|
@ -62,7 +62,7 @@ func (c *Client) lookupTXT(ctx context.Context, fqdn string, nameservers []strin
|
|||
|
||||
seen[name] = struct{}{}
|
||||
|
||||
msg, err := c.sendQueryCustom(ctx, name, dns.TypeTXT, nameservers, recursive)
|
||||
msg, err := c.core.SendQueryCustom(ctx, name, dns.TypeTXT, nameservers, recursive)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
package dnspersist01
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge/internal"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// FindZoneByFqdn determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
/*
|
||||
* NOTE(ldez): This function is a partial duplication of `Client.FindZoneByFqdn()` from `dns01/client_zone.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) FindZoneByFqdn(ctx context.Context, fqdn string) (string, error) {
|
||||
return c.FindZoneByFqdnCustom(ctx, fqdn, c.recursiveNameservers)
|
||||
}
|
||||
|
||||
// FindZoneByFqdnCustom determines the zone apex for the given fqdn
|
||||
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
|
||||
/*
|
||||
* NOTE(ldez): This function is a partial duplication of `Client.FindZoneByFqdnCustom()` from `dns01/client_zone.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) FindZoneByFqdnCustom(ctx context.Context, fqdn string, nameservers []string) (string, error) {
|
||||
soa, err := c.fetchSoaByFqdn(ctx, fqdn, nameservers)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
|
||||
}
|
||||
|
||||
return soa, nil
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a partial duplication of `Client.fetchSoaByFqdn()` from `dns01/client_zone.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []string) (string, error) {
|
||||
var (
|
||||
err error
|
||||
r *dns.Msg
|
||||
)
|
||||
|
||||
for domain := range domainsSeq(fqdn) {
|
||||
r, err = c.sendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch r.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
// Check if we got a SOA RR in the answer section
|
||||
if len(r.Answer) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// CNAME records cannot/should not exist at the root of a zone.
|
||||
// So we skip a domain when a CNAME is found.
|
||||
if internal.MsgContainsCNAME(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ans := range r.Answer {
|
||||
if soa, ok := ans.(*dns.SOA); ok {
|
||||
return soa.Hdr.Name, nil
|
||||
}
|
||||
}
|
||||
case dns.RcodeNameError:
|
||||
// NXDOMAIN
|
||||
default:
|
||||
// Any response code other than NOERROR and NXDOMAIN is treated as error
|
||||
return "", &internal.DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
|
||||
}
|
||||
}
|
||||
|
||||
return "", &internal.DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
|
||||
}
|
||||
|
|
@ -66,6 +66,8 @@ func Test_preCheck_checkDNSPropagation(t *testing.T) {
|
|||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
DefaultClient().ClearFqdnCache()
|
||||
|
||||
match := func(records []TXTRecord) bool {
|
||||
for _, record := range records {
|
||||
if record.Value == test.value {
|
||||
|
|
|
|||
|
|
@ -11,10 +11,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a duplication of `fakeNS` from `dns01/mock_test.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func fakeNS(name, ns string) *dns.NS {
|
||||
return &dns.NS{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
|
||||
|
|
@ -22,10 +18,6 @@ func fakeNS(name, ns string) *dns.NS {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a partial duplication of `fakeA` from `dns01/mock_test.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func fakeA(name, ip string) *dns.A {
|
||||
return &dns.A{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10},
|
||||
|
|
@ -33,10 +25,6 @@ func fakeA(name, ip string) *dns.A {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This function is a partial duplication of `fakeTXT` from `dns01/mock_test.go`.
|
||||
* The 2 functions should be kept in sync.
|
||||
*/
|
||||
func fakeTXT(name, value string, ttl uint32) *dns.TXT {
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: ttl},
|
||||
|
|
|
|||
135
challenge/internal/client.go
Normal file
135
challenge/internal/client.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
RecursiveNameservers []string
|
||||
Timeout time.Duration
|
||||
TCPOnly bool
|
||||
NetworkStack challenge.NetworkStack
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
recursiveNameservers []string
|
||||
|
||||
// authoritativeNSPort used by authoritative NS.
|
||||
// For testing purposes only.
|
||||
authoritativeNSPort string
|
||||
|
||||
tcpClient *dns.Client
|
||||
udpClient *dns.Client
|
||||
tcpOnly bool
|
||||
|
||||
fqdnSoaCache map[string]*soaCacheEntry
|
||||
muFqdnSoaCache sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(opts *Options) *Client {
|
||||
if opts == nil {
|
||||
tcpOnly, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY"))
|
||||
opts = &Options{TCPOnly: tcpOnly}
|
||||
}
|
||||
|
||||
if len(opts.RecursiveNameservers) == 0 {
|
||||
opts.RecursiveNameservers = getNameservers(DefaultResolvConf, opts.NetworkStack)
|
||||
}
|
||||
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = dnsTimeout
|
||||
}
|
||||
|
||||
return &Client{
|
||||
recursiveNameservers: parseNameservers(opts.RecursiveNameservers),
|
||||
authoritativeNSPort: "53",
|
||||
tcpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("tcp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
udpClient: &dns.Client{
|
||||
Net: opts.NetworkStack.Network("udp"),
|
||||
Timeout: opts.Timeout,
|
||||
},
|
||||
tcpOnly: opts.TCPOnly,
|
||||
fqdnSoaCache: map[string]*soaCacheEntry{},
|
||||
muFqdnSoaCache: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SendQuery(ctx context.Context, fqdn string, rtype uint16, recursive bool) (*dns.Msg, error) {
|
||||
return c.SendQueryCustom(ctx, fqdn, rtype, c.recursiveNameservers, recursive)
|
||||
}
|
||||
|
||||
func (c *Client) SendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
|
||||
m := createDNSMsg(fqdn, rtype, recursive)
|
||||
|
||||
if len(nameservers) == 0 {
|
||||
return nil, &DNSError{Message: "empty list of nameservers"}
|
||||
}
|
||||
|
||||
var (
|
||||
r *dns.Msg
|
||||
err error
|
||||
errAll error
|
||||
)
|
||||
|
||||
for _, ns := range nameservers {
|
||||
r, err = c.exchange(ctx, m, ns)
|
||||
if err == nil && len(r.Answer) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
errAll = errors.Join(errAll, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, errAll
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
|
||||
if c.tcpOnly {
|
||||
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
if err != nil {
|
||||
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
r, _, err := c.udpClient.ExchangeContext(ctx, m, ns)
|
||||
|
||||
if r != nil && r.Truncated {
|
||||
// If the TCP request succeeds, the "err" will reset to nil
|
||||
r, _, err = c.tcpClient.ExchangeContext(ctx, m, ns)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(fqdn, rtype)
|
||||
m.SetEdns0(4096, false)
|
||||
|
||||
if !recursive {
|
||||
m.RecursionDesired = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package dns01
|
||||
package internal
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
|
@ -9,9 +9,9 @@ import (
|
|||
)
|
||||
|
||||
func TestDNSError_Error(t *testing.T) {
|
||||
msgIn := CreateDNSMsg("example.com.", dns.TypeTXT, true)
|
||||
msgIn := createDNSMsg("example.com.", dns.TypeTXT, true)
|
||||
|
||||
msgOut := CreateDNSMsg("example.org.", dns.TypeSOA, true)
|
||||
msgOut := createDNSMsg("example.org.", dns.TypeSOA, true)
|
||||
msgOut.Rcode = dns.RcodeNameError
|
||||
|
||||
testCases := []struct {
|
||||
|
|
|
|||
91
challenge/internal/client_nameservers.go
Normal file
91
challenge/internal/client_nameservers.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultResolvConf = "/etc/resolv.conf"
|
||||
|
||||
// GetRecursiveNameservers returns a copy of the recursive nameservers.
|
||||
func (c *Client) GetRecursiveNameservers() []string {
|
||||
return slices.Clone(c.recursiveNameservers)
|
||||
}
|
||||
|
||||
// LookupAuthoritativeNameservers returns the authoritative nameservers for the given fqdn.
|
||||
func (c *Client) LookupAuthoritativeNameservers(ctx context.Context, fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
zone, err := c.FindZoneByFqdn(ctx, fqdn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find zone: %w", err)
|
||||
}
|
||||
|
||||
r, err := c.SendQuery(ctx, zone, dns.TypeNS, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NS call failed: %w", err)
|
||||
}
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||
}
|
||||
}
|
||||
|
||||
if len(authoritativeNss) > 0 {
|
||||
return authoritativeNss, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
|
||||
}
|
||||
|
||||
// getNameservers attempts to get systems nameservers before falling back to the defaults.
|
||||
func getNameservers(path string, stack challenge.NetworkStack) []string {
|
||||
config, err := dns.ClientConfigFromFile(path)
|
||||
if err == nil && len(config.Servers) > 0 {
|
||||
return config.Servers
|
||||
}
|
||||
|
||||
switch stack {
|
||||
case challenge.IPv4Only:
|
||||
return []string{
|
||||
"1.1.1.1:53",
|
||||
"1.0.0.1:53",
|
||||
}
|
||||
|
||||
case challenge.IPv6Only:
|
||||
return []string{
|
||||
"[2606:4700:4700::1111]:53",
|
||||
"[2606:4700:4700::1001]:53",
|
||||
}
|
||||
|
||||
default:
|
||||
return []string{
|
||||
"1.1.1.1:53",
|
||||
"1.0.0.1:53",
|
||||
"[2606:4700:4700::1111]:53",
|
||||
"[2606:4700:4700::1001]:53",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseNameservers(servers []string) []string {
|
||||
var resolvers []string
|
||||
|
||||
for _, resolver := range servers {
|
||||
// ensure all servers have a port number
|
||||
if _, _, err := net.SplitHostPort(resolver); err != nil {
|
||||
resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
|
||||
} else {
|
||||
resolvers = append(resolvers, resolver)
|
||||
}
|
||||
}
|
||||
|
||||
return resolvers
|
||||
}
|
||||
196
challenge/internal/client_nameservers_test.go
Normal file
196
challenge/internal/client_nameservers_test.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
dnsmock2 "github.com/go-acme/lego/v5/internal/tester/dnsmock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClient_LookupAuthoritativeNameservers_OK(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fakeDNSServer *dnsmock2.Builder
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
fqdn: "en.wikipedia.org.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("en.wikipedia.org.localhost SOA", dnsmock2.CNAME("dyna.wikimedia.org.localhost")).
|
||||
Query("wikipedia.org.localhost SOA", dnsmock2.SOA("")).
|
||||
Query("wikipedia.org.localhost NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("wikipedia.org.localhost.", "ns0.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns1.wikimedia.org.localhost."),
|
||||
fakeNS("wikipedia.org.localhost.", "ns2.wikimedia.org.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns0.wikimedia.org.localhost.", "ns1.wikimedia.org.localhost.", "ns2.wikimedia.org.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "www.google.com.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("www.google.com.localhost. SOA", dnsmock2.Noop).
|
||||
Query("google.com.localhost. SOA", dnsmock2.SOA("")).
|
||||
Query("google.com.localhost. NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("google.com.localhost.", "ns1.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns2.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns3.google.com.localhost."),
|
||||
fakeNS("google.com.localhost.", "ns4.google.com.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.google.com.localhost.", "ns2.google.com.localhost.", "ns3.google.com.localhost.", "ns4.google.com.localhost."},
|
||||
},
|
||||
{
|
||||
fqdn: "mail.proton.me.localhost.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("mail.proton.me.localhost. SOA", dnsmock2.Noop).
|
||||
Query("proton.me.localhost. SOA", dnsmock2.SOA("")).
|
||||
Query("proton.me.localhost. NS",
|
||||
dnsmock2.Answer(
|
||||
fakeNS("proton.me.localhost.", "ns1.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns2.proton.me.localhost."),
|
||||
fakeNS("proton.me.localhost.", "ns3.proton.me.localhost."),
|
||||
),
|
||||
),
|
||||
expected: []string{"ns1.proton.me.localhost.", "ns2.proton.me.localhost.", "ns3.proton.me.localhost."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.fqdn, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
nss, err := client.LookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.NoError(t, err)
|
||||
|
||||
sort.Strings(nss)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, nss)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_LookupAuthoritativeNameservers_error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
fakeDNSServer *dnsmock2.Builder
|
||||
error string
|
||||
}{
|
||||
{
|
||||
desc: "NXDOMAIN",
|
||||
fqdn: "example.invalid.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query(". SOA", dnsmock2.Error(dns.RcodeNameError)),
|
||||
error: "could not find zone: [fqdn=example.invalid.] could not find the start of authority for 'example.invalid.' [question='invalid. IN SOA', code=NXDOMAIN]",
|
||||
},
|
||||
{
|
||||
desc: "NS error",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("example.com. SOA", dnsmock2.SOA("")).
|
||||
Query("example.com. NS", dnsmock2.Error(dns.RcodeServerFailure)),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
{
|
||||
desc: "empty NS",
|
||||
fqdn: "example.com.",
|
||||
fakeDNSServer: dnsmock2.NewServer().
|
||||
Query("example.com. SOA", dnsmock2.SOA("")).
|
||||
Query("example.me NS", dnsmock2.Noop),
|
||||
error: "[zone=example.com.] could not determine authoritative nameservers",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
client := NewClient(&Options{RecursiveNameservers: []string{test.fakeDNSServer.Build(t).String()}})
|
||||
|
||||
_, err := client.LookupAuthoritativeNameservers(t.Context(), test.fqdn)
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, test.error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getNameservers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
stack challenge.NetworkStack
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "with resolv.conf",
|
||||
path: "fixtures/resolv.conf.1",
|
||||
stack: challenge.DualStack,
|
||||
expected: []string{"10.200.3.249", "10.200.3.250:5353", "2001:4860:4860::8844", "[10.0.0.1]:5353"},
|
||||
},
|
||||
{
|
||||
desc: "with nonexistent resolv.conf",
|
||||
path: "fixtures/resolv.conf.nonexistant",
|
||||
stack: challenge.DualStack,
|
||||
expected: []string{"1.0.0.1:53", "1.1.1.1:53", "[2606:4700:4700::1001]:53", "[2606:4700:4700::1111]:53"},
|
||||
},
|
||||
{
|
||||
desc: "default with IPv4Only",
|
||||
path: "resolv.conf.nonexistant",
|
||||
stack: challenge.IPv4Only,
|
||||
expected: []string{"1.0.0.1:53", "1.1.1.1:53"},
|
||||
},
|
||||
{
|
||||
desc: "default with IPv6Only",
|
||||
path: "resolv.conf.nonexistant",
|
||||
stack: challenge.IPv6Only,
|
||||
expected: []string{"[2606:4700:4700::1001]:53", "[2606:4700:4700::1111]:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
result := getNameservers(test.path, test.stack)
|
||||
|
||||
sort.Strings(result)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseNameservers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
servers []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "without explicit port",
|
||||
servers: []string{"ns1.example.com", "2001:db8::1"},
|
||||
expected: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
},
|
||||
{
|
||||
desc: "with explicit port",
|
||||
servers: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
expected: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := parseNameservers(test.servers)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
8
challenge/internal/client_timeout_unix.go
Normal file
8
challenge/internal/client_timeout_unix.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
//go:build !windows
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// dnsTimeout is used as the default DNS timeout on Unix-like systems.
|
||||
const dnsTimeout = 10 * time.Second
|
||||
8
challenge/internal/client_timeout_windows.go
Normal file
8
challenge/internal/client_timeout_windows.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// dnsTimeout is used as the default DNS timeout on Windows.
|
||||
const dnsTimeout = 20 * time.Second
|
||||
|
|
@ -1,10 +1,9 @@
|
|||
package dns01
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge/internal"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
|
@ -51,7 +50,7 @@ func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []
|
|||
)
|
||||
|
||||
for domain := range DomainsSeq(fqdn) {
|
||||
r, err = c.sendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true)
|
||||
r, err = c.SendQueryCustom(ctx, domain, dns.TypeSOA, nameservers, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -69,7 +68,7 @@ func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []
|
|||
|
||||
// CNAME records cannot/should not exist at the root of a zone.
|
||||
// So we skip a domain when a CNAME is found.
|
||||
if internal.MsgContainsCNAME(r) {
|
||||
if msgContainsCNAME(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -82,9 +81,9 @@ func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []
|
|||
// NXDOMAIN
|
||||
default:
|
||||
// Any response code other than NOERROR and NXDOMAIN is treated as error
|
||||
return nil, &internal.DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
|
||||
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, &internal.DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
|
||||
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package dns01
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
|
@ -23,8 +23,8 @@ func ExtractCNAME(msg *dns.Msg, name string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// MsgContainsCNAME checks for a CNAME answer in msg.
|
||||
func MsgContainsCNAME(msg *dns.Msg) bool {
|
||||
// msgContainsCNAME checks for a CNAME answer in msg.
|
||||
func msgContainsCNAME(msg *dns.Msg) bool {
|
||||
return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
|
||||
_, ok := rr.(*dns.CNAME)
|
||||
return ok
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
package dns01
|
||||
package internal
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_updateDomainWithCName_caseInsensitive(t *testing.T) {
|
||||
func Test_ExtractCNAME_caseInsensitive(t *testing.T) {
|
||||
qname := "_acme-challenge.uppercase-test.example.com."
|
||||
cnameTarget := "_acme-challenge.uppercase-test.cname-target.example.com."
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ func Test_updateDomainWithCName_caseInsensitive(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
fqdn := updateDomainWithCName(msg, qname)
|
||||
fqdn := ExtractCNAME(msg, qname)
|
||||
|
||||
assert.Equal(t, cnameTarget, fqdn)
|
||||
}
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
package internal
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
func CreateDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(fqdn, rtype)
|
||||
m.SetEdns0(4096, false)
|
||||
|
||||
if !recursive {
|
||||
m.RecursionDesired = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
//go:build !windows
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// DNSTimeout is used as the default DNS timeout on Unix-like systems.
|
||||
const DNSTimeout = 10 * time.Second
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// DNSTimeout is used as the default DNS timeout on Windows.
|
||||
const DNSTimeout = 20 * time.Second
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package dnspersist01
|
||||
package internal
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
|
@ -6,13 +6,8 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
/*
|
||||
* NOTE(ldez): This file is a duplication of `dns01/fqdn.go`.
|
||||
* The 2 files should be kept in sync.
|
||||
*/
|
||||
|
||||
// domainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order.
|
||||
func domainsSeq(fqdn string) iter.Seq[string] {
|
||||
// DomainsSeq generates a sequence of domain names derived from a domain (FQDN or not) in descending order.
|
||||
func DomainsSeq(fqdn string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
if fqdn == "" {
|
||||
return
|
||||
67
challenge/internal/fqdn_test.go
Normal file
67
challenge/internal/fqdn_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDomainsSeq(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fqdn string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "empty",
|
||||
fqdn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty FQDN",
|
||||
fqdn: ".",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "TLD FQDN",
|
||||
fqdn: "com",
|
||||
expected: []string{"com"},
|
||||
},
|
||||
{
|
||||
desc: "TLD",
|
||||
fqdn: "com.",
|
||||
expected: []string{"com."},
|
||||
},
|
||||
{
|
||||
desc: "2 levels",
|
||||
fqdn: "example.com",
|
||||
expected: []string{"example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "2 levels FQDN",
|
||||
fqdn: "example.com.",
|
||||
expected: []string{"example.com.", "com."},
|
||||
},
|
||||
{
|
||||
desc: "3 levels",
|
||||
fqdn: "foo.example.com",
|
||||
expected: []string{"foo.example.com", "example.com", "com"},
|
||||
},
|
||||
{
|
||||
desc: "3 levels FQDN",
|
||||
fqdn: "foo.example.com.",
|
||||
expected: []string{"foo.example.com.", "example.com.", "com."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := slices.Collect(DomainsSeq(test.fqdn))
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
12
challenge/internal/mock_test.go
Normal file
12
challenge/internal/mock_test.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func fakeNS(name, ns string) *dns.NS {
|
||||
return &dns.NS{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800},
|
||||
Ns: ns,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultResolvConf = "/etc/resolv.conf"
|
||||
|
||||
// GetNameservers attempts to get systems nameservers before falling back to the defaults.
|
||||
func GetNameservers(path string, stack challenge.NetworkStack) []string {
|
||||
config, err := dns.ClientConfigFromFile(path)
|
||||
if err == nil && len(config.Servers) > 0 {
|
||||
return config.Servers
|
||||
}
|
||||
|
||||
switch stack {
|
||||
case challenge.IPv4Only:
|
||||
return []string{
|
||||
"1.1.1.1:53",
|
||||
"1.0.0.1:53",
|
||||
}
|
||||
|
||||
case challenge.IPv6Only:
|
||||
return []string{
|
||||
"[2606:4700:4700::1111]:53",
|
||||
"[2606:4700:4700::1001]:53",
|
||||
}
|
||||
|
||||
default:
|
||||
return []string{
|
||||
"1.1.1.1:53",
|
||||
"1.0.0.1:53",
|
||||
"[2606:4700:4700::1111]:53",
|
||||
"[2606:4700:4700::1001]:53",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseNameservers(servers []string) []string {
|
||||
var resolvers []string
|
||||
|
||||
for _, resolver := range servers {
|
||||
// ensure all servers have a port number
|
||||
if _, _, err := net.SplitHostPort(resolver); err != nil {
|
||||
resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
|
||||
} else {
|
||||
resolvers = append(resolvers, resolver)
|
||||
}
|
||||
}
|
||||
|
||||
return resolvers
|
||||
}
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/go-acme/lego/v5/challenge"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetNameservers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
stack challenge.NetworkStack
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "with resolv.conf",
|
||||
path: "fixtures/resolv.conf.1",
|
||||
stack: challenge.DualStack,
|
||||
expected: []string{"10.200.3.249", "10.200.3.250:5353", "2001:4860:4860::8844", "[10.0.0.1]:5353"},
|
||||
},
|
||||
{
|
||||
desc: "with nonexistent resolv.conf",
|
||||
path: "fixtures/resolv.conf.nonexistant",
|
||||
stack: challenge.DualStack,
|
||||
expected: []string{"1.0.0.1:53", "1.1.1.1:53", "[2606:4700:4700::1001]:53", "[2606:4700:4700::1111]:53"},
|
||||
},
|
||||
{
|
||||
desc: "default with IPv4Only",
|
||||
path: "resolv.conf.nonexistant",
|
||||
stack: challenge.IPv4Only,
|
||||
expected: []string{"1.0.0.1:53", "1.1.1.1:53"},
|
||||
},
|
||||
{
|
||||
desc: "default with IPv6Only",
|
||||
path: "resolv.conf.nonexistant",
|
||||
stack: challenge.IPv6Only,
|
||||
expected: []string{"[2606:4700:4700::1001]:53", "[2606:4700:4700::1111]:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
result := GetNameservers(test.path, test.stack)
|
||||
|
||||
sort.Strings(result)
|
||||
sort.Strings(test.expected)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameservers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
servers []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "without explicit port",
|
||||
servers: []string{"ns1.example.com", "2001:db8::1"},
|
||||
expected: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
},
|
||||
{
|
||||
desc: "with explicit port",
|
||||
servers: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
expected: []string{"ns1.example.com:53", "[2001:db8::1]:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := ParseNameservers(test.servers)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue