diff --git a/.golangci.yml b/.golangci.yml index 6f8c15cb..e182de8f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -161,7 +161,7 @@ issues: linters: - gocyclo - funlen - - path: providers/dns/checkdomain/client.go + - path: providers/dns/checkdomain/internal/types.go text: '`payed` is a misspelling of `paid`' - path: providers/dns/namecheap/namecheap_test.go text: 'cognitive complexity (\d+) of func `TestDNSProvider_getHosts` is high' @@ -174,7 +174,7 @@ issues: text: 'yodaStyleExpr' - path: providers/dns/dns_providers.go text: 'Function name: NewDNSChallengeProviderByName,' - - path: providers/dns/sakuracloud/client.go + - path: providers/dns/sakuracloud/wrapper.go text: 'mu is a global variable' - path: providers/dns/hosttech/internal/client_test.go text: 'Duplicate words \(0\) found' diff --git a/README.md b/README.md index b35ba09e..9daf227e 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Detailed documentation is available [here](https://go-acme.github.io/lego/dns). |---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------| | [Akamai EdgeDNS](https://go-acme.github.io/lego/dns/edgedns/) | [Alibaba Cloud DNS](https://go-acme.github.io/lego/dns/alidns/) | [all-inkl](https://go-acme.github.io/lego/dns/allinkl/) | [Amazon Lightsail](https://go-acme.github.io/lego/dns/lightsail/) | | [Amazon Route 53](https://go-acme.github.io/lego/dns/route53/) | [ArvanCloud](https://go-acme.github.io/lego/dns/arvancloud/) | [Aurora DNS](https://go-acme.github.io/lego/dns/auroradns/) | [Autodns](https://go-acme.github.io/lego/dns/autodns/) | -| [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [BRANDIT](https://go-acme.github.io/lego/dns/brandit/) | +| [Azure](https://go-acme.github.io/lego/dns/azure/) | [Bindman](https://go-acme.github.io/lego/dns/bindman/) | [Bluecat](https://go-acme.github.io/lego/dns/bluecat/) | [Brandit](https://go-acme.github.io/lego/dns/brandit/) | | [Bunny](https://go-acme.github.io/lego/dns/bunny/) | [Checkdomain](https://go-acme.github.io/lego/dns/checkdomain/) | [Civo](https://go-acme.github.io/lego/dns/civo/) | [CloudDNS](https://go-acme.github.io/lego/dns/clouddns/) | | [Cloudflare](https://go-acme.github.io/lego/dns/cloudflare/) | [ClouDNS](https://go-acme.github.io/lego/dns/cloudns/) | [CloudXNS](https://go-acme.github.io/lego/dns/cloudxns/) | [ConoHa](https://go-acme.github.io/lego/dns/conoha/) | | [Constellix](https://go-acme.github.io/lego/dns/constellix/) | [deSEC.io](https://go-acme.github.io/lego/dns/desec/) | [Designate DNSaaS for Openstack](https://go-acme.github.io/lego/dns/designate/) | [Digital Ocean](https://go-acme.github.io/lego/dns/digitalocean/) | diff --git a/cmd/zz_gen_cmd_dnshelp.go b/cmd/zz_gen_cmd_dnshelp.go index f839e949..59414b82 100644 --- a/cmd/zz_gen_cmd_dnshelp.go +++ b/cmd/zz_gen_cmd_dnshelp.go @@ -335,7 +335,7 @@ func displayDNSHelp(w io.Writer, name string) error { case "brandit": // generated from: providers/dns/brandit/brandit.toml - ew.writeln(`Configuration for BRANDIT.`) + ew.writeln(`Configuration for Brandit.`) ew.writeln(`Code: 'brandit'`) ew.writeln(`Since: 'v4.11.0'`) ew.writeln() diff --git a/docs/content/dns/zz_gen_brandit.md b/docs/content/dns/zz_gen_brandit.md index 307dfa57..237c02af 100644 --- a/docs/content/dns/zz_gen_brandit.md +++ b/docs/content/dns/zz_gen_brandit.md @@ -1,5 +1,5 @@ --- -title: "BRANDIT" +title: "Brandit" date: 2019-03-03T16:39:46+01:00 draft: false slug: brandit @@ -14,7 +14,7 @@ dnsprovider: -Configuration for [BRANDIT](https://www.brandit.com/). +Configuration for [Brandit](https://www.brandit.com/). @@ -23,7 +23,7 @@ Configuration for [BRANDIT](https://www.brandit.com/). - Since: v4.11.0 -Here is an example bash command using the BRANDIT provider: +Here is an example bash command using the Brandit provider: ```bash BRANDIT_API_KEY=xxxxxxxxxxxxxxxxxxxxx \ diff --git a/docs/content/dns/zz_gen_otc.md b/docs/content/dns/zz_gen_otc.md index 836f623e..0a7136cb 100644 --- a/docs/content/dns/zz_gen_otc.md +++ b/docs/content/dns/zz_gen_otc.md @@ -61,7 +61,7 @@ More information [here]({{< ref "dns#configuration-and-credentials" >}}). ## More information -- [API documentation](https://docs.otc.t-systems.com/en-us/dns/index.html) +- [API documentation](https://docs.otc.t-systems.com/domain-name-service/api-ref/index.html) diff --git a/go.mod b/go.mod index 8353b374..9aac58bb 100644 --- a/go.mod +++ b/go.mod @@ -63,9 +63,9 @@ require ( github.com/vultr/govultr/v2 v2.17.2 github.com/yandex-cloud/go-genproto v0.0.0-20220805142335-27b56ddae16f github.com/yandex-cloud/go-sdk v0.0.0-20220805164847-cf028e604997 - golang.org/x/crypto v0.5.0 - golang.org/x/net v0.7.0 - golang.org/x/oauth2 v0.5.0 + golang.org/x/crypto v0.7.0 + golang.org/x/net v0.8.0 + golang.org/x/oauth2 v0.6.0 golang.org/x/time v0.3.0 google.golang.org/api v0.111.0 gopkg.in/ns1/ns1-go.v2 v2.6.5 @@ -126,10 +126,10 @@ require ( github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/ratelimit v0.2.0 // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + golang.org/x/tools v0.6.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230223222841-637eb2293923 // indirect google.golang.org/grpc v1.53.0 // indirect diff --git a/go.sum b/go.sum index 60f610f0..fc925961 100644 --- a/go.sum +++ b/go.sum @@ -595,8 +595,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= -golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -619,8 +619,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -650,14 +650,14 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.5.0 h1:HuArIo48skDwlrvM3sEdHXElYslAMsf3KwRkkW4MC4s= -golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= +golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= +golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -712,12 +712,12 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -726,8 +726,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -758,8 +758,8 @@ golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4X golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/platform/wait/wait.go b/platform/wait/wait.go index d0c078b8..6ad817b2 100644 --- a/platform/wait/wait.go +++ b/platform/wait/wait.go @@ -1,7 +1,6 @@ package wait import ( - "errors" "fmt" "time" @@ -18,9 +17,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er select { case <-timeUp: if lastErr == nil { - return errors.New("time limit exceeded") + return fmt.Errorf("%s: time limit exceeded", msg) } - return fmt.Errorf("time limit exceeded: last error: %w", lastErr) + return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr) default: } diff --git a/providers/dns/alidns/alidns.go b/providers/dns/alidns/alidns.go index 7520be85..23320b23 100644 --- a/providers/dns/alidns/alidns.go +++ b/providers/dns/alidns/alidns.go @@ -198,7 +198,7 @@ func (d *DNSProvider) getHostedZone(domain string) (string, error) { authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } var hostedZone alidns.DomainInDescribeDomains diff --git a/providers/dns/allinkl/allinkl.go b/providers/dns/allinkl/allinkl.go index b82ba379..6525a119 100644 --- a/providers/dns/allinkl/allinkl.go +++ b/providers/dns/allinkl/allinkl.go @@ -2,6 +2,7 @@ package allinkl import ( + "context" "errors" "fmt" "net/http" @@ -49,7 +50,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config - client *internal.Client + + identifier *internal.Identifier + client *internal.Client recordIDs map[string]string recordIDsMu sync.Mutex @@ -80,16 +83,23 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("allinkl: missing credentials") } - client := internal.NewClient(config.Login, config.Password) + identifier := internal.NewIdentifier(config.Login, config.Password) + + if config.HTTPClient != nil { + identifier.HTTPClient = config.HTTPClient + } + + client := internal.NewClient(config.Login) if config.HTTPClient != nil { client.HTTPClient = config.HTTPClient } return &DNSProvider{ - config: config, - client: client, - recordIDs: make(map[string]string), + config: config, + identifier: identifier, + client: client, + recordIDs: make(map[string]string), }, nil } @@ -105,14 +115,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("allinkl: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("allinkl: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - credential, err := d.client.Authentication(60, true) + ctx := context.Background() + + credential, err := d.identifier.Authentication(ctx, 60, true) if err != nil { return fmt.Errorf("allinkl: %w", err) } + ctx = internal.WithContext(ctx, credential) + subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) if err != nil { return fmt.Errorf("allinkl: %w", err) @@ -125,7 +139,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { RecordData: info.Value, } - recordID, err := d.client.AddDNSSettings(credential, record) + recordID, err := d.client.AddDNSSettings(ctx, record) if err != nil { return fmt.Errorf("allinkl: %w", err) } @@ -141,11 +155,15 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - credential, err := d.client.Authentication(60, true) + ctx := context.Background() + + credential, err := d.identifier.Authentication(ctx, 60, true) if err != nil { return fmt.Errorf("allinkl: %w", err) } + ctx = internal.WithContext(ctx, credential) + // gets the record's unique ID from when we created it d.recordIDsMu.Lock() recordID, ok := d.recordIDs[token] @@ -154,7 +172,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("allinkl: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - _, err = d.client.DeleteDNSSettings(credential, recordID) + _, err = d.client.DeleteDNSSettings(ctx, recordID) if err != nil { return fmt.Errorf("allinkl: %w", err) } diff --git a/providers/dns/allinkl/internal/client.go b/providers/dns/allinkl/internal/client.go index 75eefaff..87894433 100644 --- a/providers/dns/allinkl/internal/client.go +++ b/providers/dns/allinkl/internal/client.go @@ -2,126 +2,64 @@ package internal import ( "bytes" + "context" "encoding/json" - "encoding/xml" "fmt" - "io" "net/http" "strconv" "strings" + "sync" "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" "github.com/mitchellh/mapstructure" ) -const ( - authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php" - apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php" -) +const apiEndpoint = "https://kasapi.kasserver.com/soap/KasApi.php" + +type Authentication interface { + Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error) +} // Client a KAS server client. type Client struct { - login string - password string + login string - authEndpoint string - apiEndpoint string - HTTPClient *http.Client - floodTime time.Time + floodTime time.Time + muFloodTime sync.Mutex + + baseURL string + HTTPClient *http.Client } // NewClient creates a new Client. -func NewClient(login string, password string) *Client { +func NewClient(login string) *Client { return &Client{ - login: login, - password: password, - authEndpoint: authEndpoint, - apiEndpoint: apiEndpoint, - HTTPClient: &http.Client{Timeout: 10 * time.Second}, + login: login, + baseURL: apiEndpoint, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, } } -// Authentication Creates a credential token. -// - sessionLifetime: Validity of the token in seconds. -// - sessionUpdateLifetime: with `true` the session is extended with every request. -func (c Client) Authentication(sessionLifetime int, sessionUpdateLifetime bool) (string, error) { - sul := "N" - if sessionUpdateLifetime { - sul = "Y" - } - - ar := AuthRequest{ - Login: c.login, - AuthData: c.password, - AuthType: "plain", - SessionLifetime: sessionLifetime, - SessionUpdateLifetime: sul, - } - - body, err := json.Marshal(ar) - if err != nil { - return "", fmt.Errorf("request marshal: %w", err) - } - - payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body))) - - req, err := http.NewRequest(http.MethodPost, c.authEndpoint, bytes.NewReader(payload)) - if err != nil { - return "", fmt.Errorf("request creation: %w", err) - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return "", fmt.Errorf("request execution: %w", err) - } - - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data)) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("response read: %w", err) - } - - var e KasAuthEnvelope - decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))}) - err = decoder.Decode(&e) - if err != nil { - return "", fmt.Errorf("response xml decode: %w", err) - } - - if e.Body.Fault != nil { - return "", e.Body.Fault - } - - return e.Body.KasAuthResponse.Return.Text, nil -} - // GetDNSSettings Reading out the DNS settings of a zone. // - zone: host zone. // - recordID: the ID of the resource record (optional). -func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]ReturnInfo, error) { +func (c *Client) GetDNSSettings(ctx context.Context, zone, recordID string) ([]ReturnInfo, error) { requestParams := map[string]string{"zone_host": zone} if recordID != "" { requestParams["record_id"] = recordID } - item, err := c.do(credentialToken, "get_dns_settings", requestParams) + req, err := c.newRequest(ctx, "get_dns_settings", requestParams) if err != nil { return nil, err } - raw := getValue(item) - var g GetDNSSettingsAPIResponse - err = mapstructure.Decode(raw, &g) + err = c.do(req, &g) if err != nil { - return nil, fmt.Errorf("response struct decode: %w", err) + return nil, err } c.updateFloodTime(g.Response.KasFloodDelay) @@ -130,18 +68,16 @@ func (c *Client) GetDNSSettings(credentialToken, zone, recordID string) ([]Retur } // AddDNSSettings Creation of a DNS resource record. -func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (string, error) { - item, err := c.do(credentialToken, "add_dns_settings", record) +func (c *Client) AddDNSSettings(ctx context.Context, record DNSRequest) (string, error) { + req, err := c.newRequest(ctx, "add_dns_settings", record) if err != nil { return "", err } - raw := getValue(item) - var g AddDNSSettingsAPIResponse - err = mapstructure.Decode(raw, &g) + err = c.do(req, &g) if err != nil { - return "", fmt.Errorf("response struct decode: %w", err) + return "", err } c.updateFloodTime(g.Response.KasFloodDelay) @@ -150,20 +86,18 @@ func (c *Client) AddDNSSettings(credentialToken string, record DNSRequest) (stri } // DeleteDNSSettings Deleting a DNS Resource Record. -func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, error) { +func (c *Client) DeleteDNSSettings(ctx context.Context, recordID string) (bool, error) { requestParams := map[string]string{"record_id": recordID} - item, err := c.do(credentialToken, "delete_dns_settings", requestParams) + req, err := c.newRequest(ctx, "delete_dns_settings", requestParams) if err != nil { return false, err } - raw := getValue(item) - var g DeleteDNSSettingsAPIResponse - err = mapstructure.Decode(raw, &g) + err = c.do(req, &g) if err != nil { - return false, fmt.Errorf("response struct decode: %w", err) + return false, err } c.updateFloodTime(g.Response.KasFloodDelay) @@ -171,65 +105,72 @@ func (c *Client) DeleteDNSSettings(credentialToken, recordID string) (bool, erro return g.Response.ReturnInfo, nil } -func (c Client) do(credentialToken, action string, requestParams interface{}) (*Item, error) { - time.Sleep(time.Until(c.floodTime)) - +func (c *Client) newRequest(ctx context.Context, action string, requestParams any) (*http.Request, error) { ar := KasRequest{ Login: c.login, AuthType: "session", - AuthData: credentialToken, + AuthData: getToken(ctx), Action: action, RequestParams: requestParams, } body, err := json.Marshal(ar) if err != nil { - return nil, fmt.Errorf("request marshal: %w", err) + return nil, fmt.Errorf("failed to create request JSON body: %w", err) } payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAPIEnvelope, body))) - req, err := http.NewRequest(http.MethodPost, c.apiEndpoint, bytes.NewReader(payload)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(payload)) if err != nil { - return nil, fmt.Errorf("request creation: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } + return req, nil +} + +func (c *Client) do(req *http.Request, result any) error { + c.muFloodTime.Lock() + time.Sleep(time.Until(c.floodTime)) + c.muFloodTime.Unlock() + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("request execution: %w", err) + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("invalid status code: %d %s", resp.StatusCode, string(data)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - data, err := io.ReadAll(resp.Body) + envlp, err := decodeXML[KasAPIResponseEnvelope](resp.Body) if err != nil { - return nil, fmt.Errorf("response read: %w", err) + return err } - var e KasAPIResponseEnvelope - decoder := xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(data))}) - err = decoder.Decode(&e) + if envlp.Body.Fault != nil { + return envlp.Body.Fault + } + + raw := getValue(envlp.Body.KasAPIResponse.Return) + + err = mapstructure.Decode(raw, result) if err != nil { - return nil, fmt.Errorf("response xml decode: %w", err) + return fmt.Errorf("response struct decode: %w", err) } - if e.Body.Fault != nil { - return nil, e.Body.Fault - } - - return e.Body.KasAPIResponse.Return, nil + return nil } func (c *Client) updateFloodTime(delay float64) { + c.muFloodTime.Lock() c.floodTime = time.Now().Add(time.Duration(delay * float64(time.Second))) + c.muFloodTime.Unlock() } -func getValue(item *Item) interface{} { +func getValue(item *Item) any { switch { case item.Raw != "": v, _ := strconv.ParseBool(item.Raw) @@ -253,7 +194,7 @@ func getValue(item *Item) interface{} { return getValue(item.Value) case len(item.Items) > 0 && item.Type == "SOAP-ENC:Array": - var v []interface{} + var v []any for _, i := range item.Items { v = append(v, getValue(i)) } @@ -261,7 +202,7 @@ func getValue(item *Item) interface{} { return v case len(item.Items) > 0: - v := map[string]interface{}{} + v := map[string]any{} for _, i := range item.Items { v[getKey(i)] = getValue(i) } diff --git a/providers/dns/allinkl/internal/client_test.go b/providers/dns/allinkl/internal/client_test.go index e2b51d1e..3eb7c21a 100644 --- a/providers/dns/allinkl/internal/client_test.go +++ b/providers/dns/allinkl/internal/client_test.go @@ -13,36 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_Authentication(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - mux.HandleFunc("/", testHandler("auth.xml")) - - client := NewClient("user", "secret") - client.authEndpoint = server.URL - - credentialToken, err := client.Authentication(60, false) - require.NoError(t, err) - - assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken) -} - -func TestClient_Authentication_error(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - mux.HandleFunc("/", testHandler("auth_fault.xml")) - - client := NewClient("user", "secret") - client.authEndpoint = server.URL - - _, err := client.Authentication(60, false) - require.Error(t, err) -} - func TestClient_GetDNSSettings(t *testing.T) { mux := http.NewServeMux() server := httptest.NewServer(mux) @@ -50,12 +20,10 @@ func TestClient_GetDNSSettings(t *testing.T) { mux.HandleFunc("/", testHandler("get_dns_settings.xml")) - client := NewClient("user", "secret") - client.apiEndpoint = server.URL + client := NewClient("user") + client.baseURL = server.URL - token := "sha1secret" - - records, err := client.GetDNSSettings(token, "example.com", "") + records, err := client.GetDNSSettings(mockContext(), "example.com", "") require.NoError(t, err) expected := []ReturnInfo{ @@ -134,10 +102,8 @@ func TestClient_AddDNSSettings(t *testing.T) { mux.HandleFunc("/", testHandler("add_dns_settings.xml")) - client := NewClient("user", "secret") - client.apiEndpoint = server.URL - - token := "sha1secret" + client := NewClient("user") + client.baseURL = server.URL record := DNSRequest{ ZoneHost: "42cnc.de.", @@ -146,7 +112,7 @@ func TestClient_AddDNSSettings(t *testing.T) { RecordData: "abcdefgh", } - recordID, err := client.AddDNSSettings(token, record) + recordID, err := client.AddDNSSettings(mockContext(), record) require.NoError(t, err) assert.Equal(t, "57347444", recordID) @@ -159,12 +125,10 @@ func TestClient_DeleteDNSSettings(t *testing.T) { mux.HandleFunc("/", testHandler("delete_dns_settings.xml")) - client := NewClient("user", "secret") - client.apiEndpoint = server.URL + client := NewClient("user") + client.baseURL = server.URL - token := "sha1secret" - - r, err := client.DeleteDNSSettings(token, "57347450") + r, err := client.DeleteDNSSettings(mockContext(), "57347450") require.NoError(t, err) assert.True(t, r) diff --git a/providers/dns/allinkl/internal/identity.go b/providers/dns/allinkl/internal/identity.go new file mode 100644 index 00000000..4353ece3 --- /dev/null +++ b/providers/dns/allinkl/internal/identity.go @@ -0,0 +1,104 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// authEndpoint represents the Identity API endpoint to call. +const authEndpoint = "https://kasapi.kasserver.com/soap/KasAuth.php" + +type token string + +const tokenKey token = "token" + +// Identifier generates credential tokens. +type Identifier struct { + login string + password string + + authEndpoint string + HTTPClient *http.Client +} + +// NewIdentifier creates a new Identifier. +func NewIdentifier(login string, password string) *Identifier { + return &Identifier{ + login: login, + password: password, + authEndpoint: authEndpoint, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } +} + +// Authentication Creates a credential token. +// - sessionLifetime: Validity of the token in seconds. +// - sessionUpdateLifetime: with `true` the session is extended with every request. +func (c *Identifier) Authentication(ctx context.Context, sessionLifetime int, sessionUpdateLifetime bool) (string, error) { + sul := "N" + if sessionUpdateLifetime { + sul = "Y" + } + + ar := AuthRequest{ + Login: c.login, + AuthData: c.password, + AuthType: "plain", + SessionLifetime: sessionLifetime, + SessionUpdateLifetime: sul, + } + + body, err := json.Marshal(ar) + if err != nil { + return "", fmt.Errorf("failed to create request JSON body: %w", err) + } + + payload := []byte(strings.TrimSpace(fmt.Sprintf(kasAuthEnvelope, body))) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.authEndpoint, bytes.NewReader(payload)) + if err != nil { + return "", fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return "", errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + envlp, err := decodeXML[KasAuthEnvelope](resp.Body) + if err != nil { + return "", err + } + + if envlp.Body.Fault != nil { + return "", envlp.Body.Fault + } + + return envlp.Body.KasAuthResponse.Return.Text, nil +} + +func WithContext(ctx context.Context, credential string) context.Context { + return context.WithValue(ctx, tokenKey, credential) +} + +func getToken(ctx context.Context) string { + credential, ok := ctx.Value(tokenKey).(string) + if !ok { + return "" + } + + return credential +} diff --git a/providers/dns/allinkl/internal/identity_test.go b/providers/dns/allinkl/internal/identity_test.go new file mode 100644 index 00000000..0753f386 --- /dev/null +++ b/providers/dns/allinkl/internal/identity_test.go @@ -0,0 +1,45 @@ +package internal + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockContext() context.Context { + return context.WithValue(context.Background(), tokenKey, "593959ca04f0de9689b586c6a647d15d") +} + +func TestIdentifier_Authentication(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", testHandler("auth.xml")) + + client := NewIdentifier("user", "secret") + client.authEndpoint = server.URL + + credentialToken, err := client.Authentication(context.Background(), 60, false) + require.NoError(t, err) + + assert.Equal(t, "593959ca04f0de9689b586c6a647d15d", credentialToken) +} + +func TestIdentifier_Authentication_error(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", testHandler("auth_fault.xml")) + + client := NewIdentifier("user", "secret") + client.authEndpoint = server.URL + + _, err := client.Authentication(context.Background(), 60, false) + require.Error(t, err) +} diff --git a/providers/dns/allinkl/internal/types.go b/providers/dns/allinkl/internal/types.go index ac2ddd39..b5c6ba0d 100644 --- a/providers/dns/allinkl/internal/types.go +++ b/providers/dns/allinkl/internal/types.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/xml" "fmt" + "io" ) // Trimmer trim all XML fields. @@ -44,3 +45,18 @@ type Item struct { Value *Item `xml:"value" json:"value,omitempty"` Items []*Item `xml:"item" json:"item,omitempty"` } + +func decodeXML[T any](reader io.Reader) (*T, error) { + raw, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + + var result T + err = xml.NewTokenDecoder(Trimmer{decoder: xml.NewDecoder(bytes.NewReader(raw))}).Decode(&result) + if err != nil { + return nil, fmt.Errorf("decode XML response: %w", err) + } + + return &result, nil +} diff --git a/providers/dns/allinkl/internal/types_api.go b/providers/dns/allinkl/internal/types_api.go index 49db25a3..9207dc1a 100644 --- a/providers/dns/allinkl/internal/types_api.go +++ b/providers/dns/allinkl/internal/types_api.go @@ -35,7 +35,7 @@ type KasRequest struct { // Action API function. Action string `json:"kas_action,omitempty"` // RequestParams Parameters to the API function. - RequestParams interface{} `json:"KasRequestParams,omitempty"` + RequestParams any `json:"KasRequestParams,omitempty"` } type DNSRequest struct { @@ -64,13 +64,13 @@ type GetDNSSettingsResponse struct { } type ReturnInfo struct { - ID interface{} `json:"record_id,omitempty" mapstructure:"record_id"` - Zone string `json:"record_zone,omitempty" mapstructure:"record_zone"` - Name string `json:"record_name,omitempty" mapstructure:"record_name"` - Type string `json:"record_type,omitempty" mapstructure:"record_type"` - Data string `json:"record_data,omitempty" mapstructure:"record_data"` - Changeable string `json:"record_changeable,omitempty" mapstructure:"record_changeable"` - Aux int `json:"record_aux,omitempty" mapstructure:"record_aux"` + ID any `json:"record_id,omitempty" mapstructure:"record_id"` + Zone string `json:"record_zone,omitempty" mapstructure:"record_zone"` + Name string `json:"record_name,omitempty" mapstructure:"record_name"` + Type string `json:"record_type,omitempty" mapstructure:"record_type"` + Data string `json:"record_data,omitempty" mapstructure:"record_data"` + Changeable string `json:"record_changeable,omitempty" mapstructure:"record_changeable"` + Aux int `json:"record_aux,omitempty" mapstructure:"record_aux"` } type AddDNSSettingsAPIResponse struct { diff --git a/providers/dns/arvancloud/arvancloud.go b/providers/dns/arvancloud/arvancloud.go index 6e5935b4..dde75724 100644 --- a/providers/dns/arvancloud/arvancloud.go +++ b/providers/dns/arvancloud/arvancloud.go @@ -2,6 +2,7 @@ package arvancloud import ( + "context" "errors" "fmt" "net/http" @@ -108,11 +109,13 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - authZone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + authZone = dns01.UnFqdn(authZone) + subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) if err != nil { return fmt.Errorf("arvancloud: %w", err) @@ -131,7 +134,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { }, } - newRecord, err := d.client.CreateRecord(authZone, record) + newRecord, err := d.client.CreateRecord(context.Background(), authZone, record) if err != nil { return fmt.Errorf("arvancloud: failed to add TXT record: fqdn=%s: %w", info.EffectiveFQDN, err) } @@ -147,11 +150,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - authZone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("arvancloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + authZone = dns01.UnFqdn(authZone) + // gets the record's unique ID from when we created it d.recordIDsMu.Lock() recordID, ok := d.recordIDs[token] @@ -160,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("arvancloud: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - if err := d.client.DeleteRecord(authZone, recordID); err != nil { + if err := d.client.DeleteRecord(context.Background(), authZone, recordID); err != nil { return fmt.Errorf("arvancloud: failed to delate TXT record: id=%s: %w", recordID, err) } @@ -171,12 +176,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } - -func getZone(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/arvancloud/internal/client.go b/providers/dns/arvancloud/internal/client.go index 9cf5b85a..3caff392 100644 --- a/providers/dns/arvancloud/internal/client.go +++ b/providers/dns/arvancloud/internal/client.go @@ -2,39 +2,45 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // defaultBaseURL represents the API endpoint to call. const defaultBaseURL = "https://napi.arvancloud.ir" -const authHeader = "Authorization" +const authorizationHeader = "Authorization" // Client the ArvanCloud client. type Client struct { - HTTPClient *http.Client - BaseURL string - apiKey string + + baseURL *url.URL + HTTPClient *http.Client } -// NewClient Creates a new ArvanCloud client. +// NewClient Creates a new Client. func NewClient(apiKey string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetTxtRecord gets a TXT record. -func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) { - records, err := c.getRecords(domain, name) +func (c *Client) GetTxtRecord(ctx context.Context, domain, name, value string) (*DNSRecord, error) { + records, err := c.getRecords(ctx, domain, name) if err != nil { return nil, err } @@ -49,11 +55,8 @@ func (c *Client) GetTxtRecord(domain, name, value string) (*DNSRecord, error) { } // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.list -func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) { - endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records") - if err != nil { - return nil, fmt.Errorf("failed to create endpoint: %w", err) - } +func (c *Client) getRecords(ctx context.Context, domain, search string) ([]DNSRecord, error) { + endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records") if search != "" { query := endpoint.Query() @@ -61,123 +64,110 @@ func (c *Client) getRecords(domain, search string) ([]DNSRecord, error) { endpoint.RawQuery = query.Encode() } - resp, err := c.do(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) + response := &apiResponse[[]DNSRecord]{} + err = c.do(req, http.StatusOK, response) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("could not get records %s: Domain: %s: %w", search, domain, err) } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("could not get records %s: Domain: %s; Status: %s; Body: %s", - search, domain, resp.Status, string(body)) - } - - response := &apiResponse{} - err = json.Unmarshal(body, response) - if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) - } - - var records []DNSRecord - err = json.Unmarshal(response.Data, &records) - if err != nil { - return nil, fmt.Errorf("failed to decode records: %w", err) - } - - return records, nil + return response.Data, nil } // CreateRecord creates a DNS record. // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.create -func (c *Client) CreateRecord(domain string, record DNSRecord) (*DNSRecord, error) { - reqBody, err := json.Marshal(record) +func (c *Client) CreateRecord(ctx context.Context, domain string, record DNSRecord) (*DNSRecord, error) { + endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, err } - endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records") + response := &apiResponse[*DNSRecord]{} + err = c.do(req, http.StatusCreated, response) if err != nil { - return nil, fmt.Errorf("failed to create endpoint: %w", err) + return nil, fmt.Errorf("could not create record; Domain: %s: %w", domain, err) } - resp, err := c.do(http.MethodPost, endpoint.String(), bytes.NewReader(reqBody)) - if err != nil { - return nil, err - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusCreated { - return nil, fmt.Errorf("could not create record %s; Domain: %s; Status: %s; Body: %s", string(reqBody), domain, resp.Status, string(body)) - } - - response := &apiResponse{} - err = json.Unmarshal(body, response) - if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) - } - - var newRecord DNSRecord - err = json.Unmarshal(response.Data, &newRecord) - if err != nil { - return nil, fmt.Errorf("failed to decode record: %w", err) - } - - return &newRecord, nil + return response.Data, nil } // DeleteRecord deletes a DNS record. // https://www.arvancloud.ir/docs/api/cdn/4.0#operation/dns_records.remove -func (c *Client) DeleteRecord(domain, id string) error { - endpoint, err := c.createEndpoint("cdn", "4.0", "domains", domain, "dns-records", id) - if err != nil { - return fmt.Errorf("failed to create endpoint: %w", err) - } +func (c *Client) DeleteRecord(ctx context.Context, domain, id string) error { + endpoint := c.baseURL.JoinPath("cdn", "4.0", "domains", domain, "dns-records", id) - resp, err := c.do(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("could not delete record %s; Domain: %s; Status: %s; Body: %s", id, domain, resp.Status, string(body)) + err = c.do(req, http.StatusOK, nil) + if err != nil { + return fmt.Errorf("could not delete record %s; Domain: %s: %w", id, domain, err) } return nil } -func (c *Client) do(method, endpoint string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest(method, endpoint, body) +func (c *Client) do(req *http.Request, expectedStatus int, result any) error { + req.Header.Set(authorizationHeader, c.apiKey) + + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, err + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != expectedStatus { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - if body != nil { + + if payload != nil { req.Header.Set("Content-Type", "application/json") } - req.Header.Set(authHeader, c.apiKey) - return c.HTTPClient.Do(req) -} - -func (c *Client) createEndpoint(parts ...string) (*url.URL, error) { - baseURL, err := url.Parse(c.BaseURL) - if err != nil { - return nil, err - } - - return baseURL.JoinPath(parts...), nil + return req, nil } func equalsTXTRecord(record DNSRecord, name, value string) bool { @@ -189,7 +179,7 @@ func equalsTXTRecord(record DNSRecord, name, value string) bool { return false } - data, ok := record.Value.(map[string]interface{}) + data, ok := record.Value.(map[string]any) if !ok { return false } diff --git a/providers/dns/arvancloud/internal/client_test.go b/providers/dns/arvancloud/internal/client_test.go index f21311ae..5c9154c6 100644 --- a/providers/dns/arvancloud/internal/client_test.go +++ b/providers/dns/arvancloud/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -12,21 +14,34 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_GetTxtRecord(t *testing.T) { +func setupTest(t *testing.T, apiKey string) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) - const domain = "example.com" + client := NewClient(apiKey) + client.baseURL, _ = url.Parse(server.URL) + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_GetTxtRecord(t *testing.T) { const apiKey = "myKeyA" + client, mux := setupTest(t, apiKey) + + const domain = "example.com" + mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) return } - auth := req.Header.Get(authHeader) + auth := req.Header.Get(authorizationHeader) if auth != apiKey { http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) return @@ -46,20 +61,16 @@ func TestClient_GetTxtRecord(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - - _, err := client.GetTxtRecord(domain, "_acme-challenge", "txtxtxt") + _, err := client.GetTxtRecord(context.Background(), domain, "_acme-challenge", "txtxtxt") require.NoError(t, err) } func TestClient_CreateRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + const apiKey = "myKeyB" + + client, mux := setupTest(t, apiKey) const domain = "example.com" - const apiKey = "myKeyB" mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -67,7 +78,7 @@ func TestClient_CreateRecord(t *testing.T) { return } - auth := req.Header.Get(authHeader) + auth := req.Header.Get(authorizationHeader) if auth != apiKey { http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) return @@ -88,9 +99,6 @@ func TestClient_CreateRecord(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - record := DNSRecord{ Name: "_acme-challenge", Type: "txt", @@ -98,7 +106,7 @@ func TestClient_CreateRecord(t *testing.T) { TTL: 600, } - newRecord, err := client.CreateRecord(domain, record) + newRecord, err := client.CreateRecord(context.Background(), domain, record) require.NoError(t, err) expected := &DNSRecord{ @@ -119,12 +127,11 @@ func TestClient_CreateRecord(t *testing.T) { } func TestClient_DeleteRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + const apiKey = "myKeyC" + + client, mux := setupTest(t, apiKey) const domain = "example.com" - const apiKey = "myKeyC" const recordID = "recordId" mux.HandleFunc("/cdn/4.0/domains/"+domain+"/dns-records/"+recordID, func(rw http.ResponseWriter, req *http.Request) { @@ -133,16 +140,13 @@ func TestClient_DeleteRecord(t *testing.T) { return } - auth := req.Header.Get(authHeader) + auth := req.Header.Get(authorizationHeader) if auth != apiKey { http.Error(rw, fmt.Sprintf("invalid API key: %s", auth), http.StatusUnauthorized) return } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - - err := client.DeleteRecord(domain, recordID) + err := client.DeleteRecord(context.Background(), domain, recordID) require.NoError(t, err) } diff --git a/providers/dns/arvancloud/internal/model.go b/providers/dns/arvancloud/internal/types.go similarity index 80% rename from providers/dns/arvancloud/internal/model.go rename to providers/dns/arvancloud/internal/types.go index f26043bc..dc6e04e5 100644 --- a/providers/dns/arvancloud/internal/model.go +++ b/providers/dns/arvancloud/internal/types.go @@ -1,17 +1,15 @@ package internal -import "encoding/json" - -type apiResponse struct { - Message string `json:"message"` - Data json.RawMessage `json:"data"` +type apiResponse[T any] struct { + Message string `json:"message"` + Data T `json:"data"` } // DNSRecord a DNS record. type DNSRecord struct { ID string `json:"id,omitempty"` Type string `json:"type"` - Value interface{} `json:"value,omitempty"` + Value any `json:"value,omitempty"` Name string `json:"name,omitempty"` TTL int `json:"ttl,omitempty"` UpstreamHTTPS string `json:"upstream_https,omitempty"` diff --git a/providers/dns/auroradns/auroradns.go b/providers/dns/auroradns/auroradns.go index 743a226d..700046c4 100644 --- a/providers/dns/auroradns/auroradns.go +++ b/providers/dns/auroradns/auroradns.go @@ -108,7 +108,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("aurora: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // 1. Aurora will happily create the TXT record when it is provided a fqdn, @@ -155,24 +155,24 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { d.recordIDsMu.Unlock() if !ok { - return fmt.Errorf("unknown recordID for %q", info.EffectiveFQDN) + return fmt.Errorf("aurora: unknown recordID for %q", info.EffectiveFQDN) } authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN)) if err != nil { - return fmt.Errorf("could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("aurora: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) zone, err := d.getZoneInformationByName(authZone) if err != nil { - return err + return fmt.Errorf("aurora: %w", err) } _, _, err = d.client.DeleteRecord(zone.ID, recordID) if err != nil { - return err + return fmt.Errorf("aurora: %w", err) } d.recordIDsMu.Lock() diff --git a/providers/dns/autodns/autodns.go b/providers/dns/autodns/autodns.go index abc3433d..3ab31ab1 100644 --- a/providers/dns/autodns/autodns.go +++ b/providers/dns/autodns/autodns.go @@ -2,6 +2,7 @@ package autodns import ( + "context" "errors" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/autodns/internal" ) // Environment variables names. @@ -27,11 +29,6 @@ const ( EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" ) -const ( - defaultEndpointContext int = 4 - defaultTTL int = 600 -) - // Config is used to configure the creation of the DNSProvider. type Config struct { Endpoint *url.URL @@ -46,12 +43,12 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { - endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, defaultEndpoint)) + endpoint, _ := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, internal.DefaultEndpoint)) return &Config{ Endpoint: endpoint, - Context: env.GetOrDefaultInt(EnvAPIEndpointContext, defaultEndpointContext), - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + Context: env.GetOrDefaultInt(EnvAPIEndpointContext, internal.DefaultEndpointContext), + TTL: env.GetOrDefaultInt(EnvTTL, 600), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second), HTTPClient: &http.Client{ @@ -63,6 +60,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for autoDNS. @@ -94,7 +92,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("autodns: missing password") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.Username, config.Password, config.Context) + + if config.Endpoint != nil { + client.BaseURL = config.Endpoint + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -107,7 +115,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - records := []*ResourceRecord{{ + records := []*internal.ResourceRecord{{ Name: info.EffectiveFQDN, TTL: int64(d.config.TTL), Type: "TXT", @@ -115,7 +123,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { }} // TODO(ldez) replace domain by FQDN to follow CNAME. - _, err := d.addTxtRecord(domain, records) + _, err := d.client.AddTxtRecords(context.Background(), domain, records) if err != nil { return fmt.Errorf("autodns: %w", err) } @@ -127,7 +135,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - records := []*ResourceRecord{{ + records := []*internal.ResourceRecord{{ Name: info.EffectiveFQDN, TTL: int64(d.config.TTL), Type: "TXT", @@ -135,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { }} // TODO(ldez) replace domain by FQDN to follow CNAME. - if err := d.removeTXTRecord(domain, records); err != nil { + if err := d.client.RemoveTXTRecords(context.Background(), domain, records); err != nil { return fmt.Errorf("autodns: %w", err) } diff --git a/providers/dns/autodns/client.go b/providers/dns/autodns/client.go deleted file mode 100644 index 1c58ed81..00000000 --- a/providers/dns/autodns/client.go +++ /dev/null @@ -1,159 +0,0 @@ -package autodns - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" -) - -const ( - defaultEndpoint = "https://api.autodns.com/v1/" -) - -type ResponseMessage struct { - Text string `json:"text"` - Messages []string `json:"messages"` - Objects []string `json:"objects"` - Code string `json:"code"` - Status string `json:"status"` -} - -type ResponseStatus struct { - Code string `json:"code"` - Text string `json:"text"` - Type string `json:"type"` -} - -type ResponseObject struct { - Type string `json:"type"` - Value string `json:"value"` - Summary int32 `json:"summary"` - Data string -} - -type DataZoneResponse struct { - STID string `json:"stid"` - CTID string `json:"ctid"` - Messages []*ResponseMessage `json:"messages"` - Status *ResponseStatus `json:"status"` - Object interface{} `json:"object"` - Data []*Zone `json:"data"` -} - -// ResourceRecord holds a resource record. -type ResourceRecord struct { - Name string `json:"name"` - TTL int64 `json:"ttl"` - Type string `json:"type"` - Value string `json:"value"` - Pref int32 `json:"pref,omitempty"` -} - -// Zone is an autodns zone record with all for us relevant fields. -type Zone struct { - Name string `json:"origin"` - ResourceRecords []*ResourceRecord `json:"resourceRecords"` - Action string `json:"action"` - VirtualNameServer string `json:"virtualNameServer"` -} - -type ZoneStream struct { - Adds []*ResourceRecord `json:"adds"` - Removes []*ResourceRecord `json:"rems"` -} - -func (d *DNSProvider) addTxtRecord(domain string, records []*ResourceRecord) (*Zone, error) { - zoneStream := &ZoneStream{Adds: records} - - return d.makeZoneUpdateRequest(zoneStream, domain) -} - -func (d *DNSProvider) removeTXTRecord(domain string, records []*ResourceRecord) error { - zoneStream := &ZoneStream{Removes: records} - - _, err := d.makeZoneUpdateRequest(zoneStream, domain) - return err -} - -func (d *DNSProvider) makeZoneUpdateRequest(zoneStream *ZoneStream, domain string) (*Zone, error) { - reqBody := &bytes.Buffer{} - if err := json.NewEncoder(reqBody).Encode(zoneStream); err != nil { - return nil, err - } - - endpoint := d.config.Endpoint.JoinPath("zone", domain, "_stream") - - req, err := d.makeRequest(http.MethodPost, endpoint.String(), reqBody) - if err != nil { - return nil, err - } - - var resp *Zone - if err := d.sendRequest(req, &resp); err != nil { - return nil, err - } - return resp, nil -} - -func (d *DNSProvider) makeRequest(method, endpoint string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, endpoint, body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Domainrobot-Context", strconv.Itoa(d.config.Context)) - req.SetBasicAuth(d.config.Username, d.config.Password) - - return req, nil -} - -func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error { - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - if err = checkResponse(resp); err != nil { - return err - } - - defer func() { _ = resp.Body.Close() }() - - if result == nil { - return nil - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - err = json.Unmarshal(raw, result) - if err != nil { - return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw)) - } - return err -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode < http.StatusBadRequest { - return nil - } - - if resp.Body == nil { - return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) - } - - defer func() { _ = resp.Body.Close() }() - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) - } - - return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw)) -} diff --git a/providers/dns/autodns/internal/client.go b/providers/dns/autodns/internal/client.go new file mode 100644 index 00000000..363250d0 --- /dev/null +++ b/providers/dns/autodns/internal/client.go @@ -0,0 +1,132 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultEndpoint default API endpoint. +const DefaultEndpoint = "https://api.autodns.com/v1/" + +// DefaultEndpointContext default API endpoint context. +const DefaultEndpointContext int = 4 + +// Client the Autodns API client. +type Client struct { + username string + password string + context int + + BaseURL *url.URL + HTTPClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(username string, password string, clientContext int) *Client { + baseURL, _ := url.Parse(DefaultEndpoint) + + return &Client{ + username: username, + password: password, + context: clientContext, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// AddTxtRecords adds TXT records. +func (c *Client) AddTxtRecords(ctx context.Context, domain string, records []*ResourceRecord) (*Zone, error) { + zoneStream := &ZoneStream{Adds: records} + + return c.updateZone(ctx, domain, zoneStream) +} + +// RemoveTXTRecords removes TXT records. +func (c *Client) RemoveTXTRecords(ctx context.Context, domain string, records []*ResourceRecord) error { + zoneStream := &ZoneStream{Removes: records} + + _, err := c.updateZone(ctx, domain, zoneStream) + return err +} + +// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L21090 +func (c *Client) updateZone(ctx context.Context, domain string, zoneStream *ZoneStream) (*Zone, error) { + endpoint := c.BaseURL.JoinPath("zone", domain, "_stream") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, zoneStream) + if err != nil { + return nil, err + } + + var zone *Zone + if err := c.do(req, &zone); err != nil { + return nil, err + } + + return zone, nil +} + +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set("X-Domainrobot-Context", strconv.Itoa(c.context)) + req.SetBasicAuth(c.username, c.password) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/autodns/internal/client_test.go b/providers/dns/autodns/internal/client_test.go new file mode 100644 index 00000000..f8743b24 --- /dev/null +++ b/providers/dns/autodns/internal/client_test.go @@ -0,0 +1,96 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + apiUser, apiKey, ok := req.BasicAuth() + if apiUser != "user" || apiKey != "secret" || !ok { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + client := NewClient("user", "secret", 123) + client.HTTPClient = server.Client() + client.BaseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_AddTxtRecords(t *testing.T) { + client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json") + + records := []*ResourceRecord{{}} + + zone, err := client.AddTxtRecords(context.Background(), "example.com", records) + require.NoError(t, err) + + expected := &Zone{ + Name: "example.com", + ResourceRecords: []*ResourceRecord{{ + Name: "example.com", + TTL: 120, + Type: "TXT", + Value: "txt", + Pref: 1, + }}, + Action: "xxx", + VirtualNameServer: "yyy", + } + + assert.Equal(t, expected, zone) +} + +func TestClient_RemoveTXTRecords(t *testing.T) { + client := setupTest(t, http.MethodPost, "/zone/example.com/_stream", http.StatusOK, "add-record.json") + + records := []*ResourceRecord{{}} + + err := client.RemoveTXTRecords(context.Background(), "example.com", records) + require.NoError(t, err) +} diff --git a/providers/dns/autodns/internal/fixtures/add-record.json b/providers/dns/autodns/internal/fixtures/add-record.json new file mode 100644 index 00000000..4a95f078 --- /dev/null +++ b/providers/dns/autodns/internal/fixtures/add-record.json @@ -0,0 +1,14 @@ +{ + "origin": "example.com", + "resourceRecords": [ + { + "name": "example.com", + "ttl": 120, + "type": "TXT", + "value": "txt", + "pref": 1 + } + ], + "action": "xxx", + "virtualNameServer": "yyy" +} diff --git a/providers/dns/autodns/internal/fixtures/remove-record.json b/providers/dns/autodns/internal/fixtures/remove-record.json new file mode 100644 index 00000000..4a95f078 --- /dev/null +++ b/providers/dns/autodns/internal/fixtures/remove-record.json @@ -0,0 +1,14 @@ +{ + "origin": "example.com", + "resourceRecords": [ + { + "name": "example.com", + "ttl": 120, + "type": "TXT", + "value": "txt", + "pref": 1 + } + ], + "action": "xxx", + "virtualNameServer": "yyy" +} diff --git a/providers/dns/autodns/internal/types.go b/providers/dns/autodns/internal/types.go new file mode 100644 index 00000000..93fd678c --- /dev/null +++ b/providers/dns/autodns/internal/types.go @@ -0,0 +1,57 @@ +package internal + +type ResponseMessage struct { + Text string `json:"text"` + Messages []string `json:"messages"` + Objects []string `json:"objects"` + Code string `json:"code"` + Status string `json:"status"` +} + +type ResponseStatus struct { + Code string `json:"code"` + Text string `json:"text"` + Type string `json:"type"` +} + +type ResponseObject struct { + Type string `json:"type"` + Value string `json:"value"` + Summary int32 `json:"summary"` + Data string +} + +type DataZoneResponse struct { + STID string `json:"stid"` + CTID string `json:"ctid"` + Messages []*ResponseMessage `json:"messages"` + Status *ResponseStatus `json:"status"` + Object any `json:"object"` + Data []*Zone `json:"data"` +} + +// ResourceRecord holds a resource record. +// https://help.internetx.com/display/APIXMLEN/Resource+Record+Object +type ResourceRecord struct { + Name string `json:"name"` + TTL int64 `json:"ttl"` + Type string `json:"type"` + Value string `json:"value"` + Pref int32 `json:"pref,omitempty"` +} + +// Zone is an autodns zone record with all for us relevant fields. +// https://help.internetx.com/display/APIXMLEN/Zone+Object +type Zone struct { + Name string `json:"origin"` + ResourceRecords []*ResourceRecord `json:"resourceRecords"` + Action string `json:"action"` + VirtualNameServer string `json:"virtualNameServer"` +} + +// ZoneStream body of the requests. +// https://github.com/InterNetX/domainrobot-api/blob/bdc8fe92a2f32fcbdb29e30bf6006ab446f81223/src/domainrobot.json#L35914-L35932 +type ZoneStream struct { + Adds []*ResourceRecord `json:"adds"` + Removes []*ResourceRecord `json:"rems"` +} diff --git a/providers/dns/azure/azure.go b/providers/dns/azure/azure.go index 8ace21fb..4b693efe 100644 --- a/providers/dns/azure/azure.go +++ b/providers/dns/azure/azure.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "time" "github.com/Azure/go-autorest/autorest" @@ -14,6 +15,7 @@ import ( "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultMetadataEndpoint = "http://169.254.169.254" @@ -122,7 +124,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { } if config.HTTPClient == nil { - config.HTTPClient = http.DefaultClient + config.HTTPClient = &http.Client{Timeout: 5 * time.Second} } authorizer, err := getAuthorizer(config) @@ -208,8 +210,12 @@ func getMetadata(config *Config, field string) (string, error) { metadataEndpoint = defaultMetadataEndpoint } - resource := fmt.Sprintf("%s/metadata/instance/compute/%s", metadataEndpoint, field) - req, err := http.NewRequest(http.MethodGet, resource, nil) + endpoint, err := url.JoinPath(metadataEndpoint, "metadata", "instance", "compute", field) + if err != nil { + return "", err + } + + req, err := http.NewRequest(http.MethodGet, endpoint, nil) if err != nil { return "", err } @@ -223,14 +229,15 @@ func getMetadata(config *Config, field string) (string, error) { resp, err := config.HTTPClient.Do(req) if err != nil { - return "", err + return "", errutils.NewHTTPDoError(req, err) } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return "", errutils.NewReadResponseError(req, resp.StatusCode, err) } - return string(respBody), nil + return string(raw), nil } diff --git a/providers/dns/azure/private.go b/providers/dns/azure/private.go index 3994bf20..6f1aa822 100644 --- a/providers/dns/azure/private.go +++ b/providers/dns/azure/private.go @@ -118,7 +118,7 @@ func (d *dnsProviderPrivate) getHostedZoneID(ctx context.Context, fqdn string) ( authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } dc := privatedns.NewPrivateZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) diff --git a/providers/dns/azure/public.go b/providers/dns/azure/public.go index 4f3c1ff9..aca6869b 100644 --- a/providers/dns/azure/public.go +++ b/providers/dns/azure/public.go @@ -118,7 +118,7 @@ func (d *dnsProviderPublic) getHostedZoneID(ctx context.Context, fqdn string) (s authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) diff --git a/providers/dns/bluecat/bluecat.go b/providers/dns/bluecat/bluecat.go index 3e14d309..58ac2147 100644 --- a/providers/dns/bluecat/bluecat.go +++ b/providers/dns/bluecat/bluecat.go @@ -2,6 +2,7 @@ package bluecat import ( + "context" "errors" "fmt" "net/http" @@ -97,7 +98,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("bluecat: credentials missing") } - client := internal.NewClient(config.BaseURL) + client := internal.NewClient(config.BaseURL, config.UserName, config.Password) if config.HTTPClient != nil { client.HTTPClient = config.HTTPClient @@ -112,17 +113,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.Login(d.config.UserName, d.config.Password) + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { return fmt.Errorf("bluecat: login: %w", err) } - viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView) + viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView) if err != nil { return fmt.Errorf("bluecat: lookupViewID: %w", err) } - parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN) + parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN) if err != nil { return fmt.Errorf("bluecat: lookupParentZoneID: %w", err) } @@ -137,17 +138,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Properties: fmt.Sprintf("ttl=%d|absoluteName=%s|txt=%s|", d.config.TTL, info.EffectiveFQDN, info.Value), } - _, err = d.client.AddEntity(parentZoneID, txtRecord) + _, err = d.client.AddEntity(ctx, parentZoneID, txtRecord) if err != nil { return fmt.Errorf("bluecat: add TXT record: %w", err) } - err = d.client.Deploy(parentZoneID) + err = d.client.Deploy(ctx, parentZoneID) if err != nil { return fmt.Errorf("bluecat: deploy: %w", err) } - err = d.client.Logout() + err = d.client.Logout(ctx) if err != nil { return fmt.Errorf("bluecat: logout: %w", err) } @@ -159,37 +160,37 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.Login(d.config.UserName, d.config.Password) + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { return fmt.Errorf("bluecat: login: %w", err) } - viewID, err := d.client.LookupViewID(d.config.ConfigName, d.config.DNSView) + viewID, err := d.client.LookupViewID(ctx, d.config.ConfigName, d.config.DNSView) if err != nil { return fmt.Errorf("bluecat: lookupViewID: %w", err) } - parentZoneID, name, err := d.client.LookupParentZoneID(viewID, info.EffectiveFQDN) + parentZoneID, name, err := d.client.LookupParentZoneID(ctx, viewID, info.EffectiveFQDN) if err != nil { return fmt.Errorf("bluecat: lookupParentZoneID: %w", err) } - txtRecord, err := d.client.GetEntityByName(parentZoneID, name, internal.TXTType) + txtRecord, err := d.client.GetEntityByName(ctx, parentZoneID, name, internal.TXTType) if err != nil { return fmt.Errorf("bluecat: get TXT record: %w", err) } - err = d.client.Delete(txtRecord.ID) + err = d.client.Delete(ctx, txtRecord.ID) if err != nil { return fmt.Errorf("bluecat: delete TXT record: %w", err) } - err = d.client.Deploy(parentZoneID) + err = d.client.Deploy(ctx, parentZoneID) if err != nil { return fmt.Errorf("bluecat: deploy: %w", err) } - err = d.client.Logout() + err = d.client.Logout(ctx) if err != nil { return fmt.Errorf("bluecat: logout: %w", err) } diff --git a/providers/dns/bluecat/internal/client.go b/providers/dns/bluecat/internal/client.go index bb61f9da..e6451343 100644 --- a/providers/dns/bluecat/internal/client.go +++ b/providers/dns/bluecat/internal/client.go @@ -2,14 +2,18 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" + "net/url" "regexp" "strconv" "strings" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // Object types. @@ -20,153 +24,88 @@ const ( TXTType = "TXTRecord" ) +const authorizationHeader = "Authorization" + type Client struct { - HTTPClient *http.Client + username string + password string - baseURL string - - token string tokenExp *regexp.Regexp + + baseURL *url.URL + HTTPClient *http.Client } -func NewClient(baseURL string) *Client { +func NewClient(baseURL string, username, password string) *Client { + bu, _ := url.Parse(baseURL) + return &Client{ - HTTPClient: &http.Client{Timeout: 30 * time.Second}, - baseURL: baseURL, + username: username, + password: password, tokenExp: regexp.MustCompile("BAMAuthToken: [^ ]+"), + baseURL: bu, + HTTPClient: &http.Client{Timeout: 30 * time.Second}, } } -// Login Logs in as API user. -// Authenticates and receives a token to be used in for subsequent requests. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/login/9.1.0 -func (c *Client) Login(username, password string) error { - queryArgs := map[string]string{ - "username": username, - "password": password, - } - - resp, err := c.sendRequest(http.MethodGet, "login", nil, queryArgs) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return &APIError{ - StatusCode: resp.StatusCode, - Resource: "login", - Message: string(data), - } - } - - authBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - authResp := string(authBytes) - if strings.Contains(authResp, "Authentication Error") { - return fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`)) - } - - // Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username" - c.token = c.tokenExp.FindString(authResp) - - return nil -} - -// Logout Logs out of the current API session. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/logout/9.1.0 -func (c *Client) Logout() error { - if c.token == "" { - // nothing to do - return nil - } - - resp, err := c.sendRequest(http.MethodGet, "logout", nil, nil) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return &APIError{ - StatusCode: resp.StatusCode, - Resource: "logout", - Message: string(data), - } - } - - authBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - authResp := string(authBytes) - if !strings.Contains(authResp, "successfully") { - return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`)) - } - - c.token = "" - - return nil -} - // Deploy the DNS config for the specified entity to the authoritative servers. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/quickDeploy/9.1.0 -func (c *Client) Deploy(entityID uint) error { - queryArgs := map[string]string{ - "entityId": strconv.FormatUint(uint64(entityID), 10), - } +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/quickDeploy/9.5.0 +func (c *Client) Deploy(ctx context.Context, entityID uint) error { + endpoint := c.createEndpoint("quickDeploy") - resp, err := c.sendRequest(http.MethodPost, "quickDeploy", nil, queryArgs) + q := endpoint.Query() + q.Set("entityId", strconv.FormatUint(uint64(entityID), 10)) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, nil) if err != nil { return err } - defer resp.Body.Close() + + resp, err := c.doAuthenticated(ctx, req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() // The API doc says that 201 is expected but in the reality 200 is return. if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return &APIError{ - StatusCode: resp.StatusCode, - Resource: "quickDeploy", - Message: string(data), - } + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil } // AddEntity A generic method for adding configurations, DNS zones, and DNS resource records. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/POST/v1/addEntity/9.1.0 -func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) { - queryArgs := map[string]string{ - "parentId": strconv.FormatUint(uint64(parentID), 10), - } +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/POST/v1/addEntity/9.5.0 +func (c *Client) AddEntity(ctx context.Context, parentID uint, entity Entity) (uint64, error) { + endpoint := c.createEndpoint("addEntity") - resp, err := c.sendRequest(http.MethodPost, "addEntity", entity, queryArgs) + q := endpoint.Query() + q.Set("parentId", strconv.FormatUint(uint64(parentID), 10)) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, entity) if err != nil { return 0, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return 0, &APIError{ - StatusCode: resp.StatusCode, - Resource: "addEntity", - Message: string(data), - } + resp, err := c.doAuthenticated(ctx, req) + if err != nil { + return 0, errutils.NewHTTPDoError(req, err) } - addTxtBytes, _ := io.ReadAll(resp.Body) + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return 0, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, _ := io.ReadAll(resp.Body) // addEntity responds only with body text containing the ID of the created record - addTxtResp := string(addTxtBytes) + addTxtResp := string(raw) id, err := strconv.ParseUint(addTxtResp, 10, 64) if err != nil { return 0, fmt.Errorf("addEntity request failed: %s", addTxtResp) @@ -176,73 +115,84 @@ func (c *Client) AddEntity(parentID uint, entity Entity) (uint64, error) { } // GetEntityByName Returns objects from the database referenced by their database ID and with its properties fields populated. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/GET/v1/getEntityById/9.1.0 -func (c *Client) GetEntityByName(parentID uint, name, objType string) (*EntityResponse, error) { - queryArgs := map[string]string{ - "parentId": strconv.FormatUint(uint64(parentID), 10), - "name": name, - "type": objType, - } +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/getEntityById/9.5.0 +func (c *Client) GetEntityByName(ctx context.Context, parentID uint, name, objType string) (*EntityResponse, error) { + endpoint := c.createEndpoint("getEntityByName") - resp, err := c.sendRequest(http.MethodGet, "getEntityByName", nil, queryArgs) + q := endpoint.Query() + q.Set("parentId", strconv.FormatUint(uint64(parentID), 10)) + q.Set("name", name) + q.Set("type", objType) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer resp.Body.Close() + + resp, err := c.doAuthenticated(ctx, req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return nil, &APIError{ - StatusCode: resp.StatusCode, - Resource: "getEntityByName", - Message: string(data), - } + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - var txtRec EntityResponse - if err = json.NewDecoder(resp.Body).Decode(&txtRec); err != nil { - return nil, fmt.Errorf("JSON decode: %w", err) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } - return &txtRec, nil + var entity EntityResponse + err = json.Unmarshal(raw, &entity) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return &entity, nil } // Delete Deletes an object using the generic delete method. -// https://docs.bluecatnetworks.com/r/Address-Manager-API-Guide/DELETE/v1/delete/9.1.0 -func (c *Client) Delete(objectID uint) error { - queryArgs := map[string]string{ - "objectId": strconv.FormatUint(uint64(objectID), 10), - } +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/DELETE/v1/delete/9.5.0 +func (c *Client) Delete(ctx context.Context, objectID uint) error { + endpoint := c.createEndpoint("delete") - resp, err := c.sendRequest(http.MethodDelete, "delete", nil, queryArgs) + q := endpoint.Query() + q.Set("objectId", strconv.FormatUint(uint64(objectID), 10)) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - defer resp.Body.Close() + resp, err := c.doAuthenticated(ctx, req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } - // The API doc says that 204 is expected but in the reality 200 is return. + defer func() { _ = resp.Body.Close() }() + + // The API doc says that 204 is expected but in the reality 200 is returned. if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - return &APIError{ - StatusCode: resp.StatusCode, - Resource: "delete", - Message: string(data), - } + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil } // LookupViewID Find the DNS view with the given name within. -func (c *Client) LookupViewID(configName, viewName string) (uint, error) { +func (c *Client) LookupViewID(ctx context.Context, configName, viewName string) (uint, error) { // Lookup the entity ID of the configuration named in our properties. - conf, err := c.GetEntityByName(0, configName, ConfigType) + conf, err := c.GetEntityByName(ctx, 0, configName, ConfigType) if err != nil { return 0, err } - view, err := c.GetEntityByName(conf.ID, viewName, ViewType) + view, err := c.GetEntityByName(ctx, conf.ID, viewName, ViewType) if err != nil { return 0, err } @@ -252,7 +202,7 @@ func (c *Client) LookupViewID(configName, viewName string) (uint, error) { // LookupParentZoneID Return the entityId of the parent zone by recursing from the root view. // Also return the simple name of the host. -func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, error) { +func (c *Client) LookupParentZoneID(ctx context.Context, viewID uint, fqdn string) (uint, string, error) { if fqdn == "" { return viewID, "", nil } @@ -263,7 +213,7 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err parentViewID := viewID for i := len(zones) - 1; i > -1; i-- { - zone, err := c.GetEntityByName(parentViewID, zones[i], ZoneType) + zone, err := c.GetEntityByName(ctx, parentViewID, zones[i], ZoneType) if err != nil { return 0, "", fmt.Errorf("could not find zone named %s: %w", name, err) } @@ -282,32 +232,39 @@ func (c *Client) LookupParentZoneID(viewID uint, fqdn string) (uint, string, err return parentViewID, name, nil } -// Send a REST request, using query parameters specified. -// The Authorization header will be set if we have an active auth token. -func (c *Client) sendRequest(method, resource string, payload interface{}, queryParams map[string]string) (*http.Response, error) { - url := fmt.Sprintf("%s/Services/REST/v1/%s", c.baseURL, resource) +func (c *Client) createEndpoint(resource string) *url.URL { + return c.baseURL.JoinPath("Services", "REST", "v1", resource) +} - body, err := json.Marshal(payload) - if err != nil { - return nil, err +func (c *Client) doAuthenticated(ctx context.Context, req *http.Request) (*http.Response, error) { + tok := getToken(ctx) + if tok != "" { + req.Header.Set(authorizationHeader, tok) } - req, err := http.NewRequest(method, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - - if c.token != "" { - req.Header.Set("Authorization", c.token) - } - - q := req.URL.Query() - for k, v := range queryParams { - q.Set(k, v) - } - req.URL.RawQuery = q.Encode() - return c.HTTPClient.Do(req) } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/bluecat/internal/client_test.go b/providers/dns/bluecat/internal/client_test.go index 072f6254..206d7d1a 100644 --- a/providers/dns/bluecat/internal/client_test.go +++ b/providers/dns/bluecat/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -15,7 +16,8 @@ func TestClient_LookupParentZoneID(t *testing.T) { server := httptest.NewServer(mux) t.Cleanup(server.Close) - client := NewClient(server.URL) + client := NewClient(server.URL, "user", "secret") + client.HTTPClient = server.Client() mux.HandleFunc("/Services/REST/v1/getEntityByName", func(rw http.ResponseWriter, req *http.Request) { query := req.URL.Query() @@ -33,7 +35,7 @@ func TestClient_LookupParentZoneID(t *testing.T) { http.Error(rw, "{}", http.StatusOK) }) - parentID, name, err := client.LookupParentZoneID(2, "foo.example.com") + parentID, name, err := client.LookupParentZoneID(context.Background(), 2, "foo.example.com") require.NoError(t, err) assert.EqualValues(t, 2, parentID) diff --git a/providers/dns/bluecat/internal/identity.go b/providers/dns/bluecat/internal/identity.go new file mode 100644 index 00000000..425e9cd8 --- /dev/null +++ b/providers/dns/bluecat/internal/identity.go @@ -0,0 +1,115 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +type token string + +const tokenKey token = "token" + +// login Logs in as API user. +// Authenticates and receives a token to be used in for subsequent requests. +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/login/9.5.0 +func (c *Client) login(ctx context.Context) (string, error) { + endpoint := c.createEndpoint("login") + + q := endpoint.Query() + q.Set("username", c.username) + q.Set("password", c.password) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return "", err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return "", errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + authResp := string(raw) + if strings.Contains(authResp, "Authentication Error") { + return "", fmt.Errorf("request failed: %s", strings.Trim(authResp, `"`)) + } + + // Upon success, API responds with "Session Token-> BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM= <- for User : username" + tok := c.tokenExp.FindString(authResp) + + return tok, nil +} + +// Logout Logs out of the current API session. +// https://docs.bluecatnetworks.com/r/Address-Manager-Legacy-v1-API-Guide/GET/v1/logout/9.5.0 +func (c *Client) Logout(ctx context.Context) error { + if getToken(ctx) == "" { + // nothing to do + return nil + } + + endpoint := c.createEndpoint("logout") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return err + } + + resp, err := c.doAuthenticated(ctx, req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + authResp := string(raw) + if !strings.Contains(authResp, "successfully") { + return fmt.Errorf("request failed to delete session: %s", strings.Trim(authResp, `"`)) + } + + return nil +} + +func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) { + tok, err := c.login(ctx) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, tokenKey, tok), nil +} + +func getToken(ctx context.Context) string { + tok, ok := ctx.Value(tokenKey).(string) + if !ok { + return "" + } + + return tok +} diff --git a/providers/dns/bluecat/internal/identity_test.go b/providers/dns/bluecat/internal/identity_test.go new file mode 100644 index 00000000..378f6ab3 --- /dev/null +++ b/providers/dns/bluecat/internal/identity_test.go @@ -0,0 +1,59 @@ +package internal + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const fakeToken = "BAMAuthToken: dQfuRMTUxNjc3MjcyNDg1ODppcGFybXM=" + +func TestClient_CreateAuthenticatedContext(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := NewClient(server.URL, "user", "secret") + client.HTTPClient = server.Client() + + mux.HandleFunc("/Services/REST/v1/login", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + query := req.URL.Query() + if query.Get("username") != "user" { + http.Error(rw, fmt.Sprintf("invalid username %s", query.Get("username")), http.StatusUnauthorized) + return + } + + if query.Get("password") != "secret" { + http.Error(rw, fmt.Sprintf("invalid password %s", query.Get("password")), http.StatusUnauthorized) + return + } + + _, _ = fmt.Fprint(rw, fakeToken) + }) + mux.HandleFunc("/Services/REST/v1/delete", func(rw http.ResponseWriter, req *http.Request) { + authorization := req.Header.Get(authorizationHeader) + if authorization != fakeToken { + http.Error(rw, fmt.Sprintf("invalid credential: %s", authorization), http.StatusUnauthorized) + return + } + }) + + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + at := getToken(ctx) + assert.Equal(t, fakeToken, at) + + err = client.Delete(ctx, 123) + require.NoError(t, err) +} diff --git a/providers/dns/bluecat/internal/types.go b/providers/dns/bluecat/internal/types.go index b3b7b412..5f1bf772 100644 --- a/providers/dns/bluecat/internal/types.go +++ b/providers/dns/bluecat/internal/types.go @@ -1,7 +1,5 @@ package internal -import "fmt" - // Entity JSON body for Bluecat entity requests. type Entity struct { ID string `json:"id,omitempty"` @@ -17,13 +15,3 @@ type EntityResponse struct { Type string `json:"type"` Properties string `json:"properties"` } - -type APIError struct { - StatusCode int - Resource string - Message string -} - -func (a APIError) Error() string { - return fmt.Sprintf("resource: %s, status code: %d, message: %s", a.Resource, a.StatusCode, a.Message) -} diff --git a/providers/dns/brandit/brandit.go b/providers/dns/brandit/brandit.go index 0b80f490..33af186c 100644 --- a/providers/dns/brandit/brandit.go +++ b/providers/dns/brandit/brandit.go @@ -1,9 +1,11 @@ package brandit import ( + "context" "errors" "fmt" "net/http" + "strconv" "sync" "time" @@ -12,8 +14,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/brandit/internal" ) -const defaultTTL = 600 - // Environment variables names. const ( envNamespace = "BRANDIT_" @@ -21,11 +21,10 @@ const ( EnvAPIKey = envNamespace + "API_KEY" EnvAPIUsername = envNamespace + "API_USERNAME" - EnvTTL = envNamespace + "TTL" - EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" - EnvPollingInterval = envNamespace + "POLLING_INTERVAL" - EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" - DefaultBrandItPropagationTimeout = 600 * time.Second + EnvTTL = envNamespace + "TTL" + EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" + EnvPollingInterval = envNamespace + "POLLING_INTERVAL" + EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" ) // Config is used to configure the creation of the DNSProvider. @@ -42,8 +41,8 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), - PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, DefaultBrandItPropagationTimeout), + TTL: env.GetOrDefaultInt(EnvTTL, 600), + PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 10*time.Minute), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), HTTPClient: &http.Client{ Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second), @@ -97,13 +96,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { }, nil } +// Timeout returns the timeout and interval to use when checking for DNS propagation. +// Adjusting here to cope with spikes in propagation times. +func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval +} + // Present creates a TXT record using the specified parameters. func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("brandit: %w", err) + return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -111,6 +116,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("brandit: %w", err) } + ctx := context.Background() + record := internal.Record{ Type: "TXT", Name: subDomain, @@ -119,18 +126,18 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { } // find the account associated with the domain - account, err := d.client.StatusDomain(dns01.UnFqdn(authZone)) + account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("brandit: status domain: %w", err) } // Find the next record id - recordID, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone)) + recordID, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("brandit: list records: %w", err) } - result, err := d.client.AddRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], fmt.Sprint(recordID.Response.Total[0]), record) + result, err := d.client.AddRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], strconv.Itoa(recordID.Total[0]), record) if err != nil { return fmt.Errorf("brandit: add record: %w", err) } @@ -148,7 +155,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("brandit: %w", err) + return fmt.Errorf("brandit: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // gets the record's unique ID @@ -159,25 +166,27 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("brandit: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } + ctx := context.Background() + // find the account associated with the domain - account, err := d.client.StatusDomain(dns01.UnFqdn(authZone)) + account, err := d.client.StatusDomain(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("brandit: status domain: %w", err) } - records, err := d.client.ListRecords(account.Response.Registrar[0], dns01.UnFqdn(authZone)) + records, err := d.client.ListRecords(ctx, account.Registrar[0], dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("brandit: list records: %w", err) } var recordID int - for i, r := range records.Response.RR { + for i, r := range records.RR { if r == dnsRecord { recordID = i } } - _, err = d.client.DeleteRecord(dns01.UnFqdn(authZone), account.Response.Registrar[0], dnsRecord, fmt.Sprint(recordID)) + err = d.client.DeleteRecord(ctx, dns01.UnFqdn(authZone), account.Registrar[0], dnsRecord, strconv.Itoa(recordID)) if err != nil { return fmt.Errorf("brandit: delete record: %w", err) } @@ -189,9 +198,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } - -// Timeout returns the timeout and interval to use when checking for DNS propagation. -// Adjusting here to cope with spikes in propagation times. -func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { - return d.config.PropagationTimeout, d.config.PollingInterval -} diff --git a/providers/dns/brandit/brandit.toml b/providers/dns/brandit/brandit.toml index 07346c2b..acf61bd7 100644 --- a/providers/dns/brandit/brandit.toml +++ b/providers/dns/brandit/brandit.toml @@ -1,4 +1,4 @@ -Name = "BRANDIT" +Name = "Brandit" Description = '''''' URL = "https://www.brandit.com/" Code = "brandit" diff --git a/providers/dns/brandit/internal/client.go b/providers/dns/brandit/internal/client.go index d145315d..12e28fdf 100644 --- a/providers/dns/brandit/internal/client.go +++ b/providers/dns/brandit/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -12,6 +13,8 @@ import ( "net/url" "strings" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://portal.brandit.com/api/v3/" @@ -20,8 +23,9 @@ const defaultBaseURL = "https://portal.brandit.com/api/v3/" type Client struct { apiUsername string apiKey string - BaseURL string - HTTPClient *http.Client + + baseURL string + HTTPClient *http.Client } // NewClient creates a new Client. @@ -33,70 +37,69 @@ func NewClient(apiUsername, apiKey string) (*Client, error) { return &Client{ apiUsername: apiUsername, apiKey: apiKey, - BaseURL: defaultBaseURL, + baseURL: defaultBaseURL, HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } // ListRecords lists all records. // https://portal.brandit.com/apidocv3#listDNSRR -func (c *Client) ListRecords(account, dnsZone string) (*ListRecords, error) { - // Create a new query +func (c *Client) ListRecords(ctx context.Context, account, dnsZone string) (*ListRecordsResponse, error) { query := url.Values{} query.Add("command", "listDNSRR") query.Add("account", account) query.Add("dnszone", dnsZone) - result := &ListRecords{} + result := &Response[*ListRecordsResponse]{} - err := c.do(query, result) + err := c.do(ctx, query, result) if err != nil { - return nil, fmt.Errorf("do: %w", err) + return nil, err } for len(result.Response.RR) < result.Response.Total[0] { query.Add("first", fmt.Sprint(result.Response.Last[0]+1)) - tmp := &ListRecords{} - err := c.do(query, tmp) + tmp := &Response[*ListRecordsResponse]{} + err := c.do(ctx, query, tmp) if err != nil { - return nil, fmt.Errorf("do: %w", err) + return nil, err } result.Response.RR = append(result.Response.RR, tmp.Response.RR...) result.Response.Last = tmp.Response.Last } - return result, nil + return result.Response, nil } // AddRecord adds a DNS record. // https://portal.brandit.com/apidocv3#addDNSRR -func (c *Client) AddRecord(domainName, account, newRecordID string, record Record) (*AddRecord, error) { - // Create a new query +func (c *Client) AddRecord(ctx context.Context, domainName, account, newRecordID string, record Record) (*AddRecord, error) { + value := strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ") query := url.Values{} query.Add("command", "addDNSRR") query.Add("account", account) query.Add("dnszone", domainName) - query.Add("rrdata", strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ")) + query.Add("rrdata", value) query.Add("key", newRecordID) result := &AddRecord{} - err := c.do(query, result) + err := c.do(ctx, query, result) if err != nil { - return nil, fmt.Errorf("do: %w", err) + return nil, err } - result.Record = strings.Join([]string{record.Name, fmt.Sprint(record.TTL), "IN", record.Type, record.Content}, " ") + + result.Record = value return result, nil } // DeleteRecord deletes a DNS record. // https://portal.brandit.com/apidocv3#deleteDNSRR -func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) (*DeleteRecord, error) { - // Create a new query +func (c *Client) DeleteRecord(ctx context.Context, domainName, account, dnsRecord, recordID string) error { query := url.Values{} query.Add("command", "deleteDNSRR") query.Add("account", account) @@ -104,68 +107,70 @@ func (c *Client) DeleteRecord(domainName, account, dnsRecord, recordID string) ( query.Add("rrdata", dnsRecord) query.Add("key", recordID) - result := &DeleteRecord{} - - err := c.do(query, result) - if err != nil { - return nil, fmt.Errorf("do: %w", err) - } - - return result, nil + return c.do(ctx, query, nil) } // StatusDomain returns the status of a domain and account associated with it. // https://portal.brandit.com/apidocv3#statusDomain -func (c *Client) StatusDomain(domain string) (*StatusDomain, error) { - // Create a new query +func (c *Client) StatusDomain(ctx context.Context, domain string) (*StatusResponse, error) { query := url.Values{} query.Add("command", "statusDomain") query.Add("domain", domain) - result := &StatusDomain{} + result := &Response[*StatusResponse]{} - err := c.do(query, result) + err := c.do(ctx, query, result) if err != nil { - return nil, fmt.Errorf("do: %w", err) + return nil, err } - return result, nil + return result.Response, nil } -func (c *Client) do(query url.Values, result any) error { - // Add signature - v, err := sign(c.apiUsername, c.apiKey, query) - if err != nil { - return fmt.Errorf("signature: %w", err) - } - - resp, err := c.HTTPClient.PostForm(c.BaseURL, v) +func (c *Client) do(ctx context.Context, query url.Values, result any) error { + values, err := sign(c.apiUsername, c.apiKey, query) if err != nil { return err } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, strings.NewReader(values.Encode())) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + defer func() { _ = resp.Body.Close() }() raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("read response body: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } // Unmarshal the error response, because the API returns a 200 OK even if there is an error. var apiError APIError err = json.Unmarshal(raw, &apiError) if err != nil { - return fmt.Errorf("unmarshal error response: %w %s", err, string(raw)) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } if apiError.Code > 299 || apiError.Status != "success" { return apiError } + if result == nil { + return nil + } + err = json.Unmarshal(raw, result) if err != nil { - return fmt.Errorf("unmarshal response body: %w %s", err, string(raw)) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil diff --git a/providers/dns/brandit/internal/client_test.go b/providers/dns/brandit/internal/client_test.go index 6e75294d..a37e51a2 100644 --- a/providers/dns/brandit/internal/client_test.go +++ b/providers/dns/brandit/internal/client_test.go @@ -1,30 +1,32 @@ package internal import ( + "context" "io" "net/http" "net/http/httptest" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setupTest(t *testing.T, file string) *Client { +func setupTest(t *testing.T, filename string) *Client { t.Helper() server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - open, err := os.Open(file) + file, err := os.Open(filepath.Join("fixtures", filename)) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - defer func() { _ = open.Close() }() + defer func() { _ = file.Close() }() rw.WriteHeader(http.StatusOK) - _, err = io.Copy(rw, open) + _, err = io.Copy(rw, file) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return @@ -36,78 +38,82 @@ func setupTest(t *testing.T, file string) *Client { require.NoError(t, err) client.HTTPClient = server.Client() - client.BaseURL = server.URL + client.baseURL = server.URL return client } func TestClient_StatusDomain(t *testing.T) { - client := setupTest(t, "./fixtures/status-domain.json") + client := setupTest(t, "status-domain.json") - domain, err := client.StatusDomain("example.com") + domain, err := client.StatusDomain(context.Background(), "example.com") require.NoError(t, err) - expected := &StatusDomain{ - Response: StatusResponse{ - RenewalMode: []string{"DEFAULT"}, - Status: []string{"clientTransferProhibited"}, - TransferLock: []int{1}, - Registrar: []string{"brandit"}, - PaidUntilDate: []string{"2021-12-15 05:00:00.0"}, - Nameserver: []string{"NS1.RRPPROXY.NET", "NS2.RRPPROXY.NET"}, - RegistrationExpirationDate: []string{"2021-12-15 05:00:00.0"}, - Domain: []string{"example.com"}, - RenewalDate: []string{"2024-01-19 05:00:00.0"}, - UpdatedDate: []string{"2022-12-16 08:01:27.0"}, - BillingContact: []string{"example"}, - XDomainRoID: []string{"example"}, - AdminContact: []string{"example"}, - TechContact: []string{"example"}, - DomainIDN: []string{"example.com"}, - CreatedDate: []string{"2016-12-16 05:00:00.0"}, - RegistrarTransferDate: []string{"2021-12-09 05:17:42.0"}, - Zone: []string{"com"}, - Auth: []string{"example"}, - UpdatedBy: []string{"example"}, - RoID: []string{"example"}, - OwnerContact: []string{"example"}, - CreatedBy: []string{"example"}, - TransferMode: []string{"auto"}, - }, - Code: 200, - Status: "success", - Error: "", + expected := &StatusResponse{ + RenewalMode: []string{"DEFAULT"}, + Status: []string{"clientTransferProhibited"}, + TransferLock: []int{1}, + Registrar: []string{"brandit"}, + PaidUntilDate: []string{"2021-12-15 05:00:00.0"}, + Nameserver: []string{"NS1.RRPPROXY.NET", "NS2.RRPPROXY.NET"}, + RegistrationExpirationDate: []string{"2021-12-15 05:00:00.0"}, + Domain: []string{"example.com"}, + RenewalDate: []string{"2024-01-19 05:00:00.0"}, + UpdatedDate: []string{"2022-12-16 08:01:27.0"}, + BillingContact: []string{"example"}, + XDomainRoID: []string{"example"}, + AdminContact: []string{"example"}, + TechContact: []string{"example"}, + DomainIDN: []string{"example.com"}, + CreatedDate: []string{"2016-12-16 05:00:00.0"}, + RegistrarTransferDate: []string{"2021-12-09 05:17:42.0"}, + Zone: []string{"com"}, + Auth: []string{"example"}, + UpdatedBy: []string{"example"}, + RoID: []string{"example"}, + OwnerContact: []string{"example"}, + CreatedBy: []string{"example"}, + TransferMode: []string{"auto"}, } assert.Equal(t, expected, domain) } -func TestClient_ListRecords(t *testing.T) { - client := setupTest(t, "./fixtures/list-records.json") +func TestClient_StatusDomain_error(t *testing.T) { + client := setupTest(t, "error.json") - resp, err := client.ListRecords("example", "example.com") + _, err := client.StatusDomain(context.Background(), "example.com") + require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."}) +} + +func TestClient_ListRecords(t *testing.T) { + client := setupTest(t, "list-records.json") + + resp, err := client.ListRecords(context.Background(), "example", "example.com") require.NoError(t, err) - expected := &ListRecords{ - Response: ListRecordsResponse{ - Limit: []int{100}, - Column: []string{"rr"}, - Count: []int{1}, - First: []int{0}, - Total: []int{1}, - RR: []string{"example.com. 600 IN TXT txttxttxt"}, - Last: []int{0}, - }, - Code: 200, - Status: "success", - Error: "", + expected := &ListRecordsResponse{ + Limit: []int{100}, + Column: []string{"rr"}, + Count: []int{1}, + First: []int{0}, + Total: []int{1}, + RR: []string{"example.com. 600 IN TXT txttxttxt"}, + Last: []int{0}, } assert.Equal(t, expected, resp) } +func TestClient_ListRecords_error(t *testing.T) { + client := setupTest(t, "error.json") + + _, err := client.ListRecords(context.Background(), "example", "example.com") + require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."}) +} + func TestClient_AddRecord(t *testing.T) { - client := setupTest(t, "./fixtures/add-record.json") + client := setupTest(t, "add-record.json") testRecord := Record{ ID: 2565, @@ -116,7 +122,7 @@ func TestClient_AddRecord(t *testing.T) { Content: "txttxttxt", TTL: 600, } - resp, err := client.AddRecord("example.com", "test", "2565", testRecord) + resp, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord) require.NoError(t, err) expected := &AddRecord{ @@ -133,17 +139,31 @@ func TestClient_AddRecord(t *testing.T) { assert.Equal(t, expected, resp) } -func TestClient_DeleteRecord(t *testing.T) { - client := setupTest(t, "./fixtures/delete-record.json") +func TestClient_AddRecord_error(t *testing.T) { + client := setupTest(t, "error.json") - resp, err := client.DeleteRecord("example.com", "test", "example.com 600 IN TXT txttxttxt", "2374") - require.NoError(t, err) - - expected := &DeleteRecord{ - Code: 200, - Status: "success", - Error: "", + testRecord := Record{ + ID: 2565, + Type: "TXT", + Name: "example.com", + Content: "txttxttxt", + TTL: 600, } - assert.Equal(t, expected, resp) + _, err := client.AddRecord(context.Background(), "example.com", "test", "2565", testRecord) + require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."}) +} + +func TestClient_DeleteRecord(t *testing.T) { + client := setupTest(t, "delete-record.json") + + err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374") + require.NoError(t, err) +} + +func TestClient_DeleteRecord_error(t *testing.T) { + client := setupTest(t, "error.json") + + err := client.DeleteRecord(context.Background(), "example.com", "test", "example.com 600 IN TXT txttxttxt", "2374") + require.ErrorIs(t, err, APIError{Code: 402, Status: "error", Message: "Invalid user."}) } diff --git a/providers/dns/brandit/internal/fixtures/error.json b/providers/dns/brandit/internal/fixtures/error.json new file mode 100644 index 00000000..63bc2abd --- /dev/null +++ b/providers/dns/brandit/internal/fixtures/error.json @@ -0,0 +1,5 @@ +{ + "code": 402, + "status": "error", + "error": "Invalid user." +} diff --git a/providers/dns/brandit/internal/types.go b/providers/dns/brandit/internal/types.go index 099c5fe1..a0a5e50b 100644 --- a/providers/dns/brandit/internal/types.go +++ b/providers/dns/brandit/internal/types.go @@ -2,11 +2,11 @@ package internal import "fmt" -type StatusDomain struct { - Response StatusResponse `json:"response,omitempty"` - Code int `json:"code"` - Status string `json:"status"` - Error string `json:"error"` +type Response[T any] struct { + Response T `json:"response,omitempty"` + Code int `json:"code"` + Status string `json:"status"` + Error string `json:"error"` } type StatusResponse struct { @@ -36,13 +36,6 @@ type StatusResponse struct { TransferMode []string `json:"transfermode"` } -type ListRecords struct { - Response ListRecordsResponse `json:"response,omitempty"` - Code int `json:"code"` - Status string `json:"status"` - Error string `json:"error"` -} - type ListRecordsResponse struct { Limit []int `json:"limit,omitempty"` Column []string `json:"column,omitempty"` @@ -83,9 +76,3 @@ type Record struct { Content string `json:"content,omitempty"` TTL int `json:"ttl,omitempty"` // default 600 } - -type DeleteRecord struct { - Code int `json:"code"` - Status string `json:"status"` - Error string `json:"error"` -} diff --git a/providers/dns/checkdomain/checkdomain.go b/providers/dns/checkdomain/checkdomain.go index 2a9787ba..7228fe29 100644 --- a/providers/dns/checkdomain/checkdomain.go +++ b/providers/dns/checkdomain/checkdomain.go @@ -2,15 +2,16 @@ package checkdomain import ( + "context" "errors" "fmt" "net/http" "net/url" - "sync" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/checkdomain/internal" ) // Environment variables names. @@ -26,11 +27,6 @@ const ( EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" ) -const ( - defaultEndpoint = "https://api.checkdomain.de" - defaultTTL = 300 -) - // Config is used to configure the creation of the DNSProvider. type Config struct { Endpoint *url.URL @@ -44,7 +40,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 300), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 5*time.Minute), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 7*time.Second), HTTPClient: &http.Client{ @@ -56,9 +52,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config - - domainIDMu sync.Mutex - domainIDMapping map[string]int + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for CheckDomain. @@ -71,7 +65,7 @@ func NewDNSProvider() (*DNSProvider, error) { config := NewDefaultConfig() config.Token = values[EnvToken] - endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, defaultEndpoint)) + endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, internal.DefaultEndpoint)) if err != nil { return nil, fmt.Errorf("checkdomain: invalid %s: %w", EnvEndpoint, err) } @@ -89,32 +83,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("checkdomain: missing token") } - if config.HTTPClient == nil { - config.HTTPClient = http.DefaultClient + client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.Token)) + + if config.Endpoint != nil { + client.BaseURL = config.Endpoint } - return &DNSProvider{ - config: config, - domainIDMapping: make(map[string]int), - }, nil + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainID, err := d.getDomainIDByName(domain) + domainID, err := d.client.GetDomainIDByName(ctx, domain) if err != nil { return fmt.Errorf("checkdomain: %w", err) } - err = d.checkNameservers(domainID) + err = d.client.CheckNameservers(ctx, domainID) if err != nil { return fmt.Errorf("checkdomain: %w", err) } info := dns01.GetChallengeInfo(domain, keyAuth) - err = d.createRecord(domainID, &Record{ + err = d.client.CreateRecord(ctx, domainID, &internal.Record{ Name: info.EffectiveFQDN, TTL: d.config.TTL, Type: "TXT", @@ -130,28 +125,28 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // CleanUp removes the TXT record previously created. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainID, err := d.getDomainIDByName(domain) + domainID, err := d.client.GetDomainIDByName(ctx, domain) if err != nil { return fmt.Errorf("checkdomain: %w", err) } - err = d.checkNameservers(domainID) + err = d.client.CheckNameservers(ctx, domainID) if err != nil { return fmt.Errorf("checkdomain: %w", err) } info := dns01.GetChallengeInfo(domain, keyAuth) - err = d.deleteTXTRecord(domainID, info.EffectiveFQDN, info.Value) + defer d.client.CleanCache(info.EffectiveFQDN) + + err = d.client.DeleteTXTRecord(ctx, domainID, info.EffectiveFQDN, info.Value) if err != nil { return fmt.Errorf("checkdomain: %w", err) } - d.domainIDMu.Lock() - delete(d.domainIDMapping, info.EffectiveFQDN) - d.domainIDMu.Unlock() - return nil } diff --git a/providers/dns/checkdomain/checkdomain_test.go b/providers/dns/checkdomain/checkdomain_test.go index eb9f05d3..b94f9397 100644 --- a/providers/dns/checkdomain/checkdomain_test.go +++ b/providers/dns/checkdomain/checkdomain_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/providers/dns/checkdomain/internal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -83,7 +84,7 @@ func TestNewDNSProviderConfig(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { config := NewDefaultConfig() - config.Endpoint, _ = url.Parse(defaultEndpoint) + config.Endpoint, _ = url.Parse(internal.DefaultEndpoint) if test.token != "" { config.Token = test.token diff --git a/providers/dns/checkdomain/client.go b/providers/dns/checkdomain/client.go deleted file mode 100644 index 8b401a7e..00000000 --- a/providers/dns/checkdomain/client.go +++ /dev/null @@ -1,416 +0,0 @@ -package checkdomain - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" -) - -const ( - ns1 = "ns.checkdomain.de" - ns2 = "ns2.checkdomain.de" -) - -const domainNotFound = -1 - -// max page limit that the checkdomain api allows. -const maxLimit = 100 - -// max integer value. -const maxInt = int((^uint(0)) >> 1) - -type ( - // Some fields have been omitted from the structs - // because they are not required for this application. - - DomainListingResponse struct { - Page int `json:"page"` - Limit int `json:"limit"` - Pages int `json:"pages"` - Total int `json:"total"` - Embedded EmbeddedDomainList `json:"_embedded"` - } - - EmbeddedDomainList struct { - Domains []*Domain `json:"domains"` - } - - Domain struct { - ID int `json:"id"` - Name string `json:"name"` - } - - DomainResponse struct { - ID int `json:"id"` - Name string `json:"name"` - Created string `json:"created"` - PaidUp string `json:"payed_up"` - Active bool `json:"active"` - } - - NameserverResponse struct { - General NameserverGeneral `json:"general"` - Nameservers []*Nameserver `json:"nameservers"` - SOA NameserverSOA `json:"soa"` - } - - NameserverGeneral struct { - IPv4 string `json:"ip_v4"` - IPv6 string `json:"ip_v6"` - IncludeWWW bool `json:"include_www"` - } - - NameserverSOA struct { - Mail string `json:"mail"` - Refresh int `json:"refresh"` - Retry int `json:"retry"` - Expiry int `json:"expiry"` - TTL int `json:"ttl"` - } - - Nameserver struct { - Name string `json:"name"` - } - - RecordListingResponse struct { - Page int `json:"page"` - Limit int `json:"limit"` - Pages int `json:"pages"` - Total int `json:"total"` - Embedded EmbeddedRecordList `json:"_embedded"` - } - - EmbeddedRecordList struct { - Records []*Record `json:"records"` - } - - Record struct { - Name string `json:"name"` - Value string `json:"value"` - TTL int `json:"ttl"` - Priority int `json:"priority"` - Type string `json:"type"` - } -) - -func (d *DNSProvider) getDomainIDByName(name string) (int, error) { - // Load from cache if exists - d.domainIDMu.Lock() - id, ok := d.domainIDMapping[name] - d.domainIDMu.Unlock() - if ok { - return id, nil - } - - // Find out by querying API - domains, err := d.listDomains() - if err != nil { - return domainNotFound, err - } - - // Linear search over all registered domains - for _, domain := range domains { - if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) { - d.domainIDMu.Lock() - d.domainIDMapping[name] = domain.ID - d.domainIDMu.Unlock() - - return domain.ID, nil - } - } - - return domainNotFound, errors.New("domain not found") -} - -func (d *DNSProvider) listDomains() ([]*Domain, error) { - req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody) - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - - // Checkdomain also provides a query param 'query' which allows filtering domains for a string. - // But that functionality is kinda broken, - // so we scan through the whole list of registered domains to later find the one that is of interest to us. - q := req.URL.Query() - q.Set("limit", strconv.Itoa(maxLimit)) - - currentPage := 1 - totalPages := maxInt - - var domainList []*Domain - for currentPage <= totalPages { - q.Set("page", strconv.Itoa(currentPage)) - req.URL.RawQuery = q.Encode() - - var res DomainListingResponse - if err := d.sendRequest(req, &res); err != nil { - return nil, fmt.Errorf("failed to send domain listing request: %w", err) - } - - // This is the first response, - // so we update totalPages and allocate the slice memory. - if totalPages == maxInt { - totalPages = res.Pages - domainList = make([]*Domain, 0, res.Total) - } - - domainList = append(domainList, res.Embedded.Domains...) - currentPage++ - } - - return domainList, nil -} - -func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) { - req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody) - if err != nil { - return nil, err - } - - res := &NameserverResponse{} - if err := d.sendRequest(req, res); err != nil { - return nil, err - } - - return res, nil -} - -func (d *DNSProvider) checkNameservers(domainID int) error { - info, err := d.getNameserverInfo(domainID) - if err != nil { - return err - } - - var found1, found2 bool - for _, item := range info.Nameservers { - switch item.Name { - case ns1: - found1 = true - case ns2: - found2 = true - } - } - - if !found1 || !found2 { - return errors.New("not using checkdomain nameservers, can not update records") - } - - return nil -} - -func (d *DNSProvider) createRecord(domainID int, record *Record) error { - bs, err := json.Marshal(record) - if err != nil { - return fmt.Errorf("encoding record failed: %w", err) - } - - req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs)) - if err != nil { - return err - } - - return d.sendRequest(req, nil) -} - -// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once. -// The current solution is to fetch all records and then use that list minus the record deleted as the new record list. -// TODO: Simplify this function once Checkdomain do provide the functionality. -func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error { - domainInfo, err := d.getDomainInfo(domainID) - if err != nil { - return err - } - - nsInfo, err := d.getNameserverInfo(domainID) - if err != nil { - return err - } - - allRecords, err := d.listRecords(domainID, "") - if err != nil { - return err - } - - recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".") - - var recordsToKeep []*Record - - // Find and delete matching records - for _, record := range allRecords { - if skipRecord(recordName, recordValue, record, nsInfo) { - continue - } - - // Checkdomain API can return records without any TTL set (indicated by the value of 0). - // The API Call to replace the records would fail if we wouldn't specify a value. - // Thus, we use the default TTL queried beforehand - if record.TTL == 0 { - record.TTL = nsInfo.SOA.TTL - } - - recordsToKeep = append(recordsToKeep, record) - } - - return d.replaceRecords(domainID, recordsToKeep) -} - -func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) { - req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody) - if err != nil { - return nil, err - } - - var res DomainResponse - err = d.sendRequest(req, &res) - if err != nil { - return nil, err - } - - return &res, nil -} - -func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) { - req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody) - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - - q := req.URL.Query() - q.Set("limit", strconv.Itoa(maxLimit)) - if recordType != "" { - q.Set("type", recordType) - } - - currentPage := 1 - totalPages := maxInt - - var recordList []*Record - for currentPage <= totalPages { - q.Set("page", strconv.Itoa(currentPage)) - req.URL.RawQuery = q.Encode() - - var res RecordListingResponse - if err := d.sendRequest(req, &res); err != nil { - return nil, fmt.Errorf("failed to send record listing request: %w", err) - } - - // This is the first response, so we update totalPages and allocate the slice memory. - if totalPages == maxInt { - totalPages = res.Pages - recordList = make([]*Record, 0, res.Total) - } - - recordList = append(recordList, res.Embedded.Records...) - currentPage++ - } - - return recordList, nil -} - -func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error { - bs, err := json.Marshal(records) - if err != nil { - return fmt.Errorf("encoding record failed: %w", err) - } - - req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs)) - if err != nil { - return err - } - - return d.sendRequest(req, nil) -} - -func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool { - // Skip empty records - if record.Value == "" { - return true - } - - // Skip some special records, otherwise we would get a "Nameserver update failed" - if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") { - return true - } - - nameMatch := recordName == "" || record.Name == recordName - valueMatch := recordValue == "" || record.Value == recordValue - - // Skip our matching record - if record.Type == "TXT" && nameMatch && valueMatch { - return true - } - - return false -} - -func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) { - uri, err := d.config.Endpoint.Parse(resource) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, uri.String(), body) - if err != nil { - return nil, err - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+d.config.Token) - if method != http.MethodGet { - req.Header.Set("Content-Type", "application/json") - } - - return req, nil -} - -func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error { - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - if err = checkResponse(resp); err != nil { - return err - } - - defer func() { _ = resp.Body.Close() }() - - if result == nil { - return nil - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - err = json.Unmarshal(raw, result) - if err != nil { - return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw)) - } - return nil -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode < http.StatusBadRequest { - return nil - } - - if resp.Body == nil { - return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) - } - - defer func() { _ = resp.Body.Close() }() - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) - } - - return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw)) -} diff --git a/providers/dns/checkdomain/internal/client.go b/providers/dns/checkdomain/internal/client.go new file mode 100644 index 00000000..74189dee --- /dev/null +++ b/providers/dns/checkdomain/internal/client.go @@ -0,0 +1,383 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" +) + +const ( + ns1 = "ns.checkdomain.de" + ns2 = "ns2.checkdomain.de" +) + +// DefaultEndpoint the default API endpoint. +const DefaultEndpoint = "https://api.checkdomain.de" + +const domainNotFound = -1 + +// max page limit that the checkdomain api allows. +const maxLimit = 100 + +// max integer value. +const maxInt = int((^uint(0)) >> 1) + +// Client the Autodns API client. +type Client struct { + domainIDMapping map[string]int + domainIDMu sync.Mutex + + BaseURL *url.URL + httpClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(hc *http.Client) *Client { + baseURL, _ := url.Parse(DefaultEndpoint) + + if hc == nil { + hc = &http.Client{Timeout: 10 * time.Second} + } + + return &Client{ + BaseURL: baseURL, + httpClient: hc, + domainIDMapping: make(map[string]int), + } +} + +func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) { + // Load from cache if exists + c.domainIDMu.Lock() + id, ok := c.domainIDMapping[name] + c.domainIDMu.Unlock() + if ok { + return id, nil + } + + // Find out by querying API + domains, err := c.listDomains(ctx) + if err != nil { + return domainNotFound, err + } + + // Linear search over all registered domains + for _, domain := range domains { + if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) { + c.domainIDMu.Lock() + c.domainIDMapping[name] = domain.ID + c.domainIDMu.Unlock() + + return domain.ID, nil + } + } + + return domainNotFound, errors.New("domain not found") +} + +func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) { + endpoint := c.BaseURL.JoinPath("v1", "domains") + + // Checkdomain also provides a query param 'query' which allows filtering domains for a string. + // But that functionality is kinda broken, + // so we scan through the whole list of registered domains to later find the one that is of interest to us. + q := endpoint.Query() + q.Set("limit", strconv.Itoa(maxLimit)) + + currentPage := 1 + totalPages := maxInt + + var domainList []*Domain + for currentPage <= totalPages { + q.Set("page", strconv.Itoa(currentPage)) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + + var res DomainListingResponse + if err := c.do(req, &res); err != nil { + return nil, fmt.Errorf("failed to send domain listing request: %w", err) + } + + // This is the first response, + // so we update totalPages and allocate the slice memory. + if totalPages == maxInt { + totalPages = res.Pages + domainList = make([]*Domain, 0, res.Total) + } + + domainList = append(domainList, res.Embedded.Domains...) + currentPage++ + } + + return domainList, nil +} + +func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) { + endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + res := &NameserverResponse{} + if err := c.do(req, res); err != nil { + return nil, err + } + + return res, nil +} + +func (c *Client) CheckNameservers(ctx context.Context, domainID int) error { + info, err := c.getNameserverInfo(ctx, domainID) + if err != nil { + return err + } + + var found1, found2 bool + for _, item := range info.Nameservers { + switch item.Name { + case ns1: + found1 = true + case ns2: + found2 = true + } + } + + if !found1 || !found2 { + return errors.New("not using checkdomain nameservers, can not update records") + } + + return nil +} + +func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error { + endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return err + } + + return c.do(req, nil) +} + +// DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once. +// The current solution is to fetch all records and then use that list minus the record deleted as the new record list. +// TODO: Simplify this function once Checkdomain do provide the functionality. +func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error { + domainInfo, err := c.getDomainInfo(ctx, domainID) + if err != nil { + return err + } + + nsInfo, err := c.getNameserverInfo(ctx, domainID) + if err != nil { + return err + } + + allRecords, err := c.listRecords(ctx, domainID, "") + if err != nil { + return err + } + + recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".") + + var recordsToKeep []*Record + + // Find and delete matching records + for _, record := range allRecords { + if skipRecord(recordName, recordValue, record, nsInfo) { + continue + } + + // Checkdomain API can return records without any TTL set (indicated by the value of 0). + // The API Call to replace the records would fail if we wouldn't specify a value. + // Thus, we use the default TTL queried beforehand + if record.TTL == 0 { + record.TTL = nsInfo.SOA.TTL + } + + recordsToKeep = append(recordsToKeep, record) + } + + return c.replaceRecords(ctx, domainID, recordsToKeep) +} + +func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) { + endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID)) + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + var res DomainResponse + err = c.do(req, &res) + if err != nil { + return nil, err + } + + return &res, nil +} + +func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) { + endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") + + q := endpoint.Query() + q.Set("limit", strconv.Itoa(maxLimit)) + if recordType != "" { + q.Set("type", recordType) + } + + currentPage := 1 + totalPages := maxInt + + var recordList []*Record + for currentPage <= totalPages { + q.Set("page", strconv.Itoa(currentPage)) + endpoint.RawQuery = q.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + var res RecordListingResponse + if err := c.do(req, &res); err != nil { + return nil, fmt.Errorf("failed to send record listing request: %w", err) + } + + // This is the first response, so we update totalPages and allocate the slice memory. + if totalPages == maxInt { + totalPages = res.Pages + recordList = make([]*Record, 0, res.Total) + } + + recordList = append(recordList, res.Embedded.Records...) + currentPage++ + } + + return recordList, nil +} + +func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error { + endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records") + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + resp, err := c.httpClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func (c *Client) CleanCache(fqdn string) { + c.domainIDMu.Lock() + delete(c.domainIDMapping, fqdn) + c.domainIDMu.Unlock() +} + +func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool { + // Skip empty records + if record.Value == "" { + return true + } + + // Skip some special records, otherwise we would get a "Nameserver update failed" + if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") { + return true + } + + nameMatch := recordName == "" || record.Name == recordName + valueMatch := recordValue == "" || record.Value == recordValue + + // Skip our matching record + if record.Type == "TXT" && nameMatch && valueMatch { + return true + } + + return false +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client +} diff --git a/providers/dns/checkdomain/client_test.go b/providers/dns/checkdomain/internal/client_test.go similarity index 60% rename from providers/dns/checkdomain/client_test.go rename to providers/dns/checkdomain/internal/client_test.go index f7c488be..3f6a7e7a 100644 --- a/providers/dns/checkdomain/client_test.go +++ b/providers/dns/checkdomain/internal/client_test.go @@ -1,6 +1,8 @@ -package checkdomain +package internal import ( + "bytes" + "context" "encoding/json" "fmt" "io" @@ -15,32 +17,42 @@ import ( "github.com/stretchr/testify/require" ) -func setupTestProvider(t *testing.T) (*DNSProvider, *http.ServeMux) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) - config := NewDefaultConfig() - config.Endpoint, _ = url.Parse(server.URL) - config.Token = "secret" + client := NewClient(OAuthStaticAccessToken(server.Client(), "secret")) + client.BaseURL, _ = url.Parse(server.URL) - p, err := NewDNSProviderConfig(config) - require.NoError(t, err) - - return p, mux + return client, mux } -func Test_getDomainIDByName(t *testing.T) { - prd, handler := setupTestProvider(t) +func checkAuthorizationHeader(req *http.Request) error { + val := req.Header.Get("Authorization") + if val != "Bearer secret" { + return fmt.Errorf("invalid header value, got: %s want %s", val, "Bearer secret") + } + return nil +} - handler.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) { +func TestClient_GetDomainIDByName(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/v1/domains", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) return } + err := checkAuthorizationHeader(req) + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + domainList := DomainListingResponse{ Embedded: EmbeddedDomainList{Domains: []*Domain{ {ID: 1, Name: "test.com"}, @@ -48,28 +60,34 @@ func Test_getDomainIDByName(t *testing.T) { }}, } - err := json.NewEncoder(rw).Encode(domainList) + err = json.NewEncoder(rw).Encode(domainList) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } }) - id, err := prd.getDomainIDByName("test.com") + id, err := client.GetDomainIDByName(context.Background(), "test.com") require.NoError(t, err) assert.Equal(t, 1, id) } -func Test_checkNameservers(t *testing.T) { - prd, handler := setupTestProvider(t) +func TestClient_CheckNameservers(t *testing.T) { + client, mux := setupTest(t) - handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) return } + err := checkAuthorizationHeader(req) + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + nsResp := NameserverResponse{ Nameservers: []*Nameserver{ {Name: ns1}, @@ -78,33 +96,39 @@ func Test_checkNameservers(t *testing.T) { }, } - err := json.NewEncoder(rw).Encode(nsResp) + err = json.NewEncoder(rw).Encode(nsResp) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } }) - err := prd.checkNameservers(1) + err := client.CheckNameservers(context.Background(), 1) require.NoError(t, err) } -func Test_createRecord(t *testing.T) { - prd, handler := setupTestProvider(t) +func TestClient_CreateRecord(t *testing.T) { + client, mux := setupTest(t) - handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) return } + err := checkAuthorizationHeader(req) + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + content, err := io.ReadAll(req.Body) if err != nil { http.Error(rw, err.Error(), http.StatusBadRequest) return } - if string(content) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` { + if string(bytes.TrimSpace(content)) != `{"name":"test.com","value":"value","ttl":300,"priority":0,"type":"TXT"}` { http.Error(rw, "invalid request body: "+string(content), http.StatusBadRequest) return } @@ -117,12 +141,12 @@ func Test_createRecord(t *testing.T) { Value: "value", } - err := prd.createRecord(1, record) + err := client.CreateRecord(context.Background(), 1, record) require.NoError(t, err) } -func Test_deleteTXTRecord(t *testing.T) { - prd, handler := setupTestProvider(t) +func TestClient_DeleteTXTRecord(t *testing.T) { + client, mux := setupTest(t) domainName := "lego.test" recordValue := "test" @@ -158,20 +182,26 @@ func Test_deleteTXTRecord(t *testing.T) { }, } - handler.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/v1/domains/1", func(rw http.ResponseWriter, req *http.Request) { + err := checkAuthorizationHeader(req) + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + resp := DomainResponse{ ID: 1, Name: domainName, } - err := json.NewEncoder(rw).Encode(resp) + err = json.NewEncoder(rw).Encode(resp) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } }) - handler.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/v1/domains/1/nameservers", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, "invalid method: "+req.Method, http.StatusBadRequest) return @@ -188,7 +218,7 @@ func Test_deleteTXTRecord(t *testing.T) { } }) - handler.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/v1/domains/1/nameservers/records", func(rw http.ResponseWriter, req *http.Request) { switch req.Method { case http.MethodGet: resp := RecordListingResponse{ @@ -226,6 +256,6 @@ func Test_deleteTXTRecord(t *testing.T) { }) info := dns01.GetChallengeInfo(domainName, "abc") - err := prd.deleteTXTRecord(1, info.EffectiveFQDN, recordValue) + err := client.DeleteTXTRecord(context.Background(), 1, info.EffectiveFQDN, recordValue) require.NoError(t, err) } diff --git a/providers/dns/checkdomain/internal/types.go b/providers/dns/checkdomain/internal/types.go new file mode 100644 index 00000000..06e0b018 --- /dev/null +++ b/providers/dns/checkdomain/internal/types.go @@ -0,0 +1,73 @@ +package internal + +// Some fields have been omitted from the structs +// because they are not required for this application. + +type DomainListingResponse struct { + Page int `json:"page"` + Limit int `json:"limit"` + Pages int `json:"pages"` + Total int `json:"total"` + Embedded EmbeddedDomainList `json:"_embedded"` +} + +type EmbeddedDomainList struct { + Domains []*Domain `json:"domains"` +} + +type Domain struct { + ID int `json:"id"` + Name string `json:"name"` +} + +type DomainResponse struct { + ID int `json:"id"` + Name string `json:"name"` + Created string `json:"created"` + PaidUp string `json:"payed_up"` + Active bool `json:"active"` +} + +type NameserverResponse struct { + General NameserverGeneral `json:"general"` + Nameservers []*Nameserver `json:"nameservers"` + SOA NameserverSOA `json:"soa"` +} + +type NameserverGeneral struct { + IPv4 string `json:"ip_v4"` + IPv6 string `json:"ip_v6"` + IncludeWWW bool `json:"include_www"` +} + +type NameserverSOA struct { + Mail string `json:"mail"` + Refresh int `json:"refresh"` + Retry int `json:"retry"` + Expiry int `json:"expiry"` + TTL int `json:"ttl"` +} + +type Nameserver struct { + Name string `json:"name"` +} + +type RecordListingResponse struct { + Page int `json:"page"` + Limit int `json:"limit"` + Pages int `json:"pages"` + Total int `json:"total"` + Embedded EmbeddedRecordList `json:"_embedded"` +} + +type EmbeddedRecordList struct { + Records []*Record `json:"records"` +} + +type Record struct { + Name string `json:"name"` + Value string `json:"value"` + TTL int `json:"ttl"` + Priority int `json:"priority"` + Type string `json:"type"` +} diff --git a/providers/dns/civo/civo.go b/providers/dns/civo/civo.go index 6190ca2f..3d639eb6 100644 --- a/providers/dns/civo/civo.go +++ b/providers/dns/civo/civo.go @@ -93,11 +93,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + zone := dns01.UnFqdn(authZone) + dnsDomain, err := d.client.GetDNSDomain(zone) if err != nil { return fmt.Errorf("civo: %w", err) @@ -125,11 +127,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("civo: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("civo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + zone := dns01.UnFqdn(authZone) + dnsDomain, err := d.client.GetDNSDomain(zone) if err != nil { return fmt.Errorf("civo: %w", err) @@ -166,12 +170,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { return d.config.PropagationTimeout, d.config.PollingInterval } - -func getZone(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/clouddns/clouddns.go b/providers/dns/clouddns/clouddns.go index e12054ec..7b0644f7 100644 --- a/providers/dns/clouddns/clouddns.go +++ b/providers/dns/clouddns/clouddns.go @@ -2,6 +2,7 @@ package clouddns import ( + "context" "errors" "fmt" "net/http" @@ -89,10 +90,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { client.HTTPClient = config.HTTPClient } - return &DNSProvider{ - client: client, - config: config, - }, nil + return &DNSProvider{client: client, config: config}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -107,12 +105,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("clouddns: %w", err) + return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.client.AddRecord(authZone, info.EffectiveFQDN, info.Value) + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return fmt.Errorf("clouddns: %w", err) + return err + } + + err = d.client.AddRecord(ctx, authZone, info.EffectiveFQDN, info.Value) + if err != nil { + return fmt.Errorf("clouddns: add record: %w", err) } return nil @@ -124,12 +127,17 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("clouddns: %w", err) + return fmt.Errorf("clouddns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.client.DeleteRecord(authZone, info.EffectiveFQDN) + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return fmt.Errorf("clouddns: %w", err) + return err + } + + err = d.client.DeleteRecord(ctx, authZone, info.EffectiveFQDN) + if err != nil { + return fmt.Errorf("clouddns: delete record: %w", err) } return nil diff --git a/providers/dns/clouddns/internal/client.go b/providers/dns/clouddns/internal/client.go index 7ea6234c..cd3da50c 100644 --- a/providers/dns/clouddns/internal/client.go +++ b/providers/dns/clouddns/internal/client.go @@ -2,117 +2,127 @@ package internal import ( "bytes" + "context" "encoding/json" - "errors" "fmt" "io" "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) -const ( - apiBaseURL = "https://admin.vshosting.cloud/clouddns" - loginURL = "https://admin.vshosting.cloud/api/public/auth/login" -) +const apiBaseURL = "https://admin.vshosting.cloud/clouddns" + +const authorizationHeader = "Authorization" // Client handles all communication with CloudDNS API. type Client struct { - AccessToken string - ClientID string - Email string - Password string - TTL int - HTTPClient *http.Client + clientID string + email string + password string + ttl int - apiBaseURL string - loginURL string + apiBaseURL *url.URL + + loginURL *url.URL + + HTTPClient *http.Client } // NewClient returns a Client instance configured to handle CloudDNS API communication. func NewClient(clientID, email, password string, ttl int) *Client { + baseURL, _ := url.Parse(apiBaseURL) + loginBaseURL, _ := url.Parse(loginURL) + return &Client{ - ClientID: clientID, - Email: email, - Password: password, - TTL: ttl, - HTTPClient: &http.Client{}, - apiBaseURL: apiBaseURL, - loginURL: loginURL, + clientID: clientID, + email: email, + password: password, + ttl: ttl, + apiBaseURL: baseURL, + loginURL: loginBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // AddRecord is a high level method to add a new record into CloudDNS zone. -func (c *Client) AddRecord(zone, recordName, recordValue string) error { - domain, err := c.getDomain(zone) +func (c *Client) AddRecord(ctx context.Context, zone, recordName, recordValue string) error { + domain, err := c.getDomain(ctx, zone) if err != nil { return err } record := Record{DomainID: domain.ID, Name: recordName, Value: recordValue, Type: "TXT"} - err = c.addTxtRecord(record) + err = c.addTxtRecord(ctx, record) if err != nil { return err } - return c.publishRecords(domain.ID) + return c.publishRecords(ctx, domain.ID) } // DeleteRecord is a high level method to remove a record from zone. -func (c *Client) DeleteRecord(zone, recordName string) error { - domain, err := c.getDomain(zone) +func (c *Client) DeleteRecord(ctx context.Context, zone, recordName string) error { + domain, err := c.getDomain(ctx, zone) if err != nil { return err } - record, err := c.getRecord(domain.ID, recordName) + record, err := c.getRecord(ctx, domain.ID, recordName) if err != nil { return err } - err = c.deleteRecord(record) + err = c.deleteRecord(ctx, record) if err != nil { return err } - return c.publishRecords(domain.ID) + return c.publishRecords(ctx, domain.ID) } -func (c *Client) addTxtRecord(record Record) error { - body, err := json.Marshal(record) +func (c *Client) addTxtRecord(ctx context.Context, record Record) error { + endpoint := c.apiBaseURL.JoinPath("record-txt") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return err } - _, err = c.doAPIRequest(http.MethodPost, "record-txt", bytes.NewReader(body)) - return err + return c.do(req, nil) } -func (c *Client) deleteRecord(record Record) error { - endpoint := fmt.Sprintf("record/%s", record.ID) - _, err := c.doAPIRequest(http.MethodDelete, endpoint, nil) - return err +func (c *Client) deleteRecord(ctx context.Context, record Record) error { + endpoint := c.apiBaseURL.JoinPath("record", record.ID) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) } -func (c *Client) getDomain(zone string) (Domain, error) { +func (c *Client) getDomain(ctx context.Context, zone string) (Domain, error) { searchQuery := SearchQuery{ Search: []Search{ - {Name: "clientId", Operator: "eq", Value: c.ClientID}, + {Name: "clientId", Operator: "eq", Value: c.clientID}, {Name: "domainName", Operator: "eq", Value: zone}, }, } - body, err := json.Marshal(searchQuery) - if err != nil { - return Domain{}, err - } + endpoint := c.apiBaseURL.JoinPath("domain", "search") - resp, err := c.doAPIRequest(http.MethodPost, "domain/search", bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, searchQuery) if err != nil { return Domain{}, err } var result SearchResponse - err = json.Unmarshal(resp, &result) + err = c.do(req, &result) if err != nil { return Domain{}, err } @@ -124,15 +134,16 @@ func (c *Client) getDomain(zone string) (Domain, error) { return result.Items[0], nil } -func (c *Client) getRecord(domainID, recordName string) (Record, error) { - endpoint := fmt.Sprintf("domain/%s", domainID) - resp, err := c.doAPIRequest(http.MethodGet, endpoint, nil) +func (c *Client) getRecord(ctx context.Context, domainID, recordName string) (Record, error) { + endpoint := c.apiBaseURL.JoinPath("domain", domainID) + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return Record{}, err } var result DomainInfo - err = json.Unmarshal(resp, &result) + err = c.do(req, &result) if err != nil { return Record{}, err } @@ -146,116 +157,85 @@ func (c *Client) getRecord(domainID, recordName string) (Record, error) { return Record{}, fmt.Errorf("record not found: domainID %s, name %s", domainID, recordName) } -func (c *Client) publishRecords(domainID string) error { - body, err := json.Marshal(DomainInfo{SoaTTL: c.TTL}) +func (c *Client) publishRecords(ctx context.Context, domainID string) error { + endpoint := c.apiBaseURL.JoinPath("domain", domainID, "publish") + + payload := DomainInfo{SoaTTL: c.ttl} + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, payload) if err != nil { return err } - endpoint := fmt.Sprintf("domain/%s/publish", domainID) - _, err = c.doAPIRequest(http.MethodPut, endpoint, bytes.NewReader(body)) - return err + return c.do(req, nil) } -func (c *Client) login() error { - authorization := Authorization{Email: c.Email, Password: c.Password} - - body, err := json.Marshal(authorization) - if err != nil { - return err +func (c *Client) do(req *http.Request, result any) error { + at := getAccessToken(req.Context()) + if at != "" { + req.Header.Set(authorizationHeader, "Bearer "+at) } - req, err := http.NewRequest(http.MethodPost, c.loginURL, bytes.NewReader(body)) + resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } - req.Header.Set("Content-Type", "application/json") + defer func() { _ = resp.Body.Close() }() - content, err := c.doRequest(req) - if err != nil { - return err + if resp.StatusCode/100 != 2 { + return parseError(req, resp) } - var result AuthResponse - err = json.Unmarshal(content, &result) - if err != nil { - return err + if result == nil { + return nil } - c.AccessToken = result.Auth.AccessToken + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } return nil } -func (c *Client) doAPIRequest(method, endpoint string, body io.Reader) ([]byte, error) { - if c.AccessToken == "" { - err := c.login() +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create request JSON body: %w", err) } } - url := fmt.Sprintf("%s/%s", c.apiBaseURL, endpoint) - - req, err := c.newRequest(method, url, body) + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } - content, err := c.doRequest(req) - if err != nil { - return nil, err + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") } - return content, nil -} - -func (c *Client) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, reqURL, body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.AccessToken)) - return req, nil } -func (c *Client) doRequest(req *http.Request) ([]byte, error) { - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) - if resp.StatusCode >= http.StatusBadRequest { - return nil, readError(req, resp) + var response APIError + err := json.Unmarshal(raw, &response) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return content, nil -} - -func readError(req *http.Request, resp *http.Response) error { - content, err := io.ReadAll(resp.Body) - if err != nil { - return errors.New(toUnreadableBodyMessage(req, content)) - } - - var errInfo APIError - err = json.Unmarshal(content, &errInfo) - if err != nil { - return fmt.Errorf("APIError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) - } - - return fmt.Errorf("HTTP %d: code %v: %s", resp.StatusCode, errInfo.Error.Code, errInfo.Error.Message) -} - -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) + return fmt.Errorf("[status code %d] %w", resp.StatusCode, response.Error) } diff --git a/providers/dns/clouddns/internal/client_test.go b/providers/dns/clouddns/internal/client_test.go index 68f500d8..2a4891cc 100644 --- a/providers/dns/clouddns/internal/client_test.go +++ b/providers/dns/clouddns/internal/client_test.go @@ -1,16 +1,33 @@ package internal import ( + "context" "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/require" ) -func TestClient_AddRecord(t *testing.T) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := NewClient("clientID", "email@example.com", "secret", 300) + client.HTTPClient = server.Client() + client.apiBaseURL, _ = url.Parse(server.URL + "/api") + client.loginURL, _ = url.Parse(server.URL + "/login") + + return client, mux +} + +func TestClient_AddRecord(t *testing.T) { + client, mux := setupTest(t) mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) { response := SearchResponse{ @@ -45,19 +62,12 @@ func TestClient_AddRecord(t *testing.T) { } }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - client := NewClient("clientID", "email@example.com", "secret", 300) - client.apiBaseURL = server.URL + "/api" - client.loginURL = server.URL + "/login" - - err := client.AddRecord("example.com", "_acme-challenge.example.com", "txt") + err := client.AddRecord(context.Background(), "example.com", "_acme-challenge.example.com", "txt") require.NoError(t, err) } func TestClient_DeleteRecord(t *testing.T) { - mux := http.NewServeMux() + client, mux := setupTest(t) mux.HandleFunc("/api/domain/search", func(rw http.ResponseWriter, req *http.Request) { response := SearchResponse{ @@ -114,13 +124,9 @@ func TestClient_DeleteRecord(t *testing.T) { } }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) - client := NewClient("clientID", "email@example.com", "secret", 300) - client.apiBaseURL = server.URL + "/api" - client.loginURL = server.URL + "/login" - - err := client.DeleteRecord("example.com", "_acme-challenge.example.com") + err = client.DeleteRecord(ctx, "example.com", "_acme-challenge.example.com") require.NoError(t, err) } diff --git a/providers/dns/clouddns/internal/identity.go b/providers/dns/clouddns/internal/identity.go new file mode 100644 index 00000000..4ea5c504 --- /dev/null +++ b/providers/dns/clouddns/internal/identity.go @@ -0,0 +1,47 @@ +package internal + +import ( + "context" + "net/http" +) + +const loginURL = "https://admin.vshosting.cloud/api/public/auth/login" + +type token string + +const accessTokenKey token = "accessToken" + +func (c *Client) login(ctx context.Context) (*AuthResponse, error) { + authorization := Authorization{Email: c.email, Password: c.password} + + req, err := newJSONRequest(ctx, http.MethodPost, c.loginURL, authorization) + if err != nil { + return nil, err + } + + var result AuthResponse + err = c.do(req, &result) + if err != nil { + return nil, err + } + + return &result, nil +} + +func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) { + tok, err := c.login(ctx) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, accessTokenKey, tok.Auth.AccessToken), nil +} + +func getAccessToken(ctx context.Context) string { + tok, ok := ctx.Value(accessTokenKey).(string) + if !ok { + return "" + } + + return tok +} diff --git a/providers/dns/clouddns/internal/identity_test.go b/providers/dns/clouddns/internal/identity_test.go new file mode 100644 index 00000000..3c727448 --- /dev/null +++ b/providers/dns/clouddns/internal/identity_test.go @@ -0,0 +1,46 @@ +package internal + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClient_CreateAuthenticatedContext(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/login", func(rw http.ResponseWriter, req *http.Request) { + response := AuthResponse{ + Auth: Auth{ + AccessToken: "at", + RefreshToken: "", + }, + } + + err := json.NewEncoder(rw).Encode(response) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + mux.HandleFunc("/api/record/xxx", func(rw http.ResponseWriter, req *http.Request) { + authorization := req.Header.Get(authorizationHeader) + if authorization != "Bearer at" { + http.Error(rw, "invalid credential: "+authorization, http.StatusUnauthorized) + return + } + }) + + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + at := getAccessToken(ctx) + assert.Equal(t, "at", at) + + err = client.deleteRecord(ctx, Record{ID: "xxx"}) + require.NoError(t, err) +} diff --git a/providers/dns/clouddns/internal/models.go b/providers/dns/clouddns/internal/types.go similarity index 95% rename from providers/dns/clouddns/internal/models.go rename to providers/dns/clouddns/internal/types.go index a46bfdf0..a53c958a 100644 --- a/providers/dns/clouddns/internal/models.go +++ b/providers/dns/clouddns/internal/types.go @@ -1,5 +1,7 @@ package internal +import "fmt" + type APIError struct { Error ErrorContent `json:"error"` } @@ -9,6 +11,10 @@ type ErrorContent struct { Message string `json:"message,omitempty"` } +func (e ErrorContent) Error() string { + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + type Authorization struct { Email string `json:"email,omitempty"` Password string `json:"password,omitempty"` diff --git a/providers/dns/cloudflare/cloudflare.go b/providers/dns/cloudflare/cloudflare.go index 16bfe390..b32e91ff 100644 --- a/providers/dns/cloudflare/cloudflare.go +++ b/providers/dns/cloudflare/cloudflare.go @@ -126,7 +126,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("cloudflare: %w", err) + return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } zoneID, err := d.client.ZoneIDByName(authZone) @@ -165,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("cloudflare: %w", err) + return fmt.Errorf("cloudflare: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } zoneID, err := d.client.ZoneIDByName(authZone) diff --git a/providers/dns/cloudflare/client.go b/providers/dns/cloudflare/wrapper.go similarity index 100% rename from providers/dns/cloudflare/client.go rename to providers/dns/cloudflare/wrapper.go diff --git a/providers/dns/cloudns/cloudns.go b/providers/dns/cloudns/cloudns.go index f75b9096..554e5416 100644 --- a/providers/dns/cloudns/cloudns.go +++ b/providers/dns/cloudns/cloudns.go @@ -2,6 +2,7 @@ package cloudns import ( + "context" "errors" "fmt" "net/http" @@ -104,29 +105,33 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.client.GetZone(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("ClouDNS: %w", err) } - err = d.client.AddTxtRecord(zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL) + err = d.client.AddTxtRecord(ctx, zone.Name, info.EffectiveFQDN, info.Value, d.config.TTL) if err != nil { return fmt.Errorf("ClouDNS: %w", err) } - return d.waitNameservers(domain, zone) + return d.waitNameservers(ctx, domain, zone) } // CleanUp removes the TXT records matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.client.GetZone(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("ClouDNS: %w", err) } - records, err := d.client.ListTxtRecords(zone.Name, info.EffectiveFQDN) + records, err := d.client.ListTxtRecords(ctx, zone.Name, info.EffectiveFQDN) if err != nil { return fmt.Errorf("ClouDNS: %w", err) } @@ -136,7 +141,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { } for _, record := range records { - err = d.client.RemoveTxtRecord(record.ID, zone.Name) + err = d.client.RemoveTxtRecord(ctx, record.ID, zone.Name) if err != nil { return fmt.Errorf("ClouDNS: %w", err) } @@ -153,9 +158,9 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { // waitNameservers At the time of writing 4 servers are found as authoritative, but 8 are reported during the sync. // If this is not done, the secondary verification done by Let's Encrypt server will fail quire a bit. -func (d *DNSProvider) waitNameservers(domain string, zone *internal.Zone) error { +func (d *DNSProvider) waitNameservers(ctx context.Context, domain string, zone *internal.Zone) error { return wait.For("Nameserver sync on "+domain, d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) { - syncProgress, err := d.client.GetUpdateStatus(zone.Name) + syncProgress, err := d.client.GetUpdateStatus(ctx, zone.Name) if err != nil { return false, err } diff --git a/providers/dns/cloudns/internal/client.go b/providers/dns/cloudns/internal/client.go index 65270f80..c4f350f6 100644 --- a/providers/dns/cloudns/internal/client.go +++ b/providers/dns/cloudns/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/json" "errors" "fmt" @@ -8,8 +9,10 @@ import ( "net/http" "net/url" "strconv" + "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.cloudns.net/dns/" @@ -19,8 +22,9 @@ type Client struct { authID string subAuthID string authPassword string - HTTPClient *http.Client - BaseURL *url.URL + + BaseURL *url.URL + HTTPClient *http.Client } // NewClient creates a ClouDNS client. @@ -42,16 +46,16 @@ func NewClient(authID, subAuthID, authPassword string) (*Client, error) { authID: authID, subAuthID: subAuthID, authPassword: authPassword, - HTTPClient: &http.Client{}, BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } // GetZone Get domain name information for a FQDN. -func (c *Client) GetZone(authFQDN string) (*Zone, error) { +func (c *Client) GetZone(ctx context.Context, authFQDN string) (*Zone, error) { authZone, err := dns01.FindZoneByFqdn(authFQDN) if err != nil { - return nil, err + return nil, fmt.Errorf("could not find zone for FQDN %q: %w", authFQDN, err) } authZoneName := dns01.UnFqdn(authZone) @@ -62,16 +66,21 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) { q.Set("domain-name", authZoneName) endpoint.RawQuery = q.Encode() - result, err := c.doRequest(http.MethodGet, endpoint) + req, err := c.newRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return nil, err + } + + rawMessage, err := c.do(req) if err != nil { return nil, err } var zone Zone - if len(result) > 0 { - if err = json.Unmarshal(result, &zone); err != nil { - return nil, fmt.Errorf("failed to unmarshal zone: %w", err) + if len(rawMessage) > 0 { + if err = json.Unmarshal(rawMessage, &zone); err != nil { + return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } } @@ -83,7 +92,7 @@ func (c *Client) GetZone(authFQDN string) (*Zone, error) { } // FindTxtRecord returns the TXT record a zone ID and a FQDN. -func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) { +func (c *Client) FindTxtRecord(ctx context.Context, zoneName, fqdn string) (*TXTRecord, error) { subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) if err != nil { return nil, err @@ -97,19 +106,24 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) { q.Set("type", "TXT") endpoint.RawQuery = q.Encode() - result, err := c.doRequest(http.MethodGet, endpoint) + req, err := c.newRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return nil, err + } + + rawMessage, err := c.do(req) if err != nil { return nil, err } // the API returns [] when there is no records. - if string(result) == "[]" { + if string(rawMessage) == "[]" { return nil, nil } var records map[string]TXTRecord - if err = json.Unmarshal(result, &records); err != nil { - return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result)) + if err = json.Unmarshal(rawMessage, &records); err != nil { + return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } for _, record := range records { @@ -122,7 +136,7 @@ func (c *Client) FindTxtRecord(zoneName, fqdn string) (*TXTRecord, error) { } // ListTxtRecords returns the TXT records a zone ID and a FQDN. -func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) { +func (c *Client) ListTxtRecords(ctx context.Context, zoneName, fqdn string) ([]TXTRecord, error) { subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) if err != nil { return nil, err @@ -136,19 +150,24 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) { q.Set("type", "TXT") endpoint.RawQuery = q.Encode() - result, err := c.doRequest(http.MethodGet, endpoint) + req, err := c.newRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return nil, err + } + + rawMessage, err := c.do(req) if err != nil { return nil, err } // the API returns [] when there is no records. - if string(result) == "[]" { + if string(rawMessage) == "[]" { return nil, nil } var raw map[string]TXTRecord - if err = json.Unmarshal(result, &raw); err != nil { - return nil, fmt.Errorf("failed to unmarshall TXT records: %w: %s", err, string(result)) + if err = json.Unmarshal(rawMessage, &raw); err != nil { + return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } var records []TXTRecord @@ -162,7 +181,7 @@ func (c *Client) ListTxtRecords(zoneName, fqdn string) ([]TXTRecord, error) { } // AddTxtRecord adds a TXT record. -func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error { +func (c *Client) AddTxtRecord(ctx context.Context, zoneName, fqdn, value string, ttl int) error { subDomain, err := dns01.ExtractSubDomain(fqdn, zoneName) if err != nil { return err @@ -178,14 +197,19 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error { q.Set("record-type", "TXT") endpoint.RawQuery = q.Encode() - raw, err := c.doRequest(http.MethodPost, endpoint) + req, err := c.newRequest(ctx, http.MethodPost, endpoint) + if err != nil { + return err + } + + rawMessage, err := c.do(req) if err != nil { return err } resp := apiResponse{} - if err = json.Unmarshal(raw, &resp); err != nil { - return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw)) + if err = json.Unmarshal(rawMessage, &resp); err != nil { + return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } if resp.Status != "Success" { @@ -196,7 +220,7 @@ func (c *Client) AddTxtRecord(zoneName, fqdn, value string, ttl int) error { } // RemoveTxtRecord removes a TXT record. -func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error { +func (c *Client) RemoveTxtRecord(ctx context.Context, recordID int, zoneName string) error { endpoint := c.BaseURL.JoinPath("delete-record.json") q := endpoint.Query() @@ -204,14 +228,19 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error { q.Set("record-id", strconv.Itoa(recordID)) endpoint.RawQuery = q.Encode() - raw, err := c.doRequest(http.MethodPost, endpoint) + req, err := c.newRequest(ctx, http.MethodPost, endpoint) + if err != nil { + return err + } + + rawMessage, err := c.do(req) if err != nil { return err } resp := apiResponse{} - if err = json.Unmarshal(raw, &resp); err != nil { - return fmt.Errorf("failed to unmarshal API response: %w: %s", err, string(raw)) + if err = json.Unmarshal(rawMessage, &resp); err != nil { + return errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } if resp.Status != "Success" { @@ -222,26 +251,31 @@ func (c *Client) RemoveTxtRecord(recordID int, zoneName string) error { } // GetUpdateStatus gets sync progress of all CloudDNS NS servers. -func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) { +func (c *Client) GetUpdateStatus(ctx context.Context, zoneName string) (*SyncProgress, error) { endpoint := c.BaseURL.JoinPath("update-status.json") q := endpoint.Query() q.Set("domain-name", zoneName) endpoint.RawQuery = q.Encode() - result, err := c.doRequest(http.MethodGet, endpoint) + req, err := c.newRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return nil, err + } + + rawMessage, err := c.do(req) if err != nil { return nil, err } // the API returns [] when there is no records. - if string(result) == "[]" { + if string(rawMessage) == "[]" { return nil, errors.New("no nameservers records returned") } var records []UpdateRecord - if err = json.Unmarshal(result, &records); err != nil { - return nil, fmt.Errorf("failed to unmarshal UpdateRecord: %w: %s", err, string(result)) + if err = json.Unmarshal(rawMessage, &records); err != nil { + return nil, errutils.NewUnmarshalError(req, http.StatusOK, rawMessage, err) } updatedCount := 0 @@ -254,33 +288,8 @@ func (c *Client) GetUpdateStatus(zoneName string) (*SyncProgress, error) { return &SyncProgress{Complete: updatedCount == len(records), Updated: updatedCount, Total: len(records)}, nil } -func (c *Client) doRequest(method string, uri *url.URL) (json.RawMessage, error) { - req, err := c.buildRequest(method, uri) - if err != nil { - return nil, err - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.New(toUnreadableBodyMessage(req, content)) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("invalid code (%d), error: %s", resp.StatusCode, content) - } - - return content, nil -} - -func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error) { - q := uri.Query() +func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL) (*http.Request, error) { + q := endpoint.Query() if c.subAuthID != "" { q.Set("sub-auth-id", c.subAuthID) @@ -290,18 +299,34 @@ func (c *Client) buildRequest(method string, uri *url.URL) (*http.Request, error q.Set("auth-password", c.authPassword) - uri.RawQuery = q.Encode() + endpoint.RawQuery = q.Encode() - req, err := http.NewRequest(method, uri.String(), nil) + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), nil) if err != nil { - return nil, fmt.Errorf("invalid request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } return req, nil } -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) +func (c *Client) do(req *http.Request) (json.RawMessage, error) { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + return raw, nil } // Rounds the given TTL in seconds to the next accepted value. diff --git a/providers/dns/cloudns/internal/client_test.go b/providers/dns/cloudns/internal/client_test.go index 277063a8..554bf008 100644 --- a/providers/dns/cloudns/internal/client_test.go +++ b/providers/dns/cloudns/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -11,6 +12,21 @@ import ( "github.com/stretchr/testify/require" ) +func setupTest(t *testing.T, subAuthID string, handler http.HandlerFunc) *Client { + t.Helper() + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + client, err := NewClient("myAuthID", subAuthID, "myAuthPassword") + require.NoError(t, err) + + client.BaseURL, _ = url.Parse(server.URL) + client.HTTPClient = server.Client() + + return client +} + func handlerMock(method string, jsonData []byte) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { if req.Method != method { @@ -109,22 +125,16 @@ func TestClient_GetZone(t *testing.T) { authFQDN: "_acme-challenge.foo.com.", apiResponse: `[{}]`, expected: expected{ - errorMsg: "failed to unmarshal zone: json: cannot unmarshal array into Go value of type internal.Zone", + errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.Zone", }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) - t.Cleanup(server.Close) + client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse))) - client, err := NewClient("myAuthID", "", "myAuthPassword") - require.NoError(t, err) - - client.BaseURL, _ = url.Parse(server.URL) - - zone, err := client.GetZone(test.authFQDN) + zone, err := client.GetZone(context.Background(), test.authFQDN) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) @@ -222,22 +232,16 @@ func TestClient_FindTxtRecord(t *testing.T) { zoneName: "example.com", apiResponse: `[{}]`, expected: expected{ - errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]", + errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord", }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) - t.Cleanup(server.Close) + client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse))) - client, err := NewClient("myAuthID", "", "myAuthPassword") - require.NoError(t, err) - - client.BaseURL, _ = url.Parse(server.URL) - - txtRecord, err := client.FindTxtRecord(test.zoneName, test.authFQDN) + txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneName, test.authFQDN) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) @@ -337,22 +341,16 @@ func TestClient_ListTxtRecord(t *testing.T) { zoneName: "example.com", apiResponse: `[{}]`, expected: expected{ - errorMsg: "failed to unmarshall TXT records: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord: [{}]", + errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type map[string]internal.TXTRecord", }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(handlerMock(http.MethodGet, []byte(test.apiResponse))) - t.Cleanup(server.Close) + client := setupTest(t, "", handlerMock(http.MethodGet, []byte(test.apiResponse))) - client, err := NewClient("myAuthID", "", "myAuthPassword") - require.NoError(t, err) - - client.BaseURL, _ = url.Parse(server.URL) - - txtRecords, err := client.ListTxtRecords(test.zoneName, test.authFQDN) + txtRecords, err := client.ListTxtRecords(context.Background(), test.zoneName, test.authFQDN) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) @@ -440,14 +438,14 @@ func TestClient_AddTxtRecord(t *testing.T) { apiResponse: `[{}]`, expected: expected{ query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=bar.com&host=_acme-challenge&record=TXTtxtTXTtxtTXTtxtTXTtxt&record-type=TXT&ttl=300`, - errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]", + errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse", }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + client := setupTest(t, test.subAuthID, func(rw http.ResponseWriter, req *http.Request) { if test.expected.query != req.URL.RawQuery { msg := fmt.Sprintf("got: %s, want: %s", test.expected.query, req.URL.RawQuery) http.Error(rw, msg, http.StatusBadRequest) @@ -455,15 +453,9 @@ func TestClient_AddTxtRecord(t *testing.T) { } handlerMock(http.MethodPost, []byte(test.apiResponse))(rw, req) - })) - t.Cleanup(server.Close) + }) - client, err := NewClient(test.authID, test.subAuthID, "myAuthPassword") - require.NoError(t, err) - - client.BaseURL, _ = url.Parse(server.URL) - - err = client.AddTxtRecord(test.zoneName, test.authFQDN, test.value, test.ttl) + err := client.AddTxtRecord(context.Background(), test.zoneName, test.authFQDN, test.value, test.ttl) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) @@ -513,7 +505,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) { apiResponse: `[{}]`, expected: expected{ query: `auth-id=myAuthID&auth-password=myAuthPassword&domain-name=foo-plus.com&record-id=44`, - errorMsg: "failed to unmarshal API response: json: cannot unmarshal array into Go value of type internal.apiResponse: [{}]", + errorMsg: "unable to unmarshal response: [status code: 200] body: [{}] error: json: cannot unmarshal array into Go value of type internal.apiResponse", }, }, } @@ -536,7 +528,7 @@ func TestClient_RemoveTxtRecord(t *testing.T) { client.BaseURL, _ = url.Parse(server.URL) - err = client.RemoveTxtRecord(test.id, test.zoneName) + err = client.RemoveTxtRecord(context.Background(), test.id, test.zoneName) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) @@ -592,7 +584,7 @@ func TestClient_GetUpdateStatus(t *testing.T) { authFQDN: "_acme-challenge.foo.com.", zoneName: "test-zone", apiResponse: `[x]`, - expected: expected{errorMsg: "failed to unmarshal UpdateRecord: invalid character 'x' looking for beginning of value: [x]"}, + expected: expected{errorMsg: "unable to unmarshal response: [status code: 200] body: [x] error: invalid character 'x' looking for beginning of value"}, }, } @@ -606,7 +598,7 @@ func TestClient_GetUpdateStatus(t *testing.T) { client.BaseURL, _ = url.Parse(server.URL) - syncProgress, err := client.GetUpdateStatus(test.zoneName) + syncProgress, err := client.GetUpdateStatus(context.Background(), test.zoneName) if test.expected.errorMsg != "" { require.EqualError(t, err, test.expected.errorMsg) diff --git a/providers/dns/cloudxns/cloudxns.go b/providers/dns/cloudxns/cloudxns.go index d3bd9b6b..6269b8da 100644 --- a/providers/dns/cloudxns/cloudxns.go +++ b/providers/dns/cloudxns/cloudxns.go @@ -2,6 +2,7 @@ package cloudxns import ( + "context" "errors" "fmt" "net/http" @@ -59,7 +60,7 @@ type DNSProvider struct { func NewDNSProvider() (*DNSProvider, error) { values, err := env.Get(EnvAPIKey, EnvSecretKey) if err != nil { - return nil, fmt.Errorf("CloudXNS: %w", err) + return nil, fmt.Errorf("cloudxns: %w", err) } config := NewDefaultConfig() @@ -72,15 +73,17 @@ func NewDNSProvider() (*DNSProvider, error) { // NewDNSProviderConfig return a DNSProvider instance configured for CloudXNS. func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { if config == nil { - return nil, errors.New("CloudXNS: the configuration of the DNS provider is nil") + return nil, errors.New("cloudxns: the configuration of the DNS provider is nil") } client, err := internal.NewClient(config.APIKey, config.SecretKey) if err != nil { - return nil, err + return nil, fmt.Errorf("cloudxns: %w", err) } - client.HTTPClient = config.HTTPClient + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } return &DNSProvider{client: client, config: config}, nil } @@ -89,29 +92,43 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { challengeInfo := dns01.GetChallengeInfo(domain, keyAuth) - info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN) + ctx := context.Background() + + info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("cloudxns: %w", err) } - return d.client.AddTxtRecord(info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL) + err = d.client.AddTxtRecord(ctx, info, challengeInfo.EffectiveFQDN, challengeInfo.Value, d.config.TTL) + if err != nil { + return fmt.Errorf("cloudxns: %w", err) + } + + return nil } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { challengeInfo := dns01.GetChallengeInfo(domain, keyAuth) - info, err := d.client.GetDomainInformation(challengeInfo.EffectiveFQDN) + ctx := context.Background() + + info, err := d.client.GetDomainInformation(ctx, challengeInfo.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("cloudxns: %w", err) } - record, err := d.client.FindTxtRecord(info.ID, challengeInfo.EffectiveFQDN) + record, err := d.client.FindTxtRecord(ctx, info.ID, challengeInfo.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("cloudxns: %w", err) } - return d.client.RemoveTxtRecord(record.RecordID, info.ID) + err = d.client.RemoveTxtRecord(ctx, record.RecordID, info.ID) + if err != nil { + return fmt.Errorf("cloudxns: %w", err) + } + + return nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. diff --git a/providers/dns/cloudxns/cloudxns_test.go b/providers/dns/cloudxns/cloudxns_test.go index 43dd8a99..0b327176 100644 --- a/providers/dns/cloudxns/cloudxns_test.go +++ b/providers/dns/cloudxns/cloudxns_test.go @@ -34,7 +34,7 @@ func TestNewDNSProvider(t *testing.T) { EnvAPIKey: "", EnvSecretKey: "", }, - expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY", + expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY,CLOUDXNS_SECRET_KEY", }, { desc: "missing API key", @@ -42,7 +42,7 @@ func TestNewDNSProvider(t *testing.T) { EnvAPIKey: "", EnvSecretKey: "456", }, - expected: "CloudXNS: some credentials information are missing: CLOUDXNS_API_KEY", + expected: "cloudxns: some credentials information are missing: CLOUDXNS_API_KEY", }, { desc: "missing secret key", @@ -50,7 +50,7 @@ func TestNewDNSProvider(t *testing.T) { EnvAPIKey: "123", EnvSecretKey: "", }, - expected: "CloudXNS: some credentials information are missing: CLOUDXNS_SECRET_KEY", + expected: "cloudxns: some credentials information are missing: CLOUDXNS_SECRET_KEY", }, } @@ -89,17 +89,17 @@ func TestNewDNSProviderConfig(t *testing.T) { }, { desc: "missing credentials", - expected: "CloudXNS: credentials missing: apiKey", + expected: "cloudxns: credentials missing: apiKey", }, { desc: "missing api key", secretKey: "456", - expected: "CloudXNS: credentials missing: apiKey", + expected: "cloudxns: credentials missing: apiKey", }, { desc: "missing secret key", apiKey: "123", - expected: "CloudXNS: credentials missing: secretKey", + expected: "cloudxns: credentials missing: secretKey", }, } diff --git a/providers/dns/cloudxns/internal/client.go b/providers/dns/cloudxns/internal/client.go index dd151184..2fc6aab2 100644 --- a/providers/dns/cloudxns/internal/client.go +++ b/providers/dns/cloudxns/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "crypto/md5" "encoding/hex" "encoding/json" @@ -9,83 +10,63 @@ import ( "fmt" "io" "net/http" + "net/url" "strconv" "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://www.cloudxns.net/api2/" -type apiResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data,omitempty"` -} +// Client CloudXNS client. +type Client struct { + apiKey string + secretKey string -// Data Domain information. -type Data struct { - ID string `json:"id"` - Domain string `json:"domain"` - TTL int `json:"ttl,omitempty"` -} - -// TXTRecord a TXT record. -type TXTRecord struct { - ID int `json:"domain_id,omitempty"` - RecordID string `json:"record_id,omitempty"` - - Host string `json:"host"` - Value string `json:"value"` - Type string `json:"type"` - LineID int `json:"line_id,string"` - TTL int `json:"ttl,string"` + baseURL *url.URL + HTTPClient *http.Client } // NewClient creates a CloudXNS client. func NewClient(apiKey, secretKey string) (*Client, error) { if apiKey == "" { - return nil, errors.New("CloudXNS: credentials missing: apiKey") + return nil, errors.New("credentials missing: apiKey") } if secretKey == "" { - return nil, errors.New("CloudXNS: credentials missing: secretKey") + return nil, errors.New("credentials missing: secretKey") } + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ apiKey: apiKey, secretKey: secretKey, - HTTPClient: &http.Client{}, - BaseURL: defaultBaseURL, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } -// Client CloudXNS client. -type Client struct { - apiKey string - secretKey string - HTTPClient *http.Client - BaseURL string -} - // GetDomainInformation Get domain name information for a FQDN. -func (c *Client) GetDomainInformation(fqdn string) (*Data, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) +func (c *Client) GetDomainInformation(ctx context.Context, fqdn string) (*Data, error) { + endpoint := c.baseURL.JoinPath("domain") + + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - result, err := c.doRequest(http.MethodGet, "domain", nil) + authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return nil, err + return nil, fmt.Errorf("cloudflare: could not find zone for FQDN %q: %w", fqdn, err) } var domains []Data - if len(result) > 0 { - err = json.Unmarshal(result, &domains) - if err != nil { - return nil, fmt.Errorf("CloudXNS: domains unmarshaling error: %w", err) - } + err = c.do(req, &domains) + if err != nil { + return nil, err } for _, data := range domains { @@ -94,20 +75,28 @@ func (c *Client) GetDomainInformation(fqdn string) (*Data, error) { } } - return nil, fmt.Errorf("CloudXNS: zone %s not found for domain %s", authZone, fqdn) + return nil, fmt.Errorf("zone %s not found for domain %s", authZone, fqdn) } // FindTxtRecord return the TXT record a zone ID and a FQDN. -func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) { - result, err := c.doRequest(http.MethodGet, fmt.Sprintf("record/%s?host_id=0&offset=0&row_num=2000", zoneID), nil) +func (c *Client) FindTxtRecord(ctx context.Context, zoneID, fqdn string) (*TXTRecord, error) { + endpoint := c.baseURL.JoinPath("record", zoneID) + + query := endpoint.Query() + query.Set("host_id", "0") + query.Set("offset", "0") + query.Set("row_num", "2000") + endpoint.RawQuery = query.Encode() + + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } var records []TXTRecord - err = json.Unmarshal(result, &records) + err = c.do(req, &records) if err != nil { - return nil, fmt.Errorf("CloudXNS: TXT record unmarshaling error: %w", err) + return nil, err } for _, record := range records { @@ -116,22 +105,24 @@ func (c *Client) FindTxtRecord(zoneID, fqdn string) (*TXTRecord, error) { } } - return nil, fmt.Errorf("CloudXNS: no existing record found for %q", fqdn) + return nil, fmt.Errorf("no existing record found for %q", fqdn) } // AddTxtRecord add a TXT record. -func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error { +func (c *Client) AddTxtRecord(ctx context.Context, info *Data, fqdn, value string, ttl int) error { id, err := strconv.Atoi(info.ID) if err != nil { - return fmt.Errorf("CloudXNS: invalid zone ID: %w", err) + return fmt.Errorf("invalid zone ID: %w", err) } + endpoint := c.baseURL.JoinPath("record") + subDomain, err := dns01.ExtractSubDomain(fqdn, info.Domain) if err != nil { - return fmt.Errorf("CloudXNS: %w", err) + return err } - payload := TXTRecord{ + record := TXTRecord{ ID: id, Host: subDomain, Value: value, @@ -140,74 +131,91 @@ func (c *Client) AddTxtRecord(info *Data, fqdn, value string, ttl int) error { TTL: ttl, } - body, err := json.Marshal(payload) + req, err := c.newRequest(ctx, http.MethodPost, endpoint, record) if err != nil { - return fmt.Errorf("CloudXNS: record unmarshaling error: %w", err) + return err } - _, err = c.doRequest(http.MethodPost, "record", body) - return err + return c.do(req, nil) } // RemoveTxtRecord remove a TXT record. -func (c *Client) RemoveTxtRecord(recordID, zoneID string) error { - _, err := c.doRequest(http.MethodDelete, fmt.Sprintf("record/%s/%s", recordID, zoneID), nil) - return err -} +func (c *Client) RemoveTxtRecord(ctx context.Context, recordID, zoneID string) error { + endpoint := c.baseURL.JoinPath("record", recordID, zoneID) -func (c *Client) doRequest(method, uri string, body []byte) (json.RawMessage, error) { - req, err := c.buildRequest(method, uri, body) + req, err := c.newRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { - return nil, err + return err } + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("CloudXNS: %w", err) + return errutils.NewHTTPDoError(req, err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() - content, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("CloudXNS: %s", toUnreadableBodyMessage(req, content)) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - var r apiResponse - err = json.Unmarshal(content, &r) + var response apiResponse + err = json.Unmarshal(raw, &response) if err != nil { - return nil, fmt.Errorf("CloudXNS: response unmashaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } - if r.Code != 1 { - return nil, fmt.Errorf("CloudXNS: invalid code (%v), error: %s", r.Code, r.Message) + if response.Code != 1 { + return fmt.Errorf("[status code %d] invalid code (%v) error: %s", resp.StatusCode, response.Code, response.Message) } - return r.Data, nil + + if result == nil { + return nil + } + + if len(response.Data) == 0 { + return nil + } + + err = json.Unmarshal(response.Data, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil } -func (c *Client) buildRequest(method, uri string, body []byte) (*http.Request, error) { - url := c.BaseURL + uri +func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) - req, err := http.NewRequest(method, url, bytes.NewReader(body)) + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, fmt.Errorf("CloudXNS: invalid request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } requestDate := time.Now().Format(time.RFC1123Z) req.Header.Set("API-KEY", c.apiKey) req.Header.Set("API-REQUEST-DATE", requestDate) - req.Header.Set("API-HMAC", c.hmac(url, requestDate, string(body))) + req.Header.Set("API-HMAC", c.hmac(endpoint.String(), requestDate, buf.String())) req.Header.Set("API-FORMAT", "json") return req, nil } -func (c *Client) hmac(url, date, body string) string { - sum := md5.Sum([]byte(c.apiKey + url + body + date + c.secretKey)) +func (c *Client) hmac(endpoint, date, body string) string { + sum := md5.Sum([]byte(c.apiKey + endpoint + body + date + c.secretKey)) return hex.EncodeToString(sum[:]) } - -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) -} diff --git a/providers/dns/cloudxns/internal/client_test.go b/providers/dns/cloudxns/internal/client_test.go index 618ac027..e4972174 100644 --- a/providers/dns/cloudxns/internal/client_test.go +++ b/providers/dns/cloudxns/internal/client_test.go @@ -1,19 +1,35 @@ package internal import ( + "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func handlerMock(method string, response *apiResponse, data interface{}) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { +func setupTest(t *testing.T, handler http.HandlerFunc) *Client { + t.Helper() + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + client, _ := NewClient("myKey", "mySecret") + client.baseURL, _ = url.Parse(server.URL + "/") + client.HTTPClient = server.Client() + + return client +} + +func handlerMock(method string, response *apiResponse, data interface{}) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { if req.Method != method { content, err := json.Marshal(apiResponse{ Code: 999, // random code only for the test @@ -47,10 +63,10 @@ func handlerMock(method string, response *apiResponse, data interface{}) http.Ha http.Error(rw, err.Error(), http.StatusInternalServerError) return } - }) + } } -func TestClientGetDomainInformation(t *testing.T) { +func TestClient_GetDomainInformation(t *testing.T) { type result struct { domain *Data error bool @@ -106,13 +122,9 @@ func TestClientGetDomainInformation(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.data)) - t.Cleanup(server.Close) + client := setupTest(t, handlerMock(http.MethodGet, test.response, test.data)) - client, _ := NewClient("myKey", "mySecret") - client.BaseURL = server.URL + "/" - - domain, err := client.GetDomainInformation(test.fqdn) + domain, err := client.GetDomainInformation(context.Background(), test.fqdn) if test.expected.error { require.Error(t, err) @@ -124,7 +136,7 @@ func TestClientGetDomainInformation(t *testing.T) { } } -func TestClientFindTxtRecord(t *testing.T) { +func TestClient_FindTxtRecord(t *testing.T) { type result struct { txtRecord *TXTRecord error bool @@ -210,13 +222,9 @@ func TestClientFindTxtRecord(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - server := httptest.NewServer(handlerMock(http.MethodGet, test.response, test.txtRecords)) - t.Cleanup(server.Close) + client := setupTest(t, handlerMock(http.MethodGet, test.response, test.txtRecords)) - client, _ := NewClient("myKey", "mySecret") - client.BaseURL = server.URL + "/" - - txtRecord, err := client.FindTxtRecord(test.zoneID, test.fqdn) + txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneID, test.fqdn) if test.expected.error { require.Error(t, err) @@ -228,7 +236,7 @@ func TestClientFindTxtRecord(t *testing.T) { } } -func TestClientAddTxtRecord(t *testing.T) { +func TestClient_AddTxtRecord(t *testing.T) { testCases := []struct { desc string domain *Data @@ -267,21 +275,17 @@ func TestClientAddTxtRecord(t *testing.T) { Code: 1, } - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { assert.NotNil(t, req.Body) content, err := io.ReadAll(req.Body) require.NoError(t, err) - assert.Equal(t, test.expected, string(content)) + assert.Equal(t, test.expected, string(bytes.TrimSpace(content))) handlerMock(http.MethodPost, response, nil).ServeHTTP(rw, req) - })) - t.Cleanup(server.Close) + }) - client, _ := NewClient("myKey", "mySecret") - client.BaseURL = server.URL + "/" - - err := client.AddTxtRecord(test.domain, test.fqdn, test.value, test.ttl) + err := client.AddTxtRecord(context.Background(), test.domain, test.fqdn, test.value, test.ttl) require.NoError(t, err) }) } diff --git a/providers/dns/cloudxns/internal/types.go b/providers/dns/cloudxns/internal/types.go new file mode 100644 index 00000000..c1b24e30 --- /dev/null +++ b/providers/dns/cloudxns/internal/types.go @@ -0,0 +1,28 @@ +package internal + +import "encoding/json" + +type apiResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data,omitempty"` +} + +// Data Domain information. +type Data struct { + ID string `json:"id"` + Domain string `json:"domain"` + TTL int `json:"ttl,omitempty"` +} + +// TXTRecord a TXT record. +type TXTRecord struct { + ID int `json:"domain_id,omitempty"` + RecordID string `json:"record_id,omitempty"` + + Host string `json:"host"` + Value string `json:"value"` + Type string `json:"type"` + LineID int `json:"line_id,string"` + TTL int `json:"ttl,string"` +} diff --git a/providers/dns/conoha/conoha.go b/providers/dns/conoha/conoha.go index b107ed37..1e4b0e18 100644 --- a/providers/dns/conoha/conoha.go +++ b/providers/dns/conoha/conoha.go @@ -2,6 +2,7 @@ package conoha import ( + "context" "errors" "fmt" "net/http" @@ -85,6 +86,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("conoha: some credentials information are missing") } + identifier, err := internal.NewIdentifier(config.Region) + if err != nil { + return nil, fmt.Errorf("conoha: failed to create identity client: %w", err) + } + + if config.HTTPClient != nil { + identifier.HTTPClient = config.HTTPClient + } + auth := internal.Auth{ TenantID: config.TenantID, PasswordCredentials: internal.PasswordCredentials{ @@ -93,11 +103,20 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { }, } - client, err := internal.NewClient(config.Region, auth, config.HTTPClient) + tokens, err := identifier.GetToken(context.TODO(), auth) + if err != nil { + return nil, fmt.Errorf("conoha: failed to login: %w", err) + } + + client, err := internal.NewClient(config.Region, tokens.Access.Token.ID) if err != nil { return nil, fmt.Errorf("conoha: failed to create client: %w", err) } + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + return &DNSProvider{config: config, client: client}, nil } @@ -107,10 +126,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - id, err := d.client.GetDomainID(authZone) + ctx := context.Background() + + id, err := d.client.GetDomainID(ctx, authZone) if err != nil { return fmt.Errorf("conoha: failed to get domain ID: %w", err) } @@ -122,7 +143,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - err = d.client.CreateRecord(id, record) + err = d.client.CreateRecord(ctx, id, record) if err != nil { return fmt.Errorf("conoha: failed to create record: %w", err) } @@ -136,20 +157,22 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("conoha: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - domID, err := d.client.GetDomainID(authZone) + ctx := context.Background() + + domID, err := d.client.GetDomainID(ctx, authZone) if err != nil { return fmt.Errorf("conoha: failed to get domain ID: %w", err) } - recID, err := d.client.GetRecordID(domID, info.EffectiveFQDN, "TXT", info.Value) + recID, err := d.client.GetRecordID(ctx, domID, info.EffectiveFQDN, "TXT", info.Value) if err != nil { return fmt.Errorf("conoha: failed to get record ID: %w", err) } - err = d.client.DeleteRecord(domID, recID) + err = d.client.DeleteRecord(ctx, domID, recID) if err != nil { return fmt.Errorf("conoha: failed to delete record: %w", err) } diff --git a/providers/dns/conoha/conoha_test.go b/providers/dns/conoha/conoha_test.go index 75e38b7f..8d8197f3 100644 --- a/providers/dns/conoha/conoha_test.go +++ b/providers/dns/conoha/conoha_test.go @@ -29,7 +29,7 @@ func TestNewDNSProvider(t *testing.T) { EnvAPIUsername: "api_username", EnvAPIPassword: "api_password", }, - expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, + expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, }, { desc: "missing credentials", @@ -99,7 +99,7 @@ func TestNewDNSProviderConfig(t *testing.T) { }{ { desc: "complete credentials, but login failed", - expected: `conoha: failed to create client: failed to login: HTTP request failed with status code 401: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, + expected: `conoha: failed to login: unexpected status code: [status code: 401] body: {"unauthorized":{"message":"Invalid user: api_username","code":401}}`, tenant: "tenant_id", username: "api_username", password: "api_password", diff --git a/providers/dns/conoha/internal/client.go b/providers/dns/conoha/internal/client.go index f8a5e192..87fbe5a0 100644 --- a/providers/dns/conoha/internal/client.go +++ b/providers/dns/conoha/internal/client.go @@ -2,121 +2,45 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) -const ( - identityBaseURL = "https://identity.%s.conoha.io" - dnsServiceBaseURL = "https://dns-service.%s.conoha.io" -) - -// IdentityRequest is an authentication request body. -type IdentityRequest struct { - Auth Auth `json:"auth"` -} - -// Auth is an authentication information. -type Auth struct { - TenantID string `json:"tenantId"` - PasswordCredentials PasswordCredentials `json:"passwordCredentials"` -} - -// PasswordCredentials is API-user's credentials. -type PasswordCredentials struct { - Username string `json:"username"` - Password string `json:"password"` -} - -// IdentityResponse is an authentication response body. -type IdentityResponse struct { - Access Access `json:"access"` -} - -// Access is an identity information. -type Access struct { - Token Token `json:"token"` -} - -// Token is an api access token. -type Token struct { - ID string `json:"id"` -} - -// DomainListResponse is a response of a domain listing request. -type DomainListResponse struct { - Domains []Domain `json:"domains"` -} - -// Domain is a hosted domain entry. -type Domain struct { - ID string `json:"id"` - Name string `json:"name"` -} - -// RecordListResponse is a response of record listing request. -type RecordListResponse struct { - Records []Record `json:"records"` -} - -// Record is a record entry. -type Record struct { - ID string `json:"id,omitempty"` - Name string `json:"name"` - Type string `json:"type"` - Data string `json:"data"` - TTL int `json:"ttl"` -} +const dnsServiceBaseURL = "https://dns-service.%s.conoha.io" // Client is a ConoHa API client. type Client struct { - token string - endpoint string - httpClient *http.Client + token string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient returns a client instance logged into the ConoHa service. -func NewClient(region string, auth Auth, httpClient *http.Client) (*Client, error) { - if httpClient == nil { - httpClient = &http.Client{} - } - - c := &Client{httpClient: httpClient} - - c.endpoint = fmt.Sprintf(identityBaseURL, region) - - identity, err := c.getIdentity(auth) - if err != nil { - return nil, fmt.Errorf("failed to login: %w", err) - } - - c.token = identity.Access.Token.ID - c.endpoint = fmt.Sprintf(dnsServiceBaseURL, region) - - return c, nil -} - -func (c *Client) getIdentity(auth Auth) (*IdentityResponse, error) { - req := &IdentityRequest{Auth: auth} - - identity := &IdentityResponse{} - - err := c.do(http.MethodPost, "/v2.0/tokens", req, identity) +func NewClient(region string, token string) (*Client, error) { + baseURL, err := url.Parse(fmt.Sprintf(dnsServiceBaseURL, region)) if err != nil { return nil, err } - return identity, nil + return &Client{ + token: token, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + }, nil } // GetDomainID returns an ID of specified domain. -func (c *Client) GetDomainID(domainName string) (string, error) { - domainList := &DomainListResponse{} - - err := c.do(http.MethodGet, "/v1/domains", nil, domainList) +func (c *Client) GetDomainID(ctx context.Context, domainName string) (string, error) { + domainList, err := c.getDomains(ctx) if err != nil { return "", err } @@ -126,14 +50,32 @@ func (c *Client) GetDomainID(domainName string) (string, error) { return domain.ID, nil } } + return "", fmt.Errorf("no such domain: %s", domainName) } -// GetRecordID returns an ID of specified record. -func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (string, error) { - recordList := &RecordListResponse{} +// https://www.conoha.jp/docs/paas-dns-list-domains.php +func (c *Client) getDomains(ctx context.Context) (*DomainListResponse, error) { + endpoint := c.baseURL.JoinPath("v1", "domains") - err := c.do(http.MethodGet, fmt.Sprintf("/v1/domains/%s/records", domainID), nil, recordList) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + domainList := &DomainListResponse{} + + err = c.do(req, domainList) + if err != nil { + return nil, err + } + + return domainList, nil +} + +// GetRecordID returns an ID of specified record. +func (c *Client) GetRecordID(ctx context.Context, domainID, recordName, recordType, data string) (string, error) { + recordList, err := c.getRecords(ctx, domainID) if err != nil { return "", err } @@ -143,63 +85,119 @@ func (c *Client) GetRecordID(domainID, recordName, recordType, data string) (str return record.ID, nil } } + return "", errors.New("no such record") } +// https://www.conoha.jp/docs/paas-dns-list-records-in-a-domain.php +func (c *Client) getRecords(ctx context.Context, domainID string) (*RecordListResponse, error) { + endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + recordList := &RecordListResponse{} + + err = c.do(req, recordList) + if err != nil { + return nil, err + } + + return recordList, nil +} + // CreateRecord adds new record. -func (c *Client) CreateRecord(domainID string, record Record) error { - return c.do(http.MethodPost, fmt.Sprintf("/v1/domains/%s/records", domainID), record, nil) +func (c *Client) CreateRecord(ctx context.Context, domainID string, record Record) error { + _, err := c.createRecord(ctx, domainID, record) + return err +} + +// https://www.conoha.jp/docs/paas-dns-create-record.php +func (c *Client) createRecord(ctx context.Context, domainID string, record Record) (*Record, error) { + endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return nil, err + } + + newRecord := &Record{} + err = c.do(req, newRecord) + if err != nil { + return nil, err + } + + return newRecord, nil } // DeleteRecord removes specified record. -func (c *Client) DeleteRecord(domainID, recordID string) error { - return c.do(http.MethodDelete, fmt.Sprintf("/v1/domains/%s/records/%s", domainID, recordID), nil, nil) +// https://www.conoha.jp/docs/paas-dns-delete-a-record.php +func (c *Client) DeleteRecord(ctx context.Context, domainID, recordID string) error { + endpoint := c.baseURL.JoinPath("v1", "domains", domainID, "records", recordID) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) } -func (c *Client) do(method, path string, payload, result interface{}) error { - body := bytes.NewReader(nil) - - if payload != nil { - bodyBytes, err := json.Marshal(payload) - if err != nil { - return err - } - body = bytes.NewReader(bodyBytes) +func (c *Client) do(req *http.Request, result any) error { + if c.token != "" { + req.Header.Set("X-Auth-Token", c.token) } - req, err := http.NewRequest(method, c.endpoint+path, body) + resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Auth-Token", c.token) - - resp, err := c.httpClient.Do(req) - if err != nil { - return err - } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - defer resp.Body.Close() - - return fmt.Errorf("HTTP request failed with status code %d: %s", resp.StatusCode, string(respBody)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - if result != nil { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - defer resp.Body.Close() + if result == nil { + return nil + } - return json.Unmarshal(respBody, result) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/conoha/internal/client_test.go b/providers/dns/conoha/internal/client_test.go index 44f16b1c..bc27ec21 100644 --- a/providers/dns/conoha/internal/client_test.go +++ b/providers/dns/conoha/internal/client_test.go @@ -1,30 +1,71 @@ package internal import ( + "bytes" + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) - client := &Client{ - token: "secret", - endpoint: server.URL, - httpClient: server.Client(), - } + client, err := NewClient("tyo1", "secret") + require.NoError(t, err) - return mux, client + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + return client, mux +} + +func writeFixtureHandler(method, filename string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + writeFixture(rw, filename) + } +} + +func writeBodyHandler(method, content string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + _, err := fmt.Fprint(rw, content) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + } +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) } func TestClient_GetDomainID(t *testing.T) { @@ -42,91 +83,30 @@ func TestClient_GetDomainID(t *testing.T) { { desc: "success", domainName: "domain1.com.", - handler: func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) - return - } - - content := ` -{ - "domains":[ - { - "id": "09494b72-b65b-4297-9efb-187f65a0553e", - "name": "domain1.com.", - "ttl": 3600, - "serial": 1351800668, - "email": "nsadmin@example.org", - "gslb": 0, - "created_at": "2012-11-01T20:11:08.000000", - "updated_at": null, - "description": "memo" - }, - { - "id": "cf661142-e577-40b5-b3eb-75795cdc0cd7", - "name": "domain2.com.", - "ttl": 7200, - "serial": 1351800670, - "email": "nsadmin2@example.org", - "gslb": 1, - "created_at": "2012-11-01T20:11:08.000000", - "updated_at": "2012-12-01T20:11:08.000000", - "description": "memomemo" - } - ] -} -` - _, err := fmt.Fprint(rw, content) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"}, + handler: writeFixtureHandler(http.MethodGet, "domains_GET.json"), + expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"}, }, { desc: "non existing domain", domainName: "domain1.com.", - handler: func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) - return - } - - _, err := fmt.Fprint(rw, "{}") - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - expected: expected{error: true}, + handler: writeBodyHandler(http.MethodGet, "{}"), + expected: expected{error: true}, }, { desc: "marshaling error", domainName: "domain1.com.", - handler: func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) - return - } - - _, err := fmt.Fprint(rw, "[]") - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - expected: expected{error: true}, + handler: writeBodyHandler(http.MethodGet, "[]"), + expected: expected{error: true}, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v1/domains", test.handler) - domainID, err := client.GetDomainID(test.domainName) + domainID, err := client.GetDomainID(context.Background(), test.domainName) if test.expected.error { require.Error(t, err) @@ -140,15 +120,15 @@ func TestClient_GetDomainID(t *testing.T) { func TestClient_CreateRecord(t *testing.T) { testCases := []struct { - desc string - handler http.HandlerFunc - expectError bool + desc string + handler http.HandlerFunc + assert require.ErrorAssertionFunc }{ { desc: "success", handler: func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { - http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) return } @@ -157,31 +137,34 @@ func TestClient_CreateRecord(t *testing.T) { http.Error(rw, err.Error(), http.StatusBadRequest) return } - defer req.Body.Close() + defer func() { _ = req.Body.Close() }() - if string(raw) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` { + if string(bytes.TrimSpace(raw)) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` { http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) return } + + writeFixture(rw, "domains-records_POST.json") }, + assert: require.NoError, }, { desc: "bad request", handler: func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { - http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed) + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) return } http.Error(rw, "OOPS", http.StatusBadRequest) }, - expectError: true, + assert: require.Error, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v1/domains/lego/records", test.handler) @@ -194,13 +177,36 @@ func TestClient_CreateRecord(t *testing.T) { TTL: 300, } - err := client.CreateRecord(domainID, record) - - if test.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } + err := client.CreateRecord(context.Background(), domainID, record) + test.assert(t, err) }) } } + +func TestClient_GetRecordID(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records", + writeFixtureHandler(http.MethodGet, "domains-records_GET.json")) + + recordID, err := client.GetRecordID(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "www.example.com.", "A", "15.185.172.153") + require.NoError(t, err) + + assert.Equal(t, "2e32e609-3a4f-45ba-bdef-e50eacd345ad", recordID) +} + +func TestClient_DeleteRecord(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/v1/domains/89acac79-38e7-497d-807c-a011e1310438/records/2e32e609-3a4f-45ba-bdef-e50eacd345ad", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodDelete { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + rw.WriteHeader(http.StatusOK) + }) + + err := client.DeleteRecord(context.Background(), "89acac79-38e7-497d-807c-a011e1310438", "2e32e609-3a4f-45ba-bdef-e50eacd345ad") + require.NoError(t, err) +} diff --git a/providers/dns/conoha/internal/fixtures/domains-records_GET.json b/providers/dns/conoha/internal/fixtures/domains-records_GET.json new file mode 100644 index 00000000..6b8ca263 --- /dev/null +++ b/providers/dns/conoha/internal/fixtures/domains-records_GET.json @@ -0,0 +1,43 @@ +{ + "records": [ + { + "id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad", + "name": "www.example.com.", + "type": "A", + "ttl": 3600, + "created_at": "2012-11-02T19:56:26.000000", + "updated_at": "2012-11-04T13:22:36.000000", + "data": "15.185.172.153", + "domain_id": "89acac79-38e7-497d-807c-a011e1310438", + "version": 1, + "gslb_region": "JP", + "gslb_weight": 250, + "gslb_check": 12300 + }, + { + "id": "8e9ecf3e-fb92-4a3a-a8ae-7596f167bea3", + "name": "host1.example.com.", + "type": "A", + "ttl": 3600, + "created_at": "2012-11-04T13:57:50.000000", + "updated_at": null, + "data": "15.185.172.154", + "domain_id": "89acac79-38e7-497d-807c-a011e1310438", + "version": 1, + "gslb_region": "US", + "gslb_weight": 220, + "gslb_check": 12200 + }, + { + "id": "4ad19089-3e62-40f8-9482-17cc8ccb92cb", + "name": "web.example.com.", + "type": "CNAME", + "ttl": 3600, + "created_at": "2012-11-04T13:58:16.393735", + "updated_at": null, + "data": "www.example.com.", + "domain_id": "89acac79-38e7-497d-807c-a011e1310438", + "version": 1 + } + ] +} diff --git a/providers/dns/conoha/internal/fixtures/domains-records_POST.json b/providers/dns/conoha/internal/fixtures/domains-records_POST.json new file mode 100644 index 00000000..832d7b22 --- /dev/null +++ b/providers/dns/conoha/internal/fixtures/domains-records_POST.json @@ -0,0 +1,13 @@ +{ + "id": "2e32e609-3a4f-45ba-bdef-e50eacd345ad", + "name": "www.example.com.", + "type": "A", + "created_at": "2012-11-02T19:56:26.366792", + "updated_at": null, + "domain_id": "89acac79-38e7-497d-807c-a011e1310438", + "ttl": null, + "data": "192.0.2.3", + "gslb_check": 1, + "gslb_region": "JP", + "gslb_weight": 250 +} diff --git a/providers/dns/conoha/internal/fixtures/domains_GET.json b/providers/dns/conoha/internal/fixtures/domains_GET.json new file mode 100644 index 00000000..bafc4585 --- /dev/null +++ b/providers/dns/conoha/internal/fixtures/domains_GET.json @@ -0,0 +1,26 @@ +{ + "domains":[ + { + "id": "09494b72-b65b-4297-9efb-187f65a0553e", + "name": "domain1.com.", + "ttl": 3600, + "serial": 1351800668, + "email": "nsadmin@example.org", + "gslb": 0, + "created_at": "2012-11-01T20:11:08.000000", + "updated_at": null, + "description": "memo" + }, + { + "id": "cf661142-e577-40b5-b3eb-75795cdc0cd7", + "name": "domain2.com.", + "ttl": 7200, + "serial": 1351800670, + "email": "nsadmin2@example.org", + "gslb": 1, + "created_at": "2012-11-01T20:11:08.000000", + "updated_at": "2012-12-01T20:11:08.000000", + "description": "memomemo" + } + ] +} diff --git a/providers/dns/conoha/internal/fixtures/tokens_POST.json b/providers/dns/conoha/internal/fixtures/tokens_POST.json new file mode 100644 index 00000000..ac917186 --- /dev/null +++ b/providers/dns/conoha/internal/fixtures/tokens_POST.json @@ -0,0 +1,17 @@ +{ + "access": { + "token": { + "issued_at": "2015-05-19T07:08:21.927295", + "expires": "2015-05-20T07:08:21Z", + "id": "sample00d88246078f2bexample788f7", + "tenant": { + "name": "example00000000", + "enabled": true, + "tyo1_image_size": "550GB" + }, + "endpoints_links": [], + "type": "mailhosting", + "name": "Mail Hosting Service" + } + } +} diff --git a/providers/dns/conoha/internal/identity.go b/providers/dns/conoha/internal/identity.go new file mode 100644 index 00000000..995d55bb --- /dev/null +++ b/providers/dns/conoha/internal/identity.go @@ -0,0 +1,82 @@ +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const identityBaseURL = "https://identity.%s.conoha.io" + +type Identifier struct { + baseURL *url.URL + HTTPClient *http.Client +} + +// NewIdentifier creates a new Identifier. +func NewIdentifier(region string) (*Identifier, error) { + baseURL, err := url.Parse(fmt.Sprintf(identityBaseURL, region)) + if err != nil { + return nil, err + } + + return &Identifier{ + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + }, nil +} + +// GetToken gets valid token information. +// https://www.conoha.jp/docs/identity-post_tokens.php +func (c *Identifier) GetToken(ctx context.Context, auth Auth) (*IdentityResponse, error) { + endpoint := c.baseURL.JoinPath("v2.0", "tokens") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, &IdentityRequest{Auth: auth}) + if err != nil { + return nil, err + } + + identity := &IdentityResponse{} + + err = c.do(req, identity) + if err != nil { + return nil, err + } + + return identity, nil +} + +func (c *Identifier) do(req *http.Request, result any) error { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} diff --git a/providers/dns/conoha/internal/identity_test.go b/providers/dns/conoha/internal/identity_test.go new file mode 100644 index 00000000..027c7f2c --- /dev/null +++ b/providers/dns/conoha/internal/identity_test.go @@ -0,0 +1,41 @@ +package internal + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + identifier, err := NewIdentifier("tyo1") + require.NoError(t, err) + + identifier.HTTPClient = server.Client() + identifier.baseURL, _ = url.Parse(server.URL) + + mux.HandleFunc("/v2.0/tokens", writeFixtureHandler(http.MethodPost, "tokens_POST.json")) + + auth := Auth{ + TenantID: "487727e3921d44e3bfe7ebb337bf085e", + PasswordCredentials: PasswordCredentials{ + Username: "ConoHa", + Password: "paSSword123456#$%", + }, + } + + token, err := identifier.GetToken(context.Background(), auth) + require.NoError(t, err) + + expected := &IdentityResponse{Access: Access{Token: Token{ID: "sample00d88246078f2bexample788f7"}}} + + assert.Equal(t, expected, token) +} diff --git a/providers/dns/conoha/internal/types.go b/providers/dns/conoha/internal/types.go new file mode 100644 index 00000000..7749aded --- /dev/null +++ b/providers/dns/conoha/internal/types.go @@ -0,0 +1,58 @@ +package internal + +// IdentityRequest is an authentication request body. +type IdentityRequest struct { + Auth Auth `json:"auth"` +} + +// Auth is an authentication information. +type Auth struct { + TenantID string `json:"tenantId"` + PasswordCredentials PasswordCredentials `json:"passwordCredentials"` +} + +// PasswordCredentials is API-user's credentials. +type PasswordCredentials struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// IdentityResponse is an authentication response body. +type IdentityResponse struct { + Access Access `json:"access"` +} + +// Access is an identity information. +type Access struct { + Token Token `json:"token"` +} + +// Token is an api access token. +type Token struct { + ID string `json:"id"` +} + +// DomainListResponse is a response of a domain listing request. +type DomainListResponse struct { + Domains []Domain `json:"domains"` +} + +// Domain is a hosted domain entry. +type Domain struct { + ID string `json:"id"` + Name string `json:"name"` +} + +// RecordListResponse is a response of record listing request. +type RecordListResponse struct { + Records []Record `json:"records"` +} + +// Record is a record entry. +type Record struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Type string `json:"type"` + Data string `json:"data"` + TTL int `json:"ttl"` +} diff --git a/providers/dns/constellix/constellix.go b/providers/dns/constellix/constellix.go index 6f43b531..17ca1ab6 100644 --- a/providers/dns/constellix/constellix.go +++ b/providers/dns/constellix/constellix.go @@ -2,6 +2,7 @@ package constellix import ( + "context" "errors" "fmt" "net/http" @@ -101,10 +102,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone)) + ctx := context.Background() + + dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err) } @@ -114,7 +117,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("constellix: %w", err) } - records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName) + records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName) if err != nil { return fmt.Errorf("constellix: failed to search TXT records: %w", err) } @@ -125,10 +128,10 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // TXT record entry already existing if len(records) == 1 { - return d.appendRecordValue(dom, records[0].ID, info.Value) + return d.appendRecordValue(ctx, dom, records[0].ID, info.Value) } - err = d.createRecord(dom, info.EffectiveFQDN, recordName, info.Value) + err = d.createRecord(ctx, dom, info.EffectiveFQDN, recordName, info.Value) if err != nil { return fmt.Errorf("constellix: %w", err) } @@ -142,10 +145,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("constellix: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("constellix: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - dom, err := d.client.Domains.GetByName(dns01.UnFqdn(authZone)) + ctx := context.Background() + + dom, err := d.client.Domains.GetByName(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("constellix: failed to get domain (%s): %w", authZone, err) } @@ -155,7 +160,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("constellix: %w", err) } - records, err := d.client.TxtRecords.Search(dom.ID, internal.Exact, recordName) + records, err := d.client.TxtRecords.Search(ctx, dom.ID, internal.Exact, recordName) if err != nil { return fmt.Errorf("constellix: failed to search TXT records: %w", err) } @@ -168,7 +173,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } - record, err := d.client.TxtRecords.Get(dom.ID, records[0].ID) + record, err := d.client.TxtRecords.Get(ctx, dom.ID, records[0].ID) if err != nil { return fmt.Errorf("constellix: failed to get TXT records: %w", err) } @@ -179,14 +184,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { // only 1 record value, the whole record must be deleted. if len(record.Value) == 1 { - _, err = d.client.TxtRecords.Delete(dom.ID, record.ID) + _, err = d.client.TxtRecords.Delete(ctx, dom.ID, record.ID) if err != nil { return fmt.Errorf("constellix: failed to delete TXT records: %w", err) } return nil } - err = d.removeRecordValue(dom, record, info.Value) + err = d.removeRecordValue(ctx, dom, record, info.Value) if err != nil { return fmt.Errorf("constellix: %w", err) } @@ -194,7 +199,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } -func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value string) error { +func (d *DNSProvider) createRecord(ctx context.Context, dom internal.Domain, fqdn, recordName, value string) error { request := internal.RecordRequest{ Name: recordName, TTL: d.config.TTL, @@ -203,7 +208,7 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value }, } - _, err := d.client.TxtRecords.Create(dom.ID, request) + _, err := d.client.TxtRecords.Create(ctx, dom.ID, request) if err != nil { return fmt.Errorf("failed to create TXT record %s: %w", fqdn, err) } @@ -211,8 +216,8 @@ func (d *DNSProvider) createRecord(dom internal.Domain, fqdn, recordName, value return nil } -func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, value string) error { - record, err := d.client.TxtRecords.Get(dom.ID, recordID) +func (d *DNSProvider) appendRecordValue(ctx context.Context, dom internal.Domain, recordID int64, value string) error { + record, err := d.client.TxtRecords.Get(ctx, dom.ID, recordID) if err != nil { return fmt.Errorf("failed to get TXT records: %w", err) } @@ -227,7 +232,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val RoundRobin: append(record.RoundRobin, internal.RecordValue{Value: fmt.Sprintf(`%q`, value)}), } - _, err = d.client.TxtRecords.Update(dom.ID, record.ID, request) + _, err = d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request) if err != nil { return fmt.Errorf("failed to update TXT records: %w", err) } @@ -235,7 +240,7 @@ func (d *DNSProvider) appendRecordValue(dom internal.Domain, recordID int64, val return nil } -func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Record, value string) error { +func (d *DNSProvider) removeRecordValue(ctx context.Context, dom internal.Domain, record *internal.Record, value string) error { request := internal.RecordRequest{ Name: record.Name, TTL: record.TTL, @@ -247,7 +252,7 @@ func (d *DNSProvider) removeRecordValue(dom internal.Domain, record *internal.Re } } - _, err := d.client.TxtRecords.Update(dom.ID, record.ID, request) + _, err := d.client.TxtRecords.Update(ctx, dom.ID, record.ID, request) if err != nil { return fmt.Errorf("failed to update TXT records: %w", err) } diff --git a/providers/dns/constellix/internal/client.go b/providers/dns/constellix/internal/client.go index af399fa3..fee0c5a3 100644 --- a/providers/dns/constellix/internal/client.go +++ b/providers/dns/constellix/internal/client.go @@ -6,6 +6,9 @@ import ( "io" "net/http" "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const ( @@ -28,7 +31,7 @@ type Client struct { // NewClient Creates a Constellix client. func NewClient(httpClient *http.Client) *Client { if httpClient == nil { - httpClient = http.DefaultClient + httpClient = &http.Client{Timeout: 5 * time.Second} } client := &Client{ @@ -48,13 +51,15 @@ type service struct { } // do sends an API request and returns the API response. -func (c *Client) do(req *http.Request, v interface{}) error { +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } + defer func() { _ = resp.Body.Close() }() err = checkResponse(resp) @@ -64,11 +69,11 @@ func (c *Client) do(req *http.Request, v interface{}) error { raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read body: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - if err = json.Unmarshal(raw, v); err != nil { - return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) + if err = json.Unmarshal(raw, result); err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil @@ -83,21 +88,21 @@ func checkResponse(resp *http.Response) error { return nil } - data, err := io.ReadAll(resp.Body) - if err == nil && data != nil { - msg := &APIError{StatusCode: resp.StatusCode} + raw, err := io.ReadAll(resp.Body) + if err == nil && raw != nil { + errAPI := &APIError{StatusCode: resp.StatusCode} - if json.Unmarshal(data, msg) != nil { - return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(data)) + if json.Unmarshal(raw, errAPI) != nil { + return fmt.Errorf("API error: status code: %d: %v", resp.StatusCode, string(raw)) } switch resp.StatusCode { case http.StatusNotFound: - return &NotFound{APIError: msg} + return &NotFound{APIError: errAPI} case http.StatusBadRequest: - return &BadRequest{APIError: msg} + return &BadRequest{APIError: errAPI} default: - return msg + return errAPI } } diff --git a/providers/dns/constellix/internal/domains.go b/providers/dns/constellix/internal/domains.go index c6e2480d..485f0d53 100644 --- a/providers/dns/constellix/internal/domains.go +++ b/providers/dns/constellix/internal/domains.go @@ -1,6 +1,7 @@ package internal import ( + "context" "errors" "fmt" "net/http" @@ -13,15 +14,15 @@ type DomainService service // GetAll domains. // https://api-docs.constellix.com/?version=latest#484c3f21-d724-4ee4-a6fa-ab22c8eb9e9b -func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) { +func (s *DomainService) GetAll(ctx context.Context, params *PaginationParameters) ([]Domain, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains") if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } if params != nil { @@ -42,8 +43,8 @@ func (s *DomainService) GetAll(params *PaginationParameters) ([]Domain, error) { } // GetByName Gets domain by name. -func (s *DomainService) GetByName(domainName string) (Domain, error) { - domains, err := s.Search(Exact, domainName) +func (s *DomainService) GetByName(ctx context.Context, domainName string) (Domain, error) { + domains, err := s.Search(ctx, Exact, domainName) if err != nil { return Domain{}, err } @@ -61,15 +62,15 @@ func (s *DomainService) GetByName(domainName string) (Domain, error) { // Search searches for a domain by name. // https://api-docs.constellix.com/?version=latest#3d7b2679-2209-49f3-b011-b7d24e512008 -func (s *DomainService) Search(filter searchFilter, value string) ([]Domain, error) { +func (s *DomainService) Search(ctx context.Context, filter searchFilter, value string) ([]Domain, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", "search") if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } query := req.URL.Query() diff --git a/providers/dns/constellix/internal/domains_test.go b/providers/dns/constellix/internal/domains_test.go index 5df3f423..1b0779b3 100644 --- a/providers/dns/constellix/internal/domains_test.go +++ b/providers/dns/constellix/internal/domains_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "io" "net/http" "net/http/httptest" @@ -47,7 +48,7 @@ func TestDomainService_GetAll(t *testing.T) { } }) - data, err := client.Domains.GetAll(nil) + data, err := client.Domains.GetAll(context.Background(), nil) require.NoError(t, err) expected := []Domain{ @@ -83,7 +84,7 @@ func TestDomainService_Search(t *testing.T) { } }) - data, err := client.Domains.Search(Exact, "lego.wtf") + data, err := client.Domains.Search(context.Background(), Exact, "lego.wtf") require.NoError(t, err) expected := []Domain{ diff --git a/providers/dns/constellix/internal/txtrecords.go b/providers/dns/constellix/internal/txtrecords.go index e9df28e6..7880da4d 100644 --- a/providers/dns/constellix/internal/txtrecords.go +++ b/providers/dns/constellix/internal/txtrecords.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -14,20 +15,20 @@ type TxtRecordService service // Create a TXT record. // https://api-docs.constellix.com/?version=latest#22e24d5b-9ec0-49a7-b2b0-5ff0a28e71be -func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Record, error) { - body, err := json.Marshal(record) - if err != nil { - return nil, fmt.Errorf("failed to marshall request body: %w", err) - } - +func (s *TxtRecordService) Create(ctx context.Context, domainID int64, record RecordRequest) ([]Record, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt") if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body)) + body, err := json.Marshal(record) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } var records []Record @@ -41,15 +42,15 @@ func (s *TxtRecordService) Create(domainID int64, record RecordRequest) ([]Recor // GetAll TXT records. // https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2 -func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) { +func (s *TxtRecordService) GetAll(ctx context.Context, domainID int64) ([]Record, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt") if err != nil { - return nil, fmt.Errorf("failed to create request endpoint: %w", err) + return nil, fmt.Errorf("failed to create endpoint: %w", err) } - req, err := http.NewRequest(http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } var records []Record @@ -63,15 +64,15 @@ func (s *TxtRecordService) GetAll(domainID int64) ([]Record, error) { // Get a TXT record. // https://api-docs.constellix.com/?version=latest#e7103c53-2ad8-4bc8-b5b3-4c22c4b571b2 -func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) { +func (s *TxtRecordService) Get(ctx context.Context, domainID, recordID int64) (*Record, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } var records Record @@ -85,20 +86,20 @@ func (s *TxtRecordService) Get(domainID, recordID int64) (*Record, error) { // Update a TXT record. // https://api-docs.constellix.com/?version=latest#d4e9ab2e-fac0-45a6-b0e4-cf62a2d2e3da -func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) { - body, err := json.Marshal(record) - if err != nil { - return nil, fmt.Errorf("failed to marshall request body: %w", err) - } - +func (s *TxtRecordService) Update(ctx context.Context, domainID, recordID int64, record RecordRequest) (*SuccessMessage, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodPut, endpoint, bytes.NewReader(body)) + body, err := json.Marshal(record) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } var msg SuccessMessage @@ -112,15 +113,15 @@ func (s *TxtRecordService) Update(domainID, recordID int64, record RecordRequest // Delete a TXT record. // https://api-docs.constellix.com/?version=latest#135947f7-d6c8-481a-83c7-4d387b0bdf9e -func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, error) { +func (s *TxtRecordService) Delete(ctx context.Context, domainID, recordID int64) (*SuccessMessage, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", strconv.FormatInt(recordID, 10)) if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodDelete, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } var msg *SuccessMessage @@ -134,15 +135,15 @@ func (s *TxtRecordService) Delete(domainID, recordID int64) (*SuccessMessage, er // Search searches for a TXT record by name. // https://api-docs.constellix.com/?version=latest#81003e4f-bd3f-413f-a18d-6d9d18f10201 -func (s *TxtRecordService) Search(domainID int64, filter searchFilter, value string) ([]Record, error) { +func (s *TxtRecordService) Search(ctx context.Context, domainID int64, filter searchFilter, value string) ([]Record, error) { endpoint, err := s.client.createEndpoint(defaultVersion, "domains", strconv.FormatInt(domainID, 10), "records", "txt", "search") if err != nil { return nil, fmt.Errorf("failed to create request endpoint: %w", err) } - req, err := http.NewRequest(http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } query := req.URL.Query() diff --git a/providers/dns/constellix/internal/txtrecords_test.go b/providers/dns/constellix/internal/txtrecords_test.go index e0c4de6d..7adc4af5 100644 --- a/providers/dns/constellix/internal/txtrecords_test.go +++ b/providers/dns/constellix/internal/txtrecords_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/json" "io" "net/http" @@ -34,7 +35,7 @@ func TestTxtRecordService_Create(t *testing.T) { } }) - records, err := client.TxtRecords.Create(12345, RecordRequest{}) + records, err := client.TxtRecords.Create(context.Background(), 12345, RecordRequest{}) require.NoError(t, err) recordsJSON, err := json.Marshal(records) @@ -69,7 +70,7 @@ func TestTxtRecordService_GetAll(t *testing.T) { } }) - records, err := client.TxtRecords.GetAll(12345) + records, err := client.TxtRecords.GetAll(context.Background(), 12345) require.NoError(t, err) recordsJSON, err := json.Marshal(records) @@ -104,7 +105,7 @@ func TestTxtRecordService_Get(t *testing.T) { } }) - record, err := client.TxtRecords.Get(12345, 6789) + record, err := client.TxtRecords.Get(context.Background(), 12345, 6789) require.NoError(t, err) expected := &Record{ @@ -145,7 +146,7 @@ func TestTxtRecordService_Update(t *testing.T) { } }) - msg, err := client.TxtRecords.Update(12345, 6789, RecordRequest{}) + msg, err := client.TxtRecords.Update(context.Background(), 12345, 6789, RecordRequest{}) require.NoError(t, err) expected := &SuccessMessage{Success: "Record updated successfully"} @@ -168,7 +169,7 @@ func TestTxtRecordService_Delete(t *testing.T) { } }) - msg, err := client.TxtRecords.Delete(12345, 6789) + msg, err := client.TxtRecords.Delete(context.Background(), 12345, 6789) require.NoError(t, err) expected := &SuccessMessage{Success: "Record deleted successfully"} @@ -198,7 +199,7 @@ func TestTxtRecordService_Search(t *testing.T) { } }) - records, err := client.TxtRecords.Search(12345, Exact, "test") + records, err := client.TxtRecords.Search(context.Background(), 12345, Exact, "test") require.NoError(t, err) recordsJSON, err := json.Marshal(records) diff --git a/providers/dns/constellix/internal/model.go b/providers/dns/constellix/internal/types.go similarity index 100% rename from providers/dns/constellix/internal/model.go rename to providers/dns/constellix/internal/types.go diff --git a/providers/dns/desec/desec.go b/providers/dns/desec/desec.go index e391fd38..1c88ad08 100644 --- a/providers/dns/desec/desec.go +++ b/providers/dns/desec/desec.go @@ -106,7 +106,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -156,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("desec: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("desec: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) diff --git a/providers/dns/designate/designate.go b/providers/dns/designate/designate.go index a8502819..da3b6f78 100644 --- a/providers/dns/designate/designate.go +++ b/providers/dns/designate/designate.go @@ -128,12 +128,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err) + return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } zoneID, err := d.getZoneID(authZone) if err != nil { - return fmt.Errorf("designate: %w", err) + return fmt.Errorf("designate: couldn't get zone ID in Present: %w", err) } // use mutex to prevent race condition between creating the record and verifying it @@ -168,7 +168,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } zoneID, err := d.getZoneID(authZone) diff --git a/providers/dns/designate/designate_test.go b/providers/dns/designate/designate_test.go index f80fee1c..881faeef 100644 --- a/providers/dns/designate/designate_test.go +++ b/providers/dns/designate/designate_test.go @@ -286,6 +286,9 @@ func setupTestProvider(t *testing.T) string { t.Helper() mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(`{ "access": { @@ -319,9 +322,6 @@ func setupTestProvider(t *testing.T) string { w.WriteHeader(http.StatusOK) }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - return server.URL } diff --git a/providers/dns/digitalocean/client.go b/providers/dns/digitalocean/client.go deleted file mode 100644 index 82580e78..00000000 --- a/providers/dns/digitalocean/client.go +++ /dev/null @@ -1,131 +0,0 @@ -package digitalocean - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - - "github.com/go-acme/lego/v4/challenge/dns01" -) - -const defaultBaseURL = "https://api.digitalocean.com" - -// txtRecordResponse represents a response from DO's API after making a TXT record. -type txtRecordResponse struct { - DomainRecord record `json:"domain_record"` -} - -type record struct { - ID int `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Name string `json:"name,omitempty"` - Data string `json:"data,omitempty"` - TTL int `json:"ttl,omitempty"` -} - -type apiError struct { - ID string `json:"id"` - Message string `json:"message"` -} - -func (d *DNSProvider) removeTxtRecord(domain string, recordID int) error { - authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain)) - if err != nil { - return fmt.Errorf("could not determine zone for domain %q: %w", domain, err) - } - - reqURL := fmt.Sprintf("%s/v2/domains/%s/records/%d", d.config.BaseURL, dns01.UnFqdn(authZone), recordID) - req, err := d.newRequest(http.MethodDelete, reqURL, nil) - if err != nil { - return err - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return readError(req, resp) - } - - return nil -} - -func (d *DNSProvider) addTxtRecord(fqdn, value string) (*txtRecordResponse, error) { - authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(fqdn)) - if err != nil { - return nil, fmt.Errorf("could not determine zone for domain %q: %w", fqdn, err) - } - - reqData := record{Type: "TXT", Name: fqdn, Data: value, TTL: d.config.TTL} - body, err := json.Marshal(reqData) - if err != nil { - return nil, err - } - - reqURL := fmt.Sprintf("%s/v2/domains/%s/records", d.config.BaseURL, dns01.UnFqdn(authZone)) - req, err := d.newRequest(http.MethodPost, reqURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return nil, readError(req, resp) - } - - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.New(toUnreadableBodyMessage(req, content)) - } - - // Everything looks good; but we'll need the ID later to delete the record - respData := &txtRecordResponse{} - err = json.Unmarshal(content, respData) - if err != nil { - return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(req, content)) - } - - return respData, nil -} - -func (d *DNSProvider) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, reqURL, body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.config.AuthToken)) - - return req, nil -} - -func readError(req *http.Request, resp *http.Response) error { - content, err := io.ReadAll(resp.Body) - if err != nil { - return errors.New(toUnreadableBodyMessage(req, content)) - } - - var errInfo apiError - err = json.Unmarshal(content, &errInfo) - if err != nil { - return fmt.Errorf("apiError unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) - } - - return fmt.Errorf("HTTP %d: %s: %s", resp.StatusCode, errInfo.ID, errInfo.Message) -} - -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) -} diff --git a/providers/dns/digitalocean/digitalocean.go b/providers/dns/digitalocean/digitalocean.go index df27244d..dd790faa 100644 --- a/providers/dns/digitalocean/digitalocean.go +++ b/providers/dns/digitalocean/digitalocean.go @@ -2,14 +2,17 @@ package digitalocean import ( + "context" "errors" "fmt" "net/http" + "net/url" "sync" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/digitalocean/internal" ) // Environment variables names. @@ -38,7 +41,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - BaseURL: env.GetOrDefaultString(EnvAPIUrl, defaultBaseURL), + BaseURL: env.GetOrDefaultString(EnvAPIUrl, internal.DefaultBaseURL), TTL: env.GetOrDefaultInt(EnvTTL, 30), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Second), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 5*time.Second), @@ -50,7 +53,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + recordIDs map[string]int recordIDsMu sync.Mutex } @@ -80,12 +85,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("digitalocean: credentials missing") } - if config.BaseURL == "" { - config.BaseURL = defaultBaseURL + client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.AuthToken)) + + if config.BaseURL != "" { + var err error + client.BaseURL, err = url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("digitalocean: %w", err) + } } return &DNSProvider{ config: config, + client: client, recordIDs: make(map[string]int), }, nil } @@ -100,7 +112,14 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - respData, err := d.addTxtRecord(info.EffectiveFQDN, info.Value) + authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN)) + if err != nil { + return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) + } + + record := internal.Record{Type: "TXT", Name: info.EffectiveFQDN, Data: info.Value, TTL: d.config.TTL} + + respData, err := d.client.AddTxtRecord(context.Background(), authZone, record) if err != nil { return fmt.Errorf("digitalocean: %w", err) } @@ -118,7 +137,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("digitalocean: %w", err) + return fmt.Errorf("designate: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // get the record's unique ID from when we created it @@ -129,7 +148,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("digitalocean: unknown record ID for '%s'", info.EffectiveFQDN) } - err = d.removeTxtRecord(authZone, recordID) + err = d.client.RemoveTxtRecord(context.Background(), authZone, recordID) if err != nil { return fmt.Errorf("digitalocean: %w", err) } diff --git a/providers/dns/digitalocean/digitalocean_test.go b/providers/dns/digitalocean/digitalocean_test.go index 7cdc9638..bfd2d68c 100644 --- a/providers/dns/digitalocean/digitalocean_test.go +++ b/providers/dns/digitalocean/digitalocean_test.go @@ -1,6 +1,7 @@ package digitalocean import ( + "bytes" "fmt" "io" "net/http" @@ -115,6 +116,7 @@ func TestDNSProvider_Present(t *testing.T) { mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method, "method") + assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept") assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type") assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization") @@ -125,7 +127,7 @@ func TestDNSProvider_Present(t *testing.T) { } expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}` - assert.Equal(t, expectedReqBody, string(reqBody)) + assert.Equal(t, expectedReqBody, string(bytes.TrimSpace(reqBody))) w.WriteHeader(http.StatusCreated) _, err = fmt.Fprintf(w, `{ @@ -157,7 +159,7 @@ func TestDNSProvider_CleanUp(t *testing.T) { assert.Equal(t, "/v2/domains/example.com/records/1234567", r.URL.Path, "Path") - // NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type... + assert.Equal(t, "application/json", r.Header.Get("Accept"), "Accept") assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type") assert.Equal(t, "Bearer asdf1234", r.Header.Get("Authorization"), "Authorization") diff --git a/providers/dns/digitalocean/internal/client.go b/providers/dns/digitalocean/internal/client.go new file mode 100644 index 00000000..e7dd181b --- /dev/null +++ b/providers/dns/digitalocean/internal/client.go @@ -0,0 +1,142 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" +) + +// DefaultBaseURL default API endpoint. +const DefaultBaseURL = "https://api.digitalocean.com" + +// Client the Digital Ocean API client. +type Client struct { + BaseURL *url.URL + httpClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(hc *http.Client) *Client { + baseURL, _ := url.Parse(DefaultBaseURL) + + if hc == nil { + hc = &http.Client{Timeout: 5 * time.Second} + } + + return &Client{BaseURL: baseURL, httpClient: hc} +} + +func (c *Client) AddTxtRecord(ctx context.Context, zone string, record Record) (*TxtRecordResponse, error) { + endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return nil, err + } + + respData := &TxtRecordResponse{} + err = c.do(req, respData) + if err != nil { + return nil, err + } + + return respData, nil +} + +func (c *Client) RemoveTxtRecord(ctx context.Context, zone string, recordID int) error { + endpoint := c.BaseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records", strconv.Itoa(recordID)) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + resp, err := c.httpClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusBadRequest { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + // NOTE: Even though the body is empty, DigitalOcean API docs still show setting this Content-Type... + req.Header.Set("Content-Type", "application/json") + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errInfo APIError + err := json.Unmarshal(raw, &errInfo) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("[status code %d] %w", resp.StatusCode, errInfo) +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client +} diff --git a/providers/dns/digitalocean/internal/client_test.go b/providers/dns/digitalocean/internal/client_test.go new file mode 100644 index 00000000..081e1a10 --- /dev/null +++ b/providers/dns/digitalocean/internal/client_test.go @@ -0,0 +1,139 @@ +package internal + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := NewClient(OAuthStaticAccessToken(server.Client(), "secret")) + client.BaseURL, _ = url.Parse(server.URL) + + mux.HandleFunc(pattern, handler) + + return client +} + +func checkHeader(req *http.Request, name, value string) error { + val := req.Header.Get(name) + if val != value { + return fmt.Errorf("invalid header value, got: %s want %s", val, value) + } + return nil +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) +} + +func TestClient_AddTxtRecord(t *testing.T) { + client := setupTest(t, "/v2/domains/example.com/records", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) + return + } + + err := checkHeader(req, "Accept", "application/json") + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + err = checkHeader(req, "Content-Type", "application/json") + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + err = checkHeader(req, "Authorization", "Bearer secret") + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + + reqBody, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + expectedReqBody := `{"type":"TXT","name":"_acme-challenge.example.com.","data":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":30}` + if expectedReqBody != string(bytes.TrimSpace(reqBody)) { + http.Error(rw, fmt.Sprintf("unexpected request body: %s", string(bytes.TrimSpace(reqBody))), http.StatusBadRequest) + return + } + + rw.WriteHeader(http.StatusCreated) + writeFixture(rw, "domains-records_POST.json") + }) + + record := Record{ + Type: "TXT", + Name: "_acme-challenge.example.com.", + Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI", + TTL: 30, + } + + newRecord, err := client.AddTxtRecord(context.Background(), "example.com", record) + require.NoError(t, err) + + expected := &TxtRecordResponse{DomainRecord: Record{ + ID: 1234567, + Type: "TXT", + Name: "_acme-challenge", + Data: "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI", + TTL: 0, + }} + + assert.Equal(t, expected, newRecord) +} + +func TestClient_RemoveTxtRecord(t *testing.T) { + client := setupTest(t, "/v2/domains/example.com/records/1234567", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodDelete { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) + return + } + + err := checkHeader(req, "Accept", "application/json") + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + err = checkHeader(req, "Authorization", "Bearer secret") + if err != nil { + http.Error(rw, err.Error(), http.StatusUnauthorized) + return + } + + rw.WriteHeader(http.StatusNoContent) + }) + + err := client.RemoveTxtRecord(context.Background(), "example.com", 1234567) + require.NoError(t, err) +} diff --git a/providers/dns/digitalocean/internal/fixtures/domains-records_POST.json b/providers/dns/digitalocean/internal/fixtures/domains-records_POST.json new file mode 100644 index 00000000..8f13835a --- /dev/null +++ b/providers/dns/digitalocean/internal/fixtures/domains-records_POST.json @@ -0,0 +1,11 @@ +{ + "domain_record": { + "id": 1234567, + "type": "TXT", + "name": "_acme-challenge", + "data": "w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI", + "priority": null, + "port": null, + "weight": null + } +} diff --git a/providers/dns/digitalocean/internal/types.go b/providers/dns/digitalocean/internal/types.go new file mode 100644 index 00000000..c1246e6e --- /dev/null +++ b/providers/dns/digitalocean/internal/types.go @@ -0,0 +1,25 @@ +package internal + +import "fmt" + +// TxtRecordResponse represents a response from DO's API after making a TXT record. +type TxtRecordResponse struct { + DomainRecord Record `json:"domain_record"` +} + +type Record struct { + ID int `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Data string `json:"data,omitempty"` + TTL int `json:"ttl,omitempty"` +} + +type APIError struct { + ID string `json:"id"` + Message string `json:"message"` +} + +func (a APIError) Error() string { + return fmt.Sprintf("%s: %s", a.ID, a.Message) +} diff --git a/providers/dns/dns_providers.go b/providers/dns/dns_providers.go index 685b6f80..07fbd330 100644 --- a/providers/dns/dns_providers.go +++ b/providers/dns/dns_providers.go @@ -126,7 +126,7 @@ import ( // NewDNSChallengeProviderByName Factory for DNS providers. func NewDNSChallengeProviderByName(name string) (challenge.Provider, error) { switch name { - case "acme-dns": + case "acme-dns": // TODO(ldez): remove "-" in v5 return acmedns.NewDNSProvider() case "alidns": return alidns.NewDNSProvider() diff --git a/providers/dns/dnshomede/dnshomede.go b/providers/dns/dnshomede/dnshomede.go index f098d573..1b81be74 100644 --- a/providers/dns/dnshomede/dnshomede.go +++ b/providers/dns/dnshomede/dnshomede.go @@ -2,6 +2,7 @@ package dnshomede import ( + "context" "errors" "fmt" "net/http" @@ -99,7 +100,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, _, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.Add(dns01.UnFqdn(info.EffectiveFQDN), info.Value) + err := d.client.Add(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value) if err != nil { return fmt.Errorf("dnshomede: %w", err) } @@ -111,7 +112,7 @@ func (d *DNSProvider) Present(domain, _, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.Remove(dns01.UnFqdn(info.EffectiveFQDN), info.Value) + err := d.client.Remove(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value) if err != nil { return fmt.Errorf("dnshomede: %w", err) } diff --git a/providers/dns/dnshomede/internal/client.go b/providers/dns/dnshomede/internal/client.go index 175b12e8..591c32a4 100644 --- a/providers/dns/dnshomede/internal/client.go +++ b/providers/dns/dnshomede/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "context" "errors" "fmt" "io" @@ -9,6 +10,8 @@ import ( "strings" "sync" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const ( @@ -22,8 +25,8 @@ const defaultBaseURL = "https://www.dnshome.de/dyndns.php" // Client the dnsHome.de client. type Client struct { - HTTPClient *http.Client baseURL string + HTTPClient *http.Client credentials map[string]string credMu sync.Mutex @@ -40,75 +43,48 @@ func NewClient(credentials map[string]string) *Client { // Add adds a TXT record. // only one TXT record for ACME is allowed, so it will update the "current" TXT record. -func (c *Client) Add(hostname, value string) error { +func (c *Client) Add(ctx context.Context, hostname, value string) error { domain := strings.TrimPrefix(hostname, "_acme-challenge.") - c.credMu.Lock() - password, ok := c.credentials[domain] - c.credMu.Unlock() - - if !ok { - return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain) - } - - return c.do(url.UserPassword(domain, password), addAction, value) + return c.doAction(ctx, domain, addAction, value) } // Remove removes a TXT record. // only one TXT record for ACME is allowed, so it will remove "all" the TXT records. -func (c *Client) Remove(hostname, value string) error { +func (c *Client) Remove(ctx context.Context, hostname, value string) error { domain := strings.TrimPrefix(hostname, "_acme-challenge.") - c.credMu.Lock() - password, ok := c.credentials[domain] - c.credMu.Unlock() - - if !ok { - return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain) - } - - return c.do(url.UserPassword(domain, password), removeAction, value) + return c.doAction(ctx, domain, removeAction, value) } -func (c *Client) do(userInfo *url.Userinfo, action, value string) error { - if len(value) < 12 { - return fmt.Errorf("the TXT value must have more than 12 characters: %s", value) - } - - apiEndpoint, err := url.Parse(c.baseURL) +func (c *Client) doAction(ctx context.Context, domain, action, value string) error { + endpoint, err := c.createEndpoint(domain, action, value) if err != nil { return err } - apiEndpoint.User = userInfo - - query := apiEndpoint.Query() - query.Set("acme", action) - query.Set("txt", value) - apiEndpoint.RawQuery = query.Encode() - - req, err := http.NewRequest(http.MethodPost, apiEndpoint.String(), http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), http.NoBody) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - all, _ := io.ReadAll(resp.Body) - return fmt.Errorf("%d: %s", resp.StatusCode, string(all)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - all, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return err + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - output := string(all) + output := string(raw) if !strings.HasPrefix(output, successCode) { return errors.New(output) @@ -116,3 +92,31 @@ func (c *Client) do(userInfo *url.Userinfo, action, value string) error { return nil } + +func (c *Client) createEndpoint(domain, action, value string) (*url.URL, error) { + if len(value) < 12 { + return nil, fmt.Errorf("the TXT value must have more than 12 characters: %s", value) + } + + endpoint, err := url.Parse(c.baseURL) + if err != nil { + return nil, err + } + + c.credMu.Lock() + password, ok := c.credentials[domain] + c.credMu.Unlock() + + if !ok { + return nil, fmt.Errorf("domain %s not found in credentials, check your credentials map", domain) + } + + endpoint.User = url.UserPassword(domain, password) + + query := endpoint.Query() + query.Set("acme", action) + query.Set("txt", value) + endpoint.RawQuery = query.Encode() + + return endpoint, nil +} diff --git a/providers/dns/dnshomede/internal/client_test.go b/providers/dns/dnshomede/internal/client_test.go index 305d83cb..e6f2c1b7 100644 --- a/providers/dns/dnshomede/internal/client_test.go +++ b/providers/dns/dnshomede/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -9,79 +10,55 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_Add(t *testing.T) { - txtValue := "123456789012" +func setupTest(t *testing.T, credentials map[string]string, handler http.HandlerFunc) *Client { + t.Helper() mux := http.NewServeMux() - mux.HandleFunc("/", handlerMock(addAction, txtValue)) server := httptest.NewServer(mux) + t.Cleanup(server.Close) - credentials := map[string]string{ - "example.org": "secret", - } + mux.HandleFunc("/", handler) client := NewClient(credentials) client.HTTPClient = server.Client() client.baseURL = server.URL - err := client.Add("example.org", txtValue) + return client +} + +func TestClient_Add(t *testing.T) { + txtValue := "123456789012" + + client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(addAction, txtValue)) + + err := client.Add(context.Background(), "example.org", txtValue) require.NoError(t, err) } func TestClient_Add_error(t *testing.T) { txtValue := "123456789012" - mux := http.NewServeMux() - mux.HandleFunc("/", handlerMock(addAction, txtValue)) - server := httptest.NewServer(mux) + client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(addAction, txtValue)) - credentials := map[string]string{ - "example.com": "secret", - } - - client := NewClient(credentials) - client.HTTPClient = server.Client() - client.baseURL = server.URL - - err := client.Add("example.org", txtValue) + err := client.Add(context.Background(), "example.org", txtValue) require.Error(t, err) } func TestClient_Remove(t *testing.T) { txtValue := "ABCDEFGHIJKL" - mux := http.NewServeMux() - mux.HandleFunc("/", handlerMock(removeAction, txtValue)) - server := httptest.NewServer(mux) + client := setupTest(t, map[string]string{"example.org": "secret"}, handlerMock(removeAction, txtValue)) - credentials := map[string]string{ - "example.org": "secret", - } - - client := NewClient(credentials) - client.HTTPClient = server.Client() - client.baseURL = server.URL - - err := client.Remove("example.org", txtValue) + err := client.Remove(context.Background(), "example.org", txtValue) require.NoError(t, err) } func TestClient_Remove_error(t *testing.T) { txtValue := "ABCDEFGHIJKL" - mux := http.NewServeMux() - mux.HandleFunc("/", handlerMock(removeAction, txtValue)) - server := httptest.NewServer(mux) + client := setupTest(t, map[string]string{"example.com": "secret"}, handlerMock(removeAction, txtValue)) - credentials := map[string]string{ - "example.com": "secret", - } - - client := NewClient(credentials) - client.HTTPClient = server.Client() - client.baseURL = server.URL - - err := client.Remove("example.org", txtValue) + err := client.Remove(context.Background(), "example.org", txtValue) require.Error(t, err) } diff --git a/providers/dns/dnsimple/dnsimple.go b/providers/dns/dnsimple/dnsimple.go index 67f3b3e9..4a5b8788 100644 --- a/providers/dns/dnsimple/dnsimple.go +++ b/providers/dns/dnsimple/dnsimple.go @@ -149,7 +149,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) getHostedZone(domain string) (string, error) { authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } accountID, err := d.getAccountID() diff --git a/providers/dns/dnsmadeeasy/dnsmadeeasy.go b/providers/dns/dnsmadeeasy/dnsmadeeasy.go index b3ae9245..50512fe6 100644 --- a/providers/dns/dnsmadeeasy/dnsmadeeasy.go +++ b/providers/dns/dnsmadeeasy/dnsmadeeasy.go @@ -2,10 +2,12 @@ package dnsmadeeasy import ( + "context" "crypto/tls" "errors" "fmt" "net/http" + "net/url" "strings" "time" @@ -86,12 +88,12 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { var baseURL string if config.Sandbox { - baseURL = "https://api.sandbox.dnsmadeeasy.com/V2.0" + baseURL = internal.DefaultSandboxBaseURL } else { - if len(config.BaseURL) > 0 { - baseURL = config.BaseURL + if config.BaseURL == "" { + baseURL = internal.DefaultProdBaseURL } else { - baseURL = "https://api.dnsmadeeasy.com/V2.0" + baseURL = config.BaseURL } } @@ -101,7 +103,10 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { } client.HTTPClient = config.HTTPClient - client.BaseURL = baseURL + client.BaseURL, err = url.Parse(baseURL) + if err != nil { + return nil, err + } return &DNSProvider{ client: client, @@ -115,11 +120,13 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err) } + ctx := context.Background() + // fetch the domain details - domain, err := d.client.GetDomain(authZone) + domain, err := d.client.GetDomain(ctx, authZone) if err != nil { return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err) } @@ -128,7 +135,7 @@ func (d *DNSProvider) Present(domainName, token, keyAuth string) error { name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1) record := &internal.Record{Type: "TXT", Name: name, Value: info.Value, TTL: d.config.TTL} - err = d.client.CreateRecord(domain, record) + err = d.client.CreateRecord(ctx, domain, record) if err != nil { return fmt.Errorf("dnsmadeeasy: unable to create record for %s: %w", name, err) } @@ -141,18 +148,20 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("dnsmadeeasy: unable to find zone for %s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("dnsmadeeasy: could not find zone for domain %q (%s): %w", domainName, info.EffectiveFQDN, err) } + ctx := context.Background() + // fetch the domain details - domain, err := d.client.GetDomain(authZone) + domain, err := d.client.GetDomain(ctx, authZone) if err != nil { return fmt.Errorf("dnsmadeeasy: unable to get domain for zone %s: %w", authZone, err) } // find matching records name := strings.Replace(info.EffectiveFQDN, "."+authZone, "", 1) - records, err := d.client.GetRecords(domain, name, "TXT") + records, err := d.client.GetRecords(ctx, domain, name, "TXT") if err != nil { return fmt.Errorf("dnsmadeeasy: unable to get records for domain %s: %w", domain.Name, err) } @@ -160,7 +169,7 @@ func (d *DNSProvider) CleanUp(domainName, token, keyAuth string) error { // delete records var lastError error for _, record := range *records { - err = d.client.DeleteRecord(record) + err = d.client.DeleteRecord(ctx, record) if err != nil { lastError = fmt.Errorf("dnsmadeeasy: unable to delete record [id=%d, name=%s]: %w", record.ID, record.Name, err) } diff --git a/providers/dns/dnsmadeeasy/internal/client.go b/providers/dns/dnsmadeeasy/internal/client.go index 85d18a0f..9890de8b 100644 --- a/providers/dns/dnsmadeeasy/internal/client.go +++ b/providers/dns/dnsmadeeasy/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" "encoding/hex" @@ -10,34 +11,25 @@ import ( "fmt" "io" "net/http" + "net/url" + "strconv" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) -// Domain holds the DNSMadeEasy API representation of a Domain. -type Domain struct { - ID int `json:"id"` - Name string `json:"name"` -} - -// Record holds the DNSMadeEasy API representation of a Domain Record. -type Record struct { - ID int `json:"id"` - Type string `json:"type"` - Name string `json:"name"` - Value string `json:"value"` - TTL int `json:"ttl"` - SourceID int `json:"sourceId"` -} - -type recordsResponse struct { - Records *[]Record `json:"data"` -} +// Default API endpoints. +const ( + DefaultSandboxBaseURL = "https://api.sandbox.dnsmadeeasy.com/V2.0" + DefaultProdBaseURL = "https://api.dnsmadeeasy.com/V2.0" +) // Client DNSMadeEasy client. type Client struct { - apiKey string - apiSecret string - BaseURL string + apiKey string + apiSecret string + + BaseURL *url.URL HTTPClient *http.Client } @@ -51,26 +43,33 @@ func NewClient(apiKey, apiSecret string) (*Client, error) { return nil, errors.New("credentials missing: API secret") } + baseURL, _ := url.Parse(DefaultProdBaseURL) + return &Client{ apiKey: apiKey, apiSecret: apiSecret, - HTTPClient: &http.Client{}, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, }, nil } // GetDomain gets a domain. -func (c *Client) GetDomain(authZone string) (*Domain, error) { - domainName := authZone[0 : len(authZone)-1] - resource := fmt.Sprintf("%s%s", "/dns/managed/name?domainname=", domainName) +func (c *Client) GetDomain(ctx context.Context, authZone string) (*Domain, error) { + endpoint := c.BaseURL.JoinPath("dns", "managed", "name") - resp, err := c.sendRequest(http.MethodGet, resource, nil) + domainName := authZone[0 : len(authZone)-1] + + query := endpoint.Query() + query.Set("domainname", domainName) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer resp.Body.Close() domain := &Domain{} - err = json.NewDecoder(resp.Body).Decode(&domain) + err = c.do(req, domain) if err != nil { return nil, err } @@ -79,17 +78,20 @@ func (c *Client) GetDomain(authZone string) (*Domain, error) { } // GetRecords gets all TXT records. -func (c *Client) GetRecords(domain *Domain, recordName, recordType string) (*[]Record, error) { - resource := fmt.Sprintf("%s/%d/%s%s%s%s", "/dns/managed", domain.ID, "records?recordName=", recordName, "&type=", recordType) +func (c *Client) GetRecords(ctx context.Context, domain *Domain, recordName, recordType string) (*[]Record, error) { + endpoint := c.BaseURL.JoinPath("dns", "managed", strconv.Itoa(domain.ID), "records") - resp, err := c.sendRequest(http.MethodGet, resource, nil) + query := endpoint.Query() + query.Set("recordName", recordName) + query.Set("type", recordType) + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer resp.Body.Close() records := &recordsResponse{} - err = json.NewDecoder(resp.Body).Decode(&records) + err = c.do(req, records) if err != nil { return nil, err } @@ -98,69 +100,73 @@ func (c *Client) GetRecords(domain *Domain, recordName, recordType string) (*[]R } // CreateRecord creates a TXT records. -func (c *Client) CreateRecord(domain *Domain, record *Record) error { - url := fmt.Sprintf("%s/%d/%s", "/dns/managed", domain.ID, "records") +func (c *Client) CreateRecord(ctx context.Context, domain *Domain, record *Record) error { + endpoint := c.BaseURL.JoinPath("dns", "managed", strconv.Itoa(domain.ID), "records") - resp, err := c.sendRequest(http.MethodPost, url, record) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return err } - defer resp.Body.Close() - return nil + return c.do(req, nil) } // DeleteRecord deletes a TXT records. -func (c *Client) DeleteRecord(record Record) error { - resource := fmt.Sprintf("%s/%d/%s/%d", "/dns/managed", record.SourceID, "records", record.ID) +func (c *Client) DeleteRecord(ctx context.Context, record Record) error { + endpoint := c.BaseURL.JoinPath("/dns/managed", strconv.Itoa(record.SourceID), "records", strconv.Itoa(record.ID)) - resp, err := c.sendRequest(http.MethodDelete, resource, nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - defer resp.Body.Close() + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + err := c.sign(req, time.Now().UTC().Format(time.RFC1123)) + if err != nil { + return err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if err = json.Unmarshal(raw, result); err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } return nil } -func (c *Client) sendRequest(method, resource string, payload interface{}) (*http.Response, error) { - url := fmt.Sprintf("%s%s", c.BaseURL, resource) - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - timestamp := time.Now().UTC().Format(time.RFC1123) +func (c *Client) sign(req *http.Request, timestamp string) error { signature, err := computeHMAC(timestamp, c.apiSecret) if err != nil { - return nil, err + return err } - req, err := http.NewRequest(method, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } req.Header.Set("x-dnsme-apiKey", c.apiKey) req.Header.Set("x-dnsme-requestDate", timestamp) req.Header.Set("x-dnsme-hmac", signature) - req.Header.Set("accept", "application/json") - req.Header.Set("content-type", "application/json") - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - - if resp.StatusCode > 299 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("request failed with HTTP status code %d", resp.StatusCode) - } - return nil, fmt.Errorf("request failed with HTTP status code %d: %s", resp.StatusCode, string(body)) - } - - return resp, nil + return nil } func computeHMAC(message, secret string) (string, error) { @@ -172,3 +178,27 @@ func computeHMAC(message, secret string) (string, error) { } return hex.EncodeToString(h.Sum(nil)), nil } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/dnsmadeeasy/internal/client_test.go b/providers/dns/dnsmadeeasy/internal/client_test.go new file mode 100644 index 00000000..72121469 --- /dev/null +++ b/providers/dns/dnsmadeeasy/internal/client_test.go @@ -0,0 +1,28 @@ +package internal + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_sign(t *testing.T) { + apiKey := "key" + + client := Client{apiKey: apiKey, apiSecret: "secret"} + + req, err := http.NewRequest(http.MethodGet, "", http.NoBody) + require.NoError(t, err) + + timestamp := time.Date(2015, time.June, 2, 2, 36, 7, 0, time.UTC).Format(time.RFC1123) + + err = client.sign(req, timestamp) + require.NoError(t, err) + + assert.Equal(t, apiKey, req.Header.Get("x-dnsme-apiKey")) + assert.Equal(t, timestamp, req.Header.Get("x-dnsme-requestDate")) + assert.Equal(t, "6b6c8432119c31e1d3776eb4cd3abd92fae4a71c", req.Header.Get("x-dnsme-hmac")) +} diff --git a/providers/dns/dnsmadeeasy/internal/types.go b/providers/dns/dnsmadeeasy/internal/types.go new file mode 100644 index 00000000..a10da88e --- /dev/null +++ b/providers/dns/dnsmadeeasy/internal/types.go @@ -0,0 +1,21 @@ +package internal + +// Domain holds the DNSMadeEasy API representation of a Domain. +type Domain struct { + ID int `json:"id"` + Name string `json:"name"` +} + +// Record holds the DNSMadeEasy API representation of a Domain Record. +type Record struct { + ID int `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + Value string `json:"value"` + TTL int `json:"ttl"` + SourceID int `json:"sourceId"` +} + +type recordsResponse struct { + Records *[]Record `json:"data"` +} diff --git a/providers/dns/dnspod/dnspod.go b/providers/dns/dnspod/dnspod.go index 740c648e..c20caf3c 100644 --- a/providers/dns/dnspod/dnspod.go +++ b/providers/dns/dnspod/dnspod.go @@ -143,7 +143,7 @@ func (d *DNSProvider) getHostedZone(domain string) (string, string, error) { authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return "", "", err + return "", "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } var hostedZone dnspod.Domain diff --git a/providers/dns/dode/client.go b/providers/dns/dode/client.go deleted file mode 100644 index d788c1a7..00000000 --- a/providers/dns/dode/client.go +++ /dev/null @@ -1,57 +0,0 @@ -package dode - -import ( - "encoding/json" - "fmt" - "io" - "net/url" - - "github.com/go-acme/lego/v4/challenge/dns01" -) - -type apiResponse struct { - Domain string - Success bool -} - -// updateTxtRecord Update the domains TXT record -// To update the TXT record we just need to make one simple get request. -func (d *DNSProvider) updateTxtRecord(fqdn, token, txt string, clear bool) error { - u, _ := url.Parse("https://www.do.de/api/letsencrypt") - - query := u.Query() - query.Set("token", token) - query.Set("domain", dns01.UnFqdn(fqdn)) - - // api call differs per set/delete - if clear { - query.Set("action", "delete") - } else { - query.Set("value", txt) - } - - u.RawQuery = query.Encode() - - response, err := d.config.HTTPClient.Get(u.String()) - if err != nil { - return err - } - defer response.Body.Close() - - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - return err - } - - var r apiResponse - err = json.Unmarshal(bodyBytes, &r) - if err != nil { - return fmt.Errorf("request to change TXT record for do.de returned the following invalid json (%s); used url [%s]", string(bodyBytes), u) - } - - body := string(bodyBytes) - if !r.Success { - return fmt.Errorf("request to change TXT record for do.de returned the following error result (%s); used url [%s]", body, u) - } - return nil -} diff --git a/providers/dns/dode/dode.go b/providers/dns/dode/dode.go index 18c16b10..04393fb0 100644 --- a/providers/dns/dode/dode.go +++ b/providers/dns/dode/dode.go @@ -2,6 +2,7 @@ package dode import ( + "context" "errors" "fmt" "net/http" @@ -9,6 +10,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/dode/internal" ) // Environment variables names. @@ -47,6 +49,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a new DNS provider using @@ -73,19 +76,25 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("do.de: credentials missing") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.Token) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - return d.updateTxtRecord(info.EffectiveFQDN, d.config.Token, info.Value, false) + return d.client.UpdateTxtRecord(context.Background(), info.EffectiveFQDN, info.Value, false) } // CleanUp clears TXT record. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - return d.updateTxtRecord(info.EffectiveFQDN, d.config.Token, "", true) + return d.client.UpdateTxtRecord(context.Background(), info.EffectiveFQDN, "", true) } // Timeout returns the timeout and interval to use when checking for DNS propagation. diff --git a/providers/dns/dode/dode_test.go b/providers/dns/dode/dode_test.go index 67a61433..3d8e9395 100644 --- a/providers/dns/dode/dode_test.go +++ b/providers/dns/dode/dode_test.go @@ -10,8 +10,7 @@ import ( const envDomain = envNamespace + "DOMAIN" -var envTest = tester.NewEnvTest(EnvToken). - WithDomain(envDomain) +var envTest = tester.NewEnvTest(EnvToken).WithDomain(envDomain) func TestNewDNSProvider(t *testing.T) { testCases := []struct { diff --git a/providers/dns/dode/internal/client.go b/providers/dns/dode/internal/client.go new file mode 100644 index 00000000..4568cd9b --- /dev/null +++ b/providers/dns/dode/internal/client.go @@ -0,0 +1,84 @@ +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const defaultBaseURL = "https://www.do.de/api" + +// Client the do.de API client. +type Client struct { + token string + + baseURL *url.URL + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(token string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + token: token, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// UpdateTxtRecord Update the domains TXT record +// To update the TXT record we just need to make one simple get request. +func (c Client) UpdateTxtRecord(ctx context.Context, fqdn, txt string, clear bool) error { + endpoint := c.baseURL.JoinPath("letsencrypt") + + query := endpoint.Query() + query.Set("token", c.token) + query.Set("domain", dns01.UnFqdn(fqdn)) + + // api call differs per set/delete + if clear { + query.Set("action", "delete") + } else { + query.Set("value", txt) + } + + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), http.NoBody) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var response apiResponse + err = json.Unmarshal(raw, &response) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + body := string(raw) + if !response.Success { + return fmt.Errorf("request to change TXT record for do.de returned the following error result (%s); used url [%s]", body, endpoint) + } + + return nil +} diff --git a/providers/dns/dode/internal/client_test.go b/providers/dns/dode/internal/client_test.go new file mode 100644 index 00000000..116ca8c4 --- /dev/null +++ b/providers/dns/dode/internal/client_test.go @@ -0,0 +1,93 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + query := req.URL.Query() + if query.Get("token") != "secret" { + http.Error(rw, fmt.Sprintf("invalid credentials: %q", query.Get("token")), http.StatusUnauthorized) + return + } + + if query.Get("domain") != "example.com" { + http.Error(rw, fmt.Sprintf("invalid domain: %q", query.Get("domain")), http.StatusBadRequest) + return + } + + if query.Has("action") { + if query.Get("action") != "delete" { + http.Error(rw, fmt.Sprintf("invalid action: %q", query.Get("action")), http.StatusBadRequest) + return + } + } else { + if query.Get("value") != "value" { + http.Error(rw, fmt.Sprintf("invalid value: %q", query.Get("value")), http.StatusBadRequest) + return + } + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + client := NewClient("secret") + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_UpdateTxtRecord(t *testing.T) { + client := setupTest(t, http.MethodGet, "/letsencrypt", http.StatusOK, "success.json") + + err := client.UpdateTxtRecord(context.Background(), "example.com.", "value", false) + require.NoError(t, err) +} + +func TestClient_UpdateTxtRecord_clear(t *testing.T) { + client := setupTest(t, http.MethodGet, "/letsencrypt", http.StatusOK, "success.json") + + err := client.UpdateTxtRecord(context.Background(), "example.com.", "value", true) + require.NoError(t, err) +} diff --git a/providers/dns/dode/internal/fixtures/success.json b/providers/dns/dode/internal/fixtures/success.json new file mode 100644 index 00000000..d6622346 --- /dev/null +++ b/providers/dns/dode/internal/fixtures/success.json @@ -0,0 +1,4 @@ +{ + "Domain" : "example.com", + "Success": true +} diff --git a/providers/dns/dode/internal/types.go b/providers/dns/dode/internal/types.go new file mode 100644 index 00000000..cc95ba14 --- /dev/null +++ b/providers/dns/dode/internal/types.go @@ -0,0 +1,6 @@ +package internal + +type apiResponse struct { + Domain string + Success bool +} diff --git a/providers/dns/domeneshop/domeneshop.go b/providers/dns/domeneshop/domeneshop.go index c71cd1f9..c9f7fcd9 100644 --- a/providers/dns/domeneshop/domeneshop.go +++ b/providers/dns/domeneshop/domeneshop.go @@ -2,6 +2,7 @@ package domeneshop import ( + "context" "errors" "fmt" "net/http" @@ -100,12 +101,14 @@ func (d *DNSProvider) Present(domain, _, keyAuth string) error { return fmt.Errorf("domeneshop: %w", err) } - domainInstance, err := d.client.GetDomainByName(zone) + ctx := context.Background() + + domainInstance, err := d.client.GetDomainByName(ctx, zone) if err != nil { return fmt.Errorf("domeneshop: %w", err) } - err = d.client.CreateTXTRecord(domainInstance, host, info.Value) + err = d.client.CreateTXTRecord(ctx, domainInstance, host, info.Value) if err != nil { return fmt.Errorf("domeneshop: failed to create record: %w", err) } @@ -122,12 +125,14 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { return fmt.Errorf("domeneshop: %w", err) } - domainInstance, err := d.client.GetDomainByName(zone) + ctx := context.Background() + + domainInstance, err := d.client.GetDomainByName(ctx, zone) if err != nil { return fmt.Errorf("domeneshop: %w", err) } - if err := d.client.DeleteTXTRecord(domainInstance, host, info.Value); err != nil { + if err := d.client.DeleteTXTRecord(ctx, domainInstance, host, info.Value); err != nil { return fmt.Errorf("domeneshop: failed to create record: %w", err) } @@ -138,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { func (d *DNSProvider) splitDomain(fqdn string) (string, string, error) { zone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", "", err + return "", "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } subDomain, err := dns01.ExtractSubDomain(fqdn, zone) diff --git a/providers/dns/domeneshop/internal/client.go b/providers/dns/domeneshop/internal/client.go index f578fb42..9b48d326 100644 --- a/providers/dns/domeneshop/internal/client.go +++ b/providers/dns/domeneshop/internal/client.go @@ -2,11 +2,16 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" + "net/url" + "strconv" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL string = "https://api.domeneshop.no/v0" @@ -15,28 +20,38 @@ const defaultBaseURL string = "https://api.domeneshop.no/v0" // For now it will only deal with adding and removing TXT records, as required by ACME providers. // https://api.domeneshop.no/docs/ type Client struct { + apiToken string + apiSecret string + + baseURL *url.URL HTTPClient *http.Client - baseURL string - apiToken string - apiSecret string } // NewClient returns an instance of the Domeneshop API wrapper. func NewClient(apiToken, apiSecret string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - baseURL: defaultBaseURL, apiToken: apiToken, apiSecret: apiSecret, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetDomainByName fetches the domain list and returns the Domain object for the matching domain. // https://api.domeneshop.no/docs/#operation/getDomains -func (c *Client) GetDomainByName(domain string) (*Domain, error) { +func (c *Client) GetDomainByName(ctx context.Context, domain string) (*Domain, error) { + endpoint := c.baseURL.JoinPath("domains") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + var domains []Domain - err := c.doRequest(http.MethodGet, "domains", nil, &domains) + err = c.do(req, &domains) if err != nil { return nil, err } @@ -57,37 +72,55 @@ func (c *Client) GetDomainByName(domain string) (*Domain, error) { // CreateTXTRecord creates a TXT record with the provided host (subdomain) and data. // https://api.domeneshop.no/docs/#tag/dns/paths/~1domains~1{domainId}~1dns/post -func (c *Client) CreateTXTRecord(domain *Domain, host string, data string) error { - jsonRecord, err := json.Marshal(DNSRecord{ +func (c *Client) CreateTXTRecord(ctx context.Context, domain *Domain, host string, data string) error { + endpoint := c.baseURL.JoinPath("domains", strconv.Itoa(domain.ID), "dns") + + record := DNSRecord{ Data: data, Host: host, TTL: 300, Type: "TXT", - }) + } + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return err } - return c.doRequest(http.MethodPost, fmt.Sprintf("domains/%d/dns", domain.ID), jsonRecord, nil) + return c.do(req, nil) } // DeleteTXTRecord deletes the DNS record matching the provided host and data. // https://api.domeneshop.no/docs/#tag/dns/paths/~1domains~1{domainId}~1dns~1{recordId}/delete -func (c *Client) DeleteTXTRecord(domain *Domain, host string, data string) error { - record, err := c.getDNSRecordByHostData(*domain, host, data) +func (c *Client) DeleteTXTRecord(ctx context.Context, domain *Domain, host string, data string) error { + record, err := c.getDNSRecordByHostData(ctx, *domain, host, data) if err != nil { return err } - return c.doRequest(http.MethodDelete, fmt.Sprintf("domains/%d/dns/%d", domain.ID, record.ID), nil, nil) + endpoint := c.baseURL.JoinPath("domains", strconv.Itoa(domain.ID), "dns", strconv.Itoa(record.ID)) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) } // getDNSRecordByHostData finds the first matching DNS record with the provided host and data. // https://api.domeneshop.no/docs/#operation/getDnsRecords -func (c *Client) getDNSRecordByHostData(domain Domain, host string, data string) (*DNSRecord, error) { +func (c *Client) getDNSRecordByHostData(ctx context.Context, domain Domain, host string, data string) (*DNSRecord, error) { + endpoint := c.baseURL.JoinPath("domains", strconv.Itoa(domain.ID), "dns") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + var records []DNSRecord - err := c.doRequest(http.MethodGet, fmt.Sprintf("domains/%d/dns", domain.ID), nil, &records) + err = c.do(req, &records) if err != nil { return nil, err } @@ -101,35 +134,59 @@ func (c *Client) getDNSRecordByHostData(domain Domain, host string, data string) return nil, fmt.Errorf("failed to find record with host %s for domain %s", host, domain.Name) } -// doRequest makes a request against the API with an optional body, +// do a request against the API, // and makes sure that the required Authorization header is set using `setBasicAuth`. -func (c *Client) doRequest(method string, endpoint string, reqBody []byte, v interface{}) error { - req, err := http.NewRequest(method, fmt.Sprintf("%s/%s", c.baseURL, endpoint), bytes.NewBuffer(reqBody)) - if err != nil { - return err - } - +func (c *Client) do(req *http.Request, result any) error { req.SetBasicAuth(c.apiToken, c.apiSecret) resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= http.StatusBadRequest { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - return fmt.Errorf("API returned %s: %s", resp.Status, respBody) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - if v != nil { - return json.NewDecoder(resp.Body).Decode(&v) + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/domeneshop/internal/client_test.go b/providers/dns/domeneshop/internal/client_test.go index 569ca403..71205cac 100644 --- a/providers/dns/domeneshop/internal/client_test.go +++ b/providers/dns/domeneshop/internal/client_test.go @@ -1,31 +1,34 @@ package internal import ( + "context" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setup(t *testing.T) (*Client, *http.ServeMux) { +const authorizationHeader = "Authorization" + +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() - server := httptest.NewServer(mux) t.Cleanup(server.Close) client := NewClient("token", "secret") - - client.baseURL = server.URL + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) return client, mux } func TestClient_CreateTXTRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/1/dns", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -33,21 +36,21 @@ func TestClient_CreateTXTRecord(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "Basic dG9rZW46c2VjcmV0" { - http.Error(rw, "invalid method: "+req.Method, http.StatusUnauthorized) + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) return } _, _ = rw.Write([]byte(`{"id": 1}`)) }) - err := client.CreateTXTRecord(&Domain{ID: 1}, "example", "txtTXTtxt") + err := client.CreateTXTRecord(context.Background(), &Domain{ID: 1}, "example", "txtTXTtxt") require.NoError(t, err) } func TestClient_DeleteTXTRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/1/dns", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -55,9 +58,9 @@ func TestClient_DeleteTXTRecord(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "Basic dG9rZW46c2VjcmV0" { - http.Error(rw, "invalid method: "+req.Method, http.StatusUnauthorized) + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) return } @@ -78,19 +81,19 @@ func TestClient_DeleteTXTRecord(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "Basic dG9rZW46c2VjcmV0" { - http.Error(rw, "invalid method: "+req.Method, http.StatusUnauthorized) + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) return } }) - err := client.DeleteTXTRecord(&Domain{ID: 1}, "example.com", "txtTXTtxt") + err := client.DeleteTXTRecord(context.Background(), &Domain{ID: 1}, "example.com", "txtTXTtxt") require.NoError(t, err) } func TestClient_getDNSRecordByHostData(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/1/dns", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -98,9 +101,9 @@ func TestClient_getDNSRecordByHostData(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "Basic dG9rZW46c2VjcmV0" { - http.Error(rw, "invalid method: "+req.Method, http.StatusUnauthorized) + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) return } @@ -115,7 +118,7 @@ func TestClient_getDNSRecordByHostData(t *testing.T) { ]`)) }) - record, err := client.getDNSRecordByHostData(Domain{ID: 1}, "example.com", "txtTXTtxt") + record, err := client.getDNSRecordByHostData(context.Background(), Domain{ID: 1}, "example.com", "txtTXTtxt") require.NoError(t, err) expected := &DNSRecord{ @@ -130,7 +133,7 @@ func TestClient_getDNSRecordByHostData(t *testing.T) { } func TestClient_GetDomainByName(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/domains", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -138,9 +141,9 @@ func TestClient_GetDomainByName(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "Basic dG9rZW46c2VjcmV0" { - http.Error(rw, "invalid method: "+req.Method, http.StatusUnauthorized) + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) return } @@ -168,7 +171,7 @@ func TestClient_GetDomainByName(t *testing.T) { ]`)) }) - domain, err := client.GetDomainByName("example.com") + domain, err := client.GetDomainByName(context.Background(), "example.com") require.NoError(t, err) expected := &Domain{ diff --git a/providers/dns/dreamhost/client.go b/providers/dns/dreamhost/client.go deleted file mode 100644 index 8ddc4da8..00000000 --- a/providers/dns/dreamhost/client.go +++ /dev/null @@ -1,74 +0,0 @@ -package dreamhost - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - - "github.com/go-acme/lego/v4/log" -) - -const ( - defaultBaseURL = "https://api.dreamhost.com" - - cmdAddRecord = "dns-add_record" - cmdRemoveRecord = "dns-remove_record" -) - -type apiResponse struct { - Data string `json:"data"` - Result string `json:"result"` -} - -func (d *DNSProvider) buildQuery(action, domain, txt string) (*url.URL, error) { - u, err := url.Parse(d.config.BaseURL) - if err != nil { - return nil, err - } - - query := u.Query() - query.Set("key", d.config.APIKey) - query.Set("cmd", action) - query.Set("format", "json") - query.Set("record", domain) - query.Set("type", "TXT") - query.Set("value", txt) - query.Set("comment", url.QueryEscape("Managed By lego")) - u.RawQuery = query.Encode() - - return u, nil -} - -// updateTxtRecord will either add or remove a TXT record. -// action is either cmdAddRecord or cmdRemoveRecord. -func (d *DNSProvider) updateTxtRecord(u fmt.Stringer) error { - resp, err := d.config.HTTPClient.Get(u.String()) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("request failed with HTTP status code %d", resp.StatusCode) - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read body: %w", err) - } - - var response apiResponse - err = json.Unmarshal(raw, &response) - if err != nil { - return fmt.Errorf("unable to decode API server response: %w: %s", err, string(raw)) - } - - if response.Result == "error" { - return fmt.Errorf("add TXT record failed: %s", response.Data) - } - - log.Infof("dreamhost: %s", response.Data) - return nil -} diff --git a/providers/dns/dreamhost/dreamhost.go b/providers/dns/dreamhost/dreamhost.go index 56cf74fb..8f0c850d 100644 --- a/providers/dns/dreamhost/dreamhost.go +++ b/providers/dns/dreamhost/dreamhost.go @@ -4,6 +4,7 @@ package dreamhost import ( + "context" "errors" "fmt" "net/http" @@ -11,6 +12,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/dreamhost/internal" ) // Environment variables names. @@ -36,7 +38,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - BaseURL: defaultBaseURL, + BaseURL: internal.DefaultBaseURL, PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Minute), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 1*time.Minute), HTTPClient: &http.Client{ @@ -48,6 +50,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a new DNS provider using @@ -74,44 +77,39 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("dreamhost: credentials missing") } - if config.BaseURL == "" { - config.BaseURL = defaultBaseURL + client := internal.NewClient(config.APIKey) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient } - return &DNSProvider{config: config}, nil + if config.BaseURL != "" { + client.BaseURL = config.BaseURL + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record using the specified parameters. func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - record := dns01.UnFqdn(info.EffectiveFQDN) - - u, err := d.buildQuery(cmdAddRecord, record, info.Value) + err := d.client.AddRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value) if err != nil { return fmt.Errorf("dreamhost: %w", err) } - err = d.updateTxtRecord(u) - if err != nil { - return fmt.Errorf("dreamhost: %w", err) - } return nil } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - record := dns01.UnFqdn(info.EffectiveFQDN) - u, err := d.buildQuery(cmdRemoveRecord, record, info.Value) + err := d.client.RemoveRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value) if err != nil { return fmt.Errorf("dreamhost: %w", err) } - err = d.updateTxtRecord(u) - if err != nil { - return fmt.Errorf("dreamhost: %w", err) - } return nil } diff --git a/providers/dns/dreamhost/internal/client.go b/providers/dns/dreamhost/internal/client.go new file mode 100644 index 00000000..dee808ac --- /dev/null +++ b/providers/dns/dreamhost/internal/client.go @@ -0,0 +1,114 @@ +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultBaseURL the default API endpoint. +const DefaultBaseURL = "https://api.dreamhost.com" + +const ( + cmdAddRecord = "dns-add_record" + cmdRemoveRecord = "dns-remove_record" +) + +// Client the Dreamhost API client. +type Client struct { + apiKey string + + BaseURL string + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(apiKey string) *Client { + return &Client{ + apiKey: apiKey, + BaseURL: DefaultBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// AddRecord adds a TXT record. +func (c *Client) AddRecord(ctx context.Context, domain, value string) error { + query, err := c.buildEndpoint(cmdAddRecord, domain, value) + if err != nil { + return err + } + + return c.updateTxtRecord(ctx, query) +} + +// RemoveRecord removes a TXT record. +func (c *Client) RemoveRecord(ctx context.Context, domain, value string) error { + query, err := c.buildEndpoint(cmdRemoveRecord, domain, value) + if err != nil { + return err + } + + return c.updateTxtRecord(ctx, query) +} + +// action is either cmdAddRecord or cmdRemoveRecord. +func (c *Client) buildEndpoint(action, domain, txt string) (*url.URL, error) { + endpoint, err := url.Parse(c.BaseURL) + if err != nil { + return nil, err + } + + query := endpoint.Query() + query.Set("key", c.apiKey) + query.Set("cmd", action) + query.Set("format", "json") + query.Set("record", domain) + query.Set("type", "TXT") + query.Set("value", txt) + query.Set("comment", url.QueryEscape("Managed By lego")) + endpoint.RawQuery = query.Encode() + + return endpoint, nil +} + +// updateTxtRecord will either add or remove a TXT record. +func (c *Client) updateTxtRecord(ctx context.Context, endpoint *url.URL) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), http.NoBody) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var response apiResponse + err = json.Unmarshal(raw, &response) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + if response.Result == "error" { + return fmt.Errorf("add TXT record failed: %s", response.Data) + } + + return nil +} diff --git a/providers/dns/dreamhost/client_test.go b/providers/dns/dreamhost/internal/client_test.go similarity index 71% rename from providers/dns/dreamhost/client_test.go rename to providers/dns/dreamhost/internal/client_test.go index c8d195bd..348c50ce 100644 --- a/providers/dns/dreamhost/client_test.go +++ b/providers/dns/dreamhost/internal/client_test.go @@ -1,4 +1,4 @@ -package dreamhost +package internal import ( "testing" @@ -7,7 +7,9 @@ import ( "github.com/stretchr/testify/require" ) -func TestDNSProvider_buildQuery(t *testing.T) { +const fakeAPIKey = "asdf1234" + +func TestClient_buildQuery(t *testing.T) { testCases := []struct { desc string apiKey string @@ -40,23 +42,18 @@ func TestDNSProvider_buildQuery(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - config := NewDefaultConfig() - config.APIKey = test.apiKey + client := NewClient(test.apiKey) if test.baseURL != "" { - config.BaseURL = test.baseURL + client.BaseURL = test.baseURL } - provider, err := NewDNSProviderConfig(config) - require.NoError(t, err) - require.NotNil(t, provider) - - u, err := provider.buildQuery(test.action, test.domain, test.txt) + endpoint, err := client.buildEndpoint(test.action, test.domain, test.txt) if test.expected == "" { require.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expected, u.String()) + assert.Equal(t, test.expected, endpoint.String()) } }) } diff --git a/providers/dns/dreamhost/internal/types.go b/providers/dns/dreamhost/internal/types.go new file mode 100644 index 00000000..6a1e903f --- /dev/null +++ b/providers/dns/dreamhost/internal/types.go @@ -0,0 +1,6 @@ +package internal + +type apiResponse struct { + Data string `json:"data"` + Result string `json:"result"` +} diff --git a/providers/dns/duckdns/client.go b/providers/dns/duckdns/client.go deleted file mode 100644 index 5eb9cb44..00000000 --- a/providers/dns/duckdns/client.go +++ /dev/null @@ -1,68 +0,0 @@ -package duckdns - -import ( - "fmt" - "io" - "net/url" - "strconv" - "strings" - - "github.com/go-acme/lego/v4/challenge/dns01" - "github.com/miekg/dns" -) - -// updateTxtRecord Update the domains TXT record -// To update the TXT record we just need to make one simple get request. -// In DuckDNS you only have one TXT record shared with the domain and all sub domains. -func (d *DNSProvider) updateTxtRecord(domain, token, txt string, clear bool) error { - u, _ := url.Parse("https://www.duckdns.org/update") - - mainDomain := getMainDomain(domain) - if mainDomain == "" { - return fmt.Errorf("unable to find the main domain for: %s", domain) - } - - query := u.Query() - query.Set("domains", mainDomain) - query.Set("token", token) - query.Set("clear", strconv.FormatBool(clear)) - query.Set("txt", txt) - u.RawQuery = query.Encode() - - response, err := d.config.HTTPClient.Get(u.String()) - if err != nil { - return err - } - defer response.Body.Close() - - bodyBytes, err := io.ReadAll(response.Body) - if err != nil { - return err - } - - body := string(bodyBytes) - if body != "OK" { - return fmt.Errorf("request to change TXT record for DuckDNS returned the following result (%s) this does not match expectation (OK) used url [%s]", body, u) - } - return nil -} - -// DuckDNS only lets you write to your subdomain. -// It must be in format subdomain.duckdns.org, -// not in format subsubdomain.subdomain.duckdns.org. -// So strip off everything that is not top 3 levels. -func getMainDomain(domain string) string { - domain = dns01.UnFqdn(domain) - - split := dns.Split(domain) - if strings.HasSuffix(strings.ToLower(domain), "duckdns.org") { - if len(split) < 3 { - return "" - } - - firstSubDomainIndex := split[len(split)-3] - return domain[firstSubDomainIndex:] - } - - return domain[split[len(split)-1]:] -} diff --git a/providers/dns/duckdns/duckdns.go b/providers/dns/duckdns/duckdns.go index 65dafab7..8cb82aed 100644 --- a/providers/dns/duckdns/duckdns.go +++ b/providers/dns/duckdns/duckdns.go @@ -3,6 +3,7 @@ package duckdns import ( + "context" "errors" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/duckdns/internal" ) // Environment variables names. @@ -48,6 +50,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a new DNS provider using @@ -74,19 +77,25 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("duckdns: credentials missing") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.Token) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - return d.updateTxtRecord(dns01.UnFqdn(info.EffectiveFQDN), d.config.Token, info.Value, false) + return d.client.AddTXTRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value) } // CleanUp clears DuckDNS TXT record. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - return d.updateTxtRecord(dns01.UnFqdn(info.EffectiveFQDN), d.config.Token, "", true) + return d.client.RemoveTXTRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN)) } // Timeout returns the timeout and interval to use when checking for DNS propagation. diff --git a/providers/dns/duckdns/duckdns_test.go b/providers/dns/duckdns/duckdns_test.go index 0fd291ab..b89966a3 100644 --- a/providers/dns/duckdns/duckdns_test.go +++ b/providers/dns/duckdns/duckdns_test.go @@ -5,7 +5,6 @@ import ( "time" "github.com/go-acme/lego/v4/platform/tester" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -89,65 +88,6 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func Test_getMainDomain(t *testing.T) { - testCases := []struct { - desc string - domain string - expected string - }{ - { - desc: "empty", - domain: "", - expected: "", - }, - { - desc: "missing sub domain", - domain: "duckdns.org", - expected: "", - }, - { - desc: "explicit domain: sub domain", - domain: "_acme-challenge.sub.duckdns.org", - expected: "sub.duckdns.org", - }, - { - desc: "explicit domain: subsub domain", - domain: "_acme-challenge.my.sub.duckdns.org", - expected: "sub.duckdns.org", - }, - { - desc: "explicit domain: subsubsub domain", - domain: "_acme-challenge.my.sub.sub.duckdns.org", - expected: "sub.duckdns.org", - }, - { - desc: "only subname: sub domain", - domain: "_acme-challenge.sub", - expected: "sub", - }, - { - desc: "only subname: subsub domain", - domain: "_acme-challenge.my.sub", - expected: "sub", - }, - { - desc: "only subname: subsubsub domain", - domain: "_acme-challenge.my.sub.sub", - expected: "sub", - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - wDomain := getMainDomain(test.domain) - assert.Equal(t, test.expected, wDomain) - }) - } -} - func TestLivePresent(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test") diff --git a/providers/dns/duckdns/internal/client.go b/providers/dns/duckdns/internal/client.go new file mode 100644 index 00000000..e1985ee7 --- /dev/null +++ b/providers/dns/duckdns/internal/client.go @@ -0,0 +1,103 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "github.com/miekg/dns" +) + +const defaultBaseURL = "https://www.duckdns.org/update" + +// Client the DuckDNS API client. +type Client struct { + token string + + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(token string) *Client { + return &Client{ + token: token, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c Client) AddTXTRecord(ctx context.Context, domain, value string) error { + return c.UpdateTxtRecord(ctx, domain, value, false) +} + +func (c Client) RemoveTXTRecord(ctx context.Context, domain string) error { + return c.UpdateTxtRecord(ctx, domain, "", true) +} + +// UpdateTxtRecord Update the domains TXT record +// To update the TXT record we just need to make one simple get request. +// In DuckDNS you only have one TXT record shared with the domain and all subdomains. +func (c Client) UpdateTxtRecord(ctx context.Context, domain, txt string, clear bool) error { + endpoint, _ := url.Parse(defaultBaseURL) + + mainDomain := getMainDomain(domain) + if mainDomain == "" { + return fmt.Errorf("unable to find the main domain for: %s", domain) + } + + query := endpoint.Query() + query.Set("domains", mainDomain) + query.Set("token", c.token) + query.Set("clear", strconv.FormatBool(clear)) + query.Set("txt", txt) + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), http.NoBody) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + body := string(raw) + if body != "OK" { + return fmt.Errorf("request to change TXT record for DuckDNS returned the following result (%s) this does not match expectation (OK) used url [%s]", body, endpoint) + } + return nil +} + +// DuckDNS only lets you write to your subdomain. +// It must be in format subdomain.duckdns.org, +// not in format subsubdomain.subdomain.duckdns.org. +// So strip off everything that is not top 3 levels. +func getMainDomain(domain string) string { + domain = dns01.UnFqdn(domain) + + split := dns.Split(domain) + if strings.HasSuffix(strings.ToLower(domain), "duckdns.org") { + if len(split) < 3 { + return "" + } + + firstSubDomainIndex := split[len(split)-3] + return domain[firstSubDomainIndex:] + } + + return domain[split[len(split)-1]:] +} diff --git a/providers/dns/duckdns/internal/client_test.go b/providers/dns/duckdns/internal/client_test.go new file mode 100644 index 00000000..ec3196a7 --- /dev/null +++ b/providers/dns/duckdns/internal/client_test.go @@ -0,0 +1,66 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_getMainDomain(t *testing.T) { + testCases := []struct { + desc string + domain string + expected string + }{ + { + desc: "empty", + domain: "", + expected: "", + }, + { + desc: "missing sub domain", + domain: "duckdns.org", + expected: "", + }, + { + desc: "explicit domain: sub domain", + domain: "_acme-challenge.sub.duckdns.org", + expected: "sub.duckdns.org", + }, + { + desc: "explicit domain: subsub domain", + domain: "_acme-challenge.my.sub.duckdns.org", + expected: "sub.duckdns.org", + }, + { + desc: "explicit domain: subsubsub domain", + domain: "_acme-challenge.my.sub.sub.duckdns.org", + expected: "sub.duckdns.org", + }, + { + desc: "only subname: sub domain", + domain: "_acme-challenge.sub", + expected: "sub", + }, + { + desc: "only subname: subsub domain", + domain: "_acme-challenge.my.sub", + expected: "sub", + }, + { + desc: "only subname: subsubsub domain", + domain: "_acme-challenge.my.sub.sub", + expected: "sub", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + wDomain := getMainDomain(test.domain) + assert.Equal(t, test.expected, wDomain) + }) + } +} diff --git a/providers/dns/dyn/client.go b/providers/dns/dyn/client.go deleted file mode 100644 index fecef33c..00000000 --- a/providers/dns/dyn/client.go +++ /dev/null @@ -1,147 +0,0 @@ -package dyn - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "net/http" -) - -const defaultBaseURL = "https://api.dynect.net/REST" - -type dynResponse struct { - // One of 'success', 'failure', or 'incomplete' - Status string `json:"status"` - - // The structure containing the actual results of the request - Data json.RawMessage `json:"data"` - - // The ID of the job that was created in response to a request. - JobID int `json:"job_id"` - - // A list of zero or more messages - Messages json.RawMessage `json:"msgs"` -} - -type credentials struct { - Customer string `json:"customer_name"` - User string `json:"user_name"` - Pass string `json:"password"` -} - -type session struct { - Token string `json:"token"` - Version string `json:"version"` -} - -type publish struct { - Publish bool `json:"publish"` - Notes string `json:"notes"` -} - -// Starts a new Dyn API Session. Authenticates using customerName, userName, -// password and receives a token to be used in for subsequent requests. -func (d *DNSProvider) login() error { - payload := &credentials{Customer: d.config.CustomerName, User: d.config.UserName, Pass: d.config.Password} - dynRes, err := d.sendRequest(http.MethodPost, "Session", payload) - if err != nil { - return err - } - - var s session - err = json.Unmarshal(dynRes.Data, &s) - if err != nil { - return err - } - - d.token = s.Token - - return nil -} - -// Destroys Dyn Session. -func (d *DNSProvider) logout() error { - if d.token == "" { - // nothing to do - return nil - } - - url := fmt.Sprintf("%s/Session", defaultBaseURL) - req, err := http.NewRequest(http.MethodDelete, url, nil) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Auth-Token", d.token) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("API request failed to delete session with HTTP status code %d", resp.StatusCode) - } - - d.token = "" - - return nil -} - -func (d *DNSProvider) publish(zone, notes string) error { - pub := &publish{Publish: true, Notes: notes} - resource := fmt.Sprintf("Zone/%s/", zone) - - _, err := d.sendRequest(http.MethodPut, resource, pub) - return err -} - -func (d *DNSProvider) sendRequest(method, resource string, payload interface{}) (*dynResponse, error) { - url := fmt.Sprintf("%s/%s", defaultBaseURL, resource) - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - if len(d.token) > 0 { - req.Header.Set("Auth-Token", d.token) - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusInternalServerError { - return nil, fmt.Errorf("API request failed with HTTP status code %d", resp.StatusCode) - } - - var dynRes dynResponse - err = json.NewDecoder(resp.Body).Decode(&dynRes) - if err != nil { - return nil, err - } - - if resp.StatusCode >= http.StatusBadRequest { - return nil, fmt.Errorf("API request failed with HTTP status code %d: %s", resp.StatusCode, dynRes.Messages) - } else if resp.StatusCode == http.StatusTemporaryRedirect { - // TODO add support for HTTP 307 response and long running jobs - return nil, errors.New("API request returned HTTP 307. This is currently unsupported") - } - - if dynRes.Status == "failure" { - // TODO add better error handling - return nil, fmt.Errorf("API request failed: %s", dynRes.Messages) - } - - return &dynRes, nil -} diff --git a/providers/dns/dyn/dyn.go b/providers/dns/dyn/dyn.go index d0dec91a..1b2d8254 100644 --- a/providers/dns/dyn/dyn.go +++ b/providers/dns/dyn/dyn.go @@ -2,14 +2,15 @@ package dyn import ( + "context" "errors" "fmt" "net/http" - "strconv" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/dyn/internal" ) // Environment variables names. @@ -52,7 +53,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config - token string + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for Dyn DNS. @@ -82,7 +83,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("dyn: credentials missing") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.CustomerName, config.UserName, config.Password) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record using the specified parameters. @@ -91,33 +98,25 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("dyn: %w", err) + return fmt.Errorf("dyn: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { return fmt.Errorf("dyn: %w", err) } - data := map[string]interface{}{ - "rdata": map[string]string{ - "txtdata": info.Value, - }, - "ttl": strconv.Itoa(d.config.TTL), - } - - resource := fmt.Sprintf("TXTRecord/%s/%s/", authZone, info.EffectiveFQDN) - _, err = d.sendRequest(http.MethodPost, resource, data) + err = d.client.AddTXTRecord(ctx, authZone, info.EffectiveFQDN, info.Value, d.config.TTL) if err != nil { return fmt.Errorf("dyn: %w", err) } - err = d.publish(authZone, "Added TXT record for ACME dns-01 challenge using lego client") + err = d.client.Publish(ctx, authZone, "Added TXT record for ACME dns-01 challenge using lego client") if err != nil { return fmt.Errorf("dyn: %w", err) } - return d.logout() + return d.client.Logout(ctx) } // CleanUp removes the TXT record matching the specified parameters. @@ -126,41 +125,25 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("dyn: %w", err) + return fmt.Errorf("dyn: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { return fmt.Errorf("dyn: %w", err) } - resource := fmt.Sprintf("TXTRecord/%s/%s/", authZone, info.EffectiveFQDN) - url := fmt.Sprintf("%s/%s", defaultBaseURL, resource) - - req, err := http.NewRequest(http.MethodDelete, url, nil) + err = d.client.RemoveTXTRecord(ctx, authZone, info.EffectiveFQDN) if err != nil { return fmt.Errorf("dyn: %w", err) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Auth-Token", d.token) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("dyn: %w", err) - } - resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("dyn: API request failed to delete TXT record HTTP status code %d", resp.StatusCode) - } - - err = d.publish(authZone, "Removed TXT record for ACME dns-01 challenge using lego client") + err = d.client.Publish(ctx, authZone, "Removed TXT record for ACME dns-01 challenge using lego client") if err != nil { return fmt.Errorf("dyn: %w", err) } - return d.logout() + return d.client.Logout(ctx) } // Timeout returns the timeout and interval to use when checking for DNS propagation. diff --git a/providers/dns/dyn/internal/client.go b/providers/dns/dyn/internal/client.go new file mode 100644 index 00000000..43981cc4 --- /dev/null +++ b/providers/dns/dyn/internal/client.go @@ -0,0 +1,178 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const defaultBaseURL = "https://api.dynect.net/REST" + +// Client the Dyn API client. +type Client struct { + customerName string + username string + password string + + baseURL *url.URL + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(customerName string, username string, password string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + customerName: customerName, + username: username, + password: password, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// Publish updating Zone settings. +// https://help.dyn.com/update-zone-api/ +func (c *Client) Publish(ctx context.Context, zone, notes string) error { + endpoint := c.baseURL.JoinPath("Zone", zone) + + payload := &publish{Publish: true, Notes: notes} + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, payload) + if err != nil { + return err + } + + _, err = c.do(req) + if err != nil { + return err + } + + return nil +} + +// AddTXTRecord creating TXT Records. +// https://help.dyn.com/create-txt-record-api/ +func (c *Client) AddTXTRecord(ctx context.Context, authZone, fqdn, value string, ttl int) error { + endpoint := c.baseURL.JoinPath("TXTRecord", authZone, fqdn) + + payload := map[string]any{ + "rdata": map[string]string{ + "txtdata": value, + }, + "ttl": strconv.Itoa(ttl), + } + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, payload) + if err != nil { + return err + } + + _, err = c.do(req) + if err != nil { + return err + } + + return nil +} + +// RemoveTXTRecord deleting one or all existing TXT Records. +// https://help.dyn.com/delete-txt-records-api/ +func (c *Client) RemoveTXTRecord(ctx context.Context, authZone, fqdn string) error { + endpoint := c.baseURL.JoinPath("TXTRecord", authZone, fqdn) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + return nil +} + +func (c *Client) do(req *http.Request) (*APIResponse, error) { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusInternalServerError { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var response APIResponse + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + if resp.StatusCode >= http.StatusBadRequest { + return nil, fmt.Errorf("%s: %w", response.Messages, errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)) + } + + if resp.StatusCode == http.StatusTemporaryRedirect { + // TODO add support for HTTP 307 response and long running jobs + return nil, errors.New("API request returned HTTP 307. This is currently unsupported") + } + + if response.Status == "failure" { + return nil, fmt.Errorf("%s: %w", response.Messages, errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)) + } + + return &response, nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + tok := getToken(req.Context()) + if tok != "" { + req.Header.Set(authTokenHeader, tok) + } + + return req, nil +} diff --git a/providers/dns/dyn/internal/client_test.go b/providers/dns/dyn/internal/client_test.go new file mode 100644 index 00000000..87bee1cd --- /dev/null +++ b/providers/dns/dyn/internal/client_test.go @@ -0,0 +1,122 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, handlerFunc http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, handlerFunc) + + client := NewClient("bob", "user", "secret") + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + return client +} + +func authenticatedHandler(method string, status int, file string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + token := req.Header.Get(authTokenHeader) + if token != "tok" { + http.Error(rw, fmt.Sprintf("invalid credentials: %q", token), http.StatusUnauthorized) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + } +} + +func unauthenticatedHandler(method string, status int, file string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + token := req.Header.Get(authTokenHeader) + if token != "" { + http.Error(rw, fmt.Sprintf("invalid credentials: %q", token), http.StatusUnauthorized) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + } +} + +func TestClient_Publish(t *testing.T) { + client := setupTest(t, "/Zone/example.com", unauthenticatedHandler(http.MethodPut, http.StatusOK, "publish.json")) + + err := client.Publish(context.Background(), "example.com", "my message") + require.NoError(t, err) +} + +func TestClient_AddTXTRecord(t *testing.T) { + client := setupTest(t, "/TXTRecord/example.com/example.com.", unauthenticatedHandler(http.MethodPost, http.StatusCreated, "create-txt-record.json")) + + err := client.AddTXTRecord(context.Background(), "example.com", "example.com.", "txt", 120) + require.NoError(t, err) +} + +func TestClient_RemoveTXTRecord(t *testing.T) { + client := setupTest(t, "/TXTRecord/example.com/example.com.", unauthenticatedHandler(http.MethodDelete, http.StatusOK, "")) + + err := client.RemoveTXTRecord(context.Background(), "example.com", "example.com.") + require.NoError(t, err) +} diff --git a/providers/dns/dyn/internal/fixtures/create-txt-record.json b/providers/dns/dyn/internal/fixtures/create-txt-record.json new file mode 100644 index 00000000..fd09a5d4 --- /dev/null +++ b/providers/dns/dyn/internal/fixtures/create-txt-record.json @@ -0,0 +1,10 @@ +{ + "fqdn": "example.com.", + "rdata": { + "txtdata": "txt" + }, + "record_type": "TXT", + "ttl": 120, + "zone": "example.com" +} + diff --git a/providers/dns/dyn/internal/fixtures/login.json b/providers/dns/dyn/internal/fixtures/login.json new file mode 100644 index 00000000..86434d7b --- /dev/null +++ b/providers/dns/dyn/internal/fixtures/login.json @@ -0,0 +1,9 @@ +{ + "status": "success", + "data": { + "token": "tok", + "version": "456" + }, + "job_id": 123, + "msgs": [] +} diff --git a/providers/dns/dyn/internal/fixtures/publish.json b/providers/dns/dyn/internal/fixtures/publish.json new file mode 100644 index 00000000..c7e27945 --- /dev/null +++ b/providers/dns/dyn/internal/fixtures/publish.json @@ -0,0 +1,6 @@ +{ + "status": "success", + "data": {}, + "job_id": 123, + "msgs": [] +} diff --git a/providers/dns/dyn/internal/session.go b/providers/dns/dyn/internal/session.go new file mode 100644 index 00000000..647080fa --- /dev/null +++ b/providers/dns/dyn/internal/session.go @@ -0,0 +1,89 @@ +package internal + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +type token string + +const tokenKey token = "token" + +const authTokenHeader = "Auth-Token" + +// login Starts a new Dyn API Session. Authenticates using customerName, username, password +// and receives a token to be used in for subsequent requests. +// https://help.dyn.com/session-log-in/ +func (c *Client) login(ctx context.Context) (session, error) { + endpoint := c.baseURL.JoinPath("Session") + + payload := &credentials{Customer: c.customerName, User: c.username, Pass: c.password} + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, payload) + if err != nil { + return session{}, err + } + + dynRes, err := c.do(req) + if err != nil { + return session{}, err + } + + var s session + err = json.Unmarshal(dynRes.Data, &s) + if err != nil { + return session{}, errutils.NewUnmarshalError(req, http.StatusOK, dynRes.Data, err) + } + + return s, nil +} + +// Logout Destroys Dyn Session. +// https://help.dyn.com/session-log-out/ +func (c *Client) Logout(ctx context.Context) error { + endpoint := c.baseURL.JoinPath("Session") + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + tok := getToken(ctx) + if tok != "" { + req.Header.Set(authTokenHeader, tok) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + return nil +} + +func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) { + tok, err := c.login(ctx) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, tokenKey, tok.Token), nil +} + +func getToken(ctx context.Context) string { + tok, ok := ctx.Value(tokenKey).(string) + if !ok { + return "" + } + + return tok +} diff --git a/providers/dns/dyn/internal/session_test.go b/providers/dns/dyn/internal/session_test.go new file mode 100644 index 00000000..76d5bef4 --- /dev/null +++ b/providers/dns/dyn/internal/session_test.go @@ -0,0 +1,42 @@ +package internal + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockContext() context.Context { + return context.WithValue(context.Background(), tokenKey, "tok") +} + +func TestClient_login(t *testing.T) { + client := setupTest(t, "/Session", unauthenticatedHandler(http.MethodPost, http.StatusOK, "login.json")) + + sess, err := client.login(context.Background()) + require.NoError(t, err) + + expected := session{Token: "tok", Version: "456"} + + assert.Equal(t, expected, sess) +} + +func TestClient_Logout(t *testing.T) { + client := setupTest(t, "/Session", authenticatedHandler(http.MethodDelete, http.StatusOK, "")) + + err := client.Logout(mockContext()) + require.NoError(t, err) +} + +func TestClient_CreateAuthenticatedContext(t *testing.T) { + client := setupTest(t, "/Session", unauthenticatedHandler(http.MethodPost, http.StatusOK, "login.json")) + + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + at := getToken(ctx) + assert.Equal(t, "tok", at) +} diff --git a/providers/dns/dyn/internal/types.go b/providers/dns/dyn/internal/types.go new file mode 100644 index 00000000..2b039c4e --- /dev/null +++ b/providers/dns/dyn/internal/types.go @@ -0,0 +1,33 @@ +package internal + +import "encoding/json" + +type APIResponse struct { + // One of 'success', 'failure', or 'incomplete' + Status string `json:"status"` + + // The structure containing the actual results of the request + Data json.RawMessage `json:"data"` + + // The ID of the job that was created in response to a request. + JobID int `json:"job_id"` + + // A list of zero or more messages + Messages json.RawMessage `json:"msgs"` +} + +type credentials struct { + Customer string `json:"customer_name"` + User string `json:"user_name"` + Pass string `json:"password"` +} + +type session struct { + Token string `json:"token"` + Version string `json:"version"` +} + +type publish struct { + Publish bool `json:"publish"` + Notes string `json:"notes"` +} diff --git a/providers/dns/dynu/dynu.go b/providers/dns/dynu/dynu.go index a533a1c2..d0c396a2 100644 --- a/providers/dns/dynu/dynu.go +++ b/providers/dns/dynu/dynu.go @@ -2,6 +2,7 @@ package dynu import ( + "context" "errors" "fmt" "net/http" @@ -97,12 +98,14 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - rootDomain, err := d.client.GetRootDomain(dns01.UnFqdn(info.EffectiveFQDN)) + ctx := context.Background() + + rootDomain, err := d.client.GetRootDomain(ctx, dns01.UnFqdn(info.EffectiveFQDN)) if err != nil { return fmt.Errorf("dynu: could not find root domain for %s: %w", domain, err) } - records, err := d.client.GetRecords(dns01.UnFqdn(info.EffectiveFQDN), "TXT") + records, err := d.client.GetRecords(ctx, dns01.UnFqdn(info.EffectiveFQDN), "TXT") if err != nil { return fmt.Errorf("dynu: failed to get records for %s: %w", domain, err) } @@ -129,7 +132,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - err = d.client.AddNewRecord(rootDomain.ID, record) + err = d.client.AddNewRecord(ctx, rootDomain.ID, record) if err != nil { return fmt.Errorf("dynu: failed to add record to %s: %w", domain, err) } @@ -141,19 +144,21 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - rootDomain, err := d.client.GetRootDomain(dns01.UnFqdn(info.EffectiveFQDN)) + ctx := context.Background() + + rootDomain, err := d.client.GetRootDomain(ctx, dns01.UnFqdn(info.EffectiveFQDN)) if err != nil { return fmt.Errorf("dynu: could not find root domain for %s: %w", domain, err) } - records, err := d.client.GetRecords(dns01.UnFqdn(info.EffectiveFQDN), "TXT") + records, err := d.client.GetRecords(ctx, dns01.UnFqdn(info.EffectiveFQDN), "TXT") if err != nil { return fmt.Errorf("dynu: failed to get records for %s: %w", domain, err) } for _, record := range records { if record.Hostname == dns01.UnFqdn(info.EffectiveFQDN) && record.TextData == info.Value { - err = d.client.DeleteRecord(rootDomain.ID, record.ID) + err = d.client.DeleteRecord(ctx, rootDomain.ID, record.ID) if err != nil { return fmt.Errorf("dynu: failed to remove TXT record for %s: %w", domain, err) } diff --git a/providers/dns/dynu/internal/client.go b/providers/dns/dynu/internal/client.go index a65681ca..d9e6e5bf 100644 --- a/providers/dns/dynu/internal/client.go +++ b/providers/dns/dynu/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -13,35 +14,35 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.dynu.com/v2" type Client struct { + baseURL *url.URL HTTPClient *http.Client - BaseURL string } func NewClient() *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + baseURL: baseURL, } } // GetRecords Get DNS records based on a hostname and resource record type. -func (c Client) GetRecords(hostname, recordType string) ([]DNSRecord, error) { - endpoint, err := c.createEndpoint("dns", "record", hostname) - if err != nil { - return nil, err - } +func (c Client) GetRecords(ctx context.Context, hostname, recordType string) ([]DNSRecord, error) { + endpoint := c.baseURL.JoinPath("dns", "record", hostname) query := endpoint.Query() query.Set("recordType", recordType) endpoint.RawQuery = query.Encode() apiResp := RecordsResponse{} - err = c.doRetry(http.MethodGet, endpoint.String(), nil, &apiResp) + err := c.doRetry(ctx, http.MethodGet, endpoint.String(), nil, &apiResp) if err != nil { return nil, err } @@ -54,19 +55,16 @@ func (c Client) GetRecords(hostname, recordType string) ([]DNSRecord, error) { } // AddNewRecord Add a new DNS record for DNS service. -func (c Client) AddNewRecord(domainID int64, record DNSRecord) error { - endpoint, err := c.createEndpoint("dns", strconv.FormatInt(domainID, 10), "record") - if err != nil { - return err - } +func (c Client) AddNewRecord(ctx context.Context, domainID int64, record DNSRecord) error { + endpoint := c.baseURL.JoinPath("dns", strconv.FormatInt(domainID, 10), "record") reqBody, err := json.Marshal(record) if err != nil { - return err + return fmt.Errorf("failed to create request JSON body: %w", err) } apiResp := RecordResponse{} - err = c.doRetry(http.MethodPost, endpoint.String(), reqBody, &apiResp) + err = c.doRetry(ctx, http.MethodPost, endpoint.String(), reqBody, &apiResp) if err != nil { return err } @@ -79,14 +77,11 @@ func (c Client) AddNewRecord(domainID int64, record DNSRecord) error { } // DeleteRecord Remove a DNS record from DNS service. -func (c Client) DeleteRecord(domainID, recordID int64) error { - endpoint, err := c.createEndpoint("dns", strconv.FormatInt(domainID, 10), "record", strconv.FormatInt(recordID, 10)) - if err != nil { - return err - } +func (c Client) DeleteRecord(ctx context.Context, domainID, recordID int64) error { + endpoint := c.baseURL.JoinPath("dns", strconv.FormatInt(domainID, 10), "record", strconv.FormatInt(recordID, 10)) apiResp := APIException{} - err = c.doRetry(http.MethodDelete, endpoint.String(), nil, &apiResp) + err := c.doRetry(ctx, http.MethodDelete, endpoint.String(), nil, &apiResp) if err != nil { return err } @@ -99,14 +94,11 @@ func (c Client) DeleteRecord(domainID, recordID int64) error { } // GetRootDomain Get the root domain name based on a hostname. -func (c Client) GetRootDomain(hostname string) (*DNSHostname, error) { - endpoint, err := c.createEndpoint("dns", "getroot", hostname) - if err != nil { - return nil, err - } +func (c Client) GetRootDomain(ctx context.Context, hostname string) (*DNSHostname, error) { + endpoint := c.baseURL.JoinPath("dns", "getroot", hostname) apiResp := DNSHostname{} - err = c.doRetry(http.MethodGet, endpoint.String(), nil, &apiResp) + err := c.doRetry(ctx, http.MethodGet, endpoint.String(), nil, &apiResp) if err != nil { return nil, err } @@ -119,33 +111,9 @@ func (c Client) GetRootDomain(hostname string) (*DNSHostname, error) { } // doRetry the API is really unstable so we need to retry on EOF. -func (c Client) doRetry(method, uri string, body []byte, data interface{}) error { - var resp *http.Response - +func (c Client) doRetry(ctx context.Context, method, uri string, body []byte, result any) error { operation := func() error { - var reqBody io.Reader - if len(body) > 0 { - reqBody = bytes.NewReader(body) - } - - req, err := http.NewRequest(method, uri, reqBody) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - - resp, err = c.HTTPClient.Do(req) - if errors.Is(err, io.EOF) { - return err - } - - if err != nil { - return backoff.Permanent(fmt.Errorf("client error: %w", err)) - } - - return nil + return c.do(ctx, method, uri, body, result) } notify := func(err error, duration time.Duration) { @@ -160,21 +128,43 @@ func (c Client) doRetry(method, uri string, body []byte, data interface{}) error return err } + return nil +} + +func (c Client) do(ctx context.Context, method, uri string, body []byte, result any) error { + var reqBody io.Reader + if len(body) > 0 { + reqBody = bytes.NewReader(body) + } + + req, err := http.NewRequestWithContext(ctx, method, uri, reqBody) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if errors.Is(err, io.EOF) { + return err + } + + if err != nil { + return backoff.Permanent(fmt.Errorf("client error: %w", errutils.NewHTTPDoError(req, err))) + } + defer func() { _ = resp.Body.Close() }() - all, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return backoff.Permanent(errutils.NewReadResponseError(req, resp.StatusCode, err)) } - return json.Unmarshal(all, data) -} - -func (c Client) createEndpoint(fragments ...string) (*url.URL, error) { - baseURL, err := url.Parse(c.BaseURL) + err = json.Unmarshal(raw, result) if err != nil { - return nil, err + return backoff.Permanent(errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)) } - return baseURL.JoinPath(fragments...), nil + return nil } diff --git a/providers/dns/dynu/internal/client_test.go b/providers/dns/dynu/internal/client_test.go index 56eef940..005ceff1 100644 --- a/providers/dns/dynu/internal/client_test.go +++ b/providers/dns/dynu/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -43,7 +45,7 @@ func setupTest(t *testing.T, method, pattern string, status int, file string) *C client := NewClient() client.HTTPClient = server.Client() - client.BaseURL = server.URL + client.baseURL, _ = url.Parse(server.URL) return client } @@ -96,7 +98,7 @@ func TestGetRootDomain(t *testing.T) { client := setupTest(t, http.MethodGet, test.pattern, test.status, test.file) - domain, err := client.GetRootDomain("test.lego.freeddns.org") + domain, err := client.GetRootDomain(context.Background(), "test.lego.freeddns.org") if test.expected.error != "" { assert.EqualError(t, err, test.expected.error) @@ -185,7 +187,7 @@ func TestGetRecords(t *testing.T) { client := setupTest(t, http.MethodGet, test.pattern, test.status, test.file) - records, err := client.GetRecords("_acme-challenge.lego.freeddns.org", "TXT") + records, err := client.GetRecords(context.Background(), "_acme-challenge.lego.freeddns.org", "TXT") if test.expected.error != "" { assert.EqualError(t, err, test.expected.error) @@ -246,7 +248,7 @@ func TestAddNewRecord(t *testing.T) { TTL: 300, } - err := client.AddNewRecord(9007481, record) + err := client.AddNewRecord(context.Background(), 9007481, record) if test.expected.error != "" { assert.EqualError(t, err, test.expected.error) @@ -294,7 +296,7 @@ func TestDeleteRecord(t *testing.T) { client := setupTest(t, http.MethodDelete, test.pattern, test.status, test.file) - err := client.DeleteRecord(9007481, 6041418) + err := client.DeleteRecord(context.Background(), 9007481, 6041418) if test.expected.error != "" { assert.EqualError(t, err, test.expected.error) diff --git a/providers/dns/dynu/internal/model.go b/providers/dns/dynu/internal/types.go similarity index 100% rename from providers/dns/dynu/internal/model.go rename to providers/dns/dynu/internal/types.go diff --git a/providers/dns/easydns/client.go b/providers/dns/easydns/client.go deleted file mode 100644 index 2d3950ef..00000000 --- a/providers/dns/easydns/client.go +++ /dev/null @@ -1,96 +0,0 @@ -package easydns - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" -) - -const defaultEndpoint = "https://rest.easydns.net" - -type zoneRecord struct { - ID string `json:"id,omitempty"` - Domain string `json:"domain"` - Host string `json:"host"` - TTL string `json:"ttl"` - Prio string `json:"prio"` - Type string `json:"type"` - Rdata string `json:"rdata"` - LastMod string `json:"last_mod,omitempty"` - Revoked int `json:"revoked,omitempty"` - NewHost string `json:"new_host,omitempty"` -} - -type addRecordResponse struct { - Msg string `json:"msg"` - Tm int `json:"tm"` - Data zoneRecord `json:"data"` - Status int `json:"status"` -} - -func (d *DNSProvider) addRecord(domain string, record interface{}) (string, error) { - endpoint := d.config.Endpoint.JoinPath("zones", "records", "add", domain, "TXT") - - response := &addRecordResponse{} - err := d.doRequest(http.MethodPut, endpoint, record, response) - if err != nil { - return "", err - } - - recordID := response.Data.ID - - return recordID, nil -} - -func (d *DNSProvider) deleteRecord(domain, recordID string) error { - endpoint := d.config.Endpoint.JoinPath("zones", "records", domain, recordID) - - return d.doRequest(http.MethodDelete, endpoint, nil, nil) -} - -func (d *DNSProvider) doRequest(method string, endpoint *url.URL, requestMsg, responseMsg interface{}) error { - reqBody := &bytes.Buffer{} - if requestMsg != nil { - err := json.NewEncoder(reqBody).Encode(requestMsg) - if err != nil { - return err - } - } - - query := endpoint.Query() - query.Set("format", "json") - endpoint.RawQuery = query.Encode() - - request, err := http.NewRequest(method, endpoint.String(), reqBody) - if err != nil { - return err - } - - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "application/json") - request.SetBasicAuth(d.config.Token, d.config.Key) - - response, err := d.config.HTTPClient.Do(request) - if err != nil { - return err - } - defer response.Body.Close() - - if response.StatusCode >= http.StatusBadRequest { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("%d: failed to read response body: %w", response.StatusCode, err) - } - - return fmt.Errorf("%d: request failed: %v", response.StatusCode, string(body)) - } - - if responseMsg != nil { - return json.NewDecoder(response.Body).Decode(responseMsg) - } - - return nil -} diff --git a/providers/dns/easydns/easydns.go b/providers/dns/easydns/easydns.go index 5ec4b97c..40d2ec0c 100644 --- a/providers/dns/easydns/easydns.go +++ b/providers/dns/easydns/easydns.go @@ -2,6 +2,7 @@ package easydns import ( + "context" "errors" "fmt" "net/http" @@ -13,6 +14,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/easydns/internal" "github.com/miekg/dns" ) @@ -58,7 +60,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + recordIDs map[string]string recordIDsMu sync.Mutex } @@ -67,7 +71,7 @@ type DNSProvider struct { func NewDNSProvider() (*DNSProvider, error) { config := NewDefaultConfig() - endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, defaultEndpoint)) + endpoint, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, internal.DefaultBaseURL)) if err != nil { return nil, fmt.Errorf("easydns: %w", err) } @@ -98,7 +102,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("easydns: the API key is missing") } - return &DNSProvider{config: config, recordIDs: map[string]string{}}, nil + client := internal.NewClient(config.Token, config.Key) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + if config.Endpoint != nil { + client.BaseURL = config.Endpoint + } + + return &DNSProvider{config: config, client: client, recordIDs: map[string]string{}}, nil } // Present creates a TXT record to fulfill the dns-01 challenge. @@ -106,16 +120,17 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) apiHost, apiDomain := splitFqdn(info.EffectiveFQDN) - record := &zoneRecord{ - Domain: apiDomain, - Host: apiHost, - Type: "TXT", - Rdata: info.Value, - TTL: strconv.Itoa(d.config.TTL), - Prio: "0", + + record := internal.ZoneRecord{ + Domain: apiDomain, + Host: apiHost, + Type: "TXT", + Rdata: info.Value, + TTL: strconv.Itoa(d.config.TTL), + Priority: "0", } - recordID, err := d.addRecord(apiDomain, record) + recordID, err := d.client.AddRecord(context.Background(), apiDomain, record) if err != nil { return fmt.Errorf("easydns: error adding zone record: %w", err) } @@ -134,13 +149,18 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) key := getMapKey(info.EffectiveFQDN, info.Value) + + d.recordIDsMu.Lock() recordID, exists := d.recordIDs[key] + d.recordIDsMu.Unlock() + if !exists { return nil } _, apiDomain := splitFqdn(info.EffectiveFQDN) - err := d.deleteRecord(apiDomain, recordID) + + err := d.client.DeleteRecord(context.Background(), apiDomain, recordID) d.recordIDsMu.Lock() defer delete(d.recordIDs, key) diff --git a/providers/dns/easydns/easydns_test.go b/providers/dns/easydns/easydns_test.go index f67c7406..ea1f854c 100644 --- a/providers/dns/easydns/easydns_test.go +++ b/providers/dns/easydns/easydns_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/require" ) +const authorizationHeader = "Authorization" + const envDomain = envNamespace + "DOMAIN" var envTest = tester.NewEnvTest( @@ -149,7 +151,7 @@ func TestDNSProvider_Present(t *testing.T) { assert.Equal(t, http.MethodPut, r.Method, "method") assert.Equal(t, "format=json", r.URL.RawQuery, "query") assert.Equal(t, "application/json", r.Header.Get("Content-Type"), "Content-Type") - assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get("Authorization"), "Authorization") + assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader) reqBody, err := io.ReadAll(r.Body) if err != nil { @@ -201,7 +203,7 @@ func TestDNSProvider_Cleanup_WhenRecordIdSet_DeletesTxtRecord(t *testing.T) { mux.HandleFunc("/zones/records/example.com/123456", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodDelete, r.Method, "method") assert.Equal(t, "format=json", r.URL.RawQuery, "query") - assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get("Authorization"), "Authorization") + assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader) w.WriteHeader(http.StatusOK) _, err := fmt.Fprintf(w, `{ @@ -235,7 +237,7 @@ func TestDNSProvider_Cleanup_WhenHttpError_ReturnsError(t *testing.T) { mux.HandleFunc("/zones/records/example.com/123456", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodDelete, r.Method, "method") assert.Equal(t, "format=json", r.URL.RawQuery, "query") - assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get("Authorization"), "Authorization") + assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader) w.WriteHeader(http.StatusNotAcceptable) _, err := fmt.Fprint(w, errorMessage) @@ -247,7 +249,7 @@ func TestDNSProvider_Cleanup_WhenHttpError_ReturnsError(t *testing.T) { provider.recordIDs["_acme-challenge.example.com.|pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM"] = "123456" err := provider.CleanUp("example.com", "token", "keyAuth") - expectedError := fmt.Sprintf("easydns: 406: request failed: %v", errorMessage) + expectedError := fmt.Sprintf("easydns: unexpected status code: [status code: 406] body: %v", errorMessage) require.EqualError(t, err, expectedError) } diff --git a/providers/dns/easydns/internal/client.go b/providers/dns/easydns/internal/client.go new file mode 100644 index 00000000..363a2fc7 --- /dev/null +++ b/providers/dns/easydns/internal/client.go @@ -0,0 +1,127 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultBaseURL the default API endpoint. +const DefaultBaseURL = "https://rest.easydns.net" + +// Client the EasyDNS API client. +type Client struct { + token string + key string + + BaseURL *url.URL + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(token string, key string) *Client { + baseURL, _ := url.Parse(DefaultBaseURL) + + return &Client{ + token: token, + key: key, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c *Client) AddRecord(ctx context.Context, domain string, record ZoneRecord) (string, error) { + endpoint := c.BaseURL.JoinPath("zones", "records", "add", domain, "TXT") + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, record) + if err != nil { + return "", err + } + + response := &addRecordResponse{} + err = c.do(req, response) + if err != nil { + return "", err + } + + recordID := response.Data.ID + + return recordID, nil +} + +func (c *Client) DeleteRecord(ctx context.Context, domain, recordID string) error { + endpoint := c.BaseURL.JoinPath("zones", "records", domain, recordID) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + req.SetBasicAuth(c.token, c.key) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + query := endpoint.Query() + query.Set("format", "json") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/easydns/internal/client_test.go b/providers/dns/easydns/internal/client_test.go new file mode 100644 index 00000000..7ea61d3c --- /dev/null +++ b/providers/dns/easydns/internal/client_test.go @@ -0,0 +1,93 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + token, key, ok := req.BasicAuth() + if token != "tok" || key != "k" || !ok { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if req.URL.Query().Get("format") != "json" { + http.Error(rw, fmt.Sprintf("invalid format: %s", req.URL.Query().Get("format")), http.StatusBadRequest) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + client := NewClient("tok", "k") + client.HTTPClient = server.Client() + client.BaseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_AddRecord(t *testing.T) { + client := setupTest(t, http.MethodPut, "/zones/records/add/example.com/TXT", http.StatusCreated, "add-record.json") + + record := ZoneRecord{ + Domain: "example.com", + Host: "test631", + Type: "TXT", + Rdata: "txt", + TTL: "300", + Priority: "0", + } + + recordID, err := client.AddRecord(context.Background(), "example.com", record) + require.NoError(t, err) + + assert.Equal(t, "xxx", recordID) +} + +func TestClient_DeleteRecord(t *testing.T) { + client := setupTest(t, http.MethodDelete, "/zones/records/example.com/xxx", http.StatusOK, "") + + err := client.DeleteRecord(context.Background(), "example.com", "xxx") + require.NoError(t, err) +} diff --git a/providers/dns/easydns/internal/fixtures/add-record.json b/providers/dns/easydns/internal/fixtures/add-record.json new file mode 100644 index 00000000..66ddf4bc --- /dev/null +++ b/providers/dns/easydns/internal/fixtures/add-record.json @@ -0,0 +1,14 @@ +{ + "msg": "message", + "tm": 1, + "data": { + "id": "xxx", + "domain": "example.com", + "host": "test631", + "ttl": "300", + "prio": "0", + "type": "TXT", + "rdata": "txt" + }, + "status": 201 +} diff --git a/providers/dns/easydns/internal/types.go b/providers/dns/easydns/internal/types.go new file mode 100644 index 00000000..5235c4d7 --- /dev/null +++ b/providers/dns/easydns/internal/types.go @@ -0,0 +1,21 @@ +package internal + +type ZoneRecord struct { + ID string `json:"id,omitempty"` + Domain string `json:"domain"` + Host string `json:"host"` + TTL string `json:"ttl"` + Priority string `json:"prio"` + Type string `json:"type"` + Rdata string `json:"rdata"` + LastMod string `json:"last_mod,omitempty"` + Revoked int `json:"revoked,omitempty"` + NewHost string `json:"new_host,omitempty"` +} + +type addRecordResponse struct { + Msg string `json:"msg"` + Tm int `json:"tm"` + Data ZoneRecord `json:"data"` + Status int `json:"status"` +} diff --git a/providers/dns/edgedns/edgedns.go b/providers/dns/edgedns/edgedns.go index 221f5b71..df6d93e7 100644 --- a/providers/dns/edgedns/edgedns.go +++ b/providers/dns/edgedns/edgedns.go @@ -109,7 +109,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := findZone(info.EffectiveFQDN) + zone, err := getZone(info.EffectiveFQDN) if err != nil { return fmt.Errorf("edgedns: %w", err) } @@ -161,7 +161,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := findZone(info.EffectiveFQDN) + zone, err := getZone(info.EffectiveFQDN) if err != nil { return fmt.Errorf("edgedns: %w", err) } @@ -214,10 +214,10 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } -func findZone(domain string) (string, error) { +func getZone(domain string) (string, error) { zone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } return dns01.UnFqdn(zone), nil diff --git a/providers/dns/edgedns/edgedns_integration_test.go b/providers/dns/edgedns/edgedns_integration_test.go index 5ad0a5e3..c4044d1e 100644 --- a/providers/dns/edgedns/edgedns_integration_test.go +++ b/providers/dns/edgedns/edgedns_integration_test.go @@ -66,7 +66,7 @@ func TestLiveTTL(t *testing.T) { }() fqdn := "_acme-challenge." + domain + "." - zone, err := findZone(fqdn) + zone, err := getZone(fqdn) require.NoError(t, err) resourceRecordSets, err := configdns.GetRecordList(zone, fqdn, "TXT") diff --git a/providers/dns/edgedns/edgedns_test.go b/providers/dns/edgedns/edgedns_test.go index 3e855292..a7f17b16 100644 --- a/providers/dns/edgedns/edgedns_test.go +++ b/providers/dns/edgedns/edgedns_test.go @@ -173,7 +173,7 @@ func TestDNSProvider_findZone(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - zone, err := findZone(test.domain) + zone, err := getZone(test.domain) require.NoError(t, err) require.Equal(t, test.expected, zone) }) diff --git a/providers/dns/epik/epik.go b/providers/dns/epik/epik.go index 0c82b7dc..8114a21c 100644 --- a/providers/dns/epik/epik.go +++ b/providers/dns/epik/epik.go @@ -2,6 +2,7 @@ package epik import ( + "context" "errors" "fmt" "net/http" @@ -98,7 +99,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // find authZone authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("epik: %w", err) + return fmt.Errorf("epik: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -113,7 +114,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - _, err = d.client.CreateHostRecord(dns01.UnFqdn(authZone), record) + _, err = d.client.CreateHostRecord(context.Background(), dns01.UnFqdn(authZone), record) if err != nil { return fmt.Errorf("epik: %w", err) } @@ -128,12 +129,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { // find authZone authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("epik: %w", err) + return fmt.Errorf("epik: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } dom := dns01.UnFqdn(authZone) - records, err := d.client.GetDNSRecords(dom) + ctx := context.Background() + + records, err := d.client.GetDNSRecords(ctx, dom) if err != nil { return fmt.Errorf("epik: %w", err) } @@ -145,7 +148,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { for _, record := range records { if strings.EqualFold(record.Type, "TXT") && record.Data == info.Value && record.Name == subDomain { - _, err = d.client.RemoveHostRecord(dom, record.ID) + _, err = d.client.RemoveHostRecord(ctx, dom, record.ID) if err != nil { return fmt.Errorf("epik: %w", err) } diff --git a/providers/dns/epik/internal/client.go b/providers/dns/epik/internal/client.go index d5fb7829..0ca46c2c 100644 --- a/providers/dns/epik/internal/client.go +++ b/providers/dns/epik/internal/client.go @@ -2,56 +2,52 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://usersapiv2.epik.com/v2" +// Client the Epik API client. type Client struct { - HTTPClient *http.Client + signature string + baseURL *url.URL - signature string + HTTPClient *http.Client } +// NewClient Creates a new Client. func NewClient(signature string) *Client { baseURL, _ := url.Parse(defaultBaseURL) return &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - baseURL: baseURL, signature: signature, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetDNSRecords gets DNS records for a domain. // https://docs.userapi.epik.com/v2/#/DNS%20Host%20Records/getDnsRecord -func (c Client) GetDNSRecords(domain string) ([]Record, error) { - resp, err := c.do(http.MethodGet, domain, url.Values{}, nil) - if err != nil { - return nil, err - } +func (c Client) GetDNSRecords(ctx context.Context, domain string) ([]Record, error) { + endpoint := c.createEndpoint(domain, url.Values{}) - defer func() { _ = resp.Body.Close() }() - - all, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request body (%d): %w", resp.StatusCode, err) - } - - err = checkError(resp.StatusCode, all) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } var data GetDNSRecordResponse - err = json.Unmarshal(all, &data) + err = c.do(req, &data) if err != nil { - return nil, fmt.Errorf("failed to unmarshal request body (%d): %s", resp.StatusCode, string(all)) + return nil, err } return data.Data.Records, nil @@ -59,35 +55,20 @@ func (c Client) GetDNSRecords(domain string) ([]Record, error) { // CreateHostRecord creates a record for a domain. // https://docs.userapi.epik.com/v2/#/DNS%20Host%20Records/createHostRecord -func (c Client) CreateHostRecord(domain string, record RecordRequest) (*Data, error) { +func (c Client) CreateHostRecord(ctx context.Context, domain string, record RecordRequest) (*Data, error) { + endpoint := c.createEndpoint(domain, url.Values{}) + payload := CreateHostRecords{Payload: record} - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - resp, err := c.do(http.MethodPost, domain, url.Values{}, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - defer func() { _ = resp.Body.Close() }() - - all, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request body (%d): %w", resp.StatusCode, err) - } - - err = checkError(resp.StatusCode, all) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, payload) if err != nil { return nil, err } var data Data - err = json.Unmarshal(all, &data) + err = c.do(req, &data) if err != nil { - return nil, fmt.Errorf("%d: %s", resp.StatusCode, string(all)) + return nil, err } return &data, nil @@ -95,64 +76,95 @@ func (c Client) CreateHostRecord(domain string, record RecordRequest) (*Data, er // RemoveHostRecord removes a record for a domain. // https://docs.userapi.epik.com/v2/#/DNS%20Host%20Records/removeHostRecord -func (c Client) RemoveHostRecord(domain string, recordID string) (*Data, error) { +func (c Client) RemoveHostRecord(ctx context.Context, domain string, recordID string) (*Data, error) { params := url.Values{} params.Set("ID", recordID) - resp, err := c.do(http.MethodDelete, domain, params, nil) - if err != nil { - return nil, err - } + endpoint := c.createEndpoint(domain, params) - defer func() { _ = resp.Body.Close() }() - - all, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read request body (%d): %w", resp.StatusCode, err) - } - - err = checkError(resp.StatusCode, all) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return nil, err } var data Data - err = json.Unmarshal(all, &data) + err = c.do(req, &data) if err != nil { - return nil, fmt.Errorf("%d: %s", resp.StatusCode, string(all)) + return nil, err } return &data, nil } -func (c *Client) do(method, domain string, params url.Values, body io.Reader) (*http.Response, error) { +func (c Client) do(req *http.Request, result any) error { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func (c Client) createEndpoint(domain string, params url.Values) *url.URL { endpoint := c.baseURL.JoinPath("domains", domain, "records") params.Set("SIGNATURE", c.signature) endpoint.RawQuery = params.Encode() - req, err := http.NewRequest(method, endpoint.String(), body) + return endpoint +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - if body != nil { + + if payload != nil { req.Header.Set("Content-Type", "application/json") } - return c.HTTPClient.Do(req) + return req, nil } -func checkError(statusCode int, all []byte) error { - if statusCode == http.StatusOK { - return nil - } +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) var apiErr APIError - err := json.Unmarshal(all, &apiErr) + err := json.Unmarshal(raw, &apiErr) if err != nil { - return fmt.Errorf("%d: %s", statusCode, string(all)) + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } return &apiErr diff --git a/providers/dns/epik/internal/client_test.go b/providers/dns/epik/internal/client_test.go index 47159d24..a1d0186a 100644 --- a/providers/dns/epik/internal/client_test.go +++ b/providers/dns/epik/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -14,8 +15,9 @@ import ( "github.com/stretchr/testify/require" ) -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) @@ -24,15 +26,15 @@ func setupTest(t *testing.T) (*http.ServeMux, *Client) { client.HTTPClient = server.Client() client.baseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func TestClient_GetDNSRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodGet, http.StatusOK, "getDnsRecord.json")) - records, err := client.GetDNSRecords("example.com") + records, err := client.GetDNSRecords(context.Background(), "example.com") require.NoError(t, err) expected := []Record{ @@ -87,16 +89,16 @@ func TestClient_GetDNSRecords(t *testing.T) { } func TestClient_GetDNSRecords_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodGet, http.StatusUnauthorized, "error.json")) - _, err := client.GetDNSRecords("example.com") + _, err := client.GetDNSRecords(context.Background(), "example.com") assert.Error(t, err) } func TestClient_CreateHostRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodPost, http.StatusOK, "createHostRecord.json")) @@ -108,7 +110,7 @@ func TestClient_CreateHostRecord(t *testing.T) { TTL: 300, } - data, err := client.CreateHostRecord("example.com", record) + data, err := client.CreateHostRecord(context.Background(), "example.com", record) require.NoError(t, err) expected := &Data{ @@ -120,7 +122,7 @@ func TestClient_CreateHostRecord(t *testing.T) { } func TestClient_CreateHostRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodPost, http.StatusUnauthorized, "error.json")) @@ -132,16 +134,16 @@ func TestClient_CreateHostRecord_error(t *testing.T) { TTL: 300, } - _, err := client.CreateHostRecord("example.com", record) + _, err := client.CreateHostRecord(context.Background(), "example.com", record) assert.Error(t, err) } func TestClient_RemoveHostRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodDelete, http.StatusOK, "removeHostRecord.json")) - data, err := client.RemoveHostRecord("example.com", "abc123") + data, err := client.RemoveHostRecord(context.Background(), "example.com", "abc123") require.NoError(t, err) expected := &Data{ @@ -153,11 +155,11 @@ func TestClient_RemoveHostRecord(t *testing.T) { } func TestClient_RemoveHostRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/domains/example.com/records", testHandler(http.MethodDelete, http.StatusUnauthorized, "error.json")) - _, err := client.RemoveHostRecord("example.com", "abc123") + _, err := client.RemoveHostRecord(context.Background(), "example.com", "abc123") assert.Error(t, err) } diff --git a/providers/dns/exec/exec.go b/providers/dns/exec/exec.go index a6252ac7..a07cba0a 100644 --- a/providers/dns/exec/exec.go +++ b/providers/dns/exec/exec.go @@ -2,6 +2,7 @@ package exec import ( + "context" "errors" "fmt" "os" @@ -67,7 +68,7 @@ func NewDNSProvider() (*DNSProvider, error) { // for adding and removing the DNS record. func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { if config == nil { - return nil, errors.New("the configuration is nil") + return nil, errors.New("exec: the configuration is nil") } return &DNSProvider{config: config}, nil @@ -75,42 +76,22 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { - var args []string - if d.config.Mode == "RAW" { - args = []string{"present", "--", domain, token, keyAuth} - } else { - info := dns01.GetChallengeInfo(domain, keyAuth) - args = []string{"present", info.EffectiveFQDN, info.Value} + err := d.run(context.Background(), "present", domain, token, keyAuth) + if err != nil { + return fmt.Errorf("exec: %w", err) } - cmd := exec.Command(d.config.Program, args...) - - output, err := cmd.CombinedOutput() - if len(output) > 0 { - log.Println(string(output)) - } - - return err + return nil } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { - var args []string - if d.config.Mode == "RAW" { - args = []string{"cleanup", "--", domain, token, keyAuth} - } else { - info := dns01.GetChallengeInfo(domain, keyAuth) - args = []string{"cleanup", info.EffectiveFQDN, info.Value} + err := d.run(context.Background(), "cleanup", domain, token, keyAuth) + if err != nil { + return fmt.Errorf("exec: %w", err) } - cmd := exec.Command(d.config.Program, args...) - - output, err := cmd.CombinedOutput() - if len(output) > 0 { - log.Println(string(output)) - } - - return err + return nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -124,3 +105,22 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Sequential() time.Duration { return d.config.SequenceInterval } + +func (d *DNSProvider) run(ctx context.Context, command, domain, token, keyAuth string) error { + var args []string + if d.config.Mode == "RAW" { + args = []string{command, "--", domain, token, keyAuth} + } else { + info := dns01.GetChallengeInfo(domain, keyAuth) + args = []string{command, info.EffectiveFQDN, info.Value} + } + + cmd := exec.CommandContext(ctx, d.config.Program, args...) + + output, err := cmd.CombinedOutput() + if len(output) > 0 { + log.Println(string(output)) + } + + return err +} diff --git a/providers/dns/exoscale/exoscale.go b/providers/dns/exoscale/exoscale.go index e92bdd44..770899f9 100644 --- a/providers/dns/exoscale/exoscale.go +++ b/providers/dns/exoscale/exoscale.go @@ -246,7 +246,7 @@ func (d *DNSProvider) findExistingRecordID(zoneID, recordName string) (string, e func (d *DNSProvider) findZoneAndRecordName(fqdn string) (string, string, error) { zone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", "", err + return "", "", fmt.Errorf("designate: could not find zone for FQDN %q: %w", fqdn, err) } zone = dns01.UnFqdn(zone) diff --git a/providers/dns/gandi/client.go b/providers/dns/gandi/client.go deleted file mode 100644 index acdc8a4c..00000000 --- a/providers/dns/gandi/client.go +++ /dev/null @@ -1,322 +0,0 @@ -package gandi - -import ( - "bytes" - "encoding/xml" - "errors" - "fmt" - "io" -) - -// types for XML-RPC method calls and parameters - -type param interface { - param() -} - -type paramString struct { - XMLName xml.Name `xml:"param"` - Value string `xml:"value>string"` -} - -type paramInt struct { - XMLName xml.Name `xml:"param"` - Value int `xml:"value>int"` -} - -type structMember interface { - structMember() -} - -type structMemberString struct { - Name string `xml:"name"` - Value string `xml:"value>string"` -} - -type structMemberInt struct { - Name string `xml:"name"` - Value int `xml:"value>int"` -} - -type paramStruct struct { - XMLName xml.Name `xml:"param"` - StructMembers []structMember `xml:"value>struct>member"` -} - -func (p paramString) param() {} -func (p paramInt) param() {} -func (m structMemberString) structMember() {} -func (m structMemberInt) structMember() {} -func (p paramStruct) param() {} - -type methodCall struct { - XMLName xml.Name `xml:"methodCall"` - MethodName string `xml:"methodName"` - Params []param `xml:"params"` -} - -// types for XML-RPC responses - -type response interface { - faultCode() int - faultString() string -} - -type responseFault struct { - FaultCode int `xml:"fault>value>struct>member>value>int"` - FaultString string `xml:"fault>value>struct>member>value>string"` -} - -func (r responseFault) faultCode() int { return r.FaultCode } -func (r responseFault) faultString() string { return r.FaultString } - -type responseStruct struct { - responseFault - StructMembers []struct { - Name string `xml:"name"` - ValueInt int `xml:"value>int"` - } `xml:"params>param>value>struct>member"` -} - -type responseInt struct { - responseFault - Value int `xml:"params>param>value>int"` -} - -type responseBool struct { - responseFault - Value bool `xml:"params>param>value>boolean"` -} - -type rpcError struct { - faultCode int - faultString string -} - -func (e rpcError) Error() string { - return fmt.Sprintf("Gandi DNS: RPC Error: (%d) %s", e.faultCode, e.faultString) -} - -// rpcCall makes an XML-RPC call to Gandi's RPC endpoint by -// marshaling the data given in the call argument to XML and sending -// that via HTTP Post to Gandi. -// The response is then unmarshalled into the resp argument. -func (d *DNSProvider) rpcCall(call *methodCall, resp response) error { - // marshal - b, err := xml.MarshalIndent(call, "", " ") - if err != nil { - return fmt.Errorf("marshal error: %w", err) - } - - // post - b = append([]byte(``+"\n"), b...) - respBody, err := d.httpPost(d.config.BaseURL, "text/xml", bytes.NewReader(b)) - if err != nil { - return err - } - - // unmarshal - err = xml.Unmarshal(respBody, resp) - if err != nil { - return fmt.Errorf("unmarshal error: %w", err) - } - if resp.faultCode() != 0 { - return rpcError{ - faultCode: resp.faultCode(), faultString: resp.faultString(), - } - } - return nil -} - -// functions to perform API actions - -func (d *DNSProvider) getZoneID(domain string) (int, error) { - resp := &responseStruct{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.info", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramString{Value: domain}, - }, - }, resp) - if err != nil { - return 0, err - } - - var zoneID int - for _, member := range resp.StructMembers { - if member.Name == "zone_id" { - zoneID = member.ValueInt - } - } - - if zoneID == 0 { - return 0, fmt.Errorf("could not determine zone_id for %s", domain) - } - return zoneID, nil -} - -func (d *DNSProvider) cloneZone(zoneID int, name string) (int, error) { - resp := &responseStruct{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.clone", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramInt{Value: zoneID}, - paramInt{Value: 0}, - paramStruct{ - StructMembers: []structMember{ - structMemberString{ - Name: "name", - Value: name, - }, - }, - }, - }, - }, resp) - if err != nil { - return 0, err - } - - var newZoneID int - for _, member := range resp.StructMembers { - if member.Name == "id" { - newZoneID = member.ValueInt - } - } - - if newZoneID == 0 { - return 0, errors.New("could not determine cloned zone_id") - } - return newZoneID, nil -} - -func (d *DNSProvider) newZoneVersion(zoneID int) (int, error) { - resp := &responseInt{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.version.new", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramInt{Value: zoneID}, - }, - }, resp) - if err != nil { - return 0, err - } - - if resp.Value == 0 { - return 0, errors.New("could not create new zone version") - } - return resp.Value, nil -} - -func (d *DNSProvider) addTXTRecord(zoneID, version int, name, value string, ttl int) error { - resp := &responseStruct{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.record.add", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramInt{Value: zoneID}, - paramInt{Value: version}, - paramStruct{ - StructMembers: []structMember{ - structMemberString{ - Name: "type", - Value: "TXT", - }, structMemberString{ - Name: "name", - Value: name, - }, structMemberString{ - Name: "value", - Value: value, - }, structMemberInt{ - Name: "ttl", - Value: ttl, - }, - }, - }, - }, - }, resp) - return err -} - -func (d *DNSProvider) setZoneVersion(zoneID, version int) error { - resp := &responseBool{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.version.set", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramInt{Value: zoneID}, - paramInt{Value: version}, - }, - }, resp) - if err != nil { - return err - } - - if !resp.Value { - return errors.New("could not set zone version") - } - return nil -} - -func (d *DNSProvider) setZone(domain string, zoneID int) error { - resp := &responseStruct{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.set", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramString{Value: domain}, - paramInt{Value: zoneID}, - }, - }, resp) - if err != nil { - return err - } - - var respZoneID int - for _, member := range resp.StructMembers { - if member.Name == "zone_id" { - respZoneID = member.ValueInt - } - } - - if respZoneID != zoneID { - return fmt.Errorf("could not set new zone_id for %s", domain) - } - return nil -} - -func (d *DNSProvider) deleteZone(zoneID int) error { - resp := &responseBool{} - err := d.rpcCall(&methodCall{ - MethodName: "domain.zone.delete", - Params: []param{ - paramString{Value: d.config.APIKey}, - paramInt{Value: zoneID}, - }, - }, resp) - if err != nil { - return err - } - - if !resp.Value { - return errors.New("could not delete zone_id") - } - return nil -} - -func (d *DNSProvider) httpPost(url, bodyType string, body io.Reader) ([]byte, error) { - resp, err := d.config.HTTPClient.Post(url, bodyType, body) - if err != nil { - return nil, fmt.Errorf("HTTP Post Error: %w", err) - } - defer resp.Body.Close() - - b, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("HTTP Post Error: %w", err) - } - - return b, nil -} diff --git a/providers/dns/gandi/gandi.go b/providers/dns/gandi/gandi.go index ccf46c78..29af01a9 100644 --- a/providers/dns/gandi/gandi.go +++ b/providers/dns/gandi/gandi.go @@ -2,6 +2,7 @@ package gandi import ( + "context" "errors" "fmt" "net/http" @@ -10,16 +11,13 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/gandi/internal" ) // Gandi API reference: http://doc.rpc.gandi.net/index.html // Gandi API domain examples: http://doc.rpc.gandi.net/domain/faq.html -const ( - // defaultBaseURL Gandi XML-RPC endpoint used by Present and CleanUp. - defaultBaseURL = "https://rpc.gandi.net/xmlrpc/" - minTTL = 300 -) +const minTTL = 300 // Environment variables names. const ( @@ -64,11 +62,16 @@ type inProgressInfo struct { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { + config *Config + client *internal.Client + inProgressFQDNs map[string]inProgressInfo inProgressAuthZones map[string]struct{} inProgressMu sync.Mutex - config *Config - // findZoneByFqdn determines the DNS zone of an fqdn. It is overridden during tests. + + // findZoneByFqdn determines the DNS zone of a FQDN. + // It is overridden during tests. + // only for testing purpose. findZoneByFqdn func(fqdn string) (string, error) } @@ -96,12 +99,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("gandi: no API Key given") } - if config.BaseURL == "" { - config.BaseURL = defaultBaseURL + client := internal.NewClient(config.APIKey) + + if config.BaseURL != "" { + client.BaseURL = config.BaseURL + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient } return &DNSProvider{ config: config, + client: client, inProgressFQDNs: make(map[string]inProgressInfo), inProgressAuthZones: make(map[string]struct{}), findZoneByFqdn: dns01.FindZoneByFqdn, @@ -121,10 +131,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // find authZone and Gandi zone_id for fqdn authZone, err := d.findZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("gandi: findZoneByFqdn failure: %w", err) + return fmt.Errorf("gandi: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zoneID, err := d.getZoneID(authZone) + ctx := context.Background() + + zoneID, err := d.client.GetZoneID(ctx, authZone) if err != nil { return fmt.Errorf("gandi: %w", err) } @@ -148,27 +160,27 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // containing the required TXT record newZoneName := fmt.Sprintf("%s [ACME Challenge %s]", dns01.UnFqdn(authZone), time.Now().Format(time.RFC822Z)) - newZoneID, err := d.cloneZone(zoneID, newZoneName) + newZoneID, err := d.client.CloneZone(ctx, zoneID, newZoneName) if err != nil { return err } - newZoneVersion, err := d.newZoneVersion(newZoneID) + newZoneVersion, err := d.client.NewZoneVersion(ctx, newZoneID) if err != nil { return fmt.Errorf("gandi: %w", err) } - err = d.addTXTRecord(newZoneID, newZoneVersion, subDomain, info.Value, d.config.TTL) + err = d.client.AddTXTRecord(ctx, newZoneID, newZoneVersion, subDomain, info.Value, d.config.TTL) if err != nil { return fmt.Errorf("gandi: %w", err) } - err = d.setZoneVersion(newZoneID, newZoneVersion) + err = d.client.SetZoneVersion(ctx, newZoneID, newZoneVersion) if err != nil { return fmt.Errorf("gandi: %w", err) } - err = d.setZone(authZone, newZoneID) + err = d.client.SetZone(ctx, authZone, newZoneID) if err != nil { return fmt.Errorf("gandi: %w", err) } @@ -205,13 +217,15 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { delete(d.inProgressFQDNs, info.EffectiveFQDN) delete(d.inProgressAuthZones, authZone) + ctx := context.Background() + // perform API actions to restore old gandi zone for authZone - err := d.setZone(authZone, zoneID) + err := d.client.SetZone(ctx, authZone, zoneID) if err != nil { return fmt.Errorf("gandi: %w", err) } - return d.deleteZone(newZoneID) + return d.client.DeleteZone(ctx, newZoneID) } // Timeout returns the values (40*time.Minute, 60*time.Second) which diff --git a/providers/dns/gandi/gandi_mock_test.go b/providers/dns/gandi/gandi_mock_test.go index 970588ef..34783fe8 100644 --- a/providers/dns/gandi/gandi_mock_test.go +++ b/providers/dns/gandi/gandi_mock_test.go @@ -1,7 +1,7 @@ package gandi // CleanUp Request->Response 1 (setZone). -const cleanupSetZoneRequestMock = ` +const cleanupSetZoneRequestMock = ` domain.zone.set @@ -22,7 +22,7 @@ const cleanupSetZoneRequestMock = ` ` // CleanUp Request->Response 1 (setZone). -const cleanupSetZoneResponseMock = ` +const cleanupSetZoneResponseMock = ` @@ -192,7 +192,7 @@ const cleanupSetZoneResponseMock = ` ` // CleanUp Request->Response 2 (deleteZone). -const cleanupDeleteZoneRequestMock = ` +const cleanupDeleteZoneRequestMock = ` domain.zone.delete @@ -208,7 +208,7 @@ const cleanupDeleteZoneRequestMock = ` ` // CleanUp Request->Response 2 (deleteZone). -const cleanupDeleteZoneResponseMock = ` +const cleanupDeleteZoneResponseMock = ` @@ -219,7 +219,7 @@ const cleanupDeleteZoneResponseMock = ` ` // Present Request->Response 1 (getZoneID). -const presentGetZoneIDRequestMock = ` +const presentGetZoneIDRequestMock = ` domain.info @@ -235,7 +235,7 @@ const presentGetZoneIDRequestMock = ` ` // Present Request->Response 1 (getZoneID). -const presentGetZoneIDResponseMock = ` +const presentGetZoneIDResponseMock = ` @@ -405,7 +405,7 @@ const presentGetZoneIDResponseMock = ` ` // Present Request->Response 2 (cloneZone). -const presentCloneZoneRequestMock = ` +const presentCloneZoneRequestMock = ` domain.zone.clone @@ -438,7 +438,7 @@ const presentCloneZoneRequestMock = ` ` // Present Request->Response 2 (cloneZone). -const presentCloneZoneResponseMock = ` +const presentCloneZoneResponseMock = ` @@ -484,7 +484,7 @@ const presentCloneZoneResponseMock = ` ` // Present Request->Response 3 (newZoneVersion). -const presentNewZoneVersionRequestMock = ` +const presentNewZoneVersionRequestMock = ` domain.zone.version.new @@ -500,7 +500,7 @@ const presentNewZoneVersionRequestMock = ` ` // Present Request->Response 3 (newZoneVersion). -const presentNewZoneVersionResponseMock = ` +const presentNewZoneVersionResponseMock = ` @@ -511,7 +511,7 @@ const presentNewZoneVersionResponseMock = ` ` // Present Request->Response 4 (addTXTRecord). -const presentAddTXTRecordRequestMock = ` +const presentAddTXTRecordRequestMock = ` domain.zone.record.add @@ -562,7 +562,7 @@ const presentAddTXTRecordRequestMock = ` ` // Present Request->Response 4 (addTXTRecord). -const presentAddTXTRecordResponseMock = ` +const presentAddTXTRecordResponseMock = ` @@ -594,7 +594,7 @@ const presentAddTXTRecordResponseMock = ` ` // Present Request->Response 5 (setZoneVersion). -const presentSetZoneVersionRequestMock = ` +const presentSetZoneVersionRequestMock = ` domain.zone.version.set @@ -615,7 +615,7 @@ const presentSetZoneVersionRequestMock = ` ` // Present Request->Response 5 (setZoneVersion). -const presentSetZoneVersionResponseMock = ` +const presentSetZoneVersionResponseMock = ` @@ -626,7 +626,7 @@ const presentSetZoneVersionResponseMock = ` ` // Present Request->Response 6 (setZone). -const presentSetZoneRequestMock = ` +const presentSetZoneRequestMock = ` domain.zone.set @@ -647,7 +647,7 @@ const presentSetZoneRequestMock = ` ` // Present Request->Response 6 (setZone). -const presentSetZoneResponseMock = ` +const presentSetZoneResponseMock = ` diff --git a/providers/dns/gandi/gandi_test.go b/providers/dns/gandi/gandi_test.go index f53ef3c2..36bc4ccd 100644 --- a/providers/dns/gandi/gandi_test.go +++ b/providers/dns/gandi/gandi_test.go @@ -132,7 +132,7 @@ func TestDNSProvider(t *testing.T) { req = regexpDate.ReplaceAllLiteral(req, []byte(`[ACME Challenge 01 Jan 16 00:00 +0000]`)) resp, ok := serverResponses[string(req)] - require.True(t, ok, "Server response for request not found") + require.Truef(t, ok, "Server response for request not found: %s", string(req)) _, errS = io.Copy(w, strings.NewReader(resp)) require.NoError(t, errS) diff --git a/providers/dns/gandi/internal/client.go b/providers/dns/gandi/internal/client.go new file mode 100644 index 00000000..6dc09648 --- /dev/null +++ b/providers/dns/gandi/internal/client.go @@ -0,0 +1,289 @@ +package internal + +import ( + "bytes" + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// defaultBaseURL Gandi XML-RPC endpoint used by Present and CleanUp. +const defaultBaseURL = "https://rpc.gandi.net/xmlrpc/" + +// Client the Gandi API client. +type Client struct { + apiKey string + + BaseURL string + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(apiKey string) *Client { + return &Client{ + apiKey: apiKey, + BaseURL: defaultBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c *Client) GetZoneID(ctx context.Context, domain string) (int, error) { + call := &methodCall{ + MethodName: "domain.info", + Params: []param{ + paramString{Value: c.apiKey}, + paramString{Value: domain}, + }, + } + + resp := &responseStruct{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return 0, err + } + + var zoneID int + for _, member := range resp.StructMembers { + if member.Name == "zone_id" { + zoneID = member.ValueInt + } + } + + if zoneID == 0 { + return 0, fmt.Errorf("could not find zone_id for %s", domain) + } + return zoneID, nil +} + +func (c *Client) CloneZone(ctx context.Context, zoneID int, name string) (int, error) { + call := &methodCall{ + MethodName: "domain.zone.clone", + Params: []param{ + paramString{Value: c.apiKey}, + paramInt{Value: zoneID}, + paramInt{Value: 0}, + paramStruct{ + StructMembers: []structMember{ + structMemberString{ + Name: "name", + Value: name, + }, + }, + }, + }, + } + + resp := &responseStruct{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return 0, err + } + + var newZoneID int + for _, member := range resp.StructMembers { + if member.Name == "id" { + newZoneID = member.ValueInt + } + } + + if newZoneID == 0 { + return 0, errors.New("could not determine cloned zone_id") + } + return newZoneID, nil +} + +func (c *Client) NewZoneVersion(ctx context.Context, zoneID int) (int, error) { + call := &methodCall{ + MethodName: "domain.zone.version.new", + Params: []param{ + paramString{Value: c.apiKey}, + paramInt{Value: zoneID}, + }, + } + + resp := &responseInt{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return 0, err + } + + if resp.Value == 0 { + return 0, errors.New("could not create new zone version") + } + return resp.Value, nil +} + +func (c *Client) AddTXTRecord(ctx context.Context, zoneID, version int, name, value string, ttl int) error { + call := &methodCall{ + MethodName: "domain.zone.record.add", + Params: []param{ + paramString{Value: c.apiKey}, + paramInt{Value: zoneID}, + paramInt{Value: version}, + paramStruct{ + StructMembers: []structMember{ + structMemberString{ + Name: "type", + Value: "TXT", + }, structMemberString{ + Name: "name", + Value: name, + }, structMemberString{ + Name: "value", + Value: value, + }, structMemberInt{ + Name: "ttl", + Value: ttl, + }, + }, + }, + }, + } + + resp := &responseStruct{} + + return c.rpcCall(ctx, call, resp) +} + +func (c *Client) SetZoneVersion(ctx context.Context, zoneID, version int) error { + call := &methodCall{ + MethodName: "domain.zone.version.set", + Params: []param{ + paramString{Value: c.apiKey}, + paramInt{Value: zoneID}, + paramInt{Value: version}, + }, + } + + resp := &responseBool{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return err + } + + if !resp.Value { + return errors.New("could not set zone version") + } + return nil +} + +func (c *Client) SetZone(ctx context.Context, domain string, zoneID int) error { + call := &methodCall{ + MethodName: "domain.zone.set", + Params: []param{ + paramString{Value: c.apiKey}, + paramString{Value: domain}, + paramInt{Value: zoneID}, + }, + } + + resp := &responseStruct{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return err + } + + var respZoneID int + for _, member := range resp.StructMembers { + if member.Name == "zone_id" { + respZoneID = member.ValueInt + } + } + + if respZoneID != zoneID { + return fmt.Errorf("could not set new zone_id for %s", domain) + } + return nil +} + +func (c *Client) DeleteZone(ctx context.Context, zoneID int) error { + call := &methodCall{ + MethodName: "domain.zone.delete", + Params: []param{ + paramString{Value: c.apiKey}, + paramInt{Value: zoneID}, + }, + } + + resp := &responseBool{} + + err := c.rpcCall(ctx, call, resp) + if err != nil { + return err + } + + if !resp.Value { + return errors.New("could not delete zone_id") + } + + return nil +} + +// rpcCall makes an XML-RPC call to Gandi's RPC endpoint by marshaling the data given in the call argument to XML +// and sending that via HTTP Post to Gandi. +// The response is then unmarshalled into the resp argument. +func (c *Client) rpcCall(ctx context.Context, call *methodCall, result response) error { + req, err := newXMLRequest(ctx, c.BaseURL, call) + if err != nil { + return err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = xml.Unmarshal(raw, result) + if err != nil { + return fmt.Errorf("unmarshal error: %w", err) + } + + if result.faultCode() != 0 { + return RPCError{ + FaultCode: result.faultCode(), + FaultString: result.faultString(), + } + } + + return nil +} + +func newXMLRequest(ctx context.Context, endpoint string, payload *methodCall) (*http.Request, error) { + body := new(bytes.Buffer) + body.WriteString(xml.Header) + + encoder := xml.NewEncoder(body) + encoder.Indent("", " ") + + err := encoder.Encode(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/xml") + + return req, nil +} diff --git a/providers/dns/gandi/internal/types.go b/providers/dns/gandi/internal/types.go new file mode 100644 index 00000000..cdcd0a65 --- /dev/null +++ b/providers/dns/gandi/internal/types.go @@ -0,0 +1,95 @@ +package internal + +import ( + "encoding/xml" + "fmt" +) + +// types for XML-RPC method calls and parameters + +type param interface { + param() +} + +type paramString struct { + XMLName xml.Name `xml:"param"` + Value string `xml:"value>string"` +} + +type paramInt struct { + XMLName xml.Name `xml:"param"` + Value int `xml:"value>int"` +} + +type structMember interface { + structMember() +} + +type structMemberString struct { + Name string `xml:"name"` + Value string `xml:"value>string"` +} + +type structMemberInt struct { + Name string `xml:"name"` + Value int `xml:"value>int"` +} + +type paramStruct struct { + XMLName xml.Name `xml:"param"` + StructMembers []structMember `xml:"value>struct>member"` +} + +func (p paramString) param() {} +func (p paramInt) param() {} +func (m structMemberString) structMember() {} +func (m structMemberInt) structMember() {} +func (p paramStruct) param() {} + +type methodCall struct { + XMLName xml.Name `xml:"methodCall"` + MethodName string `xml:"methodName"` + Params []param `xml:"params"` +} + +// types for XML-RPC responses + +type response interface { + faultCode() int + faultString() string +} + +type responseFault struct { + FaultCode int `xml:"fault>value>struct>member>value>int"` + FaultString string `xml:"fault>value>struct>member>value>string"` +} + +func (r responseFault) faultCode() int { return r.FaultCode } +func (r responseFault) faultString() string { return r.FaultString } + +type responseStruct struct { + responseFault + StructMembers []struct { + Name string `xml:"name"` + ValueInt int `xml:"value>int"` + } `xml:"params>param>value>struct>member"` +} + +type responseInt struct { + responseFault + Value int `xml:"params>param>value>int"` +} + +type responseBool struct { + responseFault + Value bool `xml:"params>param>value>boolean"` +} + +type RPCError struct { + FaultCode int + FaultString string +} + +func (e RPCError) Error() string { + return fmt.Sprintf("Gandi DNS: RPC Error: (%d) %s", e.FaultCode, e.FaultString) +} diff --git a/providers/dns/gandiv5/client.go b/providers/dns/gandiv5/client.go deleted file mode 100644 index 4ec3e1b5..00000000 --- a/providers/dns/gandiv5/client.go +++ /dev/null @@ -1,200 +0,0 @@ -package gandiv5 - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - - "github.com/go-acme/lego/v4/log" -) - -const apiKeyHeader = "X-Api-Key" - -// types for JSON responses with only a message. -type apiResponse struct { - Message string `json:"message"` - UUID string `json:"uuid,omitempty"` -} - -// Record TXT record representation. -type Record struct { - RRSetTTL int `json:"rrset_ttl"` - RRSetValues []string `json:"rrset_values"` - RRSetName string `json:"rrset_name,omitempty"` - RRSetType string `json:"rrset_type,omitempty"` -} - -func (d *DNSProvider) addTXTRecord(domain, name, value string, ttl int) error { - // Get exiting values for the TXT records - // Needed to create challenges for both wildcard and base name domains - txtRecord, err := d.getTXTRecord(domain, name) - if err != nil { - return err - } - - values := []string{value} - if len(txtRecord.RRSetValues) > 0 { - values = append(values, txtRecord.RRSetValues...) - } - - target := fmt.Sprintf("domains/%s/records/%s/TXT", domain, name) - - newRecord := &Record{RRSetTTL: ttl, RRSetValues: values} - req, err := d.newRequest(http.MethodPut, target, newRecord) - if err != nil { - return err - } - - message := apiResponse{} - err = d.do(req, &message) - if err != nil { - return fmt.Errorf("unable to create TXT record for domain %s and name %s: %w", domain, name, err) - } - - if len(message.Message) > 0 { - log.Infof("API response: %s", message.Message) - } - - return nil -} - -func (d *DNSProvider) getTXTRecord(domain, name string) (*Record, error) { - target := fmt.Sprintf("domains/%s/records/%s/TXT", domain, name) - - // Get exiting values for the TXT records - // Needed to create challenges for both wildcard and base name domains - req, err := d.newRequest(http.MethodGet, target, nil) - if err != nil { - return nil, err - } - - txtRecord := &Record{} - err = d.do(req, txtRecord) - if err != nil { - return nil, fmt.Errorf("unable to get TXT records for domain %s and name %s: %w", domain, name, err) - } - - return txtRecord, nil -} - -func (d *DNSProvider) deleteTXTRecord(domain, name string) error { - target := fmt.Sprintf("domains/%s/records/%s/TXT", domain, name) - - req, err := d.newRequest(http.MethodDelete, target, nil) - if err != nil { - return err - } - - message := apiResponse{} - err = d.do(req, &message) - if err != nil { - return fmt.Errorf("unable to delete TXT record for domain %s and name %s: %w", domain, name, err) - } - - if len(message.Message) > 0 { - log.Infof("API response: %s", message.Message) - } - - return nil -} - -func (d *DNSProvider) newRequest(method, resource string, body interface{}) (*http.Request, error) { - u := fmt.Sprintf("%s/%s", d.config.BaseURL, resource) - - if body == nil { - req, err := http.NewRequest(method, u, nil) - if err != nil { - return nil, err - } - - return req, nil - } - - reqBody, err := json.Marshal(body) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, u, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - - return req, nil -} - -func (d *DNSProvider) do(req *http.Request, v interface{}) error { - if len(d.config.APIKey) > 0 { - req.Header.Set(apiKeyHeader, d.config.APIKey) - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - err = checkResponse(resp) - if err != nil { - return err - } - - if v == nil { - return nil - } - - raw, err := readBody(resp) - if err != nil { - return fmt.Errorf("failed to read body: %w", err) - } - - if len(raw) > 0 { - err = json.Unmarshal(raw, v) - if err != nil { - return fmt.Errorf("unmarshaling error: %w: %s", err, string(raw)) - } - } - - return nil -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode == http.StatusNotFound && resp.Request.Method == http.MethodGet { - return nil - } - - if resp.StatusCode >= http.StatusBadRequest { - data, err := readBody(resp) - if err != nil { - return fmt.Errorf("%d [%s] request failed: %w", resp.StatusCode, http.StatusText(resp.StatusCode), err) - } - - message := &apiResponse{} - err = json.Unmarshal(data, message) - if err != nil { - return fmt.Errorf("%d [%s] request failed: %w: %s", resp.StatusCode, http.StatusText(resp.StatusCode), err, data) - } - return fmt.Errorf("%d [%s] request failed: %s", resp.StatusCode, http.StatusText(resp.StatusCode), message.Message) - } - - return nil -} - -func readBody(resp *http.Response) ([]byte, error) { - if resp.Body == nil { - return nil, errors.New("response body is nil") - } - - defer resp.Body.Close() - - rawBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return rawBody, nil -} diff --git a/providers/dns/gandiv5/gandiv5.go b/providers/dns/gandiv5/gandiv5.go index 7a877ee2..44859397 100644 --- a/providers/dns/gandiv5/gandiv5.go +++ b/providers/dns/gandiv5/gandiv5.go @@ -2,23 +2,22 @@ package gandiv5 import ( + "context" "errors" "fmt" "net/http" + "net/url" "sync" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/gandiv5/internal" ) // Gandi API reference: http://doc.livedns.gandi.net/ -const ( - // defaultBaseURL endpoint is the Gandi API endpoint used by Present and CleanUp. - defaultBaseURL = "https://dns.api.gandi.net/api/v5" - minTTL = 300 -) +const minTTL = 300 // Environment variables names. const ( @@ -62,10 +61,15 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + inProgressFQDNs map[string]inProgressInfo inProgressMu sync.Mutex - // findZoneByFqdn determines the DNS zone of an fqdn. It is overridden during tests. + + // findZoneByFqdn determines the DNS zone of a FQDN. + // It is overridden during tests. + // only for testing purpose. findZoneByFqdn func(fqdn string) (string, error) } @@ -93,16 +97,27 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("gandiv5: no API Key given") } - if config.BaseURL == "" { - config.BaseURL = defaultBaseURL - } - if config.TTL < minTTL { return nil, fmt.Errorf("gandiv5: invalid TTL, TTL (%d) must be greater than %d", config.TTL, minTTL) } + client := internal.NewClient(config.APIKey) + + if config.BaseURL != "" { + baseURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("gandiv5: %w", err) + } + client.BaseURL = baseURL + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + return &DNSProvider{ config: config, + client: client, inProgressFQDNs: make(map[string]inProgressInfo), findZoneByFqdn: dns01.FindZoneByFqdn, }, nil @@ -115,7 +130,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // find authZone authZone, err := d.findZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("gandiv5: findZoneByFqdn failure: %w", err) + return fmt.Errorf("gandiv5: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // determine name of TXT record @@ -130,7 +145,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { defer d.inProgressMu.Unlock() // add TXT record into authZone - err = d.addTXTRecord(dns01.UnFqdn(authZone), subDomain, info.Value, d.config.TTL) + err = d.client.AddTXTRecord(context.Background(), dns01.UnFqdn(authZone), subDomain, info.Value, d.config.TTL) if err != nil { return err } @@ -160,7 +175,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { delete(d.inProgressFQDNs, info.EffectiveFQDN) // delete TXT record from authZone - err := d.deleteTXTRecord(dns01.UnFqdn(authZone), fieldName) + err := d.client.DeleteTXTRecord(context.Background(), dns01.UnFqdn(authZone), fieldName) if err != nil { return fmt.Errorf("gandiv5: %w", err) } diff --git a/providers/dns/gandiv5/gandiv5_test.go b/providers/dns/gandiv5/gandiv5_test.go index 31156e23..52e3b961 100644 --- a/providers/dns/gandiv5/gandiv5_test.go +++ b/providers/dns/gandiv5/gandiv5_test.go @@ -10,6 +10,7 @@ import ( "github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/providers/dns/gandiv5/internal" "github.com/stretchr/testify/require" ) @@ -115,10 +116,13 @@ func TestDNSProvider(t *testing.T) { // start fake RPC server mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + mux.HandleFunc("/domains/example.com/records/_acme-challenge.abc.def/TXT", func(rw http.ResponseWriter, req *http.Request) { log.Infof("request: %s %s", req.Method, req.URL) - if req.Header.Get(apiKeyHeader) == "" { + if req.Header.Get(internal.APIKeyHeader) == "" { http.Error(rw, `{"message": "missing API key"}`, http.StatusUnauthorized) return } @@ -155,9 +159,6 @@ func TestDNSProvider(t *testing.T) { http.Error(rw, fmt.Sprintf(`{"message": "URL doesn't match: %s"}`, req.URL), http.StatusNotFound) }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - // define function to override findZoneByFqdn with fakeFindZoneByFqdn := func(fqdn string) (string, error) { return "example.com.", nil diff --git a/providers/dns/gandiv5/internal/client.go b/providers/dns/gandiv5/internal/client.go new file mode 100644 index 00000000..bb280a3c --- /dev/null +++ b/providers/dns/gandiv5/internal/client.go @@ -0,0 +1,208 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// defaultBaseURL endpoint is the Gandi API endpoint used by Present and CleanUp. +const defaultBaseURL = "https://dns.api.gandi.net/api/v5" + +// APIKeyHeader API key header. +const APIKeyHeader = "X-Api-Key" + +// Client the Gandi API v5 client. +type Client struct { + apiKey string + + BaseURL *url.URL + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(apiKey string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + apiKey: apiKey, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c *Client) AddTXTRecord(ctx context.Context, domain, name, value string, ttl int) error { + // Get exiting values for the TXT records + // Needed to create challenges for both wildcard and base name domains + txtRecord, err := c.getTXTRecord(ctx, domain, name) + if err != nil { + return err + } + + values := []string{value} + if len(txtRecord.RRSetValues) > 0 { + values = append(values, txtRecord.RRSetValues...) + } + + newRecord := &Record{RRSetTTL: ttl, RRSetValues: values} + + err = c.addTXTRecord(ctx, domain, name, newRecord) + if err != nil { + return err + } + + return nil +} + +func (c *Client) getTXTRecord(ctx context.Context, domain, name string) (*Record, error) { + endpoint := c.BaseURL.JoinPath("domains", domain, "records", name, "TXT") + + // Get exiting values for the TXT records + // Needed to create challenges for both wildcard and base name domains + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + txtRecord := &Record{} + err = c.do(req, txtRecord) + if err != nil { + return nil, fmt.Errorf("unable to get TXT records for domain %s and name %s: %w", domain, name, err) + } + + return txtRecord, nil +} + +func (c *Client) addTXTRecord(ctx context.Context, domain, name string, newRecord *Record) error { + endpoint := c.BaseURL.JoinPath("domains", domain, "records", name, "TXT") + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, newRecord) + if err != nil { + return err + } + + message := apiResponse{} + err = c.do(req, &message) + if err != nil { + return fmt.Errorf("unable to create TXT record for domain %s and name %s: %w", domain, name, err) + } + + if message.Message != "" { + log.Infof("API response: %s", message.Message) + } + + return nil +} + +func (c *Client) DeleteTXTRecord(ctx context.Context, domain, name string) error { + endpoint := c.BaseURL.JoinPath("domains", domain, "records", name, "TXT") + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + message := apiResponse{} + err = c.do(req, &message) + if err != nil { + return fmt.Errorf("unable to delete TXT record for domain %s and name %s: %w", domain, name, err) + } + + if message.Message != "" { + log.Infof("API response: %s", message.Message) + } + + return nil +} + +func (c *Client) do(req *http.Request, result any) error { + if c.apiKey != "" { + req.Header.Set(APIKeyHeader, c.apiKey) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + err = checkResponse(req, resp) + if err != nil { + return err + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if len(raw) > 0 { + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func checkResponse(req *http.Request, resp *http.Response) error { + if resp.StatusCode == http.StatusNotFound && resp.Request.Method == http.MethodGet { + return nil + } + + if resp.StatusCode < http.StatusBadRequest { + return nil + } + + return parseError(req, resp) +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + response := apiResponse{} + err := json.Unmarshal(raw, &response) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("%d: request failed: %s", resp.StatusCode, response.Message) +} diff --git a/providers/dns/gandiv5/internal/types.go b/providers/dns/gandiv5/internal/types.go new file mode 100644 index 00000000..2c0ba534 --- /dev/null +++ b/providers/dns/gandiv5/internal/types.go @@ -0,0 +1,15 @@ +package internal + +// types for JSON responses with only a message. +type apiResponse struct { + Message string `json:"message"` + UUID string `json:"uuid,omitempty"` +} + +// Record TXT record representation. +type Record struct { + RRSetTTL int `json:"rrset_ttl"` + RRSetValues []string `json:"rrset_values"` + RRSetName string `json:"rrset_name,omitempty"` + RRSetType string `json:"rrset_type,omitempty"` +} diff --git a/providers/dns/gcloud/googlecloud.go b/providers/dns/gcloud/googlecloud.go index 54169678..34a7d1e0 100644 --- a/providers/dns/gcloud/googlecloud.go +++ b/providers/dns/gcloud/googlecloud.go @@ -75,7 +75,7 @@ type DNSProvider struct { // or by specifying the keyfile location: GCE_SERVICE_ACCOUNT_FILE. func NewDNSProvider() (*DNSProvider, error) { // Use a service account file if specified via environment variable. - if saKey := env.GetOrFile(EnvServiceAccount); len(saKey) > 0 { + if saKey := env.GetOrFile(EnvServiceAccount); saKey != "" { return NewDNSProviderServiceAccountKey([]byte(saKey)) } @@ -312,7 +312,7 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) getHostedZone(domain string) (string, error) { authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain)) if err != nil { - return "", err + return "", fmt.Errorf("designate: could not find zone for FQDN %q: %w", domain, err) } zones, err := d.client.ManagedZones. diff --git a/providers/dns/gcloud/googlecloud_test.go b/providers/dns/gcloud/googlecloud_test.go index 87ba9dbd..02071b1c 100644 --- a/providers/dns/gcloud/googlecloud_test.go +++ b/providers/dns/gcloud/googlecloud_test.go @@ -144,6 +144,8 @@ func TestNewDNSProviderConfig(t *testing.T) { func TestPresentNoExistingRR(t *testing.T) { mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) // getHostedZone: /manhattan/managedZones?alt=json&dnsName=lego.wtf. mux.HandleFunc("/dns/v1/projects/manhattan/managedZones", func(w http.ResponseWriter, r *http.Request) { @@ -205,11 +207,8 @@ func TestPresentNoExistingRR(t *testing.T) { } }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - config := NewDefaultConfig() - config.HTTPClient = &http.Client{} + config.HTTPClient = &http.Client{Timeout: 10 * time.Second} config.Project = "manhattan" p, err := NewDNSProviderConfig(config) @@ -225,6 +224,8 @@ func TestPresentNoExistingRR(t *testing.T) { func TestPresentWithExistingRR(t *testing.T) { mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) // getHostedZone: /manhattan/managedZones?alt=json&dnsName=lego.wtf. mux.HandleFunc("/dns/v1/projects/manhattan/managedZones", func(w http.ResponseWriter, r *http.Request) { @@ -306,11 +307,8 @@ func TestPresentWithExistingRR(t *testing.T) { } }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - config := NewDefaultConfig() - config.HTTPClient = &http.Client{} + config.HTTPClient = &http.Client{Timeout: 10 * time.Second} config.Project = "manhattan" p, err := NewDNSProviderConfig(config) @@ -326,6 +324,8 @@ func TestPresentWithExistingRR(t *testing.T) { func TestPresentSkipExistingRR(t *testing.T) { mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) // getHostedZone: /manhattan/managedZones?alt=json&dnsName=lego.wtf. mux.HandleFunc("/dns/v1/projects/manhattan/managedZones", func(w http.ResponseWriter, r *http.Request) { @@ -370,11 +370,8 @@ func TestPresentSkipExistingRR(t *testing.T) { } }) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - config := NewDefaultConfig() - config.HTTPClient = &http.Client{} + config.HTTPClient = &http.Client{Timeout: 10 * time.Second} config.Project = "manhattan" p, err := NewDNSProviderConfig(config) diff --git a/providers/dns/gcore/internal/client.go b/providers/dns/gcore/internal/client.go index 18160681..65841487 100644 --- a/providers/dns/gcore/internal/client.go +++ b/providers/dns/gcore/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "bytes" "context" "encoding/json" "errors" @@ -8,21 +9,26 @@ import ( "io" "net/http" "net/url" - "strings" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) +const defaultBaseURL = "https://api.gcorelabs.com/dns" + const ( - defaultBaseURL = "https://api.gcorelabs.com/dns" - tokenHeader = "APIKey" - txtRecordType = "TXT" + authorizationHeader = "Authorization" + tokenTypeHeader = "APIKey" ) +const txtRecordType = "TXT" + // Client for DNS API. type Client struct { - HTTPClient *http.Client + token string + baseURL *url.URL - token string + HTTPClient *http.Client } // NewClient constructor of Client. @@ -42,7 +48,7 @@ func (c *Client) GetZone(ctx context.Context, name string) (Zone, error) { endpoint := c.baseURL.JoinPath("v2", "zones", name) zone := Zone{} - err := c.do(ctx, http.MethodGet, endpoint, nil, &zone) + err := c.doRequest(ctx, http.MethodGet, endpoint, nil, &zone) if err != nil { return Zone{}, fmt.Errorf("get zone %s: %w", name, err) } @@ -56,7 +62,7 @@ func (c *Client) GetRRSet(ctx context.Context, zone, name string) (RRSet, error) endpoint := c.baseURL.JoinPath("v2", "zones", zone, name, txtRecordType) var result RRSet - err := c.do(ctx, http.MethodGet, endpoint, nil, &result) + err := c.doRequest(ctx, http.MethodGet, endpoint, nil, &result) if err != nil { return RRSet{}, fmt.Errorf("get txt records %s -> %s: %w", zone, name, err) } @@ -69,7 +75,7 @@ func (c *Client) GetRRSet(ctx context.Context, zone, name string) (RRSet, error) func (c *Client) DeleteRRSet(ctx context.Context, zone, name string) error { endpoint := c.baseURL.JoinPath("v2", "zones", zone, name, txtRecordType) - err := c.do(ctx, http.MethodDelete, endpoint, nil, nil) + err := c.doRequest(ctx, http.MethodDelete, endpoint, nil, nil) if err != nil { // Support DELETE idempotence https://developer.mozilla.org/en-US/docs/Glossary/Idempotent statusErr := new(APIError) @@ -100,59 +106,84 @@ func (c *Client) AddRRSet(ctx context.Context, zone, recordName, value string, t func (c *Client) createRRSet(ctx context.Context, zone, name string, record RRSet) error { endpoint := c.baseURL.JoinPath("v2", "zones", zone, name, txtRecordType) - return c.do(ctx, http.MethodPost, endpoint, record, nil) + return c.doRequest(ctx, http.MethodPost, endpoint, record, nil) } // https://dnsapi.gcorelabs.com/docs#operation/UpdateRRSet func (c *Client) updateRRSet(ctx context.Context, zone, name string, record RRSet) error { endpoint := c.baseURL.JoinPath("v2", "zones", zone, name, txtRecordType) - return c.do(ctx, http.MethodPut, endpoint, record, nil) + return c.doRequest(ctx, http.MethodPut, endpoint, record, nil) } -func (c *Client) do(ctx context.Context, method string, endpoint *url.URL, bodyParams interface{}, dest interface{}) error { - var bs []byte - if bodyParams != nil { - var err error - bs, err = json.Marshal(bodyParams) - if err != nil { - return fmt.Errorf("encode bodyParams: %w", err) - } - } - - req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), strings.NewReader(string(bs))) +func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL, bodyParams any, result any) error { + req, err := newJSONRequest(ctx, method, endpoint, bodyParams) if err != nil { return fmt.Errorf("new request: %w", err) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenHeader, c.token)) + req.Header.Set(authorizationHeader, fmt.Sprintf("%s %s", tokenTypeHeader, c.token)) resp, err := c.HTTPClient.Do(req) if err != nil { - return fmt.Errorf("send request: %w", err) + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { - all, _ := io.ReadAll(resp.Body) - - e := APIError{ - StatusCode: resp.StatusCode, - } - - err := json.Unmarshal(all, &e) - if err != nil { - e.Message = string(all) - } - - return e + return parseError(resp) } - if dest == nil { + if result == nil { return nil } - return json.NewDecoder(resp.Body).Decode(dest) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := APIError{StatusCode: resp.StatusCode} + err := json.Unmarshal(raw, &errAPI) + if err != nil { + errAPI.Message = string(raw) + } + + return errAPI } diff --git a/providers/dns/gcore/internal/client_test.go b/providers/dns/gcore/internal/client_test.go index 86872260..f414b33e 100644 --- a/providers/dns/gcore/internal/client_test.go +++ b/providers/dns/gcore/internal/client_test.go @@ -21,22 +21,21 @@ const ( testTTL = 10 ) -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() - server := httptest.NewServer(mux) t.Cleanup(server.Close) client := NewClient(testToken) client.baseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func TestClient_GetZone(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) expected := Zone{Name: "example.com"} @@ -52,7 +51,7 @@ func TestClient_GetZone(t *testing.T) { } func TestClient_GetZone_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v2/zones/example.com", validationHandler{ method: http.MethodGet, @@ -64,7 +63,7 @@ func TestClient_GetZone_error(t *testing.T) { } func TestClient_GetRRSet(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) expected := RRSet{ TTL: testTTL, @@ -85,7 +84,7 @@ func TestClient_GetRRSet(t *testing.T) { } func TestClient_GetRRSet_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v2/zones/example.com/foo.example.com/TXT", validationHandler{ method: http.MethodGet, @@ -97,7 +96,7 @@ func TestClient_GetRRSet_error(t *testing.T) { } func TestClient_DeleteRRSet(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v2/zones/test.example.com/my.test.example.com/"+txtRecordType, validationHandler{method: http.MethodDelete}) @@ -107,7 +106,7 @@ func TestClient_DeleteRRSet(t *testing.T) { } func TestClient_DeleteRRSet_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.Handle("/v2/zones/test.example.com/my.test.example.com/"+txtRecordType, validationHandler{ method: http.MethodDelete, @@ -178,7 +177,7 @@ func TestClient_AddRRSet(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux, cl := setupTest(t) + cl, mux := setupTest(t) for pattern, handler := range test.handlers { mux.Handle(pattern, handler) @@ -201,7 +200,7 @@ type validationHandler struct { } func (v validationHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if req.Header.Get("Authorization") != fmt.Sprintf("%s %s", tokenHeader, testToken) { + if req.Header.Get(authorizationHeader) != fmt.Sprintf("%s %s", tokenTypeHeader, testToken) { rw.WriteHeader(http.StatusForbidden) _ = json.NewEncoder(rw).Encode(APIError{Message: "token up for parsing was not passed through the context"}) return diff --git a/providers/dns/glesys/client.go b/providers/dns/glesys/client.go deleted file mode 100644 index 32207280..00000000 --- a/providers/dns/glesys/client.go +++ /dev/null @@ -1,91 +0,0 @@ -package glesys - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - - "github.com/go-acme/lego/v4/log" -) - -// types for JSON method calls, parameters, and responses - -type addRecordRequest struct { - DomainName string `json:"domainname"` - Host string `json:"host"` - Type string `json:"type"` - Data string `json:"data"` - TTL int `json:"ttl,omitempty"` -} - -type deleteRecordRequest struct { - RecordID int `json:"recordid"` -} - -type responseStruct struct { - Response struct { - Status struct { - Code int `json:"code"` - } `json:"status"` - Record deleteRecordRequest `json:"record"` - } `json:"response"` -} - -func (d *DNSProvider) addTXTRecord(fqdn, domain, name, value string, ttl int) (int, error) { - response, err := d.sendRequest(http.MethodPost, "addrecord", addRecordRequest{ - DomainName: domain, - Host: name, - Type: "TXT", - Data: value, - TTL: ttl, - }) - - if response != nil && response.Response.Status.Code == http.StatusOK { - log.Infof("[%s]: Successfully created record id %d", fqdn, response.Response.Record.RecordID) - return response.Response.Record.RecordID, nil - } - return 0, err -} - -func (d *DNSProvider) deleteTXTRecord(fqdn string, recordid int) error { - response, err := d.sendRequest(http.MethodPost, "deleterecord", deleteRecordRequest{ - RecordID: recordid, - }) - if response != nil && response.Response.Status.Code == 200 { - log.Infof("[%s]: Successfully deleted record id %d", fqdn, recordid) - } - return err -} - -func (d *DNSProvider) sendRequest(method, resource string, payload interface{}) (*responseStruct, error) { - url := fmt.Sprintf("%s/%s", defaultBaseURL, resource) - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.SetBasicAuth(d.config.APIUser, d.config.APIKey) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return nil, fmt.Errorf("request failed with HTTP status code %d", resp.StatusCode) - } - - var response responseStruct - err = json.NewDecoder(resp.Body).Decode(&response) - - return &response, err -} diff --git a/providers/dns/glesys/glesys.go b/providers/dns/glesys/glesys.go index 2b5379b3..acdf1b44 100644 --- a/providers/dns/glesys/glesys.go +++ b/providers/dns/glesys/glesys.go @@ -2,6 +2,7 @@ package glesys import ( + "context" "errors" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/glesys/internal" ) const ( @@ -55,7 +57,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + activeRecords map[string]int inProgressMu sync.Mutex } @@ -90,8 +94,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, fmt.Errorf("glesys: invalid TTL, TTL (%d) must be greater than %d", config.TTL, minTTL) } + client := internal.NewClient(config.APIUser, config.APIKey) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + return &DNSProvider{ config: config, + client: client, activeRecords: make(map[string]int), }, nil } @@ -103,7 +114,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // find authZone authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("glesys: findZoneByFqdn failure: %w", err) + return fmt.Errorf("glesys: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -111,14 +122,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("glesys: %w", err) } - // acquire lock and check there is not a challenge already in - // progress for this value of authZone + // acquire lock and check there is not a challenge already in progress for this value of authZone d.inProgressMu.Lock() defer d.inProgressMu.Unlock() // add TXT record into authZone - // TODO(ldez) replace domain by FQDN to follow CNAME. - recordID, err := d.addTXTRecord(domain, dns01.UnFqdn(authZone), subDomain, info.Value, d.config.TTL) + recordID, err := d.client.AddTXTRecord(context.Background(), dns01.UnFqdn(authZone), subDomain, info.Value, d.config.TTL) if err != nil { return err } @@ -144,8 +153,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { delete(d.activeRecords, info.EffectiveFQDN) // delete TXT record from authZone - // TODO(ldez) replace domain by FQDN to follow CNAME. - return d.deleteTXTRecord(domain, recordID) + return d.client.DeleteTXTRecord(context.Background(), recordID) } // Timeout returns the values (20*time.Minute, 20*time.Second) which diff --git a/providers/dns/glesys/internal/client.go b/providers/dns/glesys/internal/client.go new file mode 100644 index 00000000..038c6f0d --- /dev/null +++ b/providers/dns/glesys/internal/client.go @@ -0,0 +1,135 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// defaultBaseURL is the GleSYS API endpoint used by Present and CleanUp. +const defaultBaseURL = "https://api.glesys.com/" + +type Client struct { + apiUser string + apiKey string + + baseURL *url.URL + HTTPClient *http.Client +} + +func NewClient(apiUser string, apiKey string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + apiUser: apiUser, + apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// AddTXTRecord adds a dns record to a domain. +// https://github.com/GleSYS/API/wiki/API-Documentation#domainaddrecord +func (c *Client) AddTXTRecord(ctx context.Context, domain, name, value string, ttl int) (int, error) { + endpoint := c.baseURL.JoinPath("domain", "addrecord") + + request := addRecordRequest{ + DomainName: domain, + Host: name, + Type: "TXT", + Data: value, + TTL: ttl, + } + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, request) + if err != nil { + return 0, err + } + + response, err := c.do(req) + if err != nil { + return 0, err + } + + if response != nil && response.Response.Status.Code == http.StatusOK { + return response.Response.Record.RecordID, nil + } + + return 0, err +} + +// DeleteTXTRecord removes a dns record from a domain. +// https://github.com/GleSYS/API/wiki/API-Documentation#domaindeleterecord +func (c *Client) DeleteTXTRecord(ctx context.Context, recordID int) error { + endpoint := c.baseURL.JoinPath("domain", "deleterecord") + + request := deleteRecordRequest{RecordID: recordID} + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, request) + if err != nil { + return err + } + + _, err = c.do(req) + + return err +} + +func (c *Client) do(req *http.Request) (*apiResponse, error) { + req.SetBasicAuth(c.apiUser, c.apiKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var response apiResponse + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return &response, nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/glesys/internal/client_test.go b/providers/dns/glesys/internal/client_test.go new file mode 100644 index 00000000..7e8ca972 --- /dev/null +++ b/providers/dns/glesys/internal/client_test.go @@ -0,0 +1,79 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + apiUser, apiKey, ok := req.BasicAuth() + if apiUser != "user" || apiKey != "secret" || !ok { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_AddTXTRecord(t *testing.T) { + client := setupTest(t, http.MethodPost, "/domain/addrecord", http.StatusOK, "add-record.json") + + recordID, err := client.AddTXTRecord(context.Background(), "example.com", "foo", "txt", 120) + require.NoError(t, err) + + assert.Equal(t, 123, recordID) +} + +func TestClient_DeleteTXTRecord(t *testing.T) { + client := setupTest(t, http.MethodPost, "/domain/deleterecord", http.StatusOK, "delete-record.json") + + err := client.DeleteTXTRecord(context.Background(), 123) + require.NoError(t, err) +} diff --git a/providers/dns/glesys/internal/fixtures/add-record.json b/providers/dns/glesys/internal/fixtures/add-record.json new file mode 100644 index 00000000..c7d1fc82 --- /dev/null +++ b/providers/dns/glesys/internal/fixtures/add-record.json @@ -0,0 +1,10 @@ +{ + "response": { + "status": { + "code": 200 + }, + "record": { + "recordid": 123 + } + } +} diff --git a/providers/dns/glesys/internal/fixtures/delete-record.json b/providers/dns/glesys/internal/fixtures/delete-record.json new file mode 100644 index 00000000..c7d1fc82 --- /dev/null +++ b/providers/dns/glesys/internal/fixtures/delete-record.json @@ -0,0 +1,10 @@ +{ + "response": { + "status": { + "code": 200 + }, + "record": { + "recordid": 123 + } + } +} diff --git a/providers/dns/glesys/internal/types.go b/providers/dns/glesys/internal/types.go new file mode 100644 index 00000000..61949d1f --- /dev/null +++ b/providers/dns/glesys/internal/types.go @@ -0,0 +1,30 @@ +package internal + +type addRecordRequest struct { + DomainName string `json:"domainname"` + Host string `json:"host"` + Type string `json:"type"` + Data string `json:"data"` + TTL int `json:"ttl,omitempty"` +} + +type deleteRecordRequest struct { + RecordID int `json:"recordid"` +} + +type apiResponse struct { + Response Response `json:"response"` +} + +type Response struct { + Status Status `json:"status"` + Record Record `json:"record"` +} + +type Status struct { + Code int `json:"code"` +} + +type Record struct { + RecordID int `json:"recordid"` +} diff --git a/providers/dns/godaddy/godaddy.go b/providers/dns/godaddy/godaddy.go index 961b2fd9..d5adbeb3 100644 --- a/providers/dns/godaddy/godaddy.go +++ b/providers/dns/godaddy/godaddy.go @@ -2,6 +2,7 @@ package godaddy import ( + "context" "errors" "fmt" "net/http" @@ -104,17 +105,21 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - domainZone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("godaddy: failed to get zone: %w", err) + return fmt.Errorf("godaddy: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, domainZone) + authZone = dns01.UnFqdn(authZone) + + subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) if err != nil { return fmt.Errorf("godaddy: %w", err) } - records, err := d.client.GetRecords(domainZone, "TXT", subDomain) + ctx := context.Background() + + records, err := d.client.GetRecords(ctx, authZone, "TXT", subDomain) if err != nil { return fmt.Errorf("godaddy: failed to get TXT records: %w", err) } @@ -134,7 +139,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { } newRecords = append(newRecords, record) - err = d.client.UpdateTxtRecords(newRecords, domainZone, subDomain) + err = d.client.UpdateTxtRecords(ctx, newRecords, authZone, subDomain) if err != nil { return fmt.Errorf("godaddy: failed to add TXT record: %w", err) } @@ -146,17 +151,21 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - domainZone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("godaddy: failed to get zone: %w", err) + return fmt.Errorf("godaddy: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, domainZone) + authZone = dns01.UnFqdn(authZone) + + subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) if err != nil { return fmt.Errorf("godaddy: %w", err) } - records, err := d.client.GetRecords(domainZone, "TXT", subDomain) + ctx := context.Background() + + records, err := d.client.GetRecords(ctx, authZone, "TXT", subDomain) if err != nil { return fmt.Errorf("godaddy: failed to get TXT records: %w", err) } @@ -165,7 +174,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } - allTxtRecords, err := d.client.GetRecords(domainZone, "TXT", "") + allTxtRecords, err := d.client.GetRecords(ctx, authZone, "TXT", "") if err != nil { return fmt.Errorf("godaddy: failed to get all TXT records: %w", err) } @@ -183,19 +192,10 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { recordsKeep = append(recordsKeep, emptyRecord) } - err = d.client.UpdateTxtRecords(recordsKeep, domainZone, "") + err = d.client.UpdateTxtRecords(ctx, recordsKeep, authZone, "") if err != nil { return fmt.Errorf("godaddy: failed to remove TXT record: %w", err) } return nil } - -func getZone(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/godaddy/internal/client.go b/providers/dns/godaddy/internal/client.go index 90f1ef01..64f9f0bf 100644 --- a/providers/dns/godaddy/internal/client.go +++ b/providers/dns/godaddy/internal/client.go @@ -2,54 +2,51 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" - "path" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // DefaultBaseURL represents the API endpoint to call. const DefaultBaseURL = "https://api.godaddy.com" +const authorizationHeader = "Authorization" + type Client struct { - HTTPClient *http.Client + apiKey string + apiSecret string + baseURL *url.URL - apiKey string - apiSecret string + HTTPClient *http.Client } func NewClient(apiKey string, apiSecret string) *Client { baseURL, _ := url.Parse(DefaultBaseURL) return &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - baseURL: baseURL, apiKey: apiKey, apiSecret: apiSecret, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } -func (d *Client) GetRecords(domainZone, rType, recordName string) ([]DNSRecord, error) { - resource := path.Clean(fmt.Sprintf("/v1/domains/%s/records/%s/%s", domainZone, rType, recordName)) +func (c *Client) GetRecords(ctx context.Context, domainZone, rType, recordName string) ([]DNSRecord, error) { + endpoint := c.baseURL.JoinPath("v1", "domains", domainZone, "records", rType, recordName) - resp, err := d.makeRequest(http.MethodGet, resource, nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("could not get records: Domain: %s; Record: %s, Status: %v; Body: %s", - domainZone, recordName, resp.StatusCode, string(bodyBytes)) - } - var records []DNSRecord - err = json.NewDecoder(resp.Body).Decode(&records) + err = c.do(req, &records) if err != nil { return nil, err } @@ -57,41 +54,68 @@ func (d *Client) GetRecords(domainZone, rType, recordName string) ([]DNSRecord, return records, nil } -func (d *Client) UpdateTxtRecords(records []DNSRecord, domainZone, recordName string) error { - body, err := json.Marshal(records) +func (c *Client) UpdateTxtRecords(ctx context.Context, records []DNSRecord, domainZone, recordName string) error { + endpoint := c.baseURL.JoinPath("v1", "domains", domainZone, "records", "TXT", recordName) + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records) if err != nil { return err } - resource := path.Clean(fmt.Sprintf("/v1/domains/%s/records/TXT/%s", domainZone, recordName)) + return c.do(req, nil) +} - var resp *http.Response - resp, err = d.makeRequest(http.MethodPut, resource, bytes.NewReader(body)) +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set(authorizationHeader, fmt.Sprintf("sso-key %s:%s", c.apiKey, c.apiSecret)) + + resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("could not create record %v; Status: %v; Body: %s", string(body), resp.StatusCode, string(bodyBytes)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } -func (d *Client) makeRequest(method, uri string, body io.Reader) (*http.Response, error) { - endpoint := d.baseURL.JoinPath(uri) +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) - req, err := http.NewRequest(method, endpoint.String(), body) + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("sso-key %s:%s", d.apiKey, d.apiSecret)) - return d.HTTPClient.Do(req) + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil } diff --git a/providers/dns/godaddy/internal/client_test.go b/providers/dns/godaddy/internal/client_test.go index 5e297453..ccbab16d 100644 --- a/providers/dns/godaddy/internal/client_test.go +++ b/providers/dns/godaddy/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -14,7 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() @@ -25,15 +26,15 @@ func setupTest(t *testing.T) (*http.ServeMux, *Client) { client.HTTPClient = server.Client() client.baseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func TestClient_GetRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/domains/example.com/records/TXT/", testHandler(http.MethodGet, http.StatusOK, "getrecords.json")) - records, err := client.GetRecords("example.com", "TXT", "") + records, err := client.GetRecords(context.Background(), "example.com", "TXT", "") require.NoError(t, err) expected := []DNSRecord{ @@ -49,17 +50,17 @@ func TestClient_GetRecords(t *testing.T) { } func TestClient_GetRecords_errors(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/domains/example.com/records/TXT/", testHandler(http.MethodGet, http.StatusUnprocessableEntity, "errors.json")) - records, err := client.GetRecords("example.com", "TXT", "") + records, err := client.GetRecords(context.Background(), "example.com", "TXT", "") require.Error(t, err) assert.Nil(t, records) } func TestClient_UpdateTxtRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/domains/example.com/records/TXT/lego", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPut { @@ -67,7 +68,7 @@ func TestClient_UpdateTxtRecords(t *testing.T) { return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "sso-key key:secret" { http.Error(rw, fmt.Sprintf("invalid API key or secret: %s", auth), http.StatusUnauthorized) return @@ -83,12 +84,12 @@ func TestClient_UpdateTxtRecords(t *testing.T) { {Name: "_acme-challenge.lego", Type: "TXT", Data: "acme", TTL: 600}, } - err := client.UpdateTxtRecords(records, "example.com", "lego") + err := client.UpdateTxtRecords(context.Background(), records, "example.com", "lego") require.NoError(t, err) } func TestClient_UpdateTxtRecords_errors(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/domains/example.com/records/TXT/lego", testHandler(http.MethodPut, http.StatusUnprocessableEntity, "errors.json")) @@ -102,7 +103,7 @@ func TestClient_UpdateTxtRecords_errors(t *testing.T) { {Name: "_acme-challenge.lego", Type: "TXT", Data: "acme", TTL: 600}, } - err := client.UpdateTxtRecords(records, "example.com", "lego") + err := client.UpdateTxtRecords(context.Background(), records, "example.com", "lego") require.Error(t, err) } @@ -113,7 +114,7 @@ func testHandler(method string, statusCode int, filename string) http.HandlerFun return } - auth := req.Header.Get("Authorization") + auth := req.Header.Get(authorizationHeader) if auth != "sso-key key:secret" { http.Error(rw, fmt.Sprintf("invalid API key or secret: %s", auth), http.StatusUnauthorized) return diff --git a/providers/dns/hetzner/hetzner.go b/providers/dns/hetzner/hetzner.go index 405d02d3..58916b4a 100644 --- a/providers/dns/hetzner/hetzner.go +++ b/providers/dns/hetzner/hetzner.go @@ -2,6 +2,7 @@ package hetzner import ( + "context" "errors" "fmt" "net/http" @@ -100,12 +101,16 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hetzner: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("hetzner: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zoneID, err := d.client.GetZoneID(zone) + zone := dns01.UnFqdn(authZone) + + ctx := context.Background() + + zoneID, err := d.client.GetZoneID(ctx, zone) if err != nil { return fmt.Errorf("hetzner: %w", err) } @@ -123,7 +128,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { ZoneID: zoneID, } - if err := d.client.CreateRecord(record); err != nil { + if err := d.client.CreateRecord(ctx, record); err != nil { return fmt.Errorf("hetzner: failed to add TXT record: fqdn=%s, zoneID=%s: %w", info.EffectiveFQDN, zoneID, err) } @@ -134,12 +139,16 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := getZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hetzner: failed to find zone: fqdn=%s: %w", info.EffectiveFQDN, err) + return fmt.Errorf("hetzner: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zoneID, err := d.client.GetZoneID(zone) + zone := dns01.UnFqdn(authZone) + + ctx := context.Background() + + zoneID, err := d.client.GetZoneID(ctx, zone) if err != nil { return fmt.Errorf("hetzner: %w", err) } @@ -149,23 +158,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("hetzner: %w", err) } - record, err := d.client.GetTxtRecord(subDomain, info.Value, zoneID) + record, err := d.client.GetTxtRecord(ctx, subDomain, info.Value, zoneID) if err != nil { return fmt.Errorf("hetzner: %w", err) } - if err := d.client.DeleteRecord(record.ID); err != nil { + if err := d.client.DeleteRecord(ctx, record.ID); err != nil { return fmt.Errorf("hetzner: failed to delate TXT record: id=%s, name=%s: %w", record.ID, record.Name, err) } return nil } - -func getZone(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/hetzner/internal/client.go b/providers/dns/hetzner/internal/client.go index 326ecdb0..38192226 100644 --- a/providers/dns/hetzner/internal/client.go +++ b/providers/dns/hetzner/internal/client.go @@ -2,11 +2,15 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // defaultBaseURL represents the API endpoint to call. @@ -16,24 +20,26 @@ const authHeader = "Auth-API-Token" // Client the Hetzner client. type Client struct { - HTTPClient *http.Client - BaseURL string - apiKey string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient Creates a new Hetzner client. func NewClient(apiKey string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetTxtRecord gets a TXT record. -func (c *Client) GetTxtRecord(name, value, zoneID string) (*DNSRecord, error) { - records, err := c.getRecords(zoneID) +func (c *Client) GetTxtRecord(ctx context.Context, name, value, zoneID string) (*DNSRecord, error) { + records, err := c.getRecords(ctx, zoneID) if err != nil { return nil, err } @@ -48,33 +54,38 @@ func (c *Client) GetTxtRecord(name, value, zoneID string) (*DNSRecord, error) { } // https://dns.hetzner.com/api-docs#operation/GetRecords -func (c *Client) getRecords(zoneID string) (*DNSRecords, error) { - endpoint, err := c.createEndpoint("api", "v1", "records") - if err != nil { - return nil, fmt.Errorf("failed to create endpoint: %w", err) - } +func (c *Client) getRecords(ctx context.Context, zoneID string) (*DNSRecords, error) { + endpoint := c.baseURL.JoinPath("api", "v1", "records") query := endpoint.Query() query.Set("zone_id", zoneID) endpoint.RawQuery = query.Encode() - resp, err := c.do(http.MethodGet, endpoint, nil) + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("could not get records: zone ID: %s; Status: %s; Body: %s", - zoneID, resp.Status, string(bodyBytes)) + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } records := &DNSRecords{} - err = json.NewDecoder(resp.Body).Decode(records) + err = json.Unmarshal(raw, records) if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return records, nil @@ -82,25 +93,23 @@ func (c *Client) getRecords(zoneID string) (*DNSRecords, error) { // CreateRecord creates a DNS record. // https://dns.hetzner.com/api-docs#operation/CreateRecord -func (c *Client) CreateRecord(record DNSRecord) error { - body, err := json.Marshal(record) +func (c *Client) CreateRecord(ctx context.Context, record DNSRecord) error { + endpoint := c.baseURL.JoinPath("api", "v1", "records") + + req, err := c.newRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return err } - endpoint, err := c.createEndpoint("api", "v1", "records") + resp, err := c.HTTPClient.Do(req) if err != nil { - return fmt.Errorf("failed to create endpoint: %w", err) + return errutils.NewHTTPDoError(req, err) } - resp, err := c.do(http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return err - } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("could not create record %s; Status: %s; Body: %s", string(body), resp.Status, string(bodyBytes)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil @@ -108,27 +117,31 @@ func (c *Client) CreateRecord(record DNSRecord) error { // DeleteRecord deletes a DNS record. // https://dns.hetzner.com/api-docs#operation/DeleteRecord -func (c *Client) DeleteRecord(recordID string) error { - endpoint, err := c.createEndpoint("api", "v1", "records", recordID) - if err != nil { - return fmt.Errorf("failed to create endpoint: %w", err) - } +func (c *Client) DeleteRecord(ctx context.Context, recordID string) error { + endpoint := c.baseURL.JoinPath("api", "v1", "records", recordID) - resp, err := c.do(http.MethodDelete, endpoint, nil) + req, err := c.newRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { - return fmt.Errorf("could not delete record: %s; Status: %s", resp.Status, recordID) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil } // GetZoneID gets the zone ID for a domain. -func (c *Client) GetZoneID(domain string) (string, error) { - zones, err := c.getZones(domain) +func (c *Client) GetZoneID(ctx context.Context, domain string) (string, error) { + zones, err := c.getZones(ctx, domain) if err != nil { return "", err } @@ -143,57 +156,70 @@ func (c *Client) GetZoneID(domain string) (string, error) { } // https://dns.hetzner.com/api-docs#operation/GetZones -func (c *Client) getZones(name string) (*Zones, error) { - endpoint, err := c.createEndpoint("api", "v1", "zones") - if err != nil { - return nil, fmt.Errorf("failed to create endpoint: %w", err) - } +func (c *Client) getZones(ctx context.Context, name string) (*Zones, error) { + endpoint := c.baseURL.JoinPath("api", "v1", "zones") query := endpoint.Query() query.Set("name", name) endpoint.RawQuery = query.Encode() - resp, err := c.do(http.MethodGet, endpoint, nil) + req, err := c.newRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("could not get zones: %w", err) } + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + // EOF fallback if resp.StatusCode == http.StatusNotFound { return &Zones{}, nil } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("could not get zones: %s", resp.Status) + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } zones := &Zones{} - err = json.NewDecoder(resp.Body).Decode(zones) + err = json.Unmarshal(raw, zones) if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return zones, nil } -func (c *Client) do(method string, endpoint fmt.Stringer, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest(method, endpoint.String(), body) +func (c *Client) newRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set(authHeader, c.apiKey) - return c.HTTPClient.Do(req) -} - -func (c *Client) createEndpoint(parts ...string) (*url.URL, error) { - baseURL, err := url.Parse(c.BaseURL) - if err != nil { - return nil, err + if payload != nil { + req.Header.Set("Content-Type", "application/json") } - return baseURL.JoinPath(parts...), nil + req.Header.Set(authHeader, c.apiKey) + + return req, nil } diff --git a/providers/dns/hetzner/internal/client_test.go b/providers/dns/hetzner/internal/client_test.go index 269f984a..aa217540 100644 --- a/providers/dns/hetzner/internal/client_test.go +++ b/providers/dns/hetzner/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -12,14 +14,26 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_GetTxtRecord(t *testing.T) { +func setupTest(t *testing.T, apiKey string) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) + client := NewClient(apiKey) + client.baseURL, _ = url.Parse(server.URL) + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_GetTxtRecord(t *testing.T) { const zoneID = "zoneA" const apiKey = "myKeyA" + client, mux := setupTest(t, apiKey) + mux.HandleFunc("/api/v1/records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -52,23 +66,18 @@ func TestClient_GetTxtRecord(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - - record, err := client.GetTxtRecord("test1", "txttxttxt", zoneID) + record, err := client.GetTxtRecord(context.Background(), "test1", "txttxttxt", zoneID) require.NoError(t, err) fmt.Println(record) } func TestClient_CreateRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - const zoneID = "zoneA" const apiKey = "myKeyB" + client, mux := setupTest(t, apiKey) + mux.HandleFunc("/api/v1/records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -95,9 +104,6 @@ func TestClient_CreateRecord(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - record := DNSRecord{ Name: "test", Type: "TXT", @@ -106,17 +112,15 @@ func TestClient_CreateRecord(t *testing.T) { ZoneID: zoneID, } - err := client.CreateRecord(record) + err := client.CreateRecord(context.Background(), record) require.NoError(t, err) } func TestClient_DeleteRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - const apiKey = "myKeyC" + client, mux := setupTest(t, apiKey) + mux.HandleFunc("/api/v1/records/recordID", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodDelete { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -130,19 +134,15 @@ func TestClient_DeleteRecord(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - - err := client.DeleteRecord("recordID") + err := client.DeleteRecord(context.Background(), "recordID") require.NoError(t, err) } func TestClient_GetZoneID(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - const apiKey = "myKeyD" + + client, mux := setupTest(t, apiKey) + mux.HandleFunc("/api/v1/zones", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -169,10 +169,7 @@ func TestClient_GetZoneID(t *testing.T) { } }) - client := NewClient(apiKey) - client.BaseURL = server.URL - - zoneID, err := client.GetZoneID("example.com") + zoneID, err := client.GetZoneID(context.Background(), "example.com") require.NoError(t, err) assert.Equal(t, "zoneA", zoneID) diff --git a/providers/dns/hostingde/client.go b/providers/dns/hostingde/client.go deleted file mode 100644 index 047bb740..00000000 --- a/providers/dns/hostingde/client.go +++ /dev/null @@ -1,123 +0,0 @@ -package hostingde - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "time" - - "github.com/cenkalti/backoff/v4" -) - -const defaultBaseURL = "https://secure.hosting.de/api/dns/v1/json" - -// https://www.hosting.de/api/?json#list-zoneconfigs -func (d *DNSProvider) listZoneConfigs(findRequest ZoneConfigsFindRequest) (*ZoneConfigsFindResponse, error) { - uri := defaultBaseURL + "/zoneConfigsFind" - - findResponse := &ZoneConfigsFindResponse{} - - rawResp, err := d.post(uri, findRequest, findResponse) - if err != nil { - return nil, err - } - - if len(findResponse.Response.Data) == 0 { - return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(uri, rawResp)) - } - - if findResponse.Status != "success" && findResponse.Status != "pending" { - return findResponse, errors.New(toUnreadableBodyMessage(uri, rawResp)) - } - - return findResponse, nil -} - -// https://www.hosting.de/api/?json#updating-zones -func (d *DNSProvider) updateZone(updateRequest ZoneUpdateRequest) (*ZoneUpdateResponse, error) { - uri := defaultBaseURL + "/zoneUpdate" - - // but we'll need the ID later to delete the record - updateResponse := &ZoneUpdateResponse{} - - rawResp, err := d.post(uri, updateRequest, updateResponse) - if err != nil { - return nil, err - } - - if updateResponse.Status != "success" && updateResponse.Status != "pending" { - return nil, errors.New(toUnreadableBodyMessage(uri, rawResp)) - } - - return updateResponse, nil -} - -func (d *DNSProvider) getZone(findRequest ZoneConfigsFindRequest) (*ZoneConfig, error) { - var zoneConfig *ZoneConfig - - operation := func() error { - findResponse, err := d.listZoneConfigs(findRequest) - if err != nil { - return backoff.Permanent(err) - } - - if findResponse.Response.Data[0].Status != "active" { - return fmt.Errorf("unexpected status: %q", findResponse.Response.Data[0].Status) - } - - zoneConfig = &findResponse.Response.Data[0] - - return nil - } - - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = 3 * time.Second - bo.MaxInterval = 10 * bo.InitialInterval - bo.MaxElapsedTime = 100 * bo.InitialInterval - - // retry in case the zone was edited recently and is not yet active - err := backoff.Retry(operation, bo) - if err != nil { - return nil, err - } - - return zoneConfig, nil -} - -func (d *DNSProvider) post(uri string, request, response interface{}) ([]byte, error) { - body, err := json.Marshal(request) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(http.MethodPost, uri, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error querying API: %w", err) - } - - defer resp.Body.Close() - - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.New(toUnreadableBodyMessage(uri, content)) - } - - err = json.Unmarshal(content, response) - if err != nil { - return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(uri, content)) - } - - return content, nil -} - -func toUnreadableBodyMessage(uri string, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", uri, string(rawBody)) -} diff --git a/providers/dns/hostingde/hostingde.go b/providers/dns/hostingde/hostingde.go index e8ebefe1..10d9b5c0 100644 --- a/providers/dns/hostingde/hostingde.go +++ b/providers/dns/hostingde/hostingde.go @@ -2,6 +2,7 @@ package hostingde import ( + "context" "errors" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/hostingde/internal" ) // Environment variables names. @@ -49,7 +51,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + recordIDs map[string]string recordIDsMu sync.Mutex } @@ -82,6 +86,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return &DNSProvider{ config: config, + client: internal.NewClient(config.APIKey), recordIDs: make(map[string]string), }, nil } @@ -98,42 +103,43 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { zoneName, err := d.getZoneName(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hostingde: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("hostingde: could not find zone for domain %q: %w", domain, err) } + ctx := context.Background() + // get the ZoneConfig for that domain - zonesFind := ZoneConfigsFindRequest{ - Filter: Filter{Field: "zoneName", Value: zoneName}, + zonesFind := internal.ZoneConfigsFindRequest{ + Filter: internal.Filter{Field: "zoneName", Value: zoneName}, Limit: 1, Page: 1, } - zonesFind.AuthToken = d.config.APIKey - zoneConfig, err := d.getZone(zonesFind) + zoneConfig, err := d.client.GetZone(ctx, zonesFind) if err != nil { return fmt.Errorf("hostingde: %w", err) } + zoneConfig.Name = zoneName - rec := []DNSRecord{{ + rec := []internal.DNSRecord{{ Type: "TXT", Name: dns01.UnFqdn(info.EffectiveFQDN), Content: info.Value, TTL: d.config.TTL, }} - req := ZoneUpdateRequest{ + req := internal.ZoneUpdateRequest{ ZoneConfig: *zoneConfig, RecordsToAdd: rec, } - req.AuthToken = d.config.APIKey - resp, err := d.updateZone(req) + response, err := d.client.UpdateZone(ctx, req) if err != nil { return fmt.Errorf("hostingde: %w", err) } - for _, record := range resp.Response.Records { + for _, record := range response.Records { if record.Name == dns01.UnFqdn(info.EffectiveFQDN) && record.Content == fmt.Sprintf(`%q`, info.Value) { d.recordIDsMu.Lock() d.recordIDs[info.EffectiveFQDN] = record.ID @@ -154,41 +160,41 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { zoneName, err := d.getZoneName(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hostingde: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("hostingde: could not find zone for domain %q: %w", domain, err) } - rec := []DNSRecord{{ - Type: "TXT", - Name: dns01.UnFqdn(info.EffectiveFQDN), - Content: `"` + info.Value + `"`, - }} + ctx := context.Background() // get the ZoneConfig for that domain - zonesFind := ZoneConfigsFindRequest{ - Filter: Filter{Field: "zoneName", Value: zoneName}, + zonesFind := internal.ZoneConfigsFindRequest{ + Filter: internal.Filter{Field: "zoneName", Value: zoneName}, Limit: 1, Page: 1, } - zonesFind.AuthToken = d.config.APIKey - zoneConfig, err := d.getZone(zonesFind) + zoneConfig, err := d.client.GetZone(ctx, zonesFind) if err != nil { return fmt.Errorf("hostingde: %w", err) } zoneConfig.Name = zoneName - req := ZoneUpdateRequest{ + rec := []internal.DNSRecord{{ + Type: "TXT", + Name: dns01.UnFqdn(info.EffectiveFQDN), + Content: `"` + info.Value + `"`, + }} + + req := internal.ZoneUpdateRequest{ ZoneConfig: *zoneConfig, RecordsToDelete: rec, } - req.AuthToken = d.config.APIKey // Delete record ID from map d.recordIDsMu.Lock() delete(d.recordIDs, info.EffectiveFQDN) d.recordIDsMu.Unlock() - _, err = d.updateZone(req) + _, err = d.client.UpdateZone(ctx, req) if err != nil { return fmt.Errorf("hostingde: %w", err) } @@ -202,7 +208,7 @@ func (d *DNSProvider) getZoneName(fqdn string) (string, error) { zoneName, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } if zoneName == "" { diff --git a/providers/dns/hostingde/internal/client.go b/providers/dns/hostingde/internal/client.go new file mode 100644 index 00000000..0f5c6d18 --- /dev/null +++ b/providers/dns/hostingde/internal/client.go @@ -0,0 +1,147 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const defaultBaseURL = "https://secure.hosting.de/api/dns/v1/json" + +// Client the API client for Hosting.de. +type Client struct { + apiKey string + + baseURL *url.URL + HTTPClient *http.Client +} + +// NewClient creates new Client. +func NewClient(apiKey string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// GetZone gets a zone. +func (c Client) GetZone(ctx context.Context, req ZoneConfigsFindRequest) (*ZoneConfig, error) { + var zoneConfig *ZoneConfig + + operation := func() error { + response, err := c.ListZoneConfigs(ctx, req) + if err != nil { + return backoff.Permanent(err) + } + + if response.Data[0].Status != "active" { + return fmt.Errorf("unexpected status: %q", response.Data[0].Status) + } + + zoneConfig = &response.Data[0] + + return nil + } + + bo := backoff.NewExponentialBackOff() + bo.InitialInterval = 3 * time.Second + bo.MaxInterval = 10 * bo.InitialInterval + bo.MaxElapsedTime = 100 * bo.InitialInterval + + // retry in case the zone was edited recently and is not yet active + err := backoff.Retry(operation, bo) + if err != nil { + return nil, err + } + + return zoneConfig, nil +} + +// ListZoneConfigs lists zone configuration. +// https://www.hosting.de/api/?json#list-zoneconfigs +func (c Client) ListZoneConfigs(ctx context.Context, req ZoneConfigsFindRequest) (*ZoneResponse, error) { + endpoint := c.baseURL.JoinPath("zoneConfigsFind") + + req.AuthToken = c.apiKey + + response := &BaseResponse[*ZoneResponse]{} + + rawResp, err := c.post(ctx, endpoint, req, response) + if err != nil { + return nil, err + } + + if response.Status != "success" && response.Status != "pending" { + return nil, fmt.Errorf("unexpected status: %q, %s", response.Status, string(rawResp)) + } + + if response.Response == nil || len(response.Response.Data) == 0 { + return nil, fmt.Errorf("no data, status: %q, %s", response.Status, string(rawResp)) + } + + return response.Response, nil +} + +// UpdateZone updates a zone. +// https://www.hosting.de/api/?json#updating-zones +func (c Client) UpdateZone(ctx context.Context, req ZoneUpdateRequest) (*Zone, error) { + endpoint := c.baseURL.JoinPath("zoneUpdate") + + req.AuthToken = c.apiKey + + // but we'll need the ID later to delete the record + response := &BaseResponse[*Zone]{} + + rawResp, err := c.post(ctx, endpoint, req, response) + if err != nil { + return nil, err + } + + if response.Status != "success" && response.Status != "pending" { + return nil, fmt.Errorf("unexpected status: %q, %s", response.Status, string(rawResp)) + } + + return response.Response, nil +} + +func (c Client) post(ctx context.Context, endpoint *url.URL, request, result any) ([]byte, error) { + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return raw, nil +} diff --git a/providers/dns/hostingde/internal/client_test.go b/providers/dns/hostingde/internal/client_test.go new file mode 100644 index 00000000..af76d0d2 --- /dev/null +++ b/providers/dns/hostingde/internal/client_test.go @@ -0,0 +1,264 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := NewClient("secret") + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + mux.HandleFunc(pattern, handler) + + return client +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) +} + +func TestClient_ListZoneConfigs(t *testing.T) { + client := setupTest(t, "/zoneConfigsFind", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + raw, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + body := string(bytes.TrimSpace(raw)) + if body != `{"authToken":"secret","filter":{"field":"zoneName","value":"example.com"},"limit":1,"page":1}` { + http.Error(rw, fmt.Sprintf("unexpected body: got %s", body), http.StatusBadRequest) + return + } + + writeFixture(rw, "zoneConfigsFind.json") + }) + + zonesFind := ZoneConfigsFindRequest{ + Filter: Filter{Field: "zoneName", Value: "example.com"}, + Limit: 1, + Page: 1, + } + + zoneResponse, err := client.ListZoneConfigs(context.Background(), zonesFind) + require.NoError(t, err) + + expected := &ZoneResponse{ + Limit: 10, + Page: 1, + TotalEntries: 15, + TotalPages: 2, + Type: "FindZoneConfigsResult", + Data: []ZoneConfig{{ + ID: "123", + AccountID: "456", + Status: "s", + Name: "n", + NameUnicode: "u", + MasterIP: "m", + Type: "t", + EMailAddress: "e", + ZoneTransferWhitelist: []string{"a", "b"}, + LastChangeDate: "l", + DNSServerGroupID: "g", + DNSSecMode: "m", + SOAValues: &SOAValues{ + Refresh: 1, + Retry: 2, + Expire: 3, + TTL: 4, + NegativeTTL: 5, + }, + TemplateValues: json.RawMessage(nil), + }}, + } + + assert.Equal(t, expected, zoneResponse) +} + +func TestClient_ListZoneConfigs_error(t *testing.T) { + client := setupTest(t, "/zoneConfigsFind", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + writeFixture(rw, "zoneConfigsFind_error.json") + }) + + zonesFind := ZoneConfigsFindRequest{ + Filter: Filter{Field: "zoneName", Value: "example.com"}, + Limit: 1, + Page: 1, + } + + _, err := client.ListZoneConfigs(context.Background(), zonesFind) + require.Error(t, err) +} + +func TestClient_UpdateZone(t *testing.T) { + client := setupTest(t, "/zoneUpdate", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + raw, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + body := string(bytes.TrimSpace(raw)) + if body != `{"authToken":"secret","zoneConfig":{"id":"123","accountId":"456","status":"s","name":"n","nameUnicode":"u","masterIp":"m","type":"t","emailAddress":"e","zoneTransferWhitelist":["a","b"],"lastChangeDate":"l","dnsServerGroupId":"g","dnsSecMode":"m","soaValues":{"refresh":1,"retry":2,"expire":3,"ttl":4,"negativeTtl":5}},"recordsToAdd":null,"recordsToDelete":[{"name":"_acme-challenge.example.com","type":"TXT","content":"\"txt\""}]}` { + http.Error(rw, fmt.Sprintf("unexpected body: got %s", body), http.StatusBadRequest) + return + } + + writeFixture(rw, "zoneUpdate.json") + }) + + request := ZoneUpdateRequest{ + ZoneConfig: ZoneConfig{ + ID: "123", + AccountID: "456", + Status: "s", + Name: "n", + NameUnicode: "u", + MasterIP: "m", + Type: "t", + EMailAddress: "e", + ZoneTransferWhitelist: []string{"a", "b"}, + LastChangeDate: "l", + DNSServerGroupID: "g", + DNSSecMode: "m", + SOAValues: &SOAValues{ + Refresh: 1, + Retry: 2, + Expire: 3, + TTL: 4, + NegativeTTL: 5, + }, + }, + RecordsToDelete: []DNSRecord{{ + Type: "TXT", + Name: "_acme-challenge.example.com", + Content: `"txt"`, + }}, + } + + response, err := client.UpdateZone(context.Background(), request) + require.NoError(t, err) + + expected := &Zone{ + Records: []DNSRecord{{ + ID: "123", + ZoneID: "456", + RecordTemplateID: "789", + Name: "n", + Type: "TXT", + Content: "txt", + TTL: 120, + Priority: 5, + LastChangeDate: "d", + }}, + ZoneConfig: ZoneConfig{ + ID: "123", + AccountID: "456", + Status: "s", + Name: "n", + NameUnicode: "u", + MasterIP: "m", + Type: "t", + EMailAddress: "e", + ZoneTransferWhitelist: []string{"a", "b"}, + LastChangeDate: "l", + DNSServerGroupID: "g", + DNSSecMode: "m", + SOAValues: &SOAValues{ + Refresh: 1, + Retry: 2, + Expire: 3, + TTL: 4, + NegativeTTL: 5, + }, + }, + } + + assert.Equal(t, expected, response) +} + +func TestClient_UpdateZone_error(t *testing.T) { + client := setupTest(t, "/zoneUpdate", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + writeFixture(rw, "zoneUpdate_error.json") + }) + + request := ZoneUpdateRequest{ + ZoneConfig: ZoneConfig{ + ID: "123", + AccountID: "456", + Status: "s", + Name: "n", + NameUnicode: "u", + MasterIP: "m", + Type: "t", + EMailAddress: "e", + ZoneTransferWhitelist: []string{"a", "b"}, + LastChangeDate: "l", + DNSServerGroupID: "g", + DNSSecMode: "m", + SOAValues: &SOAValues{ + Refresh: 1, + Retry: 2, + Expire: 3, + TTL: 4, + NegativeTTL: 5, + }, + }, + RecordsToDelete: []DNSRecord{{ + Type: "TXT", + Name: "_acme-challenge.example.com", + Content: `"txt"`, + }}, + } + + _, err := client.UpdateZone(context.Background(), request) + require.Error(t, err) +} diff --git a/providers/dns/hostingde/internal/fixtures/zoneConfigsFind.json b/providers/dns/hostingde/internal/fixtures/zoneConfigsFind.json new file mode 100644 index 00000000..7c44d5d0 --- /dev/null +++ b/providers/dns/hostingde/internal/fixtures/zoneConfigsFind.json @@ -0,0 +1,44 @@ +{ + "metadata": { + "clientTransactionId": "1", + "serverTransactionId": "2" + }, + "warnings": [ + "aaa", + "bbb" + ], + "status": "success", + "response": { + "limit": 10, + "page": 1, + "totalEntries": 15, + "totalPages": 2, + "type": "FindZoneConfigsResult", + "data": [ + { + "id": "123", + "accountId": "456", + "status": "s", + "name": "n", + "nameUnicode": "u", + "masterIp": "m", + "type": "t", + "emailAddress": "e", + "zoneTransferWhitelist": [ + "a", + "b" + ], + "lastChangeDate": "l", + "dnsServerGroupId": "g", + "dnsSecMode": "m", + "soaValues": { + "refresh": 1, + "retry": 2, + "expire": 3, + "ttl": 4, + "negativeTtl": 5 + } + } + ] + } +} diff --git a/providers/dns/hostingde/internal/fixtures/zoneConfigsFind_error.json b/providers/dns/hostingde/internal/fixtures/zoneConfigsFind_error.json new file mode 100644 index 00000000..101f02f0 --- /dev/null +++ b/providers/dns/hostingde/internal/fixtures/zoneConfigsFind_error.json @@ -0,0 +1,57 @@ +{ + "errors": [ + { + "code": 123, + "contextObject": "o", + "contextPath": "p", + "details": [ + "a", + "b" + ], + "text": "t", + "value": "v" + } + ], + "metadata": { + "clientTransactionId": "1", + "serverTransactionId": "2" + }, + "warnings": [ + "aaa", + "bbb" + ], + "status": "error", + "response": { + "limit": 10, + "page": 1, + "totalEntries": 15, + "totalPages": 2, + "type": "FindZoneConfigsResult", + "data": [ + { + "id": "123", + "accountId": "456", + "status": "s", + "name": "n", + "nameUnicode": "u", + "masterIp": "m", + "type": "t", + "emailAddress": "e", + "zoneTransferWhitelist": [ + "a", + "b" + ], + "lastChangeDate": "l", + "dnsServerGroupId": "g", + "dnsSecMode": "m", + "soaValues": { + "refresh": 1, + "retry": 2, + "expire": 3, + "ttl": 4, + "negativeTtl": 5 + } + } + ] + } +} diff --git a/providers/dns/hostingde/internal/fixtures/zoneUpdate.json b/providers/dns/hostingde/internal/fixtures/zoneUpdate.json new file mode 100644 index 00000000..ac758c07 --- /dev/null +++ b/providers/dns/hostingde/internal/fixtures/zoneUpdate.json @@ -0,0 +1,50 @@ +{ + "metadata": { + "clientTransactionId": "", + "serverTransactionId": "" + }, + "warnings": [ + "aaa", + "bbb" + ], + "status": "success", + "response": { + "records": [ + { + "id": "123", + "zoneId": "456", + "recordTemplateId": "789", + "name": "n", + "type": "TXT", + "content": "txt", + "ttl": 120, + "priority": 5, + "lastChangeDate": "d" + } + ], + "zoneConfig": { + "id": "123", + "accountId": "456", + "status": "s", + "name": "n", + "nameUnicode": "u", + "masterIp": "m", + "type": "t", + "emailAddress": "e", + "zoneTransferWhitelist": [ + "a", + "b" + ], + "lastChangeDate": "l", + "dnsServerGroupId": "g", + "dnsSecMode": "m", + "soaValues": { + "refresh": 1, + "retry": 2, + "expire": 3, + "ttl": 4, + "negativeTtl": 5 + } + } + } +} diff --git a/providers/dns/hostingde/internal/fixtures/zoneUpdate_error.json b/providers/dns/hostingde/internal/fixtures/zoneUpdate_error.json new file mode 100644 index 00000000..74a26508 --- /dev/null +++ b/providers/dns/hostingde/internal/fixtures/zoneUpdate_error.json @@ -0,0 +1,63 @@ +{ + "errors": [ + { + "code": 123, + "contextObject": "o", + "contextPath": "p", + "details": [ + "a", + "b" + ], + "text": "t", + "value": "v" + } + ], + "metadata": { + "clientTransactionId": "", + "serverTransactionId": "" + }, + "warnings": [ + "aaa", + "bbb" + ], + "status": "error", + "response": { + "records": [ + { + "id": "123", + "zoneId": "456", + "recordTemplateId": "789", + "name": "n", + "type": "TXT", + "content": "txt", + "ttl": 120, + "priority": 5, + "lastChangeDate": "d" + } + ], + "zoneConfig": { + "id": "123", + "accountId": "456", + "status": "s", + "name": "n", + "nameUnicode": "u", + "masterIp": "m", + "type": "t", + "emailAddress": "e", + "zoneTransferWhitelist": [ + "a", + "b" + ], + "lastChangeDate": "l", + "dnsServerGroupId": "g", + "dnsSecMode": "m", + "soaValues": { + "refresh": 1, + "retry": 2, + "expire": 3, + "ttl": 4, + "negativeTtl": 5 + } + } + } +} diff --git a/providers/dns/hostingde/model.go b/providers/dns/hostingde/internal/types.go similarity index 84% rename from providers/dns/hostingde/model.go rename to providers/dns/hostingde/internal/types.go index 9c67784b..a706008a 100644 --- a/providers/dns/hostingde/model.go +++ b/providers/dns/hostingde/internal/types.go @@ -1,4 +1,4 @@ -package hostingde +package internal import "encoding/json" @@ -93,13 +93,6 @@ type ZoneUpdateRequest struct { RecordsToDelete []DNSRecord `json:"recordsToDelete"` } -// ZoneUpdateResponse represents a response from the API. -// https://www.hosting.de/api/?json#updating-zones -type ZoneUpdateResponse struct { - BaseResponse - Response Zone `json:"response"` -} - // ZoneConfigsFindRequest represents a API ZonesFind request. // https://www.hosting.de/api/?json#list-zoneconfigs type ZoneConfigsFindRequest struct { @@ -110,27 +103,25 @@ type ZoneConfigsFindRequest struct { Sort *Sort `json:"sort,omitempty"` } -// ZoneConfigsFindResponse represents the API response for ZoneConfigsFind. -// https://www.hosting.de/api/?json#list-zoneconfigs -type ZoneConfigsFindResponse struct { - BaseResponse - Response struct { - Limit int `json:"limit"` - Page int `json:"page"` - TotalEntries int `json:"totalEntries"` - TotalPages int `json:"totalPages"` - Type string `json:"type"` - Data []ZoneConfig `json:"data"` - } `json:"response"` +type ZoneResponse struct { + Limit int `json:"limit"` + Page int `json:"page"` + TotalEntries int `json:"totalEntries"` + TotalPages int `json:"totalPages"` + Type string `json:"type"` + Data []ZoneConfig `json:"data"` } // BaseResponse Common response struct. -// https://www.hosting.de/api/?json#responses -type BaseResponse struct { +// base: https://www.hosting.de/api/?json#responses +// ZoneConfigsFind: https://www.hosting.de/api/?json#list-zoneconfigs +// ZoneUpdate: https://www.hosting.de/api/?json#updating-zones +type BaseResponse[T any] struct { Errors []APIError `json:"errors"` Metadata Metadata `json:"metadata"` Warnings []string `json:"warnings"` Status string `json:"status"` + Response T `json:"response"` } // BaseRequest Common request struct. diff --git a/providers/dns/hosttech/hosttech.go b/providers/dns/hosttech/hosttech.go index bd153150..41073f3c 100644 --- a/providers/dns/hosttech/hosttech.go +++ b/providers/dns/hosttech/hosttech.go @@ -2,6 +2,7 @@ package hosttech import ( + "context" "errors" "fmt" "net/http" @@ -80,11 +81,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("hosttech: missing credentials") } - client := internal.NewClient(config.APIKey) - - if config.HTTPClient != nil { - client.HTTPClient = config.HTTPClient - } + client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.APIKey)) return &DNSProvider{ config: config, @@ -105,10 +102,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hosttech: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("hosttech: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zone, err := d.client.GetZone(dns01.UnFqdn(authZone)) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("hosttech: could not find zone for domain %q (%s): %w", domain, authZone, err) } @@ -125,7 +124,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - newRecord, err := d.client.AddRecord(strconv.Itoa(zone.ID), record) + newRecord, err := d.client.AddRecord(ctx, strconv.Itoa(zone.ID), record) if err != nil { return fmt.Errorf("hosttech: %w", err) } @@ -143,10 +142,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("hosttech: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("hosttech: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zone, err := d.client.GetZone(dns01.UnFqdn(authZone)) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("hosttech: could not find zone for domain %q (%s): %w", domain, authZone, err) } @@ -159,7 +160,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("hosttech: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - err = d.client.DeleteRecord(strconv.Itoa(zone.ID), strconv.Itoa(recordID)) + err = d.client.DeleteRecord(ctx, strconv.Itoa(zone.ID), strconv.Itoa(recordID)) if err != nil { return fmt.Errorf("hosttech: %w", err) } diff --git a/providers/dns/hosttech/internal/client.go b/providers/dns/hosttech/internal/client.go index 0786674c..78b59455 100644 --- a/providers/dns/hosttech/internal/client.go +++ b/providers/dns/hosttech/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -9,32 +10,33 @@ import ( "net/url" "strconv" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" ) const defaultBaseURL = "https://api.ns1.hosttech.eu/api" // Client a Hosttech client. type Client struct { - HTTPClient *http.Client baseURL *url.URL - - apiKey string + httpClient *http.Client } // NewClient creates a new Client. -func NewClient(apiKey string) *Client { +func NewClient(hc *http.Client) *Client { baseURL, _ := url.Parse(defaultBaseURL) - return &Client{ - HTTPClient: &http.Client{Timeout: 10 * time.Second}, - baseURL: baseURL, - apiKey: apiKey, + if hc == nil { + hc = &http.Client{Timeout: 10 * time.Second} } + + return &Client{baseURL: baseURL, httpClient: hc} } // GetZones Get a list of all zones. // https://api.ns1.hosttech.eu/api/documentation/#/Zones/get_api_user_v1_zones -func (c Client) GetZones(query string, limit, offset int) ([]Zone, error) { +func (c Client) GetZones(ctx context.Context, query string, limit, offset int) ([]Zone, error) { endpoint := c.baseURL.JoinPath("user", "v1", "zones") values := endpoint.Query() @@ -50,52 +52,42 @@ func (c Client) GetZones(query string, limit, offset int) ([]Zone, error) { endpoint.RawQuery = values.Encode() - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - raw, err := c.do(req) + result := apiResponse[[]Zone]{} + err = c.do(req, &result) if err != nil { return nil, err } - var r []Zone - err = json.Unmarshal(raw, &r) - if err != nil { - return nil, fmt.Errorf("unmarshal response data: %s: %w", string(raw), err) - } - - return r, nil + return result.Data, nil } // GetZone Get a single zone. // https://api.ns1.hosttech.eu/api/documentation/#/Zones/get_api_user_v1_zones__zoneId_ -func (c Client) GetZone(zoneID string) (*Zone, error) { +func (c Client) GetZone(ctx context.Context, zoneID string) (*Zone, error) { endpoint := c.baseURL.JoinPath("user", "v1", "zones", zoneID) - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - raw, err := c.do(req) + result := apiResponse[*Zone]{} + err = c.do(req, &result) if err != nil { return nil, err } - var r Zone - err = json.Unmarshal(raw, &r) - if err != nil { - return nil, fmt.Errorf("unmarshal response data: %s: %w", string(raw), err) - } - - return &r, nil + return result.Data, nil } // GetRecords Returns a list of all records for the given zone. // https://api.ns1.hosttech.eu/api/documentation/#/Records/get_api_user_v1_zones__zoneId__records -func (c Client) GetRecords(zoneID, recordType string) ([]Record, error) { +func (c Client) GetRecords(ctx context.Context, zoneID, recordType string) ([]Record, error) { endpoint := c.baseURL.JoinPath("user", "v1", "zones", zoneID, "records") values := endpoint.Query() @@ -106,107 +98,127 @@ func (c Client) GetRecords(zoneID, recordType string) ([]Record, error) { endpoint.RawQuery = values.Encode() - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - raw, err := c.do(req) + result := apiResponse[[]Record]{} + err = c.do(req, &result) if err != nil { return nil, err } - var r []Record - err = json.Unmarshal(raw, &r) - if err != nil { - return nil, fmt.Errorf("unmarshal response data: %s: %w", string(raw), err) - } - - return r, nil + return result.Data, nil } // AddRecord Adds a new record to the zone and returns the newly created record. // https://api.ns1.hosttech.eu/api/documentation/#/Records/post_api_user_v1_zones__zoneId__records -func (c Client) AddRecord(zoneID string, record Record) (*Record, error) { +func (c Client) AddRecord(ctx context.Context, zoneID string, record Record) (*Record, error) { endpoint := c.baseURL.JoinPath("user", "v1", "zones", zoneID, "records") - body, err := json.Marshal(record) - if err != nil { - return nil, fmt.Errorf("marshal request data: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - raw, err := c.do(req) + result := apiResponse[*Record]{} + err = c.do(req, &result) if err != nil { return nil, err } - var r Record - err = json.Unmarshal(raw, &r) - if err != nil { - return nil, fmt.Errorf("unmarshal response data: %s: %w", string(raw), err) - } - - return &r, nil + return result.Data, nil } // DeleteRecord Deletes a single record for the given id. // https://api.ns1.hosttech.eu/api/documentation/#/Records/delete_api_user_v1_zones__zoneId__records__recordId_ -func (c Client) DeleteRecord(zoneID, recordID string) error { +func (c Client) DeleteRecord(ctx context.Context, zoneID, recordID string) error { endpoint := c.baseURL.JoinPath("user", "v1", "zones", zoneID, "records", recordID) - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return fmt.Errorf("create request: %w", err) } - _, err = c.do(req) - - return err + return c.do(req, nil) } -func (c Client) do(req *http.Request) (json.RawMessage, error) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - - resp, errD := c.HTTPClient.Do(req) +func (c Client) do(req *http.Request, result any) error { + resp, errD := c.httpClient.Do(req) if errD != nil { - return nil, fmt.Errorf("send request: %w", errD) + return errutils.NewHTTPDoError(req, errD) } + defer func() { _ = resp.Body.Close() }() switch resp.StatusCode { case http.StatusOK, http.StatusCreated: - all, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("read response: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - var r apiResponse - err = json.Unmarshal(all, &r) + err = json.Unmarshal(raw, result) if err != nil { - return nil, fmt.Errorf("unmarshal response: %w", err) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } - return r.Data, nil + return nil case http.StatusNoContent: - return nil, nil + return nil default: - data, _ := io.ReadAll(resp.Body) - - e := APIError{StatusCode: resp.StatusCode} - err := json.Unmarshal(data, &e) - if err != nil { - e.Message = string(data) - } - - return nil, e + return parseError(req, resp) } } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := &APIError{StatusCode: resp.StatusCode} + err := json.Unmarshal(raw, errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errAPI +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client +} diff --git a/providers/dns/hosttech/internal/client_test.go b/providers/dns/hosttech/internal/client_test.go index b1073cfe..bf90acc9 100644 --- a/providers/dns/hosttech/internal/client_test.go +++ b/providers/dns/hosttech/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -19,7 +20,7 @@ const testAPIKey = "secret" func TestClient_GetZones(t *testing.T) { client := setupTest(t, "/user/v1/zones", testHandler(http.MethodGet, http.StatusOK, "zones.json")) - zones, err := client.GetZones("", 100, 0) + zones, err := client.GetZones(context.Background(), "", 100, 0) require.NoError(t, err) expected := []Zone{ @@ -40,14 +41,14 @@ func TestClient_GetZones(t *testing.T) { func TestClient_GetZones_error(t *testing.T) { client := setupTest(t, "/user/v1/zones", testHandler(http.MethodGet, http.StatusUnauthorized, "error.json")) - _, err := client.GetZones("", 100, 0) + _, err := client.GetZones(context.Background(), "", 100, 0) require.Error(t, err) } func TestClient_GetZone(t *testing.T) { client := setupTest(t, "/user/v1/zones/123", testHandler(http.MethodGet, http.StatusOK, "zone.json")) - zone, err := client.GetZone("123") + zone, err := client.GetZone(context.Background(), "123") require.NoError(t, err) expected := &Zone{ @@ -66,14 +67,14 @@ func TestClient_GetZone(t *testing.T) { func TestClient_GetZone_error(t *testing.T) { client := setupTest(t, "/user/v1/zones/123", testHandler(http.MethodGet, http.StatusUnauthorized, "error.json")) - _, err := client.GetZone("123") + _, err := client.GetZone(context.Background(), "123") require.Error(t, err) } func TestClient_GetRecords(t *testing.T) { client := setupTest(t, "/user/v1/zones/123/records", testHandler(http.MethodGet, http.StatusOK, "records.json")) - records, err := client.GetRecords("123", "TXT") + records, err := client.GetRecords(context.Background(), "123", "TXT") require.NoError(t, err) expected := []Record{ @@ -153,7 +154,7 @@ func TestClient_GetRecords(t *testing.T) { func TestClient_GetRecords_error(t *testing.T) { client := setupTest(t, "/user/v1/zones/123/records", testHandler(http.MethodGet, http.StatusUnauthorized, "error.json")) - _, err := client.GetRecords("123", "TXT") + _, err := client.GetRecords(context.Background(), "123", "TXT") require.Error(t, err) } @@ -168,7 +169,7 @@ func TestClient_AddRecord(t *testing.T) { Comment: "example", } - newRecord, err := client.AddRecord("123", record) + newRecord, err := client.AddRecord(context.Background(), "123", record) require.NoError(t, err) expected := &Record{ @@ -194,21 +195,21 @@ func TestClient_AddRecord_error(t *testing.T) { Comment: "example", } - _, err := client.AddRecord("123", record) + _, err := client.AddRecord(context.Background(), "123", record) require.Error(t, err) } func TestClient_DeleteRecord(t *testing.T) { client := setupTest(t, "/user/v1/zones/123/records/6", testHandler(http.MethodDelete, http.StatusUnauthorized, "error.json")) - err := client.DeleteRecord("123", "6") + err := client.DeleteRecord(context.Background(), "123", "6") require.Error(t, err) } func TestClient_DeleteRecord_error(t *testing.T) { client := setupTest(t, "/user/v1/zones/123/records/6", testHandler(http.MethodDelete, http.StatusNoContent, "")) - err := client.DeleteRecord("123", "6") + err := client.DeleteRecord(context.Background(), "123", "6") require.NoError(t, err) } @@ -221,7 +222,7 @@ func setupTest(t *testing.T, path string, handler http.Handler) *Client { mux.Handle(path, handler) - client := NewClient(testAPIKey) + client := NewClient(OAuthStaticAccessToken(server.Client(), testAPIKey)) client.baseURL, _ = url.Parse(server.URL) return client diff --git a/providers/dns/hosttech/internal/types.go b/providers/dns/hosttech/internal/types.go index 53489e60..bf86964f 100644 --- a/providers/dns/hosttech/internal/types.go +++ b/providers/dns/hosttech/internal/types.go @@ -1,18 +1,17 @@ package internal import ( - "encoding/json" "fmt" ) -type apiResponse struct { - Data json.RawMessage `json:"data"` +type apiResponse[T any] struct { + Data T `json:"data"` } type APIError struct { - Message string `json:"message,omitempty"` - Errors map[string]interface{} `json:"errors,omitempty"` - StatusCode int `json:"-"` + Message string `json:"message,omitempty"` + Errors map[string]any `json:"errors,omitempty"` + StatusCode int `json:"-"` } func (a APIError) Error() string { diff --git a/providers/dns/httpreq/httpreq.go b/providers/dns/httpreq/httpreq.go index f7b97027..782f9a2a 100644 --- a/providers/dns/httpreq/httpreq.go +++ b/providers/dns/httpreq/httpreq.go @@ -3,16 +3,17 @@ package httpreq import ( "bytes" + "context" "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // Environment variables names. @@ -108,6 +109,8 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { + ctx := context.Background() + if d.config.Mode == "RAW" { msg := &messageRaw{ Domain: domain, @@ -115,7 +118,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { KeyAuth: keyAuth, } - err := d.doPost("/present", msg) + err := d.doPost(ctx, "/present", msg) if err != nil { return fmt.Errorf("httpreq: %w", err) } @@ -128,7 +131,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Value: info.Value, } - err := d.doPost("/present", msg) + err := d.doPost(ctx, "/present", msg) if err != nil { return fmt.Errorf("httpreq: %w", err) } @@ -137,6 +140,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { + ctx := context.Background() + if d.config.Mode == "RAW" { msg := &messageRaw{ Domain: domain, @@ -144,7 +149,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { KeyAuth: keyAuth, } - err := d.doPost("/cleanup", msg) + err := d.doPost(ctx, "/cleanup", msg) if err != nil { return fmt.Errorf("httpreq: %w", err) } @@ -157,46 +162,43 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { Value: info.Value, } - err := d.doPost("/cleanup", msg) + err := d.doPost(ctx, "/cleanup", msg) if err != nil { return fmt.Errorf("httpreq: %w", err) } return nil } -func (d *DNSProvider) doPost(uri string, msg interface{}) error { - reqBody := &bytes.Buffer{} +func (d *DNSProvider) doPost(ctx context.Context, uri string, msg any) error { + reqBody := new(bytes.Buffer) err := json.NewEncoder(reqBody).Encode(msg) if err != nil { - return err + return fmt.Errorf("failed to create request JSON body: %w", err) } endpoint := d.config.Endpoint.JoinPath(uri) - req, err := http.NewRequest(http.MethodPost, endpoint.String(), reqBody) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), reqBody) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } + req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - if len(d.config.Username) > 0 && len(d.config.Password) > 0 { + if d.config.Username != "" && d.config.Password != "" { req.SetBasicAuth(d.config.Username, d.config.Password) } resp, err := d.config.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } - defer resp.Body.Close() - if resp.StatusCode >= http.StatusBadRequest { - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("%d: failed to read response body: %w", resp.StatusCode, err) - } + defer func() { _ = resp.Body.Close() }() - return fmt.Errorf("%d: request failed: %v", resp.StatusCode, string(body)) + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil diff --git a/providers/dns/httpreq/httpreq_test.go b/providers/dns/httpreq/httpreq_test.go index 99c56371..a545bd17 100644 --- a/providers/dns/httpreq/httpreq_test.go +++ b/providers/dns/httpreq/httpreq_test.go @@ -121,7 +121,7 @@ func TestNewDNSProvider_Present(t *testing.T) { { desc: "error", handler: http.NotFound, - expectedError: "httpreq: 404: request failed: 404 page not found\n", + expectedError: "httpreq: unexpected status code: [status code: 404] body: 404 page not found", }, { desc: "success raw mode", @@ -132,7 +132,7 @@ func TestNewDNSProvider_Present(t *testing.T) { desc: "error raw mode", mode: "RAW", handler: http.NotFound, - expectedError: "httpreq: 404: request failed: 404 page not found\n", + expectedError: "httpreq: unexpected status code: [status code: 404] body: 404 page not found", }, { desc: "basic auth", @@ -157,11 +157,11 @@ func TestNewDNSProvider_Present(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.HandleFunc(path.Join("/", test.pathPrefix, "present"), test.handler) - server := httptest.NewServer(mux) t.Cleanup(server.Close) + mux.HandleFunc(path.Join("/", test.pathPrefix, "present"), test.handler) + config := NewDefaultConfig() config.Endpoint = mustParse(server.URL + test.pathPrefix) config.Mode = test.mode @@ -199,7 +199,7 @@ func TestNewDNSProvider_Cleanup(t *testing.T) { { desc: "error", handler: http.NotFound, - expectedError: "httpreq: 404: request failed: 404 page not found\n", + expectedError: "httpreq: unexpected status code: [status code: 404] body: 404 page not found", }, { desc: "success raw mode", @@ -210,7 +210,7 @@ func TestNewDNSProvider_Cleanup(t *testing.T) { desc: "error raw mode", mode: "RAW", handler: http.NotFound, - expectedError: "httpreq: 404: request failed: 404 page not found\n", + expectedError: "httpreq: unexpected status code: [status code: 404] body: 404 page not found", }, { desc: "basic auth", @@ -234,11 +234,11 @@ func TestNewDNSProvider_Cleanup(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.HandleFunc("/cleanup", test.handler) - server := httptest.NewServer(mux) t.Cleanup(server.Close) + mux.HandleFunc("/cleanup", test.handler) + config := NewDefaultConfig() config.Endpoint = mustParse(server.URL) config.Mode = test.mode diff --git a/providers/dns/hurricane/internal/client.go b/providers/dns/hurricane/internal/client.go index e5848112..bbc90758 100644 --- a/providers/dns/hurricane/internal/client.go +++ b/providers/dns/hurricane/internal/client.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" "golang.org/x/time/rate" ) @@ -59,7 +60,7 @@ func (c *Client) UpdateTxtRecord(ctx context.Context, hostname string, txt strin c.credMu.Unlock() if !ok { - return fmt.Errorf("hurricane: Domain %s not found in credentials, check your credentials map", domain) + return fmt.Errorf("domain %s not found in credentials, check your credentials map", domain) } data := url.Values{} @@ -67,32 +68,37 @@ func (c *Client) UpdateTxtRecord(ctx context.Context, hostname string, txt strin data.Set("hostname", hostname) data.Set("txt", txt) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, strings.NewReader(data.Encode())) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rl, _ := c.rateLimiters.LoadOrStore(hostname, rate.NewLimiter(limit(defaultBurst), defaultBurst)) - err := rl.(*rate.Limiter).Wait(ctx) + err = rl.(*rate.Limiter).Wait(ctx) if err != nil { return err } - resp, err := c.HTTPClient.PostForm(c.baseURL, data) + resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - body := string(bytes.TrimSpace(bodyBytes)) - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("%d: attempt to change TXT record %s returned %s", resp.StatusCode, hostname, body) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - return evaluateBody(body, hostname) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + return evaluateBody(string(bytes.TrimSpace(raw)), hostname) } func evaluateBody(body string, hostname string) error { diff --git a/providers/dns/hurricane/internal/client_test.go b/providers/dns/hurricane/internal/client_test.go index f68d9b73..16d4f60f 100644 --- a/providers/dns/hurricane/internal/client_test.go +++ b/providers/dns/hurricane/internal/client_test.go @@ -74,6 +74,7 @@ func TestClient_UpdateTxtRecord(t *testing.T) { client := NewClient(map[string]string{"example.com": "secret"}) client.baseURL = server.URL + client.HTTPClient = server.Client() err := client.UpdateTxtRecord(context.Background(), "_acme-challenge.example.com", "foo") test.expected(t, err) diff --git a/providers/dns/hyperone/hyperone.go b/providers/dns/hyperone/hyperone.go index c5fcdc0c..5e23c0a3 100644 --- a/providers/dns/hyperone/hyperone.go +++ b/providers/dns/hyperone/hyperone.go @@ -2,6 +2,7 @@ package hyperone import ( + "context" "fmt" "net/http" "os" @@ -105,18 +106,20 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getHostedZone(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.getHostedZone(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("hyperone: failed to get zone for fqdn=%s: %w", info.EffectiveFQDN, err) } - recordset, err := d.client.FindRecordset(zone.ID, "TXT", info.EffectiveFQDN) + recordset, err := d.client.FindRecordset(ctx, zone.ID, "TXT", info.EffectiveFQDN) if err != nil { return fmt.Errorf("hyperone: fqdn=%s, zone ID=%s: %w", info.EffectiveFQDN, zone.ID, err) } if recordset == nil { - _, err = d.client.CreateRecordset(zone.ID, "TXT", info.EffectiveFQDN, info.Value, d.config.TTL) + _, err = d.client.CreateRecordset(ctx, zone.ID, "TXT", info.EffectiveFQDN, info.Value, d.config.TTL) if err != nil { return fmt.Errorf("hyperone: failed to create recordset: fqdn=%s, zone ID=%s, value=%s: %w", info.EffectiveFQDN, zone.ID, info.Value, err) } @@ -124,7 +127,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return nil } - _, err = d.client.CreateRecord(zone.ID, recordset.ID, info.Value) + _, err = d.client.CreateRecord(ctx, zone.ID, recordset.ID, info.Value) if err != nil { return fmt.Errorf("hyperone: failed to create record: fqdn=%s, zone ID=%s, recordset ID=%s: %w", info.EffectiveFQDN, zone.ID, recordset.ID, err) } @@ -137,12 +140,14 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getHostedZone(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.getHostedZone(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("hyperone: failed to get zone for fqdn=%s: %w", info.EffectiveFQDN, err) } - recordset, err := d.client.FindRecordset(zone.ID, "TXT", info.EffectiveFQDN) + recordset, err := d.client.FindRecordset(ctx, zone.ID, "TXT", info.EffectiveFQDN) if err != nil { return fmt.Errorf("hyperone: fqdn=%s, zone ID=%s: %w", info.EffectiveFQDN, zone.ID, err) } @@ -151,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { return fmt.Errorf("hyperone: recordset to remove not found: fqdn=%s", info.EffectiveFQDN) } - records, err := d.client.GetRecords(zone.ID, recordset.ID) + records, err := d.client.GetRecords(ctx, zone.ID, recordset.ID) if err != nil { return fmt.Errorf("hyperone: %w", err) } @@ -160,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { return fmt.Errorf("hyperone: record with content %s not found: fqdn=%s", info.Value, info.EffectiveFQDN) } - err = d.client.DeleteRecordset(zone.ID, recordset.ID) + err = d.client.DeleteRecordset(ctx, zone.ID, recordset.ID) if err != nil { return fmt.Errorf("hyperone: failed to delete record: fqdn=%s, zone ID=%s, recordset ID=%s: %w", info.EffectiveFQDN, zone.ID, recordset.ID, err) } @@ -170,7 +175,7 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { for _, record := range records { if record.Content == info.Value { - err = d.client.DeleteRecord(zone.ID, recordset.ID, record.ID) + err = d.client.DeleteRecord(ctx, zone.ID, recordset.ID, record.ID) if err != nil { return fmt.Errorf("hyperone: fqdn=%s, zone ID=%s, recordset ID=%s, record ID=%s: %w", info.EffectiveFQDN, zone.ID, recordset.ID, record.ID, err) } @@ -183,13 +188,13 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { } // getHostedZone gets the hosted zone. -func (d *DNSProvider) getHostedZone(fqdn string) (*internal.Zone, error) { +func (d *DNSProvider) getHostedZone(ctx context.Context, fqdn string) (*internal.Zone, error) { authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return nil, err + return nil, fmt.Errorf("hetzner: could not find zone for FQDN %q: %w", fqdn, err) } - return d.client.FindZone(authZone) + return d.client.FindZone(ctx, authZone) } func GetDefaultPassportLocation() (string, error) { diff --git a/providers/dns/hyperone/internal/client.go b/providers/dns/hyperone/internal/client.go index 1231d9bc..09fa6876 100644 --- a/providers/dns/hyperone/internal/client.go +++ b/providers/dns/hyperone/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -9,6 +10,8 @@ import ( "net/http" "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.hyperone.com/v2" @@ -21,12 +24,11 @@ type signer interface { // Client the HyperOne client. type Client struct { - HTTPClient *http.Client - - apiEndpoint *url.URL - passport *Passport signer signer + + baseURL *url.URL + HTTPClient *http.Client } // NewClient Creates a new HyperOne client. @@ -62,10 +64,10 @@ func NewClient(apiEndpoint, locationID string, passport *Passport) (*Client, err } client := &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - apiEndpoint: baseURL.JoinPath("dns", locationID, "project", projectID), - passport: passport, - signer: tokenSigner, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + baseURL: baseURL.JoinPath("dns", locationID, "project", projectID), + passport: passport, + signer: tokenSigner, } return client, nil @@ -74,11 +76,11 @@ func NewClient(apiEndpoint, locationID string, passport *Passport) (*Client, err // FindRecordset looks for recordset with given recordType and name and returns it. // In case if recordset is not found returns nil. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_list -func (c *Client) FindRecordset(zoneID, recordType, name string) (*Recordset, error) { +func (c *Client) FindRecordset(ctx context.Context, zoneID, recordType, name string) (*Recordset, error) { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset") + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset") - req, err := c.createRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -102,7 +104,10 @@ func (c *Client) FindRecordset(zoneID, recordType, name string) (*Recordset, err // CreateRecordset creates recordset and record with given value within one request. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_create -func (c *Client) CreateRecordset(zoneID, recordType, name, recordValue string, ttl int) (*Recordset, error) { +func (c *Client) CreateRecordset(ctx context.Context, zoneID, recordType, name, recordValue string, ttl int) (*Recordset, error) { + // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset") + recordsetInput := Recordset{ RecordType: recordType, Name: name, @@ -110,15 +115,7 @@ func (c *Client) CreateRecordset(zoneID, recordType, name, recordValue string, t Record: &Record{Content: recordValue}, } - requestBody, err := json.Marshal(recordsetInput) - if err != nil { - return nil, fmt.Errorf("failed to marshal recordset: %w", err) - } - - // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset") - - req, err := c.createRequest(http.MethodPost, endpoint.String(), bytes.NewBuffer(requestBody)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, recordsetInput) if err != nil { return nil, err } @@ -135,11 +132,11 @@ func (c *Client) CreateRecordset(zoneID, recordType, name, recordValue string, t // DeleteRecordset deletes a recordset. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_delete -func (c *Client) DeleteRecordset(zoneID string, recordsetID string) error { +func (c *Client) DeleteRecordset(ctx context.Context, zoneID string, recordsetID string) error { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset/{recordsetId} - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset", recordsetID) + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset", recordsetID) - req, err := c.createRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } @@ -149,11 +146,11 @@ func (c *Client) DeleteRecordset(zoneID string, recordsetID string) error { // GetRecords gets all records within specified recordset. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_record_list -func (c *Client) GetRecords(zoneID string, recordsetID string) ([]Record, error) { +func (c *Client) GetRecords(ctx context.Context, zoneID string, recordsetID string) ([]Record, error) { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset/{recordsetId}/record - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset", recordsetID, "record") + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset", recordsetID, "record") - req, err := c.createRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -170,16 +167,11 @@ func (c *Client) GetRecords(zoneID string, recordsetID string) ([]Record, error) // CreateRecord creates a record. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_record_create -func (c *Client) CreateRecord(zoneID, recordsetID, recordContent string) (*Record, error) { +func (c *Client) CreateRecord(ctx context.Context, zoneID, recordsetID, recordContent string) (*Record, error) { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset/{recordsetId}/record - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset", recordsetID, "record") + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset", recordsetID, "record") - requestBody, err := json.Marshal(Record{Content: recordContent}) - if err != nil { - return nil, fmt.Errorf("failed to marshal record: %w", err) - } - - req, err := c.createRequest(http.MethodPost, endpoint.String(), bytes.NewBuffer(requestBody)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, Record{Content: recordContent}) if err != nil { return nil, err } @@ -196,11 +188,11 @@ func (c *Client) CreateRecord(zoneID, recordsetID, recordContent string) (*Recor // DeleteRecord deletes a record. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_recordset_record_delete -func (c *Client) DeleteRecord(zoneID, recordsetID, recordID string) error { +func (c *Client) DeleteRecord(ctx context.Context, zoneID, recordsetID, recordID string) error { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone/{zoneId}/recordset/{recordsetId}/record/{recordId} - endpoint := c.apiEndpoint.JoinPath("zone", zoneID, "recordset", recordsetID, "record", recordID) + endpoint := c.baseURL.JoinPath("zone", zoneID, "recordset", recordsetID, "record", recordID) - req, err := c.createRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } @@ -209,8 +201,8 @@ func (c *Client) DeleteRecord(zoneID, recordsetID, recordID string) error { } // FindZone looks for DNS Zone and returns nil if it does not exist. -func (c *Client) FindZone(name string) (*Zone, error) { - zones, err := c.GetZones() +func (c *Client) FindZone(ctx context.Context, name string) (*Zone, error) { + zones, err := c.GetZones(ctx) if err != nil { return nil, err } @@ -226,11 +218,11 @@ func (c *Client) FindZone(name string) (*Zone, error) { // GetZones gets all user's zones. // https://api.hyperone.com/v2/docs#operation/dns_project_zone_list -func (c *Client) GetZones() ([]Zone, error) { +func (c *Client) GetZones(ctx context.Context) ([]Zone, error) { // https://api.hyperone.com/v2/dns/{locationId}/project/{projectId}/zone - endpoint := c.apiEndpoint.JoinPath("zone") + endpoint := c.baseURL.JoinPath("zone") - req, err := c.createRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -245,69 +237,72 @@ func (c *Client) GetZones() ([]Zone, error) { return zones, nil } -func (c *Client) createRequest(method, endpoint string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, endpoint, body) - if err != nil { - return nil, err - } - +func (c *Client) do(req *http.Request, result any) error { jwt, err := c.signer.GetJWT() if err != nil { - return nil, fmt.Errorf("failed to sign the request: %w", err) + return fmt.Errorf("failed to sign the request: %w", err) } req.Header.Set("Authorization", "Bearer "+jwt) - req.Header.Set("Content-Type", "application/json") - return req, nil -} - -func (c *Client) do(req *http.Request, v interface{}) error { resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - err = checkResponse(resp) - if err != nil { - return err + if resp.StatusCode/100 != 2 { + return parseError(req, resp) } - if v == nil { + if result == nil { return nil } raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read body: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - if err = json.Unmarshal(raw, v); err != nil { - return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) + if err = json.Unmarshal(raw, result); err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } -func checkResponse(resp *http.Response) error { - if resp.StatusCode/100 == 2 { - return nil +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } } + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { var msg string if resp.StatusCode == http.StatusForbidden { msg = "forbidden: check if service account you are trying to use has permissions required for managing DNS" } else { - msg = fmt.Sprintf("%d: unknown error", resp.StatusCode) + msg = "unknown error" } - // add response body to error message if not empty - responseBody, _ := io.ReadAll(resp.Body) - if len(responseBody) > 0 { - msg = fmt.Sprintf("%s: %s", msg, string(responseBody)) - } - - return errors.New(msg) + return fmt.Errorf("%s: %w", msg, errutils.NewUnexpectedResponseStatusCodeError(req, resp)) } diff --git a/providers/dns/hyperone/internal/client_test.go b/providers/dns/hyperone/internal/client_test.go index 2f503094..e3a1073e 100644 --- a/providers/dns/hyperone/internal/client_test.go +++ b/providers/dns/hyperone/internal/client_test.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -24,7 +25,7 @@ func (s signerMock) GetJWT() (string, error) { func TestClient_FindRecordset(t *testing.T) { client := setupTest(t, http.MethodGet, "/dns/loc123/project/proj123/zone/zone321/recordset", respFromFile("recordset.json")) - recordset, err := client.FindRecordset("zone321", "SOA", "example.com.") + recordset, err := client.FindRecordset(context.Background(), "zone321", "SOA", "example.com.") require.NoError(t, err) expected := &Recordset{ @@ -48,7 +49,7 @@ func TestClient_CreateRecordset(t *testing.T) { client := setupTest(t, http.MethodPost, "/dns/loc123/project/proj123/zone/zone123/recordset", hasReqBody(expectedReqBody), respFromFile("createRecordset.json")) - rs, err := client.CreateRecordset("zone123", "TXT", "test.example.com.", "value", 3600) + rs, err := client.CreateRecordset(context.Background(), "zone123", "TXT", "test.example.com.", "value", 3600) require.NoError(t, err) expected := &Recordset{RecordType: "TXT", Name: "test.example.com.", TTL: 3600, ID: "1234567890qwertyuiop"} @@ -58,14 +59,14 @@ func TestClient_CreateRecordset(t *testing.T) { func TestClient_DeleteRecordset(t *testing.T) { client := setupTest(t, http.MethodDelete, "/dns/loc123/project/proj123/zone/zone321/recordset/rs322") - err := client.DeleteRecordset("zone321", "rs322") + err := client.DeleteRecordset(context.Background(), "zone321", "rs322") require.NoError(t, err) } func TestClient_GetRecords(t *testing.T) { client := setupTest(t, http.MethodGet, "/dns/loc123/project/proj123/zone/321/recordset/322/record", respFromFile("record.json")) - records, err := client.GetRecords("321", "322") + records, err := client.GetRecords(context.Background(), "321", "322") require.NoError(t, err) expected := []Record{ @@ -87,7 +88,7 @@ func TestClient_CreateRecord(t *testing.T) { client := setupTest(t, http.MethodPost, "/dns/loc123/project/proj123/zone/z123/recordset/rs325/record", hasReqBody(expectedReqBody), respFromFile("createRecord.json")) - rs, err := client.CreateRecord("z123", "rs325", "value") + rs, err := client.CreateRecord(context.Background(), "z123", "rs325", "value") require.NoError(t, err) expected := &Record{ID: "123321qwerqwewqerq", Content: "value", Enabled: true} @@ -97,14 +98,14 @@ func TestClient_CreateRecord(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) { client := setupTest(t, http.MethodDelete, "/dns/loc123/project/proj123/zone/321/recordset/322/record/323") - err := client.DeleteRecord("321", "322", "323") + err := client.DeleteRecord(context.Background(), "321", "322", "323") require.NoError(t, err) } func TestClient_FindZone(t *testing.T) { client := setupTest(t, http.MethodGet, "/dns/loc123/project/proj123/zone", respFromFile("zones.json")) - zone, err := client.FindZone("example.com") + zone, err := client.FindZone(context.Background(), "example.com") require.NoError(t, err) expected := &Zone{ @@ -121,7 +122,7 @@ func TestClient_FindZone(t *testing.T) { func TestClient_GetZones(t *testing.T) { client := setupTest(t, http.MethodGet, "/dns/loc123/project/proj123/zone", respFromFile("zones.json")) - zones, err := client.GetZones() + zones, err := client.GetZones(context.Background()) require.NoError(t, err) expected := []Zone{ @@ -194,7 +195,7 @@ func hasReqBody(v interface{}) assertHandler { return http.StatusInternalServerError, err } - if !bytes.Equal(marshal, reqBody) { + if !bytes.Equal(marshal, bytes.TrimSpace(reqBody)) { return http.StatusBadRequest, fmt.Errorf("invalid request body, got: %s, expect: %s", string(reqBody), string(marshal)) } diff --git a/providers/dns/hyperone/internal/models.go b/providers/dns/hyperone/internal/types.go similarity index 100% rename from providers/dns/hyperone/internal/models.go rename to providers/dns/hyperone/internal/types.go diff --git a/providers/dns/iij/iij.go b/providers/dns/iij/iij.go index 48838172..ed5b8770 100644 --- a/providers/dns/iij/iij.go +++ b/providers/dns/iij/iij.go @@ -231,7 +231,7 @@ func splitDomain(domain string, zones []string) (string, string, error) { zone = strings.Join(parts[i:], ".") if zoneContains(zone, zones) { baseOwner := strings.Join(parts[0:i], ".") - if len(baseOwner) > 0 { + if baseOwner != "" { baseOwner = "." + baseOwner } owner = "_acme-challenge" + baseOwner diff --git a/providers/dns/iijdpf/client.go b/providers/dns/iijdpf/wrapper.go similarity index 100% rename from providers/dns/iijdpf/client.go rename to providers/dns/iijdpf/wrapper.go diff --git a/providers/dns/infomaniak/infomaniak.go b/providers/dns/infomaniak/infomaniak.go index eafba97c..53d98c4f 100644 --- a/providers/dns/infomaniak/infomaniak.go +++ b/providers/dns/infomaniak/infomaniak.go @@ -2,6 +2,7 @@ package infomaniak import ( + "context" "errors" "fmt" "net/http" @@ -29,8 +30,6 @@ const ( EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" ) -const defaultBaseURL = "https://api.infomaniak.com" - // Config is used to configure the creation of the DNSProvider. type Config struct { APIEndpoint string @@ -44,7 +43,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - APIEndpoint: env.GetOrDefaultString(EnvEndpoint, defaultBaseURL), + APIEndpoint: env.GetOrDefaultString(EnvEndpoint, internal.DefaultBaseURL), TTL: env.GetOrDefaultInt(EnvTTL, 7200), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), @@ -94,10 +93,9 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("infomaniak: missing access token") } - client := internal.New(config.APIEndpoint, config.AccessToken) - - if config.HTTPClient != nil { - client.HTTPClient = config.HTTPClient + client, err := internal.New(internal.OAuthStaticAccessToken(config.HTTPClient, config.AccessToken), config.APIEndpoint) + if err != nil { + return nil, fmt.Errorf("infomaniak: %w", err) } return &DNSProvider{ @@ -112,7 +110,9 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - ikDomain, err := d.client.GetDomainByName(dns01.UnFqdn(info.EffectiveFQDN)) + ctx := context.Background() + + ikDomain, err := d.client.GetDomainByName(ctx, dns01.UnFqdn(info.EffectiveFQDN)) if err != nil { return fmt.Errorf("infomaniak: could not get domain %q: %w", info.EffectiveFQDN, err) } @@ -133,7 +133,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - recordID, err := d.client.CreateDNSRecord(ikDomain, record) + recordID, err := d.client.CreateDNSRecord(ctx, ikDomain, record) if err != nil { return fmt.Errorf("infomaniak: error when calling api to create DNS record: %w", err) } @@ -165,7 +165,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("infomaniak: unknown domain ID for '%s'", info.EffectiveFQDN) } - err := d.client.DeleteDNSRecord(domainID, recordID) + err := d.client.DeleteDNSRecord(context.Background(), domainID, recordID) if err != nil { return fmt.Errorf("infomaniak: could not delete record %q: %w", dns01.UnFqdn(info.EffectiveFQDN), err) } diff --git a/providers/dns/infomaniak/internal/client.go b/providers/dns/infomaniak/internal/client.go index cb27cfb6..886a8966 100644 --- a/providers/dns/infomaniak/internal/client.go +++ b/providers/dns/infomaniak/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -13,71 +14,63 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" ) +// DefaultBaseURL Default API endpoint. +const DefaultBaseURL = "https://api.infomaniak.com" + // Client the Infomaniak client. type Client struct { - apiEndpoint string - apiToken string - HTTPClient *http.Client + baseURL *url.URL + httpClient *http.Client } // New Creates a new Infomaniak client. -func New(apiEndpoint, apiToken string) *Client { - return &Client{ - apiEndpoint: apiEndpoint, - apiToken: apiToken, - HTTPClient: &http.Client{Timeout: 5 * time.Second}, +func New(hc *http.Client, apiEndpoint string) (*Client, error) { + baseURL, err := url.Parse(apiEndpoint) + if err != nil { + return nil, err } + + if hc == nil { + hc = &http.Client{Timeout: 5 * time.Second} + } + + return &Client{baseURL: baseURL, httpClient: hc}, nil } -func (c *Client) CreateDNSRecord(domain *DNSDomain, record Record) (string, error) { - rawJSON, err := json.Marshal(record) - if err != nil { - return "", err - } +func (c *Client) CreateDNSRecord(ctx context.Context, domain *DNSDomain, record Record) (string, error) { + endpoint := c.baseURL.JoinPath("1", "domain", strconv.FormatUint(domain.ID, 10), "dns", "record") - endpoint, err := url.JoinPath(c.apiEndpoint, "1", "domain", strconv.FormatUint(domain.ID, 10), "dns", "record") - if err != nil { - return "", err - } - - req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(rawJSON)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } - resp, err := c.do(req) + result := APIResponse[string]{} + err = c.do(req, &result) if err != nil { return "", err } - var recordID string - if err = json.Unmarshal(resp.Data, &recordID); err != nil { - return "", fmt.Errorf("expected record, got: %s", string(resp.Data)) - } - - return recordID, err + return result.Data, err } -func (c *Client) DeleteDNSRecord(domainID uint64, recordID string) error { - endpoint, err := url.JoinPath(c.apiEndpoint, "1", "domain", strconv.FormatUint(domainID, 10), "dns", "record", recordID) - if err != nil { - return err - } +func (c *Client) DeleteDNSRecord(ctx context.Context, domainID uint64, recordID string) error { + endpoint := c.baseURL.JoinPath("1", "domain", strconv.FormatUint(domainID, 10), "dns", "record", recordID) - req, err := http.NewRequest(http.MethodDelete, endpoint, http.NoBody) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - _, err = c.do(req) - - return err + return c.do(req, &APIResponse[json.RawMessage]{}) } // GetDomainByName gets a Domain object from its name. -func (c *Client) GetDomainByName(name string) (*DNSDomain, error) { +func (c *Client) GetDomainByName(ctx context.Context, name string) (*DNSDomain, error) { name = dns01.UnFqdn(name) // Try to find the most specific domain @@ -88,7 +81,7 @@ func (c *Client) GetDomainByName(name string) (*DNSDomain, error) { break } - domain, err := c.getDomainByName(name) + domain, err := c.getDomainByName(ctx, name) if err != nil { return nil, err } @@ -105,35 +98,26 @@ func (c *Client) GetDomainByName(name string) (*DNSDomain, error) { return nil, fmt.Errorf("domain not found %s", name) } -func (c *Client) getDomainByName(name string) (*DNSDomain, error) { - baseURL, err := url.Parse(c.apiEndpoint) - if err != nil { - return nil, err - } - - endpoint := baseURL.JoinPath("1", "product") +func (c *Client) getDomainByName(ctx context.Context, name string) (*DNSDomain, error) { + endpoint := c.baseURL.JoinPath("1", "product") query := endpoint.Query() query.Add("service_name", "domain") query.Add("customer_name", name) endpoint.RawQuery = query.Encode() - req, err := http.NewRequest(http.MethodGet, endpoint.String(), http.NoBody) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - resp, err := c.do(req) + result := APIResponse[[]DNSDomain]{} + err = c.do(req, &result) if err != nil { return nil, err } - var domains []DNSDomain - if err = json.Unmarshal(resp.Data, &domains); err != nil { - return nil, fmt.Errorf("failed to marshal domains: %s", string(resp.Data)) - } - - for _, domain := range domains { + for _, domain := range result.Data { if domain.CustomerName == name { return &domain, nil } @@ -142,30 +126,63 @@ func (c *Client) getDomainByName(name string) (*DNSDomain, error) { return nil, nil } -func (c *Client) do(req *http.Request) (*APIResponse, error) { - req.Header.Set("Authorization", "Bearer "+c.apiToken) - req.Header.Set("Content-Type", "application/json") - - rawResp, err := c.HTTPClient.Do(req) +func (c *Client) do(req *http.Request, result Response) error { + resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to perform API request: %w", err) + return errutils.NewHTTPDoError(req, err) } - defer func() { _ = rawResp.Body.Close() }() + defer func() { _ = resp.Body.Close() }() - content, err := io.ReadAll(rawResp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read the response body, status code: %d", rawResp.StatusCode) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - var resp APIResponse - if err := json.Unmarshal(content, &resp); err != nil { - return nil, fmt.Errorf("failed to unmarshal the response body: %s, %w", string(content), err) + if err := json.Unmarshal(raw, result); err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } - if resp.Result != "success" { - return nil, fmt.Errorf("%d: unexpected API result (%s): %w", rawResp.StatusCode, resp.Result, resp.ErrResponse) + if result.GetResult() != "success" { + return fmt.Errorf("%d: unexpected API result (%s): %w", resp.StatusCode, result.GetResult(), result.GetError()) } - return &resp, nil + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client } diff --git a/providers/dns/infomaniak/internal/client_test.go b/providers/dns/infomaniak/internal/client_test.go index 7dd8c9e2..4fadaf0f 100644 --- a/providers/dns/infomaniak/internal/client_test.go +++ b/providers/dns/infomaniak/internal/client_test.go @@ -1,6 +1,8 @@ package internal import ( + "bytes" + "context" "fmt" "io" "net/http" @@ -18,7 +20,10 @@ func setupTest(t *testing.T) (*Client, *http.ServeMux) { server := httptest.NewServer(mux) t.Cleanup(server.Close) - return New(server.URL, "token"), mux + client, err := New(OAuthStaticAccessToken(server.Client(), "token"), server.URL) + require.NoError(t, err) + + return client, mux } func TestClient_CreateDNSRecord(t *testing.T) { @@ -42,7 +47,7 @@ func TestClient_CreateDNSRecord(t *testing.T) { } defer func() { _ = req.Body.Close() }() - if string(raw) != `{"source":"foo","type":"TXT","ttl":60,"target":"txtxtxttxt"}` { + if string(bytes.TrimSpace(raw)) != `{"source":"foo","type":"TXT","ttl":60,"target":"txtxtxttxt"}` { http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) return } @@ -68,7 +73,7 @@ func TestClient_CreateDNSRecord(t *testing.T) { TTL: 60, } - recordID, err := client.CreateDNSRecord(domain, record) + recordID, err := client.CreateDNSRecord(context.Background(), domain, record) require.NoError(t, err) assert.Equal(t, "123", recordID) @@ -95,7 +100,6 @@ func TestClient_GetDomainByName(t *testing.T) { } customerName := req.URL.Query().Get("customer_name") - fmt.Println("customerName", customerName) if customerName == "" { http.Error(rw, fmt.Sprintf("invalid customer_name: %s", customerName), http.StatusBadRequest) return @@ -124,7 +128,7 @@ func TestClient_GetDomainByName(t *testing.T) { } }) - domain, err := client.GetDomainByName("one.two.three.example.com.") + domain, err := client.GetDomainByName(context.Background(), "one.two.three.example.com.") require.NoError(t, err) expected := &DNSDomain{ID: 123, CustomerName: "two.three.example.com"} @@ -152,6 +156,6 @@ func TestClient_DeleteDNSRecord(t *testing.T) { } }) - err := client.DeleteDNSRecord(123, "456") + err := client.DeleteDNSRecord(context.Background(), 123, "456") require.NoError(t, err) } diff --git a/providers/dns/infomaniak/internal/models.go b/providers/dns/infomaniak/internal/types.go similarity index 74% rename from providers/dns/infomaniak/internal/models.go rename to providers/dns/infomaniak/internal/types.go index 7056354d..059bc9e9 100644 --- a/providers/dns/infomaniak/internal/models.go +++ b/providers/dns/infomaniak/internal/types.go @@ -1,7 +1,6 @@ package internal import ( - "encoding/json" "fmt" ) @@ -19,12 +18,25 @@ type DNSDomain struct { CustomerName string `json:"customer_name,omitempty"` } -type APIResponse struct { +type Response interface { + GetResult() string + GetError() *APIErrorResponse +} + +type APIResponse[T any] struct { Result string `json:"result"` - Data json.RawMessage `json:"data,omitempty"` + Data T `json:"data,omitempty"` ErrResponse *APIErrorResponse `json:"error,omitempty"` } +func (a APIResponse[T]) GetResult() string { + return a.Result +} + +func (a APIResponse[T]) GetError() *APIErrorResponse { + return a.ErrResponse +} + type APIErrorResponse struct { Code string `json:"code"` Description string `json:"description,omitempty"` diff --git a/providers/dns/internal/errutils/client.go b/providers/dns/internal/errutils/client.go new file mode 100644 index 00000000..09f1344b --- /dev/null +++ b/providers/dns/internal/errutils/client.go @@ -0,0 +1,133 @@ +package errutils + +import ( + "bytes" + "fmt" + "io" + "net/http" + "os" + "strconv" +) + +const legoDebugClientVerboseError = "LEGO_DEBUG_CLIENT_VERBOSE_ERROR" + +// HTTPDoError uses with `(http.Client).Do` error. +type HTTPDoError struct { + req *http.Request + err error +} + +// NewHTTPDoError creates a new HTTPDoError. +func NewHTTPDoError(req *http.Request, err error) *HTTPDoError { + return &HTTPDoError{req: req, err: err} +} + +func (h HTTPDoError) Error() string { + msg := "unable to communicate with the API server:" + + if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok { + msg += fmt.Sprintf(" [request: %s %s]", h.req.Method, h.req.URL) + } + + if h.err == nil { + return msg + } + + return msg + fmt.Sprintf(" error: %v", h.err) +} + +func (h HTTPDoError) Unwrap() error { + return h.err +} + +// ReadResponseError use with `io.ReadAll` when reading response body. +type ReadResponseError struct { + req *http.Request + StatusCode int + err error +} + +// NewReadResponseError creates a new ReadResponseError. +func NewReadResponseError(req *http.Request, statusCode int, err error) *ReadResponseError { + return &ReadResponseError{req: req, StatusCode: statusCode, err: err} +} + +func (r ReadResponseError) Error() string { + msg := "unable to read response body:" + + if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok { + msg += fmt.Sprintf(" [request: %s %s]", r.req.Method, r.req.URL) + } + + msg += fmt.Sprintf(" [status code: %d]", r.StatusCode) + + if r.err == nil { + return msg + } + + return msg + fmt.Sprintf(" error: %v", r.err) +} + +func (r ReadResponseError) Unwrap() error { + return r.err +} + +// UnmarshalError uses with `json.Unmarshal` or `xml.Unmarshal` when reading response body. +type UnmarshalError struct { + req *http.Request + StatusCode int + Body []byte + err error +} + +// NewUnmarshalError creates a new UnmarshalError. +func NewUnmarshalError(req *http.Request, statusCode int, body []byte, err error) *UnmarshalError { + return &UnmarshalError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body), err: err} +} + +func (u UnmarshalError) Error() string { + msg := "unable to unmarshal response:" + + if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok { + msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL) + } + + msg += fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body)) + + if u.err == nil { + return msg + } + + return msg + fmt.Sprintf(" error: %v", u.err) +} + +func (u UnmarshalError) Unwrap() error { + return u.err +} + +// UnexpectedStatusCodeError use when the status of the response is unexpected but there is no API error type. +type UnexpectedStatusCodeError struct { + req *http.Request + StatusCode int + Body []byte +} + +// NewUnexpectedStatusCodeError creates a new UnexpectedStatusCodeError. +func NewUnexpectedStatusCodeError(req *http.Request, statusCode int, body []byte) *UnexpectedStatusCodeError { + return &UnexpectedStatusCodeError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body)} +} + +func NewUnexpectedResponseStatusCodeError(req *http.Request, resp *http.Response) *UnexpectedStatusCodeError { + raw, _ := io.ReadAll(resp.Body) + return &UnexpectedStatusCodeError{req: req, StatusCode: resp.StatusCode, Body: bytes.TrimSpace(raw)} +} + +func (u UnexpectedStatusCodeError) Error() string { + msg := "unexpected status code:" + + if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok { + msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL) + } + + return msg + fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body)) +} diff --git a/providers/dns/internal/rimuhosting/client.go b/providers/dns/internal/rimuhosting/client.go index c4f67ed5..4976f378 100644 --- a/providers/dns/internal/rimuhosting/client.go +++ b/providers/dns/internal/rimuhosting/client.go @@ -1,13 +1,17 @@ package rimuhosting import ( + "context" "encoding/xml" "errors" + "fmt" "io" "net/http" "net/url" "regexp" + "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) @@ -35,9 +39,9 @@ type Client struct { // NewClient Creates a RimuHosting/Zonomi client. func NewClient(apiKey string) *Client { return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: DefaultZonomiBaseURL, apiKey: apiKey, + BaseURL: DefaultZonomiBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } @@ -45,14 +49,14 @@ func NewClient(apiKey string) *Client { // ex: // - https://zonomi.com/app/dns/dyndns.jsp?action=QUERY&name=example.com&api_key=apikeyvaluehere // - https://zonomi.com/app/dns/dyndns.jsp?action=QUERY&name=**.example.com&api_key=apikeyvaluehere -func (c Client) FindTXTRecords(domain string) ([]Record, error) { +func (c Client) FindTXTRecords(ctx context.Context, domain string) ([]Record, error) { action := ActionParameter{ Action: QueryAction, Name: domain, Type: "TXT", } - resp, err := c.DoActions(action) + resp, err := c.DoActions(ctx, action) if err != nil { return nil, err } @@ -61,7 +65,7 @@ func (c Client) FindTXTRecords(domain string) ([]Record, error) { } // DoActions performs actions. -func (c Client) DoActions(actions ...ActionParameter) (*DNSAPIResult, error) { +func (c Client) DoActions(ctx context.Context, actions ...ActionParameter) (*DNSAPIResult, error) { if len(actions) == 0 { return nil, errors.New("no action") } @@ -74,7 +78,7 @@ func (c Client) DoActions(actions ...ActionParameter) (*DNSAPIResult, error) { APIKey: c.apiKey, } - err := c.do(action, resp) + err := c.do(ctx, action, resp) if err != nil { return nil, err } @@ -82,7 +86,7 @@ func (c Client) DoActions(actions ...ActionParameter) (*DNSAPIResult, error) { } multi := c.toMultiParameters(actions) - err := c.do(multi, resp) + err := c.do(ctx, multi, resp) if err != nil { return nil, err } @@ -105,7 +109,7 @@ func (c Client) toMultiParameters(params []ActionParameter) multiActionParameter return multi } -func (c Client) do(params, data interface{}) error { +func (c Client) do(ctx context.Context, params, result any) error { baseURL, err := url.Parse(c.BaseURL) if err != nil { return err @@ -117,47 +121,55 @@ func (c Client) do(params, data interface{}) error { } exp := regexp.MustCompile(`(%5B)(%5D)(\d+)=`) - baseURL.RawQuery = exp.ReplaceAllString(v.Encode(), "${1}${3}${2}=") - req, err := http.NewRequest(http.MethodGet, baseURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL.String(), http.NoBody) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - all, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - if resp.StatusCode/100 != 2 { - r := APIError{} - err = xml.Unmarshal(all, &r) - if err != nil { - return err - } - return r + return parseError(req, resp) } - if data != nil { - err := xml.Unmarshal(all, data) - if err != nil { - return err - } + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = xml.Unmarshal(raw, result) + if err != nil { + return fmt.Errorf("unmarshaling %T error: %w: %s", result, err, string(raw)) } return nil } -// AddRecord helper to create an action to add a TXT record. -func AddRecord(domain, content string, ttl int) ActionParameter { +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := APIError{} + err := xml.Unmarshal(raw, &errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errAPI +} + +// NewAddRecordAction helper to create an action to add a TXT record. +func NewAddRecordAction(domain, content string, ttl int) ActionParameter { return ActionParameter{ Action: SetAction, Name: domain, @@ -167,8 +179,8 @@ func AddRecord(domain, content string, ttl int) ActionParameter { } } -// DeleteRecord helper to create an action to delete a TXT record. -func DeleteRecord(domain, content string) ActionParameter { +// NewDeleteRecordAction helper to create an action to delete a TXT record. +func NewDeleteRecordAction(domain, content string) ActionParameter { return ActionParameter{ Action: DeleteAction, Name: domain, diff --git a/providers/dns/internal/rimuhosting/client_test.go b/providers/dns/internal/rimuhosting/client_test.go index 76ba18d2..ecd55b0b 100644 --- a/providers/dns/internal/rimuhosting/client_test.go +++ b/providers/dns/internal/rimuhosting/client_test.go @@ -1,6 +1,7 @@ package rimuhosting import ( + "context" "encoding/xml" "fmt" "io" @@ -14,11 +15,23 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_FindTXTRecords(t *testing.T) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) + client := NewClient("apikeyvaluehere") + client.BaseURL = server.URL + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_FindTXTRecords(t *testing.T) { + client, mux := setupTest(t) + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { query := req.URL.Query() @@ -39,9 +52,6 @@ func TestClient_FindTXTRecords(t *testing.T) { } }) - client := NewClient("apikeyvaluehere") - client.BaseURL = server.URL - testCases := []struct { desc string domain string @@ -89,7 +99,7 @@ func TestClient_FindTXTRecords(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - records, err := client.FindTXTRecords(test.domain) + records, err := client.FindTXTRecords(context.Background(), test.domain) require.NoError(t, err) assert.Equal(t, test.expected, records) @@ -113,7 +123,7 @@ func TestClient_DoActions(t *testing.T) { { desc: "SET error", actions: []ActionParameter{ - AddRecord("example.com", "txttxtx", 0), + NewAddRecordAction("example.com", "txttxtx", 0), }, fixture: "./fixtures/add_record_error.xml", expected: expected{ @@ -124,7 +134,7 @@ func TestClient_DoActions(t *testing.T) { { desc: "SET simple", actions: []ActionParameter{ - AddRecord("example.org", "txttxtx", 0), + NewAddRecordAction("example.org", "txttxtx", 0), }, fixture: "./fixtures/add_record.xml", expected: expected{ @@ -153,8 +163,8 @@ func TestClient_DoActions(t *testing.T) { { desc: "SET multiple values", actions: []ActionParameter{ - AddRecord("example.org", "txttxtx", 0), - AddRecord("example.org", "sample", 0), + NewAddRecordAction("example.org", "txttxtx", 0), + NewAddRecordAction("example.org", "sample", 0), }, fixture: "./fixtures/add_record_same_domain.xml", expected: expected{ @@ -192,7 +202,7 @@ func TestClient_DoActions(t *testing.T) { { desc: "DELETE error", actions: []ActionParameter{ - DeleteRecord("example.com", "txttxtx"), + NewDeleteRecordAction("example.com", "txttxtx"), }, fixture: "./fixtures/delete_record_error.xml", expected: expected{ @@ -203,7 +213,7 @@ func TestClient_DoActions(t *testing.T) { { desc: "DELETE nothing", actions: []ActionParameter{ - DeleteRecord("example.org", "nothing"), + NewDeleteRecordAction("example.org", "nothing"), }, fixture: "./fixtures/delete_record_nothing.xml", expected: expected{ @@ -226,7 +236,7 @@ func TestClient_DoActions(t *testing.T) { { desc: "DELETE simple", actions: []ActionParameter{ - DeleteRecord("example.org", "txttxtx"), + NewDeleteRecordAction("example.org", "txttxtx"), }, fixture: "./fixtures/delete_record.xml", expected: expected{ @@ -256,9 +266,7 @@ func TestClient_DoActions(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { query, err := url.QueryUnescape(req.URL.RawQuery) @@ -283,10 +291,7 @@ func TestClient_DoActions(t *testing.T) { } }) - client := NewClient("apikeyvaluehere") - client.BaseURL = server.URL - - resp, err := client.DoActions(test.actions...) + resp, err := client.DoActions(context.Background(), test.actions...) if test.expected.Error != "" { require.EqualError(t, err, test.expected.Error) return diff --git a/providers/dns/internal/rimuhosting/model.go b/providers/dns/internal/rimuhosting/types.go similarity index 100% rename from providers/dns/internal/rimuhosting/model.go rename to providers/dns/internal/rimuhosting/types.go diff --git a/providers/dns/internal/selectel/client.go b/providers/dns/internal/selectel/client.go index 92e4746f..dcefa34b 100644 --- a/providers/dns/internal/selectel/client.go +++ b/providers/dns/internal/selectel/client.go @@ -2,11 +2,17 @@ package selectel import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" + "net/url" + "strconv" "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // Base URL for the Selectel/VScale DNS services. @@ -15,39 +21,43 @@ const ( DefaultVScaleBaseURL = "https://api.vscale.io/v1/domains" ) +const tokenHeader = "X-Token" + // Client represents DNS client. type Client struct { - BaseURL string + token string + + BaseURL *url.URL HTTPClient *http.Client - token string } // NewClient returns a client instance. func NewClient(token string) *Client { + baseURL, _ := url.Parse(DefaultVScaleBaseURL) + return &Client{ token: token, - BaseURL: DefaultVScaleBaseURL, - HTTPClient: &http.Client{}, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetDomainByName gets Domain object by its name. If `domainName` level > 2 and there is // no such domain on the account - it'll recursively search for the first // which is exists in Selectel Domain API. -func (c *Client) GetDomainByName(domainName string) (*Domain, error) { - uri := fmt.Sprintf("/%s", domainName) - req, err := c.newRequest(http.MethodGet, uri, nil) +func (c *Client) GetDomainByName(ctx context.Context, domainName string) (*Domain, error) { + req, err := newJSONRequest(ctx, http.MethodGet, c.BaseURL.JoinPath(domainName), nil) if err != nil { return nil, err } domain := &Domain{} - resp, err := c.do(req, domain) + statusCode, err := c.do(req, domain) if err != nil { - if resp != nil && resp.StatusCode == http.StatusNotFound && strings.Count(domainName, ".") > 1 { + if statusCode == http.StatusNotFound && strings.Count(domainName, ".") > 1 { // Look up for the next sub domain subIndex := strings.Index(domainName, ".") - return c.GetDomainByName(domainName[subIndex+1:]) + return c.GetDomainByName(ctx, domainName[subIndex+1:]) } return nil, err @@ -57,9 +67,8 @@ func (c *Client) GetDomainByName(domainName string) (*Domain, error) { } // AddRecord adds Record for given domain. -func (c *Client) AddRecord(domainID int, body Record) (*Record, error) { - uri := fmt.Sprintf("/%d/records/", domainID) - req, err := c.newRequest(http.MethodPost, uri, body) +func (c *Client) AddRecord(ctx context.Context, domainID int, body Record) (*Record, error) { + req, err := newJSONRequest(ctx, http.MethodPost, c.BaseURL.JoinPath(strconv.Itoa(domainID), "records", "/"), body) if err != nil { return nil, err } @@ -74,9 +83,8 @@ func (c *Client) AddRecord(domainID int, body Record) (*Record, error) { } // ListRecords returns list records for specific domain. -func (c *Client) ListRecords(domainID int) ([]Record, error) { - uri := fmt.Sprintf("/%d/records/", domainID) - req, err := c.newRequest(http.MethodGet, uri, nil) +func (c *Client) ListRecords(ctx context.Context, domainID int) ([]Record, error) { + req, err := newJSONRequest(ctx, http.MethodGet, c.BaseURL.JoinPath(strconv.Itoa(domainID), "records", "/"), nil) if err != nil { return nil, err } @@ -86,13 +94,15 @@ func (c *Client) ListRecords(domainID int) ([]Record, error) { if err != nil { return nil, err } + return records, nil } // DeleteRecord deletes specific record. -func (c *Client) DeleteRecord(domainID, recordID int) error { - uri := fmt.Sprintf("/%d/records/%d", domainID, recordID) - req, err := c.newRequest(http.MethodDelete, uri, nil) +func (c *Client) DeleteRecord(ctx context.Context, domainID, recordID int) error { + endpoint := c.BaseURL.JoinPath(strconv.Itoa(domainID), "records", strconv.Itoa(recordID)) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } @@ -101,83 +111,69 @@ func (c *Client) DeleteRecord(domainID, recordID int) error { return err } -func (c *Client) newRequest(method, uri string, body interface{}) (*http.Request, error) { +func (c *Client) do(req *http.Request, result any) (int, error) { + req.Header.Set(tokenHeader, c.token) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return 0, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return resp.StatusCode, parseError(req, resp) + } + + if result == nil { + return resp.StatusCode, nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return resp.StatusCode, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return resp.StatusCode, nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { buf := new(bytes.Buffer) - if body != nil { - err := json.NewEncoder(buf).Encode(body) + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) if err != nil { - return nil, fmt.Errorf("failed to encode request body with error: %w", err) + return nil, fmt.Errorf("failed to create request JSON body: %w", err) } } - req, err := http.NewRequest(method, c.BaseURL+uri, buf) + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return nil, fmt.Errorf("failed to create new http request with error: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } - req.Header.Set("X-Token", c.token) - req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + return req, nil } -func (c *Client) do(req *http.Request, to interface{}) (*http.Response, error) { - resp, err := c.HTTPClient.Do(req) +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := &APIError{} + err := json.Unmarshal(raw, errAPI) if err != nil { - return nil, fmt.Errorf("request failed with error: %w", err) + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - err = checkResponse(resp) - if err != nil { - return resp, err - } - - if to != nil { - if err = unmarshalBody(resp, to); err != nil { - return resp, err - } - } - - return resp, nil -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode >= http.StatusBadRequest { - if resp.Body == nil { - return fmt.Errorf("request failed with status code %d and empty body", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - defer resp.Body.Close() - - apiError := APIError{} - err = json.Unmarshal(body, &apiError) - if err != nil { - return fmt.Errorf("request failed with status code %d, response body: %s", resp.StatusCode, string(body)) - } - - return fmt.Errorf("request failed with status code %d: %w", resp.StatusCode, apiError) - } - - return nil -} - -func unmarshalBody(resp *http.Response, to interface{}) error { - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - defer resp.Body.Close() - - err = json.Unmarshal(body, to) - if err != nil { - return fmt.Errorf("unmarshaling error: %w: %s", err, string(body)) - } - - return nil + return fmt.Errorf("request failed with status code %d: %w", resp.StatusCode, errAPI) } diff --git a/providers/dns/internal/selectel/client_test.go b/providers/dns/internal/selectel/client_test.go index c0bf3007..fd658ae3 100644 --- a/providers/dns/internal/selectel/client_test.go +++ b/providers/dns/internal/selectel/client_test.go @@ -1,11 +1,13 @@ package selectel import ( + "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -13,11 +15,23 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_ListRecords(t *testing.T) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) + client := NewClient("token") + client.BaseURL, _ = url.Parse(server.URL) + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_ListRecords(t *testing.T) { + client, mux := setupTest(t) + mux.HandleFunc("/123/records/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -33,10 +47,7 @@ func TestClient_ListRecords(t *testing.T) { } }) - client := NewClient("token") - client.BaseURL = server.URL - - records, err := client.ListRecords(123) + records, err := client.ListRecords(context.Background(), 123) require.NoError(t, err) expected := []Record{ @@ -49,9 +60,7 @@ func TestClient_ListRecords(t *testing.T) { } func TestClient_ListRecords_error(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/123/records/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -67,19 +76,14 @@ func TestClient_ListRecords_error(t *testing.T) { } }) - client := NewClient("token") - client.BaseURL = server.URL - - records, err := client.ListRecords(123) + records, err := client.ListRecords(context.Background(), 123) assert.EqualError(t, err, "request failed with status code 401: API error: 400 - error description - field that the error occurred in") assert.Nil(t, records) } func TestClient_GetDomainByName(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/sub.sub.example.org", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -114,10 +118,7 @@ func TestClient_GetDomainByName(t *testing.T) { } }) - client := NewClient("token") - client.BaseURL = server.URL - - domain, err := client.GetDomainByName("sub.sub.example.org") + domain, err := client.GetDomainByName(context.Background(), "sub.sub.example.org") require.NoError(t, err) expected := &Domain{ @@ -129,9 +130,7 @@ func TestClient_GetDomainByName(t *testing.T) { } func TestClient_AddRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/123/records/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -156,10 +155,7 @@ func TestClient_AddRecord(t *testing.T) { } }) - client := NewClient("token") - client.BaseURL = server.URL - - record, err := client.AddRecord(123, Record{ + record, err := client.AddRecord(context.Background(), 123, Record{ Name: "example.org", Type: "TXT", TTL: 60, @@ -182,9 +178,7 @@ func TestClient_AddRecord(t *testing.T) { } func TestClient_DeleteRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodDelete { @@ -193,10 +187,7 @@ func TestClient_DeleteRecord(t *testing.T) { } }) - client := NewClient("token") - client.BaseURL = server.URL - - err := client.DeleteRecord(123, 456) + err := client.DeleteRecord(context.Background(), 123, 456) require.NoError(t, err) } diff --git a/providers/dns/internal/selectel/models.go b/providers/dns/internal/selectel/types.go similarity index 100% rename from providers/dns/internal/selectel/models.go rename to providers/dns/internal/selectel/types.go diff --git a/providers/dns/internetbs/internal/client.go b/providers/dns/internetbs/internal/client.go index 9334586f..771408c5 100644 --- a/providers/dns/internetbs/internal/client.go +++ b/providers/dns/internetbs/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/json" "fmt" "io" @@ -12,6 +13,7 @@ import ( "time" "unicode" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) @@ -22,12 +24,13 @@ const statusSuccess = "SUCCESS" // Client is the API client. type Client struct { - HTTPClient *http.Client - baseURL *url.URL - debug bool - apiKey string password string + + debug bool + + baseURL *url.URL + HTTPClient *http.Client } // NewClient creates a new Client. @@ -35,17 +38,17 @@ func NewClient(apiKey string, password string) *Client { baseURL, _ := url.Parse(baseURL) return &Client{ - HTTPClient: &http.Client{Timeout: 10 * time.Second}, - baseURL: baseURL, apiKey: apiKey, password: password, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, } } // AddRecord The command is intended to add a new DNS record to a specific zone (domain). -func (c Client) AddRecord(query RecordQuery) error { +func (c Client) AddRecord(ctx context.Context, query RecordQuery) error { var r APIResponse - err := c.do("Add", query, &r) + err := c.doRequest(ctx, "Add", query, &r) if err != nil { return err } @@ -58,9 +61,9 @@ func (c Client) AddRecord(query RecordQuery) error { } // RemoveRecord The command is intended to remove a DNS record from a specific zone. -func (c Client) RemoveRecord(query RecordQuery) error { +func (c Client) RemoveRecord(ctx context.Context, query RecordQuery) error { var r APIResponse - err := c.do("Remove", query, &r) + err := c.doRequest(ctx, "Remove", query, &r) if err != nil { return err } @@ -73,9 +76,9 @@ func (c Client) RemoveRecord(query RecordQuery) error { } // ListRecords The command is intended to retrieve the list of DNS records for a specific domain. -func (c Client) ListRecords(query ListRecordQuery) ([]Record, error) { +func (c Client) ListRecords(ctx context.Context, query ListRecordQuery) ([]Record, error) { var l ListResponse - err := c.do("List", query, &l) + err := c.doRequest(ctx, "List", query, &l) if err != nil { return nil, err } @@ -87,7 +90,7 @@ func (c Client) ListRecords(query ListRecordQuery) ([]Record, error) { return l.Records, nil } -func (c Client) do(action string, params interface{}, response interface{}) error { +func (c Client) doRequest(ctx context.Context, action string, params any, result any) error { endpoint := c.baseURL.JoinPath("Domain", "DnsRecord", action) values, err := querystring.Values(params) @@ -99,27 +102,43 @@ func (c Client) do(action string, params interface{}, response interface{}) erro values.Set("password", c.password) values.Set("ResponseFormat", "JSON") - resp, err := c.HTTPClient.PostForm(endpoint.String(), values) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), strings.NewReader(values.Encode())) if err != nil { - return fmt.Errorf("post request: %w", err) + return fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("status code: %d, %s", resp.StatusCode, string(data)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } if c.debug { - return dump(endpoint, resp, response) + return dump(endpoint, resp, result) } - return json.NewDecoder(resp.Body).Decode(response) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil } -func dump(endpoint *url.URL, resp *http.Response, response interface{}) error { - data, err := io.ReadAll(resp.Body) +func dump(endpoint *url.URL, resp *http.Response, response any) error { + raw, err := io.ReadAll(resp.Body) if err != nil { return err } @@ -128,10 +147,10 @@ func dump(endpoint *url.URL, resp *http.Response, response interface{}) error { return !unicode.IsLetter(r) && !unicode.IsNumber(r) }) - err = os.WriteFile(filepath.Join("fixtures", strings.Join(fields, "_")+".json"), data, 0o666) + err = os.WriteFile(filepath.Join("fixtures", strings.Join(fields, "_")+".json"), raw, 0o666) if err != nil { return err } - return json.Unmarshal(data, response) + return json.Unmarshal(raw, response) } diff --git a/providers/dns/internetbs/internal/client_test.go b/providers/dns/internetbs/internal/client_test.go index 0efc6cab..a22f1b12 100644 --- a/providers/dns/internetbs/internal/client_test.go +++ b/providers/dns/internetbs/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -31,7 +32,7 @@ func TestClient_AddRecord(t *testing.T) { TTL: 36000, } - err := client.AddRecord(query) + err := client.AddRecord(context.Background(), query) require.NoError(t, err) } @@ -45,7 +46,7 @@ func TestClient_AddRecord_error(t *testing.T) { TTL: 36000, } - err := client.AddRecord(query) + err := client.AddRecord(context.Background(), query) require.Error(t, err) } @@ -66,7 +67,7 @@ func TestClient_AddRecord_integration(t *testing.T) { TTL: 36000, } - err := client.AddRecord(query) + err := client.AddRecord(context.Background(), query) require.NoError(t, err) query = RecordQuery{ @@ -76,7 +77,7 @@ func TestClient_AddRecord_integration(t *testing.T) { TTL: 36000, } - err = client.AddRecord(query) + err = client.AddRecord(context.Background(), query) require.NoError(t, err) } @@ -88,7 +89,7 @@ func TestClient_RemoveRecord(t *testing.T) { Type: "TXT", Value: "", } - err := client.RemoveRecord(query) + err := client.RemoveRecord(context.Background(), query) require.NoError(t, err) } @@ -100,7 +101,7 @@ func TestClient_RemoveRecord_error(t *testing.T) { Type: "TXT", Value: "", } - err := client.RemoveRecord(query) + err := client.RemoveRecord(context.Background(), query) require.Error(t, err) } @@ -120,7 +121,7 @@ func TestClient_RemoveRecord_integration(t *testing.T) { Value: "", } - err := client.RemoveRecord(query) + err := client.RemoveRecord(context.Background(), query) require.NoError(t, err) } @@ -131,7 +132,7 @@ func TestClient_ListRecords(t *testing.T) { Domain: "example.com", } - records, err := client.ListRecords(query) + records, err := client.ListRecords(context.Background(), query) require.NoError(t, err) expected := []Record{ @@ -183,7 +184,7 @@ func TestClient_ListRecords_error(t *testing.T) { Domain: "www.example.com", } - _, err := client.ListRecords(query) + _, err := client.ListRecords(context.Background(), query) require.Error(t, err) } @@ -201,7 +202,7 @@ func TestClient_ListRecords_integration(t *testing.T) { Domain: "example.com", } - records, err := client.ListRecords(query) + records, err := client.ListRecords(context.Background(), query) require.NoError(t, err) for _, record := range records { diff --git a/providers/dns/internetbs/internetbs.go b/providers/dns/internetbs/internetbs.go index 27b48de0..89b33eae 100644 --- a/providers/dns/internetbs/internetbs.go +++ b/providers/dns/internetbs/internetbs.go @@ -2,6 +2,7 @@ package internetbs import ( + "context" "errors" "fmt" "net/http" @@ -107,7 +108,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - err := d.client.AddRecord(query) + err := d.client.AddRecord(context.Background(), query) if err != nil { return fmt.Errorf("internetbs: %w", err) } @@ -126,7 +127,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { TTL: d.config.TTL, } - err := d.client.RemoveRecord(query) + err := d.client.RemoveRecord(context.Background(), query) if err != nil { return fmt.Errorf("internetbs: %w", err) } diff --git a/providers/dns/inwx/inwx.go b/providers/dns/inwx/inwx.go index 3d593073..fdfa4e54 100644 --- a/providers/dns/inwx/inwx.go +++ b/providers/dns/inwx/inwx.go @@ -97,7 +97,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(challengeInfo.EffectiveFQDN) if err != nil { - return fmt.Errorf("inwx: %w", err) + return fmt.Errorf("inwx: could not find zone for domain %q (%s): %w", domain, challengeInfo.EffectiveFQDN, err) } info, err := d.client.Account.Login() @@ -147,7 +147,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(challengeInfo.EffectiveFQDN) if err != nil { - return fmt.Errorf("inwx: %w", err) + return fmt.Errorf("inwx: could not find zone for domain %q (%s): %w", domain, challengeInfo.EffectiveFQDN, err) } info, err := d.client.Account.Login() diff --git a/providers/dns/ionos/internal/client.go b/providers/dns/ionos/internal/client.go index 3abd5657..8b37d5f1 100644 --- a/providers/dns/ionos/internal/client.go +++ b/providers/dns/ionos/internal/client.go @@ -8,7 +8,9 @@ import ( "io" "net/http" "net/url" + "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) @@ -17,10 +19,10 @@ const defaultBaseURL = "https://api.hosting.ionos.com/dns" // Client Ionos API client. type Client struct { - HTTPClient *http.Client - BaseURL *url.URL - apiKey string + + BaseURL *url.URL + HTTPClient *http.Client } // NewClient creates a new Client. @@ -31,9 +33,9 @@ func NewClient(apiKey string) (*Client, error) { } return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: baseURL, apiKey: apiKey, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, }, nil } @@ -41,28 +43,17 @@ func NewClient(apiKey string) (*Client, error) { func (c *Client) ListZones(ctx context.Context) ([]Zone, error) { endpoint := c.BaseURL.JoinPath("v1", "zones") - req, err := c.makeRequest(ctx, http.MethodGet, endpoint, nil) + req, err := makeJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := c.HTTPClient.Do(req) + var zones []Zone + err = c.do(req, &zones) if err != nil { return nil, fmt.Errorf("failed to call API: %w", err) } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, readError(resp.Body, resp.StatusCode) - } - - var zones []Zone - err = json.NewDecoder(resp.Body).Decode(&zones) - if err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - return zones, nil } @@ -70,27 +61,16 @@ func (c *Client) ListZones(ctx context.Context) ([]Zone, error) { func (c *Client) ReplaceRecords(ctx context.Context, zoneID string, records []Record) error { endpoint := c.BaseURL.JoinPath("v1", "zones", zoneID) - body, err := json.Marshal(records) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - - req, err := c.makeRequest(ctx, http.MethodPatch, endpoint, bytes.NewReader(body)) + req, err := makeJSONRequest(ctx, http.MethodPatch, endpoint, records) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - resp, err := c.HTTPClient.Do(req) + err = c.do(req, nil) if err != nil { return fmt.Errorf("failed to call API: %w", err) } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return readError(resp.Body, resp.StatusCode) - } - return nil } @@ -98,7 +78,7 @@ func (c *Client) ReplaceRecords(ctx context.Context, zoneID string, records []Re func (c *Client) GetRecords(ctx context.Context, zoneID string, filter *RecordsFilter) ([]Record, error) { endpoint := c.BaseURL.JoinPath("v1", "zones", zoneID) - req, err := c.makeRequest(ctx, http.MethodGet, endpoint, nil) + req, err := makeJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -112,23 +92,12 @@ func (c *Client) GetRecords(ctx context.Context, zoneID string, filter *RecordsF req.URL.RawQuery = v.Encode() } - resp, err := c.HTTPClient.Do(req) + var zone CustomerZone + err = c.do(req, &zone) if err != nil { return nil, fmt.Errorf("failed to call API: %w", err) } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, readError(resp.Body, resp.StatusCode) - } - - var zone CustomerZone - err = json.NewDecoder(resp.Body).Decode(&zone) - if err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - return zone.Records, nil } @@ -136,48 +105,82 @@ func (c *Client) GetRecords(ctx context.Context, zoneID string, filter *RecordsF func (c *Client) RemoveRecord(ctx context.Context, zoneID, recordID string) error { endpoint := c.BaseURL.JoinPath("v1", "zones", zoneID, "records", recordID) - req, err := c.makeRequest(ctx, http.MethodDelete, endpoint, nil) + req, err := makeJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - resp, err := c.HTTPClient.Do(req) + err = c.do(req, nil) if err != nil { return fmt.Errorf("failed to call API: %w", err) } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return readError(resp.Body, resp.StatusCode) - } - return nil } -func (c *Client) makeRequest(ctx context.Context, method string, endpoint *url.URL, body io.Reader) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), body) +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set("X-API-Key", c.apiKey) + + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func makeJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-API-Key", c.apiKey) + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } -func readError(body io.Reader, statusCode int) error { - bodyBytes, _ := io.ReadAll(body) +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) - cErr := &ClientError{StatusCode: statusCode} - - err := json.Unmarshal(bodyBytes, &cErr.errors) + errClient := &ClientError{StatusCode: resp.StatusCode} + err := json.Unmarshal(raw, &errClient.errors) if err != nil { - cErr.message = string(bodyBytes) - return cErr + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - return cErr + return errClient } diff --git a/providers/dns/ionos/internal/client_test.go b/providers/dns/ionos/internal/client_test.go index 5d40b73b..21a7a267 100644 --- a/providers/dns/ionos/internal/client_test.go +++ b/providers/dns/ionos/internal/client_test.go @@ -17,7 +17,7 @@ import ( ) func TestClient_ListZones(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones", mockHandler(http.MethodGet, http.StatusOK, "list_zones.json")) @@ -34,7 +34,7 @@ func TestClient_ListZones(t *testing.T) { } func TestClient_ListZones_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones", mockHandler(http.MethodGet, http.StatusUnauthorized, "list_zones_error.json")) @@ -49,7 +49,7 @@ func TestClient_ListZones_error(t *testing.T) { } func TestClient_GetRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01", mockHandler(http.MethodGet, http.StatusOK, "get_records.json")) @@ -67,7 +67,7 @@ func TestClient_GetRecords(t *testing.T) { } func TestClient_GetRecords_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01", mockHandler(http.MethodGet, http.StatusUnauthorized, "get_records_error.json")) @@ -82,7 +82,7 @@ func TestClient_GetRecords_error(t *testing.T) { } func TestClient_RemoveRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01/records/arecord01", mockHandler(http.MethodDelete, http.StatusOK, "")) @@ -91,7 +91,7 @@ func TestClient_RemoveRecord(t *testing.T) { } func TestClient_RemoveRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01/records/arecord01", mockHandler(http.MethodDelete, http.StatusInternalServerError, "remove_record_error.json")) @@ -104,7 +104,7 @@ func TestClient_RemoveRecord_error(t *testing.T) { } func TestClient_ReplaceRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01", mockHandler(http.MethodPatch, http.StatusOK, "")) @@ -120,7 +120,7 @@ func TestClient_ReplaceRecords(t *testing.T) { } func TestClient_ReplaceRecords_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/v1/zones/azone01", mockHandler(http.MethodPatch, http.StatusBadRequest, "replace_records_error.json")) @@ -139,7 +139,7 @@ func TestClient_ReplaceRecords_error(t *testing.T) { assert.Equal(t, http.StatusBadRequest, cErr.StatusCode) } -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() @@ -151,7 +151,7 @@ func setupTest(t *testing.T) (*http.ServeMux, *Client) { client.BaseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func mockHandler(method string, statusCode int, filename string) func(http.ResponseWriter, *http.Request) { diff --git a/providers/dns/ionos/ionos.go b/providers/dns/ionos/ionos.go index b0e64e9f..d6150a70 100644 --- a/providers/dns/ionos/ionos.go +++ b/providers/dns/ionos/ionos.go @@ -92,10 +92,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { client.HTTPClient = config.HTTPClient } - return &DNSProvider{ - config: config, - client: client, - }, nil + return &DNSProvider{config: config, client: client}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. diff --git a/providers/dns/iwantmyname/internal/client.go b/providers/dns/iwantmyname/internal/client.go index 22fef84c..7a7c50e2 100644 --- a/providers/dns/iwantmyname/internal/client.go +++ b/providers/dns/iwantmyname/internal/client.go @@ -3,28 +3,21 @@ package internal import ( "context" "fmt" - "io" "net/http" "net/url" "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) const defaultBaseURL = "https://iwantmyname.com/basicauth/ddns" -// Record represents a record. -type Record struct { - Hostname string `url:"hostname,omitempty"` - Type string `url:"type,omitempty"` - Value string `url:"value,omitempty"` - TTL int `url:"ttl,omitempty"` -} - // Client iwantmyname client. type Client struct { - username string - password string + username string + password string + baseURL *url.URL HTTPClient *http.Client } @@ -32,6 +25,7 @@ type Client struct { // NewClient creates a new Client. func NewClient(username string, password string) *Client { baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ username: username, password: password, @@ -40,8 +34,8 @@ func NewClient(username string, password string) *Client { } } -// Do send a request (create/add/delete) to the API. -func (c Client) Do(ctx context.Context, record Record) error { +// SendRequest send a request (create/add/delete) to the API. +func (c Client) SendRequest(ctx context.Context, record Record) error { values, err := querystring.Values(record) if err != nil { return err @@ -52,19 +46,20 @@ func (c Client) Do(ctx context.Context, record Record) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), http.NoBody) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } req.SetBasicAuth(c.username, c.password) resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode/100 != 2 { - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("status code: %d, %s", resp.StatusCode, string(data)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil diff --git a/providers/dns/iwantmyname/internal/client_test.go b/providers/dns/iwantmyname/internal/client_test.go index 76ea2532..b26f7c0f 100644 --- a/providers/dns/iwantmyname/internal/client_test.go +++ b/providers/dns/iwantmyname/internal/client_test.go @@ -18,14 +18,23 @@ func checkParameter(query url.Values, key, expected string) error { return nil } -func TestClient_Do(t *testing.T) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) + t.Cleanup(server.Close) client := NewClient("user", "secret") client.HTTPClient = server.Client() client.baseURL, _ = url.Parse(server.URL) + return client, mux +} + +func TestClient_Do(t *testing.T) { + client, mux := setupTest(t) + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -73,6 +82,6 @@ func TestClient_Do(t *testing.T) { TTL: 120, } - err := client.Do(context.Background(), record) + err := client.SendRequest(context.Background(), record) require.NoError(t, err) } diff --git a/providers/dns/iwantmyname/internal/types.go b/providers/dns/iwantmyname/internal/types.go new file mode 100644 index 00000000..b259235f --- /dev/null +++ b/providers/dns/iwantmyname/internal/types.go @@ -0,0 +1,9 @@ +package internal + +// Record represents a record. +type Record struct { + Hostname string `url:"hostname,omitempty"` + Type string `url:"type,omitempty"` + Value string `url:"value,omitempty"` + TTL int `url:"ttl,omitempty"` +} diff --git a/providers/dns/iwantmyname/iwantmyname.go b/providers/dns/iwantmyname/iwantmyname.go index dfb75e37..e828446a 100644 --- a/providers/dns/iwantmyname/iwantmyname.go +++ b/providers/dns/iwantmyname/iwantmyname.go @@ -108,7 +108,7 @@ func (d *DNSProvider) Present(domain, _, keyAuth string) error { TTL: d.config.TTL, } - err := d.client.Do(context.Background(), record) + err := d.client.SendRequest(context.Background(), record) if err != nil { return fmt.Errorf("iwantmyname: %w", err) } @@ -127,7 +127,7 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { TTL: d.config.TTL, } - err := d.client.Do(context.Background(), record) + err := d.client.SendRequest(context.Background(), record) if err != nil { return fmt.Errorf("iwantmyname: %w", err) } diff --git a/providers/dns/joker/internal/dmapi/client.go b/providers/dns/joker/internal/dmapi/client.go index 00a84b53..04f4350a 100644 --- a/providers/dns/joker/internal/dmapi/client.go +++ b/providers/dns/joker/internal/dmapi/client.go @@ -3,6 +3,7 @@ package dmapi import ( + "context" "errors" "fmt" "io" @@ -10,9 +11,12 @@ import ( "net/url" "strconv" "strings" + "sync" + "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://dmapi.joker.com/request/" @@ -30,129 +34,90 @@ type AuthInfo struct { APIKey string Username string Password string - authSid string } // Client a DMAPI Client. type Client struct { - HTTPClient *http.Client + apiKey string + username string + password string + + token *Token + muToken sync.Mutex + + Debug bool BaseURL string - - Debug bool - - auth AuthInfo + HTTPClient *http.Client } // NewClient creates a new DMAPI Client. -func NewClient(auth AuthInfo) *Client { +func NewClient(authInfo AuthInfo) *Client { return &Client{ - HTTPClient: http.DefaultClient, + apiKey: authInfo.APIKey, + username: authInfo.Username, + password: authInfo.Password, BaseURL: defaultBaseURL, - Debug: false, - auth: auth, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } -// Login performs a login to Joker's DMAPI. -func (c *Client) Login() (*Response, error) { - if c.auth.authSid != "" { - // already logged in - return nil, nil - } - - var values url.Values - switch { - case c.auth.Username != "" && c.auth.Password != "": - values = url.Values{ - "username": {c.auth.Username}, - "password": {c.auth.Password}, - } - case c.auth.APIKey != "": - values = url.Values{"api-key": {c.auth.APIKey}} - default: - return nil, errors.New("no username and password or api-key") - } - - response, err := c.postRequest("login", values) - if err != nil { - return response, err - } - - if response == nil { - return nil, errors.New("login returned nil response") - } - - if response.AuthSid == "" { - return response, errors.New("login did not return valid Auth-Sid") - } - - c.auth.authSid = response.AuthSid - - return response, nil -} - -// Logout closes authenticated session with Joker's DMAPI. -func (c *Client) Logout() (*Response, error) { - if c.auth.authSid == "" { - return nil, errors.New("already logged out") - } - - response, err := c.postRequest("logout", url.Values{}) - if err == nil { - c.auth.authSid = "" - } - return response, err -} - // GetZone returns content of DNS zone for domain. -func (c *Client) GetZone(domain string) (*Response, error) { - if c.auth.authSid == "" { +func (c *Client) GetZone(ctx context.Context, domain string) (*Response, error) { + if getSessionID(ctx) == "" { return nil, errors.New("must be logged in to get zone") } - return c.postRequest("dns-zone-get", url.Values{"domain": {dns01.UnFqdn(domain)}}) + return c.postRequest(ctx, "dns-zone-get", url.Values{"domain": {dns01.UnFqdn(domain)}}) } // PutZone uploads DNS zone to Joker DMAPI. -func (c *Client) PutZone(domain, zone string) (*Response, error) { - if c.auth.authSid == "" { +func (c *Client) PutZone(ctx context.Context, domain, zone string) (*Response, error) { + if getSessionID(ctx) == "" { return nil, errors.New("must be logged in to put zone") } - return c.postRequest("dns-zone-put", url.Values{"domain": {dns01.UnFqdn(domain)}, "zone": {strings.TrimSpace(zone)}}) + return c.postRequest(ctx, "dns-zone-put", url.Values{"domain": {dns01.UnFqdn(domain)}, "zone": {strings.TrimSpace(zone)}}) } // postRequest performs actual HTTP request. -func (c *Client) postRequest(cmd string, data url.Values) (*Response, error) { +func (c *Client) postRequest(ctx context.Context, cmd string, data url.Values) (*Response, error) { endpoint, err := url.JoinPath(c.BaseURL, cmd) if err != nil { return nil, err } - if c.auth.authSid != "" { - data.Set("auth-sid", c.auth.authSid) + if getSessionID(ctx) != "" { + data.Set("auth-sid", getSessionID(ctx)) } if c.Debug { log.Infof("postRequest:\n\tURL: %q\n\tData: %v", endpoint, data) } - resp, err := c.HTTPClient.PostForm(endpoint, data) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode())) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, err + return nil, errutils.NewHTTPDoError(req, err) } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP error %d [%s]: %v", resp.StatusCode, http.StatusText(resp.StatusCode), string(body)) + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - return parseResponse(string(body)), nil + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + return parseResponse(string(raw)), nil } // parseResponse parses HTTP response body. diff --git a/providers/dns/joker/internal/dmapi/client_test.go b/providers/dns/joker/internal/dmapi/client_test.go index b2835498..7bdb07ed 100644 --- a/providers/dns/joker/internal/dmapi/client_test.go +++ b/providers/dns/joker/internal/dmapi/client_test.go @@ -23,223 +23,17 @@ const ( serverErrorUsername = "error" ) -func setup(t *testing.T) (*http.ServeMux, string) { +func setupTest(t *testing.T) (*http.ServeMux, string) { t.Helper() mux := http.NewServeMux() - server := httptest.NewServer(mux) t.Cleanup(server.Close) return mux, server.URL } -func TestDNSProvider_login_api_key(t *testing.T) { - testCases := []struct { - desc string - apiKey string - expectedError bool - expectedStatusCode int - expectedAuthSid string - }{ - { - desc: "correct key", - apiKey: correctAPIKey, - expectedStatusCode: 0, - expectedAuthSid: correctAPIKey, - }, - { - desc: "incorrect key", - apiKey: incorrectAPIKey, - expectedStatusCode: 2200, - expectedError: true, - }, - { - desc: "server error", - apiKey: serverErrorAPIKey, - expectedStatusCode: -500, - expectedError: true, - }, - { - desc: "non-ok status code", - apiKey: "333", - expectedStatusCode: 2202, - expectedError: true, - }, - } - - mux, serverURL := setup(t) - - mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - - switch r.FormValue("api-key") { - case correctAPIKey: - _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet") - case incorrectAPIKey: - _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") - case serverErrorAPIKey: - http.NotFound(w, r) - default: - _, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet") - } - }) - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - client := NewClient(AuthInfo{APIKey: test.apiKey}) - client.BaseURL = serverURL - - response, err := client.Login() - if test.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.NotNil(t, response) - assert.Equal(t, test.expectedStatusCode, response.StatusCode) - assert.Equal(t, test.expectedAuthSid, response.AuthSid) - } - }) - } -} - -func TestDNSProvider_login_username(t *testing.T) { - testCases := []struct { - desc string - username string - password string - expectedError bool - expectedStatusCode int - expectedAuthSid string - }{ - { - desc: "correct username and password", - username: correctUsername, - password: "go-acme", - expectedError: false, - expectedStatusCode: 0, - expectedAuthSid: correctAPIKey, - }, - { - desc: "incorrect username", - username: incorrectUsername, - password: "go-acme", - expectedStatusCode: 2200, - expectedError: true, - }, - { - desc: "server error", - username: serverErrorUsername, - password: "go-acme", - expectedStatusCode: -500, - expectedError: true, - }, - { - desc: "non-ok status code", - username: "random", - password: "go-acme", - expectedStatusCode: 2202, - expectedError: true, - }, - } - - mux, serverURL := setup(t) - - mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - - switch r.FormValue("username") { - case correctUsername: - _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet") - case incorrectUsername: - _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") - case serverErrorUsername: - http.NotFound(w, r) - default: - _, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet") - } - }) - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - client := NewClient(AuthInfo{Username: test.username, Password: test.password}) - client.BaseURL = serverURL - - response, err := client.Login() - if test.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.NotNil(t, response) - assert.Equal(t, test.expectedStatusCode, response.StatusCode) - assert.Equal(t, test.expectedAuthSid, response.AuthSid) - } - }) - } -} - -func TestDNSProvider_logout(t *testing.T) { - testCases := []struct { - desc string - authSid string - expectedError bool - expectedStatusCode int - }{ - { - desc: "correct auth-sid", - authSid: correctAPIKey, - expectedStatusCode: 0, - }, - { - desc: "incorrect auth-sid", - authSid: incorrectAPIKey, - expectedStatusCode: 2200, - }, - { - desc: "already logged out", - authSid: "", - expectedError: true, - }, - { - desc: "server error", - authSid: serverErrorAPIKey, - expectedError: true, - }, - } - - mux, serverURL := setup(t) - - mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - - switch r.FormValue("auth-sid") { - case correctAPIKey: - _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\n") - case incorrectAPIKey: - _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") - default: - http.NotFound(w, r) - } - }) - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - client := NewClient(AuthInfo{APIKey: "12345", authSid: test.authSid}) - client.BaseURL = serverURL - - response, err := client.Logout() - if test.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.NotNil(t, response) - assert.Equal(t, test.expectedStatusCode, response.StatusCode) - } - }) - } -} - -func TestDNSProvider_getZone(t *testing.T) { +func TestClient_GetZone(t *testing.T) { testZone := "@ A 0 192.0.2.2 3600" testCases := []struct { @@ -276,7 +70,7 @@ func TestDNSProvider_getZone(t *testing.T) { }, } - mux, serverURL := setup(t) + mux, serverURL := setupTest(t) mux.HandleFunc("/dns-zone-get", func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodPost, r.Method) @@ -296,10 +90,10 @@ func TestDNSProvider_getZone(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - client := NewClient(AuthInfo{APIKey: "12345", authSid: test.authSid}) + client := NewClient(AuthInfo{APIKey: "12345"}) client.BaseURL = serverURL - response, err := client.GetZone(test.domain) + response, err := client.GetZone(mockContext(test.authSid), test.domain) if test.expectedError { require.Error(t, err) } else { @@ -387,7 +181,7 @@ func Test_parseResponse(t *testing.T) { } } -func Test_removeTxtEntryFromZone(t *testing.T) { +func Test_RemoveTxtEntryFromZone(t *testing.T) { testCases := []struct { desc string input string @@ -438,7 +232,7 @@ func Test_removeTxtEntryFromZone(t *testing.T) { } } -func Test_addTxtEntryToZone(t *testing.T) { +func Test_AddTxtEntryToZone(t *testing.T) { testCases := []struct { desc string input string diff --git a/providers/dns/joker/internal/dmapi/identity.go b/providers/dns/joker/internal/dmapi/identity.go new file mode 100644 index 00000000..351d987e --- /dev/null +++ b/providers/dns/joker/internal/dmapi/identity.go @@ -0,0 +1,110 @@ +package dmapi + +import ( + "context" + "errors" + "fmt" + "net/url" + "time" +) + +type token string + +const sessionIDKey token = "session-id" + +// Token session ID. +// > Every request (except "login") requires the presence of the Auth-Sid variable ("Session ID"), +// > which is returned by the "login" request (login). An active session will expire after some inactivity period (default: 1 hour). +// https://joker.com/faq/content/22/12/en/commonalities-for-all-requests.html +type Token struct { + SessionID string + ExpireAt time.Time +} + +// login performs a log in to Joker's DMAPI. +func (c *Client) login(ctx context.Context) (*Response, error) { + var values url.Values + switch { + case c.username != "" && c.password != "": + values = url.Values{ + "username": {c.username}, + "password": {c.password}, + } + case c.apiKey != "": + values = url.Values{"api-key": {c.apiKey}} + default: + return nil, errors.New("no username and password or api-key") + } + + response, err := c.postRequest(ctx, "login", values) + if err != nil { + return response, err + } + + if response == nil { + return nil, errors.New("login returned nil response") + } + + if response.AuthSid == "" { + return response, errors.New("login did not return valid Auth-Sid") + } + + return response, nil +} + +// Logout closes authenticated session with Joker's DMAPI. +func (c *Client) Logout(ctx context.Context) (*Response, error) { + if c.token == nil { + return nil, errors.New("already logged out") + } + + response, err := c.postRequest(ctx, "logout", url.Values{}) + + c.muToken.Lock() + c.token = nil + c.muToken.Unlock() + + if err != nil { + return response, err + } + + return response, nil +} + +func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) { + c.muToken.Lock() + defer c.muToken.Unlock() + + if c.token != nil && time.Now().UTC().Before(c.token.ExpireAt) { + return context.WithValue(ctx, sessionIDKey, c.token.SessionID), nil + } + + response, err := c.login(ctx) + if err != nil { + return nil, formatResponseError(response, err) + } + + c.token = &Token{ + SessionID: response.AuthSid, + ExpireAt: time.Now().UTC().Add(1 * time.Hour), + } + + return context.WithValue(ctx, sessionIDKey, response.AuthSid), nil +} + +func getSessionID(ctx context.Context) string { + tok, ok := ctx.Value(sessionIDKey).(string) + if !ok { + return "" + } + + return tok +} + +// formatResponseError formats error with optional details from DMAPI response. +func formatResponseError(response *Response, err error) error { + if response != nil { + return fmt.Errorf("joker: DMAPI error: %w Response: %v", err, response.Headers) + } + return fmt.Errorf("joker: DMAPI error: %w", err) +} diff --git a/providers/dns/joker/internal/dmapi/identity_test.go b/providers/dns/joker/internal/dmapi/identity_test.go new file mode 100644 index 00000000..418deaf4 --- /dev/null +++ b/providers/dns/joker/internal/dmapi/identity_test.go @@ -0,0 +1,280 @@ +package dmapi + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockContext(sessionID string) context.Context { + if sessionID == "" { + sessionID = "xxx" + } + + return context.WithValue(context.Background(), sessionIDKey, sessionID) +} + +func TestClient_login_apikey(t *testing.T) { + testCases := []struct { + desc string + apiKey string + expectedError bool + expectedStatusCode int + expectedAuthSid string + }{ + { + desc: "correct key", + apiKey: correctAPIKey, + expectedStatusCode: 0, + expectedAuthSid: correctAPIKey, + }, + { + desc: "incorrect key", + apiKey: incorrectAPIKey, + expectedStatusCode: 2200, + expectedError: true, + }, + { + desc: "server error", + apiKey: serverErrorAPIKey, + expectedStatusCode: -500, + expectedError: true, + }, + { + desc: "non-ok status code", + apiKey: "333", + expectedStatusCode: 2202, + expectedError: true, + }, + } + + mux, serverURL := setupTest(t) + + mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + switch r.FormValue("api-key") { + case correctAPIKey: + _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet") + case incorrectAPIKey: + _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") + case serverErrorAPIKey: + http.NotFound(w, r) + default: + _, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet") + } + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(AuthInfo{APIKey: test.apiKey}) + client.BaseURL = serverURL + + response, err := client.login(context.Background()) + if test.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, response) + assert.Equal(t, test.expectedStatusCode, response.StatusCode) + assert.Equal(t, test.expectedAuthSid, response.AuthSid) + } + }) + } +} + +func TestClient_login_username(t *testing.T) { + testCases := []struct { + desc string + username string + password string + expectedError bool + expectedStatusCode int + expectedAuthSid string + }{ + { + desc: "correct username and password", + username: correctUsername, + password: "go-acme", + expectedError: false, + expectedStatusCode: 0, + expectedAuthSid: correctAPIKey, + }, + { + desc: "incorrect username", + username: incorrectUsername, + password: "go-acme", + expectedStatusCode: 2200, + expectedError: true, + }, + { + desc: "server error", + username: serverErrorUsername, + password: "go-acme", + expectedStatusCode: -500, + expectedError: true, + }, + { + desc: "non-ok status code", + username: "random", + password: "go-acme", + expectedStatusCode: 2202, + expectedError: true, + }, + } + + mux, serverURL := setupTest(t) + + mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + switch r.FormValue("username") { + case correctUsername: + _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet") + case incorrectUsername: + _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") + case serverErrorUsername: + http.NotFound(w, r) + default: + _, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet") + } + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(AuthInfo{Username: test.username, Password: test.password}) + client.BaseURL = serverURL + + response, err := client.login(context.Background()) + if test.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, response) + assert.Equal(t, test.expectedStatusCode, response.StatusCode) + assert.Equal(t, test.expectedAuthSid, response.AuthSid) + } + }) + } +} + +func TestClient_logout(t *testing.T) { + testCases := []struct { + desc string + authSid string + expectedError bool + expectedStatusCode int + }{ + { + desc: "correct auth-sid", + authSid: correctAPIKey, + expectedStatusCode: 0, + }, + { + desc: "incorrect auth-sid", + authSid: incorrectAPIKey, + expectedStatusCode: 2200, + }, + { + desc: "already logged out", + authSid: "", + expectedError: true, + }, + { + desc: "server error", + authSid: serverErrorAPIKey, + expectedError: true, + }, + } + + mux, serverURL := setupTest(t) + + mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + switch r.FormValue("auth-sid") { + case correctAPIKey: + _, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\n") + case incorrectAPIKey: + _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") + default: + http.NotFound(w, r) + } + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + client := NewClient(AuthInfo{APIKey: "12345"}) + client.BaseURL = serverURL + client.token = &Token{SessionID: test.authSid} + + response, err := client.Logout(mockContext(test.authSid)) + if test.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, response) + assert.Equal(t, test.expectedStatusCode, response.StatusCode) + } + }) + } +} + +func TestClient_CreateAuthenticatedContext(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + id := atomic.Int32{} + id.Add(100) + + mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + switch r.FormValue("username") { + case correctUsername: + _, _ = fmt.Fprintf(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: %d\n\ncom\nnet", id.Load()) + id.Add(100) + + default: + _, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error") + } + }) + + client := NewClient(AuthInfo{Username: correctUsername, Password: "secret"}) + client.HTTPClient = server.Client() + client.BaseURL = server.URL + + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "100", getSessionID(ctx)) + + // the token is not expired then we use the "cache". + client.muToken.Lock() + client.token.SessionID = "cache" + client.muToken.Unlock() + + ctx, err = client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "cache", getSessionID(ctx)) + + // force the expiration of the token + client.muToken.Lock() + client.token.ExpireAt = time.Now().UTC().Add(-1 * time.Hour) + client.muToken.Unlock() + + ctx, err = client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "200", getSessionID(ctx)) +} diff --git a/providers/dns/joker/internal/svc/client.go b/providers/dns/joker/internal/svc/client.go index 28b98432..6d3a54f9 100644 --- a/providers/dns/joker/internal/svc/client.go +++ b/providers/dns/joker/internal/svc/client.go @@ -3,11 +3,14 @@ package svc import ( + "context" "fmt" "io" "net/http" "strings" + "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) @@ -23,24 +26,24 @@ type request struct { } type Client struct { - HTTPClient *http.Client - BaseURL string - username string password string + + BaseURL string + HTTPClient *http.Client } func NewClient(username, password string) *Client { return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, username: username, password: password, + BaseURL: defaultBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } -func (c *Client) Send(zone, label, value string) error { - req := request{ +func (c *Client) SendRequest(ctx context.Context, zone, label, value string) error { + payload := request{ Username: c.username, Password: c.password, Zone: zone, @@ -49,24 +52,31 @@ func (c *Client) Send(zone, label, value string) error { Value: value, } - v, err := querystring.Values(req) + v, err := querystring.Values(payload) if err != nil { return err } - resp, err := c.HTTPClient.PostForm(c.BaseURL, v) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL, strings.NewReader(v.Encode())) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } - all, err := io.ReadAll(resp.Body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } - if resp.StatusCode == http.StatusOK && strings.HasPrefix(string(all), "OK") { + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if resp.StatusCode == http.StatusOK && strings.HasPrefix(string(raw), "OK") { return nil } - return fmt.Errorf("error: %d: %s", resp.StatusCode, string(all)) + return fmt.Errorf("error: %d: %s", resp.StatusCode, string(raw)) } diff --git a/providers/dns/joker/internal/svc/client_test.go b/providers/dns/joker/internal/svc/client_test.go index b75139a2..6803ae84 100644 --- a/providers/dns/joker/internal/svc/client_test.go +++ b/providers/dns/joker/internal/svc/client_test.go @@ -1,6 +1,7 @@ package svc import ( + "context" "fmt" "io" "net/http" @@ -10,11 +11,23 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_Send(t *testing.T) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) + client := NewClient("test", "secret") + client.BaseURL = server.URL + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_Send(t *testing.T) { + client, mux := setupTest(t) + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusMethodNotAllowed) @@ -35,21 +48,16 @@ func TestClient_Send(t *testing.T) { } }) - client := NewClient("test", "secret") - client.BaseURL = server.URL - zone := "example.com" label := "_acme-challenge" value := "123" - err := client.Send(zone, label, value) + err := client.SendRequest(context.Background(), zone, label, value) require.NoError(t, err) } func TestClient_Send_empty(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t) mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -71,13 +79,10 @@ func TestClient_Send_empty(t *testing.T) { } }) - client := NewClient("test", "secret") - client.BaseURL = server.URL - zone := "example.com" label := "_acme-challenge" value := "" - err := client.Send(zone, label, value) + err := client.SendRequest(context.Background(), zone, label, value) require.NoError(t, err) } diff --git a/providers/dns/joker/provider_dmapi.go b/providers/dns/joker/provider_dmapi.go index 35ef43bf..b33d7d48 100644 --- a/providers/dns/joker/provider_dmapi.go +++ b/providers/dns/joker/provider_dmapi.go @@ -1,6 +1,7 @@ package joker import ( + "context" "errors" "fmt" "time" @@ -77,7 +78,7 @@ func (d *dmapiProvider) Present(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("joker: %w", err) + return fmt.Errorf("joker: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone) @@ -89,19 +90,19 @@ func (d *dmapiProvider) Present(domain, token, keyAuth string) error { log.Infof("[%s] joker: adding TXT record %q to zone %q with value %q", domain, subDomain, zone, info.Value) } - response, err := d.client.Login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return formatResponseError(response, err) + return err } - response, err = d.client.GetZone(zone) + response, err := d.client.GetZone(ctx, zone) if err != nil || response.StatusCode != 0 { return formatResponseError(response, err) } dnsZone := dmapi.AddTxtEntryToZone(response.Body, subDomain, info.Value, d.config.TTL) - response, err = d.client.PutZone(zone, dnsZone) + response, err = d.client.PutZone(ctx, zone, dnsZone) if err != nil || response.StatusCode != 0 { return formatResponseError(response, err) } @@ -115,7 +116,7 @@ func (d *dmapiProvider) CleanUp(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("joker: %w", err) + return fmt.Errorf("joker: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone) @@ -127,30 +128,30 @@ func (d *dmapiProvider) CleanUp(domain, token, keyAuth string) error { log.Infof("[%s] joker: removing entry %q from zone %q", domain, subDomain, zone) } - response, err := d.client.Login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return formatResponseError(response, err) + return err } defer func() { // Try to log out in case of errors - _, _ = d.client.Logout() + _, _ = d.client.Logout(ctx) }() - response, err = d.client.GetZone(zone) + response, err := d.client.GetZone(ctx, zone) if err != nil || response.StatusCode != 0 { return formatResponseError(response, err) } dnsZone, modified := dmapi.RemoveTxtEntryFromZone(response.Body, subDomain) if modified { - response, err = d.client.PutZone(zone, dnsZone) + response, err = d.client.PutZone(ctx, zone, dnsZone) if err != nil || response.StatusCode != 0 { return formatResponseError(response, err) } } - response, err = d.client.Logout() + response, err = d.client.Logout(ctx) if err != nil { return formatResponseError(response, err) } diff --git a/providers/dns/joker/provider_svc.go b/providers/dns/joker/provider_svc.go index 68b64404..837a0503 100644 --- a/providers/dns/joker/provider_svc.go +++ b/providers/dns/joker/provider_svc.go @@ -1,6 +1,7 @@ package joker import ( + "context" "errors" "fmt" "time" @@ -58,7 +59,7 @@ func (d *svcProvider) Present(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("joker: %w", err) + return fmt.Errorf("joker: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone) @@ -66,7 +67,7 @@ func (d *svcProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("joker: %w", err) } - return d.client.Send(dns01.UnFqdn(zone), subDomain, info.Value) + return d.client.SendRequest(context.Background(), dns01.UnFqdn(zone), subDomain, info.Value) } // CleanUp removes the TXT record matching the specified parameters. @@ -75,7 +76,7 @@ func (d *svcProvider) CleanUp(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("joker: %w", err) + return fmt.Errorf("joker: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone) @@ -83,7 +84,7 @@ func (d *svcProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("joker: %w", err) } - return d.client.Send(dns01.UnFqdn(zone), subDomain, "") + return d.client.SendRequest(context.Background(), dns01.UnFqdn(zone), subDomain, "") } // Sequential All DNS challenges for this provider will be resolved sequentially. diff --git a/providers/dns/liara/internal/client.go b/providers/dns/liara/internal/client.go index 56fe2e85..89794f04 100644 --- a/providers/dns/liara/internal/client.go +++ b/providers/dns/liara/internal/client.go @@ -2,170 +2,208 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" ) const defaultBaseURL = "https://dns-service.iran.liara.ir" // Client a Liara DNS API client. type Client struct { - apiKey string baseURL *url.URL - HTTPClient *http.Client + httpClient *http.Client } // NewClient creates a new Client. -func NewClient(apiKey string) *Client { +func NewClient(hc *http.Client) *Client { baseURL, _ := url.Parse(defaultBaseURL) - return &Client{ - apiKey: apiKey, - HTTPClient: &http.Client{Timeout: 10 * time.Second}, - baseURL: baseURL, + if hc == nil { + hc = &http.Client{Timeout: 10 * time.Second} } + + return &Client{httpClient: hc, baseURL: baseURL} } // GetRecords gets the records of a domain. // https://dns-service.iran.liara.ir/swagger -func (c Client) GetRecords(domainName string) ([]Record, error) { +func (c Client) GetRecords(ctx context.Context, domainName string) ([]Record, error) { endpoint := c.baseURL.JoinPath("api", "v1", "zones", domainName, "dns-records") - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return nil, err + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, readError(resp) + return nil, parseError(req, resp) } - var response RecordsResponse - err = json.NewDecoder(resp.Body).Decode(&response) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var response Response[[]Record] + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return response.Data, nil } // CreateRecord creates a record. -func (c Client) CreateRecord(domainName string, record Record) (*Record, error) { +func (c Client) CreateRecord(ctx context.Context, domainName string, record Record) (*Record, error) { endpoint := c.baseURL.JoinPath("api", "v1", "zones", domainName, "dns-records") - body, err := json.Marshal(record) - if err != nil { - return nil, fmt.Errorf("marshal request data: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return nil, err + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated { - return nil, readError(resp) + return nil, parseError(req, resp) } - var response RecordResponse - err = json.NewDecoder(resp.Body).Decode(&response) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } - return &response.Data, nil + var response Response[*Record] + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return response.Data, nil } // GetRecord gets a specific record. -func (c Client) GetRecord(domainName, recordID string) (*Record, error) { +func (c Client) GetRecord(ctx context.Context, domainName, recordID string) (*Record, error) { endpoint := c.baseURL.JoinPath("api", "v1", "zones", domainName, "dns-records", recordID) - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return nil, err + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, readError(resp) + return nil, parseError(req, resp) } - var response RecordResponse - err = json.NewDecoder(resp.Body).Decode(&response) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } - return &response.Data, nil + var response Response[*Record] + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return response.Data, nil } // DeleteRecord deletes a record. -func (c Client) DeleteRecord(domainName, recordID string) error { +func (c Client) DeleteRecord(ctx context.Context, domainName, recordID string) error { endpoint := c.baseURL.JoinPath("api", "v1", "zones", domainName, "dns-records", recordID) - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { - return readError(resp) + return parseError(req, resp) } return nil } -func readError(resp *http.Response) error { - all, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("API error (status code: %d)", resp.StatusCode) +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } } - var apiError APIError - err = json.Unmarshal(all, &apiError) + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) if err != nil { - return fmt.Errorf("API error (status code: %d): %s", resp.StatusCode, string(all)) + return nil, fmt.Errorf("unable to create request: %w", err) } - return fmt.Errorf("API error (status code: %d): %w", resp.StatusCode, &apiError) + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errAPI APIError + err := json.Unmarshal(raw, &errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("[status code: %d] %w", resp.StatusCode, &errAPI) +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client } diff --git a/providers/dns/liara/internal/client_test.go b/providers/dns/liara/internal/client_test.go index f083af4e..ed6672ab 100644 --- a/providers/dns/liara/internal/client_test.go +++ b/providers/dns/liara/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -17,11 +18,11 @@ import ( const apiKey = "key" func TestClient_GetRecords(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records", testHandler("./RecordsResponse.json", http.MethodGet, http.StatusOK)) - records, err := client.GetRecords("example.com") + records, err := client.GetRecords(context.Background(), "example.com") require.NoError(t, err) expected := []Record{ @@ -41,11 +42,11 @@ func TestClient_GetRecords(t *testing.T) { } func TestClient_GetRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records/123", testHandler("./RecordResponse.json", http.MethodGet, http.StatusOK)) - record, err := client.GetRecord("example.com", "123") + record, err := client.GetRecord(context.Background(), "example.com", "123") require.NoError(t, err) expected := &Record{ @@ -63,7 +64,7 @@ func TestClient_GetRecord(t *testing.T) { } func TestClient_CreateRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records", testHandler("./RecordResponse.json", http.MethodPost, http.StatusCreated)) @@ -78,7 +79,7 @@ func TestClient_CreateRecord(t *testing.T) { TTL: 3600, } - record, err := client.CreateRecord("example.com", data) + record, err := client.CreateRecord(context.Background(), "example.com", data) require.NoError(t, err) expected := &Record{ @@ -97,33 +98,33 @@ func TestClient_CreateRecord(t *testing.T) { } func TestClient_DeleteRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records/123", func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusNoContent) }) - err := client.DeleteRecord("example.com", "123") + err := client.DeleteRecord(context.Background(), "example.com", "123") require.NoError(t, err) } func TestClient_DeleteRecord_NotFound_Response(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records/123", func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusNotFound) }) - err := client.DeleteRecord("example.com", "123") + err := client.DeleteRecord(context.Background(), "example.com", "123") require.NoError(t, err) } func TestClient_DeleteRecord_error(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/api/v1/zones/example.com/dns-records/123", testHandler("./error.json", http.MethodDelete, http.StatusUnauthorized)) - err := client.DeleteRecord("example.com", "123") + err := client.DeleteRecord(context.Background(), "example.com", "123") require.Error(t, err) } @@ -158,16 +159,14 @@ func testHandler(filename string, method string, statusCode int) http.HandlerFun } } -func setup(t *testing.T) (*Client, *http.ServeMux) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() - server := httptest.NewServer(mux) t.Cleanup(server.Close) - client := NewClient(apiKey) - client.HTTPClient = server.Client() + client := NewClient(OAuthStaticAccessToken(server.Client(), apiKey)) client.baseURL, _ = url.Parse(server.URL) return client, mux diff --git a/providers/dns/liara/internal/types.go b/providers/dns/liara/internal/types.go index 0b817c24..34ae2c2c 100644 --- a/providers/dns/liara/internal/types.go +++ b/providers/dns/liara/internal/types.go @@ -14,14 +14,9 @@ type Record struct { Contents []Content `json:"contents"` } -type RecordResponse struct { +type Response[D any] struct { Status string `json:"status"` - Data Record `json:"data"` -} - -type RecordsResponse struct { - Status string `json:"status"` - Data []Record `json:"data"` + Data D `json:"data"` } type APIError struct { diff --git a/providers/dns/liara/liara.go b/providers/dns/liara/liara.go index 053b5a1d..27d3e600 100644 --- a/providers/dns/liara/liara.go +++ b/providers/dns/liara/liara.go @@ -2,6 +2,7 @@ package liara import ( + "context" "errors" "fmt" "net/http" @@ -94,8 +95,6 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, fmt.Errorf("liara: invalid TTL, TTL (%d) must be lower than %d", config.TTL, maxTTL) } - client := internal.NewClient(config.APIKey) - retryClient := retryablehttp.NewClient() retryClient.RetryMax = 5 if config.HTTPClient != nil { @@ -103,7 +102,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { } retryClient.Logger = log.Logger - client.HTTPClient = retryClient.StandardClient() + client := internal.NewClient(internal.OAuthStaticAccessToken(retryClient.StandardClient(), config.APIKey)) return &DNSProvider{ config: config, @@ -124,7 +123,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("liara: %w", err) + return fmt.Errorf("liara: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -138,7 +137,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Contents: []internal.Content{{Text: info.Value}}, TTL: d.config.TTL, } - newRecord, err := d.client.CreateRecord(dns01.UnFqdn(authZone), record) + newRecord, err := d.client.CreateRecord(context.Background(), dns01.UnFqdn(authZone), record) if err != nil { return fmt.Errorf("liara: failed to create TXT record, fqdn=%s: %w", info.EffectiveFQDN, err) } @@ -156,7 +155,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("liara: %w", err) + return fmt.Errorf("liara: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // gets the record's unique ID @@ -167,7 +166,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("liara: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - err = d.client.DeleteRecord(dns01.UnFqdn(authZone), recordID) + err = d.client.DeleteRecord(context.Background(), dns01.UnFqdn(authZone), recordID) if err != nil { return fmt.Errorf("liara: failed to delete TXT record, id=%s: %w", recordID, err) } diff --git a/providers/dns/linode/linode.go b/providers/dns/linode/linode.go index 8569cfac..4143f03e 100644 --- a/providers/dns/linode/linode.go +++ b/providers/dns/linode/linode.go @@ -91,21 +91,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, fmt.Errorf("linode: invalid TTL, TTL (%d) must be greater than %d", config.TTL, minTTL) } - tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.Token}) oauth2Client := &http.Client{ Timeout: config.HTTPTimeout, Transport: &oauth2.Transport{ - Source: tokenSource, + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.Token}), }, } client := linodego.NewClient(oauth2Client) - client.SetUserAgent("lego-dns https://github.com/linode/linodego") + client.SetUserAgent("go-acme/lego https://github.com/linode/linodego") - return &DNSProvider{ - config: config, - client: &client, - }, nil + return &DNSProvider{config: config, client: &client}, nil } // Timeout returns the timeout and interval to use when checking for DNS @@ -158,7 +154,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { } // Get all TXT records for the specified domain. - listOpts := linodego.NewListOptions(0, "{\"type\":\"TXT\"}") + listOpts := linodego.NewListOptions(0, `{"type":"TXT"}`) resources, err := d.client.ListDomainRecords(context.Background(), zone.domainID, listOpts) if err != nil { return err @@ -181,16 +177,16 @@ func (d *DNSProvider) getHostedZoneInfo(fqdn string) (*hostedZoneInfo, error) { // Lookup the zone that handles the specified FQDN. authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return nil, err + return nil, fmt.Errorf("inwx: could not find zone for FQDN %q: %w", fqdn, err) } // Query the authority zone. - data, err := json.Marshal(map[string]string{"domain": dns01.UnFqdn(authZone)}) + filter, err := json.Marshal(map[string]string{"domain": dns01.UnFqdn(authZone)}) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create JSON filter: %w", err) } - listOpts := linodego.NewListOptions(0, string(data)) + listOpts := linodego.NewListOptions(0, string(filter)) domains, err := d.client.ListDomains(context.Background(), listOpts) if err != nil { return nil, err diff --git a/providers/dns/loopia/internal/client.go b/providers/dns/loopia/internal/client.go index 013e5a99..d521ffee 100644 --- a/providers/dns/loopia/internal/client.go +++ b/providers/dns/loopia/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/xml" "errors" "fmt" @@ -9,6 +10,8 @@ import ( "net/http" "strings" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // DefaultBaseURL is url to the XML-RPC api. @@ -16,29 +19,30 @@ const DefaultBaseURL = "https://api.loopia.se/RPCSERV" // Client the Loopia client. type Client struct { - APIUser string - APIPassword string - BaseURL string - HTTPClient *http.Client + apiUser string + apiPassword string + + BaseURL string + HTTPClient *http.Client } // NewClient creates a new Loopia Client. func NewClient(apiUser, apiPassword string) *Client { return &Client{ - APIUser: apiUser, - APIPassword: apiPassword, + apiUser: apiUser, + apiPassword: apiPassword, BaseURL: DefaultBaseURL, HTTPClient: &http.Client{Timeout: 10 * time.Second}, } } // AddTXTRecord adds a TXT record. -func (c *Client) AddTXTRecord(domain string, subdomain string, ttl int, value string) error { +func (c *Client) AddTXTRecord(ctx context.Context, domain string, subdomain string, ttl int, value string) error { call := &methodCall{ MethodName: "addZoneRecord", Params: []param{ - paramString{Value: c.APIUser}, - paramString{Value: c.APIPassword}, + paramString{Value: c.apiUser}, + paramString{Value: c.apiPassword}, paramString{Value: domain}, paramString{Value: subdomain}, paramStruct{ @@ -54,7 +58,7 @@ func (c *Client) AddTXTRecord(domain string, subdomain string, ttl int, value st } resp := &responseString{} - err := c.rpcCall(call, resp) + err := c.rpcCall(ctx, call, resp) if err != nil { return err } @@ -63,12 +67,12 @@ func (c *Client) AddTXTRecord(domain string, subdomain string, ttl int, value st } // RemoveTXTRecord removes a TXT record. -func (c *Client) RemoveTXTRecord(domain string, subdomain string, recordID int) error { +func (c *Client) RemoveTXTRecord(ctx context.Context, domain string, subdomain string, recordID int) error { call := &methodCall{ MethodName: "removeZoneRecord", Params: []param{ - paramString{Value: c.APIUser}, - paramString{Value: c.APIPassword}, + paramString{Value: c.apiUser}, + paramString{Value: c.apiPassword}, paramString{Value: domain}, paramString{Value: subdomain}, paramInt{Value: recordID}, @@ -76,7 +80,7 @@ func (c *Client) RemoveTXTRecord(domain string, subdomain string, recordID int) } resp := &responseString{} - err := c.rpcCall(call, resp) + err := c.rpcCall(ctx, call, resp) if err != nil { return err } @@ -85,37 +89,37 @@ func (c *Client) RemoveTXTRecord(domain string, subdomain string, recordID int) } // GetTXTRecords gets TXT records. -func (c *Client) GetTXTRecords(domain string, subdomain string) ([]RecordObj, error) { +func (c *Client) GetTXTRecords(ctx context.Context, domain string, subdomain string) ([]RecordObj, error) { call := &methodCall{ MethodName: "getZoneRecords", Params: []param{ - paramString{Value: c.APIUser}, - paramString{Value: c.APIPassword}, + paramString{Value: c.apiUser}, + paramString{Value: c.apiPassword}, paramString{Value: domain}, paramString{Value: subdomain}, }, } resp := &recordObjectsResponse{} - err := c.rpcCall(call, resp) + err := c.rpcCall(ctx, call, resp) return resp.Params, err } // RemoveSubdomain remove a sub-domain. -func (c *Client) RemoveSubdomain(domain, subdomain string) error { +func (c *Client) RemoveSubdomain(ctx context.Context, domain, subdomain string) error { call := &methodCall{ MethodName: "removeSubdomain", Params: []param{ - paramString{Value: c.APIUser}, - paramString{Value: c.APIPassword}, + paramString{Value: c.apiUser}, + paramString{Value: c.apiPassword}, paramString{Value: domain}, paramString{Value: subdomain}, }, } resp := &responseString{} - err := c.rpcCall(call, resp) + err := c.rpcCall(ctx, call, resp) if err != nil { return err } @@ -123,55 +127,66 @@ func (c *Client) RemoveSubdomain(domain, subdomain string) error { return checkResponse(resp.Value) } -// rpcCall makes an XML-RPC call to Loopia's RPC endpoint -// by marshaling the data given in the call argument to XML and sending that via HTTP Post to Loopia. +// rpcCall makes an XML-RPC call to Loopia's RPC endpoint by marshaling the data given in the call argument to XML +// and sending that via HTTP Post to Loopia. // The response is then unmarshalled into the resp argument. -func (c *Client) rpcCall(call *methodCall, resp response) error { - body, err := xml.MarshalIndent(call, "", " ") - if err != nil { - return fmt.Errorf("error during unmarshalling the request body: %w", err) - } - - body = append([]byte(``+"\n"), body...) - - respBody, err := c.httpPost(c.BaseURL, "text/xml", bytes.NewReader(body)) +func (c *Client) rpcCall(ctx context.Context, call *methodCall, result response) error { + req, err := newXMLRequest(ctx, c.BaseURL, call) if err != nil { return err } - err = xml.Unmarshal(respBody, resp) + resp, err := c.HTTPClient.Do(req) if err != nil { - return fmt.Errorf("error during unmarshalling the response body: %w", err) + return errutils.NewHTTPDoError(req, err) } - if resp.faultCode() != 0 { - return rpcError{ - faultCode: resp.faultCode(), - faultString: strings.TrimSpace(resp.faultString()), + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = xml.Unmarshal(raw, result) + if err != nil { + return fmt.Errorf("unmarshal error: %w", err) + } + + if result.faultCode() != 0 { + return RPCError{ + FaultCode: result.faultCode(), + FaultString: strings.TrimSpace(result.faultString()), } } return nil } -func (c *Client) httpPost(url string, bodyType string, body io.Reader) ([]byte, error) { - resp, err := c.HTTPClient.Post(url, bodyType, body) +func newXMLRequest(ctx context.Context, endpoint string, payload any) (*http.Request, error) { + body := new(bytes.Buffer) + body.WriteString(xml.Header) + + encoder := xml.NewEncoder(body) + encoder.Indent("", " ") + + err := encoder.Encode(payload) if err != nil { - return nil, fmt.Errorf("HTTP Post Error: %w", err) + return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP Post Error: %d", resp.StatusCode) - } - - b, err := io.ReadAll(resp.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) if err != nil { - return nil, fmt.Errorf("HTTP Post Error: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } - return b, nil + req.Header.Set("Content-Type", "text/xml") + + return req, nil } func checkResponse(value string) error { diff --git a/providers/dns/loopia/internal/client_test.go b/providers/dns/loopia/internal/client_test.go index 67758177..e62fc2b6 100644 --- a/providers/dns/loopia/internal/client_test.go +++ b/providers/dns/loopia/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/xml" "fmt" "io" @@ -49,7 +50,7 @@ func TestClient_AddZoneRecord(t *testing.T) { desc: "empty response", password: "goodpassword", domain: "empty.com", - err: "error during unmarshalling the response body: EOF", + err: "unmarshal error: EOF", }, } @@ -58,7 +59,7 @@ func TestClient_AddZoneRecord(t *testing.T) { client := NewClient("apiuser", test.password) client.BaseURL = serverURL + "/" - err := client.AddTXTRecord(test.domain, exampleSubDomain, 123, "TXTrecord") + err := client.AddTXTRecord(context.Background(), test.domain, exampleSubDomain, 123, "TXTrecord") if test.err == "" { require.NoError(t, err) } else { @@ -106,7 +107,7 @@ func TestClient_RemoveSubdomain(t *testing.T) { desc: "empty response", password: "goodpassword", domain: "empty.com", - err: "error during unmarshalling the response body: EOF", + err: "unmarshal error: EOF", }, } @@ -115,7 +116,7 @@ func TestClient_RemoveSubdomain(t *testing.T) { client := NewClient("apiuser", test.password) client.BaseURL = serverURL + "/" - err := client.RemoveSubdomain(test.domain, exampleSubDomain) + err := client.RemoveSubdomain(context.Background(), test.domain, exampleSubDomain) if test.err == "" { require.NoError(t, err) } else { @@ -163,7 +164,7 @@ func TestClient_RemoveZoneRecord(t *testing.T) { desc: "empty response", password: "goodpassword", domain: "empty.com", - err: "error during unmarshalling the response body: EOF", + err: "unmarshal error: EOF", }, } @@ -172,7 +173,7 @@ func TestClient_RemoveZoneRecord(t *testing.T) { client := NewClient("apiuser", test.password) client.BaseURL = serverURL + "/" - err := client.RemoveTXTRecord(test.domain, exampleSubDomain, 12345678) + err := client.RemoveTXTRecord(context.Background(), test.domain, exampleSubDomain, 12345678) if test.err == "" { require.NoError(t, err) } else { @@ -193,7 +194,7 @@ func TestClient_GetZoneRecord(t *testing.T) { client := NewClient("apiuser", "goodpassword") client.BaseURL = serverURL + "/" - recordObjs, err := client.GetTXTRecords(exampleDomain, exampleSubDomain) + recordObjs, err := client.GetTXTRecords(context.Background(), exampleDomain, exampleSubDomain) require.NoError(t, err) expected := []RecordObj{ @@ -237,8 +238,8 @@ func TestClient_rpcCall_404(t *testing.T) { client := NewClient("apiuser", "apipassword") client.BaseURL = server.URL + "/" - err := client.rpcCall(call, &responseString{}) - assert.EqualError(t, err, "HTTP Post Error: 404") + err := client.rpcCall(context.Background(), call, &responseString{}) + assert.EqualError(t, err, "unexpected status code: [status code: 404] body: ") } func TestClient_rpcCall_RPCError(t *testing.T) { @@ -268,7 +269,7 @@ func TestClient_rpcCall_RPCError(t *testing.T) { client := NewClient("apiuser", "apipassword") client.BaseURL = server.URL + "/" - err := client.rpcCall(call, &responseString{}) + err := client.rpcCall(context.Background(), call, &responseString{}) assert.EqualError(t, err, "RPC Error: (201) Method signature error: 42") } diff --git a/providers/dns/loopia/internal/mock_test.go b/providers/dns/loopia/internal/mock_test.go index 13b9970b..7896b520 100644 --- a/providers/dns/loopia/internal/mock_test.go +++ b/providers/dns/loopia/internal/mock_test.go @@ -6,7 +6,7 @@ const ( exampleRdata = "LHDhK3oGRvkiefQnx7OOczTY5Tic_xZ6HcMOc_gmtoM" ) -// Testdata based on real traffic between an xml-rpc client and the api. +// Testdata based on real traffic between a xml-rpc client and the api. const responseOk = ` @@ -76,7 +76,7 @@ const responseRPCError = ` ` -const addZoneRecordGoodAuth = ` +const addZoneRecordGoodAuth = ` addZoneRecord @@ -139,7 +139,7 @@ const addZoneRecordGoodAuth = ` ` -const addZoneRecordBadAuth = ` +const addZoneRecordBadAuth = ` addZoneRecord @@ -202,7 +202,7 @@ const addZoneRecordBadAuth = ` ` -const addZoneRecordNonValidDomain = ` +const addZoneRecordNonValidDomain = ` addZoneRecord @@ -265,7 +265,7 @@ const addZoneRecordNonValidDomain = ` ` -const addZoneRecordEmptyResponse = ` +const addZoneRecordEmptyResponse = ` addZoneRecord @@ -328,7 +328,7 @@ const addZoneRecordEmptyResponse = ` ` -const getZoneRecords = ` +const getZoneRecords = ` getZoneRecords @@ -423,7 +423,7 @@ const getZoneRecordsResponse = ` ` -const removeRecordGoodAuth = ` +const removeRecordGoodAuth = ` removeZoneRecord @@ -455,7 +455,7 @@ const removeRecordGoodAuth = ` ` -const removeRecordBadAuth = ` +const removeRecordBadAuth = ` removeZoneRecord @@ -487,7 +487,7 @@ const removeRecordBadAuth = ` ` -const removeRecordNonValidDomain = ` +const removeRecordNonValidDomain = ` removeZoneRecord @@ -519,7 +519,7 @@ const removeRecordNonValidDomain = ` ` -const removeRecordEmptyResponse = ` +const removeRecordEmptyResponse = ` removeZoneRecord @@ -551,7 +551,7 @@ const removeRecordEmptyResponse = ` ` -const removeSubdomainGoodAuth = ` +const removeSubdomainGoodAuth = ` removeSubdomain @@ -578,7 +578,7 @@ const removeSubdomainGoodAuth = ` ` -const removeSubdomainBadAuth = ` +const removeSubdomainBadAuth = ` removeSubdomain @@ -605,7 +605,7 @@ const removeSubdomainBadAuth = ` ` -const removeSubdomainNonValidDomain = ` +const removeSubdomainNonValidDomain = ` removeSubdomain @@ -632,7 +632,7 @@ const removeSubdomainNonValidDomain = ` ` -const removeSubdomainEmptyResponse = ` +const removeSubdomainEmptyResponse = ` removeSubdomain diff --git a/providers/dns/loopia/internal/types.go b/providers/dns/loopia/internal/types.go index 9d96da40..c286c01f 100644 --- a/providers/dns/loopia/internal/types.go +++ b/providers/dns/loopia/internal/types.go @@ -77,13 +77,13 @@ type responseFault struct { func (r responseFault) faultCode() int { return r.FaultCode } func (r responseFault) faultString() string { return r.FaultString } -type rpcError struct { - faultCode int - faultString string +type RPCError struct { + FaultCode int + FaultString string } -func (e rpcError) Error() string { - return fmt.Sprintf("RPC Error: (%d) %s", e.faultCode, e.faultString) +func (e RPCError) Error() string { + return fmt.Sprintf("RPC Error: (%d) %s", e.FaultCode, e.FaultString) } type recordObjectsResponse struct { diff --git a/providers/dns/loopia/loopia.go b/providers/dns/loopia/loopia.go index 579a3efc..ed0fd02d 100644 --- a/providers/dns/loopia/loopia.go +++ b/providers/dns/loopia/loopia.go @@ -2,6 +2,7 @@ package loopia import ( + "context" "errors" "fmt" "net/http" @@ -30,10 +31,10 @@ const ( ) type dnsClient interface { - AddTXTRecord(domain string, subdomain string, ttl int, value string) error - RemoveTXTRecord(domain string, subdomain string, recordID int) error - GetTXTRecords(domain string, subdomain string) ([]internal.RecordObj, error) - RemoveSubdomain(domain, subdomain string) error + AddTXTRecord(ctx context.Context, domain string, subdomain string, ttl int, value string) error + RemoveTXTRecord(ctx context.Context, domain string, subdomain string, recordID int) error + GetTXTRecords(ctx context.Context, domain string, subdomain string) ([]internal.RecordObj, error) + RemoveSubdomain(ctx context.Context, domain, subdomain string) error } // Config is used to configure the creation of the DNSProvider. @@ -67,6 +68,7 @@ type DNSProvider struct { inProgressInfo map[string]int inProgressMu sync.Mutex + // only for testing purpose. findZoneByFqdn func(fqdn string) (string, error) } @@ -135,12 +137,14 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("loopia: %w", err) } - err = d.client.AddTXTRecord(authZone, subDomain, d.config.TTL, info.Value) + ctx := context.Background() + + err = d.client.AddTXTRecord(ctx, authZone, subDomain, d.config.TTL, info.Value) if err != nil { return fmt.Errorf("loopia: failed to add TXT record: %w", err) } - txtRecords, err := d.client.GetTXTRecords(authZone, subDomain) + txtRecords, err := d.client.GetTXTRecords(ctx, authZone, subDomain) if err != nil { return fmt.Errorf("loopia: failed to get TXT records: %w", err) } @@ -170,12 +174,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { d.inProgressMu.Lock() defer d.inProgressMu.Unlock() - err = d.client.RemoveTXTRecord(authZone, subDomain, d.inProgressInfo[token]) + ctx := context.Background() + + err = d.client.RemoveTXTRecord(ctx, authZone, subDomain, d.inProgressInfo[token]) if err != nil { return fmt.Errorf("loopia: failed to remove TXT record: %w", err) } - records, err := d.client.GetTXTRecords(authZone, subDomain) + records, err := d.client.GetTXTRecords(ctx, authZone, subDomain) if err != nil { return fmt.Errorf("loopia: failed to get TXT records: %w", err) } @@ -184,7 +190,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } - err = d.client.RemoveSubdomain(authZone, subDomain) + err = d.client.RemoveSubdomain(ctx, authZone, subDomain) if err != nil { return fmt.Errorf("loopia: failed to remove sub-domain: %w", err) } @@ -193,13 +199,15 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { } func (d *DNSProvider) splitDomain(fqdn string) (string, string, error) { - authZone, _ := d.findZoneByFqdn(fqdn) - authZone = dns01.UnFqdn(authZone) + authZone, err := d.findZoneByFqdn(fqdn) + if err != nil { + return "", "", fmt.Errorf("desec: could not find zone for FQDN %q: %w", fqdn, err) + } subDomain, err := dns01.ExtractSubDomain(fqdn, authZone) if err != nil { return "", "", err } - return subDomain, authZone, nil + return subDomain, dns01.UnFqdn(authZone), nil } diff --git a/providers/dns/loopia/loopia_mock_test.go b/providers/dns/loopia/loopia_mock_test.go index b8f108fa..79fe2b13 100644 --- a/providers/dns/loopia/loopia_mock_test.go +++ b/providers/dns/loopia/loopia_mock_test.go @@ -1,6 +1,7 @@ package loopia import ( + "context" "errors" "fmt" "testing" @@ -215,22 +216,22 @@ type mockedClient struct { mock.Mock } -func (c *mockedClient) RemoveTXTRecord(domain string, subdomain string, recordID int) error { +func (c *mockedClient) RemoveTXTRecord(ctx context.Context, domain string, subdomain string, recordID int) error { args := c.Called(domain, subdomain, recordID) return args.Error(0) } -func (c *mockedClient) AddTXTRecord(domain string, subdomain string, ttl int, value string) error { +func (c *mockedClient) AddTXTRecord(ctx context.Context, domain string, subdomain string, ttl int, value string) error { args := c.Called(domain, subdomain, ttl, value) return args.Error(0) } -func (c *mockedClient) GetTXTRecords(domain string, subdomain string) ([]internal.RecordObj, error) { +func (c *mockedClient) GetTXTRecords(ctx context.Context, domain string, subdomain string) ([]internal.RecordObj, error) { args := c.Called(domain, subdomain) return args.Get(0).([]internal.RecordObj), args.Error(1) } -func (c *mockedClient) RemoveSubdomain(domain, subdomain string) error { +func (c *mockedClient) RemoveSubdomain(ctx context.Context, domain, subdomain string) error { args := c.Called(domain, subdomain) return args.Error(0) } diff --git a/providers/dns/luadns/internal/client.go b/providers/dns/luadns/internal/client.go index 6f853cc2..8e46418f 100644 --- a/providers/dns/luadns/internal/client.go +++ b/providers/dns/luadns/internal/client.go @@ -2,10 +2,16 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" + "net/url" + "strconv" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // defaultBaseURL represents the API endpoint to call. @@ -13,49 +19,39 @@ const defaultBaseURL = "https://api.luadns.com" // Client Lua DNS API client. type Client struct { - HTTPClient *http.Client - BaseURL string - apiUsername string apiToken string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient creates a new Client. func NewClient(apiUsername, apiToken string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, apiUsername: apiUsername, apiToken: apiToken, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // ListZones gets all the hosted zones. // https://luadns.com/api.html#list-zones -func (d *Client) ListZones() ([]DNSZone, error) { - resp, err := d.do(http.MethodGet, "/v1/zones", nil) +func (c *Client) ListZones(ctx context.Context) ([]DNSZone, error) { + endpoint := c.baseURL.JoinPath("v1", "zones") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - - var errResp errorResponse - err = json.Unmarshal(bodyBytes, &errResp) - if err == nil { - return nil, fmt.Errorf("api call error: Status=%v: %w", resp.StatusCode, errResp) - } - - return nil, fmt.Errorf("api call error: Status=%d: %s", resp.StatusCode, string(bodyBytes)) - } - var zones []DNSZone - err = json.NewDecoder(resp.Body).Decode(&zones) + err = c.do(req, &zones) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response body: %w", err) + return nil, fmt.Errorf("could not list zones: %w", err) } return zones, nil @@ -63,39 +59,18 @@ func (d *Client) ListZones() ([]DNSZone, error) { // CreateRecord creates a new record in a zone. // https://luadns.com/api.html#create-a-record -func (d *Client) CreateRecord(zone DNSZone, newRecord DNSRecord) (*DNSRecord, error) { - body, err := json.Marshal(newRecord) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } +func (c *Client) CreateRecord(ctx context.Context, zone DNSZone, newRecord DNSRecord) (*DNSRecord, error) { + endpoint := c.baseURL.JoinPath("v1", "zones", strconv.Itoa(zone.ID), "records") - resource := fmt.Sprintf("/v1/zones/%d/records", zone.ID) - - resp, err := d.do(http.MethodPost, resource, bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, newRecord) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - - var errResp errorResponse - err = json.Unmarshal(bodyBytes, &errResp) - if err == nil { - return nil, fmt.Errorf("could not create record %v: Status=%d: %w", - string(body), resp.StatusCode, errResp) - } - - return nil, fmt.Errorf("could not create record %v: Status=%d: %s", - string(body), resp.StatusCode, string(bodyBytes)) - } - var record *DNSRecord - err = json.NewDecoder(resp.Body).Decode(&record) + err = c.do(req, &record) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response body: %w", err) + return nil, fmt.Errorf("could not create record %#v: %w", record, err) } return record, nil @@ -103,47 +78,85 @@ func (d *Client) CreateRecord(zone DNSZone, newRecord DNSRecord) (*DNSRecord, er // DeleteRecord deletes a record. // https://luadns.com/api.html#delete-a-record -func (d *Client) DeleteRecord(record *DNSRecord) error { - body, err := json.Marshal(record) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } +func (c *Client) DeleteRecord(ctx context.Context, record *DNSRecord) error { + endpoint := c.baseURL.JoinPath("v1", "zones", strconv.Itoa(record.ZoneID), "records", strconv.Itoa(record.ID)) - resource := fmt.Sprintf("/v1/zones/%d/records/%d", record.ZoneID, record.ID) - - resp, err := d.do(http.MethodDelete, resource, bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, record) if err != nil { return err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - - var errResp errorResponse - err = json.Unmarshal(bodyBytes, &errResp) - if err == nil { - return fmt.Errorf("could not delete record %v: Status=%d: %w", - string(body), resp.StatusCode, errResp) - } - - return fmt.Errorf("could not delete record %v: Status=%d: %s", - string(body), resp.StatusCode, string(bodyBytes)) + err = c.do(req, nil) + if err != nil { + return fmt.Errorf("could not delete record %#v: %w", record, err) } return nil } -func (d *Client) do(method, uri string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest(method, fmt.Sprintf("%s%s", d.BaseURL, uri), body) +func (c *Client) do(req *http.Request, result any) error { + req.SetBasicAuth(c.apiUsername, c.apiToken) + + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.SetBasicAuth(d.apiUsername, d.apiToken) - return d.HTTPClient.Do(req) + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errResp errorResponse + err := json.Unmarshal(raw, &errResp) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("status=%d: %w", resp.StatusCode, errResp) } diff --git a/providers/dns/luadns/internal/client_test.go b/providers/dns/luadns/internal/client_test.go index f3cc4642..1fd3efd7 100644 --- a/providers/dns/luadns/internal/client_test.go +++ b/providers/dns/luadns/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -12,13 +14,22 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_ListZones(t *testing.T) { +func setupTest(t *testing.T, apiToken string) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) - client := NewClient("me", "secretA") - client.BaseURL = server.URL + client := NewClient("me", apiToken) + client.baseURL, _ = url.Parse(server.URL) + client.HTTPClient = server.Client() + + return client, mux +} + +func TestClient_ListZones(t *testing.T) { + client, mux := setupTest(t, "secretA") mux.HandleFunc("/v1/zones", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { @@ -46,7 +57,7 @@ func TestClient_ListZones(t *testing.T) { } }) - zones, err := client.ListZones() + zones, err := client.ListZones(context.Background()) require.NoError(t, err) expected := []DNSZone{ @@ -78,12 +89,7 @@ func TestClient_ListZones(t *testing.T) { } func TestClient_CreateRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - client := NewClient("me", "secretB") - client.BaseURL = server.URL + client, mux := setupTest(t, "secretB") mux.HandleFunc("/v1/zones/1/records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -120,7 +126,7 @@ func TestClient_CreateRecord(t *testing.T) { TTL: 300, } - newRecord, err := client.CreateRecord(zone, record) + newRecord, err := client.CreateRecord(context.Background(), zone, record) require.NoError(t, err) expected := &DNSRecord{ @@ -136,12 +142,7 @@ func TestClient_CreateRecord(t *testing.T) { } func TestClient_DeleteRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - client := NewClient("me", "secretC") - client.BaseURL = server.URL + client, mux := setupTest(t, "secretC") mux.HandleFunc("/v1/zones/1/records/2", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodDelete { @@ -178,6 +179,6 @@ func TestClient_DeleteRecord(t *testing.T) { ZoneID: 1, } - err := client.DeleteRecord(record) + err := client.DeleteRecord(context.Background(), record) require.NoError(t, err) } diff --git a/providers/dns/luadns/internal/model.go b/providers/dns/luadns/internal/types.go similarity index 100% rename from providers/dns/luadns/internal/model.go rename to providers/dns/luadns/internal/types.go diff --git a/providers/dns/luadns/luadns.go b/providers/dns/luadns/luadns.go index 089c6be6..5f6f6cc2 100644 --- a/providers/dns/luadns/luadns.go +++ b/providers/dns/luadns/luadns.go @@ -2,6 +2,7 @@ package luadns import ( + "context" "errors" "fmt" "net/http" @@ -114,14 +115,16 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zones, err := d.client.ListZones() + ctx := context.Background() + + zones, err := d.client.ListZones(ctx) if err != nil { return fmt.Errorf("luadns: failed to get zones: %w", err) } authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("luadns: failed to find zone: %w", err) + return fmt.Errorf("luadns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } zone := findZone(zones, dns01.UnFqdn(authZone)) @@ -136,7 +139,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - record, err := d.client.CreateRecord(*zone, newRecord) + record, err := d.client.CreateRecord(ctx, *zone, newRecord) if err != nil { return fmt.Errorf("luadns: failed to create record: %w", err) } @@ -160,7 +163,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("luadns: unknown record ID for '%s'", info.EffectiveFQDN) } - err := d.client.DeleteRecord(record) + err := d.client.DeleteRecord(context.Background(), record) if err != nil { return fmt.Errorf("luadns: failed to delete record: %w", err) } diff --git a/providers/dns/mydnsjp/client.go b/providers/dns/mydnsjp/client.go deleted file mode 100644 index 16bfa734..00000000 --- a/providers/dns/mydnsjp/client.go +++ /dev/null @@ -1,52 +0,0 @@ -package mydnsjp - -import ( - "fmt" - "io" - "net/http" - "net/url" - "strings" -) - -func (d *DNSProvider) doRequest(domain, value, cmd string) error { - req, err := d.buildRequest(domain, value, cmd) - if err != nil { - return err - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("error querying API: %w", err) - } - - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - var content []byte - content, err = io.ReadAll(resp.Body) - if err != nil { - return err - } - - return fmt.Errorf("request %s failed [status code %d]: %s", req.URL, resp.StatusCode, string(content)) - } - - return nil -} - -func (d *DNSProvider) buildRequest(domain, value, cmd string) (*http.Request, error) { - params := url.Values{} - params.Set("CERTBOT_DOMAIN", domain) - params.Set("CERTBOT_VALIDATION", value) - params.Set("EDIT_CMD", cmd) - - req, err := http.NewRequest(http.MethodPost, defaultBaseURL, strings.NewReader(params.Encode())) - if err != nil { - return nil, fmt.Errorf("invalid request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(d.config.MasterID, d.config.Password) - - return req, nil -} diff --git a/providers/dns/mydnsjp/internal/client.go b/providers/dns/mydnsjp/internal/client.go new file mode 100644 index 00000000..9859ed68 --- /dev/null +++ b/providers/dns/mydnsjp/internal/client.go @@ -0,0 +1,81 @@ +package internal + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const defaultBaseURL = "https://www.mydns.jp/directedit.html" + +// Client the MyDNS.jp client. +type Client struct { + masterID string + password string + + baseURL *url.URL + HTTPClient *http.Client +} + +// NewClient Creates a new Client. +func NewClient(masterID string, password string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + masterID: masterID, + password: password, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c Client) AddTXTRecord(ctx context.Context, domain, value string) error { + return c.doRequest(ctx, domain, value, "REGIST") +} + +func (c Client) DeleteTXTRecord(ctx context.Context, domain, value string) error { + return c.doRequest(ctx, domain, value, "DELETE") +} + +func (c Client) buildRequest(ctx context.Context, domain, value, cmd string) (*http.Request, error) { + params := url.Values{} + params.Set("CERTBOT_DOMAIN", domain) + params.Set("CERTBOT_VALIDATION", value) + params.Set("EDIT_CMD", cmd) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), strings.NewReader(params.Encode())) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + return req, nil +} + +func (c Client) doRequest(ctx context.Context, domain, value, cmd string) error { + req, err := c.buildRequest(ctx, domain, value, cmd) + if err != nil { + return err + } + + req.SetBasicAuth(c.masterID, c.password) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + return nil +} diff --git a/providers/dns/mydnsjp/internal/client_test.go b/providers/dns/mydnsjp/internal/client_test.go new file mode 100644 index 00000000..a68f6888 --- /dev/null +++ b/providers/dns/mydnsjp/internal/client_test.go @@ -0,0 +1,92 @@ +package internal + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, cmdName string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("invalid method: %s", req.Method), http.StatusMethodNotAllowed) + return + } + + username, password, ok := req.BasicAuth() + if !ok { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if username != "xxx" { + http.Error(rw, fmt.Sprintf("username: want %s got %s", username, "xxx"), http.StatusUnauthorized) + return + } + + if password != "secret" { + http.Error(rw, fmt.Sprintf("password: want %s got %s", password, "secret"), http.StatusUnauthorized) + return + } + + if req.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + http.Error(rw, fmt.Sprintf("invalid Content-Type: %s", req.Header.Get("Content-Type")), http.StatusBadRequest) + return + } + + err := req.ParseForm() + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + domain := req.Form.Get("CERTBOT_DOMAIN") + if domain != "example.com" { + http.Error(rw, fmt.Sprintf("unexpected CERTBOT_DOMAIN: %s", domain), http.StatusBadRequest) + return + } + + validation := req.Form.Get("CERTBOT_VALIDATION") + if validation != "txt" { + http.Error(rw, fmt.Sprintf("unexpected CERTBOT_VALIDATION: %s", validation), http.StatusBadRequest) + return + } + + cmd := req.Form.Get("EDIT_CMD") + if cmd != cmdName { + http.Error(rw, fmt.Sprintf("unexpected EDIT_CMD: %s", cmd), http.StatusBadRequest) + return + } + }) + + client := NewClient("xxx", "secret") + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_AddTXTRecord(t *testing.T) { + client := setupTest(t, "REGIST") + + err := client.AddTXTRecord(context.Background(), "example.com", "txt") + require.NoError(t, err) +} + +func TestClient_DeleteTXTRecord(t *testing.T) { + client := setupTest(t, "DELETE") + + err := client.DeleteTXTRecord(context.Background(), "example.com", "txt") + require.NoError(t, err) +} diff --git a/providers/dns/mydnsjp/mydnsjp.go b/providers/dns/mydnsjp/mydnsjp.go index da52c8a8..beaaf49a 100644 --- a/providers/dns/mydnsjp/mydnsjp.go +++ b/providers/dns/mydnsjp/mydnsjp.go @@ -2,6 +2,7 @@ package mydnsjp import ( + "context" "errors" "fmt" "net/http" @@ -9,10 +10,9 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/mydnsjp/internal" ) -const defaultBaseURL = "https://www.mydns.jp/directedit.html" - // Environment variables names. const ( envNamespace = "MYDNSJP_" @@ -48,6 +48,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for MyDNS.jp. @@ -75,7 +76,10 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("mydnsjp: some credentials information are missing") } - return &DNSProvider{config: config}, nil + return &DNSProvider{ + config: config, + client: internal.NewClient(config.MasterID, config.Password), + }, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -89,7 +93,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) // TODO(ldez) replace domain by FQDN to follow CNAME. - err := d.doRequest(domain, info.Value, "REGIST") + err := d.client.AddTXTRecord(context.Background(), domain, info.Value) if err != nil { return fmt.Errorf("mydnsjp: %w", err) } @@ -101,7 +105,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) // TODO(ldez) replace domain by FQDN to follow CNAME. - err := d.doRequest(domain, info.Value, "DELETE") + err := d.client.DeleteTXTRecord(context.Background(), domain, info.Value) if err != nil { return fmt.Errorf("mydnsjp: %w", err) } diff --git a/providers/dns/mythicbeasts/client.go b/providers/dns/mythicbeasts/client.go deleted file mode 100644 index 473e4e77..00000000 --- a/providers/dns/mythicbeasts/client.go +++ /dev/null @@ -1,233 +0,0 @@ -package mythicbeasts - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -const ( - apiBaseURL = "https://api.mythic-beasts.com/dns/v2" - authBaseURL = "https://auth.mythic-beasts.com/login" -) - -type authResponse struct { - // The bearer token for use in API requests - Token string `json:"access_token"` - - // The maximum lifetime of the token in seconds - Lifetime int `json:"expires_in"` - - // The token type (must be 'bearer') - TokenType string `json:"token_type"` - - Deadline time.Time `json:"-"` -} - -type authResponseError struct { - ErrorMsg string `json:"error"` - ErrorDescription string `json:"error_description"` -} - -func (a authResponseError) Error() string { - return fmt.Sprintf("%s: %s", a.ErrorMsg, a.ErrorDescription) -} - -type createTXTRequest struct { - Records []createTXTRecord `json:"records"` -} - -type createTXTRecord struct { - Host string `json:"host"` - TTL int `json:"ttl"` - Type string `json:"type"` - Data string `json:"data"` -} - -type createTXTResponse struct { - Added int `json:"records_added"` - Removed int `json:"records_removed"` - Message string `json:"message"` -} - -type deleteTXTResponse struct { - Removed int `json:"records_removed"` - Message string `json:"message"` -} - -// Logs into mythic beasts and acquires a bearer token for use in future API calls. -// https://www.mythic-beasts.com/support/api/auth#sec-obtaining-a-token -func (d *DNSProvider) login() error { - d.muToken.Lock() - defer d.muToken.Unlock() - - if d.token != nil && time.Now().Before(d.token.Deadline) { - // Already authenticated, stop now - return nil - } - - req, err := http.NewRequest(http.MethodPost, d.config.AuthAPIEndpoint.String(), strings.NewReader("grant_type=client_credentials")) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(d.config.UserName, d.config.Password) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("login: %w", err) - } - - if resp.StatusCode != http.StatusOK { - if resp.StatusCode < 400 || resp.StatusCode > 499 { - return fmt.Errorf("login: unknown error in auth API: %d", resp.StatusCode) - } - - // Returned body should be a JSON thing - errResp := &authResponseError{} - err = json.Unmarshal(body, errResp) - if err != nil { - return fmt.Errorf("login: error parsing error: %w", err) - } - - return fmt.Errorf("login: %d: %w", resp.StatusCode, errResp) - } - - authResp := authResponse{} - err = json.Unmarshal(body, &authResp) - if err != nil { - return fmt.Errorf("login: error parsing response: %w", err) - } - - if authResp.TokenType != "bearer" { - return fmt.Errorf("login: received unexpected token type: %s", authResp.TokenType) - } - - authResp.Deadline = time.Now().Add(time.Duration(authResp.Lifetime) * time.Second) - d.token = &authResp - - // Success - return nil -} - -// https://www.mythic-beasts.com/support/api/dnsv2#ep-get-zoneszonerecords -func (d *DNSProvider) createTXTRecord(zone, leaf, value string) error { - if d.token == nil { - return fmt.Errorf("createTXTRecord: not logged in") - } - - createReq := createTXTRequest{ - Records: []createTXTRecord{{ - Host: leaf, - TTL: d.config.TTL, - Type: "TXT", - Data: value, - }}, - } - - reqBody, err := json.Marshal(createReq) - if err != nil { - return fmt.Errorf("createTXTRecord: marshaling request body failed: %w", err) - } - - endpoint := d.config.APIEndpoint.JoinPath("zones", zone, "records", leaf, "TXT") - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), bytes.NewReader(reqBody)) - if err != nil { - return fmt.Errorf("createTXTRecord: %w", err) - } - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.token.Token)) - req.Header.Set("Content-Type", "application/json") - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("createTXTRecord: unable to perform HTTP request: %w", err) - } - - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("createTXTRecord: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("createTXTRecord: error in API: %d", resp.StatusCode) - } - - createResp := createTXTResponse{} - err = json.Unmarshal(body, &createResp) - if err != nil { - return fmt.Errorf("createTXTRecord: error parsing response: %w", err) - } - - if createResp.Added != 1 { - return errors.New("createTXTRecord: did not add TXT record for some reason") - } - - // Success - return nil -} - -// https://www.mythic-beasts.com/support/api/dnsv2#ep-delete-zoneszonerecords -func (d *DNSProvider) removeTXTRecord(zone, leaf, value string) error { - if d.token == nil { - return fmt.Errorf("removeTXTRecord: not logged in") - } - - endpoint := d.config.APIEndpoint.JoinPath("zones", zone, "records", leaf, "TXT") - - query := endpoint.Query() - query.Add("data", value) - endpoint.RawQuery = query.Encode() - - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), nil) - if err != nil { - return fmt.Errorf("removeTXTRecord: %w", err) - } - - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", d.token.Token)) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("removeTXTRecord: unable to perform HTTP request: %w", err) - } - - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("removeTXTRecord: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("removeTXTRecord: error in API: %d", resp.StatusCode) - } - - deleteResp := deleteTXTResponse{} - err = json.Unmarshal(body, &deleteResp) - if err != nil { - return fmt.Errorf("removeTXTRecord: error parsing response: %w", err) - } - - if deleteResp.Removed != 1 { - return errors.New("removeTXTRecord: did not add TXT record for some reason") - } - - // Success - return nil -} diff --git a/providers/dns/mythicbeasts/internal/client.go b/providers/dns/mythicbeasts/internal/client.go new file mode 100644 index 00000000..7f7b0446 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/client.go @@ -0,0 +1,186 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// Default API endpoints. +const ( + APIBaseURL = "https://api.mythic-beasts.com/dns/v2" + AuthBaseURL = "https://auth.mythic-beasts.com/login" +) + +// Client the EasyDNS API client. +type Client struct { + username string + password string + + APIEndpoint *url.URL + AuthEndpoint *url.URL + HTTPClient *http.Client + + token *Token + muToken sync.Mutex +} + +// NewClient Creates a new Client. +func NewClient(username string, password string) *Client { + apiEndpoint, _ := url.Parse(APIBaseURL) + authEndpoint, _ := url.Parse(AuthBaseURL) + + return &Client{ + username: username, + password: password, + APIEndpoint: apiEndpoint, + AuthEndpoint: authEndpoint, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// CreateTXTRecord creates a TXT record. +// https://www.mythic-beasts.com/support/api/dnsv2#ep-get-zoneszonerecords +func (c *Client) CreateTXTRecord(ctx context.Context, zone, leaf, value string, ttl int) error { + resp, err := c.createTXTRecord(ctx, zone, leaf, "TXT", value, ttl) + if err != nil { + return err + } + + if resp.Added != 1 { + return fmt.Errorf("did not add TXT record for some reason: %s", resp.Message) + } + + // Success + return nil +} + +// RemoveTXTRecord removes a TXT records. +// https://www.mythic-beasts.com/support/api/dnsv2#ep-delete-zoneszonerecords +func (c *Client) RemoveTXTRecord(ctx context.Context, zone, leaf, value string) error { + resp, err := c.removeTXTRecord(ctx, zone, leaf, "TXT", value) + if err != nil { + return err + } + + if resp.Removed != 1 { + return fmt.Errorf("did not remove TXT record for some reason: %s", resp.Message) + } + + // Success + return nil +} + +// https://www.mythic-beasts.com/support/api/dnsv2#ep-post-zoneszonerecords +func (c *Client) createTXTRecord(ctx context.Context, zone, leaf, recordType, value string, ttl int) (*createTXTResponse, error) { + endpoint := c.APIEndpoint.JoinPath("zones", zone, "records", leaf, recordType) + + createReq := createTXTRequest{ + Records: []createTXTRecord{{ + Host: leaf, + TTL: ttl, + Type: "TXT", + Data: value, + }}, + } + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, createReq) + if err != nil { + return nil, err + } + + resp := &createTXTResponse{} + err = c.do(req, resp) + if err != nil { + return nil, err + } + + return resp, nil +} + +// https://www.mythic-beasts.com/support/api/dnsv2#ep-delete-zoneszonerecords +func (c *Client) removeTXTRecord(ctx context.Context, zone, leaf, recordType, value string) (*deleteTXTResponse, error) { + endpoint := c.APIEndpoint.JoinPath("zones", zone, "records", leaf, recordType) + + query := endpoint.Query() + query.Add("data", value) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return nil, err + } + + resp := &deleteTXTResponse{} + + err = c.do(req, resp) + if err != nil { + return nil, err + } + + return resp, nil +} + +func (c *Client) do(req *http.Request, result any) error { + tok := getToken(req.Context()) + if tok != nil { + req.Header.Set("Authorization", "Bearer "+tok.Token) + } else { + return fmt.Errorf("not logged in") + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/mythicbeasts/internal/client_test.go b/providers/dns/mythicbeasts/internal/client_test.go new file mode 100644 index 00000000..7e385798 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/client_test.go @@ -0,0 +1,69 @@ +package internal + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, handler) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.APIEndpoint, _ = url.Parse(server.URL) + client.token = &Token{ + Token: "secret", + Lifetime: 60, + TokenType: "bearer", + Deadline: time.Now().Add(1 * time.Minute), + } + + return client +} + +func writeFixtureHandler(method, filename string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) + } +} + +func TestClient_CreateTXTRecord(t *testing.T) { + client := setupTest(t, "/zones/example.com/records/foo/TXT", writeFixtureHandler(http.MethodPost, "post-zoneszonerecords.json")) + + err := client.CreateTXTRecord(mockContext(), "example.com", "foo", "txt", 120) + require.NoError(t, err) +} + +func TestClient_RemoveTXTRecord(t *testing.T) { + client := setupTest(t, "/zones/example.com/records/foo/TXT", writeFixtureHandler(http.MethodDelete, "delete-zoneszonerecords.json")) + + err := client.RemoveTXTRecord(mockContext(), "example.com", "foo", "txt") + require.NoError(t, err) +} diff --git a/providers/dns/mythicbeasts/internal/fixtures/delete-zoneszonerecords.json b/providers/dns/mythicbeasts/internal/fixtures/delete-zoneszonerecords.json new file mode 100644 index 00000000..5bb325af --- /dev/null +++ b/providers/dns/mythicbeasts/internal/fixtures/delete-zoneszonerecords.json @@ -0,0 +1,4 @@ +{ + "records_removed": 1, + "message": "1 record removed" +} diff --git a/providers/dns/mythicbeasts/internal/fixtures/post-zoneszonerecords.json b/providers/dns/mythicbeasts/internal/fixtures/post-zoneszonerecords.json new file mode 100644 index 00000000..96c7ab11 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/fixtures/post-zoneszonerecords.json @@ -0,0 +1,4 @@ +{ + "records_added": 1, + "message": "1 record added" +} diff --git a/providers/dns/mythicbeasts/internal/identity.go b/providers/dns/mythicbeasts/internal/identity.go new file mode 100644 index 00000000..417f1c75 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/identity.go @@ -0,0 +1,101 @@ +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +type token string + +const tokenKey token = "token" + +// obtainToken Logs into mythic beasts and acquires a bearer token for use in future API calls. +// https://www.mythic-beasts.com/support/api/auth#sec-obtaining-a-token +func (c *Client) obtainToken(ctx context.Context) (*Token, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.AuthEndpoint.String(), strings.NewReader("grant_type=client_credentials")) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.username, c.password) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, parseError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + tok := Token{} + err = json.Unmarshal(raw, &tok) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + if tok.TokenType != "bearer" { + return nil, fmt.Errorf("received unexpected token type: %s", tok.TokenType) + } + + tok.Deadline = time.Now().Add(time.Duration(tok.Lifetime) * time.Second) + + return &tok, nil +} + +func (c *Client) CreateAuthenticatedContext(ctx context.Context) (context.Context, error) { + c.muToken.Lock() + defer c.muToken.Unlock() + + if c.token != nil && time.Now().Before(c.token.Deadline) { + // Already authenticated, stop now + return context.WithValue(ctx, tokenKey, c.token), nil + } + + tok, err := c.obtainToken(ctx) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, tokenKey, tok), nil +} + +func parseError(req *http.Request, resp *http.Response) error { + if resp.StatusCode < 400 || resp.StatusCode > 499 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, _ := io.ReadAll(resp.Body) + + errResp := &authResponseError{} + err := json.Unmarshal(raw, errResp) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("%d: %w", resp.StatusCode, errResp) +} + +func getToken(ctx context.Context) *Token { + tok, ok := ctx.Value(tokenKey).(*Token) + if !ok { + return nil + } + + return tok +} diff --git a/providers/dns/mythicbeasts/internal/identity_test.go b/providers/dns/mythicbeasts/internal/identity_test.go new file mode 100644 index 00000000..9d8daf82 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/identity_test.go @@ -0,0 +1,81 @@ +package internal + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockContext() context.Context { + return context.WithValue(context.Background(), tokenKey, &Token{Token: "xxx"}) +} + +func tokenHandler(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("invalid method, got %s want %s", req.Method, http.MethodPost), http.StatusMethodNotAllowed) + return + } + + username, password, ok := req.BasicAuth() + if !ok || username != "user" || password != "secret" { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + _ = json.NewEncoder(rw).Encode(Token{ + Token: "xxx", + Lifetime: 666, + TokenType: "bearer", + }) +} + +func TestClient_obtainToken(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", tokenHandler) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.AuthEndpoint, _ = url.Parse(server.URL) + + assert.Nil(t, client.token) + + tok, err := client.obtainToken(context.Background()) + require.NoError(t, err) + + assert.NotNil(t, tok) + assert.NotZero(t, tok.Deadline) + assert.Equal(t, "xxx", tok.Token) +} + +func TestClient_CreateAuthenticatedContext(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", tokenHandler) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.AuthEndpoint, _ = url.Parse(server.URL) + + assert.Nil(t, client.token) + + ctx, err := client.CreateAuthenticatedContext(context.Background()) + require.NoError(t, err) + + tok := getToken(ctx) + + assert.NotNil(t, tok) + assert.NotZero(t, tok.Deadline) + assert.Equal(t, "xxx", tok.Token) +} diff --git a/providers/dns/mythicbeasts/internal/types.go b/providers/dns/mythicbeasts/internal/types.go new file mode 100644 index 00000000..c68d0dc0 --- /dev/null +++ b/providers/dns/mythicbeasts/internal/types.go @@ -0,0 +1,50 @@ +package internal + +import ( + "fmt" + "time" +) + +type Token struct { + // The bearer token for use in API requests + Token string `json:"access_token"` + + // The maximum lifetime of the token in seconds + Lifetime int `json:"expires_in"` + + // The token type (must be 'bearer') + TokenType string `json:"token_type"` + + Deadline time.Time `json:"-"` +} + +type authResponseError struct { + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func (a authResponseError) Error() string { + return fmt.Sprintf("%s: %s", a.ErrorMsg, a.ErrorDescription) +} + +type createTXTRequest struct { + Records []createTXTRecord `json:"records"` +} + +type createTXTRecord struct { + Host string `json:"host"` + TTL int `json:"ttl"` + Type string `json:"type"` + Data string `json:"data"` +} + +type createTXTResponse struct { + Added int `json:"records_added"` + Removed int `json:"records_removed"` + Message string `json:"message"` +} + +type deleteTXTResponse struct { + Removed int `json:"records_removed"` + Message string `json:"message"` +} diff --git a/providers/dns/mythicbeasts/mythicbeasts.go b/providers/dns/mythicbeasts/mythicbeasts.go index 0c787cce..7545b3fb 100644 --- a/providers/dns/mythicbeasts/mythicbeasts.go +++ b/providers/dns/mythicbeasts/mythicbeasts.go @@ -2,15 +2,16 @@ package mythicbeasts import ( + "context" "errors" "fmt" "net/http" "net/url" - "sync" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/mythicbeasts/internal" ) // Environment variables names. @@ -42,12 +43,12 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() (*Config, error) { - apiEndpoint, err := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, apiBaseURL)) + apiEndpoint, err := url.Parse(env.GetOrDefaultString(EnvAPIEndpoint, internal.APIBaseURL)) if err != nil { return nil, fmt.Errorf("mythicbeasts: Unable to parse API URL: %w", err) } - authEndpoint, err := url.Parse(env.GetOrDefaultString(EnvAuthAPIEndpoint, authBaseURL)) + authEndpoint, err := url.Parse(env.GetOrDefaultString(EnvAuthAPIEndpoint, internal.AuthBaseURL)) if err != nil { return nil, fmt.Errorf("mythicbeasts: Unable to parse AUTH API URL: %w", err) } @@ -67,10 +68,7 @@ func NewDefaultConfig() (*Config, error) { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config - - // token string - token *authResponse - muToken sync.Mutex + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for mythicbeasts DNSv2 API. @@ -102,7 +100,21 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("mythicbeasts: incomplete credentials, missing username and/or password") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.UserName, config.Password) + + if config.APIEndpoint != nil { + client.APIEndpoint = config.APIEndpoint + } + + if config.AuthAPIEndpoint != nil { + client.AuthEndpoint = config.AuthAPIEndpoint + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record using the specified parameters. @@ -111,7 +123,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -121,14 +133,14 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone = dns01.UnFqdn(authZone) - err = d.login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: login: %w", err) } - err = d.createTXTRecord(authZone, subDomain, info.Value) + err = d.client.CreateTXTRecord(ctx, authZone, subDomain, info.Value, d.config.TTL) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: CreateTXTRecord: %w", err) } return nil @@ -140,7 +152,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -150,14 +162,14 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone = dns01.UnFqdn(authZone) - err = d.login() + ctx, err := d.client.CreateAuthenticatedContext(context.Background()) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: login: %w", err) } - err = d.removeTXTRecord(authZone, subDomain, info.Value) + err = d.client.RemoveTXTRecord(ctx, authZone, subDomain, info.Value) if err != nil { - return fmt.Errorf("mythicbeasts: %w", err) + return fmt.Errorf("mythicbeasts: RemoveTXTRecord: %w", err) } return nil diff --git a/providers/dns/namecheap/client.go b/providers/dns/namecheap/client.go deleted file mode 100644 index 6d62df8b..00000000 --- a/providers/dns/namecheap/client.go +++ /dev/null @@ -1,189 +0,0 @@ -package namecheap - -import ( - "encoding/xml" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "strings" -) - -// Record describes a DNS record returned by the Namecheap DNS gethosts API. -// Namecheap uses the term "host" to refer to all DNS records that include -// a host field (A, AAAA, CNAME, NS, TXT, URL). -type Record struct { - Type string `xml:",attr"` - Name string `xml:",attr"` - Address string `xml:",attr"` - MXPref string `xml:",attr"` - TTL string `xml:",attr"` -} - -// apiError describes an error record in a namecheap API response. -type apiError struct { - Number int `xml:",attr"` - Description string `xml:",innerxml"` -} - -type setHostsResponse struct { - XMLName xml.Name `xml:"ApiResponse"` - Status string `xml:"Status,attr"` - Errors []apiError `xml:"Errors>Error"` - Result struct { - IsSuccess string `xml:",attr"` - } `xml:"CommandResponse>DomainDNSSetHostsResult"` -} - -type getHostsResponse struct { - XMLName xml.Name `xml:"ApiResponse"` - Status string `xml:"Status,attr"` - Errors []apiError `xml:"Errors>Error"` - Hosts []Record `xml:"CommandResponse>DomainDNSGetHostsResult>host"` -} - -// getHosts reads the full list of DNS host records. -// https://www.namecheap.com/support/api/methods/domains-dns/get-hosts.aspx -func (d *DNSProvider) getHosts(sld, tld string) ([]Record, error) { - request, err := d.newRequestGet("namecheap.domains.dns.getHosts", - addParam("SLD", sld), - addParam("TLD", tld), - ) - if err != nil { - return nil, err - } - - var ghr getHostsResponse - err = d.do(request, &ghr) - if err != nil { - return nil, err - } - - if len(ghr.Errors) > 0 { - return nil, fmt.Errorf("%s [%d]", ghr.Errors[0].Description, ghr.Errors[0].Number) - } - - return ghr.Hosts, nil -} - -// setHosts writes the full list of DNS host records . -// https://www.namecheap.com/support/api/methods/domains-dns/set-hosts.aspx -func (d *DNSProvider) setHosts(sld, tld string, hosts []Record) error { - req, err := d.newRequestPost("namecheap.domains.dns.setHosts", - addParam("SLD", sld), - addParam("TLD", tld), - func(values url.Values) { - for i, h := range hosts { - ind := fmt.Sprintf("%d", i+1) - values.Add("HostName"+ind, h.Name) - values.Add("RecordType"+ind, h.Type) - values.Add("Address"+ind, h.Address) - values.Add("MXPref"+ind, h.MXPref) - values.Add("TTL"+ind, h.TTL) - } - }, - ) - if err != nil { - return err - } - - var shr setHostsResponse - err = d.do(req, &shr) - if err != nil { - return err - } - - if len(shr.Errors) > 0 { - return fmt.Errorf("%s [%d]", shr.Errors[0].Description, shr.Errors[0].Number) - } - if shr.Result.IsSuccess != "true" { - return errors.New("setHosts failed") - } - - return nil -} - -func (d *DNSProvider) do(req *http.Request, out interface{}) error { - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - if resp.StatusCode >= http.StatusBadRequest { - var body []byte - body, err = readBody(resp) - if err != nil { - return fmt.Errorf("HTTP error %d [%s]: %w", resp.StatusCode, http.StatusText(resp.StatusCode), err) - } - return fmt.Errorf("HTTP error %d [%s]: %s", resp.StatusCode, http.StatusText(resp.StatusCode), string(body)) - } - - body, err := readBody(resp) - if err != nil { - return err - } - - return xml.Unmarshal(body, out) -} - -func (d *DNSProvider) newRequestGet(cmd string, params ...func(url.Values)) (*http.Request, error) { - query := d.makeQuery(cmd, params...) - - reqURL, err := url.Parse(d.config.BaseURL) - if err != nil { - return nil, err - } - - reqURL.RawQuery = query.Encode() - - return http.NewRequest(http.MethodGet, reqURL.String(), nil) -} - -func (d *DNSProvider) newRequestPost(cmd string, params ...func(url.Values)) (*http.Request, error) { - query := d.makeQuery(cmd, params...) - - req, err := http.NewRequest(http.MethodPost, d.config.BaseURL, strings.NewReader(query.Encode())) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - return req, nil -} - -func (d *DNSProvider) makeQuery(cmd string, params ...func(url.Values)) url.Values { - queryParams := make(url.Values) - queryParams.Set("ApiUser", d.config.APIUser) - queryParams.Set("ApiKey", d.config.APIKey) - queryParams.Set("UserName", d.config.APIUser) - queryParams.Set("Command", cmd) - queryParams.Set("ClientIp", d.config.ClientIP) - - for _, param := range params { - param(queryParams) - } - - return queryParams -} - -func addParam(key, value string) func(url.Values) { - return func(values url.Values) { - values.Set(key, value) - } -} - -func readBody(resp *http.Response) ([]byte, error) { - if resp.Body == nil { - return nil, errors.New("response body is nil") - } - - defer resp.Body.Close() - - rawBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return rawBody, nil -} diff --git a/providers/dns/namecheap/internal/client.go b/providers/dns/namecheap/internal/client.go new file mode 100644 index 00000000..f2124f83 --- /dev/null +++ b/providers/dns/namecheap/internal/client.go @@ -0,0 +1,175 @@ +package internal + +import ( + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// Default API endpoints. +const ( + DefaultBaseURL = "https://api.namecheap.com/xml.response" + SandboxBaseURL = "https://api.sandbox.namecheap.com/xml.response" +) + +// Client the API client for Namecheap. +type Client struct { + apiUser string + apiKey string + clientIP string + + BaseURL string + HTTPClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(apiUser string, apiKey string, clientIP string) *Client { + return &Client{ + apiUser: apiUser, + apiKey: apiKey, + clientIP: clientIP, + BaseURL: DefaultBaseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// GetHosts reads the full list of DNS host records. +// https://www.namecheap.com/support/api/methods/domains-dns/get-hosts.aspx +func (c *Client) GetHosts(ctx context.Context, sld, tld string) ([]Record, error) { + request, err := c.newRequestGet(ctx, "namecheap.domains.dns.getHosts", + addParam("SLD", sld), + addParam("TLD", tld), + ) + if err != nil { + return nil, err + } + + var ghr getHostsResponse + err = c.do(request, &ghr) + if err != nil { + return nil, err + } + + if len(ghr.Errors) > 0 { + return nil, ghr.Errors[0] + } + + return ghr.Hosts, nil +} + +// SetHosts writes the full list of DNS host records . +// https://www.namecheap.com/support/api/methods/domains-dns/set-hosts.aspx +func (c *Client) SetHosts(ctx context.Context, sld, tld string, hosts []Record) error { + req, err := c.newRequestPost(ctx, "namecheap.domains.dns.setHosts", + addParam("SLD", sld), + addParam("TLD", tld), + func(values url.Values) { + for i, h := range hosts { + ind := fmt.Sprintf("%d", i+1) + values.Add("HostName"+ind, h.Name) + values.Add("RecordType"+ind, h.Type) + values.Add("Address"+ind, h.Address) + values.Add("MXPref"+ind, h.MXPref) + values.Add("TTL"+ind, h.TTL) + } + }, + ) + if err != nil { + return err + } + + var shr setHostsResponse + err = c.do(req, &shr) + if err != nil { + return err + } + + if len(shr.Errors) > 0 { + return shr.Errors[0] + } + if shr.Result.IsSuccess != "true" { + return errors.New("setHosts failed") + } + + return nil +} + +func (c *Client) do(req *http.Request, result any) error { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusBadRequest { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + return xml.Unmarshal(raw, result) +} + +func (c *Client) newRequestGet(ctx context.Context, cmd string, params ...func(url.Values)) (*http.Request, error) { + query := c.makeQuery(cmd, params...) + + endpoint, err := url.Parse(c.BaseURL) + if err != nil { + return nil, err + } + + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + return req, nil +} + +func (c *Client) newRequestPost(ctx context.Context, cmd string, params ...func(url.Values)) (*http.Request, error) { + query := c.makeQuery(cmd, params...) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL, strings.NewReader(query.Encode())) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + return req, nil +} + +func (c *Client) makeQuery(cmd string, params ...func(url.Values)) url.Values { + queryParams := make(url.Values) + queryParams.Set("ApiUser", c.apiUser) + queryParams.Set("ApiKey", c.apiKey) + queryParams.Set("UserName", c.apiUser) + queryParams.Set("Command", cmd) + queryParams.Set("ClientIp", c.clientIP) + + for _, param := range params { + param(queryParams) + } + + return queryParams +} + +func addParam(key, value string) func(url.Values) { + return func(values url.Values) { + values.Set(key, value) + } +} diff --git a/providers/dns/namecheap/internal/client_test.go b/providers/dns/namecheap/internal/client_test.go new file mode 100644 index 00000000..9d78ee21 --- /dev/null +++ b/providers/dns/namecheap/internal/client_test.go @@ -0,0 +1,173 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, handler http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc("/", handler) + + client := NewClient("user", "secret", "127.0.0.1") + client.HTTPClient = server.Client() + client.BaseURL = server.URL + + return client +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) +} + +func TestClient_GetHosts(t *testing.T) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + expectedParams := map[string]string{ + "ApiKey": "secret", + "ApiUser": "user", + "ClientIp": "127.0.0.1", + "Command": "namecheap.domains.dns.getHosts", + "SLD": "foo", + "TLD": "example.com", + "UserName": "user", + } + + query := req.URL.Query() + for k, v := range expectedParams { + if query.Get(k) != v { + http.Error(rw, fmt.Sprintf("invalid query parameter %s value: %s", k, query.Get(k)), http.StatusBadRequest) + return + } + } + + writeFixture(rw, "getHosts.xml") + }) + + hosts, err := client.GetHosts(context.Background(), "foo", "example.com") + require.NoError(t, err) + + expected := []Record{ + {Type: "A", Name: "@", Address: "1.2.3.4", MXPref: "10", TTL: "1800"}, + {Type: "A", Name: "www", Address: "122.23.3.7", MXPref: "10", TTL: "1800"}, + } + + assert.Equal(t, expected, hosts) +} + +func TestClient_GetHosts_error(t *testing.T) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + writeFixture(rw, "getHosts_errorBadAPIKey1.xml") + }) + + _, err := client.GetHosts(context.Background(), "foo", "example.com") + require.ErrorAs(t, err, &apiError{}) +} + +func TestClient_SetHosts(t *testing.T) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + if req.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + http.Error(rw, fmt.Sprintf("invalid Content-Type: %s", req.Header.Get("Content-Type")), http.StatusBadRequest) + return + } + + err := req.ParseForm() + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + expectedParams := map[string]string{ + "HostName1": "_acme-challenge.test.example.com", + "RecordType1": "TXT", + "Address1": "txtTXTtxt", + "MXPref1": "10", + "TTL1": "120", + + "HostName2": "_acme-challenge.test.example.org", + "RecordType2": "TXT", + "Address2": "txtTXTtxt", + "MXPref2": "10", + "TTL2": "120", + + "ApiKey": "secret", + "ApiUser": "user", + "ClientIp": "127.0.0.1", + "Command": "namecheap.domains.dns.setHosts", + "SLD": "foo", + "TLD": "example.com", + "UserName": "user", + } + + for k, v := range expectedParams { + if req.Form.Get(k) != v { + http.Error(rw, fmt.Sprintf("invalid form data %s value: %q", k, req.Form.Get(k)), http.StatusBadRequest) + return + } + } + + writeFixture(rw, "setHosts.xml") + }) + + records := []Record{ + {Name: "_acme-challenge.test.example.com", Type: "TXT", Address: "txtTXTtxt", MXPref: "10", TTL: "120"}, + {Name: "_acme-challenge.test.example.org", Type: "TXT", Address: "txtTXTtxt", MXPref: "10", TTL: "120"}, + } + + err := client.SetHosts(context.Background(), "foo", "example.com", records) + require.NoError(t, err) +} + +func TestClient_SetHosts_error(t *testing.T) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + writeFixture(rw, "setHosts_errorBadAPIKey1.xml") + }) + + records := []Record{ + {Name: "_acme-challenge.test.example.com", Type: "TXT", Address: "txtTXTtxt", MXPref: "10", TTL: "120"}, + {Name: "_acme-challenge.test.example.org", Type: "TXT", Address: "txtTXTtxt", MXPref: "10", TTL: "120"}, + } + + err := client.SetHosts(context.Background(), "foo", "example.com", records) + require.ErrorAs(t, err, &apiError{}) +} diff --git a/providers/dns/namecheap/internal/fixtures/getHosts.xml b/providers/dns/namecheap/internal/fixtures/getHosts.xml new file mode 100644 index 00000000..ee2fdcca --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/getHosts.xml @@ -0,0 +1,14 @@ + + + + namecheap.domains.dns.getHosts + + + + + + + SERVER-NAME + +5 + 32.76 + diff --git a/providers/dns/namecheap/internal/fixtures/getHosts_errorBadAPIKey1.xml b/providers/dns/namecheap/internal/fixtures/getHosts_errorBadAPIKey1.xml new file mode 100644 index 00000000..3f73a678 --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/getHosts_errorBadAPIKey1.xml @@ -0,0 +1,11 @@ + + + + API Key is invalid or API access has not been enabled + + + + PHX01SBAPI01 + --5:00 + 0 + diff --git a/providers/dns/namecheap/internal/fixtures/getHosts_success1.xml b/providers/dns/namecheap/internal/fixtures/getHosts_success1.xml new file mode 100644 index 00000000..7a53f6d4 --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/getHosts_success1.xml @@ -0,0 +1,19 @@ + + + + + namecheap.domains.dns.getHosts + + + + + + + + + + + PHX01SBAPI01 + --5:00 + 3.338 + diff --git a/providers/dns/namecheap/internal/fixtures/getHosts_success2.xml b/providers/dns/namecheap/internal/fixtures/getHosts_success2.xml new file mode 100644 index 00000000..b382674a --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/getHosts_success2.xml @@ -0,0 +1,15 @@ + + + + + namecheap.domains.dns.getHosts + + + + + + + PHX01SBAPI01 + --5:00 + 3.338 + diff --git a/providers/dns/namecheap/internal/fixtures/setHosts.xml b/providers/dns/namecheap/internal/fixtures/setHosts.xml new file mode 100644 index 00000000..11366dff --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/setHosts.xml @@ -0,0 +1,11 @@ + + + + namecheap.domains.dns.setHosts + + + + SERVER-NAME + +5 + 32.76 + diff --git a/providers/dns/namecheap/internal/fixtures/setHosts_errorBadAPIKey1.xml b/providers/dns/namecheap/internal/fixtures/setHosts_errorBadAPIKey1.xml new file mode 100644 index 00000000..3f73a678 --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/setHosts_errorBadAPIKey1.xml @@ -0,0 +1,11 @@ + + + + API Key is invalid or API access has not been enabled + + + + PHX01SBAPI01 + --5:00 + 0 + diff --git a/providers/dns/namecheap/internal/fixtures/setHosts_success1.xml b/providers/dns/namecheap/internal/fixtures/setHosts_success1.xml new file mode 100644 index 00000000..e428e787 --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/setHosts_success1.xml @@ -0,0 +1,14 @@ + + + + + namecheap.domains.dns.setHosts + + + + + + PHX01SBAPI01 + --5:00 + 2.347 + diff --git a/providers/dns/namecheap/internal/fixtures/setHosts_success2.xml b/providers/dns/namecheap/internal/fixtures/setHosts_success2.xml new file mode 100644 index 00000000..e428e787 --- /dev/null +++ b/providers/dns/namecheap/internal/fixtures/setHosts_success2.xml @@ -0,0 +1,14 @@ + + + + + namecheap.domains.dns.setHosts + + + + + + PHX01SBAPI01 + --5:00 + 2.347 + diff --git a/providers/dns/namecheap/internal/ip.go b/providers/dns/namecheap/internal/ip.go new file mode 100644 index 00000000..5823212d --- /dev/null +++ b/providers/dns/namecheap/internal/ip.go @@ -0,0 +1,45 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +const getIPURL = "https://dynamicdns.park-your-domain.com/getip" + +// GetClientIP returns the client's public IP address. +// It uses namecheap's IP discovery service to perform the lookup. +func GetClientIP(ctx context.Context, client *http.Client, debug bool) (addr string, err error) { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, getIPURL, http.NoBody) + if err != nil { + return "", fmt.Errorf("unable to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return "", err + } + + defer func() { _ = resp.Body.Close() }() + + clientIP, err := io.ReadAll(resp.Body) + if err != nil { + return "", errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if debug { + log.Println("Client IP:", string(clientIP)) + } + + return string(clientIP), nil +} diff --git a/providers/dns/namecheap/internal/types.go b/providers/dns/namecheap/internal/types.go new file mode 100644 index 00000000..336776d8 --- /dev/null +++ b/providers/dns/namecheap/internal/types.go @@ -0,0 +1,43 @@ +package internal + +import ( + "encoding/xml" + "fmt" +) + +// Record describes a DNS record returned by the Namecheap DNS gethosts API. +// Namecheap uses the term "host" to refer to all DNS records that include +// a host field (A, AAAA, CNAME, NS, TXT, URL). +type Record struct { + Type string `xml:",attr"` + Name string `xml:",attr"` + Address string `xml:",attr"` + MXPref string `xml:",attr"` + TTL string `xml:",attr"` +} + +// apiError describes an error record in a namecheap API response. +type apiError struct { + Number int `xml:",attr"` + Description string `xml:",innerxml"` +} + +func (a apiError) Error() string { + return fmt.Sprintf("%s [%d]", a.Description, a.Number) +} + +type setHostsResponse struct { + XMLName xml.Name `xml:"ApiResponse"` + Status string `xml:"Status,attr"` + Errors []apiError `xml:"Errors>Error"` + Result struct { + IsSuccess string `xml:",attr"` + } `xml:"CommandResponse>DomainDNSSetHostsResult"` +} + +type getHostsResponse struct { + XMLName xml.Name `xml:"ApiResponse"` + Status string `xml:"Status,attr"` + Errors []apiError `xml:"Errors>Error"` + Hosts []Record `xml:"CommandResponse>DomainDNSGetHostsResult>host"` +} diff --git a/providers/dns/namecheap/namecheap.go b/providers/dns/namecheap/namecheap.go index 4b0c8fb3..eb94c9ce 100644 --- a/providers/dns/namecheap/namecheap.go +++ b/providers/dns/namecheap/namecheap.go @@ -2,9 +2,9 @@ package namecheap import ( + "context" "errors" "fmt" - "io" "net/http" "strconv" "strings" @@ -13,6 +13,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/namecheap/internal" "golang.org/x/net/publicsuffix" ) @@ -29,12 +30,6 @@ import ( // address as a form or query string value. This code uses a namecheap // service to query the client's IP address. -const ( - defaultBaseURL = "https://api.namecheap.com/xml.response" - sandboxBaseURL = "https://api.sandbox.namecheap.com/xml.response" - getIPURL = "https://dynamicdns.park-your-domain.com/getip" -) - // Environment variables names. const ( envNamespace = "NAMECHEAP_" @@ -62,177 +57,6 @@ type challenge struct { host string } -// Config is used to configure the creation of the DNSProvider. -type Config struct { - Debug bool - BaseURL string - APIUser string - APIKey string - ClientIP string - PropagationTimeout time.Duration - PollingInterval time.Duration - TTL int - HTTPClient *http.Client -} - -// NewDefaultConfig returns a default configuration for the DNSProvider. -func NewDefaultConfig() *Config { - baseURL := defaultBaseURL - if env.GetOrDefaultBool(EnvSandbox, false) { - baseURL = sandboxBaseURL - } - - return &Config{ - BaseURL: baseURL, - Debug: env.GetOrDefaultBool(EnvDebug, false), - TTL: env.GetOrDefaultInt(EnvTTL, dns01.DefaultTTL), - PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Minute), - PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 15*time.Second), - HTTPClient: &http.Client{ - Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 60*time.Second), - }, - } -} - -// DNSProvider implements the challenge.Provider interface. -type DNSProvider struct { - config *Config -} - -// NewDNSProvider returns a DNSProvider instance configured for namecheap. -// Credentials must be passed in the environment variables: -// NAMECHEAP_API_USER and NAMECHEAP_API_KEY. -func NewDNSProvider() (*DNSProvider, error) { - values, err := env.Get(EnvAPIUser, EnvAPIKey) - if err != nil { - return nil, fmt.Errorf("namecheap: %w", err) - } - - config := NewDefaultConfig() - config.APIUser = values[EnvAPIUser] - config.APIKey = values[EnvAPIKey] - - return NewDNSProviderConfig(config) -} - -// NewDNSProviderConfig return a DNSProvider instance configured for Namecheap. -func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { - if config == nil { - return nil, errors.New("namecheap: the configuration of the DNS provider is nil") - } - - if config.APIUser == "" || config.APIKey == "" { - return nil, errors.New("namecheap: credentials missing") - } - - if config.ClientIP == "" { - clientIP, err := getClientIP(config.HTTPClient, config.Debug) - if err != nil { - return nil, fmt.Errorf("namecheap: %w", err) - } - config.ClientIP = clientIP - } - - return &DNSProvider{config: config}, nil -} - -// Timeout returns the timeout and interval to use when checking for DNS propagation. -// Namecheap can sometimes take a long time to complete an update, so wait up to 60 minutes for the update to propagate. -func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { - return d.config.PropagationTimeout, d.config.PollingInterval -} - -// Present installs a TXT record for the DNS challenge. -func (d *DNSProvider) Present(domain, token, keyAuth string) error { - // TODO(ldez) replace domain by FQDN to follow CNAME. - ch, err := newChallenge(domain, keyAuth) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - - records, err := d.getHosts(ch.sld, ch.tld) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - - record := Record{ - Name: ch.key, - Type: "TXT", - Address: ch.keyValue, - MXPref: "10", - TTL: strconv.Itoa(d.config.TTL), - } - - records = append(records, record) - - if d.config.Debug { - for _, h := range records { - log.Printf("%-5.5s %-30.30s %-6s %-70.70s", h.Type, h.Name, h.TTL, h.Address) - } - } - - err = d.setHosts(ch.sld, ch.tld, records) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - return nil -} - -// CleanUp removes a TXT record used for a previous DNS challenge. -func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { - // TODO(ldez) replace domain by FQDN to follow CNAME. - ch, err := newChallenge(domain, keyAuth) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - - records, err := d.getHosts(ch.sld, ch.tld) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - - // Find the challenge TXT record and remove it if found. - var found bool - var newRecords []Record - for _, h := range records { - if h.Name == ch.key && h.Type == "TXT" { - found = true - } else { - newRecords = append(newRecords, h) - } - } - - if !found { - return nil - } - - err = d.setHosts(ch.sld, ch.tld, newRecords) - if err != nil { - return fmt.Errorf("namecheap: %w", err) - } - return nil -} - -// getClientIP returns the client's public IP address. -// It uses namecheap's IP discovery service to perform the lookup. -func getClientIP(client *http.Client, debug bool) (addr string, err error) { - resp, err := client.Get(getIPURL) - if err != nil { - return "", err - } - defer resp.Body.Close() - - clientIP, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - if debug { - log.Println("Client IP:", string(clientIP)) - } - return string(clientIP), nil -} - // newChallenge builds a challenge record from a domain name and a challenge authentication key. func newChallenge(domain, keyAuth string) (*challenge, error) { domain = dns01.UnFqdn(domain) @@ -263,3 +87,166 @@ func newChallenge(domain, keyAuth string) (*challenge, error) { host: host, }, nil } + +// Config is used to configure the creation of the DNSProvider. +type Config struct { + Debug bool + BaseURL string + APIUser string + APIKey string + ClientIP string + PropagationTimeout time.Duration + PollingInterval time.Duration + TTL int + HTTPClient *http.Client +} + +// NewDefaultConfig returns a default configuration for the DNSProvider. +func NewDefaultConfig() *Config { + baseURL := internal.DefaultBaseURL + if env.GetOrDefaultBool(EnvSandbox, false) { + baseURL = internal.SandboxBaseURL + } + + return &Config{ + BaseURL: baseURL, + Debug: env.GetOrDefaultBool(EnvDebug, false), + TTL: env.GetOrDefaultInt(EnvTTL, dns01.DefaultTTL), + PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 60*time.Minute), + PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 15*time.Second), + HTTPClient: &http.Client{ + Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 60*time.Second), + }, + } +} + +// DNSProvider implements the challenge.Provider interface. +type DNSProvider struct { + config *Config + client *internal.Client +} + +// NewDNSProvider returns a DNSProvider instance configured for namecheap. +// Credentials must be passed in the environment variables: +// NAMECHEAP_API_USER and NAMECHEAP_API_KEY. +func NewDNSProvider() (*DNSProvider, error) { + values, err := env.Get(EnvAPIUser, EnvAPIKey) + if err != nil { + return nil, fmt.Errorf("namecheap: %w", err) + } + + config := NewDefaultConfig() + config.APIUser = values[EnvAPIUser] + config.APIKey = values[EnvAPIKey] + + return NewDNSProviderConfig(config) +} + +// NewDNSProviderConfig return a DNSProvider instance configured for Namecheap. +func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { + if config == nil { + return nil, errors.New("namecheap: the configuration of the DNS provider is nil") + } + + if config.APIUser == "" || config.APIKey == "" { + return nil, errors.New("namecheap: credentials missing") + } + + if config.ClientIP == "" { + clientIP, err := internal.GetClientIP(context.Background(), config.HTTPClient, config.Debug) + if err != nil { + return nil, fmt.Errorf("namecheap: %w", err) + } + config.ClientIP = clientIP + } + + client := internal.NewClient(config.APIUser, config.APIKey, config.ClientIP) + client.BaseURL = config.BaseURL + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil +} + +// Timeout returns the timeout and interval to use when checking for DNS propagation. +// Namecheap can sometimes take a long time to complete an update, so wait up to 60 minutes for the update to propagate. +func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval +} + +// Present installs a TXT record for the DNS challenge. +func (d *DNSProvider) Present(domain, token, keyAuth string) error { + // TODO(ldez) replace domain by FQDN to follow CNAME. + ch, err := newChallenge(domain, keyAuth) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + + ctx := context.Background() + + records, err := d.client.GetHosts(ctx, ch.sld, ch.tld) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + + record := internal.Record{ + Name: ch.key, + Type: "TXT", + Address: ch.keyValue, + MXPref: "10", + TTL: strconv.Itoa(d.config.TTL), + } + + records = append(records, record) + + if d.config.Debug { + for _, h := range records { + log.Printf("%-5.5s %-30.30s %-6s %-70.70s", h.Type, h.Name, h.TTL, h.Address) + } + } + + err = d.client.SetHosts(ctx, ch.sld, ch.tld, records) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + return nil +} + +// CleanUp removes a TXT record used for a previous DNS challenge. +func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { + // TODO(ldez) replace domain by FQDN to follow CNAME. + ch, err := newChallenge(domain, keyAuth) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + + ctx := context.Background() + + records, err := d.client.GetHosts(ctx, ch.sld, ch.tld) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + + // Find the challenge TXT record and remove it if found. + var found bool + var newRecords []internal.Record + for _, h := range records { + if h.Name == ch.key && h.Type == "TXT" { + found = true + } else { + newRecords = append(newRecords, h) + } + } + + if !found { + return nil + } + + err = d.client.SetHosts(ctx, ch.sld, ch.tld, newRecords) + if err != nil { + return fmt.Errorf("namecheap: %w", err) + } + return nil +} diff --git a/providers/dns/namecheap/namecheap_test.go b/providers/dns/namecheap/namecheap_test.go index 6b2afb5b..e3cbb78b 100644 --- a/providers/dns/namecheap/namecheap_test.go +++ b/providers/dns/namecheap/namecheap_test.go @@ -1,13 +1,16 @@ package namecheap import ( - "fmt" + "io" "net/http" "net/http/httptest" "net/url" + "os" + "path/filepath" "testing" "time" + "github.com/go-acme/lego/v4/providers/dns/namecheap/internal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,66 +21,130 @@ const ( envTestClientIP = "10.0.0.1" ) -func TestDNSProvider_getHosts(t *testing.T) { - for _, test := range testCases { - t.Run(test.name, func(t *testing.T) { - p := setupTest(t, &test) - - ch, err := newChallenge(test.domain, "") - require.NoError(t, err) - - hosts, err := p.getHosts(ch.sld, ch.tld) - if test.errString != "" { - assert.EqualError(t, err, test.errString) - } else { - assert.NoError(t, err) - } - - next1: - for _, h := range hosts { - for _, th := range test.hosts { - if h == th { - continue next1 - } - } - t.Errorf("getHosts case %s unexpected record [%s:%s:%s]", test.name, h.Type, h.Name, h.Address) - } - - next2: - for _, th := range test.hosts { - for _, h := range hosts { - if h == th { - continue next2 - } - } - t.Errorf("getHosts case %s missing record [%s:%s:%s]", test.name, th.Type, th.Name, th.Address) - } - }) - } +type testCase struct { + name string + domain string + hosts []internal.Record + errString string + getHostsResponse string + setHostsResponse string } -func TestDNSProvider_setHosts(t *testing.T) { - for _, test := range testCases { - t.Run(test.name, func(t *testing.T) { - p := setupTest(t, &test) +var testCases = []testCase{ + { + name: "Test:Success:1", + domain: "test.example.com", + hosts: []internal.Record{ + {Type: "A", Name: "home", Address: "10.0.0.1", MXPref: "10", TTL: "1799"}, + {Type: "A", Name: "www", Address: "10.0.0.2", MXPref: "10", TTL: "1200"}, + {Type: "AAAA", Name: "a", Address: "::0", MXPref: "10", TTL: "1799"}, + {Type: "CNAME", Name: "*", Address: "example.com.", MXPref: "10", TTL: "1799"}, + {Type: "MXE", Name: "example.com", Address: "10.0.0.5", MXPref: "10", TTL: "1800"}, + {Type: "URL", Name: "xyz", Address: "https://google.com", MXPref: "10", TTL: "1799"}, + }, + getHostsResponse: "getHosts_success1.xml", + setHostsResponse: "setHosts_success1.xml", + }, + { + name: "Test:Success:2", + domain: "example.com", + hosts: []internal.Record{ + {Type: "A", Name: "@", Address: "10.0.0.2", MXPref: "10", TTL: "1200"}, + {Type: "A", Name: "www", Address: "10.0.0.3", MXPref: "10", TTL: "60"}, + }, + getHostsResponse: "getHosts_success2.xml", + setHostsResponse: "setHosts_success2.xml", + }, + { + name: "Test:Error:BadApiKey:1", + domain: "test.example.com", + errString: "API Key is invalid or API access has not been enabled [1011102]", + getHostsResponse: "getHosts_errorBadAPIKey1.xml", + }, +} - ch, err := newChallenge(test.domain, "") - require.NoError(t, err) +func setupTest(t *testing.T, tc *testCase) *DNSProvider { + t.Helper() - hosts, err := p.getHosts(ch.sld, ch.tld) - if test.errString != "" { - assert.EqualError(t, err, test.errString) - } else { - require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + values := r.URL.Query() + cmd := values.Get("Command") + switch cmd { + case "namecheap.domains.dns.getHosts": + assertHdr(t, tc, &values) + w.WriteHeader(http.StatusOK) + writeFixture(w, tc.getHostsResponse) + default: + t.Errorf("Unexpected GET command: %s", cmd) } + + case http.MethodPost: + err := r.ParseForm() if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } + values := r.Form + cmd := values.Get("Command") + switch cmd { + case "namecheap.domains.dns.setHosts": + assertHdr(t, tc, &values) + w.WriteHeader(http.StatusOK) + writeFixture(w, tc.setHostsResponse) + default: + t.Errorf("Unexpected POST command: %s", cmd) + } - err = p.setHosts(ch.sld, ch.tld, hosts) - require.NoError(t, err) - }) + default: + t.Errorf("Unexpected http method: %s", r.Method) + } + }) + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + return mockDNSProvider(t, server.URL) +} + +func mockDNSProvider(t *testing.T, baseURL string) *DNSProvider { + t.Helper() + + config := NewDefaultConfig() + config.BaseURL = baseURL + config.APIUser = envTestUser + config.APIKey = envTestKey + config.ClientIP = envTestClientIP + config.HTTPClient = &http.Client{Timeout: 60 * time.Second} + + provider, err := NewDNSProviderConfig(config) + require.NoError(t, err) + + return provider +} + +func assertHdr(t *testing.T, tc *testCase, values *url.Values) { + t.Helper() + + ch, _ := newChallenge(tc.domain, "") + assert.Equal(t, envTestUser, values.Get("ApiUser"), "ApiUser") + assert.Equal(t, envTestKey, values.Get("ApiKey"), "ApiKey") + assert.Equal(t, envTestUser, values.Get("UserName"), "UserName") + assert.Equal(t, envTestClientIP, values.Get("ClientIp"), "ClientIp") + assert.Equal(t, ch.sld, values.Get("SLD"), "SLD") + assert.Equal(t, ch.tld, values.Get("TLD"), "TLD") +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("internal", "fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) } func TestDNSProvider_Present(t *testing.T) { @@ -160,196 +227,3 @@ func TestDomainSplit(t *testing.T) { }) } } - -func assertHdr(t *testing.T, tc *testCase, values *url.Values) { - t.Helper() - - ch, _ := newChallenge(tc.domain, "") - assert.Equal(t, envTestUser, values.Get("ApiUser"), "ApiUser") - assert.Equal(t, envTestKey, values.Get("ApiKey"), "ApiKey") - assert.Equal(t, envTestUser, values.Get("UserName"), "UserName") - assert.Equal(t, envTestClientIP, values.Get("ClientIp"), "ClientIp") - assert.Equal(t, ch.sld, values.Get("SLD"), "SLD") - assert.Equal(t, ch.tld, values.Get("TLD"), "TLD") -} - -func setupTest(t *testing.T, tc *testCase) *DNSProvider { - t.Helper() - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - values := r.URL.Query() - cmd := values.Get("Command") - switch cmd { - case "namecheap.domains.dns.getHosts": - assertHdr(t, tc, &values) - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, tc.getHostsResponse) - default: - t.Errorf("Unexpected GET command: %s", cmd) - } - - case http.MethodPost: - err := r.ParseForm() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - values := r.Form - cmd := values.Get("Command") - switch cmd { - case "namecheap.domains.dns.setHosts": - assertHdr(t, tc, &values) - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, tc.setHostsResponse) - default: - t.Errorf("Unexpected POST command: %s", cmd) - } - - default: - t.Errorf("Unexpected http method: %s", r.Method) - } - }) - - server := httptest.NewServer(handler) - t.Cleanup(server.Close) - - return mockDNSProvider(t, server.URL) -} - -func mockDNSProvider(t *testing.T, baseURL string) *DNSProvider { - t.Helper() - - config := NewDefaultConfig() - config.BaseURL = baseURL - config.APIUser = envTestUser - config.APIKey = envTestKey - config.ClientIP = envTestClientIP - config.HTTPClient = &http.Client{Timeout: 60 * time.Second} - - provider, err := NewDNSProviderConfig(config) - require.NoError(t, err) - - return provider -} - -type testCase struct { - name string - domain string - hosts []Record - errString string - getHostsResponse string - setHostsResponse string -} - -var testCases = []testCase{ - { - name: "Test:Success:1", - domain: "test.example.com", - hosts: []Record{ - {Type: "A", Name: "home", Address: "10.0.0.1", MXPref: "10", TTL: "1799"}, - {Type: "A", Name: "www", Address: "10.0.0.2", MXPref: "10", TTL: "1200"}, - {Type: "AAAA", Name: "a", Address: "::0", MXPref: "10", TTL: "1799"}, - {Type: "CNAME", Name: "*", Address: "example.com.", MXPref: "10", TTL: "1799"}, - {Type: "MXE", Name: "example.com", Address: "10.0.0.5", MXPref: "10", TTL: "1800"}, - {Type: "URL", Name: "xyz", Address: "https://google.com", MXPref: "10", TTL: "1799"}, - }, - getHostsResponse: responseGetHostsSuccess1, - setHostsResponse: responseSetHostsSuccess1, - }, - { - name: "Test:Success:2", - domain: "example.com", - hosts: []Record{ - {Type: "A", Name: "@", Address: "10.0.0.2", MXPref: "10", TTL: "1200"}, - {Type: "A", Name: "www", Address: "10.0.0.3", MXPref: "10", TTL: "60"}, - }, - getHostsResponse: responseGetHostsSuccess2, - setHostsResponse: responseSetHostsSuccess2, - }, - { - name: "Test:Error:BadApiKey:1", - domain: "test.example.com", - errString: "API Key is invalid or API access has not been enabled [1011102]", - getHostsResponse: responseGetHostsErrorBadAPIKey1, - }, -} - -const responseGetHostsSuccess1 = ` - - - - namecheap.domains.dns.getHosts - - - - - - - - - - - PHX01SBAPI01 - --5:00 - 3.338 -` - -const responseSetHostsSuccess1 = ` - - - - namecheap.domains.dns.setHosts - - - - - - PHX01SBAPI01 - --5:00 - 2.347 -` - -const responseGetHostsSuccess2 = ` - - - - namecheap.domains.dns.getHosts - - - - - - - PHX01SBAPI01 - --5:00 - 3.338 -` - -const responseSetHostsSuccess2 = ` - - - - namecheap.domains.dns.setHosts - - - - - - PHX01SBAPI01 - --5:00 - 2.347 -` - -const responseGetHostsErrorBadAPIKey1 = ` - - - API Key is invalid or API access has not been enabled - - - - PHX01SBAPI01 - --5:00 - 0 -` diff --git a/providers/dns/namesilo/namesilo.go b/providers/dns/namesilo/namesilo.go index 1f59911e..bd1a3553 100644 --- a/providers/dns/namesilo/namesilo.go +++ b/providers/dns/namesilo/namesilo.go @@ -88,11 +88,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneName, err := getZoneNameByDomain(info.EffectiveFQDN) + zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("namesilo: %w", err) + return fmt.Errorf("namesilo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + zoneName := dns01.UnFqdn(zone) + subdomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zoneName) if err != nil { return fmt.Errorf("namesilo: %w", err) @@ -120,11 +122,13 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneName, err := getZoneNameByDomain(info.EffectiveFQDN) + zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("namesilo: %w", err) + return fmt.Errorf("namesilo: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + zoneName := dns01.UnFqdn(zone) + resp, err := d.client.DnsListRecords(&namesilo.DnsListRecordsParams{Domain: zoneName}) if err != nil { return fmt.Errorf("namesilo: %w", err) @@ -152,11 +156,3 @@ func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error { func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { return d.config.PropagationTimeout, d.config.PollingInterval } - -func getZoneNameByDomain(domain string) (string, error) { - zone, err := dns01.FindZoneByFqdn(domain) - if err != nil { - return "", fmt.Errorf("failed to find zone for domain: %s, %w", domain, err) - } - return dns01.UnFqdn(zone), nil -} diff --git a/providers/dns/nearlyfreespeech/internal/client.go b/providers/dns/nearlyfreespeech/internal/client.go index a59636ae..1242c6ad 100644 --- a/providers/dns/nearlyfreespeech/internal/client.go +++ b/providers/dns/nearlyfreespeech/internal/client.go @@ -1,6 +1,7 @@ package internal import ( + "context" "crypto/sha1" "encoding/json" "fmt" @@ -13,6 +14,7 @@ import ( "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" querystring "github.com/google/go-querystring/query" ) @@ -23,24 +25,25 @@ const authenticationHeader = "X-NFSN-Authentication" const saltBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" type Client struct { - HTTPClient *http.Client - baseURL *url.URL - login string apiKey string + + baseURL *url.URL + HTTPClient *http.Client } func NewClient(login string, apiKey string) *Client { baseURL, _ := url.Parse(apiURL) + return &Client{ - HTTPClient: &http.Client{Timeout: 10 * time.Second}, - baseURL: baseURL, login: login, apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, } } -func (c Client) AddRecord(domain string, record Record) error { +func (c Client) AddRecord(ctx context.Context, domain string, record Record) error { endpoint := c.baseURL.JoinPath("dns", dns01.UnFqdn(domain), "addRR") params, err := querystring.Values(record) @@ -48,10 +51,10 @@ func (c Client) AddRecord(domain string, record Record) error { return err } - return c.do(endpoint, params) + return c.doRequest(ctx, endpoint, params) } -func (c Client) RemoveRecord(domain string, record Record) error { +func (c Client) RemoveRecord(ctx context.Context, domain string, record Record) error { endpoint := c.baseURL.JoinPath("dns", dns01.UnFqdn(domain), "removeRR") params, err := querystring.Values(record) @@ -59,15 +62,15 @@ func (c Client) RemoveRecord(domain string, record Record) error { return err } - return c.do(endpoint, params) + return c.doRequest(ctx, endpoint, params) } -func (c Client) do(endpoint *url.URL, params url.Values) error { +func (c Client) doRequest(ctx context.Context, endpoint *url.URL, params url.Values) error { payload := params.Encode() - req, err := http.NewRequest(http.MethodPost, endpoint.String(), strings.NewReader(payload)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), strings.NewReader(payload)) if err != nil { - return err + return fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -75,21 +78,13 @@ func (c Client) do(endpoint *url.URL, params url.Values) error { resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) - - apiErr := &APIError{} - err := json.Unmarshal(data, apiErr) - if err != nil { - return fmt.Errorf("%s: %s", resp.Status, data) - } - - return apiErr + return parseError(req, resp) } return nil @@ -113,3 +108,15 @@ func (c Client) createSignature(uri string, body string) string { return fmt.Sprintf("%s;%s;%s;%02x", c.login, timestamp, salt, sha1.Sum([]byte(hashInput))) } + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := &APIError{} + err := json.Unmarshal(raw, errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errAPI +} diff --git a/providers/dns/nearlyfreespeech/internal/client_test.go b/providers/dns/nearlyfreespeech/internal/client_test.go index b5bf30a9..05d7d676 100644 --- a/providers/dns/nearlyfreespeech/internal/client_test.go +++ b/providers/dns/nearlyfreespeech/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -16,8 +17,8 @@ func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() - server := httptest.NewServer(mux) + t.Cleanup(server.Close) client := NewClient("user", "secret") client.HTTPClient = server.Client() @@ -91,7 +92,7 @@ func TestClient_AddRecord(t *testing.T) { TTL: 30, } - err := client.AddRecord("example.com", record) + err := client.AddRecord(context.Background(), "example.com", record) require.NoError(t, err) } @@ -107,7 +108,7 @@ func TestClient_AddRecord_error(t *testing.T) { TTL: 30, } - err := client.AddRecord("example.com", record) + err := client.AddRecord(context.Background(), "example.com", record) require.Error(t, err) } @@ -128,7 +129,7 @@ func TestClient_RemoveRecord(t *testing.T) { Data: "txtTXTtxt", } - err := client.RemoveRecord("example.com", record) + err := client.RemoveRecord(context.Background(), "example.com", record) require.NoError(t, err) } @@ -143,6 +144,6 @@ func TestClient_RemoveRecord_error(t *testing.T) { Data: "txtTXTtxt", } - err := client.RemoveRecord("example.com", record) + err := client.RemoveRecord(context.Background(), "example.com", record) require.Error(t, err) } diff --git a/providers/dns/nearlyfreespeech/nearlyfreespeech.go b/providers/dns/nearlyfreespeech/nearlyfreespeech.go index b6c9e810..eb001da8 100644 --- a/providers/dns/nearlyfreespeech/nearlyfreespeech.go +++ b/providers/dns/nearlyfreespeech/nearlyfreespeech.go @@ -2,6 +2,7 @@ package nearlyfreespeech import ( + "context" "errors" "fmt" "net/http" @@ -112,7 +113,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nearlyfreespeech: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("nearlyfreespeech: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -127,7 +128,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - err = d.client.AddRecord(authZone, record) + err = d.client.AddRecord(context.Background(), authZone, record) if err != nil { return fmt.Errorf("nearlyfreespeech: %w", err) } @@ -141,7 +142,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nearlyfreespeech: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("nearlyfreespeech: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordName, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -155,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { Data: info.Value, } - err = d.client.RemoveRecord(domain, record) + err = d.client.RemoveRecord(context.Background(), domain, record) if err != nil { return fmt.Errorf("nearlyfreespeech: %w", err) } diff --git a/providers/dns/netcup/internal/client.go b/providers/dns/netcup/internal/client.go index bba0250d..9573c09c 100644 --- a/providers/dns/netcup/internal/client.go +++ b/providers/dns/netcup/internal/client.go @@ -2,123 +2,28 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // defaultBaseURL for reaching the jSON-based API-Endpoint of netcup. const defaultBaseURL = "https://ccp.netcup.net/run/webservice/servers/endpoint.php?JSON" -// success response status. -const success = "success" - -// Request wrapper as specified in netcup wiki -// needed for every request to netcup API around *Msg. -// https://www.netcup-wiki.de/wiki/CCP_API#Anmerkungen_zu_JSON-Requests -type Request struct { - Action string `json:"action"` - Param interface{} `json:"param"` -} - -// LoginRequest as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#login -type LoginRequest struct { - CustomerNumber string `json:"customernumber"` - APIKey string `json:"apikey"` - APIPassword string `json:"apipassword"` - ClientRequestID string `json:"clientrequestid,omitempty"` -} - -// LogoutRequest as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#logout -type LogoutRequest struct { - CustomerNumber string `json:"customernumber"` - APIKey string `json:"apikey"` - APISessionID string `json:"apisessionid"` - ClientRequestID string `json:"clientrequestid,omitempty"` -} - -// UpdateDNSRecordsRequest as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#updateDnsRecords -type UpdateDNSRecordsRequest struct { - DomainName string `json:"domainname"` - CustomerNumber string `json:"customernumber"` - APIKey string `json:"apikey"` - APISessionID string `json:"apisessionid"` - ClientRequestID string `json:"clientrequestid,omitempty"` - DNSRecordSet DNSRecordSet `json:"dnsrecordset"` -} - -// DNSRecordSet as specified in netcup WSDL. -// needed in UpdateDNSRecordsRequest. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Dnsrecordset -type DNSRecordSet struct { - DNSRecords []DNSRecord `json:"dnsrecords"` -} - -// InfoDNSRecordsRequest as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#infoDnsRecords -type InfoDNSRecordsRequest struct { - DomainName string `json:"domainname"` - CustomerNumber string `json:"customernumber"` - APIKey string `json:"apikey"` - APISessionID string `json:"apisessionid"` - ClientRequestID string `json:"clientrequestid,omitempty"` -} - -// DNSRecord as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Dnsrecord -type DNSRecord struct { - ID int `json:"id,string,omitempty"` - Hostname string `json:"hostname"` - RecordType string `json:"type"` - Priority string `json:"priority,omitempty"` - Destination string `json:"destination"` - DeleteRecord bool `json:"deleterecord,omitempty"` - State string `json:"state,omitempty"` - TTL int `json:"ttl,omitempty"` -} - -// ResponseMsg as specified in netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Responsemessage -type ResponseMsg struct { - ServerRequestID string `json:"serverrequestid"` - ClientRequestID string `json:"clientrequestid,omitempty"` - Action string `json:"action"` - Status string `json:"status"` - StatusCode int `json:"statuscode"` - ShortMessage string `json:"shortmessage"` - LongMessage string `json:"longmessage"` - ResponseData json.RawMessage `json:"responsedata,omitempty"` -} - -func (r *ResponseMsg) Error() string { - return fmt.Sprintf("an error occurred during the action %s: [Status=%s, StatusCode=%d, ShortMessage=%s, LongMessage=%s]", - r.Action, r.Status, r.StatusCode, r.ShortMessage, r.LongMessage) -} - -// LoginResponse response to login action. -type LoginResponse struct { - APISessionID string `json:"apisessionid"` -} - -// InfoDNSRecordsResponse response to infoDnsRecords action. -type InfoDNSRecordsResponse struct { - APISessionID string `json:"apisessionid"` - DNSRecords []DNSRecord `json:"dnsrecords,omitempty"` -} - // Client netcup DNS client. type Client struct { customerNumber string apiKey string apiPassword string - HTTPClient *http.Client - BaseURL string + + baseURL string + HTTPClient *http.Client } // NewClient creates a netcup DNS client. @@ -131,73 +36,27 @@ func NewClient(customerNumber, apiKey, apiPassword string) (*Client, error) { customerNumber: customerNumber, apiKey: apiKey, apiPassword: apiPassword, - BaseURL: defaultBaseURL, - HTTPClient: &http.Client{ - Timeout: 10 * time.Second, - }, + baseURL: defaultBaseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } -// Login performs the login as specified by the netcup WSDL -// returns sessionID needed to perform remaining actions. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php -func (c *Client) Login() (string, error) { - payload := &Request{ - Action: "login", - Param: &LoginRequest{ - CustomerNumber: c.customerNumber, - APIKey: c.apiKey, - APIPassword: c.apiPassword, - ClientRequestID: "", - }, - } - - var responseData LoginResponse - err := c.doRequest(payload, &responseData) - if err != nil { - return "", fmt.Errorf("loging error: %w", err) - } - - return responseData.APISessionID, nil -} - -// Logout performs the logout with the supplied sessionID as specified by the netcup WSDL. -// https://ccp.netcup.net/run/webservice/servers/endpoint.php -func (c *Client) Logout(sessionID string) error { - payload := &Request{ - Action: "logout", - Param: &LogoutRequest{ - CustomerNumber: c.customerNumber, - APIKey: c.apiKey, - APISessionID: sessionID, - ClientRequestID: "", - }, - } - - err := c.doRequest(payload, nil) - if err != nil { - return fmt.Errorf("logout error: %w", err) - } - - return nil -} - // UpdateDNSRecord performs an update of the DNSRecords as specified by the netcup WSDL. // https://ccp.netcup.net/run/webservice/servers/endpoint.php -func (c *Client) UpdateDNSRecord(sessionID, domainName string, records []DNSRecord) error { +func (c *Client) UpdateDNSRecord(ctx context.Context, domainName string, records []DNSRecord) error { payload := &Request{ Action: "updateDnsRecords", Param: UpdateDNSRecordsRequest{ DomainName: domainName, CustomerNumber: c.customerNumber, APIKey: c.apiKey, - APISessionID: sessionID, + APISessionID: getSessionID(ctx), ClientRequestID: "", DNSRecordSet: DNSRecordSet{DNSRecords: records}, }, } - err := c.doRequest(payload, nil) + err := c.doRequest(ctx, payload, nil) if err != nil { return fmt.Errorf("error when sending the request: %w", err) } @@ -208,20 +67,20 @@ func (c *Client) UpdateDNSRecord(sessionID, domainName string, records []DNSReco // GetDNSRecords retrieves all dns records of an DNS-Zone as specified by the netcup WSDL // returns an array of DNSRecords. // https://ccp.netcup.net/run/webservice/servers/endpoint.php -func (c *Client) GetDNSRecords(hostname, apiSessionID string) ([]DNSRecord, error) { +func (c *Client) GetDNSRecords(ctx context.Context, hostname string) ([]DNSRecord, error) { payload := &Request{ Action: "infoDnsRecords", Param: InfoDNSRecordsRequest{ DomainName: hostname, CustomerNumber: c.customerNumber, APIKey: c.apiKey, - APISessionID: apiSessionID, + APISessionID: getSessionID(ctx), ClientRequestID: "", }, } var responseData InfoDNSRecordsResponse - err := c.doRequest(payload, &responseData) + err := c.doRequest(ctx, payload, &responseData) if err != nil { return nil, fmt.Errorf("error when sending the request: %w", err) } @@ -231,30 +90,26 @@ func (c *Client) GetDNSRecords(hostname, apiSessionID string) ([]DNSRecord, erro // doRequest marshals given body to JSON, send the request to netcup API // and returns body of response. -func (c *Client) doRequest(payload, responseData interface{}) error { - body, err := json.Marshal(payload) - if err != nil { - return err - } - - req, err := http.NewRequest(http.MethodPost, c.BaseURL, bytes.NewReader(body)) +func (c *Client) doRequest(ctx context.Context, payload, result any) error { + req, err := newJSONRequest(ctx, http.MethodPost, c.baseURL, payload) if err != nil { return err } req.Close = true - req.Header.Set("content-type", "application/json") resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } - if err = checkResponse(resp); err != nil { - return err + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusMultipleChoices { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - respMsg, err := decodeResponseMsg(resp) + respMsg, err := unmarshalResponseMsg(req, resp) if err != nil { return err } @@ -263,58 +118,18 @@ func (c *Client) doRequest(payload, responseData interface{}) error { return respMsg } - if responseData != nil { - err = json.Unmarshal(respMsg.ResponseData, responseData) - if err != nil { - //nolint:errorlint // in this context respMsg is not an error. - return fmt.Errorf("%v: unmarshaling %T error: %w: %s", - respMsg, responseData, err, string(respMsg.ResponseData)) - } + if result == nil { + return nil + } + + err = json.Unmarshal(respMsg.ResponseData, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, respMsg.ResponseData, err) } return nil } -func checkResponse(resp *http.Response) error { - if resp.StatusCode > 299 { - if resp.Body == nil { - return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) - } - - defer resp.Body.Close() - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) - } - - return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw)) - } - - return nil -} - -func decodeResponseMsg(resp *http.Response) (*ResponseMsg, error) { - if resp.Body == nil { - return nil, fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) - } - - defer resp.Body.Close() - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) - } - - var respMsg ResponseMsg - err = json.Unmarshal(raw, &respMsg) - if err != nil { - return nil, fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", respMsg, resp.StatusCode, err, string(raw)) - } - - return &respMsg, nil -} - // GetDNSRecordIdx searches a given array of DNSRecords for a given DNSRecord // equivalence is determined by Destination and RecortType attributes // returns index of given DNSRecord in given array of DNSRecords. @@ -326,3 +141,42 @@ func GetDNSRecordIdx(records []DNSRecord, record DNSRecord) (int, error) { } return -1, errors.New("no DNS Record found") } + +func newJSONRequest(ctx context.Context, method string, endpoint string, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint, buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func unmarshalResponseMsg(req *http.Request, resp *http.Response) (*ResponseMsg, error) { + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var respMsg ResponseMsg + err = json.Unmarshal(raw, &respMsg) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return &respMsg, nil +} diff --git a/providers/dns/netcup/internal/client_test.go b/providers/dns/netcup/internal/client_test.go index d62e5b2c..80923fee 100644 --- a/providers/dns/netcup/internal/client_test.go +++ b/providers/dns/netcup/internal/client_test.go @@ -1,11 +1,12 @@ package internal import ( + "bytes" + "context" "fmt" "io" "net/http" "net/http/httptest" - "strconv" "strings" "testing" @@ -31,8 +32,8 @@ func setupTest(t *testing.T) (*Client, *http.ServeMux) { client, err := NewClient("a", "b", "c") require.NoError(t, err) + client.baseURL = server.URL client.HTTPClient = server.Client() - client.BaseURL = server.URL return client, mux } @@ -139,205 +140,6 @@ func TestGetDNSRecordIdx(t *testing.T) { } } -func TestClient_Login(t *testing.T) { - client, mux := setupTest(t) - - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - raw, err := io.ReadAll(req.Body) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - - if string(raw) != `{"action":"login","param":{"customernumber":"a","apikey":"b","apipassword":"c"}}` { - http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) - return - } - - response := ` - { - "serverrequestid": "srv-request-id", - "clientrequestid": "", - "action": "login", - "status": "success", - "statuscode": 2000, - "shortmessage": "Login successful", - "longmessage": "Session has been created successful.", - "responsedata": { - "apisessionid": "api-session-id" - } - } - ` - _, err = rw.Write([]byte(response)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }) - - sessionID, err := client.Login() - require.NoError(t, err) - - assert.Equal(t, "api-session-id", sessionID) -} - -func TestClient_Login_errors(t *testing.T) { - testCases := []struct { - desc string - handler func(rw http.ResponseWriter, req *http.Request) - }{ - { - desc: "HTTP error", - handler: func(rw http.ResponseWriter, _ *http.Request) { - http.Error(rw, "error message", http.StatusInternalServerError) - }, - }, - { - desc: "API error", - handler: func(rw http.ResponseWriter, _ *http.Request) { - response := ` - { - "serverrequestid":"YxTr4EzdbJ101T211zR4yzUEMVE", - "clientrequestid":"", - "action":"login", - "status":"error", - "statuscode":4013, - "shortmessage":"Validation Error.", - "longmessage":"Message is empty.", - "responsedata":"" - }` - _, err := rw.Write([]byte(response)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - }, - { - desc: "responsedata marshaling error", - handler: func(rw http.ResponseWriter, _ *http.Request) { - response := ` - { - "serverrequestid": "srv-request-id", - "clientrequestid": "", - "action": "login", - "status": "success", - "statuscode": 2000, - "shortmessage": "Login successful", - "longmessage": "Session has been created successful.", - "responsedata": "" - }` - _, err := rw.Write([]byte(response)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - client, mux := setupTest(t) - - mux.HandleFunc("/", test.handler) - - sessionID, err := client.Login() - assert.Error(t, err) - assert.Equal(t, "", sessionID) - }) - } -} - -func TestClient_Logout(t *testing.T) { - client, mux := setupTest(t) - - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - raw, err := io.ReadAll(req.Body) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - - if string(raw) != `{"action":"logout","param":{"customernumber":"a","apikey":"b","apisessionid":"session-id"}}` { - http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) - return - } - - response := ` - { - "serverrequestid": "request-id", - "clientrequestid": "", - "action": "logout", - "status": "success", - "statuscode": 2000, - "shortmessage": "Logout successful", - "longmessage": "Session has been terminated successful.", - "responsedata": "" - }` - _, err = rw.Write([]byte(response)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }) - - err := client.Logout("session-id") - require.NoError(t, err) -} - -func TestClient_Logout_errors(t *testing.T) { - testCases := []struct { - desc string - handler func(rw http.ResponseWriter, req *http.Request) - }{ - { - desc: "HTTP error", - handler: func(rw http.ResponseWriter, _ *http.Request) { - http.Error(rw, "error message", http.StatusInternalServerError) - }, - }, - { - desc: "API error", - handler: func(rw http.ResponseWriter, _ *http.Request) { - response := ` - { - "serverrequestid":"YxTr4EzdbJ101T211zR4yzUEMVE", - "clientrequestid":"", - "action":"logout", - "status":"error", - "statuscode":4013, - "shortmessage":"Validation Error.", - "longmessage":"Message is empty.", - "responsedata":"" - }` - _, err := rw.Write([]byte(response)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - }, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - client, mux := setupTest(t) - - mux.HandleFunc("/", test.handler) - - err := client.Logout("session-id") - require.Error(t, err) - }) - } -} - func TestClient_GetDNSRecords(t *testing.T) { client, mux := setupTest(t) @@ -348,7 +150,7 @@ func TestClient_GetDNSRecords(t *testing.T) { return } - if string(raw) != `{"action":"infoDnsRecords","param":{"domainname":"example.com","customernumber":"a","apikey":"b","apisessionid":"api-session-id"}}` { + if string(bytes.TrimSpace(raw)) != `{"action":"infoDnsRecords","param":{"domainname":"example.com","customernumber":"a","apikey":"b","apisessionid":""}}` { http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) return } @@ -413,7 +215,7 @@ func TestClient_GetDNSRecords(t *testing.T) { TTL: 300, }} - records, err := client.GetDNSRecords("example.com", "api-session-id") + records, err := client.GetDNSRecords(context.Background(), "example.com") require.NoError(t, err) assert.Equal(t, expected, records) @@ -494,14 +296,14 @@ func TestClient_GetDNSRecords_errors(t *testing.T) { mux.HandleFunc("/", test.handler) - records, err := client.GetDNSRecords("example.com", "api-session-id") + records, err := client.GetDNSRecords(context.Background(), "example.com") require.Error(t, err) assert.Empty(t, records) }) } } -func TestLiveClientAuth(t *testing.T) { +func TestClient_GetDNSRecords_Live(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test") } @@ -515,35 +317,7 @@ func TestLiveClientAuth(t *testing.T) { envTest.GetValue("NETCUP_API_PASSWORD")) require.NoError(t, err) - for i := 1; i < 4; i++ { - i := i - t.Run("Test_"+strconv.Itoa(i), func(t *testing.T) { - t.Parallel() - - sessionID, err := client.Login() - require.NoError(t, err) - - err = client.Logout(sessionID) - require.NoError(t, err) - }) - } -} - -func TestLiveClientGetDnsRecords(t *testing.T) { - if !envTest.IsLiveTest() { - t.Skip("skipping live test") - } - - // Setup - envTest.RestoreEnv() - - client, err := NewClient( - envTest.GetValue("NETCUP_CUSTOMER_NUMBER"), - envTest.GetValue("NETCUP_API_KEY"), - envTest.GetValue("NETCUP_API_PASSWORD")) - require.NoError(t, err) - - sessionID, err := client.Login() + ctx, err := client.CreateSessionContext(context.Background()) require.NoError(t, err) info := dns01.GetChallengeInfo(envTest.GetDomain(), "123d==") @@ -554,15 +328,15 @@ func TestLiveClientGetDnsRecords(t *testing.T) { zone = dns01.UnFqdn(zone) // TestMethod - _, err = client.GetDNSRecords(zone, sessionID) + _, err = client.GetDNSRecords(ctx, zone) require.NoError(t, err) // Tear down - err = client.Logout(sessionID) + err = client.Logout(ctx) require.NoError(t, err) } -func TestLiveClientUpdateDnsRecord(t *testing.T) { +func TestClient_UpdateDNSRecord_Live(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test") } @@ -576,7 +350,7 @@ func TestLiveClientUpdateDnsRecord(t *testing.T) { envTest.GetValue("NETCUP_API_PASSWORD")) require.NoError(t, err) - sessionID, err := client.Login() + ctx, err := client.CreateSessionContext(context.Background()) require.NoError(t, err) info := dns01.GetChallengeInfo(envTest.GetDomain(), "123d==") @@ -597,10 +371,10 @@ func TestLiveClientUpdateDnsRecord(t *testing.T) { // test zone = dns01.UnFqdn(zone) - err = client.UpdateDNSRecord(sessionID, zone, []DNSRecord{record}) + err = client.UpdateDNSRecord(ctx, zone, []DNSRecord{record}) require.NoError(t, err) - records, err := client.GetDNSRecords(zone, sessionID) + records, err := client.GetDNSRecords(ctx, zone) require.NoError(t, err) recordIdx, err := GetDNSRecordIdx(records, record) @@ -614,9 +388,9 @@ func TestLiveClientUpdateDnsRecord(t *testing.T) { records[recordIdx].DeleteRecord = true // Tear down - err = client.UpdateDNSRecord(sessionID, envTest.GetDomain(), []DNSRecord{records[recordIdx]}) + err = client.UpdateDNSRecord(ctx, envTest.GetDomain(), []DNSRecord{records[recordIdx]}) require.NoError(t, err, "Did not remove record! Please do so yourself.") - err = client.Logout(sessionID) + err = client.Logout(ctx) require.NoError(t, err) } diff --git a/providers/dns/netcup/internal/session.go b/providers/dns/netcup/internal/session.go new file mode 100644 index 00000000..6627d74e --- /dev/null +++ b/providers/dns/netcup/internal/session.go @@ -0,0 +1,72 @@ +package internal + +import ( + "context" + "fmt" +) + +type sessionKey string + +const sessionIDKey sessionKey = "sessionID" + +// login performs the login as specified by the netcup WSDL +// returns sessionID needed to perform remaining actions. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php +func (c *Client) login(ctx context.Context) (string, error) { + payload := &Request{ + Action: "login", + Param: &LoginRequest{ + CustomerNumber: c.customerNumber, + APIKey: c.apiKey, + APIPassword: c.apiPassword, + ClientRequestID: "", + }, + } + + var responseData LoginResponse + err := c.doRequest(ctx, payload, &responseData) + if err != nil { + return "", fmt.Errorf("loging error: %w", err) + } + + return responseData.APISessionID, nil +} + +// Logout performs the logout with the supplied sessionID as specified by the netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php +func (c *Client) Logout(ctx context.Context) error { + payload := &Request{ + Action: "logout", + Param: &LogoutRequest{ + CustomerNumber: c.customerNumber, + APIKey: c.apiKey, + APISessionID: getSessionID(ctx), + ClientRequestID: "", + }, + } + + err := c.doRequest(ctx, payload, nil) + if err != nil { + return fmt.Errorf("logout error: %w", err) + } + + return nil +} + +func (c *Client) CreateSessionContext(ctx context.Context) (context.Context, error) { + sessID, err := c.login(ctx) + if err != nil { + return nil, err + } + + return context.WithValue(ctx, sessionIDKey, sessID), nil +} + +func getSessionID(ctx context.Context) string { + sessID, ok := ctx.Value(sessionIDKey).(string) + if !ok { + return "" + } + + return sessID +} diff --git a/providers/dns/netcup/internal/session_test.go b/providers/dns/netcup/internal/session_test.go new file mode 100644 index 00000000..ab605779 --- /dev/null +++ b/providers/dns/netcup/internal/session_test.go @@ -0,0 +1,245 @@ +package internal + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockContext() context.Context { + return context.WithValue(context.Background(), sessionIDKey, "session-id") +} + +func TestClient_Login(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { + raw, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + if string(bytes.TrimSpace(raw)) != `{"action":"login","param":{"customernumber":"a","apikey":"b","apipassword":"c"}}` { + http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) + return + } + + response := ` + { + "serverrequestid": "srv-request-id", + "clientrequestid": "", + "action": "login", + "status": "success", + "statuscode": 2000, + "shortmessage": "Login successful", + "longmessage": "Session has been created successful.", + "responsedata": { + "apisessionid": "api-session-id" + } + } + ` + _, err = rw.Write([]byte(response)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + sessionID, err := client.login(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "api-session-id", sessionID) +} + +func TestClient_Login_errors(t *testing.T) { + testCases := []struct { + desc string + handler func(rw http.ResponseWriter, req *http.Request) + }{ + { + desc: "HTTP error", + handler: func(rw http.ResponseWriter, _ *http.Request) { + http.Error(rw, "error message", http.StatusInternalServerError) + }, + }, + { + desc: "API error", + handler: func(rw http.ResponseWriter, _ *http.Request) { + response := ` + { + "serverrequestid":"YxTr4EzdbJ101T211zR4yzUEMVE", + "clientrequestid":"", + "action":"login", + "status":"error", + "statuscode":4013, + "shortmessage":"Validation Error.", + "longmessage":"Message is empty.", + "responsedata":"" + }` + _, err := rw.Write([]byte(response)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }, + }, + { + desc: "responsedata marshaling error", + handler: func(rw http.ResponseWriter, _ *http.Request) { + response := ` + { + "serverrequestid": "srv-request-id", + "clientrequestid": "", + "action": "login", + "status": "success", + "statuscode": 2000, + "shortmessage": "Login successful", + "longmessage": "Session has been created successful.", + "responsedata": "" + }` + _, err := rw.Write([]byte(response)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + client, mux := setupTest(t) + + mux.HandleFunc("/", test.handler) + + sessionID, err := client.login(context.Background()) + assert.Error(t, err) + assert.Equal(t, "", sessionID) + }) + } +} + +func TestClient_Logout(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { + raw, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + if string(bytes.TrimSpace(raw)) != `{"action":"logout","param":{"customernumber":"a","apikey":"b","apisessionid":"session-id"}}` { + http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest) + return + } + + response := ` + { + "serverrequestid": "request-id", + "clientrequestid": "", + "action": "logout", + "status": "success", + "statuscode": 2000, + "shortmessage": "Logout successful", + "longmessage": "Session has been terminated successful.", + "responsedata": "" + }` + _, err = rw.Write([]byte(response)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + err := client.Logout(mockContext()) + require.NoError(t, err) +} + +func TestClient_Logout_errors(t *testing.T) { + testCases := []struct { + desc string + handler func(rw http.ResponseWriter, req *http.Request) + }{ + { + desc: "HTTP error", + handler: func(rw http.ResponseWriter, _ *http.Request) { + http.Error(rw, "error message", http.StatusInternalServerError) + }, + }, + { + desc: "API error", + handler: func(rw http.ResponseWriter, _ *http.Request) { + response := ` + { + "serverrequestid":"YxTr4EzdbJ101T211zR4yzUEMVE", + "clientrequestid":"", + "action":"logout", + "status":"error", + "statuscode":4013, + "shortmessage":"Validation Error.", + "longmessage":"Message is empty.", + "responsedata":"" + }` + _, err := rw.Write([]byte(response)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + client, mux := setupTest(t) + + mux.HandleFunc("/", test.handler) + + err := client.Logout(context.Background()) + require.Error(t, err) + }) + } +} + +func TestLiveClientAuth(t *testing.T) { + if !envTest.IsLiveTest() { + t.Skip("skipping live test") + } + + // Setup + envTest.RestoreEnv() + + client, err := NewClient( + envTest.GetValue("NETCUP_CUSTOMER_NUMBER"), + envTest.GetValue("NETCUP_API_KEY"), + envTest.GetValue("NETCUP_API_PASSWORD")) + require.NoError(t, err) + + for i := 1; i < 4; i++ { + i := i + t.Run("Test_"+strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx, err := client.CreateSessionContext(context.Background()) + require.NoError(t, err) + + err = client.Logout(ctx) + require.NoError(t, err) + }) + } +} diff --git a/providers/dns/netcup/internal/types.go b/providers/dns/netcup/internal/types.go new file mode 100644 index 00000000..55212f90 --- /dev/null +++ b/providers/dns/netcup/internal/types.go @@ -0,0 +1,105 @@ +package internal + +import ( + "encoding/json" + "fmt" +) + +// success response status. +const success = "success" + +// Request wrapper as specified in netcup wiki +// needed for every request to netcup API around *Msg. +// https://www.netcup-wiki.de/wiki/CCP_API#Anmerkungen_zu_JSON-Requests +type Request struct { + Action string `json:"action"` + Param any `json:"param"` +} + +// LoginRequest as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#login +type LoginRequest struct { + CustomerNumber string `json:"customernumber"` + APIKey string `json:"apikey"` + APIPassword string `json:"apipassword"` + ClientRequestID string `json:"clientrequestid,omitempty"` +} + +// LogoutRequest as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#logout +type LogoutRequest struct { + CustomerNumber string `json:"customernumber"` + APIKey string `json:"apikey"` + APISessionID string `json:"apisessionid"` + ClientRequestID string `json:"clientrequestid,omitempty"` +} + +// UpdateDNSRecordsRequest as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#updateDnsRecords +type UpdateDNSRecordsRequest struct { + DomainName string `json:"domainname"` + CustomerNumber string `json:"customernumber"` + APIKey string `json:"apikey"` + APISessionID string `json:"apisessionid"` + ClientRequestID string `json:"clientrequestid,omitempty"` + DNSRecordSet DNSRecordSet `json:"dnsrecordset"` +} + +// DNSRecordSet as specified in netcup WSDL. +// needed in UpdateDNSRecordsRequest. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Dnsrecordset +type DNSRecordSet struct { + DNSRecords []DNSRecord `json:"dnsrecords"` +} + +// InfoDNSRecordsRequest as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#infoDnsRecords +type InfoDNSRecordsRequest struct { + DomainName string `json:"domainname"` + CustomerNumber string `json:"customernumber"` + APIKey string `json:"apikey"` + APISessionID string `json:"apisessionid"` + ClientRequestID string `json:"clientrequestid,omitempty"` +} + +// DNSRecord as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Dnsrecord +type DNSRecord struct { + ID int `json:"id,string,omitempty"` + Hostname string `json:"hostname"` + RecordType string `json:"type"` + Priority string `json:"priority,omitempty"` + Destination string `json:"destination"` + DeleteRecord bool `json:"deleterecord,omitempty"` + State string `json:"state,omitempty"` + TTL int `json:"ttl,omitempty"` +} + +// ResponseMsg as specified in netcup WSDL. +// https://ccp.netcup.net/run/webservice/servers/endpoint.php#Responsemessage +type ResponseMsg struct { + ServerRequestID string `json:"serverrequestid"` + ClientRequestID string `json:"clientrequestid,omitempty"` + Action string `json:"action"` + Status string `json:"status"` + StatusCode int `json:"statuscode"` + ShortMessage string `json:"shortmessage"` + LongMessage string `json:"longmessage"` + ResponseData json.RawMessage `json:"responsedata,omitempty"` +} + +func (r *ResponseMsg) Error() string { + return fmt.Sprintf("an error occurred during the action %s: [Status=%s, StatusCode=%d, ShortMessage=%s, LongMessage=%s]", + r.Action, r.Status, r.StatusCode, r.ShortMessage, r.LongMessage) +} + +// LoginResponse response to login action. +type LoginResponse struct { + APISessionID string `json:"apisessionid"` +} + +// InfoDNSRecordsResponse response to infoDnsRecords action. +type InfoDNSRecordsResponse struct { + APISessionID string `json:"apisessionid"` + DNSRecords []DNSRecord `json:"dnsrecords,omitempty"` +} diff --git a/providers/dns/netcup/netcup.go b/providers/dns/netcup/netcup.go index d3d16483..328c25a1 100644 --- a/providers/dns/netcup/netcup.go +++ b/providers/dns/netcup/netcup.go @@ -2,6 +2,7 @@ package netcup import ( + "context" "errors" "fmt" "net/http" @@ -96,16 +97,16 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("netcup: failed to find DNSZone, %w", err) + return fmt.Errorf("netcup: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - sessionID, err := d.client.Login() + ctx, err := d.client.CreateSessionContext(context.Background()) if err != nil { return fmt.Errorf("netcup: %w", err) } defer func() { - err = d.client.Logout(sessionID) + err = d.client.Logout(ctx) if err != nil { log.Print("netcup: %v", err) } @@ -121,7 +122,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { zone = dns01.UnFqdn(zone) - records, err := d.client.GetDNSRecords(zone, sessionID) + records, err := d.client.GetDNSRecords(ctx, zone) if err != nil { // skip no existing records log.Infof("no existing records, error ignored: %v", err) @@ -129,7 +130,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { records = append(records, record) - err = d.client.UpdateDNSRecord(sessionID, zone, records) + err = d.client.UpdateDNSRecord(ctx, zone, records) if err != nil { return fmt.Errorf("netcup: failed to add TXT-Record: %w", err) } @@ -143,16 +144,16 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("netcup: failed to find DNSZone, %w", err) + return fmt.Errorf("netcup: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - sessionID, err := d.client.Login() + ctx, err := d.client.CreateSessionContext(context.Background()) if err != nil { return fmt.Errorf("netcup: %w", err) } defer func() { - err = d.client.Logout(sessionID) + err = d.client.Logout(ctx) if err != nil { log.Print("netcup: %v", err) } @@ -162,7 +163,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { zone = dns01.UnFqdn(zone) - records, err := d.client.GetDNSRecords(zone, sessionID) + records, err := d.client.GetDNSRecords(ctx, zone) if err != nil { return fmt.Errorf("netcup: %w", err) } @@ -180,7 +181,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { records[idx].DeleteRecord = true - err = d.client.UpdateDNSRecord(sessionID, zone, []internal.DNSRecord{records[idx]}) + err = d.client.UpdateDNSRecord(ctx, zone, []internal.DNSRecord{records[idx]}) if err != nil { return fmt.Errorf("netcup: %w", err) } diff --git a/providers/dns/netlify/internal/client.go b/providers/dns/netlify/internal/client.go index 5955625f..06651bde 100644 --- a/providers/dns/netlify/internal/client.go +++ b/providers/dns/netlify/internal/client.go @@ -2,157 +2,161 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" ) const defaultBaseURL = "https://api.netlify.com/api/v1" // Client Netlify API client. type Client struct { - HTTPClient *http.Client - BaseURL string - - token string + baseURL *url.URL + httpClient *http.Client } // NewClient creates a new Client. -func NewClient(token string) *Client { - return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: defaultBaseURL, - token: token, +func NewClient(hc *http.Client) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + if hc == nil { + hc = &http.Client{Timeout: 5 * time.Second} } + + return &Client{baseURL: baseURL, httpClient: hc} } // GetRecords gets a DNS records. -func (c *Client) GetRecords(zoneID string) ([]DNSRecord, error) { - endpoint, err := c.createEndpoint("dns_zones", zoneID, "dns_records") - if err != nil { - return nil, fmt.Errorf("failed to parse endpoint: %w", err) - } +func (c *Client) GetRecords(ctx context.Context, zoneID string) ([]DNSRecord, error) { + endpoint := c.baseURL.JoinPath("dns_zones", zoneID, "dns_records") - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("API call failed: %w", err) + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if resp.StatusCode != http.StatusOK { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("invalid status code: %s: %s", resp.Status, string(body)) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } var records []DNSRecord - err = json.Unmarshal(body, &records) + err = json.Unmarshal(raw, &records) if err != nil { - return nil, fmt.Errorf("failed to marshal response body: %w", err) + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return records, nil } // CreateRecord creates a DNS records. -func (c *Client) CreateRecord(zoneID string, record DNSRecord) (*DNSRecord, error) { - endpoint, err := c.createEndpoint("dns_zones", zoneID, "dns_records") - if err != nil { - return nil, fmt.Errorf("failed to parse endpoint: %w", err) - } +func (c *Client) CreateRecord(ctx context.Context, zoneID string, record DNSRecord) (*DNSRecord, error) { + endpoint := c.baseURL.JoinPath("dns_zones", zoneID, "dns_records") - marshaledRecord, err := json.Marshal(record) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), bytes.NewReader(marshaledRecord)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) - - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("API call failed: %w", err) + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if resp.StatusCode != http.StatusCreated { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - if resp.StatusCode != http.StatusCreated { - return nil, fmt.Errorf("invalid status code: %s: %s", resp.Status, string(body)) + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } var recordResp DNSRecord - err = json.Unmarshal(body, &recordResp) + err = json.Unmarshal(raw, &recordResp) if err != nil { - return nil, fmt.Errorf("failed to marshal response body: %w", err) + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return &recordResp, nil } // RemoveRecord removes a DNS records. -func (c *Client) RemoveRecord(zoneID, recordID string) error { - endpoint, err := c.createEndpoint("dns_zones", zoneID, "dns_records", recordID) +func (c *Client) RemoveRecord(ctx context.Context, zoneID, recordID string) error { + endpoint := c.baseURL.JoinPath("dns_zones", zoneID, "dns_records", recordID) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { - return fmt.Errorf("failed to parse endpoint: %w", err) + return err } - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), nil) + resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("API call failed: %w", err) + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - if resp.StatusCode != http.StatusNoContent { - return fmt.Errorf("invalid status code: %s: %s", resp.Status, string(body)) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } return nil } -func (c *Client) createEndpoint(parts ...string) (*url.URL, error) { - base, err := url.Parse(c.BaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse base URL: %w", err) +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload interface{}) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } } - return base.JoinPath(parts...), nil + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json; charset=utf-8") + } + + return req, nil +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client } diff --git a/providers/dns/netlify/internal/client_test.go b/providers/dns/netlify/internal/client_test.go index a645bde1..e06a579b 100644 --- a/providers/dns/netlify/internal/client_test.go +++ b/providers/dns/netlify/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -12,11 +14,22 @@ import ( "github.com/stretchr/testify/require" ) -func TestClient_GetRecords(t *testing.T) { +func setupTest(t *testing.T, token string) (*Client, *http.ServeMux) { + t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) + client := NewClient(OAuthStaticAccessToken(server.Client(), token)) + client.baseURL, _ = url.Parse(server.URL) + + return client, mux +} + +func TestClient_GetRecords(t *testing.T) { + client, mux := setupTest(t, "tokenA") + mux.HandleFunc("/dns_zones/zoneID/dns_records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, "unsupported method", http.StatusMethodNotAllowed) @@ -45,10 +58,7 @@ func TestClient_GetRecords(t *testing.T) { } }) - client := NewClient("tokenA") - client.BaseURL = server.URL - - records, err := client.GetRecords("zoneID") + records, err := client.GetRecords(context.Background(), "zoneID") require.NoError(t, err) expected := []DNSRecord{ @@ -60,9 +70,7 @@ func TestClient_GetRecords(t *testing.T) { } func TestClient_CreateRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t, "tokenB") mux.HandleFunc("/dns_zones/zoneID/dns_records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -93,9 +101,6 @@ func TestClient_CreateRecord(t *testing.T) { } }) - client := NewClient("tokenB") - client.BaseURL = server.URL - record := DNSRecord{ Hostname: "_acme-challenge.example.com", TTL: 300, @@ -103,7 +108,7 @@ func TestClient_CreateRecord(t *testing.T) { Value: "txtxtxtxtxtxt", } - result, err := client.CreateRecord("zoneID", record) + result, err := client.CreateRecord(context.Background(), "zoneID", record) require.NoError(t, err) expected := &DNSRecord{ @@ -118,9 +123,7 @@ func TestClient_CreateRecord(t *testing.T) { } func TestClient_RemoveRecord(t *testing.T) { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + client, mux := setupTest(t, "tokenC") mux.HandleFunc("/dns_zones/zoneID/dns_records/recordID", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodDelete { @@ -137,9 +140,6 @@ func TestClient_RemoveRecord(t *testing.T) { rw.WriteHeader(http.StatusNoContent) }) - client := NewClient("tokenC") - client.BaseURL = server.URL - - err := client.RemoveRecord("zoneID", "recordID") + err := client.RemoveRecord(context.Background(), "zoneID", "recordID") require.NoError(t, err) } diff --git a/providers/dns/netlify/internal/model.go b/providers/dns/netlify/internal/types.go similarity index 100% rename from providers/dns/netlify/internal/model.go rename to providers/dns/netlify/internal/types.go diff --git a/providers/dns/netlify/netlify.go b/providers/dns/netlify/netlify.go index 4608cd3e..28e85f54 100644 --- a/providers/dns/netlify/netlify.go +++ b/providers/dns/netlify/netlify.go @@ -2,6 +2,7 @@ package netlify import ( + "context" "errors" "fmt" "net/http" @@ -80,11 +81,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("netlify: incomplete credentials, missing token") } - client := internal.NewClient(config.Token) - - if config.HTTPClient != nil { - client.HTTPClient = config.HTTPClient - } + client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.Token)) return &DNSProvider{ config: config, @@ -105,7 +102,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("netlify: failed to find zone: %w", err) + return fmt.Errorf("netlify: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) @@ -117,7 +114,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Value: info.Value, } - resp, err := d.client.CreateRecord(strings.ReplaceAll(authZone, ".", "_"), record) + resp, err := d.client.CreateRecord(context.Background(), strings.ReplaceAll(authZone, ".", "_"), record) if err != nil { return fmt.Errorf("netlify: failed to create TXT records: fqdn=%s, authZone=%s: %w", info.EffectiveFQDN, authZone, err) } @@ -135,7 +132,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("netlify: failed to find zone: %w", err) + return fmt.Errorf("netlify: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) @@ -148,7 +145,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("netlify: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - err = d.client.RemoveRecord(strings.ReplaceAll(authZone, ".", "_"), recordID) + err = d.client.RemoveRecord(context.Background(), strings.ReplaceAll(authZone, ".", "_"), recordID) if err != nil { return fmt.Errorf("netlify: failed to delete TXT records: fqdn=%s, authZone=%s, recordID=%s: %w", info.EffectiveFQDN, authZone, recordID, err) } diff --git a/providers/dns/nicmanager/internal/client.go b/providers/dns/nicmanager/internal/client.go index 5f165662..3134fc4f 100644 --- a/providers/dns/nicmanager/internal/client.go +++ b/providers/dns/nicmanager/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -10,6 +11,7 @@ import ( "strconv" "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" "github.com/pquerna/otp/totp" ) @@ -39,14 +41,14 @@ type Options struct { // Client a nicmanager DNS client. type Client struct { - HTTPClient *http.Client - baseURL *url.URL - username string password string otp string mode string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient create a new Client. @@ -72,29 +74,16 @@ func NewClient(opts Options) *Client { return c } -func (c Client) GetZone(name string) (*Zone, error) { +func (c Client) GetZone(ctx context.Context, name string) (*Zone, error) { endpoint := c.baseURL.JoinPath(c.mode, name) - resp, err := c.do(http.MethodGet, endpoint, nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= http.StatusBadRequest { - b, _ := io.ReadAll(resp.Body) - - msg := APIError{StatusCode: resp.StatusCode} - if err = json.Unmarshal(b, &msg); err != nil { - return nil, fmt.Errorf("failed to get zone info for %s", name) - } - - return nil, msg - } - var zone Zone - err = json.NewDecoder(resp.Body).Decode(&zone) + err = c.do(req, http.StatusOK, &zone) if err != nil { return nil, err } @@ -102,83 +91,109 @@ func (c Client) GetZone(name string) (*Zone, error) { return &zone, nil } -func (c Client) AddRecord(zone string, req RecordCreateUpdate) error { +func (c Client) AddRecord(ctx context.Context, zone string, payload RecordCreateUpdate) error { endpoint := c.baseURL.JoinPath(c.mode, zone, "records") - resp, err := c.do(http.MethodPost, endpoint, req) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, payload) if err != nil { return err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusAccepted { - b, _ := io.ReadAll(resp.Body) - - msg := APIError{StatusCode: resp.StatusCode} - if err = json.Unmarshal(b, &msg); err != nil { - return fmt.Errorf("records create should've returned %d but returned %d", http.StatusAccepted, resp.StatusCode) - } - - return msg + err = c.do(req, http.StatusAccepted, nil) + if err != nil { + return err } return nil } -func (c Client) DeleteRecord(zone string, record int) error { +func (c Client) DeleteRecord(ctx context.Context, zone string, record int) error { endpoint := c.baseURL.JoinPath(c.mode, zone, "records", strconv.Itoa(record)) - resp, err := c.do(http.MethodDelete, endpoint, nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusAccepted { - b, _ := io.ReadAll(resp.Body) - - msg := APIError{StatusCode: resp.StatusCode} - if err = json.Unmarshal(b, &msg); err != nil { - return fmt.Errorf("records delete should've returned %d but returned %d", http.StatusAccepted, resp.StatusCode) - } - - return msg + err = c.do(req, http.StatusAccepted, nil) + if err != nil { + return err } return nil } -func (c Client) do(method string, endpoint *url.URL, body interface{}) (*http.Response, error) { - var reqBody io.Reader - if body != nil { - jsonValue, err := json.Marshal(body) - if err != nil { - return nil, err - } - - reqBody = bytes.NewBuffer(jsonValue) - } - - r, err := http.NewRequest(method, endpoint.String(), reqBody) - if err != nil { - return nil, err - } - - r.Header.Set("Accept", "application/json") - r.Header.Set("Content-Type", "application/json") - - r.SetBasicAuth(c.username, c.password) +func (c Client) do(req *http.Request, expectedStatusCode int, result any) error { + req.SetBasicAuth(c.username, c.password) if c.otp != "" { tan, err := totp.GenerateCode(c.otp, time.Now()) if err != nil { - return nil, err + return err } - r.Header.Set(headerTOTPToken, tan) + req.Header.Set(headerTOTPToken, tan) } - return c.HTTPClient.Do(r) + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != expectedStatusCode { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return err +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errAPI := APIError{StatusCode: resp.StatusCode} + if err := json.Unmarshal(raw, &errAPI); err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errAPI } diff --git a/providers/dns/nicmanager/internal/client_test.go b/providers/dns/nicmanager/internal/client_test.go index 3823020b..822ec0db 100644 --- a/providers/dns/nicmanager/internal/client_test.go +++ b/providers/dns/nicmanager/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -17,7 +18,7 @@ import ( func TestClient_GetZone(t *testing.T) { client := setupTest(t, "/anycast/nicmanager-anycastdns4.net", testHandler(http.MethodGet, http.StatusOK, "zone.json")) - zone, err := client.GetZone("nicmanager-anycastdns4.net") + zone, err := client.GetZone(context.Background(), "nicmanager-anycastdns4.net") require.NoError(t, err) expected := &Zone{ @@ -40,7 +41,7 @@ func TestClient_GetZone(t *testing.T) { func TestClient_GetZone_error(t *testing.T) { client := setupTest(t, "/anycast/foo", testHandler(http.MethodGet, http.StatusNotFound, "error.json")) - _, err := client.GetZone("foo") + _, err := client.GetZone(context.Background(), "foo") require.Error(t, err) } @@ -54,7 +55,7 @@ func TestClient_AddRecord(t *testing.T) { TTL: 3600, } - err := client.AddRecord("zonedomain.tld", record) + err := client.AddRecord(context.Background(), "zonedomain.tld", record) require.NoError(t, err) } @@ -68,21 +69,21 @@ func TestClient_AddRecord_error(t *testing.T) { TTL: 3600, } - err := client.AddRecord("zonedomain.tld", record) + err := client.AddRecord(context.Background(), "zonedomain.tld", record) require.Error(t, err) } func TestClient_DeleteRecord(t *testing.T) { client := setupTest(t, "/anycast/zonedomain.tld/records/6", testHandler(http.MethodDelete, http.StatusAccepted, "error.json")) - err := client.DeleteRecord("zonedomain.tld", 6) + err := client.DeleteRecord(context.Background(), "zonedomain.tld", 6) require.NoError(t, err) } func TestClient_DeleteRecord_error(t *testing.T) { client := setupTest(t, "/anycast/zonedomain.tld/records/6", testHandler(http.MethodDelete, http.StatusNoContent, "")) - err := client.DeleteRecord("zonedomain.tld", 7) + err := client.DeleteRecord(context.Background(), "zonedomain.tld", 7) require.Error(t, err) } diff --git a/providers/dns/nicmanager/nicmanager.go b/providers/dns/nicmanager/nicmanager.go index 48561c76..e37efba9 100644 --- a/providers/dns/nicmanager/nicmanager.go +++ b/providers/dns/nicmanager/nicmanager.go @@ -2,6 +2,7 @@ package nicmanager import ( + "context" "errors" "fmt" "net/http" @@ -139,10 +140,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { rootDomain, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nicmanager: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("nicmanager: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zone, err := d.client.GetZone(dns01.UnFqdn(rootDomain)) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, dns01.UnFqdn(rootDomain)) if err != nil { return fmt.Errorf("nicmanager: failed to get zone %q: %w", rootDomain, err) } @@ -156,7 +159,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Value: info.Value, } - err = d.client.AddRecord(zone.Name, record) + err = d.client.AddRecord(ctx, zone.Name, record) if err != nil { return fmt.Errorf("nicmanager: failed to create record [zone: %q, fqdn: %q]: %w", zone.Name, info.EffectiveFQDN, err) } @@ -170,10 +173,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { rootDomain, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nicmanager: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("nicmanager: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - zone, err := d.client.GetZone(dns01.UnFqdn(rootDomain)) + ctx := context.Background() + + zone, err := d.client.GetZone(ctx, dns01.UnFqdn(rootDomain)) if err != nil { return fmt.Errorf("nicmanager: failed to get zone %q: %w", rootDomain, err) } @@ -190,7 +195,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { } if existingRecordFound { - err = d.client.DeleteRecord(zone.Name, existingRecord.ID) + err = d.client.DeleteRecord(ctx, zone.Name, existingRecord.ID) if err != nil { return fmt.Errorf("nicmanager: failed to delete record [zone: %q, domain: %q]: %w", zone.Name, name, err) } diff --git a/providers/dns/nifcloud/internal/client.go b/providers/dns/nifcloud/internal/client.go index 24ad195a..3ad95488 100644 --- a/providers/dns/nifcloud/internal/client.go +++ b/providers/dns/nifcloud/internal/client.go @@ -2,14 +2,19 @@ package internal import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" "encoding/base64" "encoding/xml" "errors" "fmt" + "io" "net/http" + "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const ( @@ -19,72 +24,13 @@ const ( XMLNs = "https://route53.amazonaws.com/doc/2012-12-12/" ) -// ChangeResourceRecordSetsRequest is a complex type that contains change information for the resource record set. -type ChangeResourceRecordSetsRequest struct { - XMLNs string `xml:"xmlns,attr"` - ChangeBatch ChangeBatch `xml:"ChangeBatch"` -} +// Client the API client for NIFCLOUD DNS. +type Client struct { + accessKey string + secretKey string -// ChangeResourceRecordSetsResponse is a complex type containing the response for the request. -type ChangeResourceRecordSetsResponse struct { - ChangeInfo ChangeInfo `xml:"ChangeInfo"` -} - -// GetChangeResponse is a complex type that contains the ChangeInfo element. -type GetChangeResponse struct { - ChangeInfo ChangeInfo `xml:"ChangeInfo"` -} - -// ErrorResponse is the information for any errors. -type ErrorResponse struct { - Error struct { - Type string `xml:"Type"` - Message string `xml:"Message"` - Code string `xml:"Code"` - } `xml:"Error"` - RequestID string `xml:"RequestId"` -} - -// ChangeBatch is the information for a change request. -type ChangeBatch struct { - Changes Changes `xml:"Changes"` - Comment string `xml:"Comment"` -} - -// Changes is array of Change. -type Changes struct { - Change []Change `xml:"Change"` -} - -// Change is the information for each resource record set that you want to change. -type Change struct { - Action string `xml:"Action"` - ResourceRecordSet ResourceRecordSet `xml:"ResourceRecordSet"` -} - -// ResourceRecordSet is the information about the resource record set to create or delete. -type ResourceRecordSet struct { - Name string `xml:"Name"` - Type string `xml:"Type"` - TTL int `xml:"TTL"` - ResourceRecords ResourceRecords `xml:"ResourceRecords"` -} - -// ResourceRecords is array of ResourceRecord. -type ResourceRecords struct { - ResourceRecord []ResourceRecord `xml:"ResourceRecord"` -} - -// ResourceRecord is the information specific to the resource record. -type ResourceRecord struct { - Value string `xml:"Value"` -} - -// ChangeInfo is A complex type that describes change information about changes made to your hosted zone. -type ChangeInfo struct { - ID string `xml:"Id"` - Status string `xml:"Status"` - SubmittedAt string `xml:"SubmittedAt"` + BaseURL *url.URL + HTTPClient *http.Client } // NewClient Creates a new client of NIFCLOUD DNS. @@ -93,117 +39,86 @@ func NewClient(accessKey, secretKey string) (*Client, error) { return nil, errors.New("credentials missing") } + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ accessKey: accessKey, secretKey: secretKey, - BaseURL: defaultBaseURL, - HTTPClient: &http.Client{}, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } -// Client client of NIFCLOUD DNS. -type Client struct { - accessKey string - secretKey string - BaseURL string - HTTPClient *http.Client -} - // ChangeResourceRecordSets Call ChangeResourceRecordSets API and return response. -func (c *Client) ChangeResourceRecordSets(hostedZoneID string, input ChangeResourceRecordSetsRequest) (*ChangeResourceRecordSetsResponse, error) { - requestURL := fmt.Sprintf("%s/%s/hostedzone/%s/rrset", c.BaseURL, apiVersion, hostedZoneID) +func (c *Client) ChangeResourceRecordSets(ctx context.Context, hostedZoneID string, input ChangeResourceRecordSetsRequest) (*ChangeResourceRecordSetsResponse, error) { + endpoint := c.BaseURL.JoinPath(apiVersion, "hostedzone", hostedZoneID, "rrset") - body := &bytes.Buffer{} - body.WriteString(xml.Header) - err := xml.NewEncoder(body).Encode(input) + req, err := newXMLRequest(ctx, http.MethodPost, endpoint, input) if err != nil { return nil, err } - req, err := http.NewRequest(http.MethodPost, requestURL, body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "text/xml; charset=utf-8") - - err = c.sign(req) - if err != nil { - return nil, fmt.Errorf("an error occurred during the creation of the signature: %w", err) - } - - res, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - if res.Body == nil { - return nil, errors.New("the response body is nil") - } - - defer res.Body.Close() - - if res.StatusCode != http.StatusOK { - errResp := &ErrorResponse{} - err = xml.NewDecoder(res.Body).Decode(errResp) - if err != nil { - return nil, fmt.Errorf("an error occurred while unmarshaling the error body to XML: %w", err) - } - - return nil, fmt.Errorf("an error occurred: %s", errResp.Error.Message) - } - output := &ChangeResourceRecordSetsResponse{} - err = xml.NewDecoder(res.Body).Decode(output) - if err != nil { - return nil, fmt.Errorf("an error occurred while unmarshaling the response body to XML: %w", err) - } - - return output, err -} - -// GetChange Call GetChange API and return response. -func (c *Client) GetChange(statusID string) (*GetChangeResponse, error) { - requestURL := fmt.Sprintf("%s/%s/change/%s", c.BaseURL, apiVersion, statusID) - - req, err := http.NewRequest(http.MethodGet, requestURL, nil) + err = c.do(req, output) if err != nil { return nil, err } - err = c.sign(req) - if err != nil { - return nil, fmt.Errorf("an error occurred during the creation of the signature: %w", err) - } - - res, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - if res.Body == nil { - return nil, errors.New("the response body is nil") - } - - defer res.Body.Close() - - if res.StatusCode != http.StatusOK { - errResp := &ErrorResponse{} - err = xml.NewDecoder(res.Body).Decode(errResp) - if err != nil { - return nil, fmt.Errorf("an error occurred while unmarshaling the error body to XML: %w", err) - } - - return nil, fmt.Errorf("an error occurred: %s", errResp.Error.Message) - } - - output := &GetChangeResponse{} - err = xml.NewDecoder(res.Body).Decode(output) - if err != nil { - return nil, fmt.Errorf("an error occurred while unmarshaling the response body to XML: %w", err) - } - return output, nil } +// GetChange Call GetChange API and return response. +func (c *Client) GetChange(ctx context.Context, statusID string) (*GetChangeResponse, error) { + endpoint := c.BaseURL.JoinPath(apiVersion, "change", statusID) + + req, err := newXMLRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + output := &GetChangeResponse{} + err = c.do(req, output) + if err != nil { + return nil, err + } + + return output, nil +} + +func (c *Client) do(req *http.Request, result any) error { + err := c.sign(req) + if err != nil { + return fmt.Errorf("an error occurred during the creation of the signature: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = xml.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + func (c *Client) sign(req *http.Request) error { if req.Header.Get("Date") == "" { location, err := time.LoadLocation("GMT") @@ -232,3 +147,39 @@ func (c *Client) sign(req *http.Request) error { return nil } + +func newXMLRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + body := new(bytes.Buffer) + body.WriteString(xml.Header) + err := xml.NewEncoder(body).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request XML body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + if payload != nil { + req.Header.Set("Content-Type", "text/xml; charset=utf-8") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errResp := &ErrorResponse{} + err := xml.Unmarshal(raw, errResp) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errResp.Error +} diff --git a/providers/dns/nifcloud/internal/client_test.go b/providers/dns/nifcloud/internal/client_test.go index cc0ae5b0..38a220b6 100644 --- a/providers/dns/nifcloud/internal/client_test.go +++ b/providers/dns/nifcloud/internal/client_test.go @@ -1,9 +1,11 @@ package internal import ( + "context" "fmt" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -25,7 +27,7 @@ func setupTest(t *testing.T, responseBody string, statusCode int) *Client { require.NoError(t, err) client.HTTPClient = server.Client() - client.BaseURL = server.URL + client.BaseURL, _ = url.Parse(server.URL) return client } @@ -43,7 +45,7 @@ func TestChangeResourceRecordSets(t *testing.T) { client := setupTest(t, responseBody, http.StatusOK) - res, err := client.ChangeResourceRecordSets("example.com", ChangeResourceRecordSetsRequest{}) + res, err := client.ChangeResourceRecordSets(context.Background(), "example.com", ChangeResourceRecordSetsRequest{}) require.NoError(t, err) assert.Equal(t, "xxxxx", res.ChangeInfo.ID) @@ -70,19 +72,19 @@ func TestChangeResourceRecordSetsErrors(t *testing.T) { `, statusCode: http.StatusUnauthorized, - expected: "an error occurred: The request signature we calculated does not match the signature you provided.", + expected: "Sender(AuthFailed): The request signature we calculated does not match the signature you provided.", }, { desc: "response body error", responseBody: "foo", statusCode: http.StatusOK, - expected: "an error occurred while unmarshaling the response body to XML: EOF", + expected: "unable to unmarshal response: [status code: 200] body: foo error: EOF", }, { desc: "error message error", responseBody: "foo", statusCode: http.StatusInternalServerError, - expected: "an error occurred while unmarshaling the error body to XML: EOF", + expected: "unexpected status code: [status code: 500] body: foo", }, } @@ -91,7 +93,7 @@ func TestChangeResourceRecordSetsErrors(t *testing.T) { t.Run(test.desc, func(t *testing.T) { client := setupTest(t, test.responseBody, test.statusCode) - res, err := client.ChangeResourceRecordSets("example.com", ChangeResourceRecordSetsRequest{}) + res, err := client.ChangeResourceRecordSets(context.Background(), "example.com", ChangeResourceRecordSetsRequest{}) assert.Nil(t, res) assert.EqualError(t, err, test.expected) }) @@ -111,7 +113,7 @@ func TestGetChange(t *testing.T) { client := setupTest(t, responseBody, http.StatusOK) - res, err := client.GetChange("12345") + res, err := client.GetChange(context.Background(), "12345") require.NoError(t, err) assert.Equal(t, "xxxxx", res.ChangeInfo.ID) @@ -138,19 +140,19 @@ func TestGetChangeErrors(t *testing.T) { `, statusCode: http.StatusUnauthorized, - expected: "an error occurred: The request signature we calculated does not match the signature you provided.", + expected: "Sender(AuthFailed): The request signature we calculated does not match the signature you provided.", }, { desc: "response body error", responseBody: "foo", statusCode: http.StatusOK, - expected: "an error occurred while unmarshaling the response body to XML: EOF", + expected: "unable to unmarshal response: [status code: 200] body: foo error: EOF", }, { desc: "error message error", responseBody: "foo", statusCode: http.StatusInternalServerError, - expected: "an error occurred while unmarshaling the error body to XML: EOF", + expected: "unexpected status code: [status code: 500] body: foo", }, } @@ -159,7 +161,7 @@ func TestGetChangeErrors(t *testing.T) { t.Run(test.desc, func(t *testing.T) { client := setupTest(t, test.responseBody, test.statusCode) - res, err := client.GetChange("12345") + res, err := client.GetChange(context.Background(), "12345") assert.Nil(t, res) assert.EqualError(t, err, test.expected) }) diff --git a/providers/dns/nifcloud/internal/types.go b/providers/dns/nifcloud/internal/types.go new file mode 100644 index 00000000..2df9f1e5 --- /dev/null +++ b/providers/dns/nifcloud/internal/types.go @@ -0,0 +1,77 @@ +package internal + +import "fmt" + +// ChangeResourceRecordSetsRequest is a complex type that contains change information for the resource record set. +type ChangeResourceRecordSetsRequest struct { + XMLNs string `xml:"xmlns,attr"` + ChangeBatch ChangeBatch `xml:"ChangeBatch"` +} + +// ChangeResourceRecordSetsResponse is a complex type containing the response for the request. +type ChangeResourceRecordSetsResponse struct { + ChangeInfo ChangeInfo `xml:"ChangeInfo"` +} + +// GetChangeResponse is a complex type that contains the ChangeInfo element. +type GetChangeResponse struct { + ChangeInfo ChangeInfo `xml:"ChangeInfo"` +} + +type Error struct { + Type string `xml:"Type"` + Message string `xml:"Message"` + Code string `xml:"Code"` +} + +func (e Error) Error() string { + return fmt.Sprintf("%s(%s): %s", e.Type, e.Code, e.Message) +} + +// ErrorResponse is the information for any errors. +type ErrorResponse struct { + Error Error `xml:"Error"` + RequestID string `xml:"RequestId"` +} + +// ChangeBatch is the information for a change request. +type ChangeBatch struct { + Changes Changes `xml:"Changes"` + Comment string `xml:"Comment"` +} + +// Changes is array of Change. +type Changes struct { + Change []Change `xml:"Change"` +} + +// Change is the information for each resource record set that you want to change. +type Change struct { + Action string `xml:"Action"` + ResourceRecordSet ResourceRecordSet `xml:"ResourceRecordSet"` +} + +// ResourceRecordSet is the information about the resource record set to create or delete. +type ResourceRecordSet struct { + Name string `xml:"Name"` + Type string `xml:"Type"` + TTL int `xml:"TTL"` + ResourceRecords ResourceRecords `xml:"ResourceRecords"` +} + +// ResourceRecords is array of ResourceRecord. +type ResourceRecords struct { + ResourceRecord []ResourceRecord `xml:"ResourceRecord"` +} + +// ResourceRecord is the information specific to the resource record. +type ResourceRecord struct { + Value string `xml:"Value"` +} + +// ChangeInfo is A complex type that describes change information about changes made to your hosted zone. +type ChangeInfo struct { + ID string `xml:"Id"` + Status string `xml:"Status"` + SubmittedAt string `xml:"SubmittedAt"` +} diff --git a/providers/dns/nifcloud/nifcloud.go b/providers/dns/nifcloud/nifcloud.go index 8e0ff334..5078175a 100644 --- a/providers/dns/nifcloud/nifcloud.go +++ b/providers/dns/nifcloud/nifcloud.go @@ -2,9 +2,11 @@ package nifcloud import ( + "context" "errors" "fmt" "net/http" + "net/url" "time" "github.com/go-acme/lego/v4/challenge/dns01" @@ -88,8 +90,13 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { client.HTTPClient = config.HTTPClient } - if len(config.BaseURL) > 0 { - client.BaseURL = config.BaseURL + if config.BaseURL != "" { + baseURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("nifcloud: %w", err) + } + + client.BaseURL = baseURL } return &DNSProvider{client: client, config: config}, nil @@ -154,10 +161,12 @@ func (d *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return fmt.Errorf("failed to find zone: %w", err) + return fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } - resp, err := d.client.ChangeResourceRecordSets(dns01.UnFqdn(authZone), reqParams) + ctx := context.Background() + + resp, err := d.client.ChangeResourceRecordSets(ctx, dns01.UnFqdn(authZone), reqParams) if err != nil { return fmt.Errorf("failed to change record set: %w", err) } @@ -165,7 +174,7 @@ func (d *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { statusID := resp.ChangeInfo.ID return wait.For("nifcloud", 120*time.Second, 4*time.Second, func() (bool, error) { - resp, err := d.client.GetChange(statusID) + resp, err := d.client.GetChange(ctx, statusID) if err != nil { return false, fmt.Errorf("failed to query change status: %w", err) } diff --git a/providers/dns/njalla/internal/client.go b/providers/dns/njalla/internal/client.go index 8b1a1a8b..f7e0023a 100644 --- a/providers/dns/njalla/internal/client.go +++ b/providers/dns/njalla/internal/client.go @@ -2,53 +2,60 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" + "io" "net/http" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const apiEndpoint = "https://njal.la/api/1/" +const authorizationHeader = "Authorization" + // Client is a Njalla API client. type Client struct { - HTTPClient *http.Client + token string + apiEndpoint string - token string + HTTPClient *http.Client } // NewClient creates a new Client. func NewClient(token string) *Client { return &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - apiEndpoint: apiEndpoint, token: token, + apiEndpoint: apiEndpoint, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // AddRecord adds a record. -func (c *Client) AddRecord(record Record) (*Record, error) { +func (c *Client) AddRecord(ctx context.Context, record Record) (*Record, error) { data := APIRequest{ Method: "add-record", Params: record, } - result, err := c.do(data) + req, err := newJSONRequest(ctx, http.MethodPost, c.apiEndpoint, data) if err != nil { return nil, err } - var rcd Record - err = json.Unmarshal(result, &rcd) + var result APIResponse[*Record] + err = c.do(req, &result) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response result: %w", err) + return nil, err } - return &rcd, nil + return result.Result, nil } // RemoveRecord removes a record. -func (c *Client) RemoveRecord(id string, domain string) error { +func (c *Client) RemoveRecord(ctx context.Context, id string, domain string) error { data := APIRequest{ Method: "remove-record", Params: Record{ @@ -57,7 +64,12 @@ func (c *Client) RemoveRecord(id string, domain string) error { }, } - _, err := c.do(data) + req, err := newJSONRequest(ctx, http.MethodPost, c.apiEndpoint, data) + if err != nil { + return err + } + + err = c.do(req, &APIResponse[json.RawMessage]{}) if err != nil { return err } @@ -66,7 +78,7 @@ func (c *Client) RemoveRecord(id string, domain string) error { } // ListRecords list the records for one domain. -func (c *Client) ListRecords(domain string) ([]Record, error) { +func (c *Client) ListRecords(ctx context.Context, domain string) ([]Record, error) { data := APIRequest{ Method: "list-records", Params: Record{ @@ -74,64 +86,67 @@ func (c *Client) ListRecords(domain string) ([]Record, error) { }, } - result, err := c.do(data) + req, err := newJSONRequest(ctx, http.MethodPost, c.apiEndpoint, data) if err != nil { return nil, err } - var rcds Records - err = json.Unmarshal(result, &rcds) + var result APIResponse[Records] + err = c.do(req, &result) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response result: %w", err) + return nil, err } - return rcds.Records, nil + return result.Result.Records, nil } -func (c *Client) do(data APIRequest) (json.RawMessage, error) { - req, err := c.createRequest(data) - if err != nil { - return nil, err - } +func (c *Client) do(req *http.Request, result Response) error { + req.Header.Set(authorizationHeader, "Njalla "+c.token) resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to perform request: %w", err) + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected error: %d", resp.StatusCode) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - apiResponse := APIResponse{} - err = json.NewDecoder(resp.Body).Decode(&apiResponse) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - if apiResponse.Error != nil { - return nil, apiResponse.Error + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } - return apiResponse.Result, nil + return result.GetError() } -func (c *Client) createRequest(data APIRequest) (*http.Request, error) { - reqBody, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("failed to marshall request body: %w", err) +func newJSONRequest(ctx context.Context, method string, endpoint string, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } } - req, err := http.NewRequest(http.MethodPost, c.apiEndpoint, bytes.NewReader(reqBody)) + req, err := http.NewRequestWithContext(ctx, method, endpoint, buf) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Njalla "+c.token) + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } diff --git a/providers/dns/njalla/internal/client_test.go b/providers/dns/njalla/internal/client_test.go index 934cbe76..3f173db6 100644 --- a/providers/dns/njalla/internal/client_test.go +++ b/providers/dns/njalla/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "encoding/json" "fmt" "net/http" @@ -11,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func setup(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *Client { +func setupTest(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *Client { t.Helper() mux := http.NewServeMux() @@ -24,7 +25,7 @@ func setup(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *Clie return } - token := req.Header.Get("Authorization") + token := req.Header.Get(authorizationHeader) if token != "Njalla secret" { _, _ = rw.Write([]byte(`{"jsonrpc":"2.0", "Error": {"code": 403, "message": "Invalid token."}}`)) return @@ -44,7 +45,7 @@ func setup(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *Clie } func TestClient_AddRecord(t *testing.T) { - client := setup(t, func(rw http.ResponseWriter, req *http.Request) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { apiReq := struct { Method string `json:"method"` Params Record `json:"params"` @@ -79,7 +80,7 @@ func TestClient_AddRecord(t *testing.T) { Type: "TXT", } - result, err := client.AddRecord(record) + result, err := client.AddRecord(context.Background(), record) require.NoError(t, err) expected := &Record{ @@ -94,7 +95,7 @@ func TestClient_AddRecord(t *testing.T) { } func TestClient_AddRecord_error(t *testing.T) { - client := setup(t, nil) + client := setupTest(t, nil) client.token = "invalid" record := Record{ @@ -105,14 +106,14 @@ func TestClient_AddRecord_error(t *testing.T) { Type: "TXT", } - result, err := client.AddRecord(record) + result, err := client.AddRecord(context.Background(), record) require.Error(t, err) assert.Nil(t, result) } func TestClient_ListRecords(t *testing.T) { - client := setup(t, func(rw http.ResponseWriter, req *http.Request) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { apiReq := struct { Method string `json:"method"` Params Record `json:"params"` @@ -156,7 +157,7 @@ func TestClient_ListRecords(t *testing.T) { } }) - records, err := client.ListRecords("example.com") + records, err := client.ListRecords(context.Background(), "example.com") require.NoError(t, err) expected := []Record{ @@ -182,17 +183,17 @@ func TestClient_ListRecords(t *testing.T) { } func TestClient_ListRecords_error(t *testing.T) { - client := setup(t, nil) + client := setupTest(t, nil) client.token = "invalid" - records, err := client.ListRecords("example.com") + records, err := client.ListRecords(context.Background(), "example.com") require.Error(t, err) assert.Empty(t, records) } func TestClient_RemoveRecord(t *testing.T) { - client := setup(t, func(rw http.ResponseWriter, req *http.Request) { + client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) { apiReq := struct { Method string `json:"method"` Params Record `json:"params"` @@ -217,14 +218,14 @@ func TestClient_RemoveRecord(t *testing.T) { _, _ = rw.Write([]byte(`{"jsonrpc":"2.0"}`)) }) - err := client.RemoveRecord("123", "example.com") + err := client.RemoveRecord(context.Background(), "123", "example.com") require.NoError(t, err) } func TestClient_RemoveRecord_error(t *testing.T) { - client := setup(t, nil) + client := setupTest(t, nil) client.token = "invalid" - err := client.RemoveRecord("123", "example.com") + err := client.RemoveRecord(context.Background(), "123", "example.com") require.Error(t, err) } diff --git a/providers/dns/njalla/internal/types.go b/providers/dns/njalla/internal/types.go index 74efe8d4..d6b8167d 100644 --- a/providers/dns/njalla/internal/types.go +++ b/providers/dns/njalla/internal/types.go @@ -1,22 +1,33 @@ package internal import ( - "encoding/json" "fmt" ) // APIRequest represents an API request body. type APIRequest struct { - Method string `json:"method"` - Params interface{} `json:"params"` + Method string `json:"method"` + Params any `json:"params"` +} + +type Response interface { + GetError() error } // APIResponse represents an API response body. -type APIResponse struct { - ID string `json:"id"` - RPC string `json:"jsonrpc"` - Error *APIError `json:"error,omitempty"` - Result json.RawMessage `json:"result,omitempty"` +type APIResponse[T any] struct { + ID string `json:"id"` + RPC string `json:"jsonrpc"` + Error *APIError `json:"error,omitempty"` + Result T `json:"result,omitempty"` +} + +func (a APIResponse[T]) GetError() error { + if a.Error == (*APIError)(nil) { + return nil + } + + return a.Error } // APIError is an API error. diff --git a/providers/dns/njalla/njalla.go b/providers/dns/njalla/njalla.go index c090fc5d..fe23e8d6 100644 --- a/providers/dns/njalla/njalla.go +++ b/providers/dns/njalla/njalla.go @@ -2,6 +2,7 @@ package njalla import ( + "context" "errors" "fmt" "net/http" @@ -116,7 +117,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Type: "TXT", } - resp, err := d.client.AddRecord(record) + resp, err := d.client.AddRecord(context.Background(), record) if err != nil { return fmt.Errorf("njalla: failed to add record: %w", err) } @@ -145,7 +146,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("njalla: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - err = d.client.RemoveRecord(recordID, dns01.UnFqdn(rootDomain)) + err = d.client.RemoveRecord(context.Background(), recordID, dns01.UnFqdn(rootDomain)) if err != nil { return fmt.Errorf("njalla: failed to delete TXT records: fqdn=%s, recordID=%s: %w", info.EffectiveFQDN, recordID, err) } diff --git a/providers/dns/nodion/nodion.go b/providers/dns/nodion/nodion.go index 8da79ea2..6b2a0be2 100644 --- a/providers/dns/nodion/nodion.go +++ b/providers/dns/nodion/nodion.go @@ -109,7 +109,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nodion: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("nodion: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -160,7 +160,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("nodion: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("nodion: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } d.zoneIDsMu.Lock() diff --git a/providers/dns/ns1/ns1.go b/providers/dns/ns1/ns1.go index 811e9d85..906ec8d6 100644 --- a/providers/dns/ns1/ns1.go +++ b/providers/dns/ns1/ns1.go @@ -150,11 +150,13 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { } func (d *DNSProvider) getHostedZone(fqdn string) (*dns.Zone, error) { - authZone, err := getAuthZone(fqdn) + authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return nil, fmt.Errorf("failed to extract auth zone from fqdn %q: %w", fqdn, err) + return nil, fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } + authZone = dns01.UnFqdn(authZone) + zone, _, err := d.client.Zones.Get(authZone) if err != nil { return nil, fmt.Errorf("failed to get zone [authZone: %q, fqdn: %q]: %w", authZone, fqdn, err) @@ -162,12 +164,3 @@ func (d *DNSProvider) getHostedZone(fqdn string) (*dns.Zone, error) { return zone, nil } - -func getAuthZone(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/ns1/ns1_test.go b/providers/dns/ns1/ns1_test.go index ea4eaa64..6df6b4af 100644 --- a/providers/dns/ns1/ns1_test.go +++ b/providers/dns/ns1/ns1_test.go @@ -5,7 +5,6 @@ import ( "time" "github.com/go-acme/lego/v4/platform/tester" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -91,57 +90,6 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func Test_getAuthZone(t *testing.T) { - type expected struct { - AuthZone string - Error string - } - - testCases := []struct { - desc string - fqdn string - expected expected - }{ - { - desc: "valid fqdn", - fqdn: "_acme-challenge.myhost.sub.example.com.", - expected: expected{ - AuthZone: "example.com", - }, - }, - { - desc: "invalid fqdn", - fqdn: "_acme-challenge.myhost.sub.example.com", - expected: expected{ - Error: "could not find the start of authority for _acme-challenge.myhost.sub.example.com: dns: domain must be fully qualified", - }, - }, - { - desc: "invalid authority", - fqdn: "_acme-challenge.myhost.sub.domain.tld.", - expected: expected{ - Error: "could not find the start of authority for _acme-challenge.myhost.sub.domain.tld.: NXDOMAIN", - }, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - authZone, err := getAuthZone(test.fqdn) - - if len(test.expected.Error) > 0 { - assert.EqualError(t, err, test.expected.Error) - } else { - require.NoError(t, err) - assert.Equal(t, test.expected.AuthZone, authZone) - } - }) - } -} - func TestLivePresent(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test") diff --git a/providers/dns/oraclecloud/oraclecloud.go b/providers/dns/oraclecloud/oraclecloud.go index f6739b46..de3a9eed 100644 --- a/providers/dns/oraclecloud/oraclecloud.go +++ b/providers/dns/oraclecloud/oraclecloud.go @@ -105,9 +105,9 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneNameOrID, err1 := dns01.FindZoneByFqdn(info.EffectiveFQDN) - if err1 != nil { - return fmt.Errorf("oraclecloud: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err1) + zoneNameOrID, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) + if err != nil { + return fmt.Errorf("oraclecloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // generate request to dns.PatchDomainRecordsRequest @@ -128,7 +128,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { }, } - _, err := d.client.PatchDomainRecords(context.Background(), request) + _, err = d.client.PatchDomainRecords(context.Background(), request) if err != nil { return fmt.Errorf("oraclecloud: %w", err) } @@ -140,9 +140,9 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneNameOrID, err1 := dns01.FindZoneByFqdn(info.EffectiveFQDN) - if err1 != nil { - return fmt.Errorf("oraclecloud: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err1) + zoneNameOrID, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) + if err != nil { + return fmt.Errorf("oraclecloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // search to TXT record's hash to delete diff --git a/providers/dns/otc/client.go b/providers/dns/otc/client.go deleted file mode 100644 index 6ad4cdfd..00000000 --- a/providers/dns/otc/client.go +++ /dev/null @@ -1,266 +0,0 @@ -package otc - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" -) - -type recordset struct { - Name string `json:"name"` - Description string `json:"description"` - Type string `json:"type"` - TTL int `json:"ttl"` - Records []string `json:"records"` -} - -type nameResponse struct { - Name string `json:"name"` -} - -type userResponse struct { - Name string `json:"name"` - Password string `json:"password"` - Domain nameResponse `json:"domain"` -} - -type passwordResponse struct { - User userResponse `json:"user"` -} - -type identityResponse struct { - Methods []string `json:"methods"` - Password passwordResponse `json:"password"` -} - -type scopeResponse struct { - Project nameResponse `json:"project"` -} - -type authResponse struct { - Identity identityResponse `json:"identity"` - Scope scopeResponse `json:"scope"` -} - -type loginResponse struct { - Auth authResponse `json:"auth"` -} - -type endpointResponse struct { - Token token `json:"token"` -} - -type token struct { - Catalog []catalog `json:"catalog"` -} - -type catalog struct { - Type string `json:"type"` - Endpoints []endpoint `json:"endpoints"` -} - -type endpoint struct { - URL string `json:"url"` -} - -type zoneItem struct { - ID string `json:"id"` - Name string `json:"name"` -} - -type zonesResponse struct { - Zones []zoneItem `json:"zones"` -} - -type recordSet struct { - ID string `json:"id"` -} - -type recordSetsResponse struct { - RecordSets []recordSet `json:"recordsets"` -} - -// Starts a new OTC API Session. Authenticates using userName, password -// and receives a token to be used in for subsequent requests. -func (d *DNSProvider) login() error { - return d.loginRequest() -} - -func (d *DNSProvider) loginRequest() error { - userResp := userResponse{ - Name: d.config.UserName, - Password: d.config.Password, - Domain: nameResponse{ - Name: d.config.DomainName, - }, - } - - loginResp := loginResponse{ - Auth: authResponse{ - Identity: identityResponse{ - Methods: []string{"password"}, - Password: passwordResponse{ - User: userResp, - }, - }, - Scope: scopeResponse{ - Project: nameResponse{ - Name: d.config.ProjectName, - }, - }, - }, - } - - body, err := json.Marshal(loginResp) - if err != nil { - return err - } - - req, err := http.NewRequest(http.MethodPost, d.config.IdentityEndpoint, bytes.NewReader(body)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: d.config.HTTPClient.Timeout} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return fmt.Errorf("OTC API request failed with HTTP status code %d", resp.StatusCode) - } - - d.token = resp.Header.Get("X-Subject-Token") - - if d.token == "" { - return errors.New("unable to get auth token") - } - - var endpointResp endpointResponse - - err = json.NewDecoder(resp.Body).Decode(&endpointResp) - if err != nil { - return err - } - - var endpoints []endpoint - for _, v := range endpointResp.Token.Catalog { - if v.Type == "dns" { - endpoints = append(endpoints, v.Endpoints...) - } - } - - if len(endpoints) > 0 { - d.baseURL = fmt.Sprintf("%s/v2", endpoints[0].URL) - } else { - return errors.New("unable to get dns endpoint") - } - - return nil -} - -func (d *DNSProvider) getZoneID(zone string) (string, error) { - resource := fmt.Sprintf("zones?name=%s", zone) - resp, err := d.sendRequest(http.MethodGet, resource, nil) - if err != nil { - return "", err - } - - var zonesRes zonesResponse - err = json.NewDecoder(resp).Decode(&zonesRes) - if err != nil { - return "", err - } - - if len(zonesRes.Zones) < 1 { - return "", fmt.Errorf("zone %s not found", zone) - } - - for _, z := range zonesRes.Zones { - if z.Name == zone { - return z.ID, nil - } - } - - return "", fmt.Errorf("zone %s not found", zone) -} - -func (d *DNSProvider) getRecordSetID(zoneID, fqdn string) (string, error) { - resource := fmt.Sprintf("zones/%s/recordsets?type=TXT&name=%s", zoneID, fqdn) - resp, err := d.sendRequest(http.MethodGet, resource, nil) - if err != nil { - return "", err - } - - var recordSetsRes recordSetsResponse - err = json.NewDecoder(resp).Decode(&recordSetsRes) - if err != nil { - return "", err - } - - if len(recordSetsRes.RecordSets) < 1 { - return "", errors.New("record not found") - } - - if len(recordSetsRes.RecordSets) > 1 { - return "", errors.New("to many records found") - } - - if recordSetsRes.RecordSets[0].ID == "" { - return "", errors.New("id not found") - } - - return recordSetsRes.RecordSets[0].ID, nil -} - -func (d *DNSProvider) deleteRecordSet(zoneID, recordID string) error { - resource := fmt.Sprintf("zones/%s/recordsets/%s", zoneID, recordID) - - _, err := d.sendRequest(http.MethodDelete, resource, nil) - return err -} - -func (d *DNSProvider) sendRequest(method, resource string, payload interface{}) (io.Reader, error) { - url := fmt.Sprintf("%s/%s", d.baseURL, resource) - - var body io.Reader - if payload != nil { - content, err := json.Marshal(payload) - if err != nil { - return nil, err - } - body = bytes.NewReader(content) - } - - req, err := http.NewRequest(method, url, body) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - if len(d.token) > 0 { - req.Header.Set("X-Auth-Token", d.token) - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return nil, fmt.Errorf("OTC API request %s failed with HTTP status code %d", url, resp.StatusCode) - } - - body1, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return bytes.NewReader(body1), nil -} diff --git a/providers/dns/otc/internal/client.go b/providers/dns/otc/internal/client.go new file mode 100644 index 00000000..59a68514 --- /dev/null +++ b/providers/dns/otc/internal/client.go @@ -0,0 +1,221 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +type Client struct { + username string + password string + domainName string + projectName string + + IdentityEndpoint string + token string + muToken sync.Mutex + + baseURL *url.URL + muBaseURL sync.Mutex + + HTTPClient *http.Client +} + +func NewClient(username string, password string, domainName string, projectName string) *Client { + return &Client{ + username: username, + password: password, + domainName: domainName, + projectName: projectName, + IdentityEndpoint: DefaultIdentityEndpoint, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c *Client) GetZoneID(ctx context.Context, zone string) (string, error) { + zonesResp, err := c.getZones(ctx, zone) + if err != nil { + return "", err + } + + if len(zonesResp.Zones) < 1 { + return "", fmt.Errorf("zone %s not found", zone) + } + + for _, z := range zonesResp.Zones { + if z.Name == zone { + return z.ID, nil + } + } + + return "", fmt.Errorf("zone %s not found", zone) +} + +// https://docs.otc.t-systems.com/domain-name-service/api-ref/apis/public_zone_management/querying_public_zones.html +func (c *Client) getZones(ctx context.Context, zone string) (*ZonesResponse, error) { + c.muBaseURL.Lock() + endpoint := c.baseURL.JoinPath("zones") + c.muBaseURL.Unlock() + + query := endpoint.Query() + query.Set("name", zone) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + var zones ZonesResponse + err = c.do(req, &zones) + if err != nil { + return nil, err + } + + return &zones, nil +} + +func (c *Client) GetRecordSetID(ctx context.Context, zoneID, fqdn string) (string, error) { + recordSetsRes, err := c.getRecordSet(ctx, zoneID, fqdn) + if err != nil { + return "", err + } + + if len(recordSetsRes.RecordSets) < 1 { + return "", errors.New("record not found") + } + + if len(recordSetsRes.RecordSets) > 1 { + return "", errors.New("to many records found") + } + + if recordSetsRes.RecordSets[0].ID == "" { + return "", errors.New("id not found") + } + + return recordSetsRes.RecordSets[0].ID, nil +} + +// https://docs.otc.t-systems.com/domain-name-service/api-ref/apis/record_set_management/querying_all_record_sets.html +func (c *Client) getRecordSet(ctx context.Context, zoneID, fqdn string) (*RecordSetsResponse, error) { + c.muBaseURL.Lock() + endpoint := c.baseURL.JoinPath("zones", zoneID, "recordsets") + c.muBaseURL.Unlock() + + query := endpoint.Query() + query.Set("type", "TXT") + query.Set("name", fqdn) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + var recordSetsRes RecordSetsResponse + err = c.do(req, &recordSetsRes) + if err != nil { + return nil, err + } + + return &recordSetsRes, nil +} + +// CreateRecordSet creates a record. +// https://docs.otc.t-systems.com/domain-name-service/api-ref/apis/record_set_management/creating_a_record_set.html +func (c *Client) CreateRecordSet(ctx context.Context, zoneID string, record RecordSets) error { + c.muBaseURL.Lock() + endpoint := c.baseURL.JoinPath("zones", zoneID, "recordsets") + c.muBaseURL.Unlock() + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return err + } + + return c.do(req, nil) +} + +// DeleteRecordSet delete a record set. +// https://docs.otc.t-systems.com/domain-name-service/api-ref/apis/record_set_management/deleting_a_record_set.html +func (c *Client) DeleteRecordSet(ctx context.Context, zoneID, recordID string) error { + c.muBaseURL.Lock() + endpoint := c.baseURL.JoinPath("zones", zoneID, "recordsets", recordID) + c.muBaseURL.Unlock() + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + c.muToken.Lock() + if c.token != "" { + req.Header.Set("X-Auth-Token", c.token) + } + c.muToken.Unlock() + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= http.StatusBadRequest { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest[T string | *url.URL](ctx context.Context, method string, endpoint T, payload interface{}) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s", endpoint), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/otc/internal/fixtures/zones-recordsets_DELETE.json b/providers/dns/otc/internal/fixtures/zones-recordsets_DELETE.json new file mode 100644 index 00000000..3090cc6a --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones-recordsets_DELETE.json @@ -0,0 +1,17 @@ +{ + "id": "2c9eb155587228570158722b6ac30007", + "name": "www.example.com.", + "description": "This is an example record set.", + "type": "A", + "ttl": 300, + "status": "PENDING_DELETE", + "links": { + "self": "https://Endpoint/v2/zones/2c9eb155587194ec01587224c9f90149/recordsets/2c9eb155587228570158722b6ac30007" + }, + "zone_id": "2c9eb155587194ec01587224c9f90149", + "zone_name": "example.com.", + "create_at": "2016-11-17T12:03:17.827", + "update_at": "2016-11-17T12:56:03.827", + "default": false, + "project_id": "e55c6f3dc4e34c9f86353b664ae0e70c" +} diff --git a/providers/dns/otc/internal/fixtures/zones-recordsets_GET.json b/providers/dns/otc/internal/fixtures/zones-recordsets_GET.json new file mode 100644 index 00000000..bfec4cfe --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones-recordsets_GET.json @@ -0,0 +1,30 @@ +{ + "links": { + "self": "https://Endpoint/v2/recordsets", + "next": "https://Endpoint/v2/recordsets?id=&limit=11&marker=2c9eb155587194ec01587224c9f9014a" + }, + "recordsets": [ + { + "id": "321321", + "name": "_acme-challenge.example.com", + "type": "TXT", + "ttl": 300, + "records": [ + "ns1.hotrot.de. xx.example.com. (1 7200 900 1209600 300)" + ], + "status": "ACTIVE", + "links": { + "self": "https://Endpoint/v2/zones/2c9eb155587194ec01587224c9f90149/recordsets/2c9eb155587194ec01587224c9f9014a" + }, + "zone_id": "2c9eb155587194ec01587224c9f90149", + "zone_name": "example.com.", + "create_at": "2016-11-17T11:56:03.439", + "update_at": "2016-11-17T11:56:03.827", + "default": true, + "project_id": "e55c6f3dc4e34c9f86353b664ae0e70c" + } + ], + "metadata": { + "total_count": 1 + } +} diff --git a/providers/dns/otc/internal/fixtures/zones-recordsets_GET_empty.json b/providers/dns/otc/internal/fixtures/zones-recordsets_GET_empty.json new file mode 100644 index 00000000..7899f988 --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones-recordsets_GET_empty.json @@ -0,0 +1,3 @@ +{ + "recordsets": [] +} diff --git a/providers/dns/otc/internal/fixtures/zones-recordsets_POST.json b/providers/dns/otc/internal/fixtures/zones-recordsets_POST.json new file mode 100644 index 00000000..f70c1744 --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones-recordsets_POST.json @@ -0,0 +1,21 @@ +{ + "id": "2c9eb155587228570158722b6ac30007", + "name": "www.example.com.", + "description": "This is an example record set.", + "type": "A", + "ttl": 300, + "records": [ + "192.168.10.1", + "192.168.10.2" + ], + "status": "PENDING_CREATE", + "links": { + "self": "https://Endpoint/v2/zones/2c9eb155587194ec01587224c9f90149/recordsets/2c9eb155587228570158722b6ac30007" + }, + "zone_id": "2c9eb155587194ec01587224c9f90149", + "zone_name": "example.com.", + "create_at": "2016-11-17T12:03:17.827", + "update_at": null, + "default": false, + "project_id": "e55c6f3dc4e34c9f86353b664ae0e70c" +} diff --git a/providers/dns/otc/internal/fixtures/zones_GET.json b/providers/dns/otc/internal/fixtures/zones_GET.json new file mode 100644 index 00000000..fcc327b7 --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones_GET.json @@ -0,0 +1,49 @@ +{ + "links": { + "self": "https://Endpoint/v2/zones?type=public&limit=11", + "next": "https://Endpoint/v2/zones?type=public&limit=11&marker=2c9eb155587194ec01587224c9f90149" + }, + "zones": [ + { + "id": "123123", + "name": "example.com.", + "description": "This is an example zone.", + "email": "xx@example.com", + "ttl": 300, + "serial": 0, + "masters": [], + "status": "ACTIVE", + "links": { + "self": "https://Endpoint/v2/zones/2c9eb155587194ec01587224c9f90149" + }, + "pool_id": "00000000570e54ee01570e9939b20019", + "project_id": "e55c6f3dc4e34c9f86353b664ae0e70c", + "zone_type": "public", + "created_at": "2016-11-17T11:56:03.439", + "updated_at": "2016-11-17T11:56:05.528", + "record_num": 2 + }, + { + "id": "2c9eb155587228570158722996c50001", + "name": "example.org.", + "description": "This is an example zone.", + "email": "xx@example.org", + "ttl": 300, + "serial": 0, + "masters": [], + "status": "PENDING_CREATE", + "links": { + "self": "https://Endpoint/v2/zones/2c9eb155587228570158722996c50001" + }, + "pool_id": "00000000570e54ee01570e9939b20019", + "project_id": "e55c6f3dc4e34c9f86353b664ae0e70c", + "zone_type": "public", + "created_at": "2016-11-17T12:01:17.996", + "updated_at": "2016-11-17T12:01:18.528", + "record_num": 2 + } + ], + "metadata": { + "total_count": 2 + } +} diff --git a/providers/dns/otc/internal/fixtures/zones_GET_empty.json b/providers/dns/otc/internal/fixtures/zones_GET_empty.json new file mode 100644 index 00000000..ee59e4e0 --- /dev/null +++ b/providers/dns/otc/internal/fixtures/zones_GET_empty.json @@ -0,0 +1,3 @@ +{ + "zones": [] +} diff --git a/providers/dns/otc/internal/identity.go b/providers/dns/otc/internal/identity.go new file mode 100644 index 00000000..f9e7cb08 --- /dev/null +++ b/providers/dns/otc/internal/identity.go @@ -0,0 +1,125 @@ +package internal + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultIdentityEndpoint the default API identity endpoint. +const DefaultIdentityEndpoint = "https://iam.eu-de.otc.t-systems.com:443/v3/auth/tokens" + +// Login Starts a new OTC API Session. Authenticates using userName, password +// and receives a token to be used in for subsequent requests. +func (c *Client) Login(ctx context.Context) error { + payload := LoginRequest{ + Auth: Auth{ + Identity: Identity{ + Methods: []string{"password"}, + Password: Password{ + User: User{ + Name: c.username, + Password: c.password, + Domain: Domain{ + Name: c.domainName, + }, + }, + }, + }, + Scope: Scope{ + Project: Project{ + Name: c.projectName, + }, + }, + }, + } + + tokenResp, token, err := c.obtainUserToken(ctx, payload) + if err != nil { + return err + } + + c.muToken.Lock() + defer c.muToken.Unlock() + c.token = token + + if c.token == "" { + return errors.New("unable to get auth token") + } + + baseURL, err := getBaseURL(tokenResp) + if err != nil { + return err + } + + c.muBaseURL.Lock() + c.baseURL = baseURL + c.muBaseURL.Unlock() + + return nil +} + +// https://docs.otc.t-systems.com/identity-access-management/api-ref/apis/token_management/obtaining_a_user_token.html +func (c *Client) obtainUserToken(ctx context.Context, payload LoginRequest) (*TokenResponse, string, error) { + req, err := newJSONRequest(ctx, http.MethodPost, c.IdentityEndpoint, payload) + if err != nil { + return nil, "", err + } + + client := &http.Client{Timeout: c.HTTPClient.Timeout} + + resp, err := client.Do(req) + if err != nil { + return nil, "", err + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return nil, "", errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + token := resp.Header.Get("X-Subject-Token") + + if token == "" { + return nil, "", errors.New("unable to get auth token") + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, "", errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var newToken TokenResponse + err = json.Unmarshal(raw, &newToken) + if err != nil { + return nil, "", errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return &newToken, token, nil +} + +func getBaseURL(tokenResp *TokenResponse) (*url.URL, error) { + var endpoints []Endpoint + for _, v := range tokenResp.Token.Catalog { + if v.Type == "dns" { + endpoints = append(endpoints, v.Endpoints...) + } + } + + if len(endpoints) == 0 { + return nil, errors.New("unable to get dns endpoint") + } + + baseURL, err := url.JoinPath(endpoints[0].URL, "v2") + if err != nil { + return nil, err + } + + return url.Parse(baseURL) +} diff --git a/providers/dns/otc/internal/identity_test.go b/providers/dns/otc/internal/identity_test.go new file mode 100644 index 00000000..18627869 --- /dev/null +++ b/providers/dns/otc/internal/identity_test.go @@ -0,0 +1,25 @@ +package internal + +import ( + "context" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClient_Login(t *testing.T) { + mock := NewDNSServerMock(t) + mock.HandleAuthSuccessfully() + + client := NewClient("user", "secret", "example.com", "test") + client.IdentityEndpoint, _ = url.JoinPath(mock.GetServerURL(), "/v3/auth/token") + + err := client.Login(context.Background()) + require.NoError(t, err) + + serverURL, _ := url.Parse(mock.GetServerURL()) + assert.Equal(t, serverURL.JoinPath("v2").String(), client.baseURL.String()) + assert.Equal(t, fakeOTCToken, client.token) +} diff --git a/providers/dns/otc/mock_test.go b/providers/dns/otc/internal/mock.go similarity index 53% rename from providers/dns/otc/mock_test.go rename to providers/dns/otc/internal/mock.go index 706d94f9..33cb0728 100644 --- a/providers/dns/otc/mock_test.go +++ b/providers/dns/otc/internal/mock.go @@ -1,10 +1,12 @@ -package otc +package internal import ( "fmt" "io" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -12,11 +14,22 @@ import ( const fakeOTCToken = "62244bc21da68d03ebac94e6636ff01f" +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("internal", "fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) +} + // DNSServerMock mock. type DNSServerMock struct { t *testing.T server *httptest.Server - Mux *http.ServeMux + mux *http.ServeMux } // NewDNSServerMock create a new DNSServerMock. @@ -28,7 +41,7 @@ func NewDNSServerMock(t *testing.T) *DNSServerMock { return &DNSServerMock{ t: t, server: httptest.NewServer(mux), - Mux: mux, + mux: mux, } } @@ -43,10 +56,10 @@ func (m *DNSServerMock) ShutdownServer() { // HandleAuthSuccessfully Handle auth successfully. func (m *DNSServerMock) HandleAuthSuccessfully() { - m.Mux.HandleFunc("/v3/auth/token", func(w http.ResponseWriter, _ *http.Request) { + m.mux.HandleFunc("/v3/auth/token", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-Subject-Token", fakeOTCToken) - fmt.Fprintf(w, `{ + _, _ = fmt.Fprintf(w, `{ "token": { "catalog": [ { @@ -70,84 +83,66 @@ func (m *DNSServerMock) HandleAuthSuccessfully() { // HandleListZonesSuccessfully Handle list zones successfully. func (m *DNSServerMock) HandleListZonesSuccessfully() { - m.Mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(m.t, r.Method, http.MethodGet) - assert.Equal(m.t, r.URL.Path, "/v2/zones") - assert.Equal(m.t, r.URL.RawQuery, "name=example.com.") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + m.mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, http.MethodGet, r.Method) + assert.Equal(m.t, "/v2/zones", r.URL.Path) + assert.Equal(m.t, "name=example.com.", r.URL.RawQuery) + assert.Equal(m.t, "application/json", r.Header.Get("Accept")) - fmt.Fprintf(w, `{ - "zones":[{ - "id":"123123", - "name":"example.com." - }]} - `) + writeFixture(w, "zones_GET.json") }) } // HandleListZonesEmpty Handle list zones empty. func (m *DNSServerMock) HandleListZonesEmpty() { - m.Mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(m.t, r.Method, http.MethodGet) - assert.Equal(m.t, r.URL.Path, "/v2/zones") - assert.Equal(m.t, r.URL.RawQuery, "name=example.com.") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + m.mux.HandleFunc("/v2/zones", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, http.MethodGet, r.Method) + assert.Equal(m.t, "/v2/zones", r.URL.Path) + assert.Equal(m.t, "name=example.com.", r.URL.RawQuery) + assert.Equal(m.t, "application/json", r.Header.Get("Accept")) - fmt.Fprintf(w, `{ - "zones":[ - ]} - `) + writeFixture(w, "zones_GET_empty.json") }) } // HandleDeleteRecordsetsSuccessfully Handle delete recordsets successfully. func (m *DNSServerMock) HandleDeleteRecordsetsSuccessfully() { - m.Mux.HandleFunc("/v2/zones/123123/recordsets/321321", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(m.t, r.Method, http.MethodDelete) - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets/321321") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + m.mux.HandleFunc("/v2/zones/123123/recordsets/321321", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, http.MethodDelete, r.Method) + assert.Equal(m.t, "/v2/zones/123123/recordsets/321321", r.URL.Path) + assert.Equal(m.t, "application/json", r.Header.Get("Accept")) - fmt.Fprintf(w, `{ - "zones":[{ - "id":"123123" - }]} - `) + writeFixture(w, "zones-recordsets_DELETE.json") }) } // HandleListRecordsetsEmpty Handle list recordsets empty. func (m *DNSServerMock) HandleListRecordsetsEmpty() { - m.Mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") - assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") + m.mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, "/v2/zones/123123/recordsets", r.URL.Path) + assert.Equal(m.t, "name=_acme-challenge.example.com.&type=TXT", r.URL.RawQuery) - fmt.Fprintf(w, `{ - "recordsets":[ - ]} - `) + writeFixture(w, "zones-recordsets_GET_empty.json") }) } // HandleListRecordsetsSuccessfully Handle list recordsets successfully. func (m *DNSServerMock) HandleListRecordsetsSuccessfully() { - m.Mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet { - assert.Equal(m.t, r.URL.Path, "/v2/zones/123123/recordsets") - assert.Equal(m.t, r.URL.RawQuery, "type=TXT&name=_acme-challenge.example.com.") - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + m.mux.HandleFunc("/v2/zones/123123/recordsets", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(m.t, "application/json", r.Header.Get("Accept")) - fmt.Fprintf(w, `{ - "recordsets":[{ - "id":"321321" - }]} - `) + if r.Method == http.MethodGet { + assert.Equal(m.t, "/v2/zones/123123/recordsets", r.URL.Path) + assert.Equal(m.t, "name=_acme-challenge.example.com.&type=TXT", r.URL.RawQuery) + + writeFixture(w, "zones-recordsets_GET.json") return } if r.Method == http.MethodPost { - assert.Equal(m.t, r.Header.Get("Content-Type"), "application/json") + assert.Equal(m.t, "application/json", r.Header.Get("Content-Type")) - body, err := io.ReadAll(r.Body) + raw, err := io.ReadAll(r.Body) assert.Nil(m.t, err) exceptedString := `{ "name": "_acme-challenge.example.com.", @@ -156,12 +151,10 @@ func (m *DNSServerMock) HandleListRecordsetsSuccessfully() { "ttl": 300, "records": ["\"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI\""] }` - assert.JSONEq(m.t, string(body), exceptedString) - fmt.Fprintf(w, `{ - "recordsets":[{ - "id":"321321" - }]} - `) + + assert.JSONEq(m.t, exceptedString, string(raw)) + + writeFixture(w, "zones-recordsets_POST.json") return } diff --git a/providers/dns/otc/internal/types.go b/providers/dns/otc/internal/types.go new file mode 100644 index 00000000..38da4f11 --- /dev/null +++ b/providers/dns/otc/internal/types.go @@ -0,0 +1,147 @@ +package internal + +// LoginRequest + +type LoginRequest struct { + Auth Auth `json:"auth"` +} + +type Auth struct { + Identity Identity `json:"identity"` + Scope Scope `json:"scope"` +} + +type Identity struct { + Methods []string `json:"methods"` + Password Password `json:"password"` +} + +type Password struct { + User User `json:"user"` +} + +type User struct { + Name string `json:"name"` + Password string `json:"password"` + Domain Domain `json:"domain"` +} + +type Scope struct { + Project Project `json:"project"` +} + +type Project struct { + Name string `json:"name"` +} + +// TokenResponse + +type TokenResponse struct { + Token Token `json:"token"` +} + +type Token struct { + User UserR `json:"user,omitempty"` + Domain Domain `json:"domain,omitempty"` + Catalog []Catalog `json:"catalog,omitempty"` + Methods []string `json:"methods,omitempty"` + Roles []Role `json:"roles,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + IssuedAt string `json:"issued_at,omitempty"` +} + +type Catalog struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Endpoints []Endpoint `json:"endpoints,omitempty"` +} + +type UserR struct { + ID string `json:"id,omitempty"` + Domain Domain `json:"domain,omitempty"` + Name string `json:"name,omitempty"` + PasswordExpiresAt string `json:"password_expires_at,omitempty"` +} + +type Endpoint struct { + ID string `json:"id,omitempty"` + URL string `json:"url,omitempty"` + Region string `json:"region,omitempty"` + RegionID string `json:"region_id,omitempty"` + Interface string `json:"interface,omitempty"` +} + +type Role struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` +} + +// RecordSetsResponse + +type RecordSetsResponse struct { + Links Links `json:"links"` + RecordSets []RecordSets `json:"recordsets"` + Metadata Metadata `json:"metadata"` +} + +type RecordSets struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Type string `json:"type,omitempty"` + TTL int `json:"ttl,omitempty"` + Records []string `json:"records,omitempty"` + + Status string `json:"status,omitempty"` + Links *Links `json:"links,omitempty"` + ZoneID string `json:"zone_id,omitempty"` + ZoneName string `json:"zone_name,omitempty"` + CreateAt string `json:"create_at,omitempty"` + UpdateAt string `json:"update_at,omitempty"` + Default bool `json:"default,omitempty"` + ProjectID string `json:"project_id,omitempty"` +} + +// ZonesResponse + +type ZonesResponse struct { + Links Links `json:"links,omitempty"` + Zones []Zone `json:"zones"` + Metadata Metadata `json:"metadata"` +} + +type Zone struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Email string `json:"email,omitempty"` + TTL int `json:"ttl,omitempty"` + Serial int `json:"serial,omitempty"` + Status string `json:"status,omitempty"` + Links *Links `json:"links,omitempty"` + PoolID string `json:"pool_id,omitempty"` + ProjectID string `json:"project_id,omitempty"` + ZoneType string `json:"zone_type,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + RecordNum int `json:"record_num,omitempty"` +} + +// Response + +type Links struct { + Self string `json:"self,omitempty"` + Next string `json:"next,omitempty"` +} + +type Metadata struct { + TotalCount int `json:"total_count,omitempty"` +} + +// Shared + +type Domain struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` +} diff --git a/providers/dns/otc/otc.go b/providers/dns/otc/otc.go index ffcf7552..e3e0b925 100644 --- a/providers/dns/otc/otc.go +++ b/providers/dns/otc/otc.go @@ -2,6 +2,7 @@ package otc import ( + "context" "errors" "fmt" "net" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/otc/internal" ) const defaultIdentityEndpoint = "https://iam.eu-de.otc.t-systems.com:443/v3/auth/tokens" @@ -60,7 +62,6 @@ func NewDefaultConfig() *Config { DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -76,9 +77,8 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config - baseURL string - token string + config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for OTC DNS. @@ -113,11 +113,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, fmt.Errorf("otc: invalid TTL, TTL (%d) must be greater than %d", config.TTL, minTTL) } - if config.IdentityEndpoint == "" { - config.IdentityEndpoint = defaultIdentityEndpoint + client := internal.NewClient(config.UserName, config.Password, config.DomainName, config.ProjectName) + + if config.IdentityEndpoint != "" { + client.IdentityEndpoint = config.IdentityEndpoint } - return &DNSProvider{config: config}, nil + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record using the specified parameters. @@ -126,22 +132,22 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("otc: %w", err) + return fmt.Errorf("otc: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.login() + ctx := context.Background() + + err = d.client.Login(ctx) if err != nil { return fmt.Errorf("otc: %w", err) } - zoneID, err := d.getZoneID(authZone) + zoneID, err := d.client.GetZoneID(ctx, authZone) if err != nil { return fmt.Errorf("otc: unable to get zone: %w", err) } - resource := fmt.Sprintf("zones/%s/recordsets", zoneID) - - r1 := &recordset{ + record := internal.RecordSets{ Name: info.EffectiveFQDN, Description: "Added TXT record for ACME dns-01 challenge using lego client", Type: "TXT", @@ -149,10 +155,11 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Records: []string{fmt.Sprintf("%q", info.Value)}, } - _, err = d.sendRequest(http.MethodPost, resource, r1) + err = d.client.CreateRecordSet(ctx, zoneID, record) if err != nil { return fmt.Errorf("otc: %w", err) } + return nil } @@ -162,28 +169,31 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("otc: %w", err) + return fmt.Errorf("otc: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - err = d.login() + ctx := context.Background() + + err = d.client.Login(ctx) if err != nil { return fmt.Errorf("otc: %w", err) } - zoneID, err := d.getZoneID(authZone) + zoneID, err := d.client.GetZoneID(ctx, authZone) if err != nil { return fmt.Errorf("otc: %w", err) } - recordID, err := d.getRecordSetID(zoneID, info.EffectiveFQDN) + recordID, err := d.client.GetRecordSetID(ctx, zoneID, info.EffectiveFQDN) if err != nil { - return fmt.Errorf("otc: unable go get record %s for zone %s: %w", info.EffectiveFQDN, domain, err) + return fmt.Errorf("otc: unable to get record %s for zone %s: %w", info.EffectiveFQDN, domain, err) } - err = d.deleteRecordSet(zoneID, recordID) + err = d.client.DeleteRecordSet(ctx, zoneID, recordID) if err != nil { return fmt.Errorf("otc: %w", err) } + return nil } diff --git a/providers/dns/otc/otc.toml b/providers/dns/otc/otc.toml index d60aa3fc..7f9703bd 100644 --- a/providers/dns/otc/otc.toml +++ b/providers/dns/otc/otc.toml @@ -20,5 +20,5 @@ Example = '''''' OTC_HTTP_TIMEOUT = "API request timeout" [Links] - API = "https://docs.otc.t-systems.com/en-us/dns/index.html" + API = "https://docs.otc.t-systems.com/domain-name-service/api-ref/index.html" diff --git a/providers/dns/otc/otc_test.go b/providers/dns/otc/otc_test.go index 03e02c81..3edfca8b 100644 --- a/providers/dns/otc/otc_test.go +++ b/providers/dns/otc/otc_test.go @@ -6,18 +6,20 @@ import ( "testing" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/providers/dns/otc/internal" "github.com/stretchr/testify/suite" ) type OTCSuite struct { suite.Suite - Mock *DNSServerMock + + mock *internal.DNSServerMock envTest *tester.EnvTest } func (s *OTCSuite) SetupTest() { - s.Mock = NewDNSServerMock(s.T()) - s.Mock.HandleAuthSuccessfully() + s.mock = internal.NewDNSServerMock(s.T()) + s.mock.HandleAuthSuccessfully() s.envTest = tester.NewEnvTest( EnvDomainName, EnvUserName, @@ -29,7 +31,7 @@ func (s *OTCSuite) SetupTest() { func (s *OTCSuite) TearDownTest() { s.envTest.RestoreEnv() - s.Mock.ShutdownServer() + s.mock.ShutdownServer() } func TestTestSuite(t *testing.T) { @@ -42,22 +44,11 @@ func (s *OTCSuite) createDNSProvider() (*DNSProvider, error) { config.Password = "Password" config.DomainName = "DomainName" config.ProjectName = "ProjectName" - config.IdentityEndpoint = fmt.Sprintf("%s/v3/auth/token", s.Mock.GetServerURL()) + config.IdentityEndpoint = fmt.Sprintf("%s/v3/auth/token", s.mock.GetServerURL()) return NewDNSProviderConfig(config) } -func (s *OTCSuite) TestLogin() { - provider, err := s.createDNSProvider() - s.Require().NoError(err) - - err = provider.loginRequest() - s.Require().NoError(err) - - s.Equal(provider.baseURL, fmt.Sprintf("%s/v2", s.Mock.GetServerURL())) - s.Equal(fakeOTCToken, provider.token) -} - func (s *OTCSuite) TestLoginEnv() { s.envTest.ClearEnv() @@ -94,8 +85,8 @@ func (s *OTCSuite) TestLoginEnvEmpty() { } func (s *OTCSuite) TestDNSProvider_Present() { - s.Mock.HandleListZonesSuccessfully() - s.Mock.HandleListRecordsetsSuccessfully() + s.mock.HandleListZonesSuccessfully() + s.mock.HandleListRecordsetsSuccessfully() provider, err := s.createDNSProvider() s.Require().NoError(err) @@ -105,8 +96,8 @@ func (s *OTCSuite) TestDNSProvider_Present() { } func (s *OTCSuite) TestDNSProvider_Present_EmptyZone() { - s.Mock.HandleListZonesEmpty() - s.Mock.HandleListRecordsetsSuccessfully() + s.mock.HandleListZonesEmpty() + s.mock.HandleListRecordsetsSuccessfully() provider, err := s.createDNSProvider() s.Require().NoError(err) @@ -116,9 +107,9 @@ func (s *OTCSuite) TestDNSProvider_Present_EmptyZone() { } func (s *OTCSuite) TestDNSProvider_CleanUp() { - s.Mock.HandleListZonesSuccessfully() - s.Mock.HandleListRecordsetsSuccessfully() - s.Mock.HandleDeleteRecordsetsSuccessfully() + s.mock.HandleListZonesSuccessfully() + s.mock.HandleListRecordsetsSuccessfully() + s.mock.HandleDeleteRecordsetsSuccessfully() provider, err := s.createDNSProvider() s.Require().NoError(err) @@ -128,8 +119,8 @@ func (s *OTCSuite) TestDNSProvider_CleanUp() { } func (s *OTCSuite) TestDNSProvider_CleanUp_EmptyRecordset() { - s.Mock.HandleListZonesSuccessfully() - s.Mock.HandleListRecordsetsEmpty() + s.mock.HandleListZonesSuccessfully() + s.mock.HandleListRecordsetsEmpty() provider, err := s.createDNSProvider() s.Require().NoError(err) diff --git a/providers/dns/ovh/ovh.go b/providers/dns/ovh/ovh.go index 014daf15..9f269f6d 100644 --- a/providers/dns/ovh/ovh.go +++ b/providers/dns/ovh/ovh.go @@ -127,7 +127,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { // Parse domain name authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("ovh: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("ovh: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) @@ -175,7 +175,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("ovh: could not determine zone for domain %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("ovh: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) diff --git a/providers/dns/pdns/client.go b/providers/dns/pdns/client.go deleted file mode 100644 index 47ffff2e..00000000 --- a/providers/dns/pdns/client.go +++ /dev/null @@ -1,217 +0,0 @@ -package pdns - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "path" - "strconv" - "strings" - - "github.com/go-acme/lego/v4/challenge/dns01" - "github.com/miekg/dns" -) - -type Record struct { - Content string `json:"content"` - Disabled bool `json:"disabled"` - - // pre-v1 API - Name string `json:"name"` - Type string `json:"type"` - TTL int `json:"ttl,omitempty"` -} - -type hostedZone struct { - ID string `json:"id"` - Name string `json:"name"` - URL string `json:"url"` - Kind string `json:"kind"` - RRSets []rrSet `json:"rrsets"` - - // pre-v1 API - Records []Record `json:"records"` -} - -type rrSet struct { - Name string `json:"name"` - Type string `json:"type"` - Kind string `json:"kind"` - ChangeType string `json:"changetype"` - Records []Record `json:"records,omitempty"` - TTL int `json:"ttl,omitempty"` -} - -type rrSets struct { - RRSets []rrSet `json:"rrsets"` -} - -type apiError struct { - ShortMsg string `json:"error"` -} - -func (a apiError) Error() string { - return a.ShortMsg -} - -type apiVersion struct { - URL string `json:"url"` - Version int `json:"version"` -} - -func (d *DNSProvider) getHostedZone(fqdn string) (*hostedZone, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return nil, err - } - - p := path.Join("/servers", d.config.ServerName, "/zones/", dns.Fqdn(authZone)) - - result, err := d.sendRequest(http.MethodGet, p, nil) - if err != nil { - return nil, err - } - - var zone hostedZone - err = json.Unmarshal(result, &zone) - if err != nil { - return nil, err - } - - // convert pre-v1 API result - if len(zone.Records) > 0 { - zone.RRSets = []rrSet{} - for _, record := range zone.Records { - set := rrSet{ - Name: record.Name, - Type: record.Type, - Records: []Record{record}, - } - zone.RRSets = append(zone.RRSets, set) - } - } - - return &zone, nil -} - -func (d *DNSProvider) findTxtRecord(fqdn string) (*rrSet, error) { - zone, err := d.getHostedZone(fqdn) - if err != nil { - return nil, err - } - - _, err = d.sendRequest(http.MethodGet, zone.URL, nil) - if err != nil { - return nil, err - } - - for _, set := range zone.RRSets { - if set.Type == "TXT" && (set.Name == dns01.UnFqdn(fqdn) || set.Name == fqdn) { - return &set, nil - } - } - - return nil, nil -} - -func (d *DNSProvider) getAPIVersion() (int, error) { - result, err := d.sendRequest(http.MethodGet, "/api", nil) - if err != nil { - return 0, err - } - - var versions []apiVersion - err = json.Unmarshal(result, &versions) - if err != nil { - return 0, err - } - - latestVersion := 0 - for _, v := range versions { - if v.Version > latestVersion { - latestVersion = v.Version - } - } - - return latestVersion, err -} - -func (d *DNSProvider) notify(zone *hostedZone) error { - if d.apiVersion < 1 || zone.Kind != "Master" && zone.Kind != "Slave" { - return nil - } - - _, err := d.sendRequest(http.MethodPut, path.Join(zone.URL, "/notify"), nil) - if err != nil { - return err - } - - return nil -} - -func (d *DNSProvider) sendRequest(method, uri string, body io.Reader) (json.RawMessage, error) { - req, err := d.makeRequest(method, uri, body) - if err != nil { - return nil, err - } - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error talking to PDNS API: %w", err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusUnprocessableEntity && (resp.StatusCode < 200 || resp.StatusCode >= 300) { - return nil, fmt.Errorf("unexpected HTTP status code %d when %sing '%s'", resp.StatusCode, req.Method, req.URL) - } - - var msg json.RawMessage - err = json.NewDecoder(resp.Body).Decode(&msg) - if err != nil { - if errors.Is(err, io.EOF) { - // empty body - return nil, nil - } - // other error - return nil, err - } - - // check for PowerDNS error message - if len(msg) > 0 && msg[0] == '{' { - var errInfo apiError - err = json.Unmarshal(msg, &errInfo) - if err != nil { - return nil, err - } - if errInfo.ShortMsg != "" { - return nil, fmt.Errorf("error talking to PDNS API: %w", errInfo) - } - } - return msg, nil -} - -func (d *DNSProvider) makeRequest(method, uri string, body io.Reader) (*http.Request, error) { - p := path.Join("/", uri) - - if p != "/api" && d.apiVersion > 0 && !strings.HasPrefix(p, "/api/v") { - p = path.Join("/api", "v"+strconv.Itoa(d.apiVersion), p) - } - - endpoint := d.config.Host.JoinPath(p) - - req, err := http.NewRequest(method, strings.TrimSuffix(endpoint.String(), "/"), body) - if err != nil { - return nil, err - } - - req.Header.Set("X-API-Key", d.config.APIKey) - - if method != http.MethodGet && method != http.MethodDelete { - req.Header.Set("Content-Type", "application/json") - } - - return req, nil -} diff --git a/providers/dns/pdns/client_test.go b/providers/dns/pdns/client_test.go deleted file mode 100644 index cb8befb9..00000000 --- a/providers/dns/pdns/client_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package pdns - -import ( - "net/http" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDNSProvider_makeRequest(t *testing.T) { - testCases := []struct { - desc string - apiVersion int - baseURL string - uri string - expected string - }{ - { - desc: "host with path", - apiVersion: 1, - baseURL: "https://example.com/test", - uri: "/foo", - expected: "https://example.com/test/api/v1/foo", - }, - { - desc: "host with path + trailing slash", - apiVersion: 1, - baseURL: "https://example.com/test/", - uri: "/foo", - expected: "https://example.com/test/api/v1/foo", - }, - { - desc: "no URI", - apiVersion: 1, - baseURL: "https://example.com/test", - uri: "", - expected: "https://example.com/test/api/v1", - }, - { - desc: "host without path", - apiVersion: 1, - baseURL: "https://example.com", - uri: "/foo", - expected: "https://example.com/api/v1/foo", - }, - { - desc: "api", - apiVersion: 1, - baseURL: "https://example.com", - uri: "/api", - expected: "https://example.com/api", - }, - { - desc: "API version 0, host with path", - apiVersion: 0, - baseURL: "https://example.com/test", - uri: "/foo", - expected: "https://example.com/test/foo", - }, - { - desc: "API version 0, host with path + trailing slash", - apiVersion: 0, - baseURL: "https://example.com/test/", - uri: "/foo", - expected: "https://example.com/test/foo", - }, - { - desc: "API version 0, no URI", - apiVersion: 0, - baseURL: "https://example.com/test", - uri: "", - expected: "https://example.com/test", - }, - { - desc: "API version 0, host without path", - apiVersion: 0, - baseURL: "https://example.com", - uri: "/foo", - expected: "https://example.com/foo", - }, - { - desc: "API version 0, api", - apiVersion: 0, - baseURL: "https://example.com", - uri: "/api", - expected: "https://example.com/api", - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - host, err := url.Parse(test.baseURL) - require.NoError(t, err) - - config := &Config{Host: host, APIKey: "secret"} - - p := &DNSProvider{ - config: config, - apiVersion: test.apiVersion, - } - - req, err := p.makeRequest(http.MethodGet, test.uri, nil) - require.NoError(t, err) - - assert.Equal(t, test.expected, req.URL.String()) - assert.Equal(t, "secret", req.Header.Get("X-API-Key")) - }) - } -} diff --git a/providers/dns/pdns/internal/client.go b/providers/dns/pdns/internal/client.go new file mode 100644 index 00000000..cddb1c63 --- /dev/null +++ b/providers/dns/pdns/internal/client.go @@ -0,0 +1,226 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "github.com/miekg/dns" +) + +// Client the PowerDNS API client. +type Client struct { + serverName string + apiKey string + + apiVersion int + + Host *url.URL + HTTPClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(host *url.URL, serverName string, apiKey string) *Client { + return &Client{ + serverName: serverName, + apiKey: apiKey, + Host: host, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +func (c *Client) APIVersion() int { + return c.apiVersion +} + +func (c *Client) SetAPIVersion(ctx context.Context) error { + var err error + + c.apiVersion, err = c.getAPIVersion(ctx) + + return err +} + +func (c *Client) getAPIVersion(ctx context.Context) (int, error) { + endpoint := c.joinPath("/", "api") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return 0, err + } + + result, err := c.do(req) + if err != nil { + return 0, err + } + + var versions []apiVersion + err = json.Unmarshal(result, &versions) + if err != nil { + return 0, err + } + + latestVersion := 0 + for _, v := range versions { + if v.Version > latestVersion { + latestVersion = v.Version + } + } + + return latestVersion, err +} + +func (c *Client) GetHostedZone(ctx context.Context, authZone string) (*HostedZone, error) { + endpoint := c.joinPath("/", "servers", c.serverName, "zones", dns.Fqdn(authZone)) + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + result, err := c.do(req) + if err != nil { + return nil, err + } + + var zone HostedZone + err = json.Unmarshal(result, &zone) + if err != nil { + return nil, err + } + + // convert pre-v1 API result + if len(zone.Records) > 0 { + zone.RRSets = []RRSet{} + for _, record := range zone.Records { + set := RRSet{ + Name: record.Name, + Type: record.Type, + Records: []Record{record}, + } + zone.RRSets = append(zone.RRSets, set) + } + } + + return &zone, nil +} + +func (c *Client) UpdateRecords(ctx context.Context, zone *HostedZone, sets RRSets) error { + endpoint := c.joinPath("/", zone.URL) + + req, err := newJSONRequest(ctx, http.MethodPatch, endpoint, sets) + if err != nil { + return err + } + + _, err = c.do(req) + if err != nil { + return err + } + + return nil +} + +func (c *Client) Notify(ctx context.Context, zone *HostedZone) error { + if c.apiVersion < 1 || zone.Kind != "Master" && zone.Kind != "Slave" { + return nil + } + + endpoint := c.joinPath("/", zone.URL, "/notify") + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, nil) + if err != nil { + return err + } + + _, err = c.do(req) + if err != nil { + return err + } + + return nil +} + +func (c *Client) joinPath(elem ...string) *url.URL { + p := path.Join(elem...) + + if p != "/api" && c.apiVersion > 0 && !strings.HasPrefix(p, "/api/v") { + p = path.Join("/api", "v"+strconv.Itoa(c.apiVersion), p) + } + + return c.Host.JoinPath(p) +} + +func (c *Client) do(req *http.Request) (json.RawMessage, error) { + req.Header.Set("X-API-Key", c.apiKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnprocessableEntity && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + var msg json.RawMessage + err = json.NewDecoder(resp.Body).Decode(&msg) + if err != nil { + if errors.Is(err, io.EOF) { + // empty body + return nil, nil + } + // other error + return nil, err + } + + // check for PowerDNS error message + if len(msg) > 0 && msg[0] == '{' { + var errInfo apiError + err = json.Unmarshal(msg, &errInfo) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, msg, err) + } + if errInfo.ShortMsg != "" { + return nil, fmt.Errorf("error talking to PDNS API: %w", errInfo) + } + } + + return msg, nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, strings.TrimSuffix(endpoint.String(), "/"), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/pdns/internal/client_test.go b/providers/dns/pdns/internal/client_test.go new file mode 100644 index 00000000..d102a5ef --- /dev/null +++ b/providers/dns/pdns/internal/client_test.go @@ -0,0 +1,352 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method: %s", req.Method), http.StatusBadRequest) + return + } + + apiKey := req.Header.Get("X-API-Key") + if apiKey != "secret" { + http.Error(rw, fmt.Sprintf("invalid credentials: %s", apiKey), http.StatusBadRequest) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + serverURL, _ := url.Parse(server.URL) + + client := NewClient(serverURL, "server", "secret") + client.HTTPClient = server.Client() + + return client +} + +func TestClient_joinPath(t *testing.T) { + testCases := []struct { + desc string + apiVersion int + baseURL string + uri string + expected string + }{ + { + desc: "host with path", + apiVersion: 1, + baseURL: "https://example.com/test", + uri: "/foo", + expected: "https://example.com/test/api/v1/foo", + }, + { + desc: "host with path + trailing slash", + apiVersion: 1, + baseURL: "https://example.com/test/", + uri: "/foo", + expected: "https://example.com/test/api/v1/foo", + }, + { + desc: "no URI", + apiVersion: 1, + baseURL: "https://example.com/test", + uri: "", + expected: "https://example.com/test/api/v1", + }, + { + desc: "host without path", + apiVersion: 1, + baseURL: "https://example.com", + uri: "/foo", + expected: "https://example.com/api/v1/foo", + }, + { + desc: "api", + apiVersion: 1, + baseURL: "https://example.com", + uri: "/api", + expected: "https://example.com/api", + }, + { + desc: "API version 0, host with path", + apiVersion: 0, + baseURL: "https://example.com/test", + uri: "/foo", + expected: "https://example.com/test/foo", + }, + { + desc: "API version 0, host with path + trailing slash", + apiVersion: 0, + baseURL: "https://example.com/test/", + uri: "/foo", + expected: "https://example.com/test/foo", + }, + { + desc: "API version 0, no URI", + apiVersion: 0, + baseURL: "https://example.com/test", + uri: "", + expected: "https://example.com/test", + }, + { + desc: "API version 0, host without path", + apiVersion: 0, + baseURL: "https://example.com", + uri: "/foo", + expected: "https://example.com/foo", + }, + { + desc: "API version 0, api", + apiVersion: 0, + baseURL: "https://example.com", + uri: "/api", + expected: "https://example.com/api", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + host, err := url.Parse(test.baseURL) + require.NoError(t, err) + + client := NewClient(host, "test", "secret") + client.apiVersion = test.apiVersion + + endpoint := client.joinPath(test.uri) + + assert.Equal(t, test.expected, endpoint.String()) + }) + } +} + +func TestClient_GetHostedZone(t *testing.T) { + client := setupTest(t, http.MethodGet, "/api/v1/servers/server/zones/example.org.", http.StatusOK, "zone.json") + client.apiVersion = 1 + + zone, err := client.GetHostedZone(context.Background(), "example.org.") + require.NoError(t, err) + + expected := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "api/v1/servers/localhost/zones/example.org.", + Kind: "Master", + RRSets: []RRSet{ + { + Name: "example.org.", + Type: "NS", + Records: []Record{{Content: "ns2.example.org."}, {Content: "ns1.example.org."}}, + TTL: 86400, + }, + { + Name: "example.org.", + Type: "SOA", + Records: []Record{{Content: "ns1.example.org. hostmaster.example.org. 2015120401 10800 15 604800 10800"}}, + TTL: 86400, + }, + { + Name: "ns1.example.org.", + Type: "A", + Records: []Record{{Content: "192.168.0.1"}}, + TTL: 86400, + }, + { + Name: "www.example.org.", + Type: "A", + Records: []Record{{Content: "192.168.0.2"}}, + TTL: 86400, + }, + }, + } + + assert.Equal(t, expected, zone) +} + +func TestClient_GetHostedZone_error(t *testing.T) { + client := setupTest(t, http.MethodGet, "/api/v1/servers/server/zones/example.org.", http.StatusUnprocessableEntity, "error.json") + client.apiVersion = 1 + + _, err := client.GetHostedZone(context.Background(), "example.org.") + require.ErrorAs(t, err, &apiError{}) +} + +func TestClient_GetHostedZone_v0(t *testing.T) { + client := setupTest(t, http.MethodGet, "/servers/server/zones/example.org.", http.StatusOK, "zone.json") + client.apiVersion = 0 + + zone, err := client.GetHostedZone(context.Background(), "example.org.") + require.NoError(t, err) + + expected := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "api/v1/servers/localhost/zones/example.org.", + Kind: "Master", + RRSets: []RRSet{ + { + Name: "example.org.", + Type: "NS", + Records: []Record{{Content: "ns2.example.org."}, {Content: "ns1.example.org."}}, + TTL: 86400, + }, + { + Name: "example.org.", + Type: "SOA", + Records: []Record{{Content: "ns1.example.org. hostmaster.example.org. 2015120401 10800 15 604800 10800"}}, + TTL: 86400, + }, + { + Name: "ns1.example.org.", + Type: "A", + Records: []Record{{Content: "192.168.0.1"}}, + TTL: 86400, + }, + { + Name: "www.example.org.", + Type: "A", + Records: []Record{{Content: "192.168.0.2"}}, + TTL: 86400, + }, + }, + } + + assert.Equal(t, expected, zone) +} + +func TestClient_UpdateRecords(t *testing.T) { + client := setupTest(t, http.MethodPatch, "/api/v1/servers/localhost/zones/example.org.", http.StatusOK, "zone.json") + client.apiVersion = 1 + + zone := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "api/v1/servers/localhost/zones/example.org.", + Kind: "Master", + } + + rrSets := RRSets{ + RRSets: []RRSet{{ + Name: "example.org.", + Type: "NS", + ChangeType: "REPLACE", + Records: []Record{{ + Content: "192.0.2.5", + Name: "ns1.example.org.", + TTL: 86400, + Type: "A", + }}, + }}, + } + + err := client.UpdateRecords(context.Background(), zone, rrSets) + require.NoError(t, err) +} + +func TestClient_UpdateRecords_v0(t *testing.T) { + client := setupTest(t, http.MethodPatch, "/servers/localhost/zones/example.org.", http.StatusOK, "zone.json") + client.apiVersion = 0 + + zone := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "servers/localhost/zones/example.org.", + Kind: "Master", + } + + rrSets := RRSets{ + RRSets: []RRSet{{ + Name: "example.org.", + Type: "NS", + ChangeType: "REPLACE", + Records: []Record{{ + Content: "192.0.2.5", + Name: "ns1.example.org.", + TTL: 86400, + Type: "A", + }}, + }}, + } + + err := client.UpdateRecords(context.Background(), zone, rrSets) + require.NoError(t, err) +} + +func TestClient_Notify(t *testing.T) { + client := setupTest(t, http.MethodPut, "/api/v1/servers/localhost/zones/example.org./notify", http.StatusOK, "") + client.apiVersion = 1 + + zone := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "api/v1/servers/localhost/zones/example.org.", + Kind: "Master", + } + + err := client.Notify(context.Background(), zone) + require.NoError(t, err) +} + +func TestClient_Notify_v0(t *testing.T) { + client := setupTest(t, http.MethodPut, "/api/v1/servers/localhost/zones/example.org./notify", http.StatusOK, "") + + zone := &HostedZone{ + ID: "example.org.", + Name: "example.org.", + URL: "servers/localhost/zones/example.org.", + Kind: "Master", + } + + err := client.Notify(context.Background(), zone) + require.NoError(t, err) +} + +func TestClient_getAPIVersion(t *testing.T) { + client := setupTest(t, http.MethodGet, "/api", http.StatusOK, "versions.json") + + version, err := client.getAPIVersion(context.Background()) + require.NoError(t, err) + + assert.Equal(t, 4, version) +} diff --git a/providers/dns/pdns/internal/fixtures/error.json b/providers/dns/pdns/internal/fixtures/error.json new file mode 100644 index 00000000..90b28efc --- /dev/null +++ b/providers/dns/pdns/internal/fixtures/error.json @@ -0,0 +1,3 @@ +{ + "error": "A human readable error message" +} diff --git a/providers/dns/pdns/internal/fixtures/versions.json b/providers/dns/pdns/internal/fixtures/versions.json new file mode 100644 index 00000000..4d7694d8 --- /dev/null +++ b/providers/dns/pdns/internal/fixtures/versions.json @@ -0,0 +1,18 @@ +[ + { + "url": "/fooa", + "version": 0 + }, + { + "url": "/foob", + "version": 4 + }, + { + "url": "/fooc", + "version": 2 + }, + { + "url": "/food", + "version": 1 + } +] diff --git a/providers/dns/pdns/internal/fixtures/zone.json b/providers/dns/pdns/internal/fixtures/zone.json new file mode 100644 index 00000000..dabf69b2 --- /dev/null +++ b/providers/dns/pdns/internal/fixtures/zone.json @@ -0,0 +1,69 @@ +{ + "id": "example.org.", + "url": "api/v1/servers/localhost/zones/example.org.", + "name": "example.org.", + "kind": "Master", + "dnssec": false, + "account": "", + "masters": [], + "serial": 2015120401, + "notified_serial": 0, + "last_check": 0, + "soa_edit_api": "", + "soa_edit": "", + "rrsets": [ + { + "comments": [], + "name": "example.org.", + "records": [ + { + "content": "ns2.example.org.", + "disabled": false + }, + { + "content": "ns1.example.org.", + "disabled": false + } + ], + "ttl": 86400, + "type": "NS" + }, + { + "comments": [], + "name": "example.org.", + "type": "SOA", + "ttl": 86400, + "records": [ + { + "disabled": false, + "content": "ns1.example.org. hostmaster.example.org. 2015120401 10800 15 604800 10800" + } + ] + }, + { + "comments": [], + "name": "ns1.example.org.", + "type": "A", + "ttl": 86400, + "records": [ + { + "content": "192.168.0.1", + "disabled": false + } + ] + }, + { + "comments": [], + "name": "www.example.org.", + "type": "A", + "ttl": 86400, + "records": [ + { + "disabled": false, + "content": "192.168.0.2" + } + ] + } + ] +} + diff --git a/providers/dns/pdns/internal/types.go b/providers/dns/pdns/internal/types.go new file mode 100644 index 00000000..df885d9c --- /dev/null +++ b/providers/dns/pdns/internal/types.go @@ -0,0 +1,48 @@ +package internal + +type Record struct { + Content string `json:"content"` + Disabled bool `json:"disabled"` + + // pre-v1 API + Name string `json:"name"` + Type string `json:"type"` + TTL int `json:"ttl,omitempty"` +} + +type HostedZone struct { + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + Kind string `json:"kind"` + RRSets []RRSet `json:"rrsets"` + + // pre-v1 API + Records []Record `json:"records"` +} + +type RRSet struct { + Name string `json:"name"` + Type string `json:"type"` + Kind string `json:"kind"` + ChangeType string `json:"changetype"` + Records []Record `json:"records,omitempty"` + TTL int `json:"ttl,omitempty"` +} + +type RRSets struct { + RRSets []RRSet `json:"rrsets"` +} + +type apiError struct { + ShortMsg string `json:"error"` +} + +func (a apiError) Error() string { + return a.ShortMsg +} + +type apiVersion struct { + URL string `json:"url"` + Version int `json:"version"` +} diff --git a/providers/dns/pdns/pdns.go b/providers/dns/pdns/pdns.go index 99e3c82e..9885362d 100644 --- a/providers/dns/pdns/pdns.go +++ b/providers/dns/pdns/pdns.go @@ -2,8 +2,7 @@ package pdns import ( - "bytes" - "encoding/json" + "context" "errors" "fmt" "net/http" @@ -13,6 +12,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/pdns/internal" ) // Environment variables names. @@ -55,8 +55,8 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - apiVersion int - config *Config + config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for pdns. @@ -94,15 +94,14 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("pdns: API URL missing") } - d := &DNSProvider{config: config} + client := internal.NewClient(config.Host, config.ServerName, config.APIKey) - apiVersion, err := d.getAPIVersion() + err := client.SetAPIVersion(context.Background()) if err != nil { log.Warnf("pdns: failed to get API version %v", err) } - d.apiVersion = apiVersion - return d, nil + return &DNSProvider{config: config, client: client}, nil } // Timeout returns the timeout and interval to use when checking for DNS @@ -115,20 +114,35 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getHostedZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) + if err != nil { + return fmt.Errorf("pdns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) + } + + ctx := context.Background() + + zone, err := d.client.GetHostedZone(ctx, authZone) if err != nil { return fmt.Errorf("pdns: %w", err) } name := info.EffectiveFQDN - - // pre-v1 API wants non-fqdn - if d.apiVersion == 0 { + if d.client.APIVersion() == 0 { + // pre-v1 API wants non-fqdn name = dns01.UnFqdn(info.EffectiveFQDN) } - rec := Record{ - Content: "\"" + info.Value + "\"", + // Look for existing records. + existingRRSet := findTxtRecord(zone, info.EffectiveFQDN) + + // merge the existing and new records + var records []internal.Record + if existingRRSet != nil { + records = existingRRSet.Records + } + + rec := internal.Record{ + Content: "\"" + info.EffectiveFQDN + "\"", Disabled: false, // pre-v1 API @@ -137,64 +151,51 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - // Look for existing records. - existingRrSet, err := d.findTxtRecord(info.EffectiveFQDN) - if err != nil { - return fmt.Errorf("pdns: %w", err) - } - - // merge the existing and new records - var records []Record - if existingRrSet != nil { - records = existingRrSet.Records - } - records = append(records, rec) - - rrsets := rrSets{ - RRSets: []rrSet{ + rrSets := internal.RRSets{ + RRSets: []internal.RRSet{ { Name: name, ChangeType: "REPLACE", Type: "TXT", Kind: "Master", TTL: d.config.TTL, - Records: records, + Records: append(records, rec), }, }, } - body, err := json.Marshal(rrsets) + err = d.client.UpdateRecords(ctx, zone, rrSets) if err != nil { return fmt.Errorf("pdns: %w", err) } - _, err = d.sendRequest(http.MethodPatch, zone.URL, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("pdns: %w", err) - } - - return d.notify(zone) + return d.client.Notify(ctx, zone) } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getHostedZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) + if err != nil { + return fmt.Errorf("pdns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) + } + + ctx := context.Background() + + zone, err := d.client.GetHostedZone(ctx, authZone) if err != nil { return fmt.Errorf("pdns: %w", err) } - set, err := d.findTxtRecord(info.EffectiveFQDN) - if err != nil { - return fmt.Errorf("pdns: %w", err) - } + set := findTxtRecord(zone, info.EffectiveFQDN) + if set == nil { return fmt.Errorf("pdns: no existing record found for %s", info.EffectiveFQDN) } - rrsets := rrSets{ - RRSets: []rrSet{ + rrSets := internal.RRSets{ + RRSets: []internal.RRSet{ { Name: set.Name, Type: set.Type, @@ -202,15 +203,21 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { }, }, } - body, err := json.Marshal(rrsets) + + err = d.client.UpdateRecords(ctx, zone, rrSets) if err != nil { return fmt.Errorf("pdns: %w", err) } - _, err = d.sendRequest(http.MethodPatch, zone.URL, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("pdns: %w", err) - } - - return d.notify(zone) + return d.client.Notify(ctx, zone) +} + +func findTxtRecord(zone *internal.HostedZone, fqdn string) *internal.RRSet { + for _, set := range zone.RRSets { + if set.Type == "TXT" && (set.Name == dns01.UnFqdn(fqdn) || set.Name == fqdn) { + return &set + } + } + + return nil } diff --git a/providers/dns/plesk/internal/client.go b/providers/dns/plesk/internal/client.go index cb61d388..9dd9d5ee 100644 --- a/providers/dns/plesk/internal/client.go +++ b/providers/dns/plesk/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/xml" "errors" "fmt" @@ -9,34 +10,37 @@ import ( "net/http" "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) // Client the Plesk API client. type Client struct { - HTTPClient *http.Client + login string + password string + baseURL *url.URL - login string - password string + HTTPClient *http.Client } // NewClient created a new Client. func NewClient(baseURL *url.URL, login string, password string) *Client { return &Client{ - HTTPClient: &http.Client{Timeout: 10 * time.Second}, - baseURL: baseURL, login: login, password: password, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, } } // GetSite gets a site. // https://docs.plesk.com/en-US/obsidian/api-rpc/about-xml-api/reference/managing-sites-domains/getting-information-about-sites.66583/ -func (c Client) GetSite(domain string) (int, error) { +func (c Client) GetSite(ctx context.Context, domain string) (int, error) { payload := RequestPacketType{Site: &SiteTypeRequest{Get: SiteGetRequest{Filter: &SiteFilterType{ Name: domain, }}}} - response, err := c.do(payload) + response, err := c.doRequest(ctx, payload) if err != nil { return 0, err } @@ -58,7 +62,7 @@ func (c Client) GetSite(domain string) (int, error) { // AddRecord adds a TXT record. // https://docs.plesk.com/en-US/obsidian/api-rpc/about-xml-api/reference/managing-dns/managing-dns-records/adding-dns-record.34798/ -func (c Client) AddRecord(siteID int, host, value string) (int, error) { +func (c Client) AddRecord(ctx context.Context, siteID int, host, value string) (int, error) { payload := RequestPacketType{DNS: &DNSInputType{AddRec: []AddRecRequest{{ SiteID: siteID, Type: "TXT", @@ -66,7 +70,7 @@ func (c Client) AddRecord(siteID int, host, value string) (int, error) { Value: value, }}}} - response, err := c.do(payload) + response, err := c.doRequest(ctx, payload) if err != nil { return 0, err } @@ -88,12 +92,12 @@ func (c Client) AddRecord(siteID int, host, value string) (int, error) { // DeleteRecord Deletes a TXT record. // https://docs.plesk.com/en-US/obsidian/api-rpc/about-xml-api/reference/managing-dns/managing-dns-records/deleting-dns-records.34864/ -func (c Client) DeleteRecord(recordID int) (int, error) { +func (c Client) DeleteRecord(ctx context.Context, recordID int) (int, error) { payload := RequestPacketType{DNS: &DNSInputType{DelRec: []DelRecRequest{{Filter: DNSSelectionFilterType{ ID: recordID, }}}}} - response, err := c.do(payload) + response, err := c.doRequest(ctx, payload) if err != nil { return 0, err } @@ -113,36 +117,45 @@ func (c Client) DeleteRecord(recordID int) (int, error) { return response.DNS.DelRec[0].Result.ID, nil } -func (c Client) do(payload RequestPacketType) (*ResponsePacketType, error) { +func (c Client) doRequest(ctx context.Context, payload RequestPacketType) (*ResponsePacketType, error) { endpoint := c.baseURL.JoinPath("/enterprise/control/agent.php") - body := &bytes.Buffer{} + body := new(bytes.Buffer) err := xml.NewEncoder(body).Encode(payload) if err != nil { return nil, err } - req, _ := http.NewRequest(http.MethodPost, endpoint.String(), body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), body) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + req.Header.Set("Content-Type", "text/xml") + req.Header.Set("Http_auth_login", c.login) req.Header.Set("Http_auth_passwd", c.password) resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, err + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { - all, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API error: %s", string(all)) + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } var response ResponsePacketType - err = xml.NewDecoder(resp.Body).Decode(&response) + err = xml.Unmarshal(raw, &response) if err != nil { - return nil, err + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return &response, nil diff --git a/providers/dns/plesk/internal/client_test.go b/providers/dns/plesk/internal/client_test.go index 0f9636e0..5d59a4c8 100644 --- a/providers/dns/plesk/internal/client_test.go +++ b/providers/dns/plesk/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -16,6 +17,7 @@ import ( func setupTest(t *testing.T, filename string) *Client { t.Helper() + mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) @@ -64,7 +66,7 @@ func setupTest(t *testing.T, filename string) *Client { func TestClient_GetSite(t *testing.T) { client := setupTest(t, "get-site.xml") - siteID, err := client.GetSite("example.com") + siteID, err := client.GetSite(context.Background(), "example.com") require.NoError(t, err) assert.Equal(t, 82, siteID) @@ -73,7 +75,7 @@ func TestClient_GetSite(t *testing.T) { func TestClient_GetSite_error(t *testing.T) { client := setupTest(t, "get-site-error.xml") - siteID, err := client.GetSite("example.com") + siteID, err := client.GetSite(context.Background(), "example.com") require.Error(t, err) assert.Equal(t, 0, siteID) @@ -82,7 +84,7 @@ func TestClient_GetSite_error(t *testing.T) { func TestClient_GetSite_system_error(t *testing.T) { client := setupTest(t, "global-error.xml") - siteID, err := client.GetSite("example.com") + siteID, err := client.GetSite(context.Background(), "example.com") require.Error(t, err) assert.Equal(t, 0, siteID) @@ -91,7 +93,7 @@ func TestClient_GetSite_system_error(t *testing.T) { func TestClient_AddRecord(t *testing.T) { client := setupTest(t, "add-record.xml") - recordID, err := client.AddRecord(123, "_acme-challenge.example.com", "txtTXTtxt") + recordID, err := client.AddRecord(context.Background(), 123, "_acme-challenge.example.com", "txtTXTtxt") require.NoError(t, err) assert.Equal(t, 4537, recordID) @@ -100,8 +102,8 @@ func TestClient_AddRecord(t *testing.T) { func TestClient_AddRecord_error(t *testing.T) { client := setupTest(t, "add-record-error.xml") - recordID, err := client.AddRecord(123, "_acme-challenge.example.com", "txtTXTtxt") - require.Error(t, err) + recordID, err := client.AddRecord(context.Background(), 123, "_acme-challenge.example.com", "txtTXTtxt") + require.ErrorAs(t, err, new(RecResult)) assert.Equal(t, 0, recordID) } @@ -109,8 +111,8 @@ func TestClient_AddRecord_error(t *testing.T) { func TestClient_AddRecord_system_error(t *testing.T) { client := setupTest(t, "global-error.xml") - recordID, err := client.AddRecord(123, "_acme-challenge.example.com", "txtTXTtxt") - require.Error(t, err) + recordID, err := client.AddRecord(context.Background(), 123, "_acme-challenge.example.com", "txtTXTtxt") + require.ErrorAs(t, err, new(*System)) assert.Equal(t, 0, recordID) } @@ -118,7 +120,7 @@ func TestClient_AddRecord_system_error(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) { client := setupTest(t, "delete-record.xml") - recordID, err := client.DeleteRecord(4537) + recordID, err := client.DeleteRecord(context.Background(), 4537) require.NoError(t, err) assert.Equal(t, 4537, recordID) @@ -127,8 +129,8 @@ func TestClient_DeleteRecord(t *testing.T) { func TestClient_DeleteRecord_error(t *testing.T) { client := setupTest(t, "delete-record-error.xml") - recordID, err := client.DeleteRecord(4537) - require.Error(t, err) + recordID, err := client.DeleteRecord(context.Background(), 4537) + require.ErrorAs(t, err, new(RecResult)) assert.Equal(t, 0, recordID) } @@ -136,8 +138,8 @@ func TestClient_DeleteRecord_error(t *testing.T) { func TestClient_DeleteRecord_system_error(t *testing.T) { client := setupTest(t, "global-error.xml") - recordID, err := client.DeleteRecord(4537) - require.Error(t, err) + recordID, err := client.DeleteRecord(context.Background(), 4537) + require.ErrorAs(t, err, new(*System)) assert.Equal(t, 0, recordID) } diff --git a/providers/dns/plesk/plesk.go b/providers/dns/plesk/plesk.go index eb3dacff..aa0fc1dd 100644 --- a/providers/dns/plesk/plesk.go +++ b/providers/dns/plesk/plesk.go @@ -2,6 +2,7 @@ package plesk import ( + "context" "errors" "fmt" "net/http" @@ -122,10 +123,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("plesk: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("plesk: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - siteID, err := d.client.GetSite(dns01.UnFqdn(authZone)) + ctx := context.Background() + + siteID, err := d.client.GetSite(ctx, dns01.UnFqdn(authZone)) if err != nil { return fmt.Errorf("plesk: failed to get site: %w", err) } @@ -135,7 +138,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("nodion: %w", err) } - recordID, err := d.client.AddRecord(siteID, subDomain, info.Value) + recordID, err := d.client.AddRecord(ctx, siteID, subDomain, info.Value) if err != nil { return fmt.Errorf("plesk: failed to add record: %w", err) } @@ -158,7 +161,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("plesk: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - _, err := d.client.DeleteRecord(recordID) + _, err := d.client.DeleteRecord(context.Background(), recordID) if err != nil { return fmt.Errorf("plesk: failed to delete record (%d): %w", recordID, err) } diff --git a/providers/dns/porkbun/porkbun.go b/providers/dns/porkbun/porkbun.go index 8a55b955..86435f37 100644 --- a/providers/dns/porkbun/porkbun.go +++ b/providers/dns/porkbun/porkbun.go @@ -171,7 +171,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func splitDomain(fqdn string) (string, string, error) { zone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", "", err + return "", "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } subDomain, err := dns01.ExtractSubDomain(fqdn, zone) diff --git a/providers/dns/porkbun/porkbun_test.go b/providers/dns/porkbun/porkbun_test.go index 9bc86adc..cdf022b5 100644 --- a/providers/dns/porkbun/porkbun_test.go +++ b/providers/dns/porkbun/porkbun_test.go @@ -1,7 +1,6 @@ package porkbun import ( - "fmt" "testing" "github.com/go-acme/lego/v4/platform/tester" @@ -144,7 +143,3 @@ func TestLiveCleanUp(t *testing.T) { err = provider.CleanUp(envTest.GetDomain(), "", "123d==") require.NoError(t, err) } - -func TestName(t *testing.T) { - fmt.Println(splitDomain("_acme-challenge.example.com.")) -} diff --git a/providers/dns/rackspace/client.go b/providers/dns/rackspace/client.go deleted file mode 100644 index b689c1e5..00000000 --- a/providers/dns/rackspace/client.go +++ /dev/null @@ -1,205 +0,0 @@ -package rackspace - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - - "github.com/go-acme/lego/v4/challenge/dns01" -) - -// APIKeyCredentials API credential. -type APIKeyCredentials struct { - Username string `json:"username"` - APIKey string `json:"apiKey"` -} - -// Auth auth credentials. -type Auth struct { - APIKeyCredentials `json:"RAX-KSKEY:apiKeyCredentials"` -} - -// AuthData Auth data. -type AuthData struct { - Auth `json:"auth"` -} - -// Identity Identity. -type Identity struct { - Access Access `json:"access"` -} - -// Access Access. -type Access struct { - ServiceCatalog []ServiceCatalog `json:"serviceCatalog"` - Token Token `json:"token"` -} - -// Token Token. -type Token struct { - ID string `json:"id"` -} - -// ServiceCatalog ServiceCatalog. -type ServiceCatalog struct { - Endpoints []Endpoint `json:"endpoints"` - Name string `json:"name"` -} - -// Endpoint Endpoint. -type Endpoint struct { - PublicURL string `json:"publicURL"` - TenantID string `json:"tenantId"` -} - -// ZoneSearchResponse represents the response when querying Rackspace DNS zones. -type ZoneSearchResponse struct { - TotalEntries int `json:"totalEntries"` - HostedZones []HostedZone `json:"domains"` -} - -// HostedZone HostedZone. -type HostedZone struct { - ID string `json:"id"` - Name string `json:"name"` -} - -// Records is the list of records sent/received from the DNS API. -type Records struct { - Record []Record `json:"records"` -} - -// Record represents a Rackspace DNS record. -type Record struct { - Name string `json:"name"` - Type string `json:"type"` - Data string `json:"data"` - TTL int `json:"ttl,omitempty"` - ID string `json:"id,omitempty"` -} - -// getHostedZoneID performs a lookup to get the DNS zone which needs -// modifying for a given FQDN. -func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - result, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/domains?name=%s", dns01.UnFqdn(authZone)), nil) - if err != nil { - return "", err - } - - var zoneSearchResponse ZoneSearchResponse - err = json.Unmarshal(result, &zoneSearchResponse) - if err != nil { - return "", err - } - - // If nothing was returned, or for whatever reason more than 1 was returned (the search uses exact match, so should not occur) - if zoneSearchResponse.TotalEntries != 1 { - return "", fmt.Errorf("found %d zones for %s in Rackspace for domain %s", zoneSearchResponse.TotalEntries, authZone, fqdn) - } - - return zoneSearchResponse.HostedZones[0].ID, nil -} - -// findTxtRecord searches a DNS zone for a TXT record with a specific name. -func (d *DNSProvider) findTxtRecord(fqdn string, zoneID string) (*Record, error) { - result, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/domains/%s/records?type=TXT&name=%s", zoneID, dns01.UnFqdn(fqdn)), nil) - if err != nil { - return nil, err - } - - var records Records - err = json.Unmarshal(result, &records) - if err != nil { - return nil, err - } - - switch len(records.Record) { - case 1: - case 0: - return nil, fmt.Errorf("no TXT record found for %s", fqdn) - default: - return nil, fmt.Errorf("more than 1 TXT record found for %s", fqdn) - } - - return &records.Record[0], nil -} - -// makeRequest is a wrapper function used for making DNS API requests. -func (d *DNSProvider) makeRequest(method, uri string, body io.Reader) (json.RawMessage, error) { - url := d.cloudDNSEndpoint + uri - - req, err := http.NewRequest(method, url, body) - if err != nil { - return nil, err - } - - req.Header.Set("X-Auth-Token", d.token) - req.Header.Set("Content-Type", "application/json") - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error querying DNS API: %w", err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - return nil, fmt.Errorf("request failed for %s %s. Response code: %d", method, url, resp.StatusCode) - } - - var r json.RawMessage - err = json.NewDecoder(resp.Body).Decode(&r) - if err != nil { - return nil, fmt.Errorf("JSON decode failed for %s %s. Response code: %d", method, url, resp.StatusCode) - } - - return r, nil -} - -func login(config *Config) (*Identity, error) { - authData := AuthData{ - Auth: Auth{ - APIKeyCredentials: APIKeyCredentials{ - Username: config.APIUser, - APIKey: config.APIKey, - }, - }, - } - - body, err := json.Marshal(authData) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(http.MethodPost, config.BaseURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error querying Identity API: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("authentication failed: response code: %d", resp.StatusCode) - } - - var identity Identity - err = json.NewDecoder(resp.Body).Decode(&identity) - if err != nil { - return nil, err - } - - return &identity, nil -} diff --git a/providers/dns/rackspace/internal/client.go b/providers/dns/rackspace/internal/client.go new file mode 100644 index 00000000..525556a2 --- /dev/null +++ b/providers/dns/rackspace/internal/client.go @@ -0,0 +1,216 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +type Client struct { + token string + + baseURL *url.URL + HTTPClient *http.Client +} + +func NewClient(endpoint string, token string) (*Client, error) { + baseURL, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + return &Client{ + token: token, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + }, nil +} + +// AddRecord Adds one record to a specified domain. +// https://docs.rackspace.com/docs/cloud-dns/v1/api-reference/records#add-records +func (c *Client) AddRecord(ctx context.Context, zoneID string, record Record) error { + endpoint := c.baseURL.JoinPath("domains", zoneID, "records") + + records := Records{Records: []Record{record}} + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, records) + if err != nil { + return err + } + + err = c.do(req, nil) + if err != nil { + return err + } + + return nil +} + +// DeleteRecord Deletes a record from the domain. +// https://docs.rackspace.com/docs/cloud-dns/v1/api-reference/records#delete-records +func (c *Client) DeleteRecord(ctx context.Context, zoneID, recordID string) error { + endpoint := c.baseURL.JoinPath("domains", zoneID, "records") + + query := endpoint.Query() + query.Set("id", recordID) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + err = c.do(req, nil) + if err != nil { + return err + } + + return nil +} + +// GetHostedZoneID performs a lookup to get the DNS zone which needs modifying for a given FQDN. +func (c *Client) GetHostedZoneID(ctx context.Context, fqdn string) (string, error) { + authZone, err := dns01.FindZoneByFqdn(fqdn) + if err != nil { + return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) + } + + zoneSearchResponse, err := c.listDomainsByName(ctx, dns01.UnFqdn(authZone)) + if err != nil { + return "", err + } + + // If nothing was returned, or for whatever reason more than 1 was returned (the search uses exact match, so should not occur) + if zoneSearchResponse.TotalEntries != 1 { + return "", fmt.Errorf("found %d zones for %s in Rackspace for domain %s", zoneSearchResponse.TotalEntries, authZone, fqdn) + } + + return zoneSearchResponse.HostedZones[0].ID, nil +} + +// listDomainsByName Filters domains by domain name. +// https://docs.rackspace.com/docs/cloud-dns/v1/api-reference/domains#list-domains-by-name +func (c *Client) listDomainsByName(ctx context.Context, domain string) (*ZoneSearchResponse, error) { + endpoint := c.baseURL.JoinPath("domains") + + query := endpoint.Query() + query.Set("name", domain) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + var zoneSearchResponse ZoneSearchResponse + err = c.do(req, &zoneSearchResponse) + if err != nil { + return nil, err + } + + return &zoneSearchResponse, nil +} + +// FindTxtRecord searches a DNS zone for a TXT record with a specific name. +func (c *Client) FindTxtRecord(ctx context.Context, fqdn string, zoneID string) (*Record, error) { + records, err := c.searchRecords(ctx, zoneID, dns01.UnFqdn(fqdn), "TXT") + if err != nil { + return nil, err + } + + switch len(records.Records) { + case 1: + case 0: + return nil, fmt.Errorf("no TXT record found for %s", fqdn) + default: + return nil, fmt.Errorf("more than 1 TXT record found for %s", fqdn) + } + + return &records.Records[0], nil +} + +// https://docs.rackspace.com/docs/cloud-dns/v1/api-reference/records#search-records +func (c *Client) searchRecords(ctx context.Context, zoneID, recordName, recordType string) (*Records, error) { + endpoint := c.baseURL.JoinPath("domains", zoneID, "records") + + query := endpoint.Query() + query.Set("type", recordType) + query.Set("name", recordName) + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + var records Records + err = c.do(req, &records) + if err != nil { + return nil, err + } + + return &records, nil +} + +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set("X-Auth-Token", c.token) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest[T string | *url.URL](ctx context.Context, method string, endpoint T, payload interface{}) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s", endpoint), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/rackspace/internal/client_test.go b/providers/dns/rackspace/internal/client_test.go new file mode 100644 index 00000000..993d34d9 --- /dev/null +++ b/providers/dns/rackspace/internal/client_test.go @@ -0,0 +1,108 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, handler http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client, err := NewClient(server.URL, "secret") + require.NoError(t, err) + + client.HTTPClient = server.Client() + + mux.HandleFunc(pattern, handler) + + return client +} + +func writeFixtureHandler(method, filename string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + if req.Header.Get("X-Auth-Token") != "secret" { + http.Error(rw, fmt.Sprintf("invalid token: %q", req.Header.Get("X-Auth-Token")), http.StatusUnauthorized) + return + } + + if filename == "" { + return + } + + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) + } +} + +func TestClient_AddRecord(t *testing.T) { + client := setupTest(t, "/domains/1234/records", writeFixtureHandler(http.MethodPost, "add-records.json")) + + err := client.AddRecord(context.Background(), "1234", Record{}) + require.NoError(t, err) +} + +func TestClient_DeleteRecord(t *testing.T) { + client := setupTest(t, "/domains/1234/records", writeFixtureHandler(http.MethodDelete, "")) + + err := client.DeleteRecord(context.Background(), "1234", "2725233") + require.NoError(t, err) +} + +func TestClient_searchRecords(t *testing.T) { + client := setupTest(t, "/domains/1234/records", writeFixtureHandler(http.MethodGet, "search-records.json")) + + records, err := client.searchRecords(context.Background(), "1234", "2725233", "A") + require.NoError(t, err) + + expected := &Records{ + TotalEntries: 6, + Records: []Record{ + {Name: "ftp.example.com", Type: "A", Data: "192.0.2.8", TTL: 5771, ID: "A-6817754"}, + {Name: "example.com", Type: "A", Data: "192.0.2.17", TTL: 86400, ID: "A-6822994"}, + {Name: "example.com", Type: "NS", Data: "ns.rackspace.com", TTL: 3600, ID: "NS-6251982"}, + {Name: "example.com", Type: "NS", Data: "ns2.rackspace.com", TTL: 3600, ID: "NS-6251983"}, + {Name: "example.com", Type: "MX", Data: "mail.example.com", TTL: 3600, ID: "MX-3151218"}, + {Name: "www.example.com", Type: "CNAME", Data: "example.com", TTL: 5400, ID: "CNAME-9778009"}, + }, + } + + assert.Equal(t, expected, records) +} + +func TestClient_listDomainsByName(t *testing.T) { + client := setupTest(t, "/domains", writeFixtureHandler(http.MethodGet, "list-domains-by-name.json")) + + domains, err := client.listDomainsByName(context.Background(), "1234") + require.NoError(t, err) + + expected := &ZoneSearchResponse{ + TotalEntries: 114, + HostedZones: []HostedZone{{ID: "2725257", Name: "sub1.example.com"}}, + } + + assert.Equal(t, expected, domains) +} diff --git a/providers/dns/rackspace/internal/fixtures/add-records.json b/providers/dns/rackspace/internal/fixtures/add-records.json new file mode 100644 index 00000000..18e507ff --- /dev/null +++ b/providers/dns/rackspace/internal/fixtures/add-records.json @@ -0,0 +1,61 @@ +{ + "totalEntries": 6, + "records": [ + { + "name": "ftp.example.com", + "id": "A-6817754", + "type": "A", + "data": "192.0.2.8", + "updated": "2011-05-19T13:07:08.000+0000", + "ttl": 5771, + "created": "2011-05-18T19:53:09.000+0000" + }, + { + "name": "example.com", + "id": "A-6822994", + "type": "A", + "data": "192.0.2.17", + "updated": "2011-06-24T01:12:52.000+0000", + "ttl": 86400, + "created": "2011-06-24T01:12:52.000+0000" + }, + { + "name": "example.com", + "id": "NS-6251982", + "type": "NS", + "data": "ns.rackspace.com", + "updated": "2011-06-24T01:12:51.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:51.000+0000" + }, + { + "name": "example.com", + "id": "NS-6251983", + "type": "NS", + "data": "ns2.rackspace.com", + "updated": "2011-06-24T01:12:51.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:51.000+0000" + }, + { + "name": "example.com", + "priority": 5, + "id": "MX-3151218", + "type": "MX", + "data": "mail.example.com", + "updated": "2011-06-24T01:12:53.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:53.000+0000" + }, + { + "name": "www.example.com", + "id": "CNAME-9778009", + "type": "CNAME", + "comment": "This is a comment on the CNAME record", + "data": "example.com", + "updated": "2011-06-24T01:12:54.000+0000", + "ttl": 5400, + "created": "2011-06-24T01:12:54.000+0000" + } + ] +} diff --git a/providers/dns/rackspace/internal/fixtures/delete-records_error.json b/providers/dns/rackspace/internal/fixtures/delete-records_error.json new file mode 100644 index 00000000..9fd735fb --- /dev/null +++ b/providers/dns/rackspace/internal/fixtures/delete-records_error.json @@ -0,0 +1,16 @@ +{ + "failedItems" : { + "faults" : [ { + "message" : "Object not Found.", + "code" : 404, + "details" : "Domain ID: 2720150; Record ID: 111111111" + }, { + "message" : "Object not Found.", + "code" : 404, + "details" : "Domain ID: 2720150; Record ID: 222222222" + } ] + }, + "message" : "One or more items could not be deleted.", + "code" : 500, + "details" : "See errors list for details." +} diff --git a/providers/dns/rackspace/internal/fixtures/list-domains-by-name.json b/providers/dns/rackspace/internal/fixtures/list-domains-by-name.json new file mode 100644 index 00000000..d34bd3b0 --- /dev/null +++ b/providers/dns/rackspace/internal/fixtures/list-domains-by-name.json @@ -0,0 +1,13 @@ +{ + "domains": [ + { + "name": "sub1.example.com", + "id": "2725257", + "comment": "1st sample subdomain", + "updated": "2011-06-23T03:09:34.000+0000", + "emailAddress": "sample@rackspace.com", + "created": "2011-06-23T03:09:33.000+0000" + } + ], + "totalEntries": 114 +} diff --git a/providers/dns/rackspace/internal/fixtures/search-records.json b/providers/dns/rackspace/internal/fixtures/search-records.json new file mode 100644 index 00000000..18e507ff --- /dev/null +++ b/providers/dns/rackspace/internal/fixtures/search-records.json @@ -0,0 +1,61 @@ +{ + "totalEntries": 6, + "records": [ + { + "name": "ftp.example.com", + "id": "A-6817754", + "type": "A", + "data": "192.0.2.8", + "updated": "2011-05-19T13:07:08.000+0000", + "ttl": 5771, + "created": "2011-05-18T19:53:09.000+0000" + }, + { + "name": "example.com", + "id": "A-6822994", + "type": "A", + "data": "192.0.2.17", + "updated": "2011-06-24T01:12:52.000+0000", + "ttl": 86400, + "created": "2011-06-24T01:12:52.000+0000" + }, + { + "name": "example.com", + "id": "NS-6251982", + "type": "NS", + "data": "ns.rackspace.com", + "updated": "2011-06-24T01:12:51.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:51.000+0000" + }, + { + "name": "example.com", + "id": "NS-6251983", + "type": "NS", + "data": "ns2.rackspace.com", + "updated": "2011-06-24T01:12:51.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:51.000+0000" + }, + { + "name": "example.com", + "priority": 5, + "id": "MX-3151218", + "type": "MX", + "data": "mail.example.com", + "updated": "2011-06-24T01:12:53.000+0000", + "ttl": 3600, + "created": "2011-06-24T01:12:53.000+0000" + }, + { + "name": "www.example.com", + "id": "CNAME-9778009", + "type": "CNAME", + "comment": "This is a comment on the CNAME record", + "data": "example.com", + "updated": "2011-06-24T01:12:54.000+0000", + "ttl": 5400, + "created": "2011-06-24T01:12:54.000+0000" + } + ] +} diff --git a/providers/dns/rackspace/internal/fixtures/tokens.json b/providers/dns/rackspace/internal/fixtures/tokens.json new file mode 100644 index 00000000..ca44076f --- /dev/null +++ b/providers/dns/rackspace/internal/fixtures/tokens.json @@ -0,0 +1,87 @@ +{ + "access": { + "token": { + "id": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "expires": "2014-11-24T22:05:39.115Z", + "tenant": { + "id": "110011", + "name": "110011" + }, + "RAX-AUTH:authenticatedBy": [ + "APIKEY" + ] + }, + "serviceCatalog": [ + { + "name": "cloudDatabases", + "endpoints": [ + { + "publicURL": "https://syd.databases.api.rackspacecloud.com/v1.0/110011", + "region": "SYD", + "tenantId": "110011" + }, + { + "publicURL": "https://dfw.databases.api.rackspacecloud.com/v1.0/110011", + "region": "DFW", + "tenantId": "110011" + }, + { + "publicURL": "https://ord.databases.api.rackspacecloud.com/v1.0/110011", + "region": "ORD", + "tenantId": "110011" + }, + { + "publicURL": "https://iad.databases.api.rackspacecloud.com/v1.0/110011", + "region": "IAD", + "tenantId": "110011" + }, + { + "publicURL": "https://hkg.databases.api.rackspacecloud.com/v1.0/110011", + "region": "HKG", + "tenantId": "110011" + } + ], + "type": "rax:database" + }, + { + "name": "cloudDNS", + "endpoints": [ + { + "publicURL": "https://dns.api.rackspacecloud.com/v1.0/110011", + "tenantId": "110011" + } + ], + "type": "rax:dns" + }, + { + "name": "rackCDN", + "endpoints": [ + { + "internalURL": "https://global.cdn.api.rackspacecloud.com/v1.0/110011", + "publicURL": "https://global.cdn.api.rackspacecloud.com/v1.0/110011", + "tenantId": "110011" + } + ], + "type": "rax:cdn" + } + ], + "user": { + "id": "123456", + "roles": [ + { + "description": "A Role that allows a user access to keystone Service methods", + "id": "6", + "name": "compute:default", + "tenantId": "110011" + }, + { + "description": "User Admin Role.", + "id": "3", + "name": "identity:user-admin" + } + ], + "name": "jsmith", + "RAX-AUTH:defaultRegion": "ORD" + } + } +} diff --git a/providers/dns/rackspace/internal/identity.go b/providers/dns/rackspace/internal/identity.go new file mode 100644 index 00000000..062350df --- /dev/null +++ b/providers/dns/rackspace/internal/identity.go @@ -0,0 +1,74 @@ +package internal + +import ( + "context" + "encoding/json" + "io" + "net/http" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultIdentityURL represents the Identity API endpoint to call. +const DefaultIdentityURL = "https://identity.api.rackspacecloud.com/v2.0/tokens" + +type Identifier struct { + baseURL string + httpClient *http.Client +} + +// NewIdentifier creates a new Identifier. +func NewIdentifier(httpClient *http.Client, baseURL string) *Identifier { + if httpClient == nil { + httpClient = &http.Client{Timeout: 5 * time.Second} + } + + if baseURL == "" { + baseURL = DefaultIdentityURL + } + + return &Identifier{baseURL: baseURL, httpClient: httpClient} +} + +// Login sends an authentication request. +// https://docs.rackspace.com/docs/cloud-dns/v1/getting-started/authenticate +func (a *Identifier) Login(ctx context.Context, apiUser, apiKey string) (*Identity, error) { + authData := AuthData{ + Auth: Auth{ + APIKeyCredentials: APIKeyCredentials{ + Username: apiUser, + APIKey: apiKey, + }, + }, + } + + req, err := newJSONRequest(ctx, http.MethodPost, a.baseURL, authData) + if err != nil { + return nil, err + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + var identity Identity + err = json.Unmarshal(raw, &identity) + if err != nil { + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return &identity, nil +} diff --git a/providers/dns/rackspace/internal/identity_test.go b/providers/dns/rackspace/internal/identity_test.go new file mode 100644 index 00000000..9ba5abb5 --- /dev/null +++ b/providers/dns/rackspace/internal/identity_test.go @@ -0,0 +1,95 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeIdentityFixtureHandler(method, filename string) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + if filename == "" { + return + } + + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) + } +} + +func TestIdentifier_Login(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + identifier := NewIdentifier(server.Client(), server.URL) + + mux.HandleFunc("/", writeIdentityFixtureHandler(http.MethodPost, "tokens.json")) + + identity, err := identifier.Login(context.Background(), "user", "secret") + require.NoError(t, err) + + expected := &Identity{ + Access: Access{ + Token: Token{ + ID: "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + Expires: "2014-11-24T22:05:39.115Z", + Tenant: Tenant{ID: "110011", Name: "110011"}, + RAXAUTHAuthenticatedBy: []string{"APIKEY"}, + }, + ServiceCatalog: []ServiceCatalog{ + { + Name: "cloudDatabases", + Type: "rax:database", + Endpoints: []Endpoint{ + {PublicURL: "https://syd.databases.api.rackspacecloud.com/v1.0/110011", Region: "SYD", TenantID: "110011", InternalURL: ""}, + {PublicURL: "https://dfw.databases.api.rackspacecloud.com/v1.0/110011", Region: "DFW", TenantID: "110011", InternalURL: ""}, + {PublicURL: "https://ord.databases.api.rackspacecloud.com/v1.0/110011", Region: "ORD", TenantID: "110011", InternalURL: ""}, + {PublicURL: "https://iad.databases.api.rackspacecloud.com/v1.0/110011", Region: "IAD", TenantID: "110011", InternalURL: ""}, + {PublicURL: "https://hkg.databases.api.rackspacecloud.com/v1.0/110011", Region: "HKG", TenantID: "110011", InternalURL: ""}, + }, + }, + { + Name: "cloudDNS", + Type: "rax:dns", + Endpoints: []Endpoint{{PublicURL: "https://dns.api.rackspacecloud.com/v1.0/110011", Region: "", TenantID: "110011", InternalURL: ""}}, + }, + { + Name: "rackCDN", + Type: "rax:cdn", + Endpoints: []Endpoint{{PublicURL: "https://global.cdn.api.rackspacecloud.com/v1.0/110011", Region: "", TenantID: "110011", InternalURL: "https://global.cdn.api.rackspacecloud.com/v1.0/110011"}}, + }, + }, + User: User{ + ID: "123456", + Roles: []Role{ + {Description: "A Role that allows a user access to keystone Service methods", ID: "6", Name: "compute:default", TenantID: "110011"}, + {Description: "User Admin Role.", ID: "3", Name: "identity:user-admin", TenantID: ""}, + }, + Name: "jsmith", + RAXAUTHDefaultRegion: "ORD", + }, + }, + } + + assert.Equal(t, expected, identity) +} diff --git a/providers/dns/rackspace/internal/types.go b/providers/dns/rackspace/internal/types.go new file mode 100644 index 00000000..b34d3a33 --- /dev/null +++ b/providers/dns/rackspace/internal/types.go @@ -0,0 +1,104 @@ +package internal + +// Authentication response. + +// Identity api structure. +type Identity struct { + Access Access `json:"access"` +} + +// Access api structure. +type Access struct { + Token Token `json:"token"` + ServiceCatalog []ServiceCatalog `json:"serviceCatalog"` + User User `json:"user"` +} + +// Token api structure. +type Token struct { + ID string `json:"id"` + Expires string `json:"expires"` + Tenant Tenant `json:"tenant"` + RAXAUTHAuthenticatedBy []string `json:"RAX-AUTH:authenticatedBy"` +} + +// ServiceCatalog service catalog. +type ServiceCatalog struct { + Name string `json:"name"` + Type string `json:"type"` + Endpoints []Endpoint `json:"endpoints"` +} + +type Tenant struct { + ID string `json:"id"` + Name string `json:"name"` +} + +// Endpoint api structure. +type Endpoint struct { + PublicURL string `json:"publicURL"` + Region string `json:"region,omitempty"` + TenantID string `json:"tenantId"` + InternalURL string `json:"internalURL,omitempty"` +} + +type Role struct { + Description string `json:"description"` + ID string `json:"id"` + Name string `json:"name"` + TenantID string `json:"tenantId,omitempty"` +} + +type User struct { + ID string `json:"id"` + Roles []Role `json:"roles"` + Name string `json:"name"` + RAXAUTHDefaultRegion string `json:"RAX-AUTH:defaultRegion"` +} + +// Authentication request. + +// AuthData api structure. +type AuthData struct { + Auth `json:"auth"` +} + +// Auth api structure. +type Auth struct { + APIKeyCredentials `json:"RAX-KSKEY:apiKeyCredentials"` +} + +// APIKeyCredentials api structure. +type APIKeyCredentials struct { + Username string `json:"username"` + APIKey string `json:"apiKey"` +} + +// API responses. + +// ZoneSearchResponse represents the response when querying Rackspace DNS zones. +type ZoneSearchResponse struct { + TotalEntries int `json:"totalEntries"` + HostedZones []HostedZone `json:"domains"` +} + +// HostedZone api structure. +type HostedZone struct { + ID string `json:"id"` + Name string `json:"name"` +} + +// Records is the list of records sent/received from the DNS API. +type Records struct { + TotalEntries int `json:"totalEntries,omitempty"` + Records []Record `json:"records,omitempty"` +} + +// Record represents a Rackspace DNS record. +type Record struct { + Name string `json:"name"` + Type string `json:"type"` + Data string `json:"data"` + TTL int `json:"ttl,omitempty"` + ID string `json:"id,omitempty"` +} diff --git a/providers/dns/rackspace/rackspace.go b/providers/dns/rackspace/rackspace.go index 5b79c2a9..c877de3b 100644 --- a/providers/dns/rackspace/rackspace.go +++ b/providers/dns/rackspace/rackspace.go @@ -2,8 +2,7 @@ package rackspace import ( - "bytes" - "encoding/json" + "context" "errors" "fmt" "net/http" @@ -11,11 +10,9 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/rackspace/internal" ) -// defaultBaseURL represents the Identity API endpoint to call. -const defaultBaseURL = "https://identity.api.rackspacecloud.com/v2.0/tokens" - // Environment variables names. const ( envNamespace = "RACKSPACE_" @@ -43,7 +40,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - BaseURL: defaultBaseURL, + BaseURL: internal.DefaultIdentityURL, TTL: env.GetOrDefaultInt(EnvTTL, 300), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), @@ -55,7 +52,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + token string cloudDNSEndpoint string } @@ -87,7 +86,9 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("rackspace: credentials missing") } - identity, err := login(config) + identifier := internal.NewIdentifier(config.HTTPClient, config.BaseURL) + + identity, err := identifier.Login(context.Background(), config.APIUser, config.APIKey) if err != nil { return nil, fmt.Errorf("rackspace: %w", err) } @@ -105,8 +106,18 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("rackspace: failed to populate DNS endpoint, check Rackspace API for changes") } + client, err := internal.NewClient(dnsEndpoint, identity.Access.Token.ID) + if err != nil { + return nil, fmt.Errorf("rackspace: %w", err) + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + return &DNSProvider{ config: config, + client: client, token: identity.Access.Token.ID, cloudDNSEndpoint: dnsEndpoint, }, nil @@ -116,29 +127,25 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneID, err := d.getHostedZoneID(info.EffectiveFQDN) + ctx := context.Background() + + zoneID, err := d.client.GetHostedZoneID(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("rackspace: %w", err) } - rec := Records{ - Record: []Record{{ - Name: dns01.UnFqdn(info.EffectiveFQDN), - Type: "TXT", - Data: info.Value, - TTL: d.config.TTL, - }}, + record := internal.Record{ + Name: dns01.UnFqdn(info.EffectiveFQDN), + Type: "TXT", + Data: info.Value, + TTL: d.config.TTL, } - body, err := json.Marshal(rec) + err = d.client.AddRecord(ctx, zoneID, record) if err != nil { return fmt.Errorf("rackspace: %w", err) } - _, err = d.makeRequest(http.MethodPost, fmt.Sprintf("/domains/%s/records", zoneID), bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("rackspace: %w", err) - } return nil } @@ -146,20 +153,23 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zoneID, err := d.getHostedZoneID(info.EffectiveFQDN) + ctx := context.Background() + + zoneID, err := d.client.GetHostedZoneID(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("rackspace: %w", err) } - record, err := d.findTxtRecord(info.EffectiveFQDN, zoneID) + record, err := d.client.FindTxtRecord(ctx, info.EffectiveFQDN, zoneID) if err != nil { return fmt.Errorf("rackspace: %w", err) } - _, err = d.makeRequest(http.MethodDelete, fmt.Sprintf("/domains/%s/records?id=%s", zoneID, record.ID), nil) + err = d.client.DeleteRecord(ctx, zoneID, record.ID) if err != nil { return fmt.Errorf("rackspace: %w", err) } + return nil } diff --git a/providers/dns/rackspace/rackspace_test.go b/providers/dns/rackspace/rackspace_test.go index 8cea7c44..1e120e09 100644 --- a/providers/dns/rackspace/rackspace_test.go +++ b/providers/dns/rackspace/rackspace_test.go @@ -1,6 +1,7 @@ package rackspace import ( + "bytes" "fmt" "io" "net/http" @@ -124,14 +125,14 @@ func identityHandler(dnsEndpoint string) http.Handler { return } - if string(reqBody) != `{"auth":{"RAX-KSKEY:apiKeyCredentials":{"username":"testUser","apiKey":"testKey"}}}` { - w.WriteHeader(http.StatusBadRequest) + if string(bytes.TrimSpace(reqBody)) != `{"auth":{"RAX-KSKEY:apiKeyCredentials":{"username":"testUser","apiKey":"testKey"}}}` { + http.Error(w, fmt.Sprintf("invalid body: %s", string(reqBody)), http.StatusBadRequest) return } resp := strings.Replace(identityResponseMock, "https://dns.api.rackspacecloud.com/v1.0/123456", dnsEndpoint, 1) w.WriteHeader(http.StatusOK) - fmt.Fprint(w, resp) + _, _ = fmt.Fprint(w, resp) }) } @@ -142,7 +143,7 @@ func dnsHandler() *http.ServeMux { mux.HandleFunc("/123456/domains", func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("name") == "example.com" { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, zoneDetailsMock) + _, _ = fmt.Fprint(w, zoneDetailsMock) return } w.WriteHeader(http.StatusBadRequest) @@ -158,27 +159,30 @@ func dnsHandler() *http.ServeMux { return } - if string(reqBody) != `{"records":[{"name":"_acme-challenge.example.com","type":"TXT","data":"pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM","ttl":300}]}` { - w.WriteHeader(http.StatusBadRequest) + if string(bytes.TrimSpace(reqBody)) != `{"records":[{"name":"_acme-challenge.example.com","type":"TXT","data":"pW9ZKG0xz_PCriK-nCMOjADy9eJcgGWIzkkj2fN4uZM","ttl":300}]}` { + http.Error(w, fmt.Sprintf("invalid body: %s", string(reqBody)), http.StatusBadRequest) return } w.WriteHeader(http.StatusAccepted) - fmt.Fprint(w, recordResponseMock) - // Used by `findTxtRecord()` finding `record.ID` "?type=TXT&name=_acme-challenge.example.com" + _, _ = fmt.Fprint(w, recordResponseMock) + + // Used by `findTxtRecord()` finding `record.ID` "?type=TXT&name=_acme-challenge.example.com" case http.MethodGet: if r.URL.Query().Get("type") == "TXT" && r.URL.Query().Get("name") == "_acme-challenge.example.com" { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, recordDetailsMock) + _, _ = fmt.Fprint(w, recordDetailsMock) return } + w.WriteHeader(http.StatusBadRequest) return - // Used by `CleanUp()` deleting the TXT record "?id=445566" + + // Used by `CleanUp()` deleting the TXT record "?id=445566" case http.MethodDelete: if r.URL.Query().Get("id") == "TXT-654321" { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, recordDeleteMock) + _, _ = fmt.Fprint(w, recordDeleteMock) return } w.WriteHeader(http.StatusBadRequest) @@ -186,8 +190,7 @@ func dnsHandler() *http.ServeMux { }) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Not Found for Request: (%+v)\n\n", r) + http.Error(w, fmt.Sprintf("Not Found for Request: (%+v)", r), http.StatusNotFound) }) return mux diff --git a/providers/dns/regru/internal/client.go b/providers/dns/regru/internal/client.go index cac3a815..b4b81dc0 100644 --- a/providers/dns/regru/internal/client.go +++ b/providers/dns/regru/internal/client.go @@ -1,11 +1,15 @@ package internal import ( + "context" "encoding/json" "fmt" "io" "net/http" "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.reg.ru/api/regru2/" @@ -15,36 +19,36 @@ type Client struct { username string password string - BaseURL string + baseURL *url.URL HTTPClient *http.Client } // NewClient Creates a reg.ru client. func NewClient(username, password string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ username: username, password: password, - BaseURL: defaultBaseURL, - HTTPClient: http.DefaultClient, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // RemoveTxtRecord removes a TXT record. // https://www.reg.ru/support/help/api2#zone_remove_record -func (c Client) RemoveTxtRecord(domain, subDomain, content string) error { +func (c Client) RemoveTxtRecord(ctx context.Context, domain, subDomain, content string) error { request := RemoveRecordRequest{ - Username: c.username, - Password: c.password, - Domains: []Domain{ - {DName: domain}, - }, + Username: c.username, + Password: c.password, + Domains: []Domain{{DName: domain}}, SubDomain: subDomain, Content: content, RecordType: "TXT", OutputContentType: "plain", } - resp, err := c.do(request, "zone", "remove_record") + resp, err := c.doRequest(ctx, request, "zone", "remove_record") if err != nil { return err } @@ -54,19 +58,17 @@ func (c Client) RemoveTxtRecord(domain, subDomain, content string) error { // AddTXTRecord adds a TXT record. // https://www.reg.ru/support/help/api2#zone_add_txt -func (c Client) AddTXTRecord(domain, subDomain, content string) error { +func (c Client) AddTXTRecord(ctx context.Context, domain, subDomain, content string) error { request := AddTxtRequest{ - Username: c.username, - Password: c.password, - Domains: []Domain{ - {DName: domain}, - }, + Username: c.username, + Password: c.password, + Domains: []Domain{{DName: domain}}, SubDomain: subDomain, Text: content, OutputContentType: "plain", } - resp, err := c.do(request, "zone", "add_txt") + resp, err := c.doRequest(ctx, request, "zone", "add_txt") if err != nil { return err } @@ -74,15 +76,12 @@ func (c Client) AddTXTRecord(domain, subDomain, content string) error { return resp.HasError() } -func (c Client) do(request interface{}, fragments ...string) (*APIResponse, error) { - endpoint, err := c.createEndpoint(fragments...) - if err != nil { - return nil, err - } +func (c Client) doRequest(ctx context.Context, request any, fragments ...string) (*APIResponse, error) { + endpoint := c.baseURL.JoinPath(fragments...) inputData, err := json.Marshal(request) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create input data: %w", err) } query := endpoint.Query() @@ -90,47 +89,44 @@ func (c Client) do(request interface{}, fragments ...string) (*APIResponse, erro query.Add("input_format", "json") endpoint.RawQuery = query.Encode() - resp, err := http.Get(endpoint.String()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), http.NoBody) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { - all, errB := io.ReadAll(resp.Body) - if errB != nil { - return nil, fmt.Errorf("API error, status code: %d", resp.StatusCode) - } - - var apiResp APIResponse - errB = json.Unmarshal(all, &apiResp) - if errB != nil { - return nil, fmt.Errorf("API error, status code: %d, %s", resp.StatusCode, string(all)) - } - - return nil, fmt.Errorf("%w, status code: %d", apiResp, resp.StatusCode) + return nil, parseError(req, resp) } - all, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, errutils.NewReadResponseError(req, resp.StatusCode, err) } var apiResp APIResponse - err = json.Unmarshal(all, &apiResp) + err = json.Unmarshal(raw, &apiResp) if err != nil { - return nil, err + return nil, errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return &apiResp, nil } -func (c Client) createEndpoint(fragments ...string) (*url.URL, error) { - baseURL, err := url.Parse(c.BaseURL) +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errAPI APIResponse + err := json.Unmarshal(raw, &errAPI) if err != nil { - return nil, err + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - return baseURL.JoinPath(fragments...), nil + return fmt.Errorf("status code: %d, %w", resp.StatusCode, errAPI) } diff --git a/providers/dns/regru/internal/client_test.go b/providers/dns/regru/internal/client_test.go index 0b7e7771..a599a2c3 100644 --- a/providers/dns/regru/internal/client_test.go +++ b/providers/dns/regru/internal/client_test.go @@ -1,7 +1,9 @@ package internal import ( + "context" "net/http" + "net/url" "os" "testing" "time" @@ -22,7 +24,7 @@ func TestRemoveRecord(t *testing.T) { client := NewClient(officialTestUser, officialTestPassword) client.HTTPClient = &http.Client{Timeout: 30 * time.Second} - err := client.RemoveTxtRecord("test.ru", "_acme-challenge", "txttxttxt") + err := client.RemoveTxtRecord(context.Background(), "test.ru", "_acme-challenge", "txttxttxt") require.NoError(t, err) } @@ -65,9 +67,9 @@ func TestRemoveRecord_errors(t *testing.T) { client := NewClient(test.username, test.username) client.HTTPClient = &http.Client{Timeout: 30 * time.Second} - client.BaseURL = test.baseURL + client.baseURL, _ = url.Parse(test.baseURL) - err := client.RemoveTxtRecord(test.domain, "_acme-challenge", "txttxttxt") + err := client.RemoveTxtRecord(context.Background(), test.domain, "_acme-challenge", "txttxttxt") require.EqualError(t, err, test.expected) }) } @@ -80,7 +82,7 @@ func TestAddTXTRecord(t *testing.T) { client := NewClient(officialTestUser, officialTestPassword) client.HTTPClient = &http.Client{Timeout: 30 * time.Second} - err := client.AddTXTRecord("test.ru", "_acme-challenge", "txttxttxt") + err := client.AddTXTRecord(context.Background(), "test.ru", "_acme-challenge", "txttxttxt") require.NoError(t, err) } @@ -123,9 +125,9 @@ func TestAddTXTRecord_errors(t *testing.T) { client := NewClient(test.username, test.username) client.HTTPClient = &http.Client{Timeout: 30 * time.Second} - client.BaseURL = test.baseURL + client.baseURL, _ = url.Parse(test.baseURL) - err := client.AddTXTRecord(test.domain, "_acme-challenge", "txttxttxt") + err := client.AddTXTRecord(context.Background(), test.domain, "_acme-challenge", "txttxttxt") require.EqualError(t, err, test.expected) }) } diff --git a/providers/dns/regru/internal/model.go b/providers/dns/regru/internal/types.go similarity index 100% rename from providers/dns/regru/internal/model.go rename to providers/dns/regru/internal/types.go diff --git a/providers/dns/regru/regru.go b/providers/dns/regru/regru.go index 5c5e5c95..b9ab272f 100644 --- a/providers/dns/regru/regru.go +++ b/providers/dns/regru/regru.go @@ -2,6 +2,7 @@ package regru import ( + "context" "errors" "fmt" "net/http" @@ -101,7 +102,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("regru: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("regru: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -109,7 +110,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("regru: %w", err) } - err = d.client.AddTXTRecord(dns01.UnFqdn(authZone), subDomain, info.Value) + err = d.client.AddTXTRecord(context.Background(), dns01.UnFqdn(authZone), subDomain, info.Value) if err != nil { return fmt.Errorf("regru: failed to create TXT records [domain: %s, sub domain: %s]: %w", dns01.UnFqdn(authZone), subDomain, err) @@ -124,7 +125,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("regru: could not find zone for domain %q and fqdn %q : %w", domain, info.EffectiveFQDN, err) + return fmt.Errorf("regru: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -132,7 +133,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("regru: %w", err) } - err = d.client.RemoveTxtRecord(dns01.UnFqdn(authZone), subDomain, info.Value) + err = d.client.RemoveTxtRecord(context.Background(), dns01.UnFqdn(authZone), subDomain, info.Value) if err != nil { return fmt.Errorf("regru: failed to remove TXT records [domain: %s, sub domain: %s]: %w", dns01.UnFqdn(authZone), subDomain, err) diff --git a/providers/dns/rfc2136/rfc2136.go b/providers/dns/rfc2136/rfc2136.go index 981302e1..bcff990d 100644 --- a/providers/dns/rfc2136/rfc2136.go +++ b/providers/dns/rfc2136/rfc2136.go @@ -179,7 +179,7 @@ func (d *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { c.SingleInflight = true // TSIG authentication / msg signing - if len(d.config.TSIGKey) > 0 && len(d.config.TSIGSecret) > 0 { + if d.config.TSIGKey != "" && d.config.TSIGSecret != "" { key := strings.ToLower(dns.Fqdn(d.config.TSIGKey)) alg := dns.Fqdn(d.config.TSIGAlgorithm) m.SetTsig(key, alg, 300, time.Now().Unix()) diff --git a/providers/dns/rimuhosting/rimuhosting.go b/providers/dns/rimuhosting/rimuhosting.go index 52c6f769..09b31d4f 100644 --- a/providers/dns/rimuhosting/rimuhosting.go +++ b/providers/dns/rimuhosting/rimuhosting.go @@ -2,6 +2,7 @@ package rimuhosting import ( + "context" "errors" "fmt" "net/http" @@ -96,20 +97,22 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - records, err := d.client.FindTXTRecords(dns01.UnFqdn(info.EffectiveFQDN)) + ctx := context.Background() + + records, err := d.client.FindTXTRecords(ctx, dns01.UnFqdn(info.EffectiveFQDN)) if err != nil { return fmt.Errorf("rimuhosting: failed to find record(s) for %s: %w", domain, err) } actions := []rimuhosting.ActionParameter{ - rimuhosting.AddRecord(dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL), + rimuhosting.NewAddRecordAction(dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL), } for _, record := range records { - actions = append(actions, rimuhosting.AddRecord(record.Name, record.Content, d.config.TTL)) + actions = append(actions, rimuhosting.NewAddRecordAction(record.Name, record.Content, d.config.TTL)) } - _, err = d.client.DoActions(actions...) + _, err = d.client.DoActions(ctx, actions...) if err != nil { return fmt.Errorf("rimuhosting: failed to add record(s) for %s: %w", domain, err) } @@ -121,9 +124,9 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - action := rimuhosting.DeleteRecord(dns01.UnFqdn(info.EffectiveFQDN), info.Value) + action := rimuhosting.NewDeleteRecordAction(dns01.UnFqdn(info.EffectiveFQDN), info.Value) - _, err := d.client.DoActions(action) + _, err := d.client.DoActions(context.Background(), action) if err != nil { return fmt.Errorf("rimuhosting: failed to delete record for %s: %w", domain, err) } diff --git a/providers/dns/route53/route53.go b/providers/dns/route53/route53.go index 9ddeece2..df504cbd 100644 --- a/providers/dns/route53/route53.go +++ b/providers/dns/route53/route53.go @@ -278,7 +278,7 @@ func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) { authZone, err := dns01.FindZoneByFqdn(fqdn) if err != nil { - return "", err + return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err) } // .DNSName should not have a trailing dot diff --git a/providers/dns/safedns/internal/client.go b/providers/dns/safedns/internal/client.go index af416ece..254ec097 100644 --- a/providers/dns/safedns/internal/client.go +++ b/providers/dns/safedns/internal/client.go @@ -2,8 +2,8 @@ package internal import ( "bytes" + "context" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -12,10 +12,13 @@ import ( "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.ukfast.io/safedns/v1" +const authorizationHeader = "Authorization" + // Client the UKFast SafeDNS client. type Client struct { authToken string @@ -27,6 +30,7 @@ type Client struct { // NewClient Creates a new Client. func NewClient(authToken string) *Client { baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ authToken: authToken, baseURL: baseURL, @@ -35,93 +39,103 @@ func NewClient(authToken string) *Client { } // AddRecord adds a DNS record. -func (c *Client) AddRecord(zone string, record Record) (*AddRecordResponse, error) { - body, err := json.Marshal(record) - if err != nil { - return nil, err - } - +func (c *Client) AddRecord(ctx context.Context, zone string, record Record) (*AddRecordResponse, error) { endpoint := c.baseURL.JoinPath("zones", dns01.UnFqdn(zone), "records") - req, err := c.newRequest(http.MethodPost, endpoint.String(), bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= http.StatusBadRequest { - return nil, readError(req, resp) - } - - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.New(toUnreadableBodyMessage(req, content)) - } - respData := &AddRecordResponse{} - err = json.Unmarshal(content, respData) + err = c.do(req, respData) if err != nil { - return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(req, content)) + return nil, fmt.Errorf("remove record: %w", err) } return respData, nil } // RemoveRecord removes a DNS record. -func (c *Client) RemoveRecord(zone string, recordID int) error { +func (c *Client) RemoveRecord(ctx context.Context, zone string, recordID int) error { endpoint := c.baseURL.JoinPath("zones", dns01.UnFqdn(zone), "records", strconv.Itoa(recordID)) - req, err := c.newRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - resp, err := c.HTTPClient.Do(req) + err = c.do(req, nil) if err != nil { - return err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= http.StatusBadRequest { - return readError(req, resp) + return fmt.Errorf("remove record: %w", err) } return nil } -func (c *Client) newRequest(method, endpoint string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, endpoint, body) +func (c *Client) do(req *http.Request, result any) error { + req.Header.Set(authorizationHeader, c.authToken) + + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, err + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) } - req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", c.authToken) + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } -func readError(req *http.Request, resp *http.Response) error { - content, err := io.ReadAll(resp.Body) +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errAPI APIError + err := json.Unmarshal(raw, &errAPI) if err != nil { - return errors.New(toUnreadableBodyMessage(req, content)) + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - var errInfo APIError - err = json.Unmarshal(content, &errInfo) - if err != nil { - return fmt.Errorf("unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) - } - - return errInfo -} - -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s received a response with an invalid format: %q", req.URL, string(rawBody)) + return fmt.Errorf("[status code: %d] %w", resp.StatusCode, errAPI) } diff --git a/providers/dns/safedns/internal/client_test.go b/providers/dns/safedns/internal/client_test.go index c6c493d7..6709277c 100644 --- a/providers/dns/safedns/internal/client_test.go +++ b/providers/dns/safedns/internal/client_test.go @@ -1,11 +1,13 @@ package internal import ( + "context" "fmt" "io" "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/go-acme/lego/v4/challenge/dns01" @@ -35,7 +37,7 @@ func TestClient_AddRecord(t *testing.T) { return } - if req.Header.Get("Authorization") != "secret" { + if req.Header.Get(authorizationHeader) != "secret" { http.Error(rw, `{"message":"Unauthenticated"}`, http.StatusUnauthorized) return } @@ -47,7 +49,7 @@ func TestClient_AddRecord(t *testing.T) { } expectedReqBody := `{"name":"_acme-challenge.example.com","type":"TXT","content":"\"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI\"","ttl":120}` - if string(reqBody) != expectedReqBody { + if strings.TrimSpace(string(reqBody)) != expectedReqBody { http.Error(rw, `{"message":"invalid request"}`, http.StatusBadRequest) return } @@ -76,7 +78,7 @@ func TestClient_AddRecord(t *testing.T) { TTL: dns01.DefaultTTL, } - response, err := client.AddRecord("example.com", record) + response, err := client.AddRecord(context.Background(), "example.com", record) require.NoError(t, err) expected := &AddRecordResponse{ @@ -104,7 +106,7 @@ func TestClient_RemoveRecord(t *testing.T) { return } - if req.Header.Get("Authorization") != "secret" { + if req.Header.Get(authorizationHeader) != "secret" { http.Error(rw, `{"message":"Unauthenticated"}`, http.StatusUnauthorized) return } @@ -112,6 +114,6 @@ func TestClient_RemoveRecord(t *testing.T) { rw.WriteHeader(http.StatusNoContent) }) - err := client.RemoveRecord("example.com", 1234567) + err := client.RemoveRecord(context.Background(), "example.com", 1234567) require.NoError(t, err) } diff --git a/providers/dns/safedns/safedns.go b/providers/dns/safedns/safedns.go index 9cb63ad7..8285f3a0 100644 --- a/providers/dns/safedns/safedns.go +++ b/providers/dns/safedns/safedns.go @@ -2,6 +2,7 @@ package safedns import ( + "context" "errors" "fmt" "net/http" @@ -104,7 +105,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { zone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(info.EffectiveFQDN)) if err != nil { - return fmt.Errorf("safedns: could not determine zone for domain: %q: %w", info.EffectiveFQDN, err) + return fmt.Errorf("safedns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } record := internal.Record{ @@ -114,7 +115,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - resp, err := d.client.AddRecord(zone, record) + resp, err := d.client.AddRecord(context.Background(), zone, record) if err != nil { return fmt.Errorf("safedns: %w", err) } @@ -132,7 +133,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("safedns: %w", err) + return fmt.Errorf("safedns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } d.recordIDsMu.Lock() @@ -142,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("safedns: unknown record ID for '%s'", info.EffectiveFQDN) } - err = d.client.RemoveRecord(authZone, recordID) + err = d.client.RemoveRecord(context.Background(), authZone, recordID) if err != nil { return fmt.Errorf("safedns: %w", err) } diff --git a/providers/dns/sakuracloud/client.go b/providers/dns/sakuracloud/wrapper.go similarity index 96% rename from providers/dns/sakuracloud/client.go rename to providers/dns/sakuracloud/wrapper.go index ed124dd6..2bf8ac9f 100644 --- a/providers/dns/sakuracloud/client.go +++ b/providers/dns/sakuracloud/wrapper.go @@ -82,7 +82,7 @@ func (d *DNSProvider) cleanupTXTRecord(fqdn, value string) error { func (d *DNSProvider) getHostedZone(domain string) (*iaas.DNS, error) { authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return nil, err + return nil, fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } zoneName := dns01.UnFqdn(authZone) diff --git a/providers/dns/sakuracloud/client_test.go b/providers/dns/sakuracloud/wrapper_test.go similarity index 100% rename from providers/dns/sakuracloud/client_test.go rename to providers/dns/sakuracloud/wrapper_test.go diff --git a/providers/dns/selectel/selectel.go b/providers/dns/selectel/selectel.go index e9378366..933115c7 100644 --- a/providers/dns/selectel/selectel.go +++ b/providers/dns/selectel/selectel.go @@ -4,9 +4,11 @@ package selectel import ( + "context" "errors" "fmt" "net/http" + "net/url" "time" "github.com/go-acme/lego/v4/challenge/dns01" @@ -87,8 +89,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { } client := selectel.NewClient(config.Token) - client.BaseURL = config.BaseURL - client.HTTPClient = config.HTTPClient + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + var err error + client.BaseURL, err = url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("selectel: %w", err) + } return &DNSProvider{config: config, client: client}, nil } @@ -103,8 +112,10 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainObj, err := d.client.GetDomainByName(domain) + domainObj, err := d.client.GetDomainByName(ctx, domain) if err != nil { return fmt.Errorf("selectel: %w", err) } @@ -115,7 +126,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Name: info.EffectiveFQDN, Content: info.Value, } - _, err = d.client.AddRecord(domainObj.ID, txtRecord) + _, err = d.client.AddRecord(ctx, domainObj.ID, txtRecord) if err != nil { return fmt.Errorf("selectel: %w", err) } @@ -129,13 +140,15 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { recordName := dns01.UnFqdn(info.EffectiveFQDN) + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainObj, err := d.client.GetDomainByName(domain) + domainObj, err := d.client.GetDomainByName(ctx, domain) if err != nil { return fmt.Errorf("selectel: %w", err) } - records, err := d.client.ListRecords(domainObj.ID) + records, err := d.client.ListRecords(ctx, domainObj.ID) if err != nil { return fmt.Errorf("selectel: %w", err) } @@ -144,7 +157,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { var lastErr error for _, record := range records { if record.Name == recordName { - err = d.client.DeleteRecord(domainObj.ID, record.ID) + err = d.client.DeleteRecord(ctx, domainObj.ID, record.ID) if err != nil { lastErr = fmt.Errorf("selectel: %w", err) } diff --git a/providers/dns/servercow/internal/client.go b/providers/dns/servercow/internal/client.go index 5230ebe2..8f03d9a9 100644 --- a/providers/dns/servercow/internal/client.go +++ b/providers/dns/servercow/internal/client.go @@ -2,59 +2,52 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const baseAPIURL = "https://api.servercow.de/dns/v1/domains" // Client the Servercow client. type Client struct { - BaseURL string - HTTPClient *http.Client - username string password string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient Creates a Servercow client. func NewClient(username, password string) *Client { + baseURL, _ := url.Parse(baseAPIURL) + return &Client{ - HTTPClient: http.DefaultClient, - BaseURL: baseAPIURL, username: username, password: password, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, } } // GetRecords from API. -func (c *Client) GetRecords(domain string) ([]Record, error) { - req, err := c.createRequest(http.MethodGet, domain, nil) +func (c *Client) GetRecords(ctx context.Context, domain string) ([]Record, error) { + endpoint := c.baseURL.JoinPath(domain) + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - // Note the API always return 200 even if the authentication failed. - if resp.StatusCode/100 != 2 { - return nil, fmt.Errorf("error: status code %d", resp.StatusCode) - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read body: %w", err) - } - var records []Record - err = unmarshal(raw, &records) + err = c.do(req, &records) if err != nil { return nil, err } @@ -63,30 +56,16 @@ func (c *Client) GetRecords(domain string) ([]Record, error) { } // CreateUpdateRecord creates or updates a record. -func (c *Client) CreateUpdateRecord(domain string, data Record) (*Message, error) { - req, err := c.createRequest(http.MethodPost, domain, &data) +func (c *Client) CreateUpdateRecord(ctx context.Context, domain string, data Record) (*Message, error) { + endpoint := c.baseURL.JoinPath(domain) + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, data) if err != nil { return nil, err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - // Note the API always return 200 even if the authentication failed. - if resp.StatusCode/100 != 2 { - return nil, fmt.Errorf("error: status code %d", resp.StatusCode) - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read body: %w", err) - } - var msg Message - err = json.Unmarshal(raw, &msg) + err = c.do(req, &msg) if err != nil { return nil, err } @@ -99,33 +78,18 @@ func (c *Client) CreateUpdateRecord(domain string, data Record) (*Message, error } // DeleteRecord deletes a record. -func (c *Client) DeleteRecord(domain string, data Record) (*Message, error) { - req, err := c.createRequest(http.MethodDelete, domain, &data) +func (c *Client) DeleteRecord(ctx context.Context, domain string, data Record) (*Message, error) { + endpoint := c.baseURL.JoinPath(domain) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, data) if err != nil { return nil, err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - // Note the API always return 200 even if the authentication failed. - if resp.StatusCode/100 != 2 { - return nil, fmt.Errorf("error: status code %d", resp.StatusCode) - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read body: %w", err) - } - var msg Message - err = json.Unmarshal(raw, &msg) + err = c.do(req, &msg) if err != nil { - //nolint:errorlint // in this context msg is not an error, and we just get the type. - return nil, fmt.Errorf("unmarshaling %T error: %w: %s", msg, err, string(raw)) + return nil, err } if msg.ErrorMsg != "" { @@ -135,40 +99,80 @@ func (c *Client) DeleteRecord(domain string, data Record) (*Message, error) { return &msg, nil } -func (c *Client) createRequest(method, domain string, payload *Record) (*http.Request, error) { - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, c.BaseURL+"/"+domain, bytes.NewReader(body)) - if err != nil { - return nil, err - } - +func (c *Client) do(req *http.Request, result any) error { req.Header.Set("X-Auth-Username", c.username) req.Header.Set("X-Auth-Password", c.password) - req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + // Note the API always return 200 even if the authentication failed. + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = unmarshal(raw, result) + if err != nil { + return err + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } -func unmarshal(raw []byte, v interface{}) error { +func unmarshal(raw []byte, v any) error { err := json.Unmarshal(raw, v) if err == nil { return nil } - var e *json.UnmarshalTypeError - if errors.As(err, &e) { - var apiError Message - errU := json.Unmarshal(raw, &apiError) - if errU != nil { - return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) - } + var utErr *json.UnmarshalTypeError - return apiError + if !errors.As(err, &utErr) { + return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) } - return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) + var apiErr Message + errU := json.Unmarshal(raw, &apiErr) + if errU != nil { + return fmt.Errorf("unmarshaling %T error: %w: %s", v, err, string(raw)) + } + + return apiErr } diff --git a/providers/dns/servercow/internal/client_test.go b/providers/dns/servercow/internal/client_test.go index 1c880973..8597d7e1 100644 --- a/providers/dns/servercow/internal/client_test.go +++ b/providers/dns/servercow/internal/client_test.go @@ -1,10 +1,12 @@ package internal import ( + "context" "encoding/json" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -20,7 +22,8 @@ func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Cleanup(server.Close) client := NewClient("", "") - client.BaseURL = server.URL + client.HTTPClient = server.Client() + client.baseURL, _ = url.Parse(server.URL) return client, mux } @@ -48,7 +51,7 @@ func TestClient_GetRecords(t *testing.T) { } }) - records, err := client.GetRecords("lego.wtf") + records, err := client.GetRecords(context.Background(), "lego.wtf") require.NoError(t, err) recordsJSON, err := json.Marshal(records) @@ -76,7 +79,7 @@ func TestClient_GetRecords_error(t *testing.T) { } }) - records, err := client.GetRecords("lego.wtf") + records, err := client.GetRecords(context.Background(), "lego.wtf") require.Error(t, err) assert.Nil(t, records) @@ -118,7 +121,7 @@ func TestClient_CreateUpdateRecord(t *testing.T) { Content: Value{"aaa", "bbb"}, } - msg, err := client.CreateUpdateRecord("lego.wtf", record) + msg, err := client.CreateUpdateRecord(context.Background(), "lego.wtf", record) require.NoError(t, err) expected := &Message{Message: "ok"} @@ -145,7 +148,7 @@ func TestClient_CreateUpdateRecord_error(t *testing.T) { Name: "_acme-challenge.www", } - msg, err := client.CreateUpdateRecord("lego.wtf", record) + msg, err := client.CreateUpdateRecord(context.Background(), "lego.wtf", record) require.Error(t, err) assert.Nil(t, msg) @@ -185,7 +188,7 @@ func TestClient_DeleteRecord(t *testing.T) { Type: "TXT", } - msg, err := client.DeleteRecord("lego.wtf", record) + msg, err := client.DeleteRecord(context.Background(), "lego.wtf", record) require.NoError(t, err) expected := &Message{Message: "ok"} @@ -212,7 +215,7 @@ func TestClient_DeleteRecord_error(t *testing.T) { Name: "_acme-challenge.www", } - msg, err := client.DeleteRecord("lego.wtf", record) + msg, err := client.DeleteRecord(context.Background(), "lego.wtf", record) require.Error(t, err) assert.Nil(t, msg) diff --git a/providers/dns/servercow/internal/model.go b/providers/dns/servercow/internal/types.go similarity index 100% rename from providers/dns/servercow/internal/model.go rename to providers/dns/servercow/internal/types.go diff --git a/providers/dns/servercow/internal/model_test.go b/providers/dns/servercow/internal/types_test.go similarity index 100% rename from providers/dns/servercow/internal/model_test.go rename to providers/dns/servercow/internal/types_test.go diff --git a/providers/dns/servercow/servercow.go b/providers/dns/servercow/servercow.go index 25b8ad29..324fa660 100644 --- a/providers/dns/servercow/servercow.go +++ b/providers/dns/servercow/servercow.go @@ -2,6 +2,7 @@ package servercow import ( + "context" "errors" "fmt" "net/http" @@ -12,8 +13,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/servercow/internal" ) -const defaultTTL = 120 - // Environment variables names. const ( envNamespace = "SERVERCOW_" @@ -41,7 +40,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 120), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), HTTPClient: &http.Client{ @@ -76,12 +75,11 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("servercow: incomplete credentials, missing username and/or password") } - if config.HTTPClient == nil { - config.HTTPClient = http.DefaultClient - } - client := internal.NewClient(config.Username, config.Password) - client.HTTPClient = config.HTTPClient + + if config.HTTPClient == nil { + client.HTTPClient = config.HTTPClient + } return &DNSProvider{ config: config, @@ -104,7 +102,9 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("servercow: %w", err) } - records, err := d.client.GetRecords(authZone) + ctx := context.Background() + + records, err := d.client.GetRecords(ctx, authZone) if err != nil { return fmt.Errorf("servercow: %w", err) } @@ -129,7 +129,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Content: append(record.Content, info.Value), } - _, err = d.client.CreateUpdateRecord(authZone, request) + _, err = d.client.CreateUpdateRecord(ctx, authZone, request) if err != nil { return fmt.Errorf("servercow: failed to update TXT records: %w", err) } @@ -143,7 +143,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Content: internal.Value{info.Value}, } - _, err = d.client.CreateUpdateRecord(authZone, request) + _, err = d.client.CreateUpdateRecord(ctx, authZone, request) if err != nil { return fmt.Errorf("servercow: failed to create TXT record %s: %w", info.EffectiveFQDN, err) } @@ -160,7 +160,9 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("servercow: %w", err) } - records, err := d.client.GetRecords(authZone) + ctx := context.Background() + + records, err := d.client.GetRecords(ctx, authZone) if err != nil { return fmt.Errorf("servercow: failed to get TXT records: %w", err) } @@ -181,7 +183,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { // only 1 record value, the whole record must be deleted. if len(record.Content) == 1 { - _, err = d.client.DeleteRecord(authZone, *record) + _, err = d.client.DeleteRecord(ctx, authZone, *record) if err != nil { return fmt.Errorf("servercow: failed to delete TXT records: %w", err) } @@ -200,7 +202,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { } } - _, err = d.client.CreateUpdateRecord(authZone, request) + _, err = d.client.CreateUpdateRecord(ctx, authZone, request) if err != nil { return fmt.Errorf("servercow: failed to update TXT records: %w", err) } @@ -211,7 +213,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func getAuthZone(domain string) (string, error) { authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return "", fmt.Errorf("could not find zone for domain %q: %w", domain, err) + return "", fmt.Errorf("could not find zone for FQDN %q: %w", domain, err) } zoneName := dns01.UnFqdn(authZone) diff --git a/providers/dns/simply/internal/client.go b/providers/dns/simply/internal/client.go index a128fd47..f4221194 100644 --- a/providers/dns/simply/internal/client.go +++ b/providers/dns/simply/internal/client.go @@ -2,23 +2,28 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "strings" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.simply.com/1/" // Client is a Simply.com API client. type Client struct { - HTTPClient *http.Client - baseURL *url.URL accountName string apiKey string + + baseURL *url.URL + HTTPClient *http.Client } // NewClient creates a new Client. @@ -37,98 +42,126 @@ func NewClient(accountName string, apiKey string) (*Client, error) { } return &Client{ - HTTPClient: &http.Client{Timeout: 5 * time.Second}, - baseURL: baseURL, accountName: accountName, apiKey: apiKey, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, }, nil } // GetRecords lists all the records in the zone. -func (c *Client) GetRecords(zoneName string) ([]Record, error) { - resp, err := c.do(zoneName, "/", http.MethodGet, nil) - if err != nil { - return nil, err - } +func (c *Client) GetRecords(ctx context.Context, zoneName string) ([]Record, error) { + endpoint := c.createEndpoint(zoneName, "/") - var records []Record - err = json.Unmarshal(resp.Records, &records) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal response result: %w", err) - } - - return records, nil -} - -// AddRecord adds a record. -func (c *Client) AddRecord(zoneName string, record Record) (int64, error) { - reqBody, err := json.Marshal(record) - if err != nil { - return 0, fmt.Errorf("failed to marshall request body: %w", err) - } - - resp, err := c.do(zoneName, "/", http.MethodPost, reqBody) - if err != nil { - return 0, err - } - - var rcd recordHeader - err = json.Unmarshal(resp.Record, &rcd) - if err != nil { - return 0, fmt.Errorf("failed to unmarshal response result: %w", err) - } - - return rcd.ID, nil -} - -// EditRecord updates a record. -func (c *Client) EditRecord(zoneName string, id int64, record Record) error { - reqBody, err := json.Marshal(record) - if err != nil { - return fmt.Errorf("failed to marshall request body: %w", err) - } - - _, err = c.do(zoneName, fmt.Sprintf("%d", id), http.MethodPut, reqBody) - return err -} - -// DeleteRecord deletes a record. -func (c *Client) DeleteRecord(zoneName string, id int64) error { - _, err := c.do(zoneName, fmt.Sprintf("%d", id), http.MethodDelete, nil) - return err -} - -func (c *Client) do(zoneName string, endpoint string, reqMethod string, reqBody []byte) (*apiResponse, error) { - reqURL := c.baseURL.JoinPath(c.accountName, c.apiKey, "my", "products", zoneName, "dns", "records", endpoint) - - req, err := http.NewRequest(reqMethod, strings.TrimSuffix(reqURL.String(), "/"), bytes.NewReader(reqBody)) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/json") + result := &apiResponse[[]Record, json.RawMessage]{} + err = c.do(req, result) + if err != nil { + return nil, err + } + return result.Records, nil +} + +// AddRecord adds a record. +func (c *Client) AddRecord(ctx context.Context, zoneName string, record Record) (int64, error) { + endpoint := c.createEndpoint(zoneName, "/") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return 0, fmt.Errorf("failed to create request: %w", err) + } + + result := &apiResponse[json.RawMessage, recordHeader]{} + err = c.do(req, result) + if err != nil { + return 0, err + } + + return result.Record.ID, nil +} + +// EditRecord updates a record. +func (c *Client) EditRecord(ctx context.Context, zoneName string, id int64, record Record) error { + endpoint := c.createEndpoint(zoneName, fmt.Sprintf("%d", id)) + + req, err := newJSONRequest(ctx, http.MethodPut, endpoint, record) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + return c.do(req, &apiResponse[json.RawMessage, json.RawMessage]{}) +} + +// DeleteRecord deletes a record. +func (c *Client) DeleteRecord(ctx context.Context, zoneName string, id int64) error { + endpoint := c.createEndpoint(zoneName, fmt.Sprintf("%d", id)) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + return c.do(req, &apiResponse[json.RawMessage, json.RawMessage]{}) +} + +func (c *Client) createEndpoint(zoneName string, uri string) *url.URL { + return c.baseURL.JoinPath(c.accountName, c.apiKey, "my", "products", zoneName, "dns", "records", strings.TrimSuffix(uri, "/")) +} + +func (c *Client) do(req *http.Request, result Response) error { resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to perform request: %w", err) + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= http.StatusInternalServerError { - return nil, fmt.Errorf("unexpected error: %d", resp.StatusCode) + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) } - response := apiResponse{} - err = json.NewDecoder(resp.Body).Decode(&response) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - if response.Status != http.StatusOK { - return nil, fmt.Errorf("unexpected error: %s", response.Message) + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } - return &response, nil + if result.GetStatus() != http.StatusOK { + return fmt.Errorf("unexpected error: %s", result.GetMessage()) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil } diff --git a/providers/dns/simply/internal/client_test.go b/providers/dns/simply/internal/client_test.go index 575ada9c..c9b97e94 100644 --- a/providers/dns/simply/internal/client_test.go +++ b/providers/dns/simply/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -16,11 +17,11 @@ import ( ) func TestClient_GetRecords(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records", mockHandler(http.MethodGet, http.StatusOK, "get_records.json")) - records, err := client.GetRecords("azone01") + records, err := client.GetRecords(context.Background(), "azone01") require.NoError(t, err) expected := []Record{ @@ -62,18 +63,18 @@ func TestClient_GetRecords(t *testing.T) { } func TestClient_GetRecords_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records", mockHandler(http.MethodGet, http.StatusBadRequest, "bad_auth_error.json")) - records, err := client.GetRecords("azone01") + records, err := client.GetRecords(context.Background(), "azone01") require.Error(t, err) assert.Nil(t, records) } func TestClient_AddRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records", mockHandler(http.MethodPost, http.StatusOK, "add_record.json")) @@ -85,14 +86,14 @@ func TestClient_AddRecord(t *testing.T) { Priority: 0, } - recordID, err := client.AddRecord("azone01", record) + recordID, err := client.AddRecord(context.Background(), "azone01", record) require.NoError(t, err) assert.EqualValues(t, 123456789, recordID) } func TestClient_AddRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records", mockHandler(http.MethodPost, http.StatusNotFound, "bad_zone_error.json")) @@ -104,14 +105,14 @@ func TestClient_AddRecord_error(t *testing.T) { Priority: 0, } - recordID, err := client.AddRecord("azone01", record) + recordID, err := client.AddRecord(context.Background(), "azone01", record) require.Error(t, err) assert.Zero(t, recordID) } func TestClient_EditRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records/123456789", mockHandler(http.MethodPut, http.StatusOK, "success.json")) @@ -123,12 +124,12 @@ func TestClient_EditRecord(t *testing.T) { Priority: 0, } - err := client.EditRecord("azone01", 123456789, record) + err := client.EditRecord(context.Background(), "azone01", 123456789, record) require.NoError(t, err) } func TestClient_EditRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records/123456789", mockHandler(http.MethodPut, http.StatusNotFound, "invalid_record_id.json")) @@ -140,29 +141,29 @@ func TestClient_EditRecord_error(t *testing.T) { Priority: 0, } - err := client.EditRecord("azone01", 123456789, record) + err := client.EditRecord(context.Background(), "azone01", 123456789, record) require.Error(t, err) } func TestClient_DeleteRecord(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records/123456789", mockHandler(http.MethodDelete, http.StatusOK, "success.json")) - err := client.DeleteRecord("azone01", 123456789) + err := client.DeleteRecord(context.Background(), "azone01", 123456789) require.NoError(t, err) } func TestClient_DeleteRecord_error(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/accountname/apikey/my/products/azone01/dns/records/123456789", mockHandler(http.MethodDelete, http.StatusNotFound, "invalid_record_id.json")) - err := client.DeleteRecord("azone01", 123456789) + err := client.DeleteRecord(context.Background(), "azone01", 123456789) require.Error(t, err) } -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() @@ -174,7 +175,7 @@ func setupTest(t *testing.T) (*http.ServeMux, *Client) { client.baseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func mockHandler(method string, statusCode int, filename string) func(http.ResponseWriter, *http.Request) { diff --git a/providers/dns/simply/internal/types.go b/providers/dns/simply/internal/types.go index e2440c31..7bc53345 100644 --- a/providers/dns/simply/internal/types.go +++ b/providers/dns/simply/internal/types.go @@ -1,7 +1,5 @@ package internal -import "encoding/json" - // Record represents the content of a DNS record. type Record struct { ID int64 `json:"record_id,omitempty"` @@ -12,12 +10,25 @@ type Record struct { Priority int `json:"priority,omitempty"` } +type Response interface { + GetStatus() int + GetMessage() string +} + // apiResponse represents an API response. -type apiResponse struct { - Status int `json:"status"` - Message string `json:"message"` - Records json.RawMessage `json:"records,omitempty"` - Record json.RawMessage `json:"record,omitempty"` +type apiResponse[S any, R any] struct { + Status int `json:"status"` + Message string `json:"message"` + Records S `json:"records,omitempty"` + Record R `json:"record,omitempty"` +} + +func (a apiResponse[S, R]) GetStatus() int { + return a.Status +} + +func (a apiResponse[S, R]) GetMessage() string { + return a.Message } type recordHeader struct { diff --git a/providers/dns/simply/simply.go b/providers/dns/simply/simply.go index f4962d92..5376b3a4 100644 --- a/providers/dns/simply/simply.go +++ b/providers/dns/simply/simply.go @@ -2,6 +2,7 @@ package simply import ( + "context" "errors" "fmt" "net/http" @@ -114,8 +115,9 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("simply: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("simply: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + authZone = dns01.UnFqdn(authZone) subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -130,7 +132,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - recordID, err := d.client.AddRecord(authZone, recordBody) + recordID, err := d.client.AddRecord(context.Background(), authZone, recordBody) if err != nil { return fmt.Errorf("simply: failed to add record: %w", err) } @@ -148,8 +150,9 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("simply: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("simply: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } + authZone = dns01.UnFqdn(authZone) // gets the record's unique ID from when we created it @@ -160,7 +163,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("simply: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - err = d.client.DeleteRecord(authZone, recordID) + err = d.client.DeleteRecord(context.Background(), authZone, recordID) if err != nil { return fmt.Errorf("simply: failed to delete TXT records: fqdn=%s, recordID=%d: %w", info.EffectiveFQDN, recordID, err) } diff --git a/providers/dns/sonic/internal/client.go b/providers/dns/sonic/internal/client.go index ba52411c..aac85c63 100644 --- a/providers/dns/sonic/internal/client.go +++ b/providers/dns/sonic/internal/client.go @@ -2,35 +2,25 @@ package internal import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" "net/http" + "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const baseURL = "https://public-api.sonic.net/dyndns" -type APIResponse struct { - Message string `json:"message"` - Result int `json:"result"` -} - -// Record holds the Sonic API representation of a Domain Record. -type Record struct { - UserID string `json:"userid"` - APIKey string `json:"apikey"` - Hostname string `json:"hostname"` - Value string `json:"value"` - TTL int `json:"ttl"` - Type string `json:"type"` -} - // Client Sonic client. type Client struct { - userID string - apiKey string + userID string + apiKey string + baseURL string HTTPClient *http.Client } @@ -52,7 +42,7 @@ func NewClient(userID, apiKey string) (*Client, error) { // SetRecord creates or updates a TXT records. // Sonic does not provide a delete record API endpoint. // https://public-api.sonic.net/dyndns#updating_or_adding_host_records -func (c *Client) SetRecord(hostname string, value string, ttl int) error { +func (c *Client) SetRecord(ctx context.Context, hostname string, value string, ttl int) error { payload := &Record{ UserID: c.userID, APIKey: c.apiKey, @@ -64,32 +54,38 @@ func (c *Client) SetRecord(hostname string, value string, ttl int) error { body, err := json.Marshal(payload) if err != nil { - return err + return fmt.Errorf("failed to create request JSON body: %w", err) } - req, err := http.NewRequest(http.MethodPut, c.baseURL+"/host", bytes.NewReader(body)) + endpoint, err := url.JoinPath(c.baseURL, "host") if err != nil { return err } + req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") req.Header.Set("content-type", "application/json") resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } r := APIResponse{} err = json.Unmarshal(raw, &r) if err != nil { - return fmt.Errorf("failed to unmarshal response: %w: %s", err, string(raw)) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } if r.Result != 200 { diff --git a/providers/dns/sonic/internal/client_test.go b/providers/dns/sonic/internal/client_test.go index 6317e16d..f4a6105e 100644 --- a/providers/dns/sonic/internal/client_test.go +++ b/providers/dns/sonic/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -9,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func setup(t *testing.T, body string) *Client { +func setupTest(t *testing.T, body string) *Client { t.Helper() mux := http.NewServeMux() @@ -51,9 +52,9 @@ func TestClient_SetRecord(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - client := setup(t, test.response) + client := setupTest(t, test.response) - err := client.SetRecord("example.com", "txttxttxt", 10) + err := client.SetRecord(context.Background(), "example.com", "txttxttxt", 10) test.assert(t, err) }) } diff --git a/providers/dns/sonic/internal/types.go b/providers/dns/sonic/internal/types.go new file mode 100644 index 00000000..d6caed3a --- /dev/null +++ b/providers/dns/sonic/internal/types.go @@ -0,0 +1,16 @@ +package internal + +type APIResponse struct { + Message string `json:"message"` + Result int `json:"result"` +} + +// Record holds the Sonic API representation of a Domain Record. +type Record struct { + UserID string `json:"userid"` + APIKey string `json:"apikey"` + Hostname string `json:"hostname"` + Value string `json:"value"` + TTL int `json:"ttl"` + Type string `json:"type"` +} diff --git a/providers/dns/sonic/sonic.go b/providers/dns/sonic/sonic.go index 907ee7d0..19c5769b 100644 --- a/providers/dns/sonic/sonic.go +++ b/providers/dns/sonic/sonic.go @@ -2,6 +2,7 @@ package sonic import ( + "context" "errors" "fmt" "net/http" @@ -94,7 +95,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.SetRecord(dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL) + err := d.client.SetRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL) if err != nil { return fmt.Errorf("sonic: unable to create record for %s: %w", info.EffectiveFQDN, err) } @@ -106,7 +107,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - err := d.client.SetRecord(dns01.UnFqdn(info.EffectiveFQDN), "_", d.config.TTL) + err := d.client.SetRecord(context.Background(), dns01.UnFqdn(info.EffectiveFQDN), "_", d.config.TTL) if err != nil { return fmt.Errorf("sonic: unable to clean record for %s: %w", info.EffectiveFQDN, err) } diff --git a/providers/dns/stackpath/client.go b/providers/dns/stackpath/client.go deleted file mode 100644 index f38f4e75..00000000 --- a/providers/dns/stackpath/client.go +++ /dev/null @@ -1,207 +0,0 @@ -package stackpath - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - - "github.com/go-acme/lego/v4/challenge/dns01" - "golang.org/x/net/publicsuffix" -) - -// Zones is the response struct from the Stackpath api GetZones. -type Zones struct { - Zones []Zone `json:"zones"` -} - -// Zone a DNS zone representation. -type Zone struct { - ID string - Domain string -} - -// Records is the response struct from the Stackpath api GetZoneRecords. -type Records struct { - Records []Record `json:"records"` -} - -// Record a DNS record representation. -type Record struct { - ID string `json:"id,omitempty"` - Name string `json:"name"` - Type string `json:"type"` - TTL int `json:"ttl"` - Data string `json:"data"` -} - -// ErrorResponse the API error response representation. -type ErrorResponse struct { - Code int `json:"code"` - Message string `json:"error"` -} - -func (e *ErrorResponse) Error() string { - return fmt.Sprintf("%d %s", e.Code, e.Message) -} - -// https://developer.stackpath.com/en/api/dns/#operation/GetZones -func (d *DNSProvider) getZones(domain string) (*Zone, error) { - tld, err := publicsuffix.EffectiveTLDPlusOne(dns01.UnFqdn(domain)) - if err != nil { - return nil, err - } - - req, err := d.newRequest(http.MethodGet, "/zones", nil) - if err != nil { - return nil, err - } - - query := req.URL.Query() - query.Add("page_request.filter", fmt.Sprintf("domain='%s'", tld)) - req.URL.RawQuery = query.Encode() - - var zones Zones - err = d.do(req, &zones) - if err != nil { - return nil, err - } - - if len(zones.Zones) == 0 { - return nil, fmt.Errorf("did not find zone with domain %s", domain) - } - - return &zones.Zones[0], nil -} - -// https://developer.stackpath.com/en/api/dns/#operation/GetZoneRecords -func (d *DNSProvider) getZoneRecords(name string, zone *Zone) ([]Record, error) { - u := fmt.Sprintf("/zones/%s/records", zone.ID) - req, err := d.newRequest(http.MethodGet, u, nil) - if err != nil { - return nil, err - } - - query := req.URL.Query() - query.Add("page_request.filter", fmt.Sprintf("name='%s' and type='TXT'", name)) - req.URL.RawQuery = query.Encode() - - var records Records - err = d.do(req, &records) - if err != nil { - return nil, err - } - - if len(records.Records) == 0 { - return nil, fmt.Errorf("did not find record with name %s", name) - } - - return records.Records, nil -} - -// https://developer.stackpath.com/en/api/dns/#operation/CreateZoneRecord -func (d *DNSProvider) createZoneRecord(zone *Zone, record Record) error { - u := fmt.Sprintf("/zones/%s/records", zone.ID) - req, err := d.newRequest(http.MethodPost, u, record) - if err != nil { - return err - } - - return d.do(req, nil) -} - -// https://developer.stackpath.com/en/api/dns/#operation/DeleteZoneRecord -func (d *DNSProvider) deleteZoneRecord(zone *Zone, record Record) error { - u := fmt.Sprintf("/zones/%s/records/%s", zone.ID, record.ID) - req, err := d.newRequest(http.MethodDelete, u, nil) - if err != nil { - return err - } - - return d.do(req, nil) -} - -func (d *DNSProvider) newRequest(method, urlStr string, body interface{}) (*http.Request, error) { - u := d.BaseURL.JoinPath(d.config.StackID, urlStr) - - if body == nil { - return http.NewRequest(method, u.String(), nil) - } - - reqBody, err := json.Marshal(body) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(reqBody)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - - return req, nil -} - -func (d *DNSProvider) do(req *http.Request, v interface{}) error { - resp, err := d.client.Do(req) - if err != nil { - return err - } - - err = checkResponse(resp) - if err != nil { - return err - } - - if v == nil { - return nil - } - - raw, err := readBody(resp) - if err != nil { - return fmt.Errorf("failed to read body: %w", err) - } - - err = json.Unmarshal(raw, v) - if err != nil { - return fmt.Errorf("unmarshaling error: %w: %s", err, string(raw)) - } - - return nil -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode > 299 { - data, err := readBody(resp) - if err != nil { - return &ErrorResponse{Code: resp.StatusCode, Message: err.Error()} - } - - errResp := &ErrorResponse{} - err = json.Unmarshal(data, errResp) - if err != nil { - return &ErrorResponse{Code: resp.StatusCode, Message: fmt.Sprintf("unmarshaling error: %v: %s", err, string(data))} - } - return errResp - } - - return nil -} - -func readBody(resp *http.Response) ([]byte, error) { - if resp.Body == nil { - return nil, errors.New("response body is nil") - } - - defer resp.Body.Close() - - rawBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return rawBody, nil -} diff --git a/providers/dns/stackpath/internal/client.go b/providers/dns/stackpath/internal/client.go new file mode 100644 index 00000000..bd11bf23 --- /dev/null +++ b/providers/dns/stackpath/internal/client.go @@ -0,0 +1,186 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/net/publicsuffix" +) + +const defaultBaseURL = "https://gateway.stackpath.com/dns/v1/stacks/" + +// Client the API client for Stackpath. +type Client struct { + stackID string + + baseURL *url.URL + httpClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(ctx context.Context, stackID, clientID, clientSecret string) *Client { + baseURL, _ := url.Parse(defaultBaseURL) + + return &Client{ + baseURL: baseURL, + stackID: stackID, + httpClient: createOAuthClient(ctx, clientID, clientSecret), + } +} + +// GetZones gets all zones. +// https://stackpath.dev/reference/getzones +func (c *Client) GetZones(ctx context.Context, domain string) (*Zone, error) { + endpoint := c.baseURL.JoinPath(c.stackID, "zones") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + tld, err := publicsuffix.EffectiveTLDPlusOne(dns01.UnFqdn(domain)) + if err != nil { + return nil, err + } + + query := req.URL.Query() + query.Add("page_request.filter", fmt.Sprintf("domain='%s'", tld)) + req.URL.RawQuery = query.Encode() + + var zones Zones + err = c.do(req, &zones) + if err != nil { + return nil, err + } + + if len(zones.Zones) == 0 { + return nil, fmt.Errorf("did not find zone with domain %s", domain) + } + + return &zones.Zones[0], nil +} + +// GetZoneRecords gets all records. +// https://stackpath.dev/reference/getzonerecords +func (c *Client) GetZoneRecords(ctx context.Context, name string, zone *Zone) ([]Record, error) { + endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + query := req.URL.Query() + query.Add("page_request.filter", fmt.Sprintf("name='%s' and type='TXT'", name)) + req.URL.RawQuery = query.Encode() + + var records Records + err = c.do(req, &records) + if err != nil { + return nil, err + } + + if len(records.Records) == 0 { + return nil, fmt.Errorf("did not find record with name %s", name) + } + + return records.Records, nil +} + +// CreateZoneRecord creates a record. +// https://stackpath.dev/reference/createzonerecord +func (c *Client) CreateZoneRecord(ctx context.Context, zone *Zone, record Record) error { + endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return err + } + + return c.do(req, nil) +} + +// DeleteZoneRecord deletes a record. +// https://stackpath.dev/reference/deletezonerecord +func (c *Client) DeleteZoneRecord(ctx context.Context, zone *Zone, record Record) error { + endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records", record.ID) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func (c *Client) do(req *http.Request, result any) error { + resp, err := c.httpClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + errResp := &ErrorResponse{} + err := json.Unmarshal(raw, errResp) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errResp +} diff --git a/providers/dns/stackpath/internal/client_test.go b/providers/dns/stackpath/internal/client_test.go new file mode 100644 index 00000000..2de1d476 --- /dev/null +++ b/providers/dns/stackpath/internal/client_test.go @@ -0,0 +1,131 @@ +package internal + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T) (*Client, *http.ServeMux) { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + client := NewClient(context.Background(), "STACK_ID", "CLIENT_ID", "CLIENT_SECRET") + client.httpClient = server.Client() + client.baseURL, _ = url.Parse(server.URL + "/") + + return client, mux +} + +func TestClient_GetZoneRecords(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, _ *http.Request) { + content := ` + { + "records": [ + {"id":"1","name":"foo1","type":"TXT","ttl":120,"data":"txtTXTtxt"}, + {"id":"2","name":"foo2","type":"TXT","ttl":121,"data":"TXTtxtTXT"} + ] + }` + + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + records, err := client.GetZoneRecords(context.Background(), "foo1", &Zone{ID: "A", Domain: "test"}) + require.NoError(t, err) + + expected := []Record{ + {ID: "1", Name: "foo1", Type: "TXT", TTL: 120, Data: "txtTXTtxt"}, + {ID: "2", Name: "foo2", Type: "TXT", TTL: 121, Data: "TXTtxtTXT"}, + } + + assert.Equal(t, expected, records) +} + +func TestClient_GetZoneRecords_apiError(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, _ *http.Request) { + content := ` +{ + "code": 401, + "error": "an unauthorized request is attempted." +}` + + w.WriteHeader(http.StatusUnauthorized) + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + _, err := client.GetZoneRecords(context.Background(), "foo1", &Zone{ID: "A", Domain: "test"}) + + expected := &ErrorResponse{Code: 401, Message: "an unauthorized request is attempted."} + assert.Equal(t, expected, err) +} + +func TestClient_GetZones(t *testing.T) { + client, mux := setupTest(t) + + mux.HandleFunc("/STACK_ID/zones", func(w http.ResponseWriter, _ *http.Request) { + content := ` +{ + "pageInfo": { + "totalCount": "5", + "hasPreviousPage": false, + "hasNextPage": false, + "startCursor": "1", + "endCursor": "1" + }, + "zones": [ + { + "stackId": "my_stack", + "accountId": "my_account", + "id": "A", + "domain": "foo.com", + "version": "1", + "labels": { + "property1": "val1", + "property2": "val2" + }, + "created": "2018-10-07T02:31:49Z", + "updated": "2018-10-07T02:31:49Z", + "nameservers": [ + "1.1.1.1" + ], + "verified": "2018-10-07T02:31:49Z", + "status": "ACTIVE", + "disabled": false + } + ] +}` + + _, err := w.Write([]byte(content)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + + zone, err := client.GetZones(context.Background(), "sub.foo.com") + require.NoError(t, err) + + expected := &Zone{ID: "A", Domain: "foo.com"} + + assert.Equal(t, expected, zone) +} diff --git a/providers/dns/stackpath/internal/identity.go b/providers/dns/stackpath/internal/identity.go new file mode 100644 index 00000000..5c6e6ab1 --- /dev/null +++ b/providers/dns/stackpath/internal/identity.go @@ -0,0 +1,20 @@ +package internal + +import ( + "context" + "net/http" + + "golang.org/x/oauth2/clientcredentials" +) + +const defaultAuthURL = "https://gateway.stackpath.com/identity/v1/oauth2/token" + +func createOAuthClient(ctx context.Context, clientID, clientSecret string) *http.Client { + config := &clientcredentials.Config{ + TokenURL: defaultAuthURL, + ClientID: clientID, + ClientSecret: clientSecret, + } + + return config.Client(ctx) +} diff --git a/providers/dns/stackpath/internal/types.go b/providers/dns/stackpath/internal/types.go new file mode 100644 index 00000000..1ca29e81 --- /dev/null +++ b/providers/dns/stackpath/internal/types.go @@ -0,0 +1,38 @@ +package internal + +import "fmt" + +// Zones is the response struct from the Stackpath api GetZones. +type Zones struct { + Zones []Zone `json:"zones"` +} + +// Zone a DNS zone representation. +type Zone struct { + ID string + Domain string +} + +// Records is the response struct from the Stackpath api GetZoneRecords. +type Records struct { + Records []Record `json:"records"` +} + +// Record a DNS record representation. +type Record struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Type string `json:"type"` + TTL int `json:"ttl"` + Data string `json:"data"` +} + +// ErrorResponse the API error response representation. +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"error"` +} + +func (e *ErrorResponse) Error() string { + return fmt.Sprintf("%d %s", e.Code, e.Message) +} diff --git a/providers/dns/stackpath/stackpath.go b/providers/dns/stackpath/stackpath.go index ee56a47e..97cfd8aa 100644 --- a/providers/dns/stackpath/stackpath.go +++ b/providers/dns/stackpath/stackpath.go @@ -6,19 +6,12 @@ import ( "context" "errors" "fmt" - "net/http" - "net/url" "time" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/platform/config/env" - "golang.org/x/oauth2/clientcredentials" -) - -const ( - defaultBaseURL = "https://gateway.stackpath.com/dns/v1/stacks/" - defaultAuthURL = "https://gateway.stackpath.com/identity/v1/oauth2/token" + "github.com/go-acme/lego/v4/providers/dns/stackpath/internal" ) // Environment variables names. @@ -55,9 +48,8 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - BaseURL *url.URL - client *http.Client - config *Config + config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance configured for Stackpath. @@ -91,30 +83,18 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("stackpath: stack id missing") } - baseURL, _ := url.Parse(defaultBaseURL) + client := internal.NewClient(context.Background(), config.StackID, config.ClientID, config.ClientSecret) - return &DNSProvider{ - BaseURL: baseURL, - client: getOathClient(config), - config: config, - }, nil -} - -func getOathClient(config *Config) *http.Client { - oathConfig := &clientcredentials.Config{ - TokenURL: defaultAuthURL, - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - } - - return oathConfig.Client(context.Background()) + return &DNSProvider{config: config, client: client}, nil } // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getZones(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.client.GetZones(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("stackpath: %w", err) } @@ -124,21 +104,23 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("stackpath: %w", err) } - record := Record{ + record := internal.Record{ Name: subDomain, Type: "TXT", TTL: d.config.TTL, Data: info.Value, } - return d.createZoneRecord(zone, record) + return d.client.CreateZoneRecord(ctx, zone, record) } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - zone, err := d.getZones(info.EffectiveFQDN) + ctx := context.Background() + + zone, err := d.client.GetZones(ctx, info.EffectiveFQDN) if err != nil { return fmt.Errorf("stackpath: %w", err) } @@ -148,13 +130,13 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("stackpath: %w", err) } - records, err := d.getZoneRecords(subDomain, zone) + records, err := d.client.GetZoneRecords(ctx, subDomain, zone) if err != nil { return err } for _, record := range records { - err = d.deleteZoneRecord(zone, record) + err = d.client.DeleteZoneRecord(ctx, zone, record) if err != nil { log.Printf("stackpath: failed to delete TXT record: %v", err) } diff --git a/providers/dns/stackpath/stackpath_test.go b/providers/dns/stackpath/stackpath_test.go index 1a575bbf..a72f268a 100644 --- a/providers/dns/stackpath/stackpath_test.go +++ b/providers/dns/stackpath/stackpath_test.go @@ -1,9 +1,6 @@ package stackpath import ( - "net/http" - "net/http/httptest" - "net/url" "testing" "time" @@ -135,132 +132,6 @@ func TestNewDNSProviderConfig(t *testing.T) { } } -func setupTest(t *testing.T) (*DNSProvider, *http.ServeMux) { - t.Helper() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - - config := NewDefaultConfig() - config.ClientID = "CLIENT_ID" - config.ClientSecret = "CLIENT_SECRET" - config.StackID = "STACK_ID" - - provider, err := NewDNSProviderConfig(config) - require.NoError(t, err) - - provider.client = http.DefaultClient - provider.BaseURL, _ = url.Parse(server.URL + "/") - - return provider, mux -} - -func TestDNSProvider_getZoneRecords(t *testing.T) { - provider, mux := setupTest(t) - - mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, _ *http.Request) { - content := ` - { - "records": [ - {"id":"1","name":"foo1","type":"TXT","ttl":120,"data":"txtTXTtxt"}, - {"id":"2","name":"foo2","type":"TXT","ttl":121,"data":"TXTtxtTXT"} - ] - }` - - _, err := w.Write([]byte(content)) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - - records, err := provider.getZoneRecords("foo1", &Zone{ID: "A", Domain: "test"}) - require.NoError(t, err) - - expected := []Record{ - {ID: "1", Name: "foo1", Type: "TXT", TTL: 120, Data: "txtTXTtxt"}, - {ID: "2", Name: "foo2", Type: "TXT", TTL: 121, Data: "TXTtxtTXT"}, - } - - assert.Equal(t, expected, records) -} - -func TestDNSProvider_getZoneRecords_apiError(t *testing.T) { - provider, mux := setupTest(t) - - mux.HandleFunc("/STACK_ID/zones/A/records", func(w http.ResponseWriter, _ *http.Request) { - content := ` -{ - "code": 401, - "error": "an unauthorized request is attempted." -}` - - w.WriteHeader(http.StatusUnauthorized) - _, err := w.Write([]byte(content)) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - - _, err := provider.getZoneRecords("foo1", &Zone{ID: "A", Domain: "test"}) - - expected := &ErrorResponse{Code: 401, Message: "an unauthorized request is attempted."} - assert.Equal(t, expected, err) -} - -func TestDNSProvider_getZones(t *testing.T) { - provider, mux := setupTest(t) - - mux.HandleFunc("/STACK_ID/zones", func(w http.ResponseWriter, _ *http.Request) { - content := ` -{ - "pageInfo": { - "totalCount": "5", - "hasPreviousPage": false, - "hasNextPage": false, - "startCursor": "1", - "endCursor": "1" - }, - "zones": [ - { - "stackId": "my_stack", - "accountId": "my_account", - "id": "A", - "domain": "foo.com", - "version": "1", - "labels": { - "property1": "val1", - "property2": "val2" - }, - "created": "2018-10-07T02:31:49Z", - "updated": "2018-10-07T02:31:49Z", - "nameservers": [ - "1.1.1.1" - ], - "verified": "2018-10-07T02:31:49Z", - "status": "ACTIVE", - "disabled": false - } - ] -}` - - _, err := w.Write([]byte(content)) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) - - zone, err := provider.getZones("sub.foo.com") - require.NoError(t, err) - - expected := &Zone{ID: "A", Domain: "foo.com"} - - assert.Equal(t, expected, zone) -} - func TestLivePresent(t *testing.T) { if !envTest.IsLiveTest() { t.Skip("skipping live test") diff --git a/providers/dns/tencentcloud/client.go b/providers/dns/tencentcloud/wrapper.go similarity index 96% rename from providers/dns/tencentcloud/client.go rename to providers/dns/tencentcloud/wrapper.go index 1c15ee38..af608bb3 100644 --- a/providers/dns/tencentcloud/client.go +++ b/providers/dns/tencentcloud/wrapper.go @@ -33,7 +33,7 @@ func (d *DNSProvider) getHostedZone(domain string) (*dnspod.DomainListItem, erro authZone, err := dns01.FindZoneByFqdn(domain) if err != nil { - return nil, err + return nil, fmt.Errorf("could not find zone for FQDN %q : %w", domain, err) } var hostedZone *dnspod.DomainListItem diff --git a/providers/dns/transip/transip.go b/providers/dns/transip/transip.go index 4b859bc6..e18f2f0f 100644 --- a/providers/dns/transip/transip.go +++ b/providers/dns/transip/transip.go @@ -95,7 +95,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("transip: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // get the subDomain @@ -127,7 +127,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return err + return fmt.Errorf("transip: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // get the subDomain diff --git a/providers/dns/ultradns/ultradns.go b/providers/dns/ultradns/ultradns.go index 5a1fb40c..2c39e9c0 100644 --- a/providers/dns/ultradns/ultradns.go +++ b/providers/dns/ultradns/ultradns.go @@ -105,7 +105,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("ultradns: %w", err) + return fmt.Errorf("ultradns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordService, err := record.Get(d.client) @@ -146,7 +146,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("ultradns: %w", err) + return fmt.Errorf("ultradns: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } recordService, err := record.Get(d.client) diff --git a/providers/dns/variomedia/internal/client.go b/providers/dns/variomedia/internal/client.go index 6df23dd6..4a671e88 100644 --- a/providers/dns/variomedia/internal/client.go +++ b/providers/dns/variomedia/internal/client.go @@ -2,22 +2,30 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://api.variomedia.de" +const authorizationHeader = "Authorization" + +// Client the API client for Variomedia. type Client struct { - apiToken string + apiToken string + baseURL *url.URL HTTPClient *http.Client } +// NewClient creates a new Client. func NewClient(apiToken string) *Client { baseURL, _ := url.Parse(defaultBaseURL) @@ -28,7 +36,9 @@ func NewClient(apiToken string) *Client { } } -func (c Client) CreateDNSRecord(record DNSRecord) (*CreateDNSRecordResponse, error) { +// CreateDNSRecord creates a new DNS entry. +// https://api.variomedia.de/docs/dns-records.html#erstellen +func (c Client) CreateDNSRecord(ctx context.Context, record DNSRecord) (*CreateDNSRecordResponse, error) { endpoint := c.baseURL.JoinPath("dns-records") data := CreateDNSRecordRequest{Data: Data{ @@ -36,12 +46,7 @@ func (c Client) CreateDNSRecord(record DNSRecord) (*CreateDNSRecordResponse, err Attributes: record, }} - body, err := json.Marshal(data) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), bytes.NewReader(body)) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, data) if err != nil { return nil, err } @@ -55,10 +60,12 @@ func (c Client) CreateDNSRecord(record DNSRecord) (*CreateDNSRecordResponse, err return &result, nil } -func (c Client) DeleteDNSRecord(id string) (*DeleteRecordResponse, error) { +// DeleteDNSRecord deletes a DNS record. +// https://api.variomedia.de/docs/dns-records.html#l%C3%B6schen +func (c Client) DeleteDNSRecord(ctx context.Context, id string) (*DeleteRecordResponse, error) { endpoint := c.baseURL.JoinPath("dns-records", id) - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return nil, err } @@ -72,10 +79,12 @@ func (c Client) DeleteDNSRecord(id string) (*DeleteRecordResponse, error) { return &result, nil } -func (c Client) GetJob(id string) (*GetJobResponse, error) { +// GetJob returns a single job based on its ID. +// https://api.variomedia.de/docs/job-queue.html +func (c Client) GetJob(ctx context.Context, id string) (*GetJobResponse, error) { endpoint := c.baseURL.JoinPath("queue-jobs", id) - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -89,39 +98,65 @@ func (c Client) GetJob(id string) (*GetJobResponse, error) { return &result, nil } -func (c Client) do(req *http.Request, data interface{}) error { - req.Header.Set("Content-Type", "application/vnd.api+json") - req.Header.Set("Accept", "application/vnd.variomedia.v1+json") - req.Header.Set("Authorization", "token "+c.apiToken) +func (c Client) do(req *http.Request, data any) error { + req.Header.Set(authorizationHeader, "token "+c.apiToken) resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { - all, _ := io.ReadAll(resp.Body) - - var e APIError - err = json.Unmarshal(all, &e) - if err != nil { - return fmt.Errorf("%d: %s", resp.StatusCode, string(all)) - } - - return e + return parseError(req, resp) } - content, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return err + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - err = json.Unmarshal(content, data) + err = json.Unmarshal(raw, data) if err != nil { - return fmt.Errorf("%w: %s", err, string(content)) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.variomedia.v1+json") + + if payload != nil { + req.Header.Set("Content-Type", "application/vnd.api+json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errAPI APIError + err := json.Unmarshal(raw, &errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return errAPI +} diff --git a/providers/dns/variomedia/internal/client_test.go b/providers/dns/variomedia/internal/client_test.go index a01e3037..c0017f24 100644 --- a/providers/dns/variomedia/internal/client_test.go +++ b/providers/dns/variomedia/internal/client_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "io" "net/http" @@ -13,7 +14,7 @@ import ( "github.com/stretchr/testify/require" ) -func setup(t *testing.T) (*Client, *http.ServeMux) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() @@ -36,7 +37,7 @@ func mockHandler(method string, filename string) http.HandlerFunc { filename = "./fixtures/" + filename statusCode := http.StatusOK - if req.Header.Get("Authorization") != "token secret" { + if req.Header.Get(authorizationHeader) != "token secret" { statusCode = http.StatusUnauthorized filename = "./fixtures/error.json" } @@ -59,7 +60,7 @@ func mockHandler(method string, filename string) http.HandlerFunc { } func TestClient_CreateDNSRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/dns-records", mockHandler(http.MethodPost, "POST_dns-records.json")) @@ -71,7 +72,7 @@ func TestClient_CreateDNSRecord(t *testing.T) { TTL: 300, } - resp, err := client.CreateDNSRecord(record) + resp, err := client.CreateDNSRecord(context.Background(), record) require.NoError(t, err) expected := &CreateDNSRecordResponse{ @@ -107,11 +108,11 @@ func TestClient_CreateDNSRecord(t *testing.T) { } func TestClient_DeleteDNSRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/dns-records/test", mockHandler(http.MethodDelete, "DELETE_dns-records_pending.json")) - resp, err := client.DeleteDNSRecord("test") + resp, err := client.DeleteDNSRecord(context.Background(), "test") require.NoError(t, err) expected := &DeleteRecordResponse{ @@ -142,11 +143,11 @@ func TestClient_DeleteDNSRecord(t *testing.T) { } func TestClient_GetJob(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/queue-jobs/test", mockHandler(http.MethodGet, "GET_queue-jobs.json")) - resp, err := client.GetJob("test") + resp, err := client.GetJob(context.Background(), "test") require.NoError(t, err) expected := &GetJobResponse{ diff --git a/providers/dns/variomedia/variomedia.go b/providers/dns/variomedia/variomedia.go index b4f21a56..e87220f4 100644 --- a/providers/dns/variomedia/variomedia.go +++ b/providers/dns/variomedia/variomedia.go @@ -2,6 +2,7 @@ package variomedia import ( + "context" "errors" "fmt" "net/http" @@ -16,8 +17,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/variomedia/internal" ) -const defaultTTL = 300 - // Environment variables names. const ( envNamespace = "VARIOMEDIA_" @@ -45,7 +44,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 300), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), SequenceInterval: env.GetOrDefaultSecond(EnvSequenceInterval, dns01.DefaultPropagationTimeout), @@ -83,10 +82,6 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("variomedia: missing credentials") } - if config.HTTPClient == nil { - config.HTTPClient = http.DefaultClient - } - client := internal.NewClient(config.APIToken) if config.HTTPClient != nil { @@ -118,7 +113,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("variomedia: %w", err) + return fmt.Errorf("variomedia: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -126,6 +121,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { return fmt.Errorf("variomedia: %w", err) } + ctx := context.Background() + record := internal.DNSRecord{ RecordType: "TXT", Name: subDomain, @@ -134,12 +131,12 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - cdrr, err := d.client.CreateDNSRecord(record) + cdrr, err := d.client.CreateDNSRecord(ctx, record) if err != nil { return fmt.Errorf("variomedia: %w", err) } - err = d.waitJob(domain, cdrr.Data.ID) + err = d.waitJob(ctx, domain, cdrr.Data.ID) if err != nil { return fmt.Errorf("variomedia: %w", err) } @@ -155,6 +152,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) + ctx := context.Background() + // get the record's unique ID from when we created it d.recordIDsMu.Lock() recordID, ok := d.recordIDs[token] @@ -163,12 +162,12 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("variomedia: unknown record ID for '%s'", info.EffectiveFQDN) } - ddrr, err := d.client.DeleteDNSRecord(recordID) + ddrr, err := d.client.DeleteDNSRecord(ctx, recordID) if err != nil { return fmt.Errorf("variomedia: %w", err) } - err = d.waitJob(domain, ddrr.Data.ID) + err = d.waitJob(ctx, domain, ddrr.Data.ID) if err != nil { return fmt.Errorf("variomedia: %w", err) } @@ -176,9 +175,9 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return nil } -func (d *DNSProvider) waitJob(domain string, id string) error { +func (d *DNSProvider) waitJob(ctx context.Context, domain string, id string) error { return wait.For("variomedia: apply change on "+domain, d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) { - result, err := d.client.GetJob(id) + result, err := d.client.GetJob(ctx, id) if err != nil { return false, err } diff --git a/providers/dns/vercel/internal/client.go b/providers/dns/vercel/internal/client.go index cf168edf..4bc59ba0 100644 --- a/providers/dns/vercel/internal/client.go +++ b/providers/dns/vercel/internal/client.go @@ -2,8 +2,8 @@ package internal import ( "bytes" + "context" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -11,65 +11,49 @@ import ( "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" + "golang.org/x/oauth2" ) const defaultBaseURL = "https://api.vercel.com" // Client Vercel client. type Client struct { - authToken string - teamID string + teamID string + baseURL *url.URL - HTTPClient *http.Client + httpClient *http.Client } // NewClient creates a Client. -func NewClient(authToken string, teamID string) *Client { +func NewClient(hc *http.Client, teamID string) *Client { baseURL, _ := url.Parse(defaultBaseURL) + if hc == nil { + hc = &http.Client{Timeout: 10 * time.Second} + } + return &Client{ - authToken: authToken, teamID: teamID, baseURL: baseURL, - HTTPClient: &http.Client{Timeout: 10 * time.Second}, + httpClient: hc, } } // CreateRecord creates a DNS record. // https://vercel.com/docs/rest-api#endpoints/dns/create-a-dns-record -func (c *Client) CreateRecord(zone string, record Record) (*CreateRecordResponse, error) { +func (c *Client) CreateRecord(ctx context.Context, zone string, record Record) (*CreateRecordResponse, error) { endpoint := c.baseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records") - body, err := json.Marshal(record) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { return nil, err } - req, err := c.newRequest(http.MethodPost, endpoint.String(), bytes.NewReader(body)) - if err != nil { - return nil, err - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= http.StatusBadRequest { - return nil, readError(req, resp) - } - - content, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.New(toUnreadableBodyMessage(req, content)) - } - - // Everything looks good; but we'll need the ID later to delete the record respData := &CreateRecordResponse{} - err = json.Unmarshal(content, respData) + err = c.do(req, respData) if err != nil { - return nil, fmt.Errorf("%w: %s", err, toUnreadableBodyMessage(req, content)) + return nil, err } return respData, nil @@ -77,60 +61,97 @@ func (c *Client) CreateRecord(zone string, record Record) (*CreateRecordResponse // DeleteRecord deletes a DNS record. // https://vercel.com/docs/rest-api#endpoints/dns/delete-a-dns-record -func (c *Client) DeleteRecord(zone string, recordID string) error { +func (c *Client) DeleteRecord(ctx context.Context, zone string, recordID string) error { endpoint := c.baseURL.JoinPath("v2", "domains", dns01.UnFqdn(zone), "records", recordID) - req, err := c.newRequest(http.MethodDelete, endpoint.String(), nil) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { return err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= http.StatusBadRequest { - return readError(req, resp) - } - - return nil + return c.do(req, nil) } -func (c *Client) newRequest(method, reqURL string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequest(method, reqURL, body) - if err != nil { - return nil, err - } - +func (c *Client) do(req *http.Request, result any) error { if c.teamID != "" { query := req.URL.Query() query.Add("teamId", c.teamID) req.URL.RawQuery = query.Encode() } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) + resp, err := c.httpClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } -func readError(req *http.Request, resp *http.Response) error { - content, err := io.ReadAll(resp.Body) +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var response APIErrorResponse + err := json.Unmarshal(raw, &response) if err != nil { - return errors.New(toUnreadableBodyMessage(req, content)) + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) } - var errInfo APIErrorResponse - err = json.Unmarshal(content, &errInfo) - if err != nil { - return fmt.Errorf("API Error unmarshaling error: %w: %s", err, toUnreadableBodyMessage(req, content)) + return fmt.Errorf("[status code: %d] %w", resp.StatusCode, response.Error) +} + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} } - return fmt.Errorf("HTTP %d: %w", resp.StatusCode, errInfo.Error) -} + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } -func toUnreadableBodyMessage(req *http.Request, rawBody []byte) string { - return fmt.Sprintf("the request %s sent a response with a body which is an invalid format: %q", req.URL, string(rawBody)) + return client } diff --git a/providers/dns/vercel/internal/client_test.go b/providers/dns/vercel/internal/client_test.go index 0d759c77..771349b2 100644 --- a/providers/dns/vercel/internal/client_test.go +++ b/providers/dns/vercel/internal/client_test.go @@ -1,6 +1,8 @@ package internal import ( + "bytes" + "context" "fmt" "io" "net/http" @@ -12,23 +14,21 @@ import ( "github.com/stretchr/testify/require" ) -func setup(t *testing.T) (*Client, *http.ServeMux) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() server := httptest.NewServer(mux) t.Cleanup(server.Close) - client := NewClient("secret", "123") - - client.HTTPClient = server.Client() + client := NewClient(OAuthStaticAccessToken(server.Client(), "secret"), "123") client.baseURL, _ = url.Parse(server.URL) return client, mux } func TestClient_CreateRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/v2/domains/example.com/records", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { @@ -55,7 +55,7 @@ func TestClient_CreateRecord(t *testing.T) { } expectedReqBody := `{"name":"_acme-challenge.example.com.","type":"TXT","value":"w6uP8Tcg6K2QR905Rms8iXTlksL6OD1KOWBxTK7wxPI","ttl":60}` - assert.Equal(t, expectedReqBody, string(reqBody)) + assert.Equal(t, expectedReqBody, string(bytes.TrimSpace(reqBody))) rw.WriteHeader(http.StatusOK) _, err = fmt.Fprintf(rw, `{ @@ -75,7 +75,7 @@ func TestClient_CreateRecord(t *testing.T) { TTL: 60, } - resp, err := client.CreateRecord("example.com.", record) + resp, err := client.CreateRecord(context.Background(), "example.com.", record) require.NoError(t, err) expected := &CreateRecordResponse{ @@ -87,7 +87,7 @@ func TestClient_CreateRecord(t *testing.T) { } func TestClient_DeleteRecord(t *testing.T) { - client, mux := setup(t) + client, mux := setupTest(t) mux.HandleFunc("/v2/domains/example.com/records/1234567", func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodDelete { @@ -109,6 +109,6 @@ func TestClient_DeleteRecord(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - err := client.DeleteRecord("example.com.", "1234567") + err := client.DeleteRecord(context.Background(), "example.com.", "1234567") require.NoError(t, err) } diff --git a/providers/dns/vercel/vercel.go b/providers/dns/vercel/vercel.go index 76ed7c96..efc401c4 100644 --- a/providers/dns/vercel/vercel.go +++ b/providers/dns/vercel/vercel.go @@ -2,6 +2,7 @@ package vercel import ( + "context" "errors" "fmt" "net/http" @@ -82,11 +83,7 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("vercel: credentials missing") } - client := internal.NewClient(config.AuthToken, config.TeamID) - - if config.HTTPClient != nil { - client.HTTPClient = config.HTTPClient - } + client := internal.NewClient(internal.OAuthStaticAccessToken(config.HTTPClient, config.AuthToken), config.TeamID) return &DNSProvider{ config: config, @@ -107,7 +104,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("vercel: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("vercel: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } record := internal.Record{ @@ -117,7 +114,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - respData, err := d.client.CreateRecord(authZone, record) + respData, err := d.client.CreateRecord(context.Background(), authZone, record) if err != nil { return fmt.Errorf("vercel: %w", err) } @@ -135,7 +132,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("vercel: %w", err) + return fmt.Errorf("vercel: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // get the record's unique ID from when we created it @@ -146,7 +143,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("vercel: unknown record ID for '%s'", info.EffectiveFQDN) } - err = d.client.DeleteRecord(authZone, recordID) + err = d.client.DeleteRecord(context.Background(), authZone, recordID) if err != nil { return fmt.Errorf("vercel: %w", err) } diff --git a/providers/dns/versio/client.go b/providers/dns/versio/client.go deleted file mode 100644 index b7ca67da..00000000 --- a/providers/dns/versio/client.go +++ /dev/null @@ -1,119 +0,0 @@ -package versio - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" -) - -const defaultBaseURL = "https://www.versio.nl/api/v1/" - -type dnsRecordsResponse struct { - Record dnsRecord `json:"domainInfo"` -} - -type dnsRecord struct { - DNSRecords []record `json:"dns_records"` -} - -type record struct { - Type string `json:"type,omitempty"` - Name string `json:"name,omitempty"` - Value string `json:"value,omitempty"` - Priority int `json:"prio,omitempty"` - TTL int `json:"ttl,omitempty"` -} - -type dnsErrorResponse struct { - Error errorMessage `json:"error"` -} - -type errorMessage struct { - Code int `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -func (d *DNSProvider) postDNSRecords(domain string, msg interface{}) error { - reqBody := &bytes.Buffer{} - err := json.NewEncoder(reqBody).Encode(msg) - if err != nil { - return err - } - - endpoint := d.config.BaseURL.JoinPath("domains", domain, "update") - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), reqBody) - if err != nil { - return err - } - - return d.do(req, nil) -} - -func (d *DNSProvider) getDNSRecords(domain string) (*dnsRecordsResponse, error) { - endpoint := d.config.BaseURL.JoinPath("domains", domain) - - query := endpoint.Query() - query.Set("show_dns_records", "true") - endpoint.RawQuery = query.Encode() - - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) - if err != nil { - return nil, err - } - - // we'll need all the dns_records to add the new TXT record - respData := &dnsRecordsResponse{} - err = d.do(req, respData) - if err != nil { - return nil, err - } - - return respData, nil -} - -func (d *DNSProvider) do(req *http.Request, result interface{}) error { - req.Header.Set("Content-Type", "application/json") - - if len(d.config.Username) > 0 && len(d.config.Password) > 0 { - req.SetBasicAuth(d.config.Username, d.config.Password) - } - - resp, err := d.config.HTTPClient.Do(req) - if resp != nil { - defer resp.Body.Close() - } - if err != nil { - return err - } - - if resp.StatusCode >= http.StatusBadRequest { - var body []byte - body, err = io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("%d: failed to read response body: %w", resp.StatusCode, err) - } - - respError := &dnsErrorResponse{} - err = json.Unmarshal(body, respError) - if err != nil { - return fmt.Errorf("%d: request failed: %s", resp.StatusCode, string(body)) - } - return fmt.Errorf("%d: request failed: %s", resp.StatusCode, respError.Error.Message) - } - - if result != nil { - content, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - - if err = json.Unmarshal(content, result); err != nil { - return fmt.Errorf("%w: %s", err, content) - } - } - - return nil -} diff --git a/providers/dns/versio/internal/client.go b/providers/dns/versio/internal/client.go new file mode 100644 index 00000000..6f70aacd --- /dev/null +++ b/providers/dns/versio/internal/client.go @@ -0,0 +1,149 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultBaseURL default API endpoint. +const DefaultBaseURL = "https://www.versio.nl/api/v1/" + +// Client the API client for Versio DNS. +type Client struct { + username string + password string + + BaseURL *url.URL + HTTPClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(username string, password string) *Client { + baseURL, _ := url.Parse(DefaultBaseURL) + + return &Client{ + username: username, + password: password, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// UpdateDomain updates domain information. +// https://www.versio.nl/RESTapidoc/#api-Domains-Update +func (c *Client) UpdateDomain(ctx context.Context, domain string, msg *DomainInfo) (*DomainInfoResponse, error) { + endpoint := c.BaseURL.JoinPath("domains", domain, "update") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, msg) + if err != nil { + return nil, err + } + + respData := &DomainInfoResponse{} + err = c.do(req, respData) + if err != nil { + return nil, err + } + + return respData, nil +} + +// GetDomain gets domain information. +// https://www.versio.nl/RESTapidoc/#api-Domains-Domain +func (c *Client) GetDomain(ctx context.Context, domain string) (*DomainInfoResponse, error) { + endpoint := c.BaseURL.JoinPath("domains", domain) + + query := endpoint.Query() + query.Set("show_dns_records", "true") + endpoint.RawQuery = query.Encode() + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + + respData := &DomainInfoResponse{} + err = c.do(req, respData) + if err != nil { + return nil, err + } + + return respData, nil +} + +func (c *Client) do(req *http.Request, result any) error { + if c.username != "" && c.password != "" { + req.SetBasicAuth(c.username, c.password) + } + + resp, err := c.HTTPClient.Do(req) + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + if resp.StatusCode/100 != 2 { + return parseError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if err = json.Unmarshal(raw, result); err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + response := &ErrorResponse{} + err := json.Unmarshal(raw, response) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return fmt.Errorf("[status code: %d] %w", resp.StatusCode, response.Message) +} diff --git a/providers/dns/versio/internal/client_test.go b/providers/dns/versio/internal/client_test.go new file mode 100644 index 00000000..f1015d28 --- /dev/null +++ b/providers/dns/versio/internal/client_test.go @@ -0,0 +1,179 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, pattern string, h http.HandlerFunc) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, h) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.BaseURL, _ = url.Parse(server.URL) + + return client +} + +func writeFixture(rw http.ResponseWriter, filename string) { + file, err := os.Open(filepath.Join("fixtures", filename)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + _, _ = io.Copy(rw, file) +} + +func TestClient_GetDomain(t *testing.T) { + client := setupTest(t, "/domains/example.com", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + auth := req.Header.Get("Authorization") + if auth != "Basic dXNlcjpzZWNyZXQ=" { + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) + return + } + + writeFixture(rw, "get-domain.json") + }) + + records, err := client.GetDomain(context.Background(), "example.com") + require.NoError(t, err) + + expected := &DomainInfoResponse{DomainInfo: DomainInfo{DNSRecords: []Record{ + {Type: "MX", Name: "example.com", Value: "fallback.axc.eu", Priority: 20, TTL: 3600}, + {Type: "TXT", Name: "example.com", Value: "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", Priority: 0, TTL: 3600}, + {Type: "A", Name: "example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "ftp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "localhost.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "pop.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "smtp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "www.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "dev.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "_domainkey.domain.com.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "MX", Name: "example.com", Value: "spamfilter2.axc.eu", Priority: 0, TTL: 3600}, + {Type: "A", Name: "redirect.example.com", Value: "localhost", Priority: 10, TTL: 14400}, + }}} + + assert.Equal(t, expected, records) +} + +func TestClient_GetDomain_error(t *testing.T) { + client := setupTest(t, "/domains/example.com", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + rw.WriteHeader(http.StatusUnauthorized) + + writeFixture(rw, "get-domain-error.json") + }) + + _, err := client.GetDomain(context.Background(), "example.com") + require.ErrorAs(t, err, &ErrorMessage{}) +} + +func TestClient_UpdateDomain(t *testing.T) { + client := setupTest(t, "/domains/example.com/update", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + auth := req.Header.Get("Authorization") + if auth != "Basic dXNlcjpzZWNyZXQ=" { + http.Error(rw, "invalid credentials: "+auth, http.StatusUnauthorized) + return + } + + writeFixture(rw, "update-domain.json") + }) + + msg := &DomainInfo{DNSRecords: []Record{ + {Type: "MX", Name: "example.com", Value: "fallback.axc.eu", Priority: 20, TTL: 3600}, + {Type: "TXT", Name: "example.com", Value: "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", Priority: 0, TTL: 3600}, + {Type: "A", Name: "example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "ftp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "localhost.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "pop.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "smtp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "www.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "dev.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "_domainkey.domain.com.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "MX", Name: "example.com", Value: "spamfilter2.axc.eu", Priority: 0, TTL: 3600}, + {Type: "A", Name: "redirect.example.com", Value: "localhost", Priority: 10, TTL: 14400}, + }} + + records, err := client.UpdateDomain(context.Background(), "example.com", msg) + require.NoError(t, err) + + expected := &DomainInfoResponse{DomainInfo: DomainInfo{DNSRecords: []Record{ + {Type: "MX", Name: "example.com", Value: "fallback.axc.eu", Priority: 20, TTL: 3600}, + {Type: "TXT", Name: "example.com", Value: "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", Priority: 0, TTL: 3600}, + {Type: "A", Name: "example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "ftp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "localhost.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "pop.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "smtp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "www.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "dev.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "_domainkey.domain.com.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "MX", Name: "example.com", Value: "spamfilter2.axc.eu", Priority: 0, TTL: 3600}, + {Type: "A", Name: "redirect.example.com", Value: "localhost", Priority: 10, TTL: 14400}, + }}} + + assert.Equal(t, expected, records) +} + +func TestClient_UpdateDomain_error(t *testing.T) { + client := setupTest(t, "/domains/example.com/update", func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + rw.WriteHeader(http.StatusUnauthorized) + + writeFixture(rw, "update-domain.json") + }) + + msg := &DomainInfo{DNSRecords: []Record{ + {Type: "MX", Name: "example.com", Value: "fallback.axc.eu", Priority: 20, TTL: 3600}, + {Type: "TXT", Name: "example.com", Value: "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", Priority: 0, TTL: 3600}, + {Type: "A", Name: "example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "ftp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "localhost.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "pop.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "smtp.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "www.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "dev.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "A", Name: "_domainkey.domain.com.example.com", Value: "185.13.227.159", Priority: 0, TTL: 14400}, + {Type: "MX", Name: "example.com", Value: "spamfilter2.axc.eu", Priority: 0, TTL: 3600}, + {Type: "A", Name: "redirect.example.com", Value: "localhost", Priority: 10, TTL: 14400}, + }} + + _, err := client.UpdateDomain(context.Background(), "example.com", msg) + require.ErrorAs(t, err, &ErrorMessage{}) +} diff --git a/providers/dns/versio/internal/fixtures/README.md b/providers/dns/versio/internal/fixtures/README.md new file mode 100644 index 00000000..b9564b0e --- /dev/null +++ b/providers/dns/versio/internal/fixtures/README.md @@ -0,0 +1,5 @@ + +Note: the snippets from the API documentation are wrong: +invalid field type (ex: prio, TTL), and JSON format contains errors. + +So the files inside the fixtures have been partially adapted to fit the reality. diff --git a/providers/dns/versio/internal/fixtures/get-domain-error.json b/providers/dns/versio/internal/fixtures/get-domain-error.json new file mode 100644 index 00000000..3250fbc7 --- /dev/null +++ b/providers/dns/versio/internal/fixtures/get-domain-error.json @@ -0,0 +1,6 @@ +{ + "error": { + "code": 401, + "message": "You are not authorized to access this resource. Did you supply the correct credentials and have you IP (xxxx:xxxx:xxx:xxxx:xxxx:xxxx:xxxx:xxxx) whitelisted for API use?" + } +} diff --git a/providers/dns/versio/internal/fixtures/get-domain.json b/providers/dns/versio/internal/fixtures/get-domain.json new file mode 100644 index 00000000..40bb0b91 --- /dev/null +++ b/providers/dns/versio/internal/fixtures/get-domain.json @@ -0,0 +1,120 @@ +{ + "domainInfo": { + "domain": "example.com", + "status": "OK", + "expire-date": "2020-10-01", + "registrant_id": "4334", + "reseller_id": "3253", + "category_id": "674", + "dnstemplate_id": "674", + "lock": false, + "auto_renew": false, + "epp_code": "3fFerggEg", + "ns": [], + "dns_management": true, + "dns_records": [ + { + "type": "MX", + "name": "example.com", + "value": "fallback.axc.eu", + "prio": 20, + "ttl": 3600 + }, + { + "type": "TXT", + "name": "example.com", + "value": "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", + "prio": 0, + "ttl": 3600 + }, + { + "type": "A", + "name": "example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "ftp.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "localhost.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "pop.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "smtp.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "www.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "dev.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "_domainkey.domain.com.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "MX", + "name": "example.com", + "value": "spamfilter2.axc.eu", + "prio": 0, + "ttl": 3600 + }, + { + "type": "A", + "name": "redirect.example.com", + "value": "localhost", + "prio": 10, + "ttl": 14400 + } + ], + "dns_redirections": [ + { + "from": "redirect.example.com", + "destination": "http:\/\/www.google.nl" + } + ], + "dnssec_keys": [ + { + "flags": 256, + "algorithm": 3, + "public_key": "AwEAAZKsuPDwO1+Usao2X1rgdFhdT3LAxy5cbRNFNEy1qsauwSIYov5SU4GlG6ylXIVQwHF5AWfbD7lcZzw1IlNegvaLnoirJjcYZhz4ppQU5+M/1hfH7aNZIsyz7AhHwX7gpOeUdGBXTiXQ3m7ksGccVQ79h7yl2fiBDCryBSf49vOTqo3dI7KZM48vmeqOxPth3ANMXzt6osHENGIchdGgIOVy5Y7AsVecL4V+lbn2t47fFfJ2O9PwuuDBzO0HCCT/mmYVsvZ33kgc7QPFKB3LojoXdHFHl1jCsC98phIVGzJR54H2xRohQvfC2WAXFEx+YNDW1yv7zQFrUVVMFwCe/E8=" + }, + { + "flags": 257, + "algorithm": 8, + "public_key": "AwEAAZKsuPDwO1+Usao2X1rgdFhdT3LAxy5cbRNFNEy1qsauwSIYov5SU4GlG6ylXIVQwHF5AWfbD7lcZzw1IlNegvaLnoirJjcYZhz4ppQU5+M/1hfH7aNZIsyz7AhHwX7gpOeUdGBXTiXQ3m7ksGccVQ79h7yl2fiBDCryBSf49vOTqo3dI7KZM48vmeqOxPth3ANMXzt6osHENGIchdGgIOVy5Y7AsVecL4V+lbn2t47fFfJ2O9PwuuDBzO0HCCT/mmYVsvZ33kgc7QPFKB3LojoXdHFHl1jCsC98phIVGzJR54H2xRohQvfC2WAXFEx+YNDW1yv7zQFrUVVMFwCe/E8=" + } + ] + } +} diff --git a/providers/dns/versio/internal/fixtures/update-domain-error.json b/providers/dns/versio/internal/fixtures/update-domain-error.json new file mode 100644 index 00000000..3250fbc7 --- /dev/null +++ b/providers/dns/versio/internal/fixtures/update-domain-error.json @@ -0,0 +1,6 @@ +{ + "error": { + "code": 401, + "message": "You are not authorized to access this resource. Did you supply the correct credentials and have you IP (xxxx:xxxx:xxx:xxxx:xxxx:xxxx:xxxx:xxxx) whitelisted for API use?" + } +} diff --git a/providers/dns/versio/internal/fixtures/update-domain.json b/providers/dns/versio/internal/fixtures/update-domain.json new file mode 100644 index 00000000..40bb0b91 --- /dev/null +++ b/providers/dns/versio/internal/fixtures/update-domain.json @@ -0,0 +1,120 @@ +{ + "domainInfo": { + "domain": "example.com", + "status": "OK", + "expire-date": "2020-10-01", + "registrant_id": "4334", + "reseller_id": "3253", + "category_id": "674", + "dnstemplate_id": "674", + "lock": false, + "auto_renew": false, + "epp_code": "3fFerggEg", + "ns": [], + "dns_management": true, + "dns_records": [ + { + "type": "MX", + "name": "example.com", + "value": "fallback.axc.eu", + "prio": 20, + "ttl": 3600 + }, + { + "type": "TXT", + "name": "example.com", + "value": "\"v=spf1 a mx ip4:127.0.0.1 a:spf.spamexperts.axc.nl ~all\"", + "prio": 0, + "ttl": 3600 + }, + { + "type": "A", + "name": "example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "ftp.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "localhost.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "pop.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "smtp.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "www.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "dev.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "A", + "name": "_domainkey.domain.com.example.com", + "value": "185.13.227.159", + "prio": 0, + "ttl": 14400 + }, + { + "type": "MX", + "name": "example.com", + "value": "spamfilter2.axc.eu", + "prio": 0, + "ttl": 3600 + }, + { + "type": "A", + "name": "redirect.example.com", + "value": "localhost", + "prio": 10, + "ttl": 14400 + } + ], + "dns_redirections": [ + { + "from": "redirect.example.com", + "destination": "http:\/\/www.google.nl" + } + ], + "dnssec_keys": [ + { + "flags": 256, + "algorithm": 3, + "public_key": "AwEAAZKsuPDwO1+Usao2X1rgdFhdT3LAxy5cbRNFNEy1qsauwSIYov5SU4GlG6ylXIVQwHF5AWfbD7lcZzw1IlNegvaLnoirJjcYZhz4ppQU5+M/1hfH7aNZIsyz7AhHwX7gpOeUdGBXTiXQ3m7ksGccVQ79h7yl2fiBDCryBSf49vOTqo3dI7KZM48vmeqOxPth3ANMXzt6osHENGIchdGgIOVy5Y7AsVecL4V+lbn2t47fFfJ2O9PwuuDBzO0HCCT/mmYVsvZ33kgc7QPFKB3LojoXdHFHl1jCsC98phIVGzJR54H2xRohQvfC2WAXFEx+YNDW1yv7zQFrUVVMFwCe/E8=" + }, + { + "flags": 257, + "algorithm": 8, + "public_key": "AwEAAZKsuPDwO1+Usao2X1rgdFhdT3LAxy5cbRNFNEy1qsauwSIYov5SU4GlG6ylXIVQwHF5AWfbD7lcZzw1IlNegvaLnoirJjcYZhz4ppQU5+M/1hfH7aNZIsyz7AhHwX7gpOeUdGBXTiXQ3m7ksGccVQ79h7yl2fiBDCryBSf49vOTqo3dI7KZM48vmeqOxPth3ANMXzt6osHENGIchdGgIOVy5Y7AsVecL4V+lbn2t47fFfJ2O9PwuuDBzO0HCCT/mmYVsvZ33kgc7QPFKB3LojoXdHFHl1jCsC98phIVGzJR54H2xRohQvfC2WAXFEx+YNDW1yv7zQFrUVVMFwCe/E8=" + } + ] + } +} diff --git a/providers/dns/versio/internal/types.go b/providers/dns/versio/internal/types.go new file mode 100644 index 00000000..44a5c53d --- /dev/null +++ b/providers/dns/versio/internal/types.go @@ -0,0 +1,32 @@ +package internal + +import "fmt" + +type DomainInfoResponse struct { + DomainInfo DomainInfo `json:"domainInfo"` +} + +type DomainInfo struct { + DNSRecords []Record `json:"dns_records"` +} + +type Record struct { + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Value string `json:"value,omitempty"` + Priority int `json:"prio,omitempty"` + TTL int `json:"ttl,omitempty"` +} + +type ErrorResponse struct { + Message ErrorMessage `json:"error"` +} + +type ErrorMessage struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +func (e ErrorMessage) Error() string { + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} diff --git a/providers/dns/versio/versio.go b/providers/dns/versio/versio.go index 6b84cf5d..bee7e526 100644 --- a/providers/dns/versio/versio.go +++ b/providers/dns/versio/versio.go @@ -2,6 +2,7 @@ package versio import ( + "context" "errors" "fmt" "net/http" @@ -11,6 +12,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/versio/internal" ) // Environment variables names. @@ -42,9 +44,9 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { - baseURL, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, defaultBaseURL)) + baseURL, err := url.Parse(env.GetOrDefaultString(EnvEndpoint, internal.DefaultBaseURL)) if err != nil { - baseURL, _ = url.Parse(defaultBaseURL) + baseURL, _ = url.Parse(internal.DefaultBaseURL) } return &Config{ @@ -61,7 +63,9 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config + config *Config + client *internal.Client + dnsEntriesMu sync.Mutex } @@ -91,7 +95,17 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("versio: the versio password is missing") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.Username, config.Password) + + if config.BaseURL != nil { + client.BaseURL = config.BaseURL + } + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + return &DNSProvider{config: config, client: client}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -106,30 +120,35 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("versio: %w", err) + return fmt.Errorf("versio: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // use mutex to prevent race condition from getDNSRecords until postDNSRecords d.dnsEntriesMu.Lock() defer d.dnsEntriesMu.Unlock() + ctx := context.Background() + zoneName := dns01.UnFqdn(authZone) - domains, err := d.getDNSRecords(zoneName) + + domains, err := d.client.GetDomain(ctx, zoneName) if err != nil { return fmt.Errorf("versio: %w", err) } - txtRecord := record{ + txtRecord := internal.Record{ Type: "TXT", Name: info.EffectiveFQDN, Value: `"` + info.Value + `"`, TTL: d.config.TTL, } - // Add new txtRercord to existing array of DNSRecords - msg := &domains.Record + + // Add new txtRecord to existing array of DNSRecords. + // We'll need all the dns_records to add a new TXT record. + msg := &domains.DomainInfo msg.DNSRecords = append(msg.DNSRecords, txtRecord) - err = d.postDNSRecords(zoneName, msg) + _, err = d.client.UpdateDomain(ctx, zoneName, msg) if err != nil { return fmt.Errorf("versio: %w", err) } @@ -142,28 +161,31 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("versio: %w", err) + return fmt.Errorf("versio: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // use mutex to prevent race condition from getDNSRecords until postDNSRecords d.dnsEntriesMu.Lock() defer d.dnsEntriesMu.Unlock() + ctx := context.Background() + zoneName := dns01.UnFqdn(authZone) - domains, err := d.getDNSRecords(zoneName) + + domains, err := d.client.GetDomain(ctx, zoneName) if err != nil { return fmt.Errorf("versio: %w", err) } // loop through the existing entries and remove the specific record - msg := &dnsRecord{} - for _, e := range domains.Record.DNSRecords { + msg := &internal.DomainInfo{} + for _, e := range domains.DomainInfo.DNSRecords { if e.Name != info.EffectiveFQDN { msg.DNSRecords = append(msg.DNSRecords, e) } } - err = d.postDNSRecords(zoneName, msg) + _, err = d.client.UpdateDomain(ctx, zoneName, msg) if err != nil { return fmt.Errorf("versio: %w", err) } diff --git a/providers/dns/versio/versio_test.go b/providers/dns/versio/versio_test.go index 7144d43a..09040ab4 100644 --- a/providers/dns/versio/versio_test.go +++ b/providers/dns/versio/versio_test.go @@ -135,12 +135,12 @@ func TestDNSProvider_Present(t *testing.T) { { desc: "FailToFindZone", handler: muxFailToFindZone(), - expectedError: `versio: 401: request failed: ObjectDoesNotExist|Domain not found`, + expectedError: `versio: [status code: 401] 401: ObjectDoesNotExist|Domain not found`, }, { desc: "FailToCreateTXT", handler: muxFailToCreateTXT(), - expectedError: `versio: 400: request failed: ProcessError|DNS record invalid type _acme-challenge.example.eu. TST`, + expectedError: `versio: [status code: 400] 400: ProcessError|DNS record invalid type _acme-challenge.example.eu. TST`, }, } @@ -182,7 +182,7 @@ func TestDNSProvider_CleanUp(t *testing.T) { { desc: "FailToFindZone", handler: muxFailToFindZone(), - expectedError: `versio: 401: request failed: ObjectDoesNotExist|Domain not found`, + expectedError: `versio: [status code: 401] 401: ObjectDoesNotExist|Domain not found`, }, } diff --git a/providers/dns/vinyldns/mock_test.go b/providers/dns/vinyldns/mock_test.go index b7f1e241..54fd8e21 100644 --- a/providers/dns/vinyldns/mock_test.go +++ b/providers/dns/vinyldns/mock_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func setup(t *testing.T) (*http.ServeMux, *DNSProvider) { +func setupTest(t *testing.T) (*http.ServeMux, *DNSProvider) { t.Helper() mux := http.NewServeMux() diff --git a/providers/dns/vinyldns/vinyldns.go b/providers/dns/vinyldns/vinyldns.go index 56ce273d..dca58fb9 100644 --- a/providers/dns/vinyldns/vinyldns.go +++ b/providers/dns/vinyldns/vinyldns.go @@ -8,7 +8,6 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" - "github.com/go-acme/lego/v4/platform/wait" "github.com/vinyldns/go-vinyldns/vinyldns" ) @@ -172,122 +171,3 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { return d.config.PropagationTimeout, d.config.PollingInterval } - -func (d *DNSProvider) getRecordSet(fqdn string) (*vinyldns.RecordSet, error) { - zoneName, hostName, err := splitDomain(fqdn) - if err != nil { - return nil, err - } - - zone, err := d.client.ZoneByName(zoneName) - if err != nil { - return nil, err - } - - allRecordSets, err := d.client.RecordSetsListAll(zone.ID, vinyldns.ListFilter{NameFilter: hostName}) - if err != nil { - return nil, err - } - - var recordSets []vinyldns.RecordSet - for _, i := range allRecordSets { - if i.Type == "TXT" { - recordSets = append(recordSets, i) - } - } - - switch { - case len(recordSets) > 1: - return nil, fmt.Errorf("ambiguous recordset definition of %s", fqdn) - case len(recordSets) == 1: - return &recordSets[0], nil - default: - return nil, nil - } -} - -func (d *DNSProvider) createRecordSet(fqdn string, records []vinyldns.Record) error { - zoneName, hostName, err := splitDomain(fqdn) - if err != nil { - return err - } - - zone, err := d.client.ZoneByName(zoneName) - if err != nil { - return err - } - - recordSet := vinyldns.RecordSet{ - Name: hostName, - ZoneID: zone.ID, - Type: "TXT", - TTL: d.config.TTL, - Records: records, - } - - resp, err := d.client.RecordSetCreate(&recordSet) - if err != nil { - return err - } - - return d.waitForChanges("CreateRS", resp) -} - -func (d *DNSProvider) updateRecordSet(recordSet *vinyldns.RecordSet, newRecords []vinyldns.Record) error { - operation := "delete" - if len(recordSet.Records) < len(newRecords) { - operation = "add" - } - - recordSet.Records = newRecords - recordSet.TTL = d.config.TTL - - resp, err := d.client.RecordSetUpdate(recordSet) - if err != nil { - return err - } - - return d.waitForChanges("UpdateRS - "+operation, resp) -} - -func (d *DNSProvider) deleteRecordSet(existingRecord *vinyldns.RecordSet) error { - resp, err := d.client.RecordSetDelete(existingRecord.ZoneID, existingRecord.ID) - if err != nil { - return err - } - - return d.waitForChanges("DeleteRS", resp) -} - -func (d *DNSProvider) waitForChanges(operation string, resp *vinyldns.RecordSetUpdateResponse) error { - return wait.For("vinyldns", d.config.PropagationTimeout, d.config.PollingInterval, - func() (bool, error) { - change, err := d.client.RecordSetChange(resp.Zone.ID, resp.RecordSet.ID, resp.ChangeID) - if err != nil { - return false, fmt.Errorf("failed to query change status: %w", err) - } - - if change.Status == "Complete" { - return true, nil - } - - return false, fmt.Errorf("waiting operation: %s, zoneID: %s, recordsetID: %s, changeID: %s", - operation, resp.Zone.ID, resp.RecordSet.ID, resp.ChangeID) - }, - ) -} - -// splitDomain splits the hostname from the authoritative zone, and returns both parts. -func splitDomain(fqdn string) (string, string, error) { - zone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", "", err - } - - subDomain, err := dns01.ExtractSubDomain(fqdn, zone) - if err != nil { - return "", "", err - } - - return zone, subDomain, nil -} diff --git a/providers/dns/vinyldns/vinyldns_test.go b/providers/dns/vinyldns/vinyldns_test.go index 1f7e7d5a..c05b5a73 100644 --- a/providers/dns/vinyldns/vinyldns_test.go +++ b/providers/dns/vinyldns/vinyldns_test.go @@ -192,7 +192,7 @@ func TestDNSProvider_Present(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - mux, p := setup(t) + mux, p := setupTest(t) mux.Handle("/", test.handler) err := p.Present(targetDomain, "token"+test.keyAuth, test.keyAuth) @@ -202,7 +202,7 @@ func TestDNSProvider_Present(t *testing.T) { } func TestDNSProvider_CleanUp(t *testing.T) { - mux, p := setup(t) + mux, p := setupTest(t) mux.Handle("/", newMockRouter(). Get("/zones/name/"+targetRootDomain+".", http.StatusOK, "zoneByName"). diff --git a/providers/dns/vinyldns/wrapper.go b/providers/dns/vinyldns/wrapper.go new file mode 100644 index 00000000..34b93e9e --- /dev/null +++ b/providers/dns/vinyldns/wrapper.go @@ -0,0 +1,128 @@ +package vinyldns + +import ( + "fmt" + + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/platform/wait" + "github.com/vinyldns/go-vinyldns/vinyldns" +) + +func (d *DNSProvider) getRecordSet(fqdn string) (*vinyldns.RecordSet, error) { + zoneName, hostName, err := splitDomain(fqdn) + if err != nil { + return nil, err + } + + zone, err := d.client.ZoneByName(zoneName) + if err != nil { + return nil, err + } + + allRecordSets, err := d.client.RecordSetsListAll(zone.ID, vinyldns.ListFilter{NameFilter: hostName}) + if err != nil { + return nil, err + } + + var recordSets []vinyldns.RecordSet + for _, i := range allRecordSets { + if i.Type == "TXT" { + recordSets = append(recordSets, i) + } + } + + switch { + case len(recordSets) > 1: + return nil, fmt.Errorf("ambiguous recordset definition of %s", fqdn) + case len(recordSets) == 1: + return &recordSets[0], nil + default: + return nil, nil + } +} + +func (d *DNSProvider) createRecordSet(fqdn string, records []vinyldns.Record) error { + zoneName, hostName, err := splitDomain(fqdn) + if err != nil { + return err + } + + zone, err := d.client.ZoneByName(zoneName) + if err != nil { + return err + } + + recordSet := vinyldns.RecordSet{ + Name: hostName, + ZoneID: zone.ID, + Type: "TXT", + TTL: d.config.TTL, + Records: records, + } + + resp, err := d.client.RecordSetCreate(&recordSet) + if err != nil { + return err + } + + return d.waitForChanges("CreateRS", resp) +} + +func (d *DNSProvider) updateRecordSet(recordSet *vinyldns.RecordSet, newRecords []vinyldns.Record) error { + operation := "delete" + if len(recordSet.Records) < len(newRecords) { + operation = "add" + } + + recordSet.Records = newRecords + recordSet.TTL = d.config.TTL + + resp, err := d.client.RecordSetUpdate(recordSet) + if err != nil { + return err + } + + return d.waitForChanges("UpdateRS - "+operation, resp) +} + +func (d *DNSProvider) deleteRecordSet(existingRecord *vinyldns.RecordSet) error { + resp, err := d.client.RecordSetDelete(existingRecord.ZoneID, existingRecord.ID) + if err != nil { + return err + } + + return d.waitForChanges("DeleteRS", resp) +} + +func (d *DNSProvider) waitForChanges(operation string, resp *vinyldns.RecordSetUpdateResponse) error { + return wait.For("vinyldns", d.config.PropagationTimeout, d.config.PollingInterval, + func() (bool, error) { + change, err := d.client.RecordSetChange(resp.Zone.ID, resp.RecordSet.ID, resp.ChangeID) + if err != nil { + return false, fmt.Errorf("failed to query change status: %w", err) + } + + if change.Status == "Complete" { + return true, nil + } + + return false, fmt.Errorf("waiting operation: %s, zoneID: %s, recordsetID: %s, changeID: %s", + operation, resp.Zone.ID, resp.RecordSet.ID, resp.ChangeID) + }, + ) +} + +// splitDomain splits the hostname from the authoritative zone, and returns both parts. +func splitDomain(fqdn string) (string, string, error) { + zone, err := dns01.FindZoneByFqdn(fqdn) + if err != nil { + return "", "", fmt.Errorf("could not find zone for FDQN %q: %w", fqdn, err) + } + + subDomain, err := dns01.ExtractSubDomain(fqdn, zone) + if err != nil { + return "", "", err + } + + return zone, subDomain, nil +} diff --git a/providers/dns/vkcloud/internal/client.go b/providers/dns/vkcloud/internal/client.go index a76293c3..5ced88d2 100644 --- a/providers/dns/vkcloud/internal/client.go +++ b/providers/dns/vkcloud/internal/client.go @@ -12,10 +12,10 @@ import ( // Client VK client. type Client struct { - baseURL *url.URL openstack *gophercloud.ProviderClient authOpts gophercloud.AuthOptions authenticated bool + baseURL *url.URL } // NewClient creates a Client. @@ -36,18 +36,18 @@ func NewClient(endpoint string, authOpts gophercloud.AuthOptions) (*Client, erro } return &Client{ - baseURL: baseURL, openstack: openstackClient, authOpts: authOpts, + baseURL: baseURL, }, nil } func (c *Client) ListZones() ([]DNSZone, error) { + endpoint := c.baseURL.JoinPath("/") + var zones []DNSZone opts := &gophercloud.RequestOpts{JSONResponse: &zones} - endpoint := c.baseURL.JoinPath("/") - err := c.request(http.MethodGet, endpoint, opts) if err != nil { return nil, err @@ -57,11 +57,11 @@ func (c *Client) ListZones() ([]DNSZone, error) { } func (c *Client) ListTXTRecords(zoneUUID string) ([]DNSTXTRecord, error) { + endpoint := c.baseURL.JoinPath(zoneUUID, "txt", "/") + var records []DNSTXTRecord opts := &gophercloud.RequestOpts{JSONResponse: &records} - endpoint := c.baseURL.JoinPath(zoneUUID, "txt", "/") - err := c.request(http.MethodGet, endpoint, opts) if err != nil { return nil, err @@ -71,13 +71,13 @@ func (c *Client) ListTXTRecords(zoneUUID string) ([]DNSTXTRecord, error) { } func (c *Client) CreateTXTRecord(zoneUUID string, record *DNSTXTRecord) error { + endpoint := c.baseURL.JoinPath(zoneUUID, "txt", "/") + opts := &gophercloud.RequestOpts{ JSONBody: record, JSONResponse: record, } - endpoint := c.baseURL.JoinPath(zoneUUID, "txt", "/") - return c.request(http.MethodPost, endpoint, opts) } diff --git a/providers/dns/vkcloud/vkcloud.go b/providers/dns/vkcloud/vkcloud.go index 9d1b03d9..775f4005 100644 --- a/providers/dns/vkcloud/vkcloud.go +++ b/providers/dns/vkcloud/vkcloud.go @@ -17,8 +17,6 @@ const ( defaultDNSEndpoint = "https://mcs.mail.ru/public-dns/v2/dns" ) -const defaultTTL = 60 - const defaultDomainName = "users" // Environment variables names. @@ -58,7 +56,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 60), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), } @@ -123,7 +121,7 @@ func (r *DNSProvider) Present(domain, _, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("vkcloud: %w", err) + return fmt.Errorf("vkcloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) @@ -163,7 +161,7 @@ func (r *DNSProvider) CleanUp(domain, _, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("vkcloud: %w", err) + return fmt.Errorf("vkcloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } authZone = dns01.UnFqdn(authZone) diff --git a/providers/dns/vscale/vscale.go b/providers/dns/vscale/vscale.go index 31a2de04..fa81f58d 100644 --- a/providers/dns/vscale/vscale.go +++ b/providers/dns/vscale/vscale.go @@ -4,9 +4,11 @@ package vscale import ( + "context" "errors" "fmt" "net/http" + "net/url" "time" "github.com/go-acme/lego/v4/challenge/dns01" @@ -87,8 +89,15 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { } client := selectel.NewClient(config.Token) - client.BaseURL = config.BaseURL - client.HTTPClient = config.HTTPClient + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + + var err error + client.BaseURL, err = url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("vscale: %w", err) + } return &DNSProvider{config: config, client: client}, nil } @@ -103,8 +112,10 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainObj, err := d.client.GetDomainByName(domain) + domainObj, err := d.client.GetDomainByName(ctx, domain) if err != nil { return fmt.Errorf("vscale: %w", err) } @@ -115,7 +126,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Name: info.EffectiveFQDN, Content: info.Value, } - _, err = d.client.AddRecord(domainObj.ID, txtRecord) + _, err = d.client.AddRecord(ctx, domainObj.ID, txtRecord) if err != nil { return fmt.Errorf("vscale: %w", err) } @@ -129,13 +140,15 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { recordName := dns01.UnFqdn(info.EffectiveFQDN) + ctx := context.Background() + // TODO(ldez) replace domain by FQDN to follow CNAME. - domainObj, err := d.client.GetDomainByName(domain) + domainObj, err := d.client.GetDomainByName(ctx, domain) if err != nil { return fmt.Errorf("vscale: %w", err) } - records, err := d.client.ListRecords(domainObj.ID) + records, err := d.client.ListRecords(ctx, domainObj.ID) if err != nil { return fmt.Errorf("vscale: %w", err) } @@ -144,7 +157,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { var lastErr error for _, record := range records { if record.Name == recordName { - err = d.client.DeleteRecord(domainObj.ID, record.ID) + err = d.client.DeleteRecord(ctx, domainObj.ID, record.ID) if err != nil { lastErr = fmt.Errorf("vscale: %w", err) } diff --git a/providers/dns/vultr/vultr.go b/providers/dns/vultr/vultr.go index c238e6cb..f63abc5f 100644 --- a/providers/dns/vultr/vultr.go +++ b/providers/dns/vultr/vultr.go @@ -78,17 +78,10 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("vultr: credentials missing") } - httpClient := config.HTTPClient - if httpClient == nil { - httpClient = &http.Client{ - Timeout: config.HTTPTimeout, - Transport: &oauth2.Transport{ - Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: config.APIKey}), - }, - } - } + authClient := OAuthStaticAccessToken(config.HTTPClient, config.APIKey) + authClient.Timeout = config.HTTPTimeout - client := govultr.NewClient(httpClient) + client := govultr.NewClient(authClient) return &DNSProvider{client: client, config: config}, nil } @@ -228,3 +221,16 @@ func (d *DNSProvider) findTxtRecords(ctx context.Context, domain, fqdn string) ( return zoneDomain, records, nil } + +func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client { + if client == nil { + client = &http.Client{Timeout: 5 * time.Second} + } + + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}), + Base: client.Transport, + } + + return client +} diff --git a/providers/dns/websupport/internal/client.go b/providers/dns/websupport/internal/client.go index 4cb803a4..cc40e9de 100644 --- a/providers/dns/websupport/internal/client.go +++ b/providers/dns/websupport/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" "encoding/hex" @@ -13,6 +14,8 @@ import ( "net/url" "strconv" "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const defaultBaseURL = "https://rest.websupport.sk" @@ -22,9 +25,10 @@ const StatusSuccess = "success" // Client a Websupport DNS API client. type Client struct { - apiKey string - secretKey string - BaseURL string + apiKey string + secretKey string + + baseURL *url.URL HTTPClient *http.Client } @@ -34,23 +38,22 @@ func NewClient(apiKey, secretKey string) (*Client, error) { return nil, errors.New("credentials missing") } + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ apiKey: apiKey, secretKey: secretKey, - BaseURL: defaultBaseURL, + baseURL: baseURL, HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } // GetUser gets a user detail. // https://rest.websupport.sk/docs/v1.user#user -func (c *Client) GetUser(userID string) (*User, error) { - endpoint, err := url.JoinPath(c.BaseURL, "v1", "user", userID) - if err != nil { - return nil, fmt.Errorf("base url parsing: %w", err) - } +func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) { + endpoint := c.baseURL.JoinPath("v1", "user", userID) - req, err := http.NewRequest(http.MethodGet, endpoint, http.NoBody) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("request payload: %w", err) } @@ -67,13 +70,10 @@ func (c *Client) GetUser(userID string) (*User, error) { // ListRecords lists all records. // https://rest.websupport.sk/docs/v1.zone#records -func (c *Client) ListRecords(domainName string) (*ListResponse, error) { - endpoint, err := url.JoinPath(c.BaseURL, "v1", "user", "self", "zone", domainName, "record") - if err != nil { - return nil, fmt.Errorf("base url parsing: %w", err) - } +func (c *Client) ListRecords(ctx context.Context, domainName string) (*ListResponse, error) { + endpoint := c.baseURL.JoinPath("v1", "user", "self", "zone", domainName, "record") - req, err := http.NewRequest(http.MethodGet, endpoint, http.NoBody) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("request payload: %w", err) } @@ -89,13 +89,10 @@ func (c *Client) ListRecords(domainName string) (*ListResponse, error) { } // GetRecords gets a DNS record. -func (c *Client) GetRecords(domainName string, recordID int) (*Record, error) { - endpoint, err := url.JoinPath(c.BaseURL, "v1", "user", "self", "zone", domainName, "record", strconv.Itoa(recordID)) - if err != nil { - return nil, fmt.Errorf("base url parsing: %w", err) - } +func (c *Client) GetRecords(ctx context.Context, domainName string, recordID int) (*Record, error) { + endpoint := c.baseURL.JoinPath("v1", "user", "self", "zone", domainName, "record", strconv.Itoa(recordID)) - req, err := http.NewRequest(http.MethodGet, endpoint, http.NoBody) + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } @@ -112,20 +109,12 @@ func (c *Client) GetRecords(domainName string, recordID int) (*Record, error) { // AddRecord adds a DNS record. // https://rest.websupport.sk/docs/v1.zone#post-record -func (c *Client) AddRecord(domainName string, record Record) (*Response, error) { - endpoint, err := url.JoinPath(c.BaseURL, "v1", "user", "self", "zone", domainName, "record") - if err != nil { - return nil, fmt.Errorf("base url parsing: %w", err) - } +func (c *Client) AddRecord(ctx context.Context, domainName string, record Record) (*Response, error) { + endpoint := c.baseURL.JoinPath("v1", "user", "self", "zone", domainName, "record") - payload, err := json.Marshal(record) + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) if err != nil { - return nil, fmt.Errorf("request payload: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(payload)) - if err != nil { - return nil, err + return nil, fmt.Errorf("create request: %w", err) } result := &Response{} @@ -140,15 +129,12 @@ func (c *Client) AddRecord(domainName string, record Record) (*Response, error) // DeleteRecord deletes a DNS record. // https://rest.websupport.sk/docs/v1.zone#delete-record -func (c *Client) DeleteRecord(domainName string, recordID int) (*Response, error) { - endpoint, err := url.JoinPath(c.BaseURL, "v1", "user", "self", "zone", domainName, "record", strconv.Itoa(recordID)) - if err != nil { - return nil, fmt.Errorf("base url parsing: %w", err) - } +func (c *Client) DeleteRecord(ctx context.Context, domainName string, recordID int) (*Response, error) { + endpoint := c.baseURL.JoinPath("v1", "user", "self", "zone", domainName, "record", strconv.Itoa(recordID)) - req, err := http.NewRequest(http.MethodDelete, endpoint, http.NoBody) + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) if err != nil { - return nil, fmt.Errorf("request payload: %w", err) + return nil, fmt.Errorf("create request: %w", err) } result := &Response{} @@ -162,8 +148,6 @@ func (c *Client) DeleteRecord(domainName string, recordID int) (*Response, error } func (c *Client) do(req *http.Request, result any) error { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") req.Header.Set("Accept-Language", "en_us") location, err := time.LoadLocation("GMT") @@ -178,31 +162,23 @@ func (c *Client) do(req *http.Request, result any) error { resp, err := c.HTTPClient.Do(req) if err != nil { - return err + return errutils.NewHTTPDoError(req, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode > http.StatusBadRequest { - all, _ := io.ReadAll(resp.Body) - - var e APIError - err = json.Unmarshal(all, &e) - if err != nil { - return fmt.Errorf("%d: %s", resp.StatusCode, string(all)) - } - - return &e + return parseError(req, resp) } - all, err := io.ReadAll(resp.Body) + raw, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("read response body: %w", err) + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - err = json.Unmarshal(all, result) + err = json.Unmarshal(raw, result) if err != nil { - return fmt.Errorf("unmarshal response body: %w", err) + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) } return nil @@ -230,3 +206,39 @@ func (c *Client) sign(req *http.Request, now time.Time) error { return nil } + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} + +func parseError(req *http.Request, resp *http.Response) error { + raw, _ := io.ReadAll(resp.Body) + + var errAPI APIError + err := json.Unmarshal(raw, &errAPI) + if err != nil { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + return &errAPI +} diff --git a/providers/dns/websupport/internal/client_test.go b/providers/dns/websupport/internal/client_test.go index 312d74f3..9612f609 100644 --- a/providers/dns/websupport/internal/client_test.go +++ b/providers/dns/websupport/internal/client_test.go @@ -1,11 +1,13 @@ package internal import ( + "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -46,7 +48,7 @@ func setupTest(t *testing.T, method, pattern string, status int, file string) *C require.NoError(t, err) client.HTTPClient = server.Client() - client.BaseURL = server.URL + client.baseURL, _ = url.Parse(server.URL) return client } @@ -54,7 +56,7 @@ func setupTest(t *testing.T, method, pattern string, status int, file string) *C func TestClient_GetUser(t *testing.T) { client := setupTest(t, http.MethodGet, "/v1/user/self", http.StatusOK, "./fixtures/get-user.json") - user, err := client.GetUser("self") + user, err := client.GetUser(context.Background(), "self") require.NoError(t, err) expected := &User{ @@ -89,7 +91,7 @@ func TestClient_GetUser(t *testing.T) { func TestClient_ListRecords(t *testing.T) { client := setupTest(t, http.MethodGet, "/v1/user/self/zone/example.com/record", http.StatusOK, "./fixtures/list-records.json") - resp, err := client.ListRecords("example.com") + resp, err := client.ListRecords(context.Background(), "example.com") require.NoError(t, err) expected := &ListResponse{ @@ -124,7 +126,7 @@ func TestClient_AddRecord(t *testing.T) { TTL: 600, } - resp, err := client.AddRecord("example.com", record) + resp, err := client.AddRecord(context.Background(), "example.com", record) require.NoError(t, err) expected := &Response{ @@ -157,7 +159,7 @@ func TestClient_AddRecord_error_400(t *testing.T) { TTL: 600, } - resp, err := client.AddRecord("example.com", record) + resp, err := client.AddRecord(context.Background(), "example.com", record) require.NoError(t, err) assert.Equal(t, "error", resp.Status) @@ -190,7 +192,7 @@ func TestClient_AddRecord_error_404(t *testing.T) { TTL: 600, } - resp, err := client.AddRecord("example.com", record) + resp, err := client.AddRecord(context.Background(), "example.com", record) require.Error(t, err) assert.Nil(t, resp) @@ -199,7 +201,7 @@ func TestClient_AddRecord_error_404(t *testing.T) { func TestClient_DeleteRecord(t *testing.T) { client := setupTest(t, http.MethodDelete, "/v1/user/self/zone/example.com/record/123", http.StatusOK, "./fixtures/delete-record.json") - resp, err := client.DeleteRecord("example.com", 123) + resp, err := client.DeleteRecord(context.Background(), "example.com", 123) require.NoError(t, err) expected := &Response{ @@ -225,7 +227,7 @@ func TestClient_DeleteRecord(t *testing.T) { func TestClient_DeleteRecord_error(t *testing.T) { client := setupTest(t, http.MethodDelete, "/v1/user/self/zone/example.com/record/123", http.StatusNotFound, "./fixtures/delete-record-error-404.json") - resp, err := client.DeleteRecord("example.com", 123) + resp, err := client.DeleteRecord(context.Background(), "example.com", 123) require.Error(t, err) assert.Nil(t, resp) diff --git a/providers/dns/websupport/internal/types.go b/providers/dns/websupport/internal/types.go index cada90ce..0923282a 100644 --- a/providers/dns/websupport/internal/types.go +++ b/providers/dns/websupport/internal/types.go @@ -30,8 +30,8 @@ type Zone struct { } type Response struct { - Status string `json:"status"` Item *Record `json:"item"` + Status string `json:"status"` Errors json.RawMessage `json:"errors"` } @@ -72,13 +72,13 @@ func (e *Errors) Error() string { // ParseError extract error from Response. func ParseError(resp *Response) error { - var apiError Errors - err := json.Unmarshal(resp.Errors, &apiError) + var errAPI Errors + err := json.Unmarshal(resp.Errors, &errAPI) if err != nil { return err } - return &apiError + return &errAPI } type User struct { diff --git a/providers/dns/websupport/websupport.go b/providers/dns/websupport/websupport.go index c74cd142..a8b74010 100644 --- a/providers/dns/websupport/websupport.go +++ b/providers/dns/websupport/websupport.go @@ -2,6 +2,7 @@ package websupport import ( + "context" "errors" "fmt" "net/http" @@ -13,8 +14,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/websupport/internal" ) -const defaultTTL = 600 - // Environment variables names. const ( envNamespace = "WEBSUPPORT_" @@ -44,7 +43,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 600), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), SequenceInterval: env.GetOrDefaultSecond(EnvSequenceInterval, dns01.DefaultPropagationTimeout), @@ -106,7 +105,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("websupport: %w", err) + return fmt.Errorf("websupport: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -121,7 +120,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { TTL: d.config.TTL, } - resp, err := d.client.AddRecord(dns01.UnFqdn(authZone), record) + resp, err := d.client.AddRecord(context.Background(), dns01.UnFqdn(authZone), record) if err != nil { return fmt.Errorf("websupport: add record: %w", err) } @@ -148,7 +147,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("websupport: %w", err) + return fmt.Errorf("websupport: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } // gets the record's unique ID @@ -159,7 +158,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("websupport: unknown record ID for '%s' '%s'", info.EffectiveFQDN, token) } - resp, err := d.client.DeleteRecord(dns01.UnFqdn(authZone), recordID) + resp, err := d.client.DeleteRecord(context.Background(), dns01.UnFqdn(authZone), recordID) if err != nil { return fmt.Errorf("websupport: delete record: %w", err) } diff --git a/providers/dns/wedos/internal/client.go b/providers/dns/wedos/internal/client.go index 2d0f94ac..defcabf6 100644 --- a/providers/dns/wedos/internal/client.go +++ b/providers/dns/wedos/internal/client.go @@ -11,61 +11,21 @@ import ( "time" "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" ) const baseURL = "https://api.wedos.com/wapi/json" -const codeOk = 1000 - -const ( - commandPing = "ping" - commandDNSDomainCommit = "dns-domain-commit" - commandDNSRowsList = "dns-rows-list" - commandDNSRowDelete = "dns-row-delete" - commandDNSRowAdd = "dns-row-add" - commandDNSRowUpdate = "dns-row-update" -) - -type ResponsePayload struct { - Code int `json:"code,omitempty"` - Result string `json:"result,omitempty"` - Timestamp int `json:"timestamp,omitempty"` - SvTRID string `json:"svTRID,omitempty"` - Command string `json:"command,omitempty"` - Data json.RawMessage `json:"data"` -} - -type DNSRow struct { - ID string `json:"ID,omitempty"` - Name string `json:"name,omitempty"` - TTL json.Number `json:"ttl,omitempty" type:"integer"` - Type string `json:"rdtype,omitempty"` - Data string `json:"rdata"` -} - -type DNSRowRequest struct { - ID string `json:"row_id,omitempty"` - Domain string `json:"domain,omitempty"` - Name string `json:"name,omitempty"` - TTL json.Number `json:"ttl,omitempty" type:"integer"` - Type string `json:"type,omitempty"` - Data string `json:"rdata"` -} - -type APIRequest struct { - User string `json:"user,omitempty"` - Auth string `json:"auth,omitempty"` - Command string `json:"command,omitempty"` - Data interface{} `json:"data,omitempty"` -} - +// Client the API client for Webos. type Client struct { - username string - password string + username string + password string + baseURL string HTTPClient *http.Client } +// NewClient creates a new Client. func NewClient(username string, password string) *Client { return &Client{ username: username, @@ -78,25 +38,23 @@ func NewClient(username string, password string) *Client { // GetRecords lists all the records in the zone. // https://kb.wedos.com/en/wapi-api-interface/wapi-command-dns-rows-list/ func (c *Client) GetRecords(ctx context.Context, zone string) ([]DNSRow, error) { - payload := map[string]interface{}{ + payload := map[string]any{ "domain": dns01.UnFqdn(zone), } - resp, err := c.do(ctx, commandDNSRowsList, payload) + req, err := c.newRequest(ctx, commandDNSRowsList, payload) if err != nil { return nil, err } - arrayWrapper := struct { - Rows []DNSRow `json:"row"` - }{} + result := APIResponse[Rows]{} - err = json.Unmarshal(resp.Data, &arrayWrapper) + err = c.do(req, &result) if err != nil { return nil, err } - return arrayWrapper.Rows, err + return result.Response.Data.Rows, err } // AddRecord adds a record in the zone, either by updating existing records or creating new ones. @@ -118,12 +76,12 @@ func (c *Client) AddRecord(ctx context.Context, zone string, record DNSRow) erro payload.ID = record.ID } - _, err := c.do(ctx, cmd, payload) + req, err := c.newRequest(ctx, cmd, payload) if err != nil { return err } - return nil + return c.do(req, &APIResponse[json.RawMessage]{}) } // DeleteRecord deletes a record from the zone. @@ -135,40 +93,67 @@ func (c *Client) DeleteRecord(ctx context.Context, zone string, recordID string) ID: recordID, } - _, err := c.do(ctx, commandDNSRowDelete, payload) + req, err := c.newRequest(ctx, commandDNSRowDelete, payload) if err != nil { return err } - return nil + return c.do(req, &APIResponse[json.RawMessage]{}) } // Commit not really required, all changes will be auto-committed after 5 minutes. // https://kb.wedos.com/en/wapi-api-interface/wapi-command-dns-domain-commit/ func (c *Client) Commit(ctx context.Context, zone string) error { - payload := map[string]interface{}{ + payload := map[string]any{ "name": dns01.UnFqdn(zone), } - _, err := c.do(ctx, commandDNSDomainCommit, payload) + req, err := c.newRequest(ctx, commandDNSDomainCommit, payload) if err != nil { return err } - return nil + return c.do(req, &APIResponse[json.RawMessage]{}) } func (c *Client) Ping(ctx context.Context) error { - _, err := c.do(ctx, commandPing, nil) + req, err := c.newRequest(ctx, commandPing, nil) if err != nil { return err } - return nil + return c.do(req, &APIResponse[json.RawMessage]{}) } -func (c *Client) do(ctx context.Context, command string, payload interface{}) (*ResponsePayload, error) { - requestObject := map[string]interface{}{ +func (c *Client) do(req *http.Request, result Response) error { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + if result.GetCode() != codeOk { + return fmt.Errorf("error %d: %s", result.GetCode(), result.GetResult()) + } + + return err +} + +func (c *Client) newRequest(ctx context.Context, command string, payload any) (*http.Request, error) { + requestObject := map[string]any{ "request": APIRequest{ User: c.username, Auth: authToken(c.username, c.password), @@ -177,46 +162,20 @@ func (c *Client) do(ctx context.Context, command string, payload interface{}) (* }, } - jsonBytes, err := json.Marshal(requestObject) + object, err := json.Marshal(requestObject) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create request JSON body: %w", err) } form := url.Values{} - form.Add("request", string(jsonBytes)) + form.Add("request", string(object)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, strings.NewReader(form.Encode())) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create request: %w", err) } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode/100 != 2 { - return nil, fmt.Errorf("API error, status code: %d", resp.StatusCode) - } - - responseWrapper := struct { - Response ResponsePayload `json:"response"` - }{} - - err = json.Unmarshal(body, &responseWrapper) - if err != nil { - return nil, err - } - - if responseWrapper.Response.Code != codeOk { - return nil, fmt.Errorf("wedos responded with error code %d = %s", responseWrapper.Response.Code, responseWrapper.Response.Result) - } - - return &responseWrapper.Response, err + return req, nil } diff --git a/providers/dns/wedos/internal/types.go b/providers/dns/wedos/internal/types.go new file mode 100644 index 00000000..bb8194b8 --- /dev/null +++ b/providers/dns/wedos/internal/types.go @@ -0,0 +1,68 @@ +package internal + +import "encoding/json" + +const codeOk = 1000 + +const ( + commandPing = "ping" + commandDNSDomainCommit = "dns-domain-commit" + commandDNSRowsList = "dns-rows-list" + commandDNSRowDelete = "dns-row-delete" + commandDNSRowAdd = "dns-row-add" + commandDNSRowUpdate = "dns-row-update" +) + +type Response interface { + GetCode() int + GetResult() string +} + +type APIResponse[D any] struct { + Response ResponsePayload[D] `json:"response"` +} + +func (a APIResponse[D]) GetCode() int { + return a.Response.Code +} + +func (a APIResponse[D]) GetResult() string { + return a.Response.Result +} + +type ResponsePayload[D any] struct { + Code int `json:"code,omitempty"` + Result string `json:"result,omitempty"` + Timestamp int `json:"timestamp,omitempty"` + SvTRID string `json:"svTRID,omitempty"` + Command string `json:"command,omitempty"` + Data D `json:"data"` +} + +type Rows struct { + Rows []DNSRow `json:"row"` +} + +type DNSRow struct { + ID string `json:"ID,omitempty"` + Name string `json:"name,omitempty"` + TTL json.Number `json:"ttl,omitempty"` + Type string `json:"rdtype,omitempty"` + Data string `json:"rdata"` +} + +type DNSRowRequest struct { + ID string `json:"row_id,omitempty"` + Domain string `json:"domain,omitempty"` + Name string `json:"name,omitempty"` + TTL json.Number `json:"ttl,omitempty"` + Type string `json:"type,omitempty"` + Data string `json:"rdata"` +} + +type APIRequest struct { + User string `json:"user,omitempty"` + Auth string `json:"auth,omitempty"` + Command string `json:"command,omitempty"` + Data any `json:"data,omitempty"` +} diff --git a/providers/dns/wedos/wedos.go b/providers/dns/wedos/wedos.go index 95e17302..8fffd3ad 100644 --- a/providers/dns/wedos/wedos.go +++ b/providers/dns/wedos/wedos.go @@ -108,7 +108,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("wedos: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("wedos: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) @@ -156,7 +156,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("wedos: could not determine zone for domain %q: %w", domain, err) + return fmt.Errorf("wedos: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone) diff --git a/providers/dns/yandex/internal/client.go b/providers/dns/yandex/internal/client.go index acf98dc3..5d7e6bff 100644 --- a/providers/dns/yandex/internal/client.go +++ b/providers/dns/yandex/internal/client.go @@ -1,12 +1,17 @@ package internal import ( + "bytes" + "context" "encoding/json" "errors" "fmt" + "io" "net/http" - "strings" + "net/url" + "time" + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" "github.com/google/go-querystring/query" ) @@ -17,119 +22,139 @@ const successCode = "ok" const pddTokenHeader = "PddToken" type Client struct { + pddToken string + + baseURL *url.URL HTTPClient *http.Client - BaseURL string - pddToken string } func NewClient(pddToken string) (*Client, error) { if pddToken == "" { return nil, errors.New("PDD token is required") } + + baseURL, _ := url.Parse(defaultBaseURL) + return &Client{ - HTTPClient: &http.Client{}, - BaseURL: defaultBaseURL, pddToken: pddToken, + baseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, }, nil } -func (c *Client) AddRecord(data Record) (*Record, error) { - resp, err := c.postForm("/add", data) +func (c *Client) AddRecord(ctx context.Context, payload Record) (*Record, error) { + endpoint := c.baseURL.JoinPath("add") + + req, err := newRequest(ctx, http.MethodPost, endpoint, payload) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API response error: %d", resp.StatusCode) - } r := AddResponse{} - err = json.NewDecoder(resp.Body).Decode(&r) + err = c.do(req, &r) if err != nil { return nil, err } - if r.Success != successCode { - return nil, fmt.Errorf("error during record addition: %s", r.Error) - } - return r.Record, nil } -func (c *Client) RemoveRecord(data Record) (int, error) { - resp, err := c.postForm("/del", data) +func (c *Client) RemoveRecord(ctx context.Context, payload Record) (int, error) { + endpoint := c.baseURL.JoinPath("del") + + req, err := newRequest(ctx, http.MethodPost, endpoint, payload) if err != nil { return 0, err } - defer func() { _ = resp.Body.Close() }() r := RemoveResponse{} - err = json.NewDecoder(resp.Body).Decode(&r) + err = c.do(req, &r) if err != nil { return 0, err } - if r.Success != successCode { - return 0, fmt.Errorf("error during record addition: %s", r.Error) - } - return r.RecordID, nil } -func (c *Client) GetRecords(domain string) ([]Record, error) { - resp, err := c.get("/list", struct { +func (c *Client) GetRecords(ctx context.Context, domain string) ([]Record, error) { + endpoint := c.baseURL.JoinPath("list") + + payload := struct { Domain string `url:"domain"` - }{Domain: domain}) + }{Domain: domain} + + req, err := newRequest(ctx, http.MethodGet, endpoint, payload) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() r := ListResponse{} - err = json.NewDecoder(resp.Body).Decode(&r) + err = c.do(req, &r) if err != nil { return nil, err } - if r.Success != successCode { - return nil, fmt.Errorf("error during record addition: %s", r.Error) - } - return r.Records, nil } -func (c *Client) postForm(uri string, data interface{}) (*http.Response, error) { - values, err := query.Values(data) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(http.MethodPost, c.BaseURL+uri, strings.NewReader(values.Encode())) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") +func (c *Client) do(req *http.Request, result Response) error { req.Header.Set(pddTokenHeader, c.pddToken) - return c.HTTPClient.Do(req) -} - -func (c *Client) get(uri string, data interface{}) (*http.Response, error) { - req, err := http.NewRequest(http.MethodGet, c.BaseURL+uri, nil) + resp, err := c.HTTPClient.Do(req) if err != nil { - return nil, err + return errutils.NewHTTPDoError(req, err) } - req.Header.Set(pddTokenHeader, c.pddToken) + defer func() { _ = resp.Body.Close() }() - values, err := query.Values(data) + raw, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return errutils.NewReadResponseError(req, resp.StatusCode, err) } - req.URL.RawQuery = values.Encode() + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } - return c.HTTPClient.Do(req) + if result.GetSuccess() != successCode { + return fmt.Errorf("error during operation: %s %s", result.GetSuccess(), result.GetError()) + } + + return nil +} + +func newRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + switch method { + case http.MethodPost: + values, err := query.Values(payload) + if err != nil { + return nil, err + } + + buf.WriteString(values.Encode()) + + case http.MethodGet: + values, err := query.Values(payload) + if err != nil { + return nil, err + } + + endpoint.RawQuery = values.Encode() + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + if method == http.MethodPost { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + + return req, nil } diff --git a/providers/dns/yandex/internal/client_test.go b/providers/dns/yandex/internal/client_test.go index 7b7bb589..346b1ff6 100644 --- a/providers/dns/yandex/internal/client_test.go +++ b/providers/dns/yandex/internal/client_test.go @@ -1,16 +1,18 @@ package internal import ( + "context" "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setupTest(t *testing.T) (*http.ServeMux, *Client) { +func setupTest(t *testing.T) (*Client, *http.ServeMux) { t.Helper() mux := http.NewServeMux() @@ -21,9 +23,9 @@ func setupTest(t *testing.T) (*http.ServeMux, *Client) { require.NoError(t, err) client.HTTPClient = server.Client() - client.BaseURL = server.URL + client.baseURL, _ = url.Parse(server.URL) - return mux, client + return client, mux } func TestAddRecord(t *testing.T) { @@ -58,7 +60,9 @@ func TestAddRecord(t *testing.T) { Content: "txtTXTtxtTXTtxtTXT", TTL: 300, }, - Success: "ok", + BaseResponse: BaseResponse{ + Success: "ok", + }, } err = json.NewEncoder(w).Encode(response) @@ -90,9 +94,11 @@ func TestAddRecord(t *testing.T) { assert.Equal(t, `content=txtTXTtxtTXTtxtTXT&domain=example.com&subdomain=foo&ttl=300&type=TXT`, r.PostForm.Encode()) response := AddResponse{ - Domain: "example.com", - Success: "error", - Error: "bad things", + Domain: "example.com", + BaseResponse: BaseResponse{ + Success: "error", + Error: "bad things", + }, } err = json.NewEncoder(w).Encode(response) @@ -114,11 +120,11 @@ func TestAddRecord(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/add", test.handler) - record, err := client.AddRecord(test.data) + record, err := client.AddRecord(context.Background(), test.data) if test.expectError { require.Error(t, err) require.Nil(t, record) @@ -154,7 +160,9 @@ func TestRemoveRecord(t *testing.T) { response := RemoveResponse{ Domain: "example.com", RecordID: 6, - Success: "ok", + BaseResponse: BaseResponse{ + Success: "ok", + }, } err = json.NewEncoder(w).Encode(response) @@ -185,8 +193,10 @@ func TestRemoveRecord(t *testing.T) { response := RemoveResponse{ Domain: "example.com", RecordID: 6, - Success: "error", - Error: "bad things", + BaseResponse: BaseResponse{ + Success: "error", + Error: "bad things", + }, } err = json.NewEncoder(w).Encode(response) @@ -205,11 +215,11 @@ func TestRemoveRecord(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/del", test.handler) - id, err := client.RemoveRecord(test.data) + id, err := client.RemoveRecord(context.Background(), test.data) if test.expectError { require.Error(t, err) require.Equal(t, 0, id) @@ -258,7 +268,9 @@ func TestGetRecords(t *testing.T) { TTL: 300, }, }, - Success: "ok", + BaseResponse: BaseResponse{ + Success: "ok", + }, } err := json.NewEncoder(w).Encode(response) @@ -278,9 +290,11 @@ func TestGetRecords(t *testing.T) { assert.Equal(t, "domain=example.com", r.URL.RawQuery) response := ListResponse{ - Domain: "example.com", - Success: "error", - Error: "bad things", + Domain: "example.com", + BaseResponse: BaseResponse{ + Success: "error", + Error: "bad things", + }, } err := json.NewEncoder(w).Encode(response) @@ -298,11 +312,11 @@ func TestGetRecords(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() - mux, client := setupTest(t) + client, mux := setupTest(t) mux.HandleFunc("/list", test.handler) - records, err := client.GetRecords(test.domain) + records, err := client.GetRecords(context.Background(), test.domain) if test.expectError { require.Error(t, err) require.Empty(t, records) diff --git a/providers/dns/yandex/internal/types.go b/providers/dns/yandex/internal/types.go index 3432d7eb..ed1873ce 100644 --- a/providers/dns/yandex/internal/types.go +++ b/providers/dns/yandex/internal/types.go @@ -10,23 +10,38 @@ type Record struct { Content string `json:"content,omitempty" url:"content,omitempty"` } +type Response interface { + GetSuccess() string + GetError() string +} + +type BaseResponse struct { + Success string `json:"success"` + Error string `json:"error,omitempty"` +} + +func (r BaseResponse) GetSuccess() string { + return r.Success +} + +func (r BaseResponse) GetError() string { + return r.Error +} + type AddResponse struct { - Domain string `json:"domain,omitempty"` - Record *Record `json:"record,omitempty"` - Success string `json:"success"` - Error string `json:"error,omitempty"` + BaseResponse + Domain string `json:"domain,omitempty"` + Record *Record `json:"record,omitempty"` } type RemoveResponse struct { + BaseResponse Domain string `json:"domain,omitempty"` RecordID int `json:"record_id,omitempty"` - Success string `json:"success"` - Error string `json:"error,omitempty"` } type ListResponse struct { + BaseResponse Domain string `json:"domain,omitempty"` Records []Record `json:"records,omitempty"` - Success string `json:"success"` - Error string `json:"error,omitempty"` } diff --git a/providers/dns/yandex/yandex.go b/providers/dns/yandex/yandex.go index 3b1ac04a..e747be33 100644 --- a/providers/dns/yandex/yandex.go +++ b/providers/dns/yandex/yandex.go @@ -2,6 +2,7 @@ package yandex import ( + "context" "errors" "fmt" "net/http" @@ -13,8 +14,6 @@ import ( "github.com/miekg/dns" ) -const defaultTTL = 21600 - // Environment variables names. const ( envNamespace = "YANDEX_" @@ -39,7 +38,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 21600), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), HTTPClient: &http.Client{ @@ -106,7 +105,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { Content: info.Value, } - _, err = d.client.AddRecord(data) + _, err = d.client.AddRecord(context.Background(), data) if err != nil { return fmt.Errorf("yandex: %w", err) } @@ -123,7 +122,9 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("yandex: %w", err) } - records, err := d.client.GetRecords(rootDomain) + ctx := context.Background() + + records, err := d.client.GetRecords(ctx, rootDomain) if err != nil { return fmt.Errorf("yandex: %w", err) } @@ -146,7 +147,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { Domain: rootDomain, } - _, err = d.client.RemoveRecord(data) + _, err = d.client.RemoveRecord(ctx, data) if err != nil { return fmt.Errorf("yandex: %w", err) } diff --git a/providers/dns/yandexcloud/yandexcloud.go b/providers/dns/yandexcloud/yandexcloud.go index 22f77570..f30aef76 100644 --- a/providers/dns/yandexcloud/yandexcloud.go +++ b/providers/dns/yandexcloud/yandexcloud.go @@ -17,8 +17,6 @@ import ( "github.com/yandex-cloud/go-sdk/iamkey" ) -const defaultTTL = 60 - // Environment variables names. const ( envNamespace = "YANDEX_CLOUD_" @@ -44,7 +42,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { return &Config{ - TTL: env.GetOrDefaultInt(EnvTTL, defaultTTL), + TTL: env.GetOrDefaultInt(EnvTTL, 60), PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, dns01.DefaultPropagationTimeout), PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval), } @@ -106,7 +104,7 @@ func (r *DNSProvider) Present(domain, _, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("yandexcloud: %w", err) + return fmt.Errorf("yandexcloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } ctx := context.Background() @@ -147,7 +145,7 @@ func (r *DNSProvider) CleanUp(domain, _, keyAuth string) error { authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("yandexcloud: %w", err) + return fmt.Errorf("yandexcloud: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } ctx := context.Background() diff --git a/providers/dns/zoneee/client.go b/providers/dns/zoneee/client.go deleted file mode 100644 index a2f340d6..00000000 --- a/providers/dns/zoneee/client.go +++ /dev/null @@ -1,122 +0,0 @@ -package zoneee - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" -) - -const defaultEndpoint = "https://api.zone.eu/v2/dns/" - -type txtRecord struct { - // Identifier (identificator) - ID string `json:"id,omitempty"` - // Hostname - Name string `json:"name"` - // TXT content value - Destination string `json:"destination"` - // Can this record be deleted - Delete bool `json:"delete,omitempty"` - // Can this record be modified - Modify bool `json:"modify,omitempty"` - // API url to get this entity - ResourceURL string `json:"resource_url,omitempty"` -} - -func (d *DNSProvider) addTxtRecord(domain string, record txtRecord) ([]txtRecord, error) { - reqBody := &bytes.Buffer{} - if err := json.NewEncoder(reqBody).Encode(record); err != nil { - return nil, err - } - - endpoint := d.config.Endpoint.JoinPath(domain, "txt") - - req, err := http.NewRequest(http.MethodPost, endpoint.String(), reqBody) - if err != nil { - return nil, err - } - - var resp []txtRecord - if err := d.sendRequest(req, &resp); err != nil { - return nil, err - } - return resp, nil -} - -func (d *DNSProvider) getTxtRecords(domain string) ([]txtRecord, error) { - endpoint := d.config.Endpoint.JoinPath(domain, "txt") - - req, err := http.NewRequest(http.MethodGet, endpoint.String(), http.NoBody) - if err != nil { - return nil, err - } - - var resp []txtRecord - if err := d.sendRequest(req, &resp); err != nil { - return nil, err - } - return resp, nil -} - -func (d *DNSProvider) removeTxtRecord(domain, id string) error { - endpoint := d.config.Endpoint.JoinPath(domain, "txt", id) - - req, err := http.NewRequest(http.MethodDelete, endpoint.String(), http.NoBody) - if err != nil { - return err - } - - return d.sendRequest(req, nil) -} - -func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error { - req.Header.Set("Content-Type", "application/json") - req.SetBasicAuth(d.config.Username, d.config.APIKey) - - resp, err := d.config.HTTPClient.Do(req) - if err != nil { - return err - } - - if err = checkResponse(resp); err != nil { - return err - } - - defer resp.Body.Close() - - if result == nil { - return nil - } - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - err = json.Unmarshal(raw, result) - if err != nil { - return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw)) - } - return err -} - -func checkResponse(resp *http.Response) error { - if resp.StatusCode < http.StatusBadRequest { - return nil - } - - if resp.Body == nil { - return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) - } - - defer resp.Body.Close() - - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) - } - - return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw)) -} diff --git a/providers/dns/zoneee/internal/client.go b/providers/dns/zoneee/internal/client.go new file mode 100644 index 00000000..e4463b83 --- /dev/null +++ b/providers/dns/zoneee/internal/client.go @@ -0,0 +1,142 @@ +package internal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/go-acme/lego/v4/providers/dns/internal/errutils" +) + +// DefaultEndpoint the default API endpoint. +const DefaultEndpoint = "https://api.zone.eu/v2/" + +// Client the API client for Zoneee. +type Client struct { + username string + apiKey string + + BaseURL *url.URL + HTTPClient *http.Client +} + +// NewClient creates a new Client. +func NewClient(username string, apiKey string) *Client { + baseURL, _ := url.Parse(DefaultEndpoint) + + return &Client{ + username: username, + apiKey: apiKey, + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + } +} + +// GetTxtRecords get TXT records. +// https://api.zone.eu/v2#operation/getdnstxtrecords +func (c *Client) GetTxtRecords(ctx context.Context, domain string) ([]TXTRecord, error) { + endpoint := c.BaseURL.JoinPath("dns", domain, "txt") + + req, err := newJSONRequest(ctx, http.MethodGet, endpoint, http.NoBody) + if err != nil { + return nil, err + } + + var records []TXTRecord + if err := c.do(req, &records); err != nil { + return nil, err + } + + return records, nil +} + +// AddTxtRecord creates a TXT records. +// https://api.zone.eu/v2#operation/creatednstxtrecord +func (c *Client) AddTxtRecord(ctx context.Context, domain string, record TXTRecord) ([]TXTRecord, error) { + endpoint := c.BaseURL.JoinPath("dns", domain, "txt") + + req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record) + if err != nil { + return nil, err + } + + var records []TXTRecord + if err := c.do(req, &records); err != nil { + return nil, err + } + + return records, nil +} + +// RemoveTxtRecord deletes a TXT record. +// https://api.zone.eu/v2#operation/deletednstxtrecord +func (c *Client) RemoveTxtRecord(ctx context.Context, domain, id string) error { + endpoint := c.BaseURL.JoinPath("dns", domain, "txt", id) + + req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return err + } + + return c.do(req, nil) +} + +func (c *Client) do(req *http.Request, result any) error { + req.SetBasicAuth(c.username, c.apiKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return errutils.NewHTTPDoError(req, err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + return errutils.NewUnexpectedResponseStatusCodeError(req, resp) + } + + if result == nil { + return nil + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return errutils.NewReadResponseError(req, resp.StatusCode, err) + } + + err = json.Unmarshal(raw, result) + if err != nil { + return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err) + } + + return nil +} + +func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) { + buf := new(bytes.Buffer) + + if payload != nil { + err := json.NewEncoder(buf).Encode(payload) + if err != nil { + return nil, fmt.Errorf("failed to create request JSON body: %w", err) + } + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + return req, nil +} diff --git a/providers/dns/zoneee/internal/client_test.go b/providers/dns/zoneee/internal/client_test.go new file mode 100644 index 00000000..9e53117a --- /dev/null +++ b/providers/dns/zoneee/internal/client_test.go @@ -0,0 +1,90 @@ +package internal + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTest(t *testing.T, method, pattern string, status int, file string) *Client { + t.Helper() + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + mux.HandleFunc(pattern, func(rw http.ResponseWriter, req *http.Request) { + if req.Method != method { + http.Error(rw, fmt.Sprintf("unsupported method %s", req.Method), http.StatusBadRequest) + return + } + + if file == "" { + rw.WriteHeader(status) + return + } + + open, err := os.Open(filepath.Join("fixtures", file)) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { _ = open.Close() }() + + rw.WriteHeader(status) + _, err = io.Copy(rw, open) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + }) + + client := NewClient("user", "secret") + client.HTTPClient = server.Client() + client.BaseURL, _ = url.Parse(server.URL) + + return client +} + +func TestClient_GetTxtRecords(t *testing.T) { + client := setupTest(t, http.MethodGet, "/dns/example.com/txt", http.StatusOK, "get-txt-records.json") + + records, err := client.GetTxtRecords(context.Background(), "example.com") + require.NoError(t, err) + + expected := []TXTRecord{ + {ID: "123", Name: "prefix.example.com", Destination: "server.example.com", Delete: true, Modify: true, ResourceURL: "string"}, + } + + assert.Equal(t, expected, records) +} + +func TestClient_AddTxtRecord(t *testing.T) { + client := setupTest(t, http.MethodPost, "/dns/example.com/txt", http.StatusCreated, "create-txt-record.json") + + records, err := client.AddTxtRecord(context.Background(), "example.com", TXTRecord{Name: "prefix.example.com", Destination: "server.example.com"}) + require.NoError(t, err) + + expected := []TXTRecord{ + {ID: "123", Name: "prefix.example.com", Destination: "server.example.com", Delete: true, Modify: true, ResourceURL: "string"}, + } + + assert.Equal(t, expected, records) +} + +func TestClient_RemoveTxtRecord(t *testing.T) { + client := setupTest(t, http.MethodDelete, "/dns/example.com/txt/123", http.StatusNoContent, "") + + err := client.RemoveTxtRecord(context.Background(), "example.com", "123") + require.NoError(t, err) +} diff --git a/providers/dns/zoneee/internal/fixtures/create-txt-record.json b/providers/dns/zoneee/internal/fixtures/create-txt-record.json new file mode 100644 index 00000000..53af789a --- /dev/null +++ b/providers/dns/zoneee/internal/fixtures/create-txt-record.json @@ -0,0 +1,10 @@ +[ + { + "resource_url": "string", + "destination": "server.example.com", + "id": "123", + "name": "prefix.example.com", + "delete": true, + "modify": true + } +] diff --git a/providers/dns/zoneee/internal/fixtures/get-txt-records.json b/providers/dns/zoneee/internal/fixtures/get-txt-records.json new file mode 100644 index 00000000..53af789a --- /dev/null +++ b/providers/dns/zoneee/internal/fixtures/get-txt-records.json @@ -0,0 +1,10 @@ +[ + { + "resource_url": "string", + "destination": "server.example.com", + "id": "123", + "name": "prefix.example.com", + "delete": true, + "modify": true + } +] diff --git a/providers/dns/zoneee/internal/types.go b/providers/dns/zoneee/internal/types.go new file mode 100644 index 00000000..f086a85c --- /dev/null +++ b/providers/dns/zoneee/internal/types.go @@ -0,0 +1,16 @@ +package internal + +type TXTRecord struct { + // Identifier (identificator) + ID string `json:"id,omitempty"` + // Hostname + Name string `json:"name"` + // TXT content value + Destination string `json:"destination"` + // Can this record be deleted + Delete bool `json:"delete,omitempty"` + // Can this record be modified + Modify bool `json:"modify,omitempty"` + // API url to get this entity + ResourceURL string `json:"resource_url,omitempty"` +} diff --git a/providers/dns/zoneee/zoneee.go b/providers/dns/zoneee/zoneee.go index 2d764bfb..b0f0e5ab 100644 --- a/providers/dns/zoneee/zoneee.go +++ b/providers/dns/zoneee/zoneee.go @@ -2,6 +2,7 @@ package zoneee import ( + "context" "errors" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" + "github.com/go-acme/lego/v4/providers/dns/zoneee/internal" ) // Environment variables names. @@ -37,7 +39,7 @@ type Config struct { // NewDefaultConfig returns a default configuration for the DNSProvider. func NewDefaultConfig() *Config { - endpoint, _ := url.Parse(defaultEndpoint) + endpoint, _ := url.Parse(internal.DefaultEndpoint) return &Config{ Endpoint: endpoint, @@ -53,6 +55,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { config *Config + client *internal.Client } // NewDNSProvider returns a DNSProvider instance. @@ -62,7 +65,7 @@ func NewDNSProvider() (*DNSProvider, error) { return nil, fmt.Errorf("zoneee: %w", err) } - rawEndpoint := env.GetOrDefaultString(EnvEndpoint, defaultEndpoint) + rawEndpoint := env.GetOrDefaultString(EnvEndpoint, internal.DefaultEndpoint) endpoint, err := url.Parse(rawEndpoint) if err != nil { return nil, fmt.Errorf("zoneee: %w", err) @@ -94,7 +97,16 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { return nil, errors.New("zoneee: the endpoint is missing") } - return &DNSProvider{config: config}, nil + client := internal.NewClient(config.Username, config.APIKey) + + if config.HTTPClient != nil { + client.HTTPClient = config.HTTPClient + } + if config.Endpoint != nil { + client.BaseURL = config.Endpoint + } + + return &DNSProvider{config: config, client: client}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. @@ -107,17 +119,19 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - record := txtRecord{ + record := internal.TXTRecord{ Name: dns01.UnFqdn(info.EffectiveFQDN), Destination: info.Value, } - authZone, err := getHostedZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("zoneee: %w", err) + return fmt.Errorf("zoneee: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - _, err = d.addTxtRecord(authZone, record) + authZone = dns01.UnFqdn(authZone) + + _, err = d.client.AddTxtRecord(context.Background(), authZone, record) if err != nil { return fmt.Errorf("zoneee: %w", err) } @@ -128,12 +142,16 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - authZone, err := getHostedZone(info.EffectiveFQDN) + authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) if err != nil { - return fmt.Errorf("zoneee: %w", err) + return fmt.Errorf("zoneee: could not find zone for domain %q (%s): %w", domain, info.EffectiveFQDN, err) } - records, err := d.getTxtRecords(authZone) + authZone = dns01.UnFqdn(authZone) + + ctx := context.Background() + + records, err := d.client.GetTxtRecords(ctx, authZone) if err != nil { return fmt.Errorf("zoneee: %w", err) } @@ -149,18 +167,9 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { return fmt.Errorf("zoneee: txt record does not exist for %s", info.Value) } - if err = d.removeTxtRecord(authZone, id); err != nil { + if err = d.client.RemoveTxtRecord(ctx, authZone, id); err != nil { return fmt.Errorf("zoneee: %w", err) } return nil } - -func getHostedZone(domain string) (string, error) { - authZone, err := dns01.FindZoneByFqdn(domain) - if err != nil { - return "", err - } - - return dns01.UnFqdn(authZone), nil -} diff --git a/providers/dns/zoneee/zoneee_test.go b/providers/dns/zoneee/zoneee_test.go index cc67fac7..7a284266 100644 --- a/providers/dns/zoneee/zoneee_test.go +++ b/providers/dns/zoneee/zoneee_test.go @@ -6,10 +6,12 @@ import ( "net/http" "net/http/httptest" "net/url" + "path" "testing" "time" "github.com/go-acme/lego/v4/platform/tester" + "github.com/go-acme/lego/v4/providers/dns/zoneee/internal" "github.com/stretchr/testify/require" ) @@ -122,7 +124,7 @@ func TestNewDNSProviderConfig(t *testing.T) { config.APIKey = test.apiKey config.Username = test.apiUser - if len(test.endpoint) > 0 { + if test.endpoint != "" { config.Endpoint = mustParse(test.endpoint) } @@ -155,7 +157,7 @@ func TestDNSProvider_Present(t *testing.T) { username: "bar", apiKey: "foo", handlers: map[string]http.HandlerFunc{ - "/" + hostedZone + "/txt": mockHandlerCreateRecord, + path.Join("/", "dns", hostedZone, "txt"): mockHandlerCreateRecord, }, }, { @@ -163,15 +165,15 @@ func TestDNSProvider_Present(t *testing.T) { username: "nope", apiKey: "foo", handlers: map[string]http.HandlerFunc{ - "/" + hostedZone + "/txt": mockHandlerCreateRecord, + path.Join("/", "dns", hostedZone, "txt"): mockHandlerCreateRecord, }, - expectedError: "zoneee: status code=401: Unauthorized\n", + expectedError: "zoneee: unexpected status code: [status code: 401] body: Unauthorized", }, { desc: "error", username: "bar", apiKey: "foo", - expectedError: "zoneee: status code=404: 404 page not found\n", + expectedError: "zoneee: unexpected status code: [status code: 404] body: 404 page not found", }, } @@ -181,13 +183,14 @@ func TestDNSProvider_Present(t *testing.T) { t.Parallel() mux := http.NewServeMux() - for uri, handler := range test.handlers { - mux.HandleFunc(uri, handler) - } - server := httptest.NewServer(mux) t.Cleanup(server.Close) + for uri, handler := range test.handlers { + handler := handler + mux.HandleFunc(uri, handler) + } + config := NewDefaultConfig() config.Endpoint = mustParse(server.URL) config.Username = test.username @@ -222,14 +225,14 @@ func TestDNSProvider_Cleanup(t *testing.T) { username: "bar", apiKey: "foo", handlers: map[string]http.HandlerFunc{ - "/" + hostedZone + "/txt": mockHandlerGetRecords([]txtRecord{{ + path.Join("/", "dns", hostedZone, "txt"): mockHandlerGetRecords([]internal.TXTRecord{{ ID: "1234", Name: domain, Destination: "LHDhK3oGRvkiefQnx7OOczTY5Tic_xZ6HcMOc_gmtoM", Delete: true, Modify: true, }}), - "/" + hostedZone + "/txt/1234": mockHandlerDeleteRecord, + path.Join("/", "dns", hostedZone, "txt", "1234"): mockHandlerDeleteRecord, }, }, { @@ -237,8 +240,8 @@ func TestDNSProvider_Cleanup(t *testing.T) { username: "bar", apiKey: "foo", handlers: map[string]http.HandlerFunc{ - "/" + hostedZone + "/txt": mockHandlerGetRecords([]txtRecord{}), - "/" + hostedZone + "/txt/1234": mockHandlerDeleteRecord, + path.Join("/", "dns", hostedZone, "txt"): mockHandlerGetRecords([]internal.TXTRecord{}), + path.Join("/", "dns", hostedZone, "txt", "1234"): mockHandlerDeleteRecord, }, expectedError: "zoneee: txt record does not exist for LHDhK3oGRvkiefQnx7OOczTY5Tic_xZ6HcMOc_gmtoM", }, @@ -247,22 +250,22 @@ func TestDNSProvider_Cleanup(t *testing.T) { username: "nope", apiKey: "foo", handlers: map[string]http.HandlerFunc{ - "/" + hostedZone + "/txt": mockHandlerGetRecords([]txtRecord{{ + path.Join("/", "dns", hostedZone, "txt"): mockHandlerGetRecords([]internal.TXTRecord{{ ID: "1234", Name: domain, Destination: "LHDhK3oGRvkiefQnx7OOczTY5Tic_xZ6HcMOc_gmtoM", Delete: true, Modify: true, }}), - "/" + hostedZone + "/txt/1234": mockHandlerDeleteRecord, + path.Join("/", "dns", hostedZone, "txt", "1234"): mockHandlerDeleteRecord, }, - expectedError: "zoneee: status code=401: Unauthorized\n", + expectedError: "zoneee: unexpected status code: [status code: 401] body: Unauthorized", }, { desc: "error", username: "bar", apiKey: "foo", - expectedError: "zoneee: status code=404: 404 page not found\n", + expectedError: "zoneee: unexpected status code: [status code: 404] body: 404 page not found", }, } @@ -272,13 +275,14 @@ func TestDNSProvider_Cleanup(t *testing.T) { t.Parallel() mux := http.NewServeMux() - for uri, handler := range test.handlers { - mux.HandleFunc(uri, handler) - } - server := httptest.NewServer(mux) t.Cleanup(server.Close) + for uri, handler := range test.handlers { + handler := handler + mux.HandleFunc(uri, handler) + } + config := NewDefaultConfig() config.Endpoint = mustParse(server.URL) config.Username = test.username @@ -346,7 +350,7 @@ func mockHandlerCreateRecord(rw http.ResponseWriter, req *http.Request) { return } - record := txtRecord{} + record := internal.TXTRecord{} err := json.NewDecoder(req.Body).Decode(&record) if err != nil { http.Error(rw, err.Error(), http.StatusBadRequest) @@ -358,7 +362,7 @@ func mockHandlerCreateRecord(rw http.ResponseWriter, req *http.Request) { record.Modify = true record.ResourceURL = req.URL.String() + "/1234" - bytes, err := json.Marshal([]txtRecord{record}) + bytes, err := json.Marshal([]internal.TXTRecord{record}) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return @@ -370,7 +374,7 @@ func mockHandlerCreateRecord(rw http.ResponseWriter, req *http.Request) { } } -func mockHandlerGetRecords(records []txtRecord) http.HandlerFunc { +func mockHandlerGetRecords(records []internal.TXTRecord) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(rw, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) diff --git a/providers/dns/zonomi/zonomi.go b/providers/dns/zonomi/zonomi.go index 740d32a0..5d1a2c79 100644 --- a/providers/dns/zonomi/zonomi.go +++ b/providers/dns/zonomi/zonomi.go @@ -2,6 +2,7 @@ package zonomi import ( + "context" "errors" "fmt" "net/http" @@ -96,20 +97,22 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { func (d *DNSProvider) Present(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - records, err := d.client.FindTXTRecords(dns01.UnFqdn(info.EffectiveFQDN)) + ctx := context.Background() + + records, err := d.client.FindTXTRecords(ctx, dns01.UnFqdn(info.EffectiveFQDN)) if err != nil { return fmt.Errorf("zonomi: failed to find record(s) for %s: %w", domain, err) } actions := []rimuhosting.ActionParameter{ - rimuhosting.AddRecord(dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL), + rimuhosting.NewAddRecordAction(dns01.UnFqdn(info.EffectiveFQDN), info.Value, d.config.TTL), } for _, record := range records { - actions = append(actions, rimuhosting.AddRecord(record.Name, record.Content, d.config.TTL)) + actions = append(actions, rimuhosting.NewAddRecordAction(record.Name, record.Content, d.config.TTL)) } - _, err = d.client.DoActions(actions...) + _, err = d.client.DoActions(ctx, actions...) if err != nil { return fmt.Errorf("zonomi: failed to add record(s) for %s: %w", domain, err) } @@ -121,9 +124,9 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { info := dns01.GetChallengeInfo(domain, keyAuth) - action := rimuhosting.DeleteRecord(dns01.UnFqdn(info.EffectiveFQDN), info.Value) + action := rimuhosting.NewDeleteRecordAction(dns01.UnFqdn(info.EffectiveFQDN), info.Value) - _, err := d.client.DoActions(action) + _, err := d.client.DoActions(context.Background(), action) if err != nil { return fmt.Errorf("zonomi: failed to delete record for %s: %w", domain, err) }