diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index f250ccd7..4c87af0f 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -11,7 +11,6 @@ import ( "encoding/json" "fmt" "net/http" - "net/url" "strings" "sync" "time" @@ -43,7 +42,7 @@ type ProvisioningAPI struct { log zerolog.Logger net bridgev2.NetworkConnector - fedClient *http.Client + fedClient *federation.Client logins map[string]*ProvLogin loginsLock sync.RWMutex @@ -89,9 +88,9 @@ func (prov *ProvisioningAPI) Init() { prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewFederationHTTPClient() - tp := prov.fedClient.Transport.(*federation.ServerResolvingTransport) - prov.fedClient.Timeout = 20 * time.Second + prov.fedClient = federation.NewClient("", nil) + prov.fedClient.HTTP.Timeout = 20 * time.Second + tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second @@ -159,10 +158,6 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User } } -type respOpenIDUserInfo struct { - Sub id.UserID `json:"sub"` -} - func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userID id.UserID, token string) error { homeserver := userID.Homeserver() wrappedToken := fmt.Sprintf("%s:%s", homeserver, token) @@ -171,9 +166,9 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI defer prov.matrixAuthCacheLock.Unlock() if cached, ok := prov.matrixAuthCache[wrappedToken]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { return nil - } else if validationResult, err := prov.validateOpenIDToken(ctx, homeserver, token); err != nil { + } else if validationResult, err := prov.fedClient.GetOpenIDUserInfo(ctx, homeserver, token); err != nil { return fmt.Errorf("failed to validate OpenID token: %w", err) - } else if validationResult != userID { + } else if validationResult.Sub != userID { return fmt.Errorf("mismatching user ID (%q != %q)", validationResult, userID) } else { prov.matrixAuthCache[wrappedToken] = matrixAuthCacheEntry{ @@ -184,35 +179,6 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } -func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, server string, token string) (id.UserID, error) { - reqURL := url.URL{ - Scheme: "matrix-federation", - Host: server, - Path: "/_matrix/federation/v1/openid/userinfo", - RawQuery: (&url.Values{ - "access_token": {token}, - }).Encode(), - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) - if err != nil { - return "", fmt.Errorf("failed to prepare request: %w", err) - } - resp, err := prov.fedClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code %d", resp.StatusCode) - } - var respData respOpenIDUserInfo - err = json.NewDecoder(resp.Body).Decode(&respData) - if err != nil { - return "", fmt.Errorf("failed to decode response: %w", err) - } - return respData.Sub, nil -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") diff --git a/federation/client.go b/federation/client.go new file mode 100644 index 00000000..dc8c139c --- /dev/null +++ b/federation/client.go @@ -0,0 +1,373 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "go.mau.fi/util/exslices" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type Client struct { + HTTP *http.Client + ServerName string + UserAgent string + Key *SigningKey +} + +func NewClient(serverName string, key *SigningKey) *Client { + return &Client{ + HTTP: &http.Client{ + Transport: NewServerResolvingTransport(), + Timeout: 120 * time.Second, + }, + UserAgent: mautrix.DefaultUserAgent, + ServerName: serverName, + Key: key, + } +} + +func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp) + return +} + +func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp) + return +} + +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) + return +} + +type PDU = json.RawMessage +type EDU = json.RawMessage + +type ReqSendTransaction struct { + Destination string `json:"destination"` + TxnID string `json:"-"` + + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` + EDUs []EDU `json:"edus,omitempty"` +} + +type PDUProcessingResult struct { + Error string `json:"error,omitempty"` +} + +type RespSendTransaction struct { + PDUs map[id.EventID]PDUProcessingResult `json:"pdus"` +} + +func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) + return +} + +type RespGetEventAuthChain struct { + AuthChain []PDU `json:"auth_chain"` +} + +func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp) + return +} + +type ReqBackfill struct { + ServerName string + RoomID id.RoomID + Limit int + BackfillFrom []id.EventID +} + +type RespBackfill struct { + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.ServerName, + Method: http.MethodGet, + Path: URLPath{"v1", "backfill", req.RoomID}, + Query: url.Values{ + "limit": {strconv.Itoa(req.Limit)}, + "v": exslices.CastToString[string](req.BackfillFrom), + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type ReqGetMissingEvents struct { + ServerName string `json:"-"` + RoomID id.RoomID `json:"-"` + EarliestEvents []id.EventID `json:"earliest_events"` + LatestEvents []id.EventID `json:"latest_events"` + Limit int `json:"limit,omitempty"` + MinDepth int `json:"min_depth,omitempty"` +} + +type RespGetMissingEvents struct { + Events []PDU `json:"events"` +} + +func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) { + err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp) + return +} + +func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp) + return +} + +type RespGetState struct { + AuthChain []PDU `json:"auth_chain"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespGetStateIDs struct { + AuthChain []id.EventID `json:"auth_chain_ids"` + PDUs []id.EventID `json:"pdu_ids"` +} + +func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state_ids", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "timestamp_to_event", roomID}, + Query: url.Values{ + "dir": {string(dir)}, + "ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespOpenIDUserInfo struct { + Sub id.UserID `json:"sub"` +} + +func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "openid", "userinfo"}, + Query: url.Values{"access_token": {accessToken}}, + ResponseJSON: &resp, + }) + return +} + +type URLPath []any + +func (fup URLPath) FullPath() []any { + return append([]any{"_matrix", "federation"}, []any(fup)...) +} + +type KeyURLPath []any + +func (fkup KeyURLPath) FullPath() []any { + return append([]any{"_matrix", "key"}, []any(fkup)...) +} + +type RequestParams struct { + ServerName string + Method string + Path mautrix.PrefixableURLPath + Query url.Values + Authenticate bool + RequestJSON any + + ResponseJSON any + DontReadBody bool +} + +func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error { + _, _, err := c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: method, + Path: path, + Authenticate: authenticate, + RequestJSON: reqJSON, + ResponseJSON: respJSON, + }) + return err +} + +func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) { + req, err := c.compileRequest(ctx, params) + if err != nil { + return nil, nil, err + } + resp, err := c.HTTP.Do(req) + if err != nil { + return nil, nil, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "request error", + WrappedError: err, + } + } + defer func() { + _ = resp.Body.Close() + }() + var body []byte + if resp.StatusCode >= 400 { + body, err = mautrix.ParseErrorResponse(req, resp) + return body, resp, err + } else if params.ResponseJSON != nil || !params.DontReadBody { + body, err = io.ReadAll(resp.Body) + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to read response body", + WrappedError: err, + } + } + if params.ResponseJSON != nil { + err = json.Unmarshal(body, params.ResponseJSON) + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to unmarshal response JSON", + ResponseBody: string(body), + WrappedError: err, + } + } + } + } + return body, resp, nil +} + +func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) { + reqURL := mautrix.BuildURL(&url.URL{ + Scheme: "matrix-federation", + Host: params.ServerName, + }, params.Path.FullPath()...) + reqURL.RawQuery = params.Query.Encode() + var reqJSON json.RawMessage + var reqBody io.Reader + if params.RequestJSON != nil { + var err error + reqJSON, err = json.Marshal(params.RequestJSON) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to marshal JSON", + WrappedError: err, + } + } + reqBody = bytes.NewReader(reqJSON) + } + req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to create request", + WrappedError: err, + } + } + req.Header.Set("User-Agent", c.UserAgent) + if params.Authenticate { + if c.ServerName == "" || c.Key == nil { + return nil, mautrix.HTTPError{ + Message: "client not configured for authentication", + } + } + auth, err := (&signableRequest{ + Method: req.Method, + URI: reqURL.RequestURI(), + Origin: c.ServerName, + Destination: params.ServerName, + Content: reqJSON, + }).Sign(c.Key) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to sign request", + WrappedError: err, + } + } + req.Header.Set("Authorization", auth) + } + return req, nil +} + +type signableRequest struct { + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content any `json:"content,omitempty"` +} + +func (r *signableRequest) Sign(key *SigningKey) (string, error) { + sig, err := key.SignJSON(r) + if err != nil { + return "", err + } + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + r.Origin, + r.Destination, + key.ID, + base64.RawURLEncoding.EncodeToString(sig), + ), nil +} diff --git a/federation/client_test.go b/federation/client_test.go new file mode 100644 index 00000000..ba3c3ed4 --- /dev/null +++ b/federation/client_test.go @@ -0,0 +1,23 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestClient_Version(t *testing.T) { + cli := federation.NewClient("", nil) + resp, err := cli.Version(context.TODO(), "maunium.net") + require.NoError(t, err) + require.Equal(t, "Synapse", resp.Server.Name) +} diff --git a/federation/request.go b/federation/httpclient.go similarity index 95% rename from federation/request.go rename to federation/httpclient.go index faeb16ad..d6d97280 100644 --- a/federation/request.go +++ b/federation/httpclient.go @@ -40,13 +40,6 @@ func NewServerResolvingTransport() *ServerResolvingTransport { return srt } -func NewFederationHTTPClient() *http.Client { - return &http.Client{ - Transport: NewServerResolvingTransport(), - Timeout: 120 * time.Second, - } -} - var _ http.RoundTripper = (*ServerResolvingTransport)(nil) func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/federation/request_test.go b/federation/request_test.go deleted file mode 100644 index e9037f2d..00000000 --- a/federation/request_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation_test - -import ( - "encoding/json" - "net/http" - "testing" - - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/federation" -) - -type serverVersionResp struct { - Server struct { - Name string `json:"name"` - Version string `json:"version"` - } `json:"server"` -} - -func TestNewFederationClient(t *testing.T) { - cli := federation.NewFederationHTTPClient() - resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version") - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - var respData serverVersionResp - err = json.NewDecoder(resp.Body).Decode(&respData) - require.NoError(t, err) - require.Equal(t, "Synapse", respData.Server.Name) -} diff --git a/federation/signingkey.go b/federation/signingkey.go index 3d118233..67751b48 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -83,6 +83,10 @@ type ServerVerifyKey struct { Key id.SigningKey `json:"key"` } +func (svk *ServerVerifyKey) Decode() (ed25519.PublicKey, error) { + return base64.RawStdEncoding.DecodeString(string(svk.Key)) +} + type OldVerifyKey struct { Key id.SigningKey `json:"key"` ExpiredTS jsontime.UnixMilli `json:"expired_ts"`