refactor: rewrite needRenewal

This commit is contained in:
Fernandez Ludovic 2026-02-06 18:25:41 +01:00
commit 62dcd9d6f7
2 changed files with 91 additions and 89 deletions

View file

@ -11,6 +11,7 @@ import (
"math/rand"
"os"
"slices"
"sort"
"strings"
"sync"
"time"
@ -26,6 +27,8 @@ import (
"github.com/urfave/cli/v3"
)
const noDays = -math.MaxInt
type lzSetUp func() (*lego.Client, error)
func createRenew() *cli.Command {
@ -117,19 +120,24 @@ func renewForDomains(ctx context.Context, cmd *cli.Command, lazyClient lzSetUp,
cert := certificates[0]
if cert.IsCA {
return fmt.Errorf("certificate bundle for %q starts with a CA certificate", domain)
}
ariRenewalTime, replacesCertID, err := getARIInfo(ctx, cmd, lazyClient, domain, cert)
if err != nil {
return err
}
forceDomains := cmd.Bool(flgForceCertDomains)
certDomains := certcrypto.ExtractDomains(cert)
days := getFlagRenewDays(cmd)
renewalDomains := slices.Clone(domains)
if !cmd.Bool(flgForceCertDomains) {
renewalDomains = merge(certDomains, domains)
}
if ariRenewalTime == nil && !needRenewal(cert, domain, days, cmd.Bool(flgRenewForce)) &&
(!forceDomains || slices.Equal(certDomains, domains)) {
if ariRenewalTime == nil && !cmd.Bool(flgRenewForce) && sameDomains(certDomains, renewalDomains) &&
!isInRenewalPeriod(cert, domain, getFlagRenewDays(cmd), time.Now()) {
return nil
}
@ -162,11 +170,6 @@ func renewForDomains(ctx context.Context, cmd *cli.Command, lazyClient lzSetUp,
randomSleep(cmd)
renewalDomains := slices.Clone(domains)
if !forceDomains {
renewalDomains = merge(certDomains, domains)
}
request := newObtainRequest(cmd, renewalDomains)
request.PrivateKey = privateKey
@ -215,14 +218,17 @@ func renewForCSR(ctx context.Context, cmd *cli.Command, lazyClient lzSetUp, cert
cert := certificates[0]
if cert.IsCA {
return fmt.Errorf("certificate bundle for %q starts with a CA certificate", domain)
}
ariRenewalTime, replacesCertID, err := getARIInfo(ctx, cmd, lazyClient, domain, cert)
if err != nil {
return err
}
days := getFlagRenewDays(cmd)
if ariRenewalTime == nil && !needRenewal(cert, domain, days, cmd.Bool(flgRenewForce)) {
if ariRenewalTime == nil && !cmd.Bool(flgRenewForce) && sameDomainsCertificate(cert, csr) &&
!isInRenewalPeriod(cert, domain, getFlagRenewDays(cmd), time.Now()) {
return nil
}
@ -267,63 +273,19 @@ func getFlagRenewDays(cmd *cli.Command) int {
return cmd.Int(flgRenewDays)
}
return -math.MaxInt
return noDays
}
func needRenewal(x509Cert *x509.Certificate, domain string, days int, force bool) bool {
if x509Cert.IsCA {
log.Fatal("Certificate bundle starts with a CA certificate.", log.DomainAttr(domain))
}
func isInRenewalPeriod(cert *x509.Certificate, domain string, days int, now time.Time) bool {
dueDate := getDueDate(cert, days, now)
if force {
return true
}
// Default behavior
if days == -math.MaxInt {
return needRenewalDynamic(x509Cert, domain, time.Now())
}
return needRenewalDays(x509Cert, domain, days)
}
func needRenewalDays(x509Cert *x509.Certificate, domain string, days int) bool {
if days < 0 {
// if the number of days is negative: always renew the certificate.
return true
}
notAfter := int(time.Until(x509Cert.NotAfter).Hours() / 24.0)
if notAfter <= days {
return true
}
log.Infof(
log.LazySprintf("Skip renewal: the certificate expires in %d days, the number of days defined to perform the renewal is %d.",
notAfter, days),
log.DomainAttr(domain),
)
return false
}
func needRenewalDynamic(x509Cert *x509.Certificate, domain string, now time.Time) bool {
lifetime := x509Cert.NotAfter.Sub(x509Cert.NotBefore)
var divisor int64 = 3
if lifetime.Round(24*time.Hour).Hours()/24.0 <= 10 {
divisor = 2
}
dueDate := x509Cert.NotAfter.Add(-1 * time.Duration(lifetime.Nanoseconds()/divisor))
if dueDate.Before(now) {
if dueDate.Before(now) || dueDate.Equal(now) {
return true
}
log.Infof(
log.LazySprintf("Skip renewal: The certificate expires at %s, the renewal can be performed in %s.",
x509Cert.NotAfter.Format(time.RFC3339),
cert.NotAfter.Format(time.RFC3339),
FormattableDuration(dueDate.Sub(now)),
),
log.DomainAttr(domain),
@ -332,6 +294,26 @@ func needRenewalDynamic(x509Cert *x509.Certificate, domain string, now time.Time
return false
}
func getDueDate(x509Cert *x509.Certificate, days int, now time.Time) time.Time {
if days == noDays {
lifetime := x509Cert.NotAfter.Sub(x509Cert.NotBefore)
var divisor int64 = 3
if lifetime.Round(24*time.Hour).Hours()/24.0 <= 10 {
divisor = 2
}
return x509Cert.NotAfter.Add(-1 * time.Duration(lifetime.Nanoseconds()/divisor))
}
if days < 0 {
// if the number of days is negative: always renew the certificate.
return now
}
return x509Cert.NotAfter.Add(-1 * time.Duration(days) * 24 * time.Hour)
}
func getARIInfo(ctx context.Context, cmd *cli.Command, lazyClient lzSetUp, domain string, cert *x509.Certificate) (*time.Time, string, error) {
if cmd.Bool(flgARIDisable) {
return nil, "", nil
@ -440,6 +422,24 @@ func merge(prevDomains, nextDomains []string) []string {
return prevDomains
}
func sameDomainsCertificate(cert *x509.Certificate, csr *x509.CertificateRequest) bool {
return sameDomains(certcrypto.ExtractDomains(cert), certcrypto.ExtractDomainsCSR(csr))
}
func sameDomains(a, b []string) bool {
if len(a) != len(b) {
return false
}
aClone := slices.Clone(a)
sort.Strings(aClone)
bClone := slices.Clone(b)
sort.Strings(bClone)
return slices.Equal(aClone, bClone)
}
type FormattableDuration time.Duration
func (f FormattableDuration) String() string {

View file

@ -57,65 +57,69 @@ func Test_merge(t *testing.T) {
}
}
func Test_needRenewal(t *testing.T) {
func Test_isInRenewalPeriod_days(t *testing.T) {
now := time.Date(2025, 1, 19, 1, 1, 1, 1, time.UTC)
oneDay := 24 * time.Hour
testCases := []struct {
desc string
x509Cert *x509.Certificate
cert *x509.Certificate
days int
expected bool
expected assert.BoolAssertionFunc
}{
{
desc: "30 days, NotAfter now",
x509Cert: &x509.Certificate{
NotAfter: time.Now(),
desc: "days: 30 days, NotAfter now",
cert: &x509.Certificate{
NotAfter: now,
},
days: 30,
expected: true,
expected: assert.True,
},
{
desc: "30 days, NotAfter 31 days",
x509Cert: &x509.Certificate{
NotAfter: time.Now().Add(31*24*time.Hour + 1*time.Second),
desc: "days: 30 days, NotAfter 31 days",
cert: &x509.Certificate{
NotAfter: now.Add(31*oneDay + 1*time.Second),
},
days: 30,
expected: false,
expected: assert.False,
},
{
desc: "30 days, NotAfter 30 days",
x509Cert: &x509.Certificate{
NotAfter: time.Now().Add(30 * 24 * time.Hour),
desc: "days: 30 days, NotAfter 30 days",
cert: &x509.Certificate{
NotAfter: now.Add(30 * oneDay),
},
days: 30,
expected: true,
expected: assert.True,
},
{
desc: "0 days, NotAfter 30 days: only the day of the expiration",
x509Cert: &x509.Certificate{
NotAfter: time.Now().Add(30 * 24 * time.Hour),
desc: "days: 0 days, NotAfter 30 days: only the day of the expiration",
cert: &x509.Certificate{
NotAfter: now.Add(30 * oneDay),
},
days: 0,
expected: false,
expected: assert.False,
},
{
desc: "-1 days, NotAfter 30 days: always renew",
x509Cert: &x509.Certificate{
NotAfter: time.Now().Add(30 * 24 * time.Hour),
desc: "days: -1 days, NotAfter 30 days: always renew",
cert: &x509.Certificate{
NotAfter: now.Add(30 * oneDay),
},
days: -1,
expected: true,
expected: assert.True,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
actual := needRenewal(test.x509Cert, "foo.com", test.days, false)
actual := isInRenewalPeriod(test.cert, "foo.com", test.days, now)
assert.Equal(t, test.expected, actual)
test.expected(t, actual)
})
}
}
func Test_needRenewalDynamic(t *testing.T) {
func Test_isInRenewalPeriod_dynamic(t *testing.T) {
testCases := []struct {
desc string
now time.Time
@ -154,16 +158,14 @@ func Test_needRenewalDynamic(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
x509Cert := &x509.Certificate{
cert := &x509.Certificate{
NotBefore: test.notBefore,
NotAfter: test.notAfter,
}
ok := needRenewalDynamic(x509Cert, "example.com", test.now)
actual := isInRenewalPeriod(cert, "foo.com", noDays, test.now)
test.expected(t, ok)
test.expected(t, actual)
})
}
}