diff --git a/federation/request.go b/federation/request.go new file mode 100644 index 00000000..faeb16ad --- /dev/null +++ b/federation/request.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "time" +) + +// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. +// It only allows requests using the "matrix-federation" scheme. +type ServerResolvingTransport struct { + ResolveOpts *ResolveServerNameOpts + Transport *http.Transport + Dialer *net.Dialer + + cache map[string]*ResolvedServerName + resolveLocks map[string]*sync.Mutex + cacheLock sync.Mutex +} + +func NewServerResolvingTransport() *ServerResolvingTransport { + srt := &ServerResolvingTransport{ + cache: make(map[string]*ResolvedServerName), + resolveLocks: make(map[string]*sync.Mutex), + + Dialer: &net.Dialer{}, + } + srt.Transport = &http.Transport{ + DialContext: srt.DialContext, + } + return srt +} + +func NewFederationHTTPClient() *http.Client { + return &http.Client{ + Transport: NewServerResolvingTransport(), + Timeout: 120 * time.Second, + } +} + +var _ http.RoundTripper = (*ServerResolvingTransport)(nil) + +func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + addrs, ok := ctx.Value(contextKeyIPPort).([]string) + if !ok { + return nil, fmt.Errorf("no IP:port in context") + } + return srt.Dialer.DialContext(ctx, network, addrs[0]) +} + +type contextKey int + +const ( + contextKeyIPPort contextKey = iota +) + +func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { + if request.URL.Scheme != "matrix-federation" { + return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) + } + resolved, err := srt.resolve(request.Context(), request.URL.Host) + if err != nil { + return nil, fmt.Errorf("failed to resolve server name: %w", err) + } + request = request.WithContext(context.WithValue(request.Context(), contextKeyIPPort, resolved.IPPort)) + request.URL.Scheme = "https" + request.URL.Host = resolved.HostHeader + request.Host = resolved.HostHeader + return srt.Transport.RoundTrip(request) +} + +func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { + res, lock := srt.getResolveCache(serverName) + if res != nil { + return res, nil + } + lock.Lock() + defer lock.Unlock() + res, _ = srt.getResolveCache(serverName) + if res != nil { + return res, nil + } + var err error + res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) + if err != nil { + return nil, err + } + srt.cacheLock.Lock() + srt.cache[serverName] = res + srt.cacheLock.Unlock() + return res, nil +} + +func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { + srt.cacheLock.Lock() + defer srt.cacheLock.Unlock() + if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { + return val, nil + } + rl, ok := srt.resolveLocks[serverName] + if !ok { + rl = &sync.Mutex{} + srt.resolveLocks[serverName] = rl + } + return nil, rl +} diff --git a/federation/request_test.go b/federation/request_test.go new file mode 100644 index 00000000..e9037f2d --- /dev/null +++ b/federation/request_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +type serverVersionResp struct { + Server struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"server"` +} + +func TestNewFederationClient(t *testing.T) { + cli := federation.NewFederationHTTPClient() + resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version") + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + var respData serverVersionResp + err = json.NewDecoder(resp.Body).Decode(&respData) + require.NoError(t, err) + require.Equal(t, "Synapse", respData.Server.Name) +} diff --git a/federation/resolution.go b/federation/resolution.go new file mode 100644 index 00000000..e6785988 --- /dev/null +++ b/federation/resolution.go @@ -0,0 +1,151 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" +) + +type ResolvedServerName struct { + ServerName string `json:"server_name"` + HostHeader string `json:"host_header"` + IPPort []string `json:"ip_port"` + Expires time.Time `json:"expires"` +} + +type ResolveServerNameOpts struct { + HTTPClient *http.Client + DNSClient *net.Resolver +} + +var ( + ErrInvalidServerName = errors.New("invalid server name") +) + +// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names +func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) { + var opt ResolveServerNameOpts + if len(opts) > 0 && opts[0] != nil { + opt = *opts[0] + } + if opt.HTTPClient == nil { + opt.HTTPClient = http.DefaultClient + } + if opt.DNSClient == nil { + opt.DNSClient = net.DefaultResolver + } + output := ResolvedServerName{ + ServerName: serverName, + HostHeader: serverName, + IPPort: []string{serverName}, + Expires: time.Now().Add(24 * time.Hour), + } + hostname, port, ok := ParseServerName(serverName) + if !ok { + return nil, ErrInvalidServerName + } + // Steps 1 and 2: handle IP literals and hostnames with port + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + // Step 3: resolve .well-known + wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname) + if err != nil { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Err(err). + Msg("Failed to get well-known data") + } else if wellKnown != nil { + output.Expires = expiry + output.HostHeader = wellKnown.Server + hostname, port, ok = ParseServerName(wellKnown.Server) + // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + } + // Step 3.3, 3.4, 4 and 5: resolve SRV records + srv, err := RequestSRV(ctx, opt.DNSClient, hostname) + if err != nil { + // TODO log more noisily for abnormal errors? + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Str("hostname", hostname). + Err(err). + Msg("Failed to get SRV record") + } else if len(srv) > 0 { + output.IPPort = make([]string, len(srv)) + for i, record := range srv { + output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port))) + } + return &output, nil + } + // Step 6 or 3.5: no SRV records were found, so default to port 8448 + output.IPPort = []string{net.JoinHostPort(hostname, "8448")} + return &output, nil +} + +// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname. +// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record. +func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) { + _, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname) + var dnsErr *net.DNSError + if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound { + _, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname) + } + return target, err +} + +// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, +// plus the time when the cache should expire. +func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { + wellKnownURL := url.URL{ + Scheme: "https", + Host: hostname, + Path: "/.well-known/matrix/server", + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err) + } + resp, err := cli.Do(req) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + var respData RespWellKnown + err = json.NewDecoder(resp.Body).Decode(&respData) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) + } else if respData.Server == "" { + return nil, time.Time{}, errors.New("server name not found in response") + } + // TODO parse cache-control header + return &respData, time.Now().Add(24 * time.Hour), nil +} diff --git a/federation/resolution_test.go b/federation/resolution_test.go new file mode 100644 index 00000000..62200454 --- /dev/null +++ b/federation/resolution_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +type resolveTestCase struct { + name string + serverName string + expected federation.ResolvedServerName +} + +func TestResolveServerName(t *testing.T) { + // See https://t2bot.io/docs/resolvematrix/ for more info on the RM test cases + testCases := []resolveTestCase{{ + "maunium", + "maunium.net", + federation.ResolvedServerName{ + HostHeader: "federation.mau.chat", + IPPort: []string{"meow.host.mau.fi:443"}, + }, + }, { + "IP literal", + "135.181.208.158", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158", + IPPort: []string{"135.181.208.158:8448"}, + }, + }, { + "IP literal with port", + "135.181.208.158:8447", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158:8447", + IPPort: []string{"135.181.208.158:8447"}, + }, + }, { + "RM Step 2", + "2.s.resolvematrix.dev:7652", + federation.ResolvedServerName{ + HostHeader: "2.s.resolvematrix.dev:7652", + IPPort: []string{"2.s.resolvematrix.dev:7652"}, + }, + }, { + "RM Step 3B", + "3b.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3b.s.resolvematrix.dev:7753", + IPPort: []string{"wk.3b.s.resolvematrix.dev:7753"}, + }, + }, { + "RM Step 3C", + "3c.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.s.resolvematrix.dev:7754"}, + }, + }, { + "RM Step 3C MSC4040", + "3c.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.msc4040.s.resolvematrix.dev:7053"}, + }, + }, { + "RM Step 3D", + "3d.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3d.s.resolvematrix.dev", + IPPort: []string{"wk.3d.s.resolvematrix.dev:8448"}, + }, + }, { + "RM Step 4", + "4.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.s.resolvematrix.dev", + IPPort: []string{"srv.4.s.resolvematrix.dev:7855"}, + }, + }, { + "RM Step 4 MSC4040", + "4.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.4.msc4040.s.resolvematrix.dev:7054"}, + }, + }, { + "RM Step 5", + "5.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "5.s.resolvematrix.dev", + IPPort: []string{"5.s.resolvematrix.dev:8448"}, + }, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.expected.ServerName = tc.serverName + resp, err := federation.ResolveServerName(context.TODO(), tc.serverName) + require.NoError(t, err) + resp.Expires = time.Time{} + assert.Equal(t, tc.expected, *resp) + }) + } +} diff --git a/federation/servername.go b/federation/servername.go new file mode 100644 index 00000000..33590712 --- /dev/null +++ b/federation/servername.go @@ -0,0 +1,95 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "net" + "strconv" + "strings" +) + +func isSpecCompliantIPv6(host string) bool { + // IPv6address = 2*45IPv6char + // IPv6char = DIGIT / %x41-46 / %x61-66 / ":" / "." + // ; 0-9, A-F, a-f, :, . + if len(host) < 2 || len(host) > 45 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') && ch != ':' && ch != '.' { + return false + } + } + return true +} + +func isValidIPv4Chunk(str string) bool { + if len(str) == 0 || len(str) > 3 { + return false + } + for _, ch := range str { + if ch < '0' || ch > '9' { + return false + } + } + return true + +} + +func isSpecCompliantIPv4(host string) bool { + // IPv4address = 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT + if len(host) < 7 || len(host) > 15 { + return false + } + parts := strings.Split(host, ".") + return len(parts) == 4 && + isValidIPv4Chunk(parts[0]) && + isValidIPv4Chunk(parts[1]) && + isValidIPv4Chunk(parts[2]) && + isValidIPv4Chunk(parts[3]) +} + +func isSpecCompliantDNSName(host string) bool { + // dns-name = 1*255dns-char + // dns-char = DIGIT / ALPHA / "-" / "." + if len(host) == 0 || len(host) > 255 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') && ch != '-' && ch != '.' { + return false + } + } + return true +} + +// ParseServerName parses the port and hostname from a Matrix server name and validates that +// it matches the grammar specified in https://spec.matrix.org/v1.11/appendices/#server-name +func ParseServerName(serverName string) (host string, port uint16, ok bool) { + if len(serverName) == 0 || len(serverName) > 255 { + return + } + colonIdx := strings.LastIndexByte(serverName, ':') + if colonIdx > 0 { + u64Port, err := strconv.ParseUint(serverName[colonIdx+1:], 10, 16) + if err == nil { + port = uint16(u64Port) + serverName = serverName[:colonIdx] + } + } + if serverName[0] == '[' { + if serverName[len(serverName)-1] != ']' { + return + } + host = serverName[1 : len(serverName)-1] + ok = isSpecCompliantIPv6(host) && net.ParseIP(host) != nil + } else { + host = serverName + ok = isSpecCompliantDNSName(host) || isSpecCompliantIPv4(host) + } + return +} diff --git a/federation/servername_test.go b/federation/servername_test.go new file mode 100644 index 00000000..156d692f --- /dev/null +++ b/federation/servername_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/federation" +) + +type parseTestCase struct { + name string + serverName string + hostname string + port uint16 +} + +func TestParseServerName(t *testing.T) { + testCases := []parseTestCase{{ + "Domain", + "matrix.org", + "matrix.org", + 0, + }, { + "Domain with port", + "matrix.org:8448", + "matrix.org", + 8448, + }, { + "IPv4 literal", + "1.2.3.4", + "1.2.3.4", + 0, + }, { + "IPv4 literal with port", + "1.2.3.4:8448", + "1.2.3.4", + 8448, + }, { + "IPv6 literal", + "[1234:5678::abcd]", + "1234:5678::abcd", + 0, + }, { + "IPv6 literal with port", + "[1234:5678::abcd]:8448", + "1234:5678::abcd", + 8448, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hostname, port, ok := federation.ParseServerName(tc.serverName) + assert.True(t, ok) + assert.Equal(t, tc.hostname, hostname) + assert.Equal(t, tc.port, port) + }) + } +}