lego/providers/dns/checkdomain/client.go
2021-08-25 11:44:11 +02:00

417 lines
10 KiB
Go

package checkdomain
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
type (
// Some fields have been omitted from the structs
// because they are not required for this application.
DomainListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedDomainList `json:"_embedded"`
}
EmbeddedDomainList struct {
Domains []*Domain `json:"domains"`
}
Domain struct {
ID int `json:"id"`
Name string `json:"name"`
}
DomainResponse struct {
ID int `json:"id"`
Name string `json:"name"`
Created string `json:"created"`
PaidUp string `json:"payed_up"`
Active bool `json:"active"`
}
NameserverResponse struct {
General NameserverGeneral `json:"general"`
Nameservers []*Nameserver `json:"nameservers"`
SOA NameserverSOA `json:"soa"`
}
NameserverGeneral struct {
IPv4 string `json:"ip_v4"`
IPv6 string `json:"ip_v6"`
IncludeWWW bool `json:"include_www"`
}
NameserverSOA struct {
Mail string `json:"mail"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expiry int `json:"expiry"`
TTL int `json:"ttl"`
}
Nameserver struct {
Name string `json:"name"`
}
RecordListingResponse struct {
Page int `json:"page"`
Limit int `json:"limit"`
Pages int `json:"pages"`
Total int `json:"total"`
Embedded EmbeddedRecordList `json:"_embedded"`
}
EmbeddedRecordList struct {
Records []*Record `json:"records"`
}
Record struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
Priority int `json:"priority"`
Type string `json:"type"`
}
)
func (d *DNSProvider) getDomainIDByName(name string) (int, error) {
// Load from cache if exists
d.domainIDMu.Lock()
id, ok := d.domainIDMapping[name]
d.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := d.listDomains()
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
d.domainIDMu.Lock()
d.domainIDMapping[name] = domain.ID
d.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (d *DNSProvider) listDomains() ([]*Domain, error) {
req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res DomainListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := d.sendRequest(req, res); err != nil {
return nil, err
}
return res, nil
}
func (d *DNSProvider) checkNameservers(domainID int) error {
info, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (d *DNSProvider) createRecord(domainID int, record *Record) error {
bs, err := json.Marshal(record)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error {
domainInfo, err := d.getDomainInfo(domainID)
if err != nil {
return err
}
nsInfo, err := d.getNameserverInfo(domainID)
if err != nil {
return err
}
allRecords, err := d.listRecords(domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return d.replaceRecords(domainID, recordsToKeep)
}
func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody)
if err != nil {
return nil, err
}
var res DomainResponse
err = d.sendRequest(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) {
req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
q := req.URL.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
req.URL.RawQuery = q.Encode()
var res RecordListingResponse
if err := d.sendRequest(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error {
bs, err := json.Marshal(records)
if err != nil {
return fmt.Errorf("encoding record failed: %w", err)
}
req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
if err != nil {
return err
}
return d.sendRequest(req, nil)
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) {
uri, err := d.config.Endpoint.Parse(resource)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, uri.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+d.config.Token)
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
resp, err := d.config.HTTPClient.Do(req)
if err != nil {
return err
}
if err = checkResponse(resp); err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(raw, result)
if err != nil {
return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
}
return nil
}
func checkResponse(resp *http.Response) error {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
if resp.Body == nil {
return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
}
return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
}