chore: update linter (#2699)

This commit is contained in:
Ludovic Fernandez 2025-10-30 13:02:35 +01:00 committed by GitHub
commit 81e0f2b42a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
498 changed files with 1439 additions and 112 deletions

View file

@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
GO_VERSION: stable GO_VERSION: stable
GOLANGCI_LINT_VERSION: v2.5.0 GOLANGCI_LINT_VERSION: v2.6.0
HUGO_VERSION: 0.148.2 HUGO_VERSION: 0.148.2
CGO_ENABLED: 0 CGO_ENABLED: 0
LEGO_E2E_TESTS: CI LEGO_E2E_TESTS: CI

View file

@ -50,11 +50,8 @@ linters:
- tagliatelle - tagliatelle
- testpackage # not relevant - testpackage # not relevant
- tparallel # not relevant - tparallel # not relevant
- usestdlibvars # false-positive https://github.com/sashamelentyev/usestdlibvars/issues/96
- varnamelen # not relevant - varnamelen # not relevant
- wrapcheck - wrapcheck
- wsl_v5 # should be enabled the future.
- embeddedstructfieldcheck # should be enabled the future.
settings: settings:
depguard: depguard:

View file

@ -13,6 +13,7 @@ type AccountService service
// New Creates a new account. // New Creates a new account.
func (a *AccountService) New(req acme.Account) (acme.ExtendedAccount, error) { func (a *AccountService) New(req acme.Account) (acme.ExtendedAccount, error) {
var account acme.Account var account acme.Account
resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account) resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account)
location := getLocation(resp) location := getLocation(resp)
@ -51,10 +52,12 @@ func (a *AccountService) Get(accountURL string) (acme.Account, error) {
} }
var account acme.Account var account acme.Account
_, err := a.core.postAsGet(accountURL, &account) _, err := a.core.postAsGet(accountURL, &account)
if err != nil { if err != nil {
return acme.Account{}, err return acme.Account{}, err
} }
return account, nil return account, nil
} }
@ -65,6 +68,7 @@ func (a *AccountService) Update(accountURL string, req acme.Account) (acme.Accou
} }
var account acme.Account var account acme.Account
_, err := a.core.post(accountURL, req, &account) _, err := a.core.post(accountURL, req, &account)
if err != nil { if err != nil {
return acme.Account{}, err return acme.Account{}, err
@ -81,6 +85,7 @@ func (a *AccountService) Deactivate(accountURL string) error {
req := acme.Account{Status: acme.StatusDeactivated} req := acme.Account{Status: acme.StatusDeactivated}
_, err := a.core.post(accountURL, req, nil) _, err := a.core.post(accountURL, req, nil)
return err return err
} }

View file

@ -155,6 +155,7 @@ func getDirectory(do *sender.Doer, caDirURL string) (acme.Directory, error) {
if dir.NewAccountURL == "" { if dir.NewAccountURL == "" {
return dir, errors.New("directory missing new registration URL") return dir, errors.New("directory missing new registration URL")
} }
if dir.NewOrderURL == "" { if dir.NewOrderURL == "" {
return dir, errors.New("directory missing new order URL") return dir, errors.New("directory missing new order URL")
} }

View file

@ -15,10 +15,12 @@ func (c *AuthorizationService) Get(authzURL string) (acme.Authorization, error)
} }
var authz acme.Authorization var authz acme.Authorization
_, err := c.core.postAsGet(authzURL, &authz) _, err := c.core.postAsGet(authzURL, &authz)
if err != nil { if err != nil {
return acme.Authorization{}, err return acme.Authorization{}, err
} }
return authz, nil return authz, nil
} }
@ -29,6 +31,8 @@ func (c *AuthorizationService) Deactivate(authzURL string) error {
} }
var disabledAuth acme.Authorization var disabledAuth acme.Authorization
_, err := c.core.post(authzURL, acme.Authorization{Status: acme.StatusDeactivated}, &disabledAuth) _, err := c.core.post(authzURL, acme.Authorization{Status: acme.StatusDeactivated}, &disabledAuth)
return err return err
} }

View file

@ -17,6 +17,7 @@ func (c *ChallengeService) New(chlgURL string) (acme.ExtendedChallenge, error) {
// Challenge initiation is done by sending a JWS payload containing the trivial JSON object `{}`. // Challenge initiation is done by sending a JWS payload containing the trivial JSON object `{}`.
// We use an empty struct instance as the postJSON payload here to achieve this result. // We use an empty struct instance as the postJSON payload here to achieve this result.
var chlng acme.ExtendedChallenge var chlng acme.ExtendedChallenge
resp, err := c.core.post(chlgURL, struct{}{}, &chlng) resp, err := c.core.post(chlgURL, struct{}{}, &chlng)
if err != nil { if err != nil {
return acme.ExtendedChallenge{}, err return acme.ExtendedChallenge{}, err
@ -24,6 +25,7 @@ func (c *ChallengeService) New(chlgURL string) (acme.ExtendedChallenge, error) {
chlng.AuthorizationURL = getLink(resp.Header, "up") chlng.AuthorizationURL = getLink(resp.Header, "up")
chlng.RetryAfter = getRetryAfter(resp) chlng.RetryAfter = getRetryAfter(resp)
return chlng, nil return chlng, nil
} }
@ -34,6 +36,7 @@ func (c *ChallengeService) Get(chlgURL string) (acme.ExtendedChallenge, error) {
} }
var chlng acme.ExtendedChallenge var chlng acme.ExtendedChallenge
resp, err := c.core.postAsGet(chlgURL, &chlng) resp, err := c.core.postAsGet(chlgURL, &chlng)
if err != nil { if err != nil {
return acme.ExtendedChallenge{}, err return acme.ExtendedChallenge{}, err
@ -41,5 +44,6 @@ func (c *ChallengeService) Get(chlgURL string) (acme.ExtendedChallenge, error) {
chlng.AuthorizationURL = getLink(resp.Header, "up") chlng.AuthorizationURL = getLink(resp.Header, "up")
chlng.RetryAfter = getRetryAfter(resp) chlng.RetryAfter = getRetryAfter(resp)
return chlng, nil return chlng, nil
} }

View file

@ -11,10 +11,11 @@ import (
// Manager Manages nonces. // Manager Manages nonces.
type Manager struct { type Manager struct {
sync.Mutex
do *sender.Doer do *sender.Doer
nonceURL string nonceURL string
nonces []string nonces []string
sync.Mutex
} }
// NewManager Creates a new Manager. // NewManager Creates a new Manager.
@ -36,6 +37,7 @@ func (n *Manager) Pop() (string, bool) {
nonce := n.nonces[len(n.nonces)-1] nonce := n.nonces[len(n.nonces)-1]
n.nonces = n.nonces[:len(n.nonces)-1] n.nonces = n.nonces[:len(n.nonces)-1]
return nonce, true return nonce, true
} }
@ -43,6 +45,7 @@ func (n *Manager) Pop() (string, bool) {
func (n *Manager) Push(nonce string) { func (n *Manager) Push(nonce string) {
n.Lock() n.Lock()
defer n.Unlock() defer n.Unlock()
n.nonces = append(n.nonces, nonce) n.nonces = append(n.nonces, nonce)
} }
@ -51,6 +54,7 @@ func (n *Manager) Nonce() (string, error) {
if nonce, ok := n.Pop(); ok { if nonce, ok := n.Pop(); ok {
return nonce, nil return nonce, nil
} }
return n.getNonce() return n.getNonce()
} }

View file

@ -30,11 +30,13 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
ch := make(chan bool) ch := make(chan bool)
resultCh := make(chan bool) resultCh := make(chan bool)
go func() { go func() {
_, errN := manager.Nonce() _, errN := manager.Nonce()
if errN != nil { if errN != nil {
t.Log(errN) t.Log(errN)
} }
ch <- true ch <- true
}() }()
go func() { go func() {
@ -42,13 +44,16 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
if errN != nil { if errN != nil {
t.Log(errN) t.Log(errN)
} }
ch <- true ch <- true
}() }()
go func() { go func() {
<-ch <-ch
<-ch <-ch
resultCh <- true resultCh <- true
}() }()
select { select {
case <-resultCh: case <-resultCh:
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):

View file

@ -36,6 +36,7 @@ func (j *JWS) SetKid(kid string) {
// SignContent Signs a content with the JWS. // SignContent Signs a content with the JWS.
func (j *JWS) SignContent(url string, content []byte) (*jose.JSONWebSignature, error) { func (j *JWS) SignContent(url string, content []byte) (*jose.JSONWebSignature, error) {
var alg jose.SignatureAlgorithm var alg jose.SignatureAlgorithm
switch k := j.privKey.(type) { switch k := j.privKey.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
alg = jose.RS256 alg = jose.RS256
@ -72,12 +73,14 @@ func (j *JWS) SignContent(url string, content []byte) (*jose.JSONWebSignature, e
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to sign content: %w", err) return nil, fmt.Errorf("failed to sign content: %w", err)
} }
return signed, nil return signed, nil
} }
// SignEABContent Signs an external account binding content with the JWS. // SignEABContent Signs an external account binding content with the JWS.
func (j *JWS) SignEABContent(url, kid string, hmac []byte) (*jose.JSONWebSignature, error) { func (j *JWS) SignEABContent(url, kid string, hmac []byte) (*jose.JSONWebSignature, error) {
jwk := jose.JSONWebKey{Key: j.privKey} jwk := jose.JSONWebKey{Key: j.privKey}
jwkJSON, err := jwk.Public().MarshalJSON() jwkJSON, err := jwk.Public().MarshalJSON()
if err != nil { if err != nil {
return nil, fmt.Errorf("acme: error encoding eab jwk key: %w", err) return nil, fmt.Errorf("acme: error encoding eab jwk key: %w", err)
@ -108,6 +111,7 @@ func (j *JWS) SignEABContent(url, kid string, hmac []byte) (*jose.JSONWebSignatu
// GetKeyAuthorization Gets the key authorization for a token. // GetKeyAuthorization Gets the key authorization for a token.
func (j *JWS) GetKeyAuthorization(token string) (string, error) { func (j *JWS) GetKeyAuthorization(token string) (string, error) {
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
switch k := j.privKey.(type) { switch k := j.privKey.(type) {
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
publicKey = k.Public() publicKey = k.Public()

View file

@ -31,11 +31,13 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
ch := make(chan bool) ch := make(chan bool)
resultCh := make(chan bool) resultCh := make(chan bool)
go func() { go func() {
_, errN := manager.Nonce() _, errN := manager.Nonce()
if errN != nil { if errN != nil {
t.Log(errN) t.Log(errN)
} }
ch <- true ch <- true
}() }()
go func() { go func() {
@ -43,13 +45,16 @@ func TestNotHoldingLockWhileMakingHTTPRequests(t *testing.T) {
if errN != nil { if errN != nil {
t.Log(errN) t.Log(errN)
} }
ch <- true ch <- true
}() }()
go func() { go func() {
<-ch <-ch
<-ch <-ch
resultCh <- true resultCh <- true
}() }()
select { select {
case <-resultCh: case <-resultCh:
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):

View file

@ -127,6 +127,7 @@ func checkError(req *http.Request, resp *http.Response) error {
} }
var errorDetails *acme.ProblemDetails var errorDetails *acme.ProblemDetails
err = json.Unmarshal(body, &errorDetails) err = json.Unmarshal(body, &errorDetails)
if err != nil { if err != nil {
return fmt.Errorf("%d ::%s :: %s :: %w :: %s", resp.StatusCode, req.Method, req.URL, err, string(body)) return fmt.Errorf("%d ::%s :: %s :: %w :: %s", resp.StatusCode, req.Method, req.URL, err, string(body))
@ -150,6 +151,7 @@ func checkError(req *http.Request, resp *http.Response) error {
return errorDetails return errorDetails
} }
return nil return nil
} }

View file

@ -12,6 +12,7 @@ import (
func TestDo_UserAgentOnAllHTTPMethod(t *testing.T) { func TestDo_UserAgentOnAllHTTPMethod(t *testing.T) {
var ua, method string var ua, method string
server := httptest.NewTLSServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { server := httptest.NewTLSServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
ua = r.Header.Get("User-Agent") ua = r.Header.Get("User-Agent")
method = r.Method method = r.Method
@ -60,9 +61,11 @@ func TestDo_CustomUserAgent(t *testing.T) {
ua := doer.formatUserAgent() ua := doer.formatUserAgent()
assert.Contains(t, ua, ourUserAgent) assert.Contains(t, ua, ourUserAgent)
assert.Contains(t, ua, customUA) assert.Contains(t, ua, customUA)
if strings.HasSuffix(ua, " ") { if strings.HasSuffix(ua, " ") {
t.Errorf("UA should not have trailing spaces; got '%s'", ua) t.Errorf("UA should not have trailing spaces; got '%s'", ua)
} }
assert.Len(t, strings.Split(ua, " "), 5) assert.Len(t, strings.Split(ua, " "), 5)
} }

View file

@ -56,6 +56,7 @@ func (o *OrderService) NewWithOptions(domains []string, opts *OrderOptions) (acm
} }
var order acme.Order var order acme.Order
resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order) resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order)
if err != nil { if err != nil {
are := &acme.AlreadyReplacedError{} are := &acme.AlreadyReplacedError{}
@ -107,6 +108,7 @@ func (o *OrderService) Get(orderURL string) (acme.ExtendedOrder, error) {
} }
var order acme.Order var order acme.Order
_, err := o.core.postAsGet(orderURL, &order) _, err := o.core.postAsGet(orderURL, &order)
if err != nil { if err != nil {
return acme.ExtendedOrder{}, err return acme.ExtendedOrder{}, err
@ -122,6 +124,7 @@ func (o *OrderService) UpdateForCSR(orderURL string, csr []byte) (acme.ExtendedO
} }
var order acme.Order var order acme.Order
_, err := o.core.post(orderURL, csrMsg, &order) _, err := o.core.post(orderURL, csrMsg, &order)
if err != nil { if err != nil {
return acme.ExtendedOrder{}, err return acme.ExtendedOrder{}, err

View file

@ -32,6 +32,7 @@ func TestOrderService_NewWithOptions(t *testing.T) {
} }
order := acme.Order{} order := acme.Order{}
err = json.Unmarshal(body, &order) err = json.Unmarshal(body, &order)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest) http.Error(rw, err.Error(), http.StatusBadRequest)
@ -107,6 +108,7 @@ func readSignedBody(r *http.Request, privateKey *rsa.PrivateKey) ([]byte, error)
} }
sigAlgs := []jose.SignatureAlgorithm{jose.RS256} sigAlgs := []jose.SignatureAlgorithm{jose.RS256}
jws, err := jose.ParseSigned(string(reqBody), sigAlgs) jws, err := jose.ParseSigned(string(reqBody), sigAlgs)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -23,11 +23,13 @@ func getLinks(header http.Header, rel string) []string {
linkExpr := regexp.MustCompile(`<(.+?)>(?:;[^;]+)*?;\s*rel="(.+?)"`) linkExpr := regexp.MustCompile(`<(.+?)>(?:;[^;]+)*?;\s*rel="(.+?)"`)
var links []string var links []string
for _, link := range header["Link"] { for _, link := range header["Link"] {
for _, m := range linkExpr.FindAllStringSubmatch(link, -1) { for _, m := range linkExpr.FindAllStringSubmatch(link, -1) {
if len(m) != 3 { if len(m) != 3 {
continue continue
} }
if m[2] == rel { if m[2] == rel {
links = append(links, m[1]) links = append(links, m[1])
} }

View file

@ -84,6 +84,7 @@ type Meta struct {
// ExtendedAccount an extended Account. // ExtendedAccount an extended Account.
type ExtendedAccount struct { type ExtendedAccount struct {
Account Account
// Contains the value of the response header `Location` // Contains the value of the response header `Location`
Location string `json:"-"` Location string `json:"-"`
} }
@ -220,11 +221,11 @@ type Authorization struct {
// The timestamp after which the server will consider this authorization invalid, // The timestamp after which the server will consider this authorization invalid,
// encoded in the format specified in RFC 3339 [RFC3339]. // encoded in the format specified in RFC 3339 [RFC3339].
// This field is REQUIRED for objects with "valid" in the "status" field. // This field is REQUIRED for objects with "valid" in the "status" field.
Expires time.Time `json:"expires,omitempty"` Expires time.Time `json:"expires,omitzero"`
// identifier (required, object): // identifier (required, object):
// The identifier that the account is authorized to represent // The identifier that the account is authorized to represent
Identifier Identifier `json:"identifier,omitempty"` Identifier Identifier `json:"identifier"`
// challenges (required, array of objects): // challenges (required, array of objects):
// For pending authorizations, the challenges that the client can fulfill in order to prove possession of the identifier. // For pending authorizations, the challenges that the client can fulfill in order to prove possession of the identifier.
@ -244,6 +245,7 @@ type Authorization struct {
// ExtendedChallenge a extended Challenge. // ExtendedChallenge a extended Challenge.
type ExtendedChallenge struct { type ExtendedChallenge struct {
Challenge Challenge
// Contains the value of the response header `Retry-After` // Contains the value of the response header `Retry-After`
RetryAfter string `json:"-"` RetryAfter string `json:"-"`
// Contains the value of the response header `Link` rel="up" // Contains the value of the response header `Link` rel="up"
@ -270,7 +272,7 @@ type Challenge struct {
// The time at which the server validated this challenge, // The time at which the server validated this challenge,
// encoded in the format specified in RFC 3339 [RFC3339]. // encoded in the format specified in RFC 3339 [RFC3339].
// This field is REQUIRED if the "status" field is "valid". // This field is REQUIRED if the "status" field is "valid".
Validated time.Time `json:"validated,omitempty"` Validated time.Time `json:"validated,omitzero"`
// error (optional, object): // error (optional, object):
// Error that occurred while the server was validating the challenge, if any, // Error that occurred while the server was validating the challenge, if any,

View file

@ -2,6 +2,7 @@ package acme
import ( import (
"fmt" "fmt"
"strings"
) )
// Errors types. // Errors types.
@ -27,21 +28,25 @@ type ProblemDetails struct {
} }
func (p *ProblemDetails) Error() string { func (p *ProblemDetails) Error() string {
msg := fmt.Sprintf("acme: error: %d", p.HTTPStatus) var msg strings.Builder
msg.WriteString(fmt.Sprintf("acme: error: %d", p.HTTPStatus))
if p.Method != "" || p.URL != "" { if p.Method != "" || p.URL != "" {
msg += fmt.Sprintf(" :: %s :: %s", p.Method, p.URL) msg.WriteString(fmt.Sprintf(" :: %s :: %s", p.Method, p.URL))
} }
msg += fmt.Sprintf(" :: %s :: %s", p.Type, p.Detail)
msg.WriteString(fmt.Sprintf(" :: %s :: %s", p.Type, p.Detail))
for _, sub := range p.SubProblems { for _, sub := range p.SubProblems {
msg += fmt.Sprintf(", problem: %q :: %s", sub.Type, sub.Detail) msg.WriteString(fmt.Sprintf(", problem: %q :: %s", sub.Type, sub.Detail))
} }
if p.Instance != "" { if p.Instance != "" {
msg += ", url: " + p.Instance msg.WriteString(", url: " + p.Instance)
} }
return msg return msg.String()
} }
// SubProblem a "subproblems". // SubProblem a "subproblems".
@ -49,7 +54,7 @@ func (p *ProblemDetails) Error() string {
type SubProblem struct { type SubProblem struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"` Detail string `json:"detail,omitempty"`
Identifier Identifier `json:"identifier,omitempty"` Identifier Identifier `json:"identifier"`
} }
// NonceError represents the error which is returned // NonceError represents the error which is returned

View file

@ -57,8 +57,10 @@ type DERCertificateBytes []byte
// ParsePEMBundle parses a certificate bundle from top to bottom and returns // ParsePEMBundle parses a certificate bundle from top to bottom and returns
// a slice of x509 certificates. This function will error if no certificates are found. // a slice of x509 certificates. This function will error if no certificates are found.
func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
var certificates []*x509.Certificate var (
var certDERBlock *pem.Block certificates []*x509.Certificate
certDERBlock *pem.Block
)
for { for {
certDERBlock, bundle = pem.Decode(bundle) certDERBlock, bundle = pem.Decode(bundle)
@ -71,6 +73,7 @@ func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
certificates = append(certificates, cert) certificates = append(certificates, cert)
} }
} }
@ -152,8 +155,11 @@ type CSROptions struct {
} }
func CreateCSR(privateKey crypto.PrivateKey, opts CSROptions) ([]byte, error) { func CreateCSR(privateKey crypto.PrivateKey, opts CSROptions) ([]byte, error) {
var dnsNames []string var (
var ipAddresses []net.IP dnsNames []string
ipAddresses []net.IP
)
for _, altname := range opts.SAN { for _, altname := range opts.SAN {
if ip := net.ParseIP(altname); ip != nil { if ip := net.ParseIP(altname); ip != nil {
ipAddresses = append(ipAddresses, ip) ipAddresses = append(ipAddresses, ip)
@ -185,6 +191,7 @@ func PEMEncode(data any) []byte {
func PEMBlock(data any) *pem.Block { func PEMBlock(data any) *pem.Block {
var pemBlock *pem.Block var pemBlock *pem.Block
switch key := data.(type) { switch key := data.(type) {
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
keyBytes, _ := x509.MarshalECPrivateKey(key) keyBytes, _ := x509.MarshalECPrivateKey(key)
@ -265,6 +272,7 @@ func ExtractDomains(cert *x509.Certificate) []string {
if sanDomain == cert.Subject.CommonName { if sanDomain == cert.Subject.CommonName {
continue continue
} }
domains = append(domains, sanDomain) domains = append(domains, sanDomain)
} }
@ -316,6 +324,7 @@ func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pki
func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain string, extensions []pkix.Extension) ([]byte, error) { func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain string, extensions []pkix.Extension) ([]byte, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -179,6 +179,7 @@ func TestParsePEMPrivateKey(t *testing.T) {
// ignoring precomputed values. // ignoring precomputed values.
decoded, err := ParsePEMPrivateKey(pemPrivateKey) decoded, err := ParsePEMPrivateKey(pemPrivateKey)
require.NoError(t, err) require.NoError(t, err)
decodedRsaPrivateKey := decoded.(*rsa.PrivateKey) decodedRsaPrivateKey := decoded.(*rsa.PrivateKey)
require.True(t, decodedRsaPrivateKey.Equal(privateKey)) require.True(t, decodedRsaPrivateKey.Equal(privateKey))

View file

@ -29,6 +29,7 @@ func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authoriz
var responses []acme.Authorization var responses []acme.Authorization
failures := newObtainError() failures := newObtainError()
for range len(order.Authorizations) { for range len(order.Authorizations) {
select { select {
case res := <-resc: case res := <-resc:
@ -62,6 +63,7 @@ func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder, force boo
} }
log.Infof("Deactivating auth: %s", authzURL) log.Infof("Deactivating auth: %s", authzURL)
if c.core.Authorizations.Deactivate(authzURL) != nil { if c.core.Authorizations.Deactivate(authzURL) != nil {
log.Infof("Unable to deactivate the authorization: %s", authzURL) log.Infof("Unable to deactivate the authorization: %s", authzURL)
} }

View file

@ -198,6 +198,7 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := newObtainError() failures := newObtainError()
cert, err := c.getForOrder(domains, order, request) cert, err := c.getForOrder(domains, order, request)
if err != nil { if err != nil {
for _, auth := range authz { for _, auth := range authz {
@ -295,6 +296,7 @@ func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, requ
if privateKey == nil { if privateKey == nil {
var err error var err error
privateKey, err = certcrypto.GeneratePrivateKey(c.options.KeyType) privateKey, err = certcrypto.GeneratePrivateKey(c.options.KeyType)
if err != nil { if err != nil {
return nil, err return nil, err
@ -490,6 +492,7 @@ type RenewOptions struct {
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle. // If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
// //
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil. // For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
//
// Deprecated: use RenewWithOptions instead. // Deprecated: use RenewWithOptions instead.
func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredChain string) (*Resource, error) { func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredChain string) (*Resource, error) {
return c.RenewWithOptions(certRes, &RenewOptions{ return c.RenewWithOptions(certRes, &RenewOptions{
@ -722,6 +725,7 @@ func checkOrderStatus(order acme.ExtendedOrder) (bool, error) {
// https://www.rfc-editor.org/rfc/rfc5280.html#section-7 // https://www.rfc-editor.org/rfc/rfc5280.html#section-7
func sanitizeDomain(domains []string) []string { func sanitizeDomain(domains []string) []string {
var sanitizedDomains []string var sanitizedDomains []string
for _, domain := range domains { for _, domain := range domains {
sanitizedDomain, err := idna.ToASCII(domain) sanitizedDomain, err := idna.ToASCII(domain)
if err != nil { if err != nil {
@ -730,5 +734,6 @@ func sanitizeDomain(domains []string) []string {
sanitizedDomains = append(sanitizedDomains, sanitizedDomain) sanitizedDomains = append(sanitizedDomains, sanitizedDomain)
} }
} }
return sanitizedDomains return sanitizedDomains
} }

View file

@ -85,6 +85,7 @@ func (c *Certifier) GetRenewalInfo(req RenewalInfoRequest) (*RenewalInfoResponse
defer resp.Body.Close() defer resp.Body.Close()
var info RenewalInfoResponse var info RenewalInfoResponse
err = json.NewDecoder(resp.Body).Decode(&info) err = json.NewDecoder(resp.Body).Decode(&info)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -40,5 +40,6 @@ func GetTargetedDomain(authz acme.Authorization) string {
if authz.Wildcard { if authz.Wildcard {
return "*." + authz.Identifier.Value return "*." + authz.Identifier.Value
} }
return authz.Identifier.Value return authz.Identifier.Value
} }

View file

@ -40,6 +40,7 @@ func CondOption(condition bool, opt ChallengeOption) ChallengeOption {
return nil return nil
} }
} }
return opt return opt
} }
@ -118,6 +119,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
info := GetChallengeInfo(authz.Identifier.Value, keyAuth) info := GetChallengeInfo(authz.Identifier.Value, keyAuth)
var timeout, interval time.Duration var timeout, interval time.Duration
switch provider := c.provider.(type) { switch provider := c.provider.(type) {
case challenge.ProviderTimeout: case challenge.ProviderTimeout:
timeout, interval = provider.Timeout() timeout, interval = provider.Timeout()
@ -134,6 +136,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
if !stop || errP != nil { if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain) log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
} }
return stop, errP return stop, errP
}) })
if err != nil { if err != nil {
@ -141,6 +144,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
} }
chlng.KeyAuthorization = keyAuth chlng.KeyAuthorization = keyAuth
return c.validate(c.core, domain, chlng) return c.validate(c.core, domain, chlng)
} }
@ -165,6 +169,7 @@ func (c *Challenge) Sequential() (bool, time.Duration) {
if p, ok := c.provider.(sequential); ok { if p, ok := c.provider.(sequential); ok {
return ok, p.Sequential() return ok, p.Sequential()
} }
return false, 0 return false, 0
} }
@ -173,6 +178,7 @@ type sequential interface {
} }
// GetRecord returns a DNS record which will fulfill the `dns-01` challenge. // GetRecord returns a DNS record which will fulfill the `dns-01` challenge.
//
// Deprecated: use GetChallengeInfo instead. // Deprecated: use GetChallengeInfo instead.
func GetRecord(domain, keyAuth string) (fqdn, value string) { func GetRecord(domain, keyAuth string) (fqdn, value string) {
info := GetChallengeInfo(domain, keyAuth) info := GetChallengeInfo(domain, keyAuth)

View file

@ -18,6 +18,7 @@ func TestDNSProviderManual(t *testing.T) {
Build(t)) Build(t))
backupStdin := os.Stdin backupStdin := os.Stdin
defer func() { os.Stdin = backupStdin }() defer func() { os.Stdin = backupStdin }()
testCases := []struct { testCases := []struct {

View file

@ -184,6 +184,7 @@ func TestChallenge_Solve(t *testing.T) {
if test.preCheck != nil { if test.preCheck != nil {
options = append(options, WrapPreCheck(test.preCheck)) options = append(options, WrapPreCheck(test.preCheck))
} }
chlg := NewChallenge(core, test.validate, test.provider, options...) chlg := NewChallenge(core, test.validate, test.provider, options...)
authz := acme.Authorization{ authz := acme.Authorization{

View file

@ -19,6 +19,7 @@ func UnFqdn(name string) string {
if n != 0 && name[n-1] == '.' { if n != 0 && name[n-1] == '.' {
return name[:n-1] return name[:n-1]
} }
return name return name
} }

View file

@ -40,6 +40,7 @@ func mockResolver(t *testing.T, addr net.Addr) {
require.NoError(t, err) require.NoError(t, err)
originalDefaultNameserverPort := defaultNameserverPort originalDefaultNameserverPort := defaultNameserverPort
t.Cleanup(func() { t.Cleanup(func() {
defaultNameserverPort = originalDefaultNameserverPort defaultNameserverPort = originalDefaultNameserverPort
}) })
@ -47,6 +48,7 @@ func mockResolver(t *testing.T, addr net.Addr) {
defaultNameserverPort = port defaultNameserverPort = port
originalResolver := net.DefaultResolver originalResolver := net.DefaultResolver
t.Cleanup(func() { t.Cleanup(func() {
net.DefaultResolver = originalResolver net.DefaultResolver = originalResolver
}) })
@ -70,6 +72,7 @@ func useAsNameserver(t *testing.T, addr net.Addr) {
}) })
originalRecursiveNameservers := recursiveNameservers originalRecursiveNameservers := recursiveNameservers
t.Cleanup(func() { t.Cleanup(func() {
recursiveNameservers = originalRecursiveNameservers recursiveNameservers = originalRecursiveNameservers
}) })

View file

@ -81,6 +81,7 @@ func getNameservers(path string, defaults []string) []string {
func ParseNameservers(servers []string) []string { func ParseNameservers(servers []string) []string {
var resolvers []string var resolvers []string
for _, resolver := range servers { for _, resolver := range servers {
// ensure all servers have a port number // ensure all servers have a port number
if _, _, err := net.SplitHostPort(resolver); err != nil { if _, _, err := net.SplitHostPort(resolver); err != nil {
@ -89,6 +90,7 @@ func ParseNameservers(servers []string) []string {
resolvers = append(resolvers, resolver) resolvers = append(resolvers, resolver)
} }
} }
return resolvers return resolvers
} }
@ -132,6 +134,7 @@ func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error
if err != nil { if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err) return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
} }
return soa.primaryNs, nil return soa.primaryNs, nil
} }
@ -148,6 +151,7 @@ func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err) return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
} }
return soa.zone, nil return soa.zone, nil
} }
@ -172,8 +176,10 @@ func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error)
} }
func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error var (
var r *dns.Msg err error
r *dns.Msg
)
for domain := range DomainsSeq(fqdn) { for domain := range DomainsSeq(fqdn) {
r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true) r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
@ -229,9 +235,11 @@ func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (
return nil, &DNSError{Message: "empty list of nameservers"} return nil, &DNSError{Message: "empty list of nameservers"}
} }
var r *dns.Msg var (
var err error r *dns.Msg
var errAll error err error
errAll error
)
for _, ns := range nameservers { for _, ns := range nameservers {
r, err = sendDNSQuery(m, ns) r, err = sendDNSQuery(m, ns)
@ -264,6 +272,7 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok { if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
r, _, err := tcp.Exchange(m, ns) r, _, err := tcp.Exchange(m, ns)
if err != nil { if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}

View file

@ -29,6 +29,7 @@ func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
} }
// DisableCompletePropagationRequirement obsolete. // DisableCompletePropagationRequirement obsolete.
//
// Deprecated: use DisableAuthoritativeNssPropagationRequirement instead. // Deprecated: use DisableAuthoritativeNssPropagationRequirement instead.
func DisableCompletePropagationRequirement() ChallengeOption { func DisableCompletePropagationRequirement() ChallengeOption {
return DisableAuthoritativeNssPropagationRequirement() return DisableAuthoritativeNssPropagationRequirement()
@ -140,9 +141,11 @@ func checkNameserversPropagation(fqdn, value string, nameservers []string, addPo
var records []string var records []string
var found bool var found bool
for _, rr := range r.Answer { for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok { if txt, ok := rr.(*dns.TXT); ok {
record := strings.Join(txt.Txt, "") record := strings.Join(txt.Txt, "")
records = append(records, record) records = append(records, record)
if record == value { if record == value {
found = true found = true

View file

@ -88,6 +88,7 @@ func (m *forwardedMatcher) matches(r *http.Request, domain string) bool {
} }
host := fwds[0]["host"] host := fwds[0]["host"]
return matchDomain(host, domain) return matchDomain(host, domain)
} }
@ -99,6 +100,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
inquote := false inquote := false
pos := 0 pos := 0
l := len(s) l := len(s)
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
r := rune(s[i]) r := rune(s[i])
@ -110,6 +112,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
pos = i pos = i
inquote = false inquote = false
} }
continue continue
} }
@ -118,6 +121,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
if key == "" { if key == "" {
return nil, fmt.Errorf("unexpected quoted string as pos %d", i) return nil, fmt.Errorf("unexpected quoted string as pos %d", i)
} }
inquote = true inquote = true
pos = i + 1 pos = i + 1
@ -137,6 +141,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
val = s[pos:i] val = s[pos:i]
cur[key] = val cur[key] = val
} }
elements = append(elements, cur) elements = append(elements, cur)
cur = make(map[string]string) cur = make(map[string]string)
key = "" key = ""
@ -159,11 +164,14 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
if pos < len(s) { if pos < len(s) {
val = s[pos:] val = s[pos:]
} }
cur[key] = val cur[key] = val
} }
if len(cur) > 0 { if len(cur) > 0 {
elements = append(elements, cur) elements = append(elements, cur)
} }
return elements, nil return elements, nil
} }
@ -178,6 +186,7 @@ func skipWS(s string, i int) int {
for isWS(rune(s[i+1])) { for isWS(rune(s[i+1])) {
i++ i++
} }
return i return i
} }

View file

@ -74,6 +74,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
if err != nil { if err != nil {
return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err) return fmt.Errorf("[%s] acme: error presenting token: %w", domain, err)
} }
defer func() { defer func() {
err := c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth) err := c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth)
if err != nil { if err != nil {
@ -86,5 +87,6 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
} }
chlng.KeyAuthorization = keyAuth chlng.KeyAuthorization = keyAuth
return c.validate(c.core, domain, chlng) return c.validate(c.core, domain, chlng)
} }

View file

@ -44,6 +44,7 @@ func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer
// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests. // Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error { func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error var err error
s.listener, err = net.Listen(s.network, s.GetAddress()) s.listener, err = net.Listen(s.network, s.GetAddress())
if err != nil { if err != nil {
return fmt.Errorf("could not start HTTP server for challenge: %w", err) return fmt.Errorf("could not start HTTP server for challenge: %w", err)
@ -120,6 +121,7 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) {
} }
log.Infof("[%s] Served key authentication", domain) log.Infof("[%s] Served key authentication", domain)
return return
} }

View file

@ -88,6 +88,7 @@ func TestChallenge(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
bodyStr := string(body) bodyStr := string(body)
if bodyStr != chlng.KeyAuthorization { if bodyStr != chlng.KeyAuthorization {
@ -157,6 +158,7 @@ func TestChallengeUnix(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
bodyStr := string(body) bodyStr := string(body)
if bodyStr != chlng.KeyAuthorization { if bodyStr != chlng.KeyAuthorization {
@ -224,6 +226,7 @@ func (h *testProxyHeader) update(r *http.Request) {
if h == nil || len(h.values) == 0 { if h == nil || len(h.values) == 0 {
return return
} }
if h.name == "Host" { if h.name == "Host" {
r.Host = h.values[0] r.Host = h.values[0]
} else if h.name != "" { } else if h.name != "" {
@ -385,6 +388,7 @@ func testServeWithProxy(t *testing.T, header, extra *testProxyHeader, expectErro
if err != nil { if err != nil {
return err return err
} }
header.update(req) header.update(req)
extra.update(req) extra.update(req)
@ -402,6 +406,7 @@ func testServeWithProxy(t *testing.T, header, extra *testProxyHeader, expectErro
if err != nil { if err != nil {
return err return err
} }
bodyStr := string(body) bodyStr := string(body)
if bodyStr != chlng.KeyAuthorization { if bodyStr != chlng.KeyAuthorization {

View file

@ -16,10 +16,12 @@ func (e obtainError) Error() string {
for domain := range e { for domain := range e {
domains = append(domains, domain) domains = append(domains, domain)
} }
sort.Strings(domains) sort.Strings(domains)
for _, domain := range domains { for _, domain := range domains {
_, _ = fmt.Fprintf(buffer, "[%s] %s\n", domain, e[domain]) _, _ = fmt.Fprintf(buffer, "[%s] %s\n", domain, e[domain])
} }
return buffer.String() return buffer.String()
} }

View file

@ -50,11 +50,14 @@ func NewProber(solverManager *SolverManager) *Prober {
func (p *Prober) Solve(authorizations []acme.Authorization) error { func (p *Prober) Solve(authorizations []acme.Authorization) error {
failures := make(obtainError) failures := make(obtainError)
var authSolvers []*selectedAuthSolver var (
var authSolversSequential []*selectedAuthSolver authSolvers []*selectedAuthSolver
authSolversSequential []*selectedAuthSolver
)
// Loop through the resources, basically through the domains. // Loop through the resources, basically through the domains.
// First pass just selects a solver for each authz. // First pass just selects a solver for each authz.
for _, authz := range authorizations { for _, authz := range authorizations {
domain := challenge.GetTargetedDomain(authz) domain := challenge.GetTargetedDomain(authz)
if authz.Status == acme.StatusValid { if authz.Status == acme.StatusValid {
@ -90,6 +93,7 @@ func (p *Prober) Solve(authorizations []acme.Authorization) error {
if len(failures) > 0 { if len(failures) > 0 {
return failures return failures
} }
return nil return nil
} }
@ -102,7 +106,9 @@ func sequentialSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
err := solvr.PreSolve(authSolver.authz) err := solvr.PreSolve(authSolver.authz)
if err != nil { if err != nil {
failures[domain] = err failures[domain] = err
cleanUp(authSolver.solver, authSolver.authz) cleanUp(authSolver.solver, authSolver.authz)
continue continue
} }
} }
@ -111,7 +117,9 @@ func sequentialSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
err := authSolver.solver.Solve(authSolver.authz) err := authSolver.solver.Solve(authSolver.authz)
if err != nil { if err != nil {
failures[domain] = err failures[domain] = err
cleanUp(authSolver.solver, authSolver.authz) cleanUp(authSolver.solver, authSolver.authz)
continue continue
} }
@ -149,6 +157,7 @@ func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
// Finally solve all challenges for real // Finally solve all challenges for real
for _, authSolver := range authSolvers { for _, authSolver := range authSolvers {
authz := authSolver.authz authz := authSolver.authz
domain := challenge.GetTargetedDomain(authz) domain := challenge.GetTargetedDomain(authz)
if failures[domain] != nil { if failures[domain] != nil {
// already failed in previous loop // already failed in previous loop
@ -165,6 +174,7 @@ func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
func cleanUp(solvr solver, authz acme.Authorization) { func cleanUp(solvr solver, authz acme.Authorization) {
if solvr, ok := solvr.(cleanup); ok { if solvr, ok := solvr.(cleanup); ok {
domain := challenge.GetTargetedDomain(authz) domain := challenge.GetTargetedDomain(authz)
err := solvr.CleanUp(authz) err := solvr.CleanUp(authz)
if err != nil { if err != nil {
log.Warnf("[%s] acme: cleaning up failed: %v ", domain, err) log.Warnf("[%s] acme: cleaning up failed: %v ", domain, err)

View file

@ -71,6 +71,7 @@ func (c *SolverManager) chooseSolver(authz acme.Authorization) solver {
log.Infof("[%s] acme: use %s solver", domain, chlg.Type) log.Infof("[%s] acme: use %s solver", domain, chlg.Type)
return solvr return solvr
} }
log.Infof("[%s] acme: Could not find solver for: %s", domain, chlg.Type) log.Infof("[%s] acme: Could not find solver for: %s", domain, chlg.Type)
} }
@ -101,6 +102,7 @@ func validate(core *api.Core, domain string, chlg acme.Challenge) error {
// https://github.com/letsencrypt/boulder/blob/master/docs/acme-divergences.md#section-82 // https://github.com/letsencrypt/boulder/blob/master/docs/acme-divergences.md#section-82
ra = 5 ra = 5
} }
initialInterval := time.Duration(ra) * time.Second initialInterval := time.Duration(ra) * time.Second
ctx := context.Background() ctx := context.Background()
@ -162,6 +164,7 @@ func checkAuthorizationStatus(authz acme.Authorization) (bool, error) {
return false, fmt.Errorf("invalid authorization: %w", chlg.Err()) return false, fmt.Errorf("invalid authorization: %w", chlg.Err())
} }
} }
return false, errors.New("invalid authorization") return false, errors.New("invalid authorization")
default: default:
return false, fmt.Errorf("the server returned an unexpected authorization status: %s", authz.Status) return false, fmt.Errorf("the server returned an unexpected authorization status: %s", authz.Status)

View file

@ -260,6 +260,7 @@ func validateNoBody(privateKey *rsa.PrivateKey, r *http.Request) error {
} }
sigAlgs := []jose.SignatureAlgorithm{jose.RS256} sigAlgs := []jose.SignatureAlgorithm{jose.RS256}
jws, err := jose.ParseSigned(string(reqBody), sigAlgs) jws, err := jose.ParseSigned(string(reqBody), sigAlgs)
if err != nil { if err != nil {
return err return err
@ -276,5 +277,6 @@ func validateNoBody(privateKey *rsa.PrivateKey, r *http.Request) error {
if bodyStr := string(body); bodyStr != "{}" && bodyStr != "" { if bodyStr := string(body); bodyStr != "{}" && bodyStr != "" {
return fmt.Errorf(`expected JWS POST body "{}" or "", got %q`, bodyStr) return fmt.Errorf(`expected JWS POST body "{}" or "", got %q`, bodyStr)
} }
return nil return nil
} }

View file

@ -80,6 +80,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
if err != nil { if err != nil {
return fmt.Errorf("[%s] acme: error presenting token: %w", challenge.GetTargetedDomain(authz), err) return fmt.Errorf("[%s] acme: error presenting token: %w", challenge.GetTargetedDomain(authz), err)
} }
defer func() { defer func() {
err := c.provider.CleanUp(domain, chlng.Token, keyAuth) err := c.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil { if err != nil {
@ -92,6 +93,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
} }
chlng.KeyAuthorization = keyAuth chlng.KeyAuthorization = keyAuth
return c.validate(c.core, domain, chlng) return c.validate(c.core, domain, chlng)
} }

View file

@ -42,6 +42,7 @@ func TestChallenge(t *testing.T) {
assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions") assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions")
idx := -1 idx := -1
for i, ext := range remoteCert.Extensions { for i, ext := range remoteCert.Extensions {
if idPeAcmeIdentifierV1.Equal(ext.Id) { if idPeAcmeIdentifierV1.Equal(ext.Id) {
idx = i idx = i
@ -145,18 +146,24 @@ func TestChallengeIPaddress(t *testing.T) {
assert.True(t, net.ParseIP("127.0.0.1").Equal(remoteCert.IPAddresses[0]), "challenge certificate IPAddress ") assert.True(t, net.ParseIP("127.0.0.1").Equal(remoteCert.IPAddresses[0]), "challenge certificate IPAddress ")
assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions") assert.NotEmpty(t, remoteCert.Extensions, "Expected the challenge certificate to contain extensions")
var foundAcmeIdentifier bool var (
var extValue []byte foundAcmeIdentifier bool
extValue []byte
)
for _, ext := range remoteCert.Extensions { for _, ext := range remoteCert.Extensions {
if idPeAcmeIdentifierV1.Equal(ext.Id) { if idPeAcmeIdentifierV1.Equal(ext.Id) {
assert.True(t, ext.Critical, "Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical") assert.True(t, ext.Critical, "Expected the challenge certificate id-pe-acmeIdentifier extension to be marked as critical")
foundAcmeIdentifier = true foundAcmeIdentifier = true
extValue = ext.Value extValue = ext.Value
break break
} }
} }
require.True(t, foundAcmeIdentifier, "Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id,") require.True(t, foundAcmeIdentifier, "Expected the challenge certificate to contain an extension with the id-pe-acmeIdentifier id,")
zBytes := sha256.Sum256([]byte(chlng.KeyAuthorization)) zBytes := sha256.Sum256([]byte(chlng.KeyAuthorization))
value, err := asn1.Marshal(zBytes[:sha256.Size]) value, err := asn1.Marshal(zBytes[:sha256.Size])
require.NoError(t, err, "Expected marshaling of the keyAuth to return no error") require.NoError(t, err, "Expected marshaling of the keyAuth to return no error")

View file

@ -96,6 +96,7 @@ func (s *AccountsStorage) ExistsAccountFilePath() bool {
} else if err != nil { } else if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return true return true
} }
@ -127,6 +128,7 @@ func (s *AccountsStorage) LoadAccount(privateKey crypto.PrivateKey) *Account {
} }
var account Account var account Account
err = json.Unmarshal(fileBytes, &account) err = json.Unmarshal(fileBytes, &account)
if err != nil { if err != nil {
log.Fatalf("Could not parse file for account %s: %v", s.userID, err) log.Fatalf("Could not parse file for account %s: %v", s.userID, err)
@ -141,6 +143,7 @@ func (s *AccountsStorage) LoadAccount(privateKey crypto.PrivateKey) *Account {
} }
account.Registration = reg account.Registration = reg
err = s.Save(&account) err = s.Save(&account)
if err != nil { if err != nil {
log.Fatalf("Could not save account for %s. Registration is nil: %#v", s.userID, err) log.Fatalf("Could not save account for %s. Registration is nil: %#v", s.userID, err)
@ -163,6 +166,7 @@ func (s *AccountsStorage) GetPrivateKey(keyType certcrypto.KeyType) crypto.Priva
} }
log.Printf("Saved key to %s", accKeyPath) log.Printf("Saved key to %s", accKeyPath)
return privateKey return privateKey
} }
@ -193,6 +197,7 @@ func generatePrivateKey(file string, keyType certcrypto.KeyType) (crypto.Private
defer certOut.Close() defer certOut.Close()
pemKey := certcrypto.PEMBlock(privateKey) pemKey := certcrypto.PEMBlock(privateKey)
err = pem.Encode(certOut, pemKey) err = pem.Encode(certOut, pemKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -211,6 +216,7 @@ func loadPrivateKey(file string) (crypto.PrivateKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return privateKey, nil return privateKey, nil
} }
@ -229,5 +235,6 @@ func tryRecoverRegistration(ctx *cli.Context, privateKey crypto.PrivateKey) (*re
if err != nil { if err != nil {
return nil, err return nil, err
} }
return reg, nil return reg, nil
} }

View file

@ -158,6 +158,7 @@ func (s *CertificatesStorage) ExistsFile(domain, extension string) bool {
} else if err != nil { } else if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return true return true
} }
@ -283,6 +284,7 @@ func getCertificateChain(certRes *certificate.Resource) ([]*x509.Certificate, er
} }
var certChain []*x509.Certificate var certChain []*x509.Certificate
for chainCertPemBlock != nil { for chainCertPemBlock != nil {
chainCert, err := x509.ParseCertificate(chainCertPemBlock.Bytes) chainCert, err := x509.ParseCertificate(chainCertPemBlock.Bytes)
if err != nil { if err != nil {
@ -298,6 +300,7 @@ func getCertificateChain(certRes *certificate.Resource) ([]*x509.Certificate, er
func getPFXEncoder(pfxFormat string) (*pkcs12.Encoder, error) { func getPFXEncoder(pfxFormat string) (*pkcs12.Encoder, error) {
var encoder *pkcs12.Encoder var encoder *pkcs12.Encoder
switch pfxFormat { switch pfxFormat {
case "SHA256": case "SHA256":
encoder = pkcs12.Modern2023 encoder = pkcs12.Modern2023
@ -318,5 +321,6 @@ func sanitizedDomain(domain string) string {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return safe return safe
} }

View file

@ -67,6 +67,7 @@ func listCertificates(ctx *cli.Context) error {
if !names { if !names {
fmt.Println("No certificates found.") fmt.Println("No certificates found.")
} }
return nil return nil
} }
@ -122,6 +123,7 @@ func listAccount(ctx *cli.Context) error {
} }
fmt.Println("Found the following accounts:") fmt.Println("Found the following accounts:")
for _, filename := range matches { for _, filename := range matches {
data, err := os.ReadFile(filename) data, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -129,6 +131,7 @@ func listAccount(ctx *cli.Context) error {
} }
var account Account var account Account
err = json.Unmarshal(data, &account) err = json.Unmarshal(data, &account)
if err != nil { if err != nil {
return err return err

View file

@ -39,16 +39,20 @@ func createRenew() *cli.Command {
Before: func(ctx *cli.Context) error { Before: func(ctx *cli.Context) error {
// we require either domains or csr, but not both // we require either domains or csr, but not both
hasDomains := len(ctx.StringSlice(flgDomains)) > 0 hasDomains := len(ctx.StringSlice(flgDomains)) > 0
hasCsr := ctx.String(flgCSR) != "" hasCsr := ctx.String(flgCSR) != ""
if hasDomains && hasCsr { if hasDomains && hasCsr {
log.Fatalf("Please specify either --%s/-d or --%s/-c, but not both", flgDomains, flgCSR) log.Fatalf("Please specify either --%s/-d or --%s/-c, but not both", flgDomains, flgCSR)
} }
if !hasDomains && !hasCsr { if !hasDomains && !hasCsr {
log.Fatalf("Please specify --%s/-d (or --%s/-c if you already have a CSR)", flgDomains, flgCSR) log.Fatalf("Please specify --%s/-d (or --%s/-c if you already have a CSR)", flgDomains, flgCSR)
} }
if ctx.Bool(flgForceCertDomains) && hasCsr { if ctx.Bool(flgForceCertDomains) && hasCsr {
log.Fatalf("--%s only works with --%s/-d, --%s/-c doesn't support this option.", flgForceCertDomains, flgDomains, flgCSR) log.Fatalf("--%s only works with --%s/-d, --%s/-c doesn't support this option.", flgForceCertDomains, flgDomains, flgCSR)
} }
return nil return nil
}, },
Flags: []cli.Flag{ Flags: []cli.Flag{
@ -165,8 +169,10 @@ func renewForDomains(ctx *cli.Context, account *Account, keyType certcrypto.KeyT
cert := certificates[0] cert := certificates[0]
var ariRenewalTime *time.Time var (
var replacesCertID string ariRenewalTime *time.Time
replacesCertID string
)
var client *lego.Client var client *lego.Client
@ -208,6 +214,7 @@ func renewForDomains(ctx *cli.Context, account *Account, keyType certcrypto.KeyT
log.Infof("[%s] acme: Trying renewal with %d hours remaining", domain, int(timeLeft.Hours())) log.Infof("[%s] acme: Trying renewal with %d hours remaining", domain, int(timeLeft.Hours()))
var privateKey crypto.PrivateKey var privateKey crypto.PrivateKey
if ctx.Bool(flgReuseKey) { if ctx.Bool(flgReuseKey) {
keyBytes, errR := certsStorage.ReadFile(domain, keyExt) keyBytes, errR := certsStorage.ReadFile(domain, keyExt)
if errR != nil { if errR != nil {
@ -225,6 +232,7 @@ func renewForDomains(ctx *cli.Context, account *Account, keyType certcrypto.KeyT
if !isatty.IsTerminal(os.Stdout.Fd()) && !ctx.Bool(flgNoRandomSleep) { if !isatty.IsTerminal(os.Stdout.Fd()) && !ctx.Bool(flgNoRandomSleep) {
// https://github.com/certbot/certbot/blob/284023a1b7672be2bd4018dd7623b3b92197d4b0/certbot/certbot/_internal/renewal.py#L472 // https://github.com/certbot/certbot/blob/284023a1b7672be2bd4018dd7623b3b92197d4b0/certbot/certbot/_internal/renewal.py#L472
const jitter = 8 * time.Minute const jitter = 8 * time.Minute
rnd := rand.New(rand.NewSource(time.Now().UnixNano())) rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
sleepTime := time.Duration(rnd.Int63n(int64(jitter))) sleepTime := time.Duration(rnd.Int63n(int64(jitter)))
@ -288,8 +296,10 @@ func renewForCSR(ctx *cli.Context, account *Account, keyType certcrypto.KeyType,
cert := certificates[0] cert := certificates[0]
var ariRenewalTime *time.Time var (
var replacesCertID string ariRenewalTime *time.Time
replacesCertID string
)
var client *lego.Client var client *lego.Client
@ -408,16 +418,20 @@ func getARIRenewalTime(ctx *cli.Context, cert *x509.Certificate, domain string,
log.Warnf("[%s] acme: %v", domain, err) log.Warnf("[%s] acme: %v", domain, err)
return nil return nil
} }
log.Warnf("[%s] acme: calling renewal info endpoint: %v", domain, err) log.Warnf("[%s] acme: calling renewal info endpoint: %v", domain, err)
return nil return nil
} }
now := time.Now().UTC() now := time.Now().UTC()
renewalTime := renewalInfo.ShouldRenewAt(now, ctx.Duration(flgARIWaitToRenewDuration)) renewalTime := renewalInfo.ShouldRenewAt(now, ctx.Duration(flgARIWaitToRenewDuration))
if renewalTime == nil { if renewalTime == nil {
log.Infof("[%s] acme: renewalInfo endpoint indicates that renewal is not needed", domain) log.Infof("[%s] acme: renewalInfo endpoint indicates that renewal is not needed", domain)
return nil return nil
} }
log.Infof("[%s] acme: renewalInfo endpoint indicates that renewal is needed", domain) log.Infof("[%s] acme: renewalInfo endpoint indicates that renewal is needed", domain)
if renewalInfo.ExplanationURL != "" { if renewalInfo.ExplanationURL != "" {

View file

@ -35,13 +35,16 @@ func createRun() *cli.Command {
Before: func(ctx *cli.Context) error { Before: func(ctx *cli.Context) error {
// we require either domains or csr, but not both // we require either domains or csr, but not both
hasDomains := len(ctx.StringSlice(flgDomains)) > 0 hasDomains := len(ctx.StringSlice(flgDomains)) > 0
hasCsr := ctx.String(flgCSR) != "" hasCsr := ctx.String(flgCSR) != ""
if hasDomains && hasCsr { if hasDomains && hasCsr {
log.Fatal("Please specify either --domains/-d or --csr/-c, but not both") log.Fatal("Please specify either --domains/-d or --csr/-c, but not both")
} }
if !hasDomains && !hasCsr { if !hasDomains && !hasCsr {
log.Fatal("Please specify --domains/-d (or --csr/-c if you already have a CSR)") log.Fatal("Please specify --domains/-d (or --csr/-c if you already have a CSR)")
} }
return nil return nil
}, },
Action: run, Action: run,
@ -155,10 +158,12 @@ func handleTOS(ctx *cli.Context, client *lego.Client) bool {
} }
reader := bufio.NewReader(os.Stdin) reader := bufio.NewReader(os.Stdin)
log.Printf("Please review the TOS at %s", client.GetToSURL()) log.Printf("Please review the TOS at %s", client.GetToSURL())
for { for {
fmt.Println("Do you accept the TOS? Y/n") fmt.Println("Do you accept the TOS? Y/n")
text, err := reader.ReadString('\n') text, err := reader.ReadString('\n')
if err != nil { if err != nil {
log.Fatalf("Could not read from console: %v", err) log.Fatalf("Could not read from console: %v", err)
@ -219,6 +224,7 @@ func obtainCertificate(ctx *cli.Context, client *lego.Client) (*certificate.Reso
if ctx.IsSet(flgPrivateKey) { if ctx.IsSet(flgPrivateKey) {
var err error var err error
request.PrivateKey, err = loadPrivateKey(ctx.String(flgPrivateKey)) request.PrivateKey, err = loadPrivateKey(ctx.String(flgPrivateKey))
if err != nil { if err != nil {
return nil, fmt.Errorf("load private key: %w", err) return nil, fmt.Errorf("load private key: %w", err)
@ -247,6 +253,7 @@ func obtainCertificate(ctx *cli.Context, client *lego.Client) (*certificate.Reso
if ctx.IsSet(flgPrivateKey) { if ctx.IsSet(flgPrivateKey) {
var err error var err error
request.PrivateKey, err = loadPrivateKey(ctx.String(flgPrivateKey)) request.PrivateKey, err = loadPrivateKey(ctx.String(flgPrivateKey))
if err != nil { if err != nil {
return nil, fmt.Errorf("load private key: %w", err) return nil, fmt.Errorf("load private key: %w", err)

View file

@ -258,5 +258,6 @@ func getTime(ctx *cli.Context, name string) time.Time {
if value == nil { if value == nil {
return time.Time{} return time.Time{}
} }
return *value return *value
} }

View file

@ -34,6 +34,7 @@ func launchHook(hook string, timeout time.Duration, meta map[string]string) erro
parts := strings.Fields(hook) parts := strings.Fields(hook)
cmd := exec.CommandContext(ctxCmd, parts[0], parts[1:]...) cmd := exec.CommandContext(ctxCmd, parts[0], parts[1:]...)
cmd.Env = append(os.Environ(), metaToEnv(meta)...) cmd.Env = append(os.Environ(), metaToEnv(meta)...)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
@ -50,6 +51,7 @@ func launchHook(hook string, timeout time.Duration, meta map[string]string) erro
go func() { go func() {
<-ctxCmd.Done() <-ctxCmd.Done()
if ctxCmd.Err() != nil { if ctxCmd.Err() != nil {
_ = cmd.Process.Kill() _ = cmd.Process.Kill()
_ = stdout.Close() _ = stdout.Close()

View file

@ -26,6 +26,7 @@ func main() {
} }
var defaultPath string var defaultPath string
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err == nil { if err == nil {
defaultPath = filepath.Join(cwd, ".lego") defaultPath = filepath.Join(cwd, ".lego")

View file

@ -114,6 +114,7 @@ func getKeyType(ctx *cli.Context) certcrypto.KeyType {
} }
log.Fatalf("Unsupported KeyType: %s", keyType) log.Fatalf("Unsupported KeyType: %s", keyType)
return "" return ""
} }
@ -122,6 +123,7 @@ func getEmail(ctx *cli.Context) string {
if email == "" { if email == "" {
log.Fatalf("You have to pass an account (email address) to the program using --%s or -m", flgEmail) log.Fatalf("You have to pass an account (email address) to the program using --%s or -m", flgEmail)
} }
return email return email
} }
@ -135,6 +137,7 @@ func createNonExistingFolder(path string) error {
} else if err != nil { } else if err != nil {
return err return err
} }
return nil return nil
} }
@ -143,10 +146,12 @@ func readCSRFile(filename string) (*x509.CertificateRequest, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
raw := bytes raw := bytes
// see if we can find a PEM-encoded CSR // see if we can find a PEM-encoded CSR
var p *pem.Block var p *pem.Block
rest := bytes rest := bytes
for { for {
// decode a PEM block // decode a PEM block

View file

@ -54,18 +54,21 @@ func setupHTTPProvider(ctx *cli.Context) challenge.Provider {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return ps return ps
case ctx.IsSet(flgHTTPMemcachedHost): case ctx.IsSet(flgHTTPMemcachedHost):
ps, err := memcached.NewMemcachedProvider(ctx.StringSlice(flgHTTPMemcachedHost)) ps, err := memcached.NewMemcachedProvider(ctx.StringSlice(flgHTTPMemcachedHost))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return ps return ps
case ctx.IsSet(flgHTTPS3Bucket): case ctx.IsSet(flgHTTPS3Bucket):
ps, err := s3.NewHTTPProvider(ctx.String(flgHTTPS3Bucket)) ps, err := s3.NewHTTPProvider(ctx.String(flgHTTPS3Bucket))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return ps return ps
case ctx.IsSet(flgHTTPPort): case ctx.IsSet(flgHTTPPort):
iface := ctx.String(flgHTTPPort) iface := ctx.String(flgHTTPPort)
@ -82,12 +85,14 @@ func setupHTTPProvider(ctx *cli.Context) challenge.Provider {
if header := ctx.String(flgHTTPProxyHeader); header != "" { if header := ctx.String(flgHTTPProxyHeader); header != "" {
srv.SetProxyHeader(header) srv.SetProxyHeader(header)
} }
return srv return srv
case ctx.Bool(flgHTTP): case ctx.Bool(flgHTTP):
srv := http01.NewProviderServer("", "") srv := http01.NewProviderServer("", "")
if header := ctx.String(flgHTTPProxyHeader); header != "" { if header := ctx.String(flgHTTPProxyHeader); header != "" {
srv.SetProxyHeader(header) srv.SetProxyHeader(header)
} }
return srv return srv
default: default:
log.Fatal("Invalid HTTP challenge options.") log.Fatal("Invalid HTTP challenge options.")

View file

@ -205,6 +205,7 @@ func TestChallengeTLS_Run_Revoke_Non_ASCII(t *testing.T) {
func TestChallengeHTTP_Client_Obtain(t *testing.T) { func TestChallengeHTTP_Client_Obtain(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -222,6 +223,7 @@ func TestChallengeHTTP_Client_Obtain(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
request := certificate.ObtainRequest{ request := certificate.ObtainRequest{
@ -243,6 +245,7 @@ func TestChallengeHTTP_Client_Obtain(t *testing.T) {
func TestChallengeHTTP_Client_Obtain_profile(t *testing.T) { func TestChallengeHTTP_Client_Obtain_profile(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -260,6 +263,7 @@ func TestChallengeHTTP_Client_Obtain_profile(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
request := certificate.ObtainRequest{ request := certificate.ObtainRequest{
@ -282,6 +286,7 @@ func TestChallengeHTTP_Client_Obtain_profile(t *testing.T) {
func TestChallengeHTTP_Client_Obtain_emails_csr(t *testing.T) { func TestChallengeHTTP_Client_Obtain_emails_csr(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -299,6 +304,7 @@ func TestChallengeHTTP_Client_Obtain_emails_csr(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
request := certificate.ObtainRequest{ request := certificate.ObtainRequest{
@ -321,6 +327,7 @@ func TestChallengeHTTP_Client_Obtain_emails_csr(t *testing.T) {
func TestChallengeHTTP_Client_Obtain_notBefore_notAfter(t *testing.T) { func TestChallengeHTTP_Client_Obtain_notBefore_notAfter(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -338,6 +345,7 @@ func TestChallengeHTTP_Client_Obtain_notBefore_notAfter(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
now := time.Now().UTC() now := time.Now().UTC()
@ -368,6 +376,7 @@ func TestChallengeHTTP_Client_Obtain_notBefore_notAfter(t *testing.T) {
func TestChallengeHTTP_Client_Registration_QueryRegistration(t *testing.T) { func TestChallengeHTTP_Client_Registration_QueryRegistration(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -385,6 +394,7 @@ func TestChallengeHTTP_Client_Registration_QueryRegistration(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
resource, err := client.Registration.QueryRegistration() resource, err := client.Registration.QueryRegistration()
@ -400,6 +410,7 @@ func TestChallengeHTTP_Client_Registration_QueryRegistration(t *testing.T) {
func TestChallengeTLS_Client_Obtain(t *testing.T) { func TestChallengeTLS_Client_Obtain(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -417,6 +428,7 @@ func TestChallengeTLS_Client_Obtain(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
// https://github.com/letsencrypt/pebble/issues/285 // https://github.com/letsencrypt/pebble/issues/285
@ -443,6 +455,7 @@ func TestChallengeTLS_Client_Obtain(t *testing.T) {
func TestChallengeTLS_Client_ObtainForCSR(t *testing.T) { func TestChallengeTLS_Client_ObtainForCSR(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -460,6 +473,7 @@ func TestChallengeTLS_Client_ObtainForCSR(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
csr, err := x509.ParseCertificateRequest(createTestCSR(t)) csr, err := x509.ParseCertificateRequest(createTestCSR(t))
@ -483,6 +497,7 @@ func TestChallengeTLS_Client_ObtainForCSR(t *testing.T) {
func TestChallengeTLS_Client_ObtainForCSR_profile(t *testing.T) { func TestChallengeTLS_Client_ObtainForCSR_profile(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -500,6 +515,7 @@ func TestChallengeTLS_Client_ObtainForCSR_profile(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
csr, err := x509.ParseCertificateRequest(createTestCSR(t)) csr, err := x509.ParseCertificateRequest(createTestCSR(t))
@ -524,6 +540,7 @@ func TestChallengeTLS_Client_ObtainForCSR_profile(t *testing.T) {
func TestRegistrar_UpdateAccount(t *testing.T) { func TestRegistrar_UpdateAccount(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "./fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)

View file

@ -75,10 +75,12 @@ func TestChallengeDNS_Run(t *testing.T) {
func TestChallengeDNS_Client_Obtain(t *testing.T) { func TestChallengeDNS_Client_Obtain(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
err = os.Setenv("EXEC_PATH", "../fixtures/update-dns.sh") err = os.Setenv("EXEC_PATH", "../fixtures/update-dns.sh")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("EXEC_PATH") }() defer func() { _ = os.Unsetenv("EXEC_PATH") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -101,6 +103,7 @@ func TestChallengeDNS_Client_Obtain(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
domains := []string{testDomain2, testDomain1} domains := []string{testDomain2, testDomain1}
@ -129,10 +132,12 @@ func TestChallengeDNS_Client_Obtain(t *testing.T) {
func TestChallengeDNS_Client_Obtain_profile(t *testing.T) { func TestChallengeDNS_Client_Obtain_profile(t *testing.T) {
err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem") err := os.Setenv("LEGO_CA_CERTIFICATES", "../fixtures/certs/pebble.minica.pem")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }() defer func() { _ = os.Unsetenv("LEGO_CA_CERTIFICATES") }()
err = os.Setenv("EXEC_PATH", "../fixtures/update-dns.sh") err = os.Setenv("EXEC_PATH", "../fixtures/update-dns.sh")
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = os.Unsetenv("EXEC_PATH") }() defer func() { _ = os.Unsetenv("EXEC_PATH") }()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
@ -155,6 +160,7 @@ func TestChallengeDNS_Client_Obtain_profile(t *testing.T) {
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
require.NoError(t, err) require.NoError(t, err)
user.registration = reg user.registration = reg
domains := []string{testDomain2, testDomain1} domains := []string{testDomain2, testDomain1}

View file

@ -43,12 +43,14 @@ func (l *EnvLoader) MainTest(m *testing.M) int {
if _, e2e := os.LookupEnv("LEGO_E2E_TESTS"); !e2e { if _, e2e := os.LookupEnv("LEGO_E2E_TESTS"); !e2e {
fmt.Fprintln(os.Stderr, "skipping test: e2e tests are disabled. (no 'LEGO_E2E_TESTS' env var)") fmt.Fprintln(os.Stderr, "skipping test: e2e tests are disabled. (no 'LEGO_E2E_TESTS' env var)")
fmt.Println("PASS") fmt.Println("PASS")
return 0 return 0
} }
if _, err := exec.LookPath("git"); err != nil { if _, err := exec.LookPath("git"); err != nil {
fmt.Fprintln(os.Stderr, "skipping because git command not found") fmt.Fprintln(os.Stderr, "skipping because git command not found")
fmt.Println("PASS") fmt.Println("PASS")
return 0 return 0
} }
@ -56,6 +58,7 @@ func (l *EnvLoader) MainTest(m *testing.M) int {
if _, err := exec.LookPath(cmdNamePebble); err != nil { if _, err := exec.LookPath(cmdNamePebble); err != nil {
fmt.Fprintln(os.Stderr, "skipping because pebble binary not found") fmt.Fprintln(os.Stderr, "skipping because pebble binary not found")
fmt.Println("PASS") fmt.Println("PASS")
return 0 return 0
} }
} }
@ -64,6 +67,7 @@ func (l *EnvLoader) MainTest(m *testing.M) int {
if _, err := exec.LookPath(cmdNameChallSrv); err != nil { if _, err := exec.LookPath(cmdNameChallSrv); err != nil {
fmt.Fprintln(os.Stderr, "skipping because challtestsrv binary not found") fmt.Fprintln(os.Stderr, "skipping because challtestsrv binary not found")
fmt.Println("PASS") fmt.Println("PASS")
return 0 return 0
} }
} }
@ -76,6 +80,7 @@ func (l *EnvLoader) MainTest(m *testing.M) int {
legoBinary, tearDown, err := buildLego() legoBinary, tearDown, err := buildLego()
defer tearDown() defer tearDown()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
return 1 return 1
@ -136,6 +141,7 @@ func (l *EnvLoader) launchPebble() func() {
} }
pebble, outPebble := l.cmdPebble() pebble, outPebble := l.cmdPebble()
go func() { go func() {
err := pebble.Run() err := pebble.Run()
if err != nil { if err != nil {
@ -148,6 +154,7 @@ func (l *EnvLoader) launchPebble() func() {
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
fmt.Println(outPebble.String()) fmt.Println(outPebble.String())
} }
} }
@ -160,11 +167,13 @@ func (l *EnvLoader) cmdPebble() (*exec.Cmd, *bytes.Buffer) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
cmd.Dir = dir cmd.Dir = dir
fmt.Printf("$ %s\n", strings.Join(cmd.Args, " ")) fmt.Printf("$ %s\n", strings.Join(cmd.Args, " "))
var b bytes.Buffer var b bytes.Buffer
cmd.Stdout = &b cmd.Stdout = &b
cmd.Stderr = &b cmd.Stderr = &b
@ -173,6 +182,7 @@ func (l *EnvLoader) cmdPebble() (*exec.Cmd, *bytes.Buffer) {
func pebbleHealthCheck(options *CmdOption) { func pebbleHealthCheck(options *CmdOption) {
client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}
err := wait.For("pebble", 10*time.Second, 500*time.Millisecond, func() (bool, error) { err := wait.For("pebble", 10*time.Second, 500*time.Millisecond, func() (bool, error) {
resp, err := client.Get(options.HealthCheckURL) resp, err := client.Get(options.HealthCheckURL)
if err != nil { if err != nil {
@ -196,6 +206,7 @@ func (l *EnvLoader) launchChallSrv() func() {
} }
challtestsrv, outChalSrv := l.cmdChallSrv() challtestsrv, outChalSrv := l.cmdChallSrv()
go func() { go func() {
err := challtestsrv.Run() err := challtestsrv.Run()
if err != nil { if err != nil {
@ -208,6 +219,7 @@ func (l *EnvLoader) launchChallSrv() func() {
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
fmt.Println(outChalSrv.String()) fmt.Println(outChalSrv.String())
} }
} }
@ -218,6 +230,7 @@ func (l *EnvLoader) cmdChallSrv() (*exec.Cmd, *bytes.Buffer) {
fmt.Printf("$ %s\n", strings.Join(cmd.Args, " ")) fmt.Printf("$ %s\n", strings.Join(cmd.Args, " "))
var b bytes.Buffer var b bytes.Buffer
cmd.Stdout = &b cmd.Stdout = &b
cmd.Stderr = &b cmd.Stderr = &b
@ -229,6 +242,7 @@ func buildLego() (string, func(), error) {
if err != nil { if err != nil {
return "", func() {}, err return "", func() {}, err
} }
defer func() { _ = os.Chdir(here) }() defer func() { _ = os.Chdir(here) }()
buildPath, err := os.MkdirTemp("", "lego_test") buildPath, err := os.MkdirTemp("", "lego_test")
@ -262,6 +276,7 @@ func buildLego() (string, func(), error) {
return binary, func() { return binary, func() {
_ = os.RemoveAll(buildPath) _ = os.RemoveAll(buildPath)
CleanLegoFiles() CleanLegoFiles()
}, nil }, nil
} }
@ -283,6 +298,7 @@ func build(binary string) error {
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command(toolPath, "build", "-o", binary) cmd := exec.Command(toolPath, "build", "-o", binary)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
@ -334,6 +350,7 @@ func goTool() (string, error) {
func CleanLegoFiles() { func CleanLegoFiles() {
cmd := exec.Command("rm", "-rf", ".lego") cmd := exec.Command("rm", "-rf", ".lego")
fmt.Printf("$ %s\n", strings.Join(cmd.Args, " ")) fmt.Printf("$ %s\n", strings.Join(cmd.Args, " "))
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
fmt.Println(string(output)) fmt.Println(string(output))

View file

@ -50,6 +50,7 @@ func generate() error {
// collect output of various help pages // collect output of various help pages
var help []commandHelp var help []commandHelp
for _, args := range [][]string{ for _, args := range [][]string{
{"lego", "help"}, {"lego", "help"},
{"lego", "help", "run"}, {"lego", "help", "run"},
@ -72,7 +73,9 @@ func generate() error {
} }
err = outputTpl.Execute(f, help) err = outputTpl.Execute(f, help)
defer func() { _ = f.Close() }() defer func() { _ = f.Close() }()
if err != nil { if err != nil {
return fmt.Errorf("failed to write cli_help.toml: %w", err) return fmt.Errorf("failed to write cli_help.toml: %w", err)
} }
@ -98,9 +101,11 @@ func createStubApp() *cli.App {
func run(app *cli.App, args []string) (h commandHelp, err error) { func run(app *cli.App, args []string) (h commandHelp, err error) {
w := app.Writer w := app.Writer
defer func() { app.Writer = w }() defer func() { app.Writer = w }()
var buf bytes.Buffer var buf bytes.Buffer
app.Writer = &buf app.Writer = &buf
if err := app.Run(args); err != nil { if err := app.Run(args); err != nil {

View file

@ -116,6 +116,7 @@ func generateCLIHelp(models *descriptors.Providers) error {
defer func() { _ = file.Close() }() defer func() { _ = file.Close() }()
b := &bytes.Buffer{} b := &bytes.Buffer{}
err = template.Must( err = template.Must(
template.New(filepath.Base(cliTemplate)).Funcs(map[string]any{ template.New(filepath.Base(cliTemplate)).Funcs(map[string]any{
"safe": func(src string) string { "safe": func(src string) string {
@ -134,6 +135,7 @@ func generateCLIHelp(models *descriptors.Providers) error {
} }
_, err = file.Write(source) _, err = file.Write(source)
return err return err
} }
@ -161,6 +163,7 @@ func generateReadMe(models *descriptors.Providers) error {
if err = tpl.Execute(buffer, providers); err != nil { if err = tpl.Execute(buffer, providers); err != nil {
return err return err
} }
skip = true skip = true
} }
@ -198,8 +201,10 @@ func orderProviders(models *descriptors.Providers) [][]descriptors.Provider {
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
}) })
var matrix [][]descriptors.Provider var (
var row []descriptors.Provider matrix [][]descriptors.Provider
row []descriptors.Provider
)
for i, p := range providers { for i, p := range providers {
switch { switch {
@ -212,6 +217,7 @@ func orderProviders(models *descriptors.Providers) [][]descriptors.Provider {
for j := len(row); j < nbCol; j++ { for j := len(row); j < nbCol; j++ {
row = append(row, descriptors.Provider{}) row = append(row, descriptors.Provider{})
} }
matrix = append(matrix, row) matrix = append(matrix, row)
default: default:
@ -223,6 +229,7 @@ func orderProviders(models *descriptors.Providers) [][]descriptors.Provider {
for j := len(row); j < nbCol; j++ { for j := len(row); j < nbCol; j++ {
row = append(row, descriptors.Provider{}) row = append(row, descriptors.Provider{})
} }
matrix = append(matrix, row) matrix = append(matrix, row)
} }

View file

@ -46,6 +46,7 @@ func generate() error {
defer func() { _ = file.Close() }() defer func() { _ = file.Close() }()
b := &bytes.Buffer{} b := &bytes.Buffer{}
err = template.Must( err = template.Must(
template.New("").Funcs(map[string]any{ template.New("").Funcs(map[string]any{
"cleanName": func(src string) string { "cleanName": func(src string) string {

View file

@ -108,6 +108,7 @@ func detach(_ *cli.Context) error {
func readCurrentVersion(filename string) (*hcversion.Version, error) { func readCurrentVersion(filename string) (*hcversion.Version, error) {
fset := token.NewFileSet() fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors) file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil { if err != nil {
return nil, err return nil, err
@ -141,6 +142,7 @@ func (v visitor) Visit(n ast.Node) ast.Visitor {
if !ok { if !ok {
continue continue
} }
if len(valueSpec.Names) != 1 || len(valueSpec.Values) != 1 { if len(valueSpec.Names) != 1 || len(valueSpec.Values) != 1 {
continue continue
} }
@ -149,6 +151,7 @@ func (v visitor) Visit(n ast.Node) ast.Visitor {
if !ok { if !ok {
continue continue
} }
if va.Kind != token.STRING { if va.Kind != token.STRING {
continue continue
} }
@ -164,6 +167,7 @@ func (v visitor) Visit(n ast.Node) ast.Visitor {
default: default:
// noop // noop
} }
return v return v
} }

View file

@ -16,11 +16,13 @@ func Get(names ...string) (map[string]string, error) {
values := map[string]string{} values := map[string]string{}
var missingEnvVars []string var missingEnvVars []string
for _, envVar := range names { for _, envVar := range names {
value := GetOrFile(envVar) value := GetOrFile(envVar)
if value == "" { if value == "" {
missingEnvVars = append(missingEnvVars, envVar) missingEnvVars = append(missingEnvVars, envVar)
} }
values[envVar] = value values[envVar] = value
} }
@ -58,6 +60,7 @@ func GetWithFallback(groups ...[]string) (map[string]string, error) {
values := map[string]string{} values := map[string]string{}
var missingEnvVars []string var missingEnvVars []string
for _, names := range groups { for _, names := range groups {
if len(names) == 0 { if len(names) == 0 {
return nil, errors.New("undefined environment variable names") return nil, errors.New("undefined environment variable names")
@ -68,6 +71,7 @@ func GetWithFallback(groups ...[]string) (map[string]string, error) {
missingEnvVars = append(missingEnvVars, envVar) missingEnvVars = append(missingEnvVars, envVar)
continue continue
} }
values[envVar] = value values[envVar] = value
} }
@ -148,6 +152,7 @@ func GetOrFile(envVar string) string {
} }
fileVar := envVar + "_FILE" fileVar := envVar + "_FILE"
fileVarValue := os.Getenv(fileVar) fileVarValue := os.Getenv(fileVar)
if fileVarValue == "" { if fileVarValue == "" {
return envVarValue return envVarValue

View file

@ -42,6 +42,7 @@ func WriteJSONResponse(w http.ResponseWriter, body any) error {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(bs); err != nil { if _, err := w.Write(bs); err != nil {
return err return err
} }

View file

@ -21,6 +21,7 @@ type EnvTest struct {
// NewEnvTest Creates an EnvTest. // NewEnvTest Creates an EnvTest.
func NewEnvTest(keys ...string) *EnvTest { func NewEnvTest(keys ...string) *EnvTest {
values := make(map[string]string) values := make(map[string]string)
for _, key := range keys { for _, key := range keys {
value := os.Getenv(key) value := os.Getenv(key)
if value != "" { if value != "" {
@ -39,6 +40,7 @@ func NewEnvTest(keys ...string) *EnvTest {
func (e *EnvTest) WithDomain(key string) *EnvTest { func (e *EnvTest) WithDomain(key string) *EnvTest {
e.domainKey = key e.domainKey = key
e.domain = os.Getenv(key) e.domain = os.Getenv(key)
return e return e
} }

View file

@ -18,6 +18,7 @@ const (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
exitCode := m.Run() exitCode := m.Run()
clearEnv() clearEnv()
os.Exit(exitCode) os.Exit(exitCode)
} }
@ -39,6 +40,7 @@ func clearEnv() {
os.Unsetenv(strings.Split(key, "=")[0]) os.Unsetenv(strings.Split(key, "=")[0])
} }
} }
os.Unsetenv("EXTRA_LEGO_TEST") os.Unsetenv("EXTRA_LEGO_TEST")
} }
@ -325,6 +327,7 @@ func TestEnvTest(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer clearEnv() defer clearEnv()
applyEnv(test.envVars) applyEnv(test.envVars)
envTest := test.envTestSetup() envTest := test.envTestSetup()

View file

@ -43,6 +43,7 @@ func (l *FormLink) Bind(next http.Handler) http.Handler {
if len(form) != len(l.values)+len(l.regexes) { if len(form) != len(l.values)+len(l.regexes) {
msg := fmt.Sprintf("invalid query parameters, got %v, want %v", req.Form, l.values) msg := fmt.Sprintf("invalid query parameters, got %v, want %v", req.Form, l.values)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }
@ -52,6 +53,7 @@ func (l *FormLink) Bind(next http.Handler) http.Handler {
if !slices.Equal(v, value) { if !slices.Equal(v, value) {
msg := fmt.Sprintf("invalid %q form value, got %q, want %q", k, value, v) msg := fmt.Sprintf("invalid %q form value, got %q, want %q", k, value, v)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }
@ -61,6 +63,7 @@ func (l *FormLink) Bind(next http.Handler) http.Handler {
if !exp.MatchString(value) { if !exp.MatchString(value) {
msg := fmt.Sprintf("invalid %q form value, %q doesn't match to %q", k, value, exp) msg := fmt.Sprintf("invalid %q form value, %q doesn't match to %q", k, value, exp)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }

View file

@ -55,6 +55,7 @@ func (l *HeaderLink) Bind(next http.Handler) http.Handler {
if !exp.MatchString(value) { if !exp.MatchString(value) {
msg := fmt.Sprintf("invalid %q header value, %q doesn't match to %q", k, value, exp) msg := fmt.Sprintf("invalid %q header value, %q doesn't match to %q", k, value, exp)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }

View file

@ -32,6 +32,7 @@ func (l *QueryParameterLink) Bind(next http.Handler) http.Handler {
if len(query) != len(l.values)+len(l.regexes) { if len(query) != len(l.values)+len(l.regexes) {
msg := fmt.Sprintf("invalid query parameters, got %v, want %v", query, l.values) msg := fmt.Sprintf("invalid query parameters, got %v, want %v", query, l.values)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }
@ -41,6 +42,7 @@ func (l *QueryParameterLink) Bind(next http.Handler) http.Handler {
if p != v { if p != v {
msg := fmt.Sprintf("invalid %q query parameter value, got %q, want %q", k, p, v) msg := fmt.Sprintf("invalid %q query parameter value, got %q, want %q", k, p, v)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }
@ -50,6 +52,7 @@ func (l *QueryParameterLink) Bind(next http.Handler) http.Handler {
if !exp.MatchString(value) { if !exp.MatchString(value) {
msg := fmt.Sprintf("invalid %q query parameter value, %q doesn't match to %q", k, value, exp) msg := fmt.Sprintf("invalid %q query parameter value, %q doesn't match to %q", k, value, exp)
http.Error(rw, msg, l.statusCode) http.Error(rw, msg, l.statusCode)
return return
} }
} }

View file

@ -76,6 +76,7 @@ func (l *RequestBodyLink) Bind(next http.Handler) http.Handler {
msg := fmt.Sprintf("%s: request body differences: got: %s, want: %s", req.URL.Path, msg := fmt.Sprintf("%s: request body differences: got: %s, want: %s", req.URL.Path,
string(bytes.TrimSpace(body)), string(bytes.TrimSpace(expectedRaw))) string(bytes.TrimSpace(body)), string(bytes.TrimSpace(expectedRaw)))
http.Error(rw, msg, http.StatusBadRequest) http.Error(rw, msg, http.StatusBadRequest)
return return
} }

View file

@ -90,6 +90,7 @@ func (l *RequestBodyJSONLink) Bind(next http.Handler) http.Handler {
if err != nil { if err != nil {
msg := fmt.Sprintf("%s: the expected request body is not valid JSON: %v", req.URL.Path, err) msg := fmt.Sprintf("%s: the expected request body is not valid JSON: %v", req.URL.Path, err)
http.Error(rw, msg, http.StatusBadRequest) http.Error(rw, msg, http.StatusBadRequest)
return return
} }
@ -97,12 +98,14 @@ func (l *RequestBodyJSONLink) Bind(next http.Handler) http.Handler {
if err != nil { if err != nil {
msg := fmt.Sprintf("%s: request body is not valid JSON: %v", req.URL.Path, err) msg := fmt.Sprintf("%s: request body is not valid JSON: %v", req.URL.Path, err)
http.Error(rw, msg, http.StatusBadRequest) http.Error(rw, msg, http.StatusBadRequest)
return return
} }
if !cmp.Equal(actual, expected) { if !cmp.Equal(actual, expected) {
msg := fmt.Sprintf("%s: request body differences: %s", req.URL.Path, cmp.Diff(actual, expected)) msg := fmt.Sprintf("%s: request body differences: %s", req.URL.Path, cmp.Diff(actual, expected))
http.Error(rw, msg, http.StatusBadRequest) http.Error(rw, msg, http.StatusBadRequest)
return return
} }

View file

@ -14,13 +14,16 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er
log.Infof("Wait for %s [timeout: %s, interval: %s]", msg, timeout, interval) log.Infof("Wait for %s [timeout: %s, interval: %s]", msg, timeout, interval)
var lastErr error var lastErr error
timeUp := time.After(timeout) timeUp := time.After(timeout)
for { for {
select { select {
case <-timeUp: case <-timeUp:
if lastErr == nil { if lastErr == nil {
return fmt.Errorf("%s: time limit exceeded", msg) return fmt.Errorf("%s: time limit exceeded", msg)
} }
return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr) return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr)
default: default:
} }
@ -44,5 +47,6 @@ func Retry(ctx context.Context, operation func() error, opts ...backoff.RetryOpt
_, err := backoff.Retry(ctx, func() (any, error) { _, err := backoff.Retry(ctx, func() (any, error) {
return nil, operation() return nil, operation()
}, opts...) }, opts...)
return err return err
} }

View file

@ -19,6 +19,7 @@ func TestFor_timeout(t *testing.T) {
go func() { go func() {
c <- For("test", 3*time.Second, 1*time.Second, func() (bool, error) { c <- For("test", 3*time.Second, 1*time.Second, func() (bool, error) {
io.Add(1) io.Add(1)
if io.Load() == 1 { if io.Load() == 1 {
return false, nil return false, nil
} }

View file

@ -114,6 +114,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
} }
// NewDNSProviderClient creates an ACME-DNS DNSProvider with the given acmeDNSClient and [goacmedns.Storage]. // NewDNSProviderClient creates an ACME-DNS DNSProvider with the given acmeDNSClient and [goacmedns.Storage].
//
// Deprecated: use [NewDNSProviderConfig] instead. // Deprecated: use [NewDNSProviderConfig] instead.
func NewDNSProviderClient(client acmeDNSClient, store goacmedns.Storage) (*DNSProvider, error) { func NewDNSProviderClient(client acmeDNSClient, store goacmedns.Storage) (*DNSProvider, error) {
if client == nil { if client == nil {

View file

@ -107,6 +107,7 @@ func newMockStorage() *mockStorage {
if acct, ok := m.accounts[domain]; ok { if acct, ok := m.accounts[domain]; ok {
return acct, nil return acct, nil
} }
return goacmedns.Account{}, storage.ErrDomainNotFound return goacmedns.Account{}, storage.ErrDomainNotFound
} }

View file

@ -50,6 +50,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -124,6 +125,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -137,6 +139,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -170,6 +170,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
if err != nil { if err != nil {
return fmt.Errorf("alicloud: API call failed: %w", err) return fmt.Errorf("alicloud: API call failed: %w", err)
} }
return nil return nil
} }
@ -233,6 +234,7 @@ func (d *DNSProvider) getHostedZone(ctx context.Context, domain string) (string,
} }
var hostedZone *alidns.DescribeDomainsResponseBodyDomainsDomain var hostedZone *alidns.DescribeDomainsResponseBodyDomainsDomain
for _, zone := range domains { for _, zone := range domains {
if ptr.Deref(zone.DomainName) == dns01.UnFqdn(authZone) || ptr.Deref(zone.PunyCode) == dns01.UnFqdn(authZone) { if ptr.Deref(zone.DomainName) == dns01.UnFqdn(authZone) || ptr.Deref(zone.PunyCode) == dns01.UnFqdn(authZone) {
hostedZone = zone hostedZone = zone
@ -287,6 +289,7 @@ func (d *DNSProvider) findTxtRecords(ctx context.Context, fqdn string) ([]*alidn
records = append(records, record) records = append(records, record)
} }
} }
return records, nil return records, nil
} }

View file

@ -64,6 +64,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -142,6 +143,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -155,6 +157,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -176,6 +176,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token] recordID, ok := d.recordIDs[token]
d.recordIDsMu.Unlock() d.recordIDsMu.Unlock()
if !ok { if !ok {
return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
} }

View file

@ -53,6 +53,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -121,6 +122,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -134,6 +136,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -57,6 +57,7 @@ func (c *Client) GetDNSSettings(ctx context.Context, zone, recordID string) ([]R
} }
var g GetDNSSettingsAPIResponse var g GetDNSSettingsAPIResponse
err = c.do(req, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,6 +76,7 @@ func (c *Client) AddDNSSettings(ctx context.Context, record DNSRequest) (string,
} }
var g AddDNSSettingsAPIResponse var g AddDNSSettingsAPIResponse
err = c.do(req, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return "", err return "", err
@ -95,6 +97,7 @@ func (c *Client) DeleteDNSSettings(ctx context.Context, recordID string) (string
} }
var g DeleteDNSSettingsAPIResponse var g DeleteDNSSettingsAPIResponse
err = c.do(req, &g) err = c.do(req, &g)
if err != nil { if err != nil {
return "", err return "", err

View file

@ -17,6 +17,7 @@ func (tr Trimmer) Token() (xml.Token, error) {
if cd, ok := t.(xml.CharData); ok { if cd, ok := t.(xml.CharData); ok {
t = xml.CharData(bytes.TrimSpace(cd)) t = xml.CharData(bytes.TrimSpace(cd))
} }
return t, err return t, err
} }
@ -53,6 +54,7 @@ func decodeXML[T any](reader io.Reader) (*T, error) {
} }
var result T var result T
err = xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(raw))}).Decode(&result) err = xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(raw))}).Decode(&result)
if err != nil { if err != nil {
return nil, fmt.Errorf("decode XML response: %w", err) return nil, fmt.Errorf("decode XML response: %w", err)

View file

@ -96,6 +96,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config.APIURL != "" { if config.APIURL != "" {
var err error var err error
client.BaseURL, err = url.Parse(config.APIURL) client.BaseURL, err = url.Parse(config.APIURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("anexia: %w", err) return nil, fmt.Errorf("anexia: %w", err)

View file

@ -42,6 +42,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -102,6 +103,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -115,6 +117,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -49,6 +49,7 @@ func (c *Client) CreateRecord(ctx context.Context, zoneName string, record Recor
} }
var zone Zone var zone Zone
err = c.do(req, &zone) err = c.do(req, &zone)
if err != nil { if err != nil {
return nil, err return nil, err
@ -147,6 +148,7 @@ func parseError(req *http.Request, resp *http.Response) error {
raw, _ := io.ReadAll(resp.Body) raw, _ := io.ReadAll(resp.Body)
var errAPI APIError var errAPI APIError
err := json.Unmarshal(raw, &errAPI) err := json.Unmarshal(raw, &errAPI)
if err != nil { if err != nil {
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)

View file

@ -167,6 +167,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token] recordID, ok := d.recordIDs[token]
d.recordIDsMu.Unlock() d.recordIDsMu.Unlock()
if !ok { if !ok {
return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token)
} }

View file

@ -37,6 +37,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -104,6 +105,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -117,6 +119,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -70,6 +70,7 @@ func (c *Client) getRecords(ctx context.Context, domain, search string) ([]DNSRe
} }
response := &apiResponse[[]DNSRecord]{} response := &apiResponse[[]DNSRecord]{}
err = c.do(req, http.StatusOK, response) err = c.do(req, http.StatusOK, response)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get records %s: Domain: %s: %w", search, domain, err) return nil, fmt.Errorf("could not get records %s: Domain: %s: %w", search, domain, err)
@ -89,6 +90,7 @@ func (c *Client) CreateRecord(ctx context.Context, domain string, record DNSReco
} }
response := &apiResponse[*DNSRecord]{} response := &apiResponse[*DNSRecord]{}
err = c.do(req, http.StatusCreated, response) err = c.do(req, http.StatusCreated, response)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create record; Domain: %s: %w", domain, err) return nil, fmt.Errorf("could not create record; Domain: %s: %w", domain, err)

View file

@ -81,8 +81,10 @@ func TestClient_CreateRecord(t *testing.T) {
func TestClient_DeleteRecord(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) {
const apiKey = "myKeyC" const apiKey = "myKeyC"
const domain = "example.com" const (
const recordID = "recordId" domain = "example.com"
recordID = "recordId"
)
client := mockBuilder(apiKey). client := mockBuilder(apiKey).
Route("DELETE /cdn/4.0/domains/"+domain+"/dns-records/"+recordID, nil). Route("DELETE /cdn/4.0/domains/"+domain+"/dns-records/"+recordID, nil).

View file

@ -71,6 +71,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)

View file

@ -57,6 +57,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -131,6 +132,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -144,6 +146,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -55,6 +55,7 @@ func (c *Client) RemoveTXTRecords(ctx context.Context, domain string, records []
zoneStream := &ZoneStream{Removes: records} zoneStream := &ZoneStream{Removes: records}
_, err := c.updateZone(ctx, domain, zoneStream) _, err := c.updateZone(ctx, domain, zoneStream)
return err return err
} }

View file

@ -50,6 +50,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -120,6 +121,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -133,6 +135,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -174,6 +174,7 @@ func parseError(req *http.Request, resp *http.Response) error {
raw, _ := io.ReadAll(resp.Body) raw, _ := io.ReadAll(resp.Body)
var errAPI APIError var errAPI APIError
err := json.Unmarshal(raw, &errAPI) err := json.Unmarshal(raw, &errAPI)
if err != nil { if err != nil {
return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)

View file

@ -137,6 +137,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
record.SetTtl(int32(d.config.TTL)) record.SetTtl(int32(d.config.TTL))
var resp *idns.PostOrPutRecordResponse var resp *idns.PostOrPutRecordResponse
if existingRecord != nil { if existingRecord != nil {
// Update existing record by adding the new value to the existing ones // Update existing record by adding the new value to the existing ones
record.SetAnswersList(append(existingRecord.GetAnswersList(), info.Value)) record.SetAnswersList(append(existingRecord.GetAnswersList(), info.Value))
@ -161,6 +162,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
} }
results := resp.GetResults() results := resp.GetResults()
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
d.recordIDs[token] = results.GetId() d.recordIDs[token] = results.GetId()
d.recordIDsMu.Unlock() d.recordIDsMu.Unlock()
@ -203,6 +205,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
currentAnswers := existingRecord.GetAnswersList() currentAnswers := existingRecord.GetAnswersList()
var updatedAnswers []string var updatedAnswers []string
for _, answer := range currentAnswers { for _, answer := range currentAnswers {
if answer != info.Value { if answer != info.Value {
updatedAnswers = append(updatedAnswers, answer) updatedAnswers = append(updatedAnswers, answer)

View file

@ -40,6 +40,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -99,6 +100,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -112,6 +114,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -89,6 +89,7 @@ type DNSProvider struct {
// If the credentials are _not_ set via the environment, // If the credentials are _not_ set via the environment,
// then it will attempt to get a bearer token via the instance metadata service. // then it will attempt to get a bearer token via the instance metadata service.
// see: https://github.com/Azure/go-autorest/blob/v10.14.0/autorest/azure/auth/auth.go#L38-L42 // see: https://github.com/Azure/go-autorest/blob/v10.14.0/autorest/azure/auth/auth.go#L38-L42
//
// Deprecated: use azuredns instead. // Deprecated: use azuredns instead.
func NewDNSProvider() (*DNSProvider, error) { func NewDNSProvider() (*DNSProvider, error) {
config := NewDefaultConfig() config := NewDefaultConfig()
@ -96,6 +97,7 @@ func NewDNSProvider() (*DNSProvider, error) {
environmentName := env.GetOrFile(EnvEnvironment) environmentName := env.GetOrFile(EnvEnvironment)
if environmentName != "" { if environmentName != "" {
var environment aazure.Environment var environment aazure.Environment
switch environmentName { switch environmentName {
case "china": case "china":
environment = aazure.ChinaCloud environment = aazure.ChinaCloud
@ -124,6 +126,7 @@ func NewDNSProvider() (*DNSProvider, error) {
} }
// NewDNSProviderConfig return a DNSProvider instance configured for Azure. // NewDNSProviderConfig return a DNSProvider instance configured for Azure.
//
// Deprecated: use azuredns instead. // Deprecated: use azuredns instead.
func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config == nil { if config == nil {
@ -148,6 +151,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if subsID == "" { if subsID == "" {
return nil, errors.New("azure: SubscriptionID is missing") return nil, errors.New("azure: SubscriptionID is missing")
} }
config.SubscriptionID = subsID config.SubscriptionID = subsID
} }
@ -160,6 +164,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if resGroup == "" { if resGroup == "" {
return nil, errors.New("azure: ResourceGroup is missing") return nil, errors.New("azure: ResourceGroup is missing")
} }
config.ResourceGroup = resGroup config.ResourceGroup = resGroup
} }

View file

@ -54,6 +54,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -158,6 +159,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
} else { } else {
mux.HandleFunc("/", test.handler) mux.HandleFunc("/", test.handler)
} }
config.MetadataEndpoint = server.URL config.MetadataEndpoint = server.URL
p, err := NewDNSProviderConfig(config) p, err := NewDNSProviderConfig(config)
@ -186,6 +188,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -199,6 +202,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -54,6 +54,7 @@ func (d *dnsProviderPrivate) Present(domain, token, keyAuth string) error {
// Construct unique TXT records using map // Construct unique TXT records using map
uniqRecords := map[string]struct{}{info.Value: {}} uniqRecords := map[string]struct{}{info.Value: {}}
if rset.RecordSetProperties != nil && rset.TxtRecords != nil { if rset.RecordSetProperties != nil && rset.TxtRecords != nil {
for _, txtRecord := range *rset.TxtRecords { for _, txtRecord := range *rset.TxtRecords {
// Assume Value doesn't contain multiple strings // Assume Value doesn't contain multiple strings
@ -81,6 +82,7 @@ func (d *dnsProviderPrivate) Present(domain, token, keyAuth string) error {
if err != nil { if err != nil {
return fmt.Errorf("azure: %w", err) return fmt.Errorf("azure: %w", err)
} }
return nil return nil
} }
@ -106,6 +108,7 @@ func (d *dnsProviderPrivate) CleanUp(domain, token, keyAuth string) error {
if err != nil { if err != nil {
return fmt.Errorf("azure: %w", err) return fmt.Errorf("azure: %w", err)
} }
return nil return nil
} }

View file

@ -54,6 +54,7 @@ func (d *dnsProviderPublic) Present(domain, token, keyAuth string) error {
// Construct unique TXT records using map // Construct unique TXT records using map
uniqRecords := map[string]struct{}{info.Value: {}} uniqRecords := map[string]struct{}{info.Value: {}}
if rset.RecordSetProperties != nil && rset.TxtRecords != nil { if rset.RecordSetProperties != nil && rset.TxtRecords != nil {
for _, txtRecord := range *rset.TxtRecords { for _, txtRecord := range *rset.TxtRecords {
// Assume Value doesn't contain multiple strings // Assume Value doesn't contain multiple strings
@ -81,6 +82,7 @@ func (d *dnsProviderPublic) Present(domain, token, keyAuth string) error {
if err != nil { if err != nil {
return fmt.Errorf("azure: %w", err) return fmt.Errorf("azure: %w", err)
} }
return nil return nil
} }
@ -106,6 +108,7 @@ func (d *dnsProviderPublic) CleanUp(domain, token, keyAuth string) error {
if err != nil { if err != nil {
return fmt.Errorf("azure: %w", err) return fmt.Errorf("azure: %w", err)
} }
return nil return nil
} }

View file

@ -35,6 +35,7 @@ func TestNewDNSProvider(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv() defer envTest.RestoreEnv()
envTest.ClearEnv() envTest.ClearEnv()
envTest.Apply(test.envVars) envTest.Apply(test.envVars)
@ -61,6 +62,7 @@ func TestLivePresent(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)
@ -74,6 +76,7 @@ func TestLiveCleanUp(t *testing.T) {
} }
envTest.RestoreEnv() envTest.RestoreEnv()
provider, err := NewDNSProvider() provider, err := NewDNSProvider()
require.NoError(t, err) require.NoError(t, err)

View file

@ -51,6 +51,7 @@ func getCredentials(config *Config) (azcore.TokenCredential, error) {
if config.TenantID != "" { if config.TenantID != "" {
credOptions = &azidentity.AzureCLICredentialOptions{TenantID: config.TenantID} credOptions = &azidentity.AzureCLICredentialOptions{TenantID: config.TenantID}
} }
return azidentity.NewAzureCLICredential(credOptions) return azidentity.NewAzureCLICredential(credOptions)
case authMethodOIDC: case authMethodOIDC:

View file

@ -181,6 +181,7 @@ func (c privateZoneClient) Delete(ctx context.Context, subDomain string) (armpri
func privateUniqueRecords(recordSet armprivatedns.RecordSet, value string) map[string]struct{} { func privateUniqueRecords(recordSet armprivatedns.RecordSet, value string) map[string]struct{} {
uniqRecords := map[string]struct{}{value: {}} uniqRecords := map[string]struct{}{value: {}}
if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil { if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil {
for _, txtRecord := range recordSet.Properties.TxtRecords { for _, txtRecord := range recordSet.Properties.TxtRecords {
// Assume Value doesn't contain multiple strings // Assume Value doesn't contain multiple strings

View file

@ -179,6 +179,7 @@ func (c publicZoneClient) Delete(ctx context.Context, subDomain string) (armdns.
func publicUniqueRecords(recordSet armdns.RecordSet, value string) map[string]struct{} { func publicUniqueRecords(recordSet armdns.RecordSet, value string) map[string]struct{} {
uniqRecords := map[string]struct{}{value: {}} uniqRecords := map[string]struct{}{value: {}}
if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil { if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil {
for _, txtRecord := range recordSet.Properties.TxtRecords { for _, txtRecord := range recordSet.Properties.TxtRecords {
// Assume Value doesn't contain multiple strings // Assume Value doesn't contain multiple strings

View file

@ -46,6 +46,7 @@ func discoverDNSZones(ctx context.Context, config *Config, credentials azcore.To
} }
zones := map[string]ServiceDiscoveryZone{} zones := map[string]ServiceDiscoveryZone{}
for { for {
// create the query request // create the query request
request := armresourcegraph.QueryRequest{ request := armresourcegraph.QueryRequest{

Some files were not shown because too many files have changed in this diff Show more