diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 33ca106cc..51aa685a8 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -44,10 +44,10 @@ jobs: install-only: true - name: Install Pebble - run: go install github.com/letsencrypt/pebble/v2/cmd/pebble@v2.9.0 + run: go install github.com/letsencrypt/pebble/v2/cmd/pebble@v2.10.0 - name: Install challtestsrv - run: go install github.com/letsencrypt/pebble/v2/cmd/pebble-challtestsrv@v2.9.0 + run: go install github.com/letsencrypt/pebble/v2/cmd/pebble-challtestsrv@v2.10.0 - name: Set up a Memcached server run: docker run -d --rm -p 11211:11211 memcached:1.6-alpine diff --git a/acme/commons.go b/acme/commons.go index e520299d8..b665ca6d6 100644 --- a/acme/commons.go +++ b/acme/commons.go @@ -281,7 +281,7 @@ type Challenge struct { // A challenge object with an error MUST have status equal to "invalid". Error *ProblemDetails `json:"error,omitempty"` - // token (required, string): + // token (required for dns-01, http-01, tlsalpn-01, string): // A random value that uniquely identifies the challenge. // This value MUST have at least 128 bits of entropy. // It MUST NOT contain any characters outside the base64url alphabet, @@ -291,6 +291,11 @@ type Challenge struct { // https://www.rfc-editor.org/rfc/rfc8555.html#section-8.4 Token string `json:"token"` + // issuer-domain-names (required for dns-persist-01, []string): + // A list of Issuer Domain Names used for dns-persist-01 challenges. + // https://www.ietf.org/archive/id/draft-ietf-acme-dns-persist-00.html#section-3.1 + IssuerDomainNames []string `json:"issuer-domain-names,omitempty"` + // https://www.rfc-editor.org/rfc/rfc8555.html#section-8.1 KeyAuthorization string `json:"keyAuthorization"` } diff --git a/challenge/challenges.go b/challenge/challenges.go index 8115fa851..562557a94 100644 --- a/challenge/challenges.go +++ b/challenge/challenges.go @@ -18,6 +18,9 @@ const ( // Note: GetRecord returns a DNS record which will fulfill this challenge. DNS01 = Type("dns-01") + // DNSPersist01 is the "dns-persist-01" ACME challenge https://datatracker.ietf.org/doc/draft-ietf-acme-dns-persist. + DNSPersist01 = Type("dns-persist-01") + // TLSALPN01 is the "tls-alpn-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8737.html TLSALPN01 = Type("tls-alpn-01") ) diff --git a/challenge/dns01/dns_challenge.go b/challenge/dns01/dns_challenge.go index 20686addd..91e555ff2 100644 --- a/challenge/dns01/dns_challenge.go +++ b/challenge/dns01/dns_challenge.go @@ -71,7 +71,7 @@ func (c *Challenge) PreSolve(ctx context.Context, authz acme.Authorization) erro } if c.provider == nil { - return fmt.Errorf("[%s] acme: no DNS Provider configured", domain) + return fmt.Errorf("[%s] acme: no DNS Provider configured for DNS-01", domain) } // Generate the Key Authorization for the challenge @@ -82,7 +82,7 @@ func (c *Challenge) PreSolve(ctx context.Context, authz acme.Authorization) erro err = c.provider.Present(ctx, authz.Identifier.Value, chlng.Token, keyAuth) if err != nil { - return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err) + return fmt.Errorf("[%s] acme: error presenting token for DNS-01: %w", domain, err) } return nil @@ -114,7 +114,7 @@ func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error { timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval } - log.Info("acme: waiting for DNS record propagation.", + log.Info("acme: waiting for DNS-01 record propagation.", log.DomainAttr(domain), slog.String("nameservers", strings.Join(DefaultClient().recursiveNameservers, ",")), ) @@ -124,7 +124,7 @@ func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error { err = wait.For("propagation", timeout, interval, func() (bool, error) { stop, errP := c.preCheck.call(ctx, domain, info.EffectiveFQDN, info.Value) if !stop || errP != nil { - log.Info("acme: waiting for DNS record propagation.", log.DomainAttr(domain)) + log.Info("acme: waiting for DNS-01 record propagation.", log.DomainAttr(domain)) } return stop, errP diff --git a/challenge/dnspersist01/dns_persist_challenge.go b/challenge/dnspersist01/dns_persist_challenge.go new file mode 100644 index 000000000..0548d4ed9 --- /dev/null +++ b/challenge/dnspersist01/dns_persist_challenge.go @@ -0,0 +1,448 @@ +package dnspersist01 + +import ( + "context" + "errors" + "fmt" + "log/slog" + "slices" + "sort" + "strings" + "time" + + "github.com/go-acme/lego/v5/acme" + "github.com/go-acme/lego/v5/acme/api" + "github.com/go-acme/lego/v5/challenge" + "github.com/go-acme/lego/v5/log" + "github.com/go-acme/lego/v5/platform/wait" + "github.com/miekg/dns" +) + +const validationLabel = "_validation-persist" + +const ( + // DefaultPropagationTimeout default propagation timeout. + DefaultPropagationTimeout = 60 * time.Second + + // DefaultPollingInterval default polling interval. + DefaultPollingInterval = 2 * time.Second +) + +// ValidateFunc validates a challenge with the ACME server. +type ValidateFunc func(ctx context.Context, core *api.Core, domain string, chlng acme.Challenge) error + +// ChallengeOption configures the dns-persist-01 challenge. +type ChallengeOption func(*Challenge) error + +// ChallengeInfo contains the information used to create a dns-persist-01 TXT +// record. +type ChallengeInfo struct { + // FQDN is the full-qualified challenge domain (i.e. + // `_validation-persist.[domain].`). + FQDN string + + // Value contains the TXT record value, an RFC 8659 issue-value. + Value string + + // IssuerDomainName is the normalized issuer-domain-name used in Value. + IssuerDomainName string +} + +// Challenge implements the dns-persist-01 challenge exclusively with manual +// instructions for TXT record creation. +type Challenge struct { + core *api.Core + validate ValidateFunc + resolver *Resolver + preCheck preCheck + + accountURI string + userSuppliedIssuerDomainName string + persistUntil time.Time + recursiveNameservers []string + authoritativeNSPort string + + propagationTimeout time.Duration + propagationInterval time.Duration +} + +// NewChallenge creates a dns-persist-01 challenge. +func NewChallenge(core *api.Core, validate ValidateFunc, opts ...ChallengeOption) (*Challenge, error) { + chlg := &Challenge{ + core: core, + validate: validate, + resolver: NewResolver(nil), + preCheck: newPreCheck(), + recursiveNameservers: DefaultNameservers(), + authoritativeNSPort: defaultAuthoritativeNSPort, + + propagationTimeout: DefaultPropagationTimeout, + propagationInterval: DefaultPollingInterval, + } + + for _, opt := range opts { + err := opt(chlg) + if err != nil { + return nil, fmt.Errorf("dnspersist01: %w", err) + } + } + + if chlg.accountURI == "" { + return nil, errors.New("dnspersist01: account URI cannot be empty") + } + + return chlg, nil +} + +// CondOptions Conditional challenge options. +func CondOptions(condition bool, opt ...ChallengeOption) ChallengeOption { + if !condition { + // NoOp options + return func(*Challenge) error { + return nil + } + } + + return func(chlg *Challenge) error { + for _, opt := range opt { + err := opt(chlg) + if err != nil { + return err + } + } + + return nil + } +} + +// WithResolver overrides the resolver used for DNS lookups. +func WithResolver(resolver *Resolver) ChallengeOption { + return func(chlg *Challenge) error { + if resolver == nil { + return errors.New("resolver is nil") + } + + chlg.resolver = resolver + + return nil + } +} + +// WithNameservers overrides resolver nameservers using the default timeout. +func WithNameservers(nameservers []string) ChallengeOption { + return func(chlg *Challenge) error { + chlg.resolver = NewResolver(nameservers) + + return nil + } +} + +// WithDNSTimeout overrides the default DNS resolver timeout. +func WithDNSTimeout(timeout time.Duration) ChallengeOption { + return func(chlg *Challenge) error { + if chlg.resolver == nil { + chlg.resolver = NewResolver(nil) + } + + chlg.resolver.Timeout = timeout + + return nil + } +} + +// WithAccountURI sets the ACME account URI bound to dns-persist-01 records. It +// is required both to construct the `accounturi=` parameter and to match +// already-provisioned TXT records that should be updated. +func WithAccountURI(accountURI string) ChallengeOption { + return func(chlg *Challenge) error { + if accountURI == "" { + return errors.New("ACME account URI cannot be empty") + } + + chlg.accountURI = accountURI + + return nil + } +} + +// WithIssuerDomainName forces the issuer-domain-name used for dns-persist-01. +// When set, it overrides automatic issuer selection and must match one of the +// issuer-domain-names offered in the ACME challenge. User input is normalized +// and validated at configuration time. +func WithIssuerDomainName(issuerDomainName string) ChallengeOption { + return func(chlg *Challenge) error { + if issuerDomainName == "" { + return nil + } + + normalized, err := normalizeUserSuppliedIssuerDomainName(issuerDomainName) + if err != nil { + return err + } + + err = validateIssuerDomainName(normalized) + if err != nil { + return err + } + + chlg.userSuppliedIssuerDomainName = normalized + + return nil + } +} + +// WithPersistUntil sets the optional persistUntil value used when constructing +// dns-persist-01 TXT records. +func WithPersistUntil(persistUntil time.Time) ChallengeOption { + return func(chlg *Challenge) error { + if persistUntil.IsZero() { + return errors.New("persistUntil cannot be zero") + } + + chlg.persistUntil = persistUntil.UTC().Truncate(time.Second) + + return nil + } +} + +// WithPropagationTimeout overrides the propagation timeout duration. +func WithPropagationTimeout(timeout time.Duration) ChallengeOption { + return func(chlg *Challenge) error { + if timeout <= 0 { + return errors.New("propagation timeout must be positive") + } + + chlg.propagationTimeout = timeout + + return nil + } +} + +// WithPropagationInterval overrides the propagation polling interval. +func WithPropagationInterval(interval time.Duration) ChallengeOption { + return func(chlg *Challenge) error { + if interval <= 0 { + return errors.New("propagation interval must be positive") + } + + chlg.propagationInterval = interval + + return nil + } +} + +// Solve validates the dns-persist-01 challenge by prompting the user to create +// the required TXT record (if necessary) then performing propagation checks (or +// a wait-only delay) before notifying the ACME server. +// +//nolint:gocyclo // challenge flow has several required branches (reuse/manual/wait/propagation/validate). +func (c *Challenge) Solve(ctx context.Context, authz acme.Authorization) error { + if c.resolver == nil { + return errors.New("dnspersist01: resolver is nil") + } + + domain := authz.Identifier.Value + if domain == "" { + return errors.New("dnspersist01: empty identifier") + } + + chlng, err := challenge.FindChallenge(challenge.DNSPersist01, authz) + if err != nil { + return err + } + + err = validateIssuerDomainNames(chlng) + if err != nil { + return fmt.Errorf("dnspersist01: %w", err) + } + + fqdn := GetAuthorizationDomainName(domain) + + result, err := c.resolver.LookupTXT(fqdn) + if err != nil { + return err + } + + issuerDomainName, err := c.selectIssuerDomainName(chlng.IssuerDomainNames, result.Records, authz.Wildcard) + if err != nil { + return fmt.Errorf("dnspersist01: %w", err) + } + + matcher := func(records []TXTRecord) bool { + return c.hasMatchingRecord(records, issuerDomainName, authz.Wildcard) + } + + if !matcher(result.Records) { + info, infoErr := GetChallengeInfo(domain, issuerDomainName, c.accountURI, authz.Wildcard, c.persistUntil) + if infoErr != nil { + return infoErr + } + + displayRecordCreationInstructions(info.FQDN, info.Value) + + waitErr := waitForUser() + if waitErr != nil { + return waitErr + } + } else { + fmt.Printf("dnspersist01: Found existing matching TXT record for %s, no need to create a new one\n", fqdn) + } + + timeout := c.propagationTimeout + interval := c.propagationInterval + + log.Info("acme: Checking DNS-PERSIST-01 record propagation.", + log.DomainAttr(domain), slog.String("nameservers", strings.Join(c.getRecursiveNameservers(), ",")), + ) + + time.Sleep(interval) + + err = wait.For("propagation", timeout, interval, func() (bool, error) { + ok, callErr := c.preCheck.call(domain, fqdn, matcher, c.checkDNSPropagation) + if !ok || callErr != nil { + log.Info("acme: Waiting for DNS-PERSIST-01 record propagation.", log.DomainAttr(domain)) + } + + return ok, callErr + }) + if err != nil { + return err + } + + return c.validate(ctx, c.core, domain, chlng) +} + +func (c *Challenge) getRecursiveNameservers() []string { + if c == nil || len(c.recursiveNameservers) == 0 { + return DefaultNameservers() + } + + return slices.Clone(c.recursiveNameservers) +} + +// GetAuthorizationDomainName returns the fully-qualified DNS label used by the +// dns-persist-01 challenge for the given domain. +func GetAuthorizationDomainName(domain string) string { + return dns.Fqdn(validationLabel + "." + domain) +} + +// GetChallengeInfo returns information used to create a DNS TXT record which +// can fulfill the `dns-persist-01` challenge. Domain, issuerDomainName, and +// accountURI parameters are required. Wildcard and persistUntil parameters are +// optional. +func GetChallengeInfo(domain, issuerDomainName, accountURI string, wildcard bool, persistUntil time.Time) (ChallengeInfo, error) { + if domain == "" { + return ChallengeInfo{}, errors.New("dnspersist01: domain cannot be empty") + } + + value, err := BuildIssueValue(issuerDomainName, accountURI, wildcard, persistUntil) + if err != nil { + return ChallengeInfo{}, err + } + + return ChallengeInfo{ + FQDN: GetAuthorizationDomainName(domain), + Value: value, + IssuerDomainName: issuerDomainName, + }, nil +} + +// validateIssuerDomainNames validates the ACME challenge "issuer-domain-names" +// array for dns-persist-01. +// +// Rules enforced: +// - The array is required and must contain at least 1 entry. +// - The array must not contain more than 10 entries; larger arrays are +// treated as malformed challenges and rejected. +// +// Each issuer-domain-name must be a normalized domain name: +// - represented in A-label (Punycode, RFC5890) form +// - all lowercase +// - no trailing dot +// - maximum total length of 253 octets +// +// The returned list is intended for issuer selection when constructing or +// matching dns-persist-01 TXT records. The challenge can be satisfied by using +// any one valid issuer-domain-name from this list. +func validateIssuerDomainNames(chlng acme.Challenge) error { + if len(chlng.IssuerDomainNames) == 0 { + return errors.New("issuer-domain-names missing from the challenge") + } + + if len(chlng.IssuerDomainNames) > 10 { + return errors.New(" issuer-domain-names exceeds maximum length of 10") + } + + for _, issuerDomainName := range chlng.IssuerDomainNames { + err := validateIssuerDomainName(issuerDomainName) + if err != nil { + return err + } + } + + return nil +} + +// selectIssuerDomainName selects the issuer-domain-name to use for a +// dns-persist-01 challenge. If the user has supplied an issuer-domain-name, it +// is used after verifying that it is offered by the ACME challenge. Otherwise, +// the first issuer-domain-name with a matching TXT record is selected. If no +// issuer-domain-name has a matching TXT record, a deterministic default +// issuer-domain-name is selected using lexicographic ordering. +func (c *Challenge) selectIssuerDomainName(challIssuers []string, records []TXTRecord, wildcard bool) (string, error) { + if len(challIssuers) == 0 { + return "", errors.New("issuer-domain-names missing from the challenge") + } + + sortedIssuers := slices.Clone(challIssuers) + sort.Strings(sortedIssuers) + + if c.userSuppliedIssuerDomainName != "" { + if !slices.Contains(sortedIssuers, c.userSuppliedIssuerDomainName) { + return "", fmt.Errorf("provided issuer-domain-name %q not offered by the challenge", c.userSuppliedIssuerDomainName) + } + + return c.userSuppliedIssuerDomainName, nil + } + + for _, issuerDomainName := range sortedIssuers { + if c.hasMatchingRecord(records, issuerDomainName, wildcard) { + return issuerDomainName, nil + } + } + + return sortedIssuers[0], nil +} + +func (c *Challenge) hasMatchingRecord(records []TXTRecord, issuerDomainName string, wildcard bool) bool { + for _, record := range records { + parsed, err := ParseIssueValue(record.Value) + if err != nil { + continue + } + + if parsed.IssuerDomainName != issuerDomainName { + continue + } + + if parsed.AccountURI != c.accountURI { + continue + } + + if wildcard && !strings.EqualFold(parsed.Policy, policyWildcard) { + continue + } + + if c.persistUntil.IsZero() { + if !parsed.PersistUntil.IsZero() { + continue + } + } else if parsed.PersistUntil.IsZero() || !parsed.PersistUntil.Equal(c.persistUntil) { + continue + } + + return true + } + + return false +} diff --git a/challenge/dnspersist01/dns_persist_challenge_manual.go b/challenge/dnspersist01/dns_persist_challenge_manual.go new file mode 100644 index 000000000..00591b9d8 --- /dev/null +++ b/challenge/dnspersist01/dns_persist_challenge_manual.go @@ -0,0 +1,61 @@ +package dnspersist01 + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +func displayRecordCreationInstructions(fqdn, value string) { + fmt.Printf("dnspersist01: Please create a TXT record with the following value:\n") + fmt.Printf("%s IN TXT %s\n", fqdn, formatTXTValue(value)) + fmt.Printf("dnspersist01: Press 'Enter' once the record is available\n") +} + +// formatTXTValue formats a TXT record value for display, splitting it into +// multiple quoted strings if it exceeds 255 octets, as per RFC 1035. +func formatTXTValue(value string) string { + chunks := splitTXTValue(value) + if len(chunks) == 1 { + return fmt.Sprintf("%q", chunks[0]) + } + + parts := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + parts = append(parts, fmt.Sprintf("%q", chunk)) + } + + return strings.Join(parts, " ") +} + +// splitTXTValue splits a TXT value into RFC 1035 chunks of +// at most 255 octets so long TXT values can be represented as multiple strings +// in one RR. +func splitTXTValue(value string) []string { + const maxTXTStringOctets = 255 + if len(value) <= maxTXTStringOctets { + return []string{value} + } + + var chunks []string + for len(value) > maxTXTStringOctets { + chunks = append(chunks, value[:maxTXTStringOctets]) + value = value[maxTXTStringOctets:] + } + + if value != "" { + chunks = append(chunks, value) + } + + return chunks +} + +func waitForUser() error { + _, err := bufio.NewReader(os.Stdin).ReadBytes('\n') + if err != nil { + return fmt.Errorf("dnspersist01: %w", err) + } + + return nil +} diff --git a/challenge/dnspersist01/dns_persist_challenge_manual_test.go b/challenge/dnspersist01/dns_persist_challenge_manual_test.go new file mode 100644 index 000000000..a08e37b60 --- /dev/null +++ b/challenge/dnspersist01/dns_persist_challenge_manual_test.go @@ -0,0 +1,40 @@ +package dnspersist01 + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_formatTXTValue(t *testing.T) { + longValue := strings.Repeat("z", 256) + + testCases := []struct { + desc string + value string + expected string + }{ + { + desc: "single quoted string", + value: "abc", + expected: `"abc"`, + }, + { + desc: "split and quoted across chunks", + value: longValue, + expected: fmt.Sprintf("%q %q", longValue[:255], longValue[255:]), + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + actual := formatTXTValue(test.value) + + assert.Equal(t, test.expected, actual) + }) + } +} diff --git a/challenge/dnspersist01/dns_persist_challenge_test.go b/challenge/dnspersist01/dns_persist_challenge_test.go new file mode 100644 index 000000000..7162bb3d4 --- /dev/null +++ b/challenge/dnspersist01/dns_persist_challenge_test.go @@ -0,0 +1,449 @@ +package dnspersist01 + +import ( + "strings" + "testing" + "time" + + "github.com/go-acme/lego/v5/acme" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetChallengeInfo(t *testing.T) { + testCases := []struct { + desc string + domain string + issuerDomainName string + accountURI string + wildcard bool + persistUntil time.Time + expected ChallengeInfo + expectErr string + }{ + { + desc: "basic", + domain: "example.com", + issuerDomainName: "authority.example", + accountURI: "https://ca.example/acct/123", + expected: ChallengeInfo{ + FQDN: "_validation-persist.example.com.", + Value: "authority.example; accounturi=https://ca.example/acct/123", + IssuerDomainName: "authority.example", + }, + }, + { + desc: "subdomain", + domain: "api.example.com", + issuerDomainName: "authority.example", + accountURI: "https://ca.example/acct/123", + expected: ChallengeInfo{ + FQDN: "_validation-persist.api.example.com.", + Value: "authority.example; accounturi=https://ca.example/acct/123", + IssuerDomainName: "authority.example", + }, + }, + { + desc: "wildcard with normalized issuer", + domain: "example.com", + issuerDomainName: "authority.example", + accountURI: "https://ca.example/acct/123", + wildcard: true, + expected: ChallengeInfo{ + FQDN: "_validation-persist.example.com.", + Value: "authority.example; accounturi=https://ca.example/acct/123; policy=wildcard", + IssuerDomainName: "authority.example", + }, + }, + { + desc: "uppercase issuer is rejected", + domain: "example.com", + issuerDomainName: "Authority.Example.", + accountURI: "https://ca.example/acct/123", + expectErr: "issuer-domain-name must be lowercase", + }, + { + desc: "unicode issuer is rejected", + domain: "example.com", + issuerDomainName: "bücher.example", + accountURI: "https://ca.example/acct/123", + expectErr: "must be a lowercase LDH label", + }, + { + desc: "issuer with trailing dot is rejected", + domain: "example.com", + issuerDomainName: "authority.example.", + accountURI: "https://ca.example/acct/123", + expectErr: "issuer-domain-name must not have a trailing dot", + }, + { + desc: "issuer with empty label is rejected", + domain: "example.com", + issuerDomainName: "authority..example", + accountURI: "https://ca.example/acct/123", + expectErr: "issuer-domain-name contains an empty label", + }, + { + desc: "issuer label length over 63 is rejected", + domain: "example.com", + issuerDomainName: strings.Repeat("a", 64) + ".example", + accountURI: "https://ca.example/acct/123", + expectErr: "issuer-domain-name label exceeds the maximum length of 63 octets", + }, + { + desc: "issuer with malformed punycode a-label is rejected", + domain: "example.com", + issuerDomainName: "xn--a.example", + accountURI: "https://ca.example/acct/123", + expectErr: "issuer-domain-name must be represented in A-label format:", + }, + { + desc: "includes persistUntil", + domain: "example.com", + issuerDomainName: "authority.example", + accountURI: "https://ca.example/acct/123", + wildcard: true, + persistUntil: time.Unix(4102444800, 0).UTC(), + expected: ChallengeInfo{ + FQDN: "_validation-persist.example.com.", + Value: "authority.example; accounturi=https://ca.example/acct/123; policy=wildcard; persistUntil=4102444800", + IssuerDomainName: "authority.example", + }, + }, + { + desc: "empty domain", + domain: "", + issuerDomainName: "authority.example", + accountURI: "https://ca.example/acct/123", + expectErr: "domain cannot be empty", + }, + { + desc: "empty account uri", + domain: "example.com", + issuerDomainName: "authority.example", + accountURI: "", + expectErr: "ACME account URI cannot be empty", + }, + { + desc: "invalid issuer", + domain: "example.com", + issuerDomainName: "ca_.example", + accountURI: "https://ca.example/acct/123", + expectErr: "must be a lowercase LDH label", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + actual, err := GetChallengeInfo(test.domain, test.issuerDomainName, test.accountURI, test.wildcard, test.persistUntil) + if test.expectErr != "" { + require.Error(t, err) + assert.ErrorContains(t, err, test.expectErr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, test.expected, actual) + }) + } +} + +func TestValidateIssuerDomainNames(t *testing.T) { + testCases := []struct { + desc string + issuers []string + assert assert.ErrorAssertionFunc + }{ + { + desc: "missing issuers", + assert: assert.Error, + }, + { + desc: "too many issuers", + issuers: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"}, + assert: assert.Error, + }, + { + desc: "valid issuer", + issuers: []string{"ca.example"}, + assert: assert.NoError, + }, + { + desc: "issuer all uppercase", + issuers: []string{"CA.EXAMPLE"}, + assert: assert.Error, + }, + { + desc: "issuer contains underscore", + issuers: []string{"ca_.example"}, + assert: assert.Error, + }, + { + desc: "issuer not in A-label format", + issuers: []string{"bücher.example"}, + assert: assert.Error, + }, + { + desc: "issuer too long", + issuers: []string{strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + "." + strings.Repeat("c", 63) + "." + strings.Repeat("d", 63)}, + assert: assert.Error, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + err := validateIssuerDomainNames(acme.Challenge{IssuerDomainNames: test.issuers}) + test.assert(t, err) + }) + } +} + +func TestWithIssuerDomainName(t *testing.T) { + testCases := []struct { + desc string + input string + expected string + expectErr bool + }{ + { + desc: "normalizes uppercase and trailing dot", + input: "CA.EXAMPLE.", + expected: "ca.example", + }, + { + desc: "normalizes idna issuer", + input: "BÜCHER.example", + expected: "xn--bcher-kva.example", + }, + { + desc: "rejects invalid issuer", + input: "ca_.example", + expectErr: true, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + chlg := &Challenge{} + + err := WithIssuerDomainName(test.input)(chlg) + if test.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected, chlg.userSuppliedIssuerDomainName) + }) + } +} + +func TestChallenge_selectIssuerDomainName(t *testing.T) { + testCases := []struct { + desc string + issuers []string + records []TXTRecord + wildcard bool + overrideIssuerDomainName string + expectIssuerDomainName string + expectErr bool + }{ + { + desc: "default uses sorted first", + issuers: []string{"ca.example", "backup.example"}, + expectIssuerDomainName: "backup.example", + }, + { + desc: "default prefers existing matching record", + issuers: []string{ + "ca.example", "backup.example", + }, + records: []TXTRecord{ + {Value: mustChallengeValue(t, "ca.example", "https://authority.example/acct/123", false, time.Time{})}, + }, + expectIssuerDomainName: "ca.example", + }, + { + desc: "override still wins over matching existing record", + issuers: []string{ + "ca.example", "backup.example", + }, + records: []TXTRecord{ + {Value: mustChallengeValue(t, "ca.example", "https://authority.example/acct/123", false, time.Time{})}, + }, + overrideIssuerDomainName: "backup.example", + expectIssuerDomainName: "backup.example", + }, + { + desc: "override not offered in challenge", + issuers: []string{"ca.example"}, + overrideIssuerDomainName: "other.example", + expectErr: true, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + chlg := &Challenge{ + accountURI: "https://authority.example/acct/123", + userSuppliedIssuerDomainName: test.overrideIssuerDomainName, + } + + issuer, err := chlg.selectIssuerDomainName(test.issuers, test.records, test.wildcard) + if test.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expectIssuerDomainName, issuer) + }) + } +} + +func TestChallenge_hasMatchingRecord(t *testing.T) { + expiredPersistUntil := time.Unix(1700000000, 0).UTC() + futurePersistUntil := time.Unix(4102444800, 0).UTC() + + testCases := []struct { + desc string + records []TXTRecord + issuer string + wildcard bool + requiredPersistUTC time.Time + assert assert.BoolAssertionFunc + }{ + { + desc: "match basic", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", false, time.Time{})}}, + issuer: "ca.example", + assert: assert.True, + }, + { + desc: "issuer mismatch", + records: []TXTRecord{{Value: mustChallengeValue(t, "other.example", "acc", false, time.Time{})}}, + issuer: "ca.example", + assert: assert.False, + }, + { + desc: "account mismatch", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "other", false, time.Time{})}}, + issuer: "ca.example", + assert: assert.False, + }, + { + desc: "wildcard requires policy", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", false, time.Time{})}}, + issuer: "ca.example", + wildcard: true, + assert: assert.False, + }, + { + desc: "wildcard match", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", true, time.Time{})}}, + issuer: "ca.example", + wildcard: true, + assert: assert.True, + }, + { + desc: "policy wildcard allowed for non-wildcard", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", true, time.Time{})}}, + issuer: "ca.example", + wildcard: false, + assert: assert.True, + }, + { + desc: "matching malformed and matching valid record succeeds", + records: []TXTRecord{ + {Value: "ca.example;accounturi=acc;accounturi=other"}, + {Value: "ca.example;accounturi=acc"}, + }, + issuer: "ca.example", + assert: assert.True, + }, + { + desc: "wildcard accepts case-insensitive policy value", + records: []TXTRecord{{Value: "ca.example;accounturi=acc;policy=wIlDcArD"}}, + issuer: "ca.example", + wildcard: true, + assert: assert.True, + }, + { + desc: "wildcard policy mismatch is not a match", + records: []TXTRecord{{Value: "ca.example;accounturi=acc;policy=notwildcard"}}, + issuer: "ca.example", + wildcard: true, + assert: assert.False, + }, + { + desc: "persistUntil present without requirement is not a match", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", false, expiredPersistUntil)}}, + issuer: "ca.example", + assert: assert.False, + }, + { + desc: "future persistUntil without requirement is not a match", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", false, futurePersistUntil)}}, + issuer: "ca.example", + assert: assert.False, + }, + { + desc: "required persistUntil matches", + records: []TXTRecord{{Value: "ca.example;accounturi=acc;persistUntil=4102444800"}}, + issuer: "ca.example", + requiredPersistUTC: time.Unix(4102444800, 0).UTC(), + assert: assert.True, + }, + { + desc: "required persistUntil matches even when expired", + records: []TXTRecord{{Value: mustChallengeValue(t, "ca.example", "acc", false, expiredPersistUntil)}}, + issuer: "ca.example", + requiredPersistUTC: expiredPersistUntil, + assert: assert.True, + }, + { + desc: "required persistUntil mismatch", + records: []TXTRecord{{Value: "ca.example;accounturi=acc;persistUntil=4102444801"}}, + issuer: "ca.example", + requiredPersistUTC: time.Unix(4102444800, 0).UTC(), + assert: assert.False, + }, + { + desc: "required persistUntil missing", + records: []TXTRecord{{Value: "ca.example;accounturi=acc"}}, + issuer: "ca.example", + requiredPersistUTC: time.Unix(4102444800, 0).UTC(), + assert: assert.False, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + chlg := &Challenge{ + accountURI: "acc", + persistUntil: test.requiredPersistUTC, + } + + match := chlg.hasMatchingRecord(test.records, test.issuer, test.wildcard) + + test.assert(t, match) + }) + } +} + +func mustChallengeValue(t *testing.T, issuerDomainName, accountURI string, wildcard bool, persistUntil time.Time) string { + t.Helper() + + info, err := GetChallengeInfo("example.com", issuerDomainName, accountURI, wildcard, persistUntil) + require.NoError(t, err) + + return info.Value +} diff --git a/challenge/dnspersist01/issue_values.go b/challenge/dnspersist01/issue_values.go new file mode 100644 index 000000000..7aa5ef2e1 --- /dev/null +++ b/challenge/dnspersist01/issue_values.go @@ -0,0 +1,157 @@ +package dnspersist01 + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +const ( + policyWildcard = "wildcard" + paramAccountURI = "accounturi" + paramPolicy = "policy" + paramPersistUntil = "persistuntil" +) + +// IssueValue represents a parsed dns-persist-01 issue-value. +type IssueValue struct { + IssuerDomainName string + AccountURI string + Policy string + PersistUntil time.Time +} + +// BuildIssueValue constructs an RFC 8659 issue-value for a dns-persist-01 TXT +// record. issuerDomainName and accountURI are required. wildcard and +// persistUntil are optional. +func BuildIssueValue(issuerDomainName, accountURI string, wildcard bool, persistUntil time.Time) (string, error) { + if accountURI == "" { + return "", errors.New("dnspersist01: ACME account URI cannot be empty") + } + + err := validateIssuerDomainName(issuerDomainName) + if err != nil { + return "", fmt.Errorf("dnspersist01: %w", err) + } + + value := issuerDomainName + "; " + paramAccountURI + "=" + accountURI + + if wildcard { + value += "; " + paramPolicy + "=" + policyWildcard + } + + if !persistUntil.IsZero() { + value += fmt.Sprintf("; persistUntil=%d", persistUntil.UTC().Unix()) + } + + return value, nil +} + +// trimWSP trims RFC 5234 WSP (SP / HTAB) characters from both ends of a +// string, as referenced by RFC 8659. +func trimWSP(s string) string { + return strings.TrimFunc(s, func(r rune) bool { + return r == ' ' || r == '\t' + }) +} + +// ParseIssueValue parses the issuer-domain-name and parameters for an RFC +// 8659 issue-value TXT record and returns the extracted fields. It returns +// an error if any portion of the value is malformed. +// +//nolint:gocyclo // parsing and validating tagged parameters requires branching +func ParseIssueValue(value string) (IssueValue, error) { + fields := strings.Split(value, ";") + + issuerDomainName := trimWSP(fields[0]) + if issuerDomainName == "" { + return IssueValue{}, errors.New("missing issuer-domain-name") + } + + parsed := IssueValue{ + IssuerDomainName: issuerDomainName, + } + + // Parse parameters (with optional surrounding WSP). + seenTags := map[string]bool{} + + for _, raw := range fields[1:] { + part := trimWSP(raw) + if part == "" { + return IssueValue{}, errors.New("empty parameter or trailing semicolon provided") + } + + // Capture each tag=value pair. + tag, val, found := strings.Cut(part, "=") + if !found { + return IssueValue{}, fmt.Errorf("malformed parameter %q should be tag=value pair", part) + } + + tag = trimWSP(tag) + val = trimWSP(val) + + if tag == "" { + return IssueValue{}, fmt.Errorf("malformed parameter %q, empty tag", part) + } + + canonicalTag := strings.ToLower(tag) + if seenTags[canonicalTag] { + return IssueValue{}, fmt.Errorf("duplicate parameter %q", tag) + } + + seenTags[canonicalTag] = true + // Ensure values contain no whitespace/control/non-ASCII characters. + for _, r := range val { + if (r >= 0x21 && r <= 0x3A) || (r >= 0x3C && r <= 0x7E) { + continue + } + + return IssueValue{}, fmt.Errorf("malformed value %q for tag %q", val, tag) + } + + // Finally, capture expected tag values. + // + // Note: according to RFC 8659 matching of tags is case insensitive. + switch canonicalTag { + case paramAccountURI: + if val == "" { + return IssueValue{}, fmt.Errorf("empty value provided for mandatory %q", paramAccountURI) + } + + parsed.AccountURI = val + + case paramPolicy: + // Per the dns-persist-01 specification, if the policy tag is + // present parameter's tag and defined values MUST be treated as + // case-insensitive. + if val != "" && !strings.EqualFold(val, policyWildcard) { + // If the policy parameter's value is anything other than + // "wildcard", the a CA MUST proceed as if the policy parameter + // were not present. + val = "" + } + + parsed.Policy = val + + case paramPersistUntil: + ts, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return IssueValue{}, fmt.Errorf("malformed %q: %w", paramPersistUntil, err) + } + + parsed.PersistUntil = time.Unix(ts, 0).UTC() + + default: + // Unknown parameters are permitted but not currently consumed. + } + } + + return parsed, nil +} + +// Pointer returns a pointer to v. +// TODO(ldez) factorize. +// TODO(ldez) it must be replaced with the builtin 'new' function when min Go 1.26. +func Pointer[T any](v T) *T { return &v } diff --git a/challenge/dnspersist01/issue_values_test.go b/challenge/dnspersist01/issue_values_test.go new file mode 100644 index 000000000..253710204 --- /dev/null +++ b/challenge/dnspersist01/issue_values_test.go @@ -0,0 +1,208 @@ +package dnspersist01 + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildIssueValue(t *testing.T) { + testCases := []struct { + desc string + issuer string + accountURI string + wildcard bool + persistUTC time.Time + expect string + expectErrContains string + }{ + { + desc: "basic", + issuer: "authority.example", + accountURI: "https://authority.example/acct/123", + expect: "authority.example; accounturi=https://authority.example/acct/123", + }, + { + desc: "with persistUntil", + issuer: "authority.example", + accountURI: "https://authority.example/acct/123", + wildcard: true, + persistUTC: time.Unix(4102444800, 0).UTC(), + expect: "authority.example; accounturi=https://authority.example/acct/123; policy=wildcard; persistUntil=4102444800", + }, + { + desc: "missing account uri", + issuer: "authority.example", + expectErrContains: "ACME account URI cannot be empty", + }, + { + desc: "invalid issuer", + issuer: "Authority.Example.", + accountURI: "https://authority.example/acct/123", + expectErrContains: "issuer-domain-name must be lowercase", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + actual, err := BuildIssueValue(test.issuer, test.accountURI, test.wildcard, test.persistUTC) + if test.expectErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, test.expectErrContains) + + return + } + + require.NoError(t, err) + assert.Equal(t, test.expect, actual) + }) + } +} + +func TestParseIssueValue(t *testing.T) { + testCases := []struct { + desc string + value string + expected IssueValue + expectErrContains string + }{ + { + desc: "basic", + value: "authority.example; accounturi=https://authority.example/acct/123", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + }, + }, + { + desc: "wildcard policy is case-insensitive", + value: "authority.example; accounturi=https://authority.example/acct/123; policy=wIlDcArD", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + Policy: "wIlDcArD", + }, + }, + { + desc: "unknown param", + value: "authority.example; accounturi=https://authority.example/acct/123; extra=value", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + }, + }, + { + desc: "unknown tag with empty value", + value: "authority.example; accounturi=https://authority.example/acct/123; foo=", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + }, + }, + { + desc: "unknown tags with unusual formatting are ignored", + value: "authority.example;accounturi=https://authority.example/acct/123;bad tag=value;\nweird=\\x01337", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + }, + }, + { + desc: "all known fields with heavy whitespace", + value: " authority.example ; accounturi = https://authority.example/acct/123 ; policy = wildcard ; persistUntil = 4102444800 ", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + Policy: "wildcard", + PersistUntil: time.Unix(4102444800, 0).UTC(), + }, + }, + { + desc: "policy other than wildcard is treated as absent", + value: "authority.example; accounturi=https://authority.example/acct/123; policy=notwildcard", + expected: IssueValue{ + IssuerDomainName: "authority.example", + AccountURI: "https://authority.example/acct/123", + }, + }, + { + desc: "missing accounturi", + value: "authority.example", + expected: IssueValue{ + IssuerDomainName: "authority.example", + }, + }, + { + desc: "missing issuer", + value: "; accounturi=https://authority.example/acct/123", + expectErrContains: "missing issuer-domain-name", + }, + { + desc: "invalid parameter", + value: "authority.example; badparam", + expectErrContains: `malformed parameter "badparam" should be tag=value pair`, + }, + { + desc: "empty tag is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123; =abc", + expectErrContains: `malformed parameter "=abc", empty tag`, + }, + { + desc: "empty accounturi is malformed", + value: "authority.example; accounturi=", + expectErrContains: `empty value provided for mandatory "accounturi"`, + }, + { + desc: "invalid value character is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123; policy=wild card", + expectErrContains: `malformed value "wild card" for tag "policy"`, + }, + { + desc: "persistUntil non unix timestamp is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123; persistUntil=not-a-unix-timestamp", + expectErrContains: `malformed "persistuntil": strconv.ParseInt: parsing "not-a-unix-timestamp": invalid syntax`, + }, + { + desc: "duplicate unknown parameter is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123; foo=bar; foo=baz", + expectErrContains: `duplicate parameter "foo"`, + }, + { + desc: "duplicate parameter is case-insensitive", + value: "authority.example; ACCOUNTURI=https://authority.example/acct/123; accounturi=https://authority.example/acct/456", + expectErrContains: `duplicate parameter "accounturi"`, + }, + { + desc: "trailing semicolon is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123;", + expectErrContains: "empty parameter or trailing semicolon provided", + }, + { + desc: "empty persistUntil is malformed", + value: "authority.example; accounturi=https://authority.example/acct/123; persistUntil=", + expectErrContains: `malformed "persistuntil": strconv.ParseInt: parsing "": invalid syntax`, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + parsed, err := ParseIssueValue(test.value) + if test.expectErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, test.expectErrContains) + + return + } + + require.NoError(t, err) + + assert.Equal(t, test.expected, parsed) + }) + } +} diff --git a/challenge/dnspersist01/issuer_domain_name.go b/challenge/dnspersist01/issuer_domain_name.go new file mode 100644 index 000000000..78e182bfd --- /dev/null +++ b/challenge/dnspersist01/issuer_domain_name.go @@ -0,0 +1,104 @@ +package dnspersist01 + +import ( + "errors" + "fmt" + "strings" + + "golang.org/x/net/idna" +) + +//nolint:gochecknoglobals // test seam for injecting IDNA conversion failures/variants. +var issuerDomainNameToASCII = idna.Lookup.ToASCII + +// validateIssuerDomainName validates a single issuer-domain-name according to +// the following rules: +// - lowercase only +// - no trailing dot +// - max 253 octets overall +// - non-empty labels, each max 63 octets +// - lowercase LDH label syntax +// - A-label (Punycode, RFC5890) +func validateIssuerDomainName(name string) error { + if name == "" { + return errors.New("issuer-domain-name cannot be empty") + } + + if strings.ToLower(name) != name { + return errors.New("issuer-domain-name must be lowercase") + } + + if strings.HasSuffix(name, ".") { + return errors.New("issuer-domain-name must not have a trailing dot") + } + + if len(name) > 253 { + return errors.New("issuer-domain-name exceeds the maximum length of 253 octets") + } + + labels := strings.SplitSeq(name, ".") + for label := range labels { + if label == "" { + return errors.New("issuer-domain-name contains an empty label") + } + + if len(label) > 63 { + return errors.New("issuer-domain-name label exceeds the maximum length of 63 octets") + } + + if !isLDHLabel(label) { + return fmt.Errorf("issuer-domain-name label %q must be a lowercase LDH label", label) + } + } + + ascii, err := issuerDomainNameToASCII(name) + if err != nil { + return fmt.Errorf("issuer-domain-name must be represented in A-label format: %w", err) + } + + if ascii != name { + return errors.New("issuer-domain-name must be represented in A-label format") + } + + return nil +} + +func isLDHLabel(label string) bool { + if label == "" { + return false + } + + if !isLowerAlphaNum(label[0]) || !isLowerAlphaNum(label[len(label)-1]) { + return false + } + + for i := range len(label) { + c := label[i] + if isLowerAlphaNum(c) || c == '-' { + continue + } + + return false + } + + return true +} + +func isLowerAlphaNum(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') +} + +// normalizeUserSuppliedIssuerDomainName normalizes a user supplied +// issuer-domain-name for comparison. Note: DO NOT normalize issuer-domain-names +// from the challenge, as they are expected to already be in the correct format. +func normalizeUserSuppliedIssuerDomainName(name string) (string, error) { + n := strings.TrimSpace(strings.TrimSuffix(name, ".")) + n = strings.ToLower(n) + + ascii, err := idna.Lookup.ToASCII(n) + if err != nil { + return "", fmt.Errorf("normalizing supplied issuer-domain-name %q: %w", n, err) + } + + return ascii, nil +} diff --git a/challenge/dnspersist01/issuer_domain_name_test.go b/challenge/dnspersist01/issuer_domain_name_test.go new file mode 100644 index 000000000..ffef4f1ee --- /dev/null +++ b/challenge/dnspersist01/issuer_domain_name_test.go @@ -0,0 +1,76 @@ +package dnspersist01 + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateIssuerDomainName_Errors(t *testing.T) { + testCases := []struct { + desc string + name string + expectErr string + }{ + { + desc: "trailing dot", + name: "authority.example.", + expectErr: "issuer-domain-name must not have a trailing dot", + }, + { + desc: "empty label", + name: "authority..example", + expectErr: "issuer-domain-name contains an empty label", + }, + { + desc: "label too long", + name: strings.Repeat("a", 64) + ".example", + expectErr: "issuer-domain-name label exceeds the maximum length of 63 octets", + }, + { + desc: "invalid a-label with idna error", + name: "xn--a.example", + expectErr: `issuer-domain-name must be represented in A-label format: idna: invalid label "\u0080"`, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + err := validateIssuerDomainName(test.name) + require.EqualError(t, err, test.expectErr) + }) + } +} + +func TestValidateIssuerDomainName_ErrorNonCanonicalALabel(t *testing.T) { + mockIssuerDomainNameToASCII(t, func(string) (string, error) { + return "different.example", nil + }) + + err := validateIssuerDomainName("authority.example") + require.EqualError(t, err, "issuer-domain-name must be represented in A-label format") +} + +func TestValidateIssuerDomainName_Valid(t *testing.T) { + mockIssuerDomainNameToASCII(t, func(name string) (string, error) { + return name, nil + }) + + err := validateIssuerDomainName("authority.example") + require.NoError(t, err) +} + +func mockIssuerDomainNameToASCII(t *testing.T, fn func(string) (string, error)) { + t.Helper() + + originalToASCII := issuerDomainNameToASCII + + t.Cleanup(func() { + issuerDomainNameToASCII = originalToASCII + }) + + issuerDomainNameToASCII = fn +} diff --git a/challenge/dnspersist01/mock_test.go b/challenge/dnspersist01/mock_test.go new file mode 100644 index 000000000..96809ae7a --- /dev/null +++ b/challenge/dnspersist01/mock_test.go @@ -0,0 +1,58 @@ +package dnspersist01 + +import ( + "context" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +func fakeNS(name, ns string) *dns.NS { + return &dns.NS{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 172800}, + Ns: ns, + } +} + +func fakeA(name, ip string) *dns.A { + return &dns.A{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 10}, + A: net.ParseIP(ip), + } +} + +func fakeTXT(name, value string, ttl uint32) *dns.TXT { + return &dns.TXT{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: ttl}, + Txt: []string{value}, + } +} + +// mockResolver modifies the default DNS resolver to use a custom network address during the test execution. +// IMPORTANT: it modifies global variables. +func mockResolver(t *testing.T, addr net.Addr) string { + t.Helper() + + _, port, err := net.SplitHostPort(addr.String()) + require.NoError(t, err) + + originalResolver := net.DefaultResolver + + t.Cleanup(func() { + net.DefaultResolver = originalResolver + }) + + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{Timeout: 1 * time.Second} + + return d.DialContext(ctx, network, addr.String()) + }, + } + + return port +} diff --git a/challenge/dnspersist01/precheck.go b/challenge/dnspersist01/precheck.go new file mode 100644 index 000000000..dde4a5866 --- /dev/null +++ b/challenge/dnspersist01/precheck.go @@ -0,0 +1,254 @@ +package dnspersist01 + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/miekg/dns" +) + +const defaultAuthoritativeNSPort = "53" + +// RecordMatcher returns true when the expected record is present. +type RecordMatcher func(records []TXTRecord) bool + +// PreCheckFunc checks DNS propagation before notifying ACME that the challenge is ready. +type PreCheckFunc func(fqdn string, matcher RecordMatcher) (bool, error) + +// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after +// the main check, put it in a loop, etc. +type WrapPreCheckFunc func(domain, fqdn string, matcher RecordMatcher, check PreCheckFunc) (bool, error) + +// WrapPreCheck Allow to define checks before notifying ACME that the challenge is ready. +func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.checkFunc = wrap + return nil + } +} + +// DisableAuthoritativeNssPropagationRequirement disables authoritative nameserver checks. +func DisableAuthoritativeNssPropagationRequirement() ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.requireAuthoritativeNssPropagation = false + return nil + } +} + +// DisableRecursiveNSsPropagationRequirement disables recursive nameserver checks. +func DisableRecursiveNSsPropagationRequirement() ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.requireRecursiveNssPropagation = false + return nil + } +} + +// AddRecursiveNameservers overrides recursive nameservers used for propagation checks. +func AddRecursiveNameservers(nameservers []string) ChallengeOption { + return func(chlg *Challenge) error { + chlg.recursiveNameservers = ParseNameservers(nameservers) + return nil + } +} + +// PropagationWait sleeps for the specified duration, optionally skipping checks. +func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption { + return WrapPreCheck(func(domain, fqdn string, matcher RecordMatcher, check PreCheckFunc) (bool, error) { + time.Sleep(wait) + + if skipCheck { + return true, nil + } + + return check(fqdn, matcher) + }) +} + +type preCheck struct { + // checks DNS propagation before notifying ACME that the DNS challenge is ready. + checkFunc WrapPreCheckFunc + + // require the TXT record to be propagated to all authoritative name servers + requireAuthoritativeNssPropagation bool + + // require the TXT record to be propagated to all recursive name servers + requireRecursiveNssPropagation bool +} + +func newPreCheck() preCheck { + return preCheck{ + requireAuthoritativeNssPropagation: true, + requireRecursiveNssPropagation: true, + } +} + +func (p preCheck) call(domain, fqdn string, matcher RecordMatcher, check PreCheckFunc) (bool, error) { + if p.checkFunc == nil { + return check(fqdn, matcher) + } + + return p.checkFunc(domain, fqdn, matcher, check) +} + +func (c *Challenge) checkDNSPropagation(fqdn string, matcher RecordMatcher) (bool, error) { + nameservers := c.getRecursiveNameservers() + + // Initial attempt to resolve at the recursive NS (require to get CNAME) + result, err := c.resolver.lookupTXT(fqdn, nameservers, true) + if err != nil { + return false, fmt.Errorf("initial recursive nameserver: %w", err) + } + + effectiveFQDN := dns.Fqdn(fqdn) + if len(result.CNAMEChain) > 0 { + effectiveFQDN = result.CNAMEChain[len(result.CNAMEChain)-1] + } + + if c.preCheck.requireRecursiveNssPropagation { + _, err = c.checkNameserversPropagation(effectiveFQDN, nameservers, false, true, matcher) + if err != nil { + return false, fmt.Errorf("recursive nameservers: %w", err) + } + } + + if !c.preCheck.requireAuthoritativeNssPropagation { + return true, nil + } + + authoritativeNss, err := lookupNameservers(effectiveFQDN, nameservers, c.resolver.Timeout) + if err != nil { + return false, err + } + + found, err := c.checkNameserversPropagation(effectiveFQDN, authoritativeNss, true, false, matcher) + if err != nil { + return found, fmt.Errorf("authoritative nameservers: %w", err) + } + + return found, nil +} + +func (c *Challenge) checkNameserversPropagation(fqdn string, nameservers []string, addPort, recursive bool, matcher RecordMatcher) (bool, error) { + for _, ns := range nameservers { + if addPort { + ns = net.JoinHostPort(ns, c.getAuthoritativeNSPort()) + } + + result, err := c.resolver.lookupTXT(fqdn, []string{ns}, recursive) + if err != nil { + return false, err + } + + if !matcher(result.Records) { + return false, fmt.Errorf("NS %s did not return a matching TXT record [fqdn: %s]: %s", ns, fqdn, result) + } + } + + return true, nil +} + +func (c *Challenge) getAuthoritativeNSPort() string { + if c == nil || c.authoritativeNSPort == "" { + return defaultAuthoritativeNSPort + } + + return c.authoritativeNSPort +} + +// lookupNameservers returns the authoritative nameservers for the given fqdn. +func lookupNameservers(fqdn string, nameservers []string, timeout time.Duration) ([]string, error) { + zone, err := findZoneByFqdn(fqdn, nameservers, timeout) + if err != nil { + return nil, fmt.Errorf("could not find zone: %w", err) + } + + r, err := dnsQueryWithTimeout(zone, dns.TypeNS, nameservers, true, timeout) + if err != nil { + return nil, fmt.Errorf("NS call failed: %w", err) + } + + var authoritativeNss []string + + for _, rr := range r.Answer { + if ns, ok := rr.(*dns.NS); ok { + authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) + } + } + + if len(authoritativeNss) > 0 { + return authoritativeNss, nil + } + + return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone) +} + +func findZoneByFqdn(fqdn string, nameservers []string, timeout time.Duration) (string, error) { + var ( + err error + r *dns.Msg + ) + + for _, domain := range domainsSeq(fqdn) { + r, err = dnsQueryWithTimeout(domain, dns.TypeSOA, nameservers, true, timeout) + if err != nil { + continue + } + + if r == nil { + continue + } + + switch r.Rcode { + case dns.RcodeSuccess: + // Check if we got a SOA RR in the answer section + if len(r.Answer) == 0 { + continue + } + + // CNAME records cannot/should not exist at the root of a zone. + // So we skip a domain when a CNAME is found. + if dnsMsgContainsCNAME(r) { + continue + } + + for _, ans := range r.Answer { + if soa, ok := ans.(*dns.SOA); ok { + return soa.Hdr.Name, nil + } + } + case dns.RcodeNameError: + // NXDOMAIN + default: + // Any response code other than NOERROR and NXDOMAIN is treated as error + return "", &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r} + } + } + + return "", &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", dns.Fqdn(fqdn)), MsgOut: r, Err: err} +} + +func dnsMsgContainsCNAME(msg *dns.Msg) bool { + for _, ans := range msg.Answer { + if _, ok := ans.(*dns.CNAME); ok { + return true + } + } + + return false +} + +func domainsSeq(fqdn string) []string { + fqdn = dns.Fqdn(fqdn) + if fqdn == "" { + return nil + } + + var domains []string + for _, index := range dns.Split(fqdn) { + domains = append(domains, fqdn[index:]) + } + + return domains +} diff --git a/challenge/dnspersist01/precheck_test.go b/challenge/dnspersist01/precheck_test.go new file mode 100644 index 000000000..c85c25fa6 --- /dev/null +++ b/challenge/dnspersist01/precheck_test.go @@ -0,0 +1,95 @@ +package dnspersist01 + +import ( + "testing" + + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_preCheck_checkDNSPropagation(t *testing.T) { + addr := dnsmock.NewServer(). + Query("ns0.lego.localhost. A", + dnsmock.Answer(fakeA("ns0.lego.localhost.", "127.0.0.1"))). + Query("ns1.lego.localhost. A", + dnsmock.Answer(fakeA("ns1.lego.localhost.", "127.0.0.1"))). + Query("example.com. TXT", + dnsmock.Answer( + fakeTXT("example.com.", "one", 10), + fakeTXT("example.com.", "two", 10), + fakeTXT("example.com.", "three", 10), + fakeTXT("example.com.", "four", 10), + fakeTXT("example.com.", "five", 10), + ), + ). + Query("acme-staging.api.example.com. TXT", + dnsmock.Answer( + fakeTXT("acme-staging.api.example.com.", "one", 10), + fakeTXT("acme-staging.api.example.com.", "two", 10), + fakeTXT("acme-staging.api.example.com.", "three", 10), + fakeTXT("acme-staging.api.example.com.", "four", 10), + fakeTXT("acme-staging.api.example.com.", "five", 10), + ), + ). + Query("acme-staging.api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("api.example.com. SOA", dnsmock.Error(dns.RcodeNameError)). + Query("example.com. SOA", dnsmock.SOA("")). + Query("example.com. NS", + dnsmock.Answer( + fakeNS("example.com.", "ns0.lego.localhost."), + fakeNS("example.com.", "ns1.lego.localhost."), + ), + ). + Build(t) + + chlg := &Challenge{ + resolver: NewResolver([]string{addr.String()}), + preCheck: newPreCheck(), + recursiveNameservers: ParseNameservers([]string{addr.String()}), + authoritativeNSPort: mockResolver(t, addr), + } + + testCases := []struct { + desc string + fqdn string + value string + expectedError bool + }{ + { + desc: "success", + fqdn: "example.com.", + value: "four", + }, + { + desc: "no matching TXT record", + fqdn: "acme-staging.api.example.com.", + value: "fe01=", + expectedError: true, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + match := func(records []TXTRecord) bool { + for _, record := range records { + if record.Value == test.value { + return true + } + } + + return false + } + + ok, err := chlg.checkDNSPropagation(test.fqdn, match) + if test.expectedError { + require.Error(t, err) + assert.False(t, ok) + } else { + require.NoError(t, err) + assert.True(t, ok) + } + }) + } +} diff --git a/challenge/dnspersist01/resolver.go b/challenge/dnspersist01/resolver.go new file mode 100644 index 000000000..bda31b9e4 --- /dev/null +++ b/challenge/dnspersist01/resolver.go @@ -0,0 +1,318 @@ +package dnspersist01 + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "strings" + "time" + + "github.com/miekg/dns" +) + +const defaultResolvConf = "/etc/resolv.conf" + +// Resolver performs DNS lookups using the configured nameservers and timeout. +type Resolver struct { + Nameservers []string + Timeout time.Duration +} + +// TXTRecord captures a DNS TXT record value and its TTL. +type TXTRecord struct { + Value string + TTL uint32 +} + +// TXTResult contains TXT records and any CNAMEs followed during lookup. +type TXTResult struct { + Records []TXTRecord + CNAMEChain []string +} + +func (r TXTResult) String() string { + values := make([]string, 0, len(r.Records)) + + for _, record := range r.Records { + values = append(values, record.Value) + } + + return strings.Join(values, ",") +} + +// NewResolver creates a resolver with normalized nameservers and default timeout. +// If nameservers is empty, the system resolv.conf is used, falling back to defaults. +func NewResolver(nameservers []string) *Resolver { + if len(nameservers) == 0 { + nameservers = DefaultNameservers() + } + + return &Resolver{ + Nameservers: ParseNameservers(nameservers), + Timeout: defaultDNSTimeout, + } +} + +// DefaultNameservers returns resolvers from resolv.conf, falling back to defaults. +func DefaultNameservers() []string { + config, err := dns.ClientConfigFromFile(defaultResolvConf) + if err != nil || len(config.Servers) == 0 { + return defaultFallbackNameservers() + } + + return ParseNameservers(config.Servers) +} + +func defaultFallbackNameservers() []string { + return []string{ + "google-public-dns-a.google.com:53", + "google-public-dns-b.google.com:53", + } +} + +// ParseNameservers ensures all servers have a port number. +func ParseNameservers(servers []string) []string { + var resolvers []string + + for _, resolver := range servers { + if _, _, err := net.SplitHostPort(resolver); err != nil { + resolvers = append(resolvers, net.JoinHostPort(resolver, "53")) + } else { + resolvers = append(resolvers, resolver) + } + } + + return resolvers +} + +// LookupTXT resolves TXT records at fqdn. If CNAMEs are returned, they are +// followed up to 50 times to resolve TXT records. +func (r *Resolver) LookupTXT(fqdn string) (TXTResult, error) { + return r.lookupTXT(fqdn, r.Nameservers, true) +} + +func (r *Resolver) lookupTXT(fqdn string, nameservers []string, recursive bool) (TXTResult, error) { + var result TXTResult + + if r == nil { + return result, errors.New("resolver is nil") + } + + nameservers = ParseNameservers(nameservers) + if len(nameservers) == 0 { + return result, errors.New("empty list of nameservers") + } + + timeout := r.Timeout + if timeout <= 0 { + timeout = defaultDNSTimeout + } + + const maxCNAMEFollows = 50 + + name := dns.Fqdn(fqdn) + seen := map[string]struct{}{} + followed := 0 + + for { + if _, ok := seen[name]; ok { + return result, fmt.Errorf("CNAME loop detected for %s", name) + } + + seen[name] = struct{}{} + + msg, err := dnsQueryWithTimeout(name, dns.TypeTXT, nameservers, recursive, timeout) + if err != nil { + return result, err + } + + switch msg.Rcode { + case dns.RcodeSuccess: + records := extractTXT(msg, name) + if len(records) > 0 { + result.Records = records + return result, nil + } + + cname := extractCNAME(msg, name) + if cname == "" { + return result, nil + } + + if followed >= maxCNAMEFollows { + return result, nil + } + + result.CNAMEChain = append(result.CNAMEChain, cname) + name = cname + followed++ + case dns.RcodeNameError: + return result, nil + default: + return result, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", name), MsgOut: msg} + } + } +} + +func extractTXT(msg *dns.Msg, name string) []TXTRecord { + var records []TXTRecord + + for _, rr := range msg.Answer { + txt, ok := rr.(*dns.TXT) + if !ok { + continue + } + + if !strings.EqualFold(txt.Hdr.Name, name) { + continue + } + + records = append(records, TXTRecord{ + Value: strings.Join(txt.Txt, ""), + TTL: txt.Hdr.Ttl, + }) + } + + return records +} + +func extractCNAME(msg *dns.Msg, name string) string { + for _, rr := range msg.Answer { + cn, ok := rr.(*dns.CNAME) + if !ok { + continue + } + + if strings.EqualFold(cn.Hdr.Name, name) { + return cn.Target + } + } + + return "" +} + +func dnsQueryWithTimeout(fqdn string, rtype uint16, nameservers []string, recursive bool, timeout time.Duration) (*dns.Msg, error) { + m := createDNSMsg(fqdn, rtype, recursive) + + if len(nameservers) == 0 { + return nil, &DNSError{Message: "empty list of nameservers"} + } + + var ( + msg *dns.Msg + err error + errAll error + ) + + for _, ns := range nameservers { + msg, err = sendDNSQuery(m, ns, timeout) + if err == nil && len(msg.Answer) > 0 { + break + } + + errAll = errors.Join(errAll, err) + } + + if err != nil { + return msg, errAll + } + + return msg, nil +} + +func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + + if !recursive { + m.RecursionDesired = false + } + + return m +} + +func sendDNSQuery(m *dns.Msg, ns string, timeout time.Duration) (*dns.Msg, error) { + if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok { + tcp := &dns.Client{Net: "tcp", Timeout: timeout} + + msg, _, err := tcp.Exchange(m, ns) + if err != nil { + return msg, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} + } + + return msg, nil + } + + udp := &dns.Client{Net: "udp", Timeout: timeout} + msg, _, err := udp.Exchange(m, ns) + + if msg != nil && msg.Truncated { + tcp := &dns.Client{Net: "tcp", Timeout: timeout} + msg, _, err = tcp.Exchange(m, ns) + } + + if err != nil { + return msg, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} + } + + return msg, nil +} + +// DNSError is an error related to DNS calls. +type DNSError struct { + Message string + NS string + MsgIn *dns.Msg + MsgOut *dns.Msg + Err error +} + +func (d *DNSError) Error() string { + var details []string + if d.NS != "" { + details = append(details, "ns="+d.NS) + } + + formatQuestions := func(questions []dns.Question) string { + var parts []string + for _, question := range questions { + parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " ")) + } + + return strings.Join(parts, ";") + } + + if d.MsgIn != nil && len(d.MsgIn.Question) > 0 { + details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question))) + } + + if d.MsgOut != nil { + if d.MsgIn == nil || len(d.MsgIn.Question) == 0 { + details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question))) + } + + details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode]) + } + + msg := "DNS error" + if d.Message != "" { + msg = d.Message + } + + if d.Err != nil { + msg += ": " + d.Err.Error() + } + + if len(details) > 0 { + msg += " [" + strings.Join(details, ", ") + "]" + } + + return msg +} + +func (d *DNSError) Unwrap() error { + return d.Err +} diff --git a/challenge/dnspersist01/resolver_test.go b/challenge/dnspersist01/resolver_test.go new file mode 100644 index 000000000..898bbc2c1 --- /dev/null +++ b/challenge/dnspersist01/resolver_test.go @@ -0,0 +1,73 @@ +package dnspersist01 + +import ( + "testing" + + "github.com/go-acme/lego/v5/platform/tester/dnsmock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolver_LookupTXT(t *testing.T) { + fqdn := "_validation-persist.example.com." + + testCases := []struct { + desc string + serverBuilder *dnsmock.Builder + expected TXTResult + }{ + { + desc: "direct TXT", + serverBuilder: dnsmock.NewServer(). + Query(fqdn+" TXT", dnsmock.Answer(fakeTXT(fqdn, "value", 120))), + expected: TXTResult{ + Records: []TXTRecord{{Value: "value", TTL: 120}}, + }, + }, + { + desc: "cname to txt", + serverBuilder: dnsmock.NewServer(). + Query(fqdn+" TXT", dnsmock.CNAME("alias.example.com.")). + Query("alias.example.com. TXT", dnsmock.Answer(fakeTXT("alias.example.com.", "value", 60))), + expected: TXTResult{ + Records: []TXTRecord{{Value: "value", TTL: 60}}, + CNAMEChain: []string{"alias.example.com."}, + }, + }, + { + desc: "cname chain follows multiple hops", + serverBuilder: dnsmock.NewServer(). + Query(fqdn+" TXT", dnsmock.CNAME("alias.example.com.")). + Query("alias.example.com. TXT", dnsmock.CNAME("alias2.example.com.")). + Query("alias2.example.com. TXT", dnsmock.Answer(fakeTXT("alias2.example.com.", "value", 30))), + expected: TXTResult{ + Records: []TXTRecord{{Value: "value", TTL: 30}}, + CNAMEChain: []string{"alias.example.com.", "alias2.example.com."}, + }, + }, + { + desc: "nxdomain", + serverBuilder: dnsmock.NewServer(). + Query(fqdn+" TXT", dnsmock.Error(dns.RcodeNameError)), + }, + { + desc: "empty answer", + serverBuilder: dnsmock.NewServer(). + Query(fqdn+" TXT", dnsmock.Noop), + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + addr := test.serverBuilder.Build(t) + + resolver := NewResolver([]string{addr.String()}) + + result, err := resolver.LookupTXT(fqdn) + require.NoError(t, err) + + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/challenge/dnspersist01/resolver_unix.go b/challenge/dnspersist01/resolver_unix.go new file mode 100644 index 000000000..247cf2cea --- /dev/null +++ b/challenge/dnspersist01/resolver_unix.go @@ -0,0 +1,8 @@ +//go:build !windows + +package dnspersist01 + +import "time" + +// defaultDNSTimeout is used as the default DNS timeout on Unix-like systems. +const defaultDNSTimeout = 10 * time.Second diff --git a/challenge/dnspersist01/resolver_windows.go b/challenge/dnspersist01/resolver_windows.go new file mode 100644 index 000000000..81172fbd7 --- /dev/null +++ b/challenge/dnspersist01/resolver_windows.go @@ -0,0 +1,8 @@ +//go:build windows + +package dnspersist01 + +import "time" + +// defaultDNSTimeout is used as the default DNS timeout on Windows. +const defaultDNSTimeout = 20 * time.Second diff --git a/challenge/resolver/solver_manager.go b/challenge/resolver/solver_manager.go index 95219165b..de036ddc1 100644 --- a/challenge/resolver/solver_manager.go +++ b/challenge/resolver/solver_manager.go @@ -13,6 +13,7 @@ import ( "github.com/go-acme/lego/v5/acme/api" "github.com/go-acme/lego/v5/challenge" "github.com/go-acme/lego/v5/challenge/dns01" + "github.com/go-acme/lego/v5/challenge/dnspersist01" "github.com/go-acme/lego/v5/challenge/http01" "github.com/go-acme/lego/v5/challenge/tlsalpn01" "github.com/go-acme/lego/v5/log" @@ -21,9 +22,21 @@ import ( type byType []acme.Challenge -func (a byType) Len() int { return len(a) } -func (a byType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a byType) Less(i, j int) bool { return a[i].Type > a[j].Type } +func (a byType) Len() int { return len(a) } +func (a byType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byType) Less(i, j int) bool { + // When users configure both DNS and DNS-PERSIST-01, prefer DNS-01 to avoid + // unexpectedly selecting the manual-only DNS-PERSIST-01 workflow. + if a[i].Type == string(challenge.DNS01) && a[j].Type == string(challenge.DNSPersist01) { + return true + } + + if a[i].Type == string(challenge.DNSPersist01) && a[j].Type == string(challenge.DNS01) { + return false + } + + return a[i].Type > a[j].Type +} type SolverManager struct { core *api.Core @@ -55,6 +68,18 @@ func (c *SolverManager) SetDNS01Provider(p challenge.Provider, opts ...dns01.Cha return nil } +// SetDNSPersist01 configures the dns-persist-01 challenge solver. +func (c *SolverManager) SetDNSPersist01(opts ...dnspersist01.ChallengeOption) error { + chlg, err := dnspersist01.NewChallenge(c.core, validate, opts...) + if err != nil { + return err + } + + c.solvers[challenge.DNSPersist01] = chlg + + return nil +} + // Remove removes a challenge type from the available solvers. func (c *SolverManager) Remove(chlgType challenge.Type) { delete(c.solvers, chlgType) diff --git a/challenge/resolver/solver_manager_test.go b/challenge/resolver/solver_manager_test.go index 9e3da899c..b72dc2149 100644 --- a/challenge/resolver/solver_manager_test.go +++ b/challenge/resolver/solver_manager_test.go @@ -20,13 +20,13 @@ import ( func TestByType(t *testing.T) { challenges := []acme.Challenge{ - {Type: "dns-01"}, {Type: "tlsalpn-01"}, {Type: "http-01"}, + {Type: "dns-01"}, {Type: "dns-persist-01"}, {Type: "tlsalpn-01"}, {Type: "http-01"}, } sort.Sort(byType(challenges)) expected := []acme.Challenge{ - {Type: "tlsalpn-01"}, {Type: "http-01"}, {Type: "dns-01"}, + {Type: "tlsalpn-01"}, {Type: "http-01"}, {Type: "dns-01"}, {Type: "dns-persist-01"}, } assert.Equal(t, expected, challenges) diff --git a/cmd/cmd_renew.go b/cmd/cmd_renew.go index 076e2d83f..c258abe7e 100644 --- a/cmd/cmd_renew.go +++ b/cmd/cmd_renew.go @@ -90,7 +90,7 @@ func renew(ctx context.Context, cmd *cli.Command) error { return nil, fmt.Errorf("new client: %w", err) } - setupChallenges(cmd, client) + setupChallenges(cmd, client, account) return client, nil }) diff --git a/cmd/cmd_run.go b/cmd/cmd_run.go index 137539f55..fa3c1d024 100644 --- a/cmd/cmd_run.go +++ b/cmd/cmd_run.go @@ -59,8 +59,6 @@ func run(ctx context.Context, cmd *cli.Command) error { return fmt.Errorf("new client: %w", err) } - setupChallenges(cmd, client) - if account.Registration == nil { var reg *registration.Resource @@ -77,6 +75,8 @@ func run(ctx context.Context, cmd *cli.Command) error { fmt.Printf(rootPathWarningMessage, accountsStorage.GetRootPath()) } + setupChallenges(cmd, client, account) + certRes, err := obtainCertificate(ctx, cmd, client) if err != nil { // Make sure to return a non-zero exit code if ObtainSANCertificate returned at least one error. diff --git a/cmd/flags.go b/cmd/flags.go index b0a2aa6e6..77230392b 100644 --- a/cmd/flags.go +++ b/cmd/flags.go @@ -21,16 +21,17 @@ import ( ) const ( - categoryHTTP01Challenge = "Flags related to the HTTP-01 challenge:" - categoryTLSALPN01Challenge = "Flags related to the TLS-ALPN-01 challenge:" - categoryDNS01Challenge = "Flags related to the DNS-01 challenge:" - categoryStorage = "Flags related to the storage:" - categoryHooks = "Flags related to hooks:" - categoryEAB = "Flags related to External Account Binding:" - categoryACMEClient = "Flags related to the ACME client:" - categoryAdvanced = "Flags related to advanced options:" - categoryARI = "Flags related to ACME Renewal Information (ARI) Extension:" - categoryLogs = "Flags related to logs:" + categoryHTTP01Challenge = "Flags related to the HTTP-01 challenge:" + categoryTLSALPN01Challenge = "Flags related to the TLS-ALPN-01 challenge:" + categoryDNS01Challenge = "Flags related to the DNS-01 challenge:" + categoryDNSPersist01Challenge = "Flags related to the DNS-PERSIST-01 challenge:" + categoryStorage = "Flags related to the storage:" + categoryHooks = "Flags related to hooks:" + categoryEAB = "Flags related to External Account Binding:" + categoryACMEClient = "Flags related to the ACME client:" + categoryAdvanced = "Flags related to advanced options:" + categoryARI = "Flags related to ACME Renewal Information (ARI) Extension:" + categoryLogs = "Flags related to logs:" ) // Flag aliases (short-codes). @@ -128,6 +129,18 @@ const ( flgDNSTimeout = "dns.timeout" ) +// Flag names related to the DNS-PERSIST-01 challenge. +const ( + flgDNSPersist = "dns-persist" + flgDNSPersistIssuerDomainName = "dns-persist.issuer-domain-name" + flgDNSPersistPersistUntil = "dns-persist.persist-until" + flgDNSPersistPropagationWait = "dns-persist.propagation.wait" + flgDNSPersistPropagationDisableANS = "dns-persist.propagation.disable-ans" + flgDNSSPersistPropagationDisableRNS = "dns-persist.propagation.disable-rns" + flgDNSPersistResolvers = "dns-persist.resolvers" + flgDNSPersistTimeout = "dns-persist.timeout" +) + // Flags names related to hooks. const ( flgDeployHook = "deploy-hook" @@ -255,6 +268,7 @@ func createChallengesFlags() []cli.Flag { flags = append(flags, createHTTPChallengeFlags()...) flags = append(flags, createTLSChallengeFlags()...) flags = append(flags, createDNSChallengeFlags()...) + flags = append(flags, createDNSPersistChallengeFlags()...) flags = append(flags, createNetworkStackFlags()...) return flags @@ -407,6 +421,71 @@ func createDNSChallengeFlags() []cli.Flag { } } +func createDNSPersistChallengeFlags() []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersist, + Sources: cli.EnvVars(toEnvName(flgDNSPersist)), + Usage: "Use the DNS-PERSIST-01 challenge to solve challenges. Manual verification only. Can be mixed with other types of challenges.", + }, + &cli.StringFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersistIssuerDomainName, + Sources: cli.EnvVars(toEnvName(flgDNSPersistIssuerDomainName)), + Usage: "Override the issuer-domain-name to use for DNS-PERSIST-01 when multiple are offered. Must be offered by the challenge.", + }, + &cli.TimestampFlag{ + Name: flgDNSPersistPersistUntil, + Category: categoryDNSPersist01Challenge, + Usage: "Set the optional persistUntil for DNS-PERSIST-01 records as an RFC3339 timestamp (for example 2026-03-01T00:00:00Z).", + Sources: cli.EnvVars(toEnvName(flgDNSPersistPersistUntil)), + Config: cli.TimestampConfig{ + Layouts: []string{time.RFC3339}, + }, + }, + &cli.DurationFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersistPropagationWait, + Sources: cli.EnvVars(toEnvName(flgDNSPersistPropagationWait)), + Usage: "By setting this flag, disables all the propagation checks of the TXT record and uses a wait duration instead.", + Validator: func(d time.Duration) error { + if d < 0 { + return errors.New("it cannot be negative") + } + + return nil + }, + }, + &cli.BoolFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersistPropagationDisableANS, + Sources: cli.EnvVars(toEnvName(flgDNSPersistPropagationDisableANS)), + Usage: "By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers.", + }, + &cli.BoolFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSSPersistPropagationDisableRNS, + Sources: cli.EnvVars(toEnvName(flgDNSSPersistPropagationDisableRNS)), + Usage: "By setting this flag to true, disables the need to await propagation of the TXT record to all recursive name servers (aka resolvers).", + }, + &cli.StringSliceFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersistResolvers, + Sources: cli.EnvVars(toEnvName(flgDNSPersistResolvers)), + Usage: "Set the resolvers to use for DNS-PERSIST-01 TXT lookups." + + " Supported: host:port." + + " The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined.", + }, + &cli.IntFlag{ + Category: categoryDNSPersist01Challenge, + Name: flgDNSPersistTimeout, + Sources: cli.EnvVars(toEnvName(flgDNSPersistTimeout)), + Usage: "Set the DNS timeout value to a specific value in seconds. Used for DNS-PERSIST-01 lookups.", + }, + } +} + func createStorageFlags() []cli.Flag { return []cli.Flag{ createPathFlag(true), diff --git a/cmd/setup_challenges.go b/cmd/setup_challenges.go index dea753f9a..fcaf9c9ad 100644 --- a/cmd/setup_challenges.go +++ b/cmd/setup_challenges.go @@ -1,6 +1,7 @@ package cmd import ( + "errors" "fmt" "log/slog" "net" @@ -9,6 +10,7 @@ import ( "github.com/go-acme/lego/v5/challenge" "github.com/go-acme/lego/v5/challenge/dns01" + "github.com/go-acme/lego/v5/challenge/dnspersist01" "github.com/go-acme/lego/v5/challenge/http01" "github.com/go-acme/lego/v5/challenge/tlsalpn01" "github.com/go-acme/lego/v5/lego" @@ -17,12 +19,14 @@ import ( "github.com/go-acme/lego/v5/providers/http/memcached" "github.com/go-acme/lego/v5/providers/http/s3" "github.com/go-acme/lego/v5/providers/http/webroot" + "github.com/go-acme/lego/v5/registration" "github.com/urfave/cli/v3" ) -func setupChallenges(cmd *cli.Command, client *lego.Client) { - if !cmd.Bool(flgHTTP) && !cmd.Bool(flgTLS) && !cmd.IsSet(flgDNS) { - log.Fatal(fmt.Sprintf("No challenge selected. You must specify at least one challenge: `--%s`, `--%s`, `--%s`.", flgHTTP, flgTLS, flgDNS)) +//nolint:gocyclo // challenge setup dispatch is expected to branch by enabled challenge type. +func setupChallenges(cmd *cli.Command, client *lego.Client, account registration.User) { + if !cmd.Bool(flgHTTP) && !cmd.Bool(flgTLS) && !cmd.IsSet(flgDNS) && !cmd.Bool(flgDNSPersist) { + log.Fatal(fmt.Sprintf("No challenge selected. You must specify at least one challenge: `--%s`, `--%s`, `--%s`, `--%s`.", flgHTTP, flgTLS, flgDNS, flgDNSPersist)) } if cmd.Bool(flgHTTP) { @@ -45,6 +49,13 @@ func setupChallenges(cmd *cli.Command, client *lego.Client) { log.Fatal("Could not set DNS challenge provider.", log.ErrorAttr(err)) } } + + if cmd.Bool(flgDNSPersist) { + err := setupDNSPersist(cmd, client, account) + if err != nil { + log.Fatal("Could not set DNS-PERSIST-01 challenge provider.", log.ErrorAttr(err)) + } + } } //nolint:gocyclo // the complexity is expected. @@ -162,7 +173,7 @@ func setupTLSProvider(cmd *cli.Command) challenge.Provider { } func setupDNS(cmd *cli.Command, client *lego.Client) error { - err := validatePropagationExclusiveOptions(cmd) + err := validatePropagationExclusiveOptions(cmd, flgDNSPropagationWait, flgDNSPropagationDisableANS, flgDNSPropagationDisableRNS) if err != nil { return err } @@ -186,29 +197,70 @@ func setupDNS(cmd *cli.Command, client *lego.Client) error { err = client.Challenge.SetDNS01Provider(provider, dns01.CondOption(shouldWait, - dns01.PropagationWait(cmd.Duration(flgDNSPropagationWait), true)), + dns01.PropagationWait(cmd.Duration(flgDNSPropagationWait), true), + ), dns01.CondOption(!shouldWait && cmd.Bool(flgDNSPropagationDisableANS), - dns01.DisableAuthoritativeNssPropagationRequirement()), + dns01.DisableAuthoritativeNssPropagationRequirement(), + ), dns01.CondOption(!shouldWait && cmd.Bool(flgDNSPropagationDisableRNS), - dns01.DisableRecursiveNSsPropagationRequirement()), + dns01.DisableRecursiveNSsPropagationRequirement(), + ), ) return err } -func validatePropagationExclusiveOptions(cmd *cli.Command) error { - if !cmd.IsSet(flgDNSPropagationWait) { +func setupDNSPersist(cmd *cli.Command, client *lego.Client, account registration.User) error { + if account == nil || account.GetRegistration() == nil || account.GetRegistration().URI == "" { + return errors.New("dns-persist-01 requires a registered account with an account URI") + } + + err := validatePropagationExclusiveOptions(cmd, flgDNSPersistPropagationWait, flgDNSPersistPropagationDisableANS, flgDNSPersistIssuerDomainName) + if err != nil { + return err + } + + resolvers := cmd.StringSlice(flgDNSPersistResolvers) + shouldWait := cmd.IsSet(flgDNSPersistPropagationWait) + + return client.Challenge.SetDNSPersist01( + dnspersist01.WithAccountURI(account.GetRegistration().URI), + dnspersist01.WithIssuerDomainName(cmd.String(flgDNSPersistIssuerDomainName)), + dnspersist01.CondOptions(len(resolvers) > 0, + dnspersist01.WithNameservers(resolvers), + dnspersist01.AddRecursiveNameservers(resolvers), + ), + dnspersist01.CondOptions(cmd.IsSet(flgDNSPersistPersistUntil), + dnspersist01.WithPersistUntil(cmd.Timestamp(flgDNSPersistPersistUntil)), + ), + dnspersist01.CondOptions(cmd.IsSet(flgDNSPersistTimeout), + dnspersist01.WithDNSTimeout(time.Duration(cmd.Int(flgDNSPersistTimeout))*time.Second), + ), + dnspersist01.CondOptions(shouldWait, + dnspersist01.PropagationWait(cmd.Duration(flgDNSPersistPropagationWait), true), + ), + dnspersist01.CondOptions(!shouldWait, + dnspersist01.CondOptions(cmd.Bool(flgDNSPersistPropagationDisableANS), + dnspersist01.DisableAuthoritativeNssPropagationRequirement(), + ), + dnspersist01.CondOptions(cmd.Bool(flgDNSSPersistPropagationDisableRNS), + dnspersist01.DisableRecursiveNSsPropagationRequirement(), + ), + ), + ) +} + +func validatePropagationExclusiveOptions(cmd *cli.Command, flgWait, flgANS, flgDNS string) error { + if !cmd.IsSet(flgWait) { return nil } - if isSetBool(cmd, flgDNSPropagationDisableANS) { - return fmt.Errorf("'%s' and '%s' are mutually exclusive", - flgDNSPropagationWait, flgDNSPropagationDisableANS) + if isSetBool(cmd, flgANS) { + return fmt.Errorf("'%s' and '%s' are mutually exclusive", flgWait, flgANS) } - if isSetBool(cmd, flgDNSPropagationDisableRNS) { - return fmt.Errorf("'%s' and '%s' are mutually exclusive", - flgDNSPropagationWait, flgDNSPropagationDisableRNS) + if isSetBool(cmd, flgDNS) { + return fmt.Errorf("'%s' and '%s' are mutually exclusive", flgWait, flgDNS) } return nil diff --git a/docs/data/zz_cli_help.toml b/docs/data/zz_cli_help.toml index 2bae8d893..55ec666e1 100644 --- a/docs/data/zz_cli_help.toml +++ b/docs/data/zz_cli_help.toml @@ -90,6 +90,17 @@ OPTIONS: --dns.resolvers string [ --dns.resolvers string ] Set the resolvers to use for performing (recursive) CNAME resolving and apex domain determination. For DNS-01 challenge verification, the authoritative DNS server is queried directly. Supported: host:port. The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined. [$LEGO_DNS_RESOLVERS] --dns.timeout int Set the DNS timeout value to a specific value in seconds. Used only when performing authoritative name server queries. (default: 10) [$LEGO_DNS_TIMEOUT] + Flags related to the DNS-PERSIST-01 challenge: + + --dns-persist Use the DNS-PERSIST-01 challenge to solve challenges. Manual verification only. Can be mixed with other types of challenges. [$LEGO_DNS_PERSIST] + --dns-persist.issuer-domain-name string Override the issuer-domain-name to use for DNS-PERSIST-01 when multiple are offered. Must be offered by the challenge. [$LEGO_DNS_PERSIST_ISSUER_DOMAIN_NAME] + --dns-persist.persist-until time Set the optional persistUntil for DNS-PERSIST-01 records as an RFC3339 timestamp (for example 2026-03-01T00:00:00Z). [$LEGO_DNS_PERSIST_PERSIST_UNTIL] + --dns-persist.propagation.disable-ans By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers. [$LEGO_DNS_PERSIST_PROPAGATION_DISABLE_ANS] + --dns-persist.propagation.disable-rns By setting this flag to true, disables the need to await propagation of the TXT record to all recursive name servers (aka resolvers). [$LEGO_DNS_PERSIST_PROPAGATION_DISABLE_RNS] + --dns-persist.propagation.wait duration By setting this flag, disables all the propagation checks of the TXT record and uses a wait duration instead. (default: 0s) [$LEGO_DNS_PERSIST_PROPAGATION_WAIT] + --dns-persist.resolvers string [ --dns-persist.resolvers string ] Set the resolvers to use for DNS-PERSIST-01 TXT lookups. Supported: host:port. The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined. [$LEGO_DNS_PERSIST_RESOLVERS] + --dns-persist.timeout int Set the DNS timeout value to a specific value in seconds. Used for DNS-PERSIST-01 lookups. (default: 0) [$LEGO_DNS_PERSIST_TIMEOUT] + Flags related to the HTTP-01 challenge: --http Use the HTTP-01 challenge to solve challenges. Can be mixed with other types of challenges. [$LEGO_HTTP] @@ -187,6 +198,17 @@ OPTIONS: --dns.resolvers string [ --dns.resolvers string ] Set the resolvers to use for performing (recursive) CNAME resolving and apex domain determination. For DNS-01 challenge verification, the authoritative DNS server is queried directly. Supported: host:port. The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined. [$LEGO_DNS_RESOLVERS] --dns.timeout int Set the DNS timeout value to a specific value in seconds. Used only when performing authoritative name server queries. (default: 10) [$LEGO_DNS_TIMEOUT] + Flags related to the DNS-PERSIST-01 challenge: + + --dns-persist Use the DNS-PERSIST-01 challenge to solve challenges. Manual verification only. Can be mixed with other types of challenges. [$LEGO_DNS_PERSIST] + --dns-persist.issuer-domain-name string Override the issuer-domain-name to use for DNS-PERSIST-01 when multiple are offered. Must be offered by the challenge. [$LEGO_DNS_PERSIST_ISSUER_DOMAIN_NAME] + --dns-persist.persist-until time Set the optional persistUntil for DNS-PERSIST-01 records as an RFC3339 timestamp (for example 2026-03-01T00:00:00Z). [$LEGO_DNS_PERSIST_PERSIST_UNTIL] + --dns-persist.propagation.disable-ans By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers. [$LEGO_DNS_PERSIST_PROPAGATION_DISABLE_ANS] + --dns-persist.propagation.disable-rns By setting this flag to true, disables the need to await propagation of the TXT record to all recursive name servers (aka resolvers). [$LEGO_DNS_PERSIST_PROPAGATION_DISABLE_RNS] + --dns-persist.propagation.wait duration By setting this flag, disables all the propagation checks of the TXT record and uses a wait duration instead. (default: 0s) [$LEGO_DNS_PERSIST_PROPAGATION_WAIT] + --dns-persist.resolvers string [ --dns-persist.resolvers string ] Set the resolvers to use for DNS-PERSIST-01 TXT lookups. Supported: host:port. The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined. [$LEGO_DNS_PERSIST_RESOLVERS] + --dns-persist.timeout int Set the DNS timeout value to a specific value in seconds. Used for DNS-PERSIST-01 lookups. (default: 0) [$LEGO_DNS_PERSIST_TIMEOUT] + Flags related to the HTTP-01 challenge: --http Use the HTTP-01 challenge to solve challenges. Can be mixed with other types of challenges. [$LEGO_HTTP] diff --git a/e2e/dnschallenge/dns_persist_challenges_test.go b/e2e/dnschallenge/dns_persist_challenges_test.go new file mode 100644 index 000000000..760eff975 --- /dev/null +++ b/e2e/dnschallenge/dns_persist_challenges_test.go @@ -0,0 +1,390 @@ +package dnschallenge + +import ( + "bytes" + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/go-acme/lego/v5/certcrypto" + "github.com/go-acme/lego/v5/certificate" + "github.com/go-acme/lego/v5/challenge/dnspersist01" + "github.com/go-acme/lego/v5/e2e/loader" + "github.com/go-acme/lego/v5/lego" + "github.com/go-acme/lego/v5/registration" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testPersistBaseDomain = "persist.localhost" + testPersistDomain = "*." + testPersistBaseDomain + testPersistIssuer = "pebble.letsencrypt.org" + + testPersistCLIDomain = "persist-cli.localhost" + testPersistCLIWildcardDomain = "*." + testPersistCLIDomain + testPersistCLIEmail = "persist-e2e@example.com" + testPersistCLIFreshEmail = "persist-e2e-fresh@example.com" + testPersistCLIRenewEmail = "persist-e2e-renew@example.com" +) + +func setTXTRecord(t *testing.T, host, value string) { + t.Helper() + + err := setTXTRecordRaw(host, value) + require.NoError(t, err) +} + +func setTXTRecordRaw(host, value string) error { + body, err := json.Marshal(map[string]string{ + "host": host, + "value": value, + }) + if err != nil { + return err + } + + resp, err := http.Post("http://localhost:8055/set-txt", "application/json", bytes.NewReader(body)) + if err != nil { + return err + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + + return nil +} + +func clearTXTRecord(t *testing.T, host string) { + t.Helper() + + body, err := json.Marshal(map[string]string{ + "host": host, + }) + require.NoError(t, err) + + resp, err := http.Post("http://localhost:8055/clear-txt", "application/json", bytes.NewReader(body)) + require.NoError(t, err) + + defer func() { _ = resp.Body.Close() }() + + require.Equal(t, http.StatusOK, resp.StatusCode) +} + +//nolint:unparam // kept generic for future e2e tests. +func mustDNSPersistIssueValue(t *testing.T, issuerDomainName, accountURI string) string { + t.Helper() + + value, err := dnspersist01.BuildIssueValue(issuerDomainName, accountURI, true, time.Time{}) + require.NoError(t, err) + + return value +} + +func createCLIAccountState(t *testing.T, email string) string { + t.Helper() + + privateKey, err := certcrypto.GeneratePrivateKey(certcrypto.EC256) + require.NoError(t, err) + + user := &fakeUser{ + email: email, + privateKey: privateKey, + } + config := lego.NewConfig(user) + config.CADirURL = "https://localhost:15000/dir" + + client, err := lego.NewClient(config) + require.NoError(t, err) + + reg, err := client.Registration.Register(context.Background(), registration.RegisterOptions{TermsOfServiceAgreed: true}) + require.NoError(t, err) + require.NotEmpty(t, reg.URI) + + keyType := certcrypto.EC256 + accountPathRoot := filepath.Join(".lego", "accounts", "localhost_15000", email, string(keyType)) + err = os.MkdirAll(accountPathRoot, 0o700) + require.NoError(t, err) + + err = saveAccountPrivateKey(filepath.Join(accountPathRoot, email+".key"), privateKey) + require.NoError(t, err) + + accountPath := filepath.Join(accountPathRoot, "account.json") + content, err := json.MarshalIndent(struct { + ID string `json:"id"` + Email string `json:"email"` + KeyType certcrypto.KeyType `json:"keyType"` + Registration *registration.Resource `json:"registration"` + }{ + ID: email, + Email: email, + KeyType: keyType, + Registration: reg, + }, "", "\t") + require.NoError(t, err) + + err = os.WriteFile(accountPath, content, 0o600) + require.NoError(t, err) + + return reg.URI +} + +func saveAccountPrivateKey(path string, privateKey crypto.PrivateKey) error { + return os.WriteFile(path, certcrypto.PEMEncode(privateKey), 0o600) +} + +func cliAccountFilePath(email string) string { + return filepath.Join(".lego", "accounts", "localhost_15000", email, string(certcrypto.EC256), "account.json") +} + +func waitForCLIAccountURI(ctx context.Context, email string) (string, error) { + accountPath := cliAccountFilePath(email) + + type accountFile struct { + Registration *registration.Resource `json:"registration"` + } + + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-ticker.C: + content, err := os.ReadFile(accountPath) + if err != nil { + if os.IsNotExist(err) { + continue + } + + return "", err + } + + var account accountFile + + err = json.Unmarshal(content, &account) + if err != nil { + continue + } + + if account.Registration != nil && account.Registration.URI != "" { + return account.Registration.URI, nil + } + } + } +} + +func TestChallengeDNSPersist_Client_Obtain(t *testing.T) { + err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") + require.NoError(t, err) + + defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err, "Could not generate test key") + + user := &fakeUser{privateKey: privateKey} + config := lego.NewConfig(user) + config.CADirURL = "https://localhost:15000/dir" + + client, err := lego.NewClient(config) + require.NoError(t, err) + + reg, err := client.Registration.Register(context.Background(), registration.RegisterOptions{TermsOfServiceAgreed: true}) + require.NoError(t, err) + require.NotEmpty(t, reg.URI) + + user.registration = reg + + txtHost := fmt.Sprintf("_validation-persist.%s", testPersistBaseDomain) + txtValue := mustDNSPersistIssueValue(t, testPersistIssuer, reg.URI) + + setTXTRecord(t, txtHost, txtValue) + defer clearTXTRecord(t, txtHost) + + err = client.Challenge.SetDNSPersist01( + dnspersist01.WithAccountURI(reg.URI), + dnspersist01.WithNameservers([]string{":8053"}), + dnspersist01.AddRecursiveNameservers([]string{":8053"}), + dnspersist01.DisableAuthoritativeNssPropagationRequirement(), + ) + require.NoError(t, err) + + privateKeyCSR, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err, "Could not generate test key") + + request := certificate.ObtainRequest{ + Domains: []string{testPersistDomain}, + Bundle: true, + PrivateKey: privateKeyCSR, + } + resource, err := client.Certificate.Obtain(context.Background(), request) + require.NoError(t, err) + + require.NotNil(t, resource) + assert.Equal(t, testPersistDomain, resource.Domains[0]) + assert.Regexp(t, `https://localhost:15000/certZ/[\w\d]{14,}`, resource.CertURL) + assert.Regexp(t, `https://localhost:15000/certZ/[\w\d]{14,}`, resource.CertStableURL) + assert.NotEmpty(t, resource.Certificate) + assert.NotEmpty(t, resource.IssuerCertificate) + assert.Empty(t, resource.CSR) +} + +func TestChallengeDNSPersist_Run(t *testing.T) { + loader.CleanLegoFiles(context.Background()) + + err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") + require.NoError(t, err) + + defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() + + accountURI := createCLIAccountState(t, testPersistCLIEmail) + require.NotEmpty(t, accountURI) + + txtHost := fmt.Sprintf("_validation-persist.%s", testPersistCLIDomain) + txtValue := mustDNSPersistIssueValue(t, testPersistIssuer, accountURI) + + setTXTRecord(t, txtHost, txtValue) + defer clearTXTRecord(t, txtHost) + + err = load.RunLego( + context.Background(), + "run", + "--email", testPersistCLIEmail, + "--accept-tos", + "--dns-persist", + "--dns-persist.resolvers", ":8053", + "--dns-persist.propagation.disable-ans", + "--dns-persist.issuer-domain-name", testPersistIssuer, + "--server", "https://localhost:15000/dir", + "--domains", testPersistCLIWildcardDomain, + "--domains", testPersistCLIDomain, + ) + require.NoError(t, err) +} + +func TestChallengeDNSPersist_Run_NewAccount(t *testing.T) { + loader.CleanLegoFiles(context.Background()) + + err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") + require.NoError(t, err) + + defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() + + txtHost := fmt.Sprintf("_validation-persist.%s", testPersistCLIDomain) + defer clearTXTRecord(t, txtHost) + + stdinReader, stdinWriter := io.Pipe() + + defer func() { _ = stdinReader.Close() }() + + errChan := make(chan error, 1) + + go func() { + defer func() { _ = stdinWriter.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + accountURI, waitErr := waitForCLIAccountURI(ctx, testPersistCLIFreshEmail) + if waitErr != nil { + errChan <- fmt.Errorf("wait for account URI: %w", waitErr) + return + } + + txtValue := mustDNSPersistIssueValue(t, testPersistIssuer, accountURI) + + err = setTXTRecordRaw(txtHost, txtValue) + if err != nil { + errChan <- fmt.Errorf("set TXT record: %w", err) + return + } + + _, err = io.WriteString(stdinWriter, "\n") + if err != nil { + errChan <- fmt.Errorf("send enter to lego: %w", err) + return + } + + errChan <- nil + }() + + err = load.RunLegoWithInput( + context.Background(), + stdinReader, + "run", + "--email", testPersistCLIFreshEmail, + "--accept-tos", + "--dns-persist", + "--dns-persist.resolvers", ":8053", + "--dns-persist.propagation.disable-ans", + "--dns-persist.issuer-domain-name", testPersistIssuer, + "--server", "https://localhost:15000/dir", + "--domains", testPersistCLIWildcardDomain, + "--domains", testPersistCLIDomain, + ) + require.NoError(t, err) + require.NoError(t, <-errChan) +} + +func TestChallengeDNSPersist_Renew(t *testing.T) { + loader.CleanLegoFiles(context.Background()) + + err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") + require.NoError(t, err) + + defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() + + accountURI := createCLIAccountState(t, testPersistCLIRenewEmail) + require.NotEmpty(t, accountURI) + + txtHost := fmt.Sprintf("_validation-persist.%s", testPersistCLIDomain) + txtValue := mustDNSPersistIssueValue(t, testPersistIssuer, accountURI) + + setTXTRecord(t, txtHost, txtValue) + defer clearTXTRecord(t, txtHost) + + err = load.RunLego( + context.Background(), + "run", + "--email", testPersistCLIRenewEmail, + "--accept-tos", + "--dns-persist", + "--dns-persist.resolvers", ":8053", + "--dns-persist.propagation.disable-ans", + "--dns-persist.issuer-domain-name", testPersistIssuer, + "--server", "https://localhost:15000/dir", + "--domains", testPersistCLIWildcardDomain, + "--domains", testPersistCLIDomain, + ) + require.NoError(t, err) + + err = load.RunLego( + context.Background(), + "renew", + "--email", testPersistCLIRenewEmail, + "--dns-persist", + "--dns-persist.resolvers", ":8053", + "--dns-persist.propagation.disable-ans", + "--dns-persist.issuer-domain-name", testPersistIssuer, + "--server", "https://localhost:15000/dir", + "--domains", testPersistCLIWildcardDomain, + "--domains", testPersistCLIDomain, + "--renew-force", + "--no-random-sleep", + ) + require.NoError(t, err) +} diff --git a/e2e/loader/loader.go b/e2e/loader/loader.go index 456f5a2ef..d231d060a 100644 --- a/e2e/loader/loader.go +++ b/e2e/loader/loader.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net/http" "os" "os/exec" @@ -113,8 +114,13 @@ func (l *EnvLoader) RunLegoCombinedOutput(ctx context.Context, arg ...string) ([ } func (l *EnvLoader) RunLego(ctx context.Context, arg ...string) error { + return l.RunLegoWithInput(ctx, nil, arg...) +} + +func (l *EnvLoader) RunLegoWithInput(ctx context.Context, stdin io.Reader, arg ...string) error { cmd := exec.CommandContext(ctx, l.lego, arg...) cmd.Env = l.LegoOptions + cmd.Stdin = stdin fmt.Printf("$ %s\n", strings.Join(cmd.Args, " ")) diff --git a/e2e/readme.md b/e2e/readme.md index 171170507..228b7a3ef 100644 --- a/e2e/readme.md +++ b/e2e/readme.md @@ -2,8 +2,8 @@ - Install [Pebble](https://github.com/letsencrypt/pebble): ```bash -go install github.com/letsencrypt/pebble/v2/cmd/pebble@v2.9.0 -go install github.com/letsencrypt/pebble/v2/cmd/pebble-challtestsrv@v2.9.0 +go install github.com/letsencrypt/pebble/v2/cmd/pebble@v2.10.0 +go install github.com/letsencrypt/pebble/v2/cmd/pebble-challtestsrv@v2.10.0 ``` - Launch tests: