mautrix-go/federation/resolution.go
Tulir Asokan 315d2ab17d
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run
all: fix staticcheck issues
2025-12-08 00:07:25 +02:00

198 lines
6 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
)
type ResolvedServerName struct {
ServerName string `json:"server_name"`
HostHeader string `json:"host_header"`
IPPort []string `json:"ip_port"`
Expires time.Time `json:"expires"`
}
type ResolveServerNameOpts struct {
HTTPClient *http.Client
DNSClient *net.Resolver
}
var (
ErrInvalidServerName = errors.New("invalid server name")
)
// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names
func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) {
var opt ResolveServerNameOpts
if len(opts) > 0 && opts[0] != nil {
opt = *opts[0]
}
if opt.HTTPClient == nil {
opt.HTTPClient = http.DefaultClient
}
if opt.DNSClient == nil {
opt.DNSClient = net.DefaultResolver
}
output := ResolvedServerName{
ServerName: serverName,
HostHeader: serverName,
IPPort: []string{serverName},
Expires: time.Now().Add(24 * time.Hour),
}
hostname, port, ok := ParseServerName(serverName)
if !ok {
return nil, ErrInvalidServerName
}
// Steps 1 and 2: handle IP literals and hostnames with port
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
port = 8448
}
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
return &output, nil
}
// Step 3: resolve .well-known
wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname)
if err != nil {
zerolog.Ctx(ctx).Trace().
Str("server_name", serverName).
Err(err).
Msg("Failed to get well-known data")
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
if ok {
hostname, port = wkHost, wkPort
}
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
port = 8448
}
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
return &output, nil
}
}
// Step 3.3, 3.4, 4 and 5: resolve SRV records
srv, err := RequestSRV(ctx, opt.DNSClient, hostname)
if err != nil {
// TODO log more noisily for abnormal errors?
zerolog.Ctx(ctx).Trace().
Str("server_name", serverName).
Str("hostname", hostname).
Err(err).
Msg("Failed to get SRV record")
} else if len(srv) > 0 {
output.IPPort = make([]string, len(srv))
for i, record := range srv {
output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port)))
}
return &output, nil
}
// Step 6 or 3.5: no SRV records were found, so default to port 8448
output.IPPort = []string{net.JoinHostPort(hostname, "8448")}
return &output, nil
}
// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname.
// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record.
func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) {
_, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname)
var dnsErr *net.DNSError
if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound {
_, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname)
}
return target, err
}
func parseCacheControl(resp *http.Response) time.Duration {
cc := resp.Header.Get("Cache-Control")
if cc == "" {
return 0
}
parts := strings.Split(cc, ",")
for _, part := range parts {
kv := strings.SplitN(strings.TrimSpace(part), "=", 1)
switch kv[0] {
case "no-cache", "no-store":
return 0
case "max-age":
if len(kv) < 2 {
continue
}
maxAge, err := strconv.Atoi(kv[1])
if err != nil || maxAge < 0 {
continue
}
age, _ := strconv.Atoi(resp.Header.Get("Age"))
return time.Duration(maxAge-age) * time.Second
}
}
return 0
}
const (
MinCacheDuration = 1 * time.Hour
MaxCacheDuration = 72 * time.Hour
DefaultCacheDuration = 24 * time.Hour
)
// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
// plus the time when the cache should expire.
func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
wellKnownURL := url.URL{
Scheme: "https",
Host: hostname,
Path: "/.well-known/matrix/server",
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err)
}
resp, err := cli.Do(req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
} else if resp.ContentLength > mautrix.WellKnownMaxSize {
return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
return nil, time.Time{}, errors.New("server name not found in response")
}
cacheDuration := parseCacheControl(resp)
if cacheDuration <= 0 {
cacheDuration = DefaultCacheDuration
} else if cacheDuration < MinCacheDuration {
cacheDuration = MinCacheDuration
} else if cacheDuration > MaxCacheDuration {
cacheDuration = MaxCacheDuration
}
return &respData, time.Now().Add(24 * time.Hour), nil
}