refactor: factorize some functions

This commit is contained in:
Fernandez Ludovic 2026-02-25 19:47:21 +01:00
commit 1d174c769d
9 changed files with 32 additions and 114 deletions

View file

@ -84,10 +84,10 @@ func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recur
}
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
m := internal.CreateDNSMsg(fqdn, rtype, recursive)
if len(nameservers) == 0 {
return nil, &DNSError{Message: "empty list of nameservers"}
return nil, &internal.DNSError{Message: "empty list of nameservers"}
}
var (
@ -116,7 +116,7 @@ func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg,
if c.tcpOnly {
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
@ -130,20 +130,8 @@ func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg,
}
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
}
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
return m
}

View file

@ -82,9 +82,9 @@ func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
return nil, &internal.DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
}
}
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
return nil, &internal.DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
}

View file

@ -82,10 +82,10 @@ func (c *Client) sendQuery(ctx context.Context, fqdn string, rtype uint16, recur
}
func (c *Client) sendQueryCustom(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
m := internal.CreateDNSMsg(fqdn, rtype, recursive)
if len(nameservers) == 0 {
return nil, &DNSError{Message: "empty list of nameservers"}
return nil, &internal.DNSError{Message: "empty list of nameservers"}
}
var (
@ -118,7 +118,7 @@ func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg,
if c.tcpOnly {
r, _, err := c.tcpClient.ExchangeContext(ctx, m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
@ -132,24 +132,8 @@ func (c *Client) exchange(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg,
}
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
return r, &internal.DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
}
/*
* NOTE(ldez): This function is a duplication of `Client.createDNSMsg()` from `dns01/client.go`.
* The 2 functions should be kept in sync.
*/
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
return m
}

View file

@ -1,69 +0,0 @@
package dnspersist01
/*
* NOTE(ldez): This file is a duplication of `dns01/client_error.go`.
* The 2 files should be kept in sync.
*/
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// DNSError error related to DNS calls.
type DNSError struct {
Message string
NS string
MsgIn *dns.Msg
MsgOut *dns.Msg
Err error
}
func (d *DNSError) Error() string {
var details []string
if d.NS != "" {
details = append(details, "ns="+d.NS)
}
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
}
if d.MsgOut != nil {
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
}
details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
}
msg := "DNS error"
if d.Message != "" {
msg = d.Message
}
if d.Err != nil {
msg += ": " + d.Err.Error()
}
if len(details) > 0 {
msg += " [" + strings.Join(details, ", ") + "]"
}
return msg
}
func (d *DNSError) Unwrap() error {
return d.Err
}
func formatQuestions(questions []dns.Question) string {
var parts []string
for _, question := range questions {
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
}
return strings.Join(parts, ";")
}

View file

@ -90,7 +90,7 @@ func (c *Client) lookupTXT(ctx context.Context, fqdn string, nameservers []strin
case dns.RcodeNameError:
return result, nil
default:
return result, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", name), MsgOut: msg}
return result, &internal.DNSError{Message: fmt.Sprintf("unexpected response for '%s'", name), MsgOut: msg}
}
}
}

View file

@ -75,9 +75,9 @@ func (c *Client) fetchSoaByFqdn(ctx context.Context, fqdn string, nameservers []
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return "", &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
return "", &internal.DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
}
}
return "", &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
return "", &internal.DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
}

View file

@ -1,4 +1,4 @@
package dns01
package internal
import (
"fmt"

View file

@ -1,4 +1,4 @@
package dns01
package internal
import (
"errors"
@ -9,9 +9,9 @@ import (
)
func TestDNSError_Error(t *testing.T) {
msgIn := createDNSMsg("example.com.", dns.TypeTXT, true)
msgIn := CreateDNSMsg("example.com.", dns.TypeTXT, true)
msgOut := createDNSMsg("example.org.", dns.TypeSOA, true)
msgOut := CreateDNSMsg("example.org.", dns.TypeSOA, true)
msgOut.Rcode = dns.RcodeNameError
testCases := []struct {

15
challenge/internal/dns.go Normal file
View file

@ -0,0 +1,15 @@
package internal
import "github.com/miekg/dns"
func CreateDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
return m
}