federation: add utilities for server name resolution

This commit is contained in:
Tulir Asokan 2024-07-28 15:18:04 +03:00
commit b5c26a2fdb
6 changed files with 575 additions and 0 deletions

115
federation/request.go Normal file
View file

@ -0,0 +1,115 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
)
// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
// It only allows requests using the "matrix-federation" scheme.
type ServerResolvingTransport struct {
ResolveOpts *ResolveServerNameOpts
Transport *http.Transport
Dialer *net.Dialer
cache map[string]*ResolvedServerName
resolveLocks map[string]*sync.Mutex
cacheLock sync.Mutex
}
func NewServerResolvingTransport() *ServerResolvingTransport {
srt := &ServerResolvingTransport{
cache: make(map[string]*ResolvedServerName),
resolveLocks: make(map[string]*sync.Mutex),
Dialer: &net.Dialer{},
}
srt.Transport = &http.Transport{
DialContext: srt.DialContext,
}
return srt
}
func NewFederationHTTPClient() *http.Client {
return &http.Client{
Transport: NewServerResolvingTransport(),
Timeout: 120 * time.Second,
}
}
var _ http.RoundTripper = (*ServerResolvingTransport)(nil)
func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
addrs, ok := ctx.Value(contextKeyIPPort).([]string)
if !ok {
return nil, fmt.Errorf("no IP:port in context")
}
return srt.Dialer.DialContext(ctx, network, addrs[0])
}
type contextKey int
const (
contextKeyIPPort contextKey = iota
)
func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) {
if request.URL.Scheme != "matrix-federation" {
return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme)
}
resolved, err := srt.resolve(request.Context(), request.URL.Host)
if err != nil {
return nil, fmt.Errorf("failed to resolve server name: %w", err)
}
request = request.WithContext(context.WithValue(request.Context(), contextKeyIPPort, resolved.IPPort))
request.URL.Scheme = "https"
request.URL.Host = resolved.HostHeader
request.Host = resolved.HostHeader
return srt.Transport.RoundTrip(request)
}
func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
res, lock := srt.getResolveCache(serverName)
if res != nil {
return res, nil
}
lock.Lock()
defer lock.Unlock()
res, _ = srt.getResolveCache(serverName)
if res != nil {
return res, nil
}
var err error
res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts)
if err != nil {
return nil, err
}
srt.cacheLock.Lock()
srt.cache[serverName] = res
srt.cacheLock.Unlock()
return res, nil
}
func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) {
srt.cacheLock.Lock()
defer srt.cacheLock.Unlock()
if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 {
return val, nil
}
rl, ok := srt.resolveLocks[serverName]
if !ok {
rl = &sync.Mutex{}
srt.resolveLocks[serverName] = rl
}
return nil, rl
}

View file

@ -0,0 +1,35 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation_test
import (
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/federation"
)
type serverVersionResp struct {
Server struct {
Name string `json:"name"`
Version string `json:"version"`
} `json:"server"`
}
func TestNewFederationClient(t *testing.T) {
cli := federation.NewFederationHTTPClient()
resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version")
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var respData serverVersionResp
err = json.NewDecoder(resp.Body).Decode(&respData)
require.NoError(t, err)
require.Equal(t, "Synapse", respData.Server.Name)
}

151
federation/resolution.go Normal file
View file

@ -0,0 +1,151 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
)
type ResolvedServerName struct {
ServerName string `json:"server_name"`
HostHeader string `json:"host_header"`
IPPort []string `json:"ip_port"`
Expires time.Time `json:"expires"`
}
type ResolveServerNameOpts struct {
HTTPClient *http.Client
DNSClient *net.Resolver
}
var (
ErrInvalidServerName = errors.New("invalid server name")
)
// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names
func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) {
var opt ResolveServerNameOpts
if len(opts) > 0 && opts[0] != nil {
opt = *opts[0]
}
if opt.HTTPClient == nil {
opt.HTTPClient = http.DefaultClient
}
if opt.DNSClient == nil {
opt.DNSClient = net.DefaultResolver
}
output := ResolvedServerName{
ServerName: serverName,
HostHeader: serverName,
IPPort: []string{serverName},
Expires: time.Now().Add(24 * time.Hour),
}
hostname, port, ok := ParseServerName(serverName)
if !ok {
return nil, ErrInvalidServerName
}
// Steps 1 and 2: handle IP literals and hostnames with port
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
port = 8448
}
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
return &output, nil
}
// Step 3: resolve .well-known
wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname)
if err != nil {
zerolog.Ctx(ctx).Trace().
Str("server_name", serverName).
Err(err).
Msg("Failed to get well-known data")
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
hostname, port, ok = ParseServerName(wellKnown.Server)
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
port = 8448
}
output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
return &output, nil
}
}
// Step 3.3, 3.4, 4 and 5: resolve SRV records
srv, err := RequestSRV(ctx, opt.DNSClient, hostname)
if err != nil {
// TODO log more noisily for abnormal errors?
zerolog.Ctx(ctx).Trace().
Str("server_name", serverName).
Str("hostname", hostname).
Err(err).
Msg("Failed to get SRV record")
} else if len(srv) > 0 {
output.IPPort = make([]string, len(srv))
for i, record := range srv {
output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port)))
}
return &output, nil
}
// Step 6 or 3.5: no SRV records were found, so default to port 8448
output.IPPort = []string{net.JoinHostPort(hostname, "8448")}
return &output, nil
}
// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname.
// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record.
func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) {
_, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname)
var dnsErr *net.DNSError
if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound {
_, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname)
}
return target, err
}
// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
// plus the time when the cache should expire.
func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
wellKnownURL := url.URL{
Scheme: "https",
Host: hostname,
Path: "/.well-known/matrix/server",
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err)
}
resp, err := cli.Do(req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
var respData RespWellKnown
err = json.NewDecoder(resp.Body).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
return nil, time.Time{}, errors.New("server name not found in response")
}
// TODO parse cache-control header
return &respData, time.Now().Add(24 * time.Hour), nil
}

View file

@ -0,0 +1,115 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/federation"
)
type resolveTestCase struct {
name string
serverName string
expected federation.ResolvedServerName
}
func TestResolveServerName(t *testing.T) {
// See https://t2bot.io/docs/resolvematrix/ for more info on the RM test cases
testCases := []resolveTestCase{{
"maunium",
"maunium.net",
federation.ResolvedServerName{
HostHeader: "federation.mau.chat",
IPPort: []string{"meow.host.mau.fi:443"},
},
}, {
"IP literal",
"135.181.208.158",
federation.ResolvedServerName{
HostHeader: "135.181.208.158",
IPPort: []string{"135.181.208.158:8448"},
},
}, {
"IP literal with port",
"135.181.208.158:8447",
federation.ResolvedServerName{
HostHeader: "135.181.208.158:8447",
IPPort: []string{"135.181.208.158:8447"},
},
}, {
"RM Step 2",
"2.s.resolvematrix.dev:7652",
federation.ResolvedServerName{
HostHeader: "2.s.resolvematrix.dev:7652",
IPPort: []string{"2.s.resolvematrix.dev:7652"},
},
}, {
"RM Step 3B",
"3b.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "wk.3b.s.resolvematrix.dev:7753",
IPPort: []string{"wk.3b.s.resolvematrix.dev:7753"},
},
}, {
"RM Step 3C",
"3c.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "wk.3c.s.resolvematrix.dev",
IPPort: []string{"srv.wk.3c.s.resolvematrix.dev:7754"},
},
}, {
"RM Step 3C MSC4040",
"3c.msc4040.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "wk.3c.msc4040.s.resolvematrix.dev",
IPPort: []string{"srv.wk.3c.msc4040.s.resolvematrix.dev:7053"},
},
}, {
"RM Step 3D",
"3d.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "wk.3d.s.resolvematrix.dev",
IPPort: []string{"wk.3d.s.resolvematrix.dev:8448"},
},
}, {
"RM Step 4",
"4.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "4.s.resolvematrix.dev",
IPPort: []string{"srv.4.s.resolvematrix.dev:7855"},
},
}, {
"RM Step 4 MSC4040",
"4.msc4040.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "4.msc4040.s.resolvematrix.dev",
IPPort: []string{"srv.4.msc4040.s.resolvematrix.dev:7054"},
},
}, {
"RM Step 5",
"5.s.resolvematrix.dev",
federation.ResolvedServerName{
HostHeader: "5.s.resolvematrix.dev",
IPPort: []string{"5.s.resolvematrix.dev:8448"},
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.expected.ServerName = tc.serverName
resp, err := federation.ResolveServerName(context.TODO(), tc.serverName)
require.NoError(t, err)
resp.Expires = time.Time{}
assert.Equal(t, tc.expected, *resp)
})
}
}

95
federation/servername.go Normal file
View file

@ -0,0 +1,95 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"net"
"strconv"
"strings"
)
func isSpecCompliantIPv6(host string) bool {
// IPv6address = 2*45IPv6char
// IPv6char = DIGIT / %x41-46 / %x61-66 / ":" / "."
// ; 0-9, A-F, a-f, :, .
if len(host) < 2 || len(host) > 45 {
return false
}
for _, ch := range host {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') && ch != ':' && ch != '.' {
return false
}
}
return true
}
func isValidIPv4Chunk(str string) bool {
if len(str) == 0 || len(str) > 3 {
return false
}
for _, ch := range str {
if ch < '0' || ch > '9' {
return false
}
}
return true
}
func isSpecCompliantIPv4(host string) bool {
// IPv4address = 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT
if len(host) < 7 || len(host) > 15 {
return false
}
parts := strings.Split(host, ".")
return len(parts) == 4 &&
isValidIPv4Chunk(parts[0]) &&
isValidIPv4Chunk(parts[1]) &&
isValidIPv4Chunk(parts[2]) &&
isValidIPv4Chunk(parts[3])
}
func isSpecCompliantDNSName(host string) bool {
// dns-name = 1*255dns-char
// dns-char = DIGIT / ALPHA / "-" / "."
if len(host) == 0 || len(host) > 255 {
return false
}
for _, ch := range host {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') && ch != '-' && ch != '.' {
return false
}
}
return true
}
// ParseServerName parses the port and hostname from a Matrix server name and validates that
// it matches the grammar specified in https://spec.matrix.org/v1.11/appendices/#server-name
func ParseServerName(serverName string) (host string, port uint16, ok bool) {
if len(serverName) == 0 || len(serverName) > 255 {
return
}
colonIdx := strings.LastIndexByte(serverName, ':')
if colonIdx > 0 {
u64Port, err := strconv.ParseUint(serverName[colonIdx+1:], 10, 16)
if err == nil {
port = uint16(u64Port)
serverName = serverName[:colonIdx]
}
}
if serverName[0] == '[' {
if serverName[len(serverName)-1] != ']' {
return
}
host = serverName[1 : len(serverName)-1]
ok = isSpecCompliantIPv6(host) && net.ParseIP(host) != nil
} else {
host = serverName
ok = isSpecCompliantDNSName(host) || isSpecCompliantIPv4(host)
}
return
}

View file

@ -0,0 +1,64 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation_test
import (
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/federation"
)
type parseTestCase struct {
name string
serverName string
hostname string
port uint16
}
func TestParseServerName(t *testing.T) {
testCases := []parseTestCase{{
"Domain",
"matrix.org",
"matrix.org",
0,
}, {
"Domain with port",
"matrix.org:8448",
"matrix.org",
8448,
}, {
"IPv4 literal",
"1.2.3.4",
"1.2.3.4",
0,
}, {
"IPv4 literal with port",
"1.2.3.4:8448",
"1.2.3.4",
8448,
}, {
"IPv6 literal",
"[1234:5678::abcd]",
"1234:5678::abcd",
0,
}, {
"IPv6 literal with port",
"[1234:5678::abcd]:8448",
"1234:5678::abcd",
8448,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hostname, port, ok := federation.ParseServerName(tc.serverName)
assert.True(t, ok)
assert.Equal(t, tc.hostname, hostname)
assert.Equal(t, tc.port, port)
})
}
}