federation: add wrappers for some federation endpoints

This commit is contained in:
Tulir Asokan 2024-07-28 23:43:07 +03:00
commit 593ad86b80
6 changed files with 406 additions and 82 deletions

View file

@ -11,7 +11,6 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
@ -43,7 +42,7 @@ type ProvisioningAPI struct {
log zerolog.Logger
net bridgev2.NetworkConnector
fedClient *http.Client
fedClient *federation.Client
logins map[string]*ProvLogin
loginsLock sync.RWMutex
@ -89,9 +88,9 @@ func (prov *ProvisioningAPI) Init() {
prov.logins = make(map[string]*ProvLogin)
prov.net = prov.br.Bridge.Network
prov.log = prov.br.Log.With().Str("component", "provisioning").Logger()
prov.fedClient = federation.NewFederationHTTPClient()
tp := prov.fedClient.Transport.(*federation.ServerResolvingTransport)
prov.fedClient.Timeout = 20 * time.Second
prov.fedClient = federation.NewClient("", nil)
prov.fedClient.HTTP.Timeout = 20 * time.Second
tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport)
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
@ -159,10 +158,6 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User
}
}
type respOpenIDUserInfo struct {
Sub id.UserID `json:"sub"`
}
func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
homeserver := userID.Homeserver()
wrappedToken := fmt.Sprintf("%s:%s", homeserver, token)
@ -171,9 +166,9 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
defer prov.matrixAuthCacheLock.Unlock()
if cached, ok := prov.matrixAuthCache[wrappedToken]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID {
return nil
} else if validationResult, err := prov.validateOpenIDToken(ctx, homeserver, token); err != nil {
} else if validationResult, err := prov.fedClient.GetOpenIDUserInfo(ctx, homeserver, token); err != nil {
return fmt.Errorf("failed to validate OpenID token: %w", err)
} else if validationResult != userID {
} else if validationResult.Sub != userID {
return fmt.Errorf("mismatching user ID (%q != %q)", validationResult, userID)
} else {
prov.matrixAuthCache[wrappedToken] = matrixAuthCacheEntry{
@ -184,35 +179,6 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
}
}
func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, server string, token string) (id.UserID, error) {
reqURL := url.URL{
Scheme: "matrix-federation",
Host: server,
Path: "/_matrix/federation/v1/openid/userinfo",
RawQuery: (&url.Values{
"access_token": {token},
}).Encode(),
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return "", fmt.Errorf("failed to prepare request: %w", err)
}
resp, err := prov.fedClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
var respData respOpenIDUserInfo
err = json.NewDecoder(resp.Body).Decode(&respData)
if err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
return respData.Sub, nil
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")

373
federation/client.go Normal file
View file

@ -0,0 +1,373 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"go.mau.fi/util/exslices"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
type Client struct {
HTTP *http.Client
ServerName string
UserAgent string
Key *SigningKey
}
func NewClient(serverName string, key *SigningKey) *Client {
return &Client{
HTTP: &http.Client{
Transport: NewServerResolvingTransport(),
Timeout: 120 * time.Second,
},
UserAgent: mautrix.DefaultUserAgent,
ServerName: serverName,
Key: key,
}
}
func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) {
err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp)
return
}
func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) {
err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp)
return
}
func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) {
err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp)
return
}
type PDU = json.RawMessage
type EDU = json.RawMessage
type ReqSendTransaction struct {
Destination string `json:"destination"`
TxnID string `json:"-"`
Origin string `json:"origin"`
OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
PDUs []PDU `json:"pdus"`
EDUs []EDU `json:"edus,omitempty"`
}
type PDUProcessingResult struct {
Error string `json:"error,omitempty"`
}
type RespSendTransaction struct {
PDUs map[id.EventID]PDUProcessingResult `json:"pdus"`
}
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
return
}
type RespGetEventAuthChain struct {
AuthChain []PDU `json:"auth_chain"`
}
func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) {
err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp)
return
}
type ReqBackfill struct {
ServerName string
RoomID id.RoomID
Limit int
BackfillFrom []id.EventID
}
type RespBackfill struct {
Origin string `json:"origin"`
OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
PDUs []PDU `json:"pdus"`
}
func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) {
_, _, err = c.MakeFullRequest(ctx, RequestParams{
ServerName: req.ServerName,
Method: http.MethodGet,
Path: URLPath{"v1", "backfill", req.RoomID},
Query: url.Values{
"limit": {strconv.Itoa(req.Limit)},
"v": exslices.CastToString[string](req.BackfillFrom),
},
Authenticate: true,
ResponseJSON: &resp,
})
return
}
type ReqGetMissingEvents struct {
ServerName string `json:"-"`
RoomID id.RoomID `json:"-"`
EarliestEvents []id.EventID `json:"earliest_events"`
LatestEvents []id.EventID `json:"latest_events"`
Limit int `json:"limit,omitempty"`
MinDepth int `json:"min_depth,omitempty"`
}
type RespGetMissingEvents struct {
Events []PDU `json:"events"`
}
func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) {
err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp)
return
}
func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) {
err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp)
return
}
type RespGetState struct {
AuthChain []PDU `json:"auth_chain"`
PDUs []PDU `json:"pdus"`
}
func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) {
_, _, err = c.MakeFullRequest(ctx, RequestParams{
ServerName: serverName,
Method: http.MethodGet,
Path: URLPath{"v1", "state", roomID},
Query: url.Values{
"event_id": {string(eventID)},
},
Authenticate: true,
ResponseJSON: &resp,
})
return
}
type RespGetStateIDs struct {
AuthChain []id.EventID `json:"auth_chain_ids"`
PDUs []id.EventID `json:"pdu_ids"`
}
func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) {
_, _, err = c.MakeFullRequest(ctx, RequestParams{
ServerName: serverName,
Method: http.MethodGet,
Path: URLPath{"v1", "state_ids", roomID},
Query: url.Values{
"event_id": {string(eventID)},
},
Authenticate: true,
ResponseJSON: &resp,
})
return
}
func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) {
_, _, err = c.MakeFullRequest(ctx, RequestParams{
ServerName: serverName,
Method: http.MethodGet,
Path: URLPath{"v1", "timestamp_to_event", roomID},
Query: url.Values{
"dir": {string(dir)},
"ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)},
},
Authenticate: true,
ResponseJSON: &resp,
})
return
}
type RespOpenIDUserInfo struct {
Sub id.UserID `json:"sub"`
}
func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) {
_, _, err = c.MakeFullRequest(ctx, RequestParams{
ServerName: serverName,
Method: http.MethodGet,
Path: URLPath{"v1", "openid", "userinfo"},
Query: url.Values{"access_token": {accessToken}},
ResponseJSON: &resp,
})
return
}
type URLPath []any
func (fup URLPath) FullPath() []any {
return append([]any{"_matrix", "federation"}, []any(fup)...)
}
type KeyURLPath []any
func (fkup KeyURLPath) FullPath() []any {
return append([]any{"_matrix", "key"}, []any(fkup)...)
}
type RequestParams struct {
ServerName string
Method string
Path mautrix.PrefixableURLPath
Query url.Values
Authenticate bool
RequestJSON any
ResponseJSON any
DontReadBody bool
}
func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error {
_, _, err := c.MakeFullRequest(ctx, RequestParams{
ServerName: serverName,
Method: method,
Path: path,
Authenticate: authenticate,
RequestJSON: reqJSON,
ResponseJSON: respJSON,
})
return err
}
func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) {
req, err := c.compileRequest(ctx, params)
if err != nil {
return nil, nil, err
}
resp, err := c.HTTP.Do(req)
if err != nil {
return nil, nil, mautrix.HTTPError{
Request: req,
Response: resp,
Message: "request error",
WrappedError: err,
}
}
defer func() {
_ = resp.Body.Close()
}()
var body []byte
if resp.StatusCode >= 400 {
body, err = mautrix.ParseErrorResponse(req, resp)
return body, resp, err
} else if params.ResponseJSON != nil || !params.DontReadBody {
body, err = io.ReadAll(resp.Body)
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
Response: resp,
Message: "failed to read response body",
WrappedError: err,
}
}
if params.ResponseJSON != nil {
err = json.Unmarshal(body, params.ResponseJSON)
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
Response: resp,
Message: "failed to unmarshal response JSON",
ResponseBody: string(body),
WrappedError: err,
}
}
}
}
return body, resp, nil
}
func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) {
reqURL := mautrix.BuildURL(&url.URL{
Scheme: "matrix-federation",
Host: params.ServerName,
}, params.Path.FullPath()...)
reqURL.RawQuery = params.Query.Encode()
var reqJSON json.RawMessage
var reqBody io.Reader
if params.RequestJSON != nil {
var err error
reqJSON, err = json.Marshal(params.RequestJSON)
if err != nil {
return nil, mautrix.HTTPError{
Message: "failed to marshal JSON",
WrappedError: err,
}
}
reqBody = bytes.NewReader(reqJSON)
}
req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody)
if err != nil {
return nil, mautrix.HTTPError{
Message: "failed to create request",
WrappedError: err,
}
}
req.Header.Set("User-Agent", c.UserAgent)
if params.Authenticate {
if c.ServerName == "" || c.Key == nil {
return nil, mautrix.HTTPError{
Message: "client not configured for authentication",
}
}
auth, err := (&signableRequest{
Method: req.Method,
URI: reqURL.RequestURI(),
Origin: c.ServerName,
Destination: params.ServerName,
Content: reqJSON,
}).Sign(c.Key)
if err != nil {
return nil, mautrix.HTTPError{
Message: "failed to sign request",
WrappedError: err,
}
}
req.Header.Set("Authorization", auth)
}
return req, nil
}
type signableRequest struct {
Method string `json:"method"`
URI string `json:"uri"`
Origin string `json:"origin"`
Destination string `json:"destination"`
Content any `json:"content,omitempty"`
}
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
sig, err := key.SignJSON(r)
if err != nil {
return "", err
}
return fmt.Sprintf(
`X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
r.Origin,
r.Destination,
key.ID,
base64.RawURLEncoding.EncodeToString(sig),
), nil
}

23
federation/client_test.go Normal file
View file

@ -0,0 +1,23 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation_test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/federation"
)
func TestClient_Version(t *testing.T) {
cli := federation.NewClient("", nil)
resp, err := cli.Version(context.TODO(), "maunium.net")
require.NoError(t, err)
require.Equal(t, "Synapse", resp.Server.Name)
}

View file

@ -40,13 +40,6 @@ func NewServerResolvingTransport() *ServerResolvingTransport {
return srt
}
func NewFederationHTTPClient() *http.Client {
return &http.Client{
Transport: NewServerResolvingTransport(),
Timeout: 120 * time.Second,
}
}
var _ http.RoundTripper = (*ServerResolvingTransport)(nil)
func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {

View file

@ -1,35 +0,0 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation_test
import (
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/federation"
)
type serverVersionResp struct {
Server struct {
Name string `json:"name"`
Version string `json:"version"`
} `json:"server"`
}
func TestNewFederationClient(t *testing.T) {
cli := federation.NewFederationHTTPClient()
resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version")
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
var respData serverVersionResp
err = json.NewDecoder(resp.Body).Decode(&respData)
require.NoError(t, err)
require.Equal(t, "Synapse", respData.Server.Name)
}

View file

@ -83,6 +83,10 @@ type ServerVerifyKey struct {
Key id.SigningKey `json:"key"`
}
func (svk *ServerVerifyKey) Decode() (ed25519.PublicKey, error) {
return base64.RawStdEncoding.DecodeString(string(svk.Key))
}
type OldVerifyKey struct {
Key id.SigningKey `json:"key"`
ExpiredTS jsontime.UnixMilli `json:"expired_ts"`