From 64f55ac3a7eb9fba7fc0ad74e2900135253fea52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 21:24:15 +0300 Subject: [PATCH] bridgev2/provisioning: use exhttp utilities for writing responses --- bridgev2/matrix/provisioning.go | 196 ++++++++------------------------ 1 file changed, 49 insertions(+), 147 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 83f56fa0..2a84bdf2 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -21,6 +21,7 @@ import ( "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" "go.mau.fi/util/requestlog" @@ -118,7 +119,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() prov.Router.Use(hlog.NewHandler(prov.log)) prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(corsMiddleware) + prov.Router.Use(exhttp.CORSMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) @@ -152,25 +153,6 @@ func (prov *ProvisioningAPI) Init() { } } -func corsMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - handler.ServeHTTP(w, r) - }) -} - -func jsonResponse(w http.ResponseWriter, status int, response any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(response) -} - func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { prov.matrixAuthCacheLock.Lock() defer prov.matrixAuthCacheLock.Unlock() @@ -216,15 +198,9 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) } else { h.ServeHTTP(w, r) } @@ -238,10 +214,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { auth = prov.GetAuthFromRequest(r) } if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) return } userID := id.UserID(r.URL.Query().Get("user_id")) @@ -258,29 +231,20 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) return } } user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get user", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to get user").Write(w) return } // TODO handle user being nil? // TODO per-endpoint permissions? if !user.Permissions.Login { - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "User does not have login permissions", - ErrCode: mautrix.MForbidden.ErrCode, - }) + mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w) return } @@ -292,10 +256,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { prov.loginsLock.RUnlock() if !ok { zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return } login.Lock.Lock() @@ -307,10 +268,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { Str("request_step_id", stepID). Str("expected_step_id", login.NextStep.StepID). Msg("Step ID does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step ID does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) return } stepType := mux.Vars(r)["stepType"] @@ -319,10 +277,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { Str("request_step_type", stepType). Str("expected_step_type", string(login.NextStep.Type)). Msg("Step type does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step type does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) + mautrix.MBadState.WithMessage("Step type does not match").Write(w) return } ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) @@ -391,7 +346,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { SpaceRoom: login.SpaceRoom, } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type RespLoginFlows struct { @@ -404,7 +359,7 @@ type RespSubmitLogin struct { } func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespLoginFlows{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{ Flows: prov.net.GetLoginFlows(), }) } @@ -445,7 +400,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { @@ -467,10 +422,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http err := json.NewDecoder(r.Body).Decode(¶ms) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) @@ -492,7 +444,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { @@ -507,7 +459,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { @@ -524,15 +476,12 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) } else { userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != user.MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return } userLogin.Logout(r.Context()) } - jsonResponse(w, http.StatusOK, json.RawMessage("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } type RespGetLogins struct { @@ -541,7 +490,7 @@ type RespGetLogins struct { func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { @@ -551,15 +500,18 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r } userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return nil, true } return userLogin, false } +var ErrNotLoggedIn = mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + StatusCode: http.StatusBadRequest, +} + func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { userLogin, failed := prov.GetExplicitLoginForRequest(w, r) if userLogin != nil || failed { @@ -567,10 +519,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R } userLogin = prov.GetUser(r).GetDefaultLogin() if userLogin == nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - }) + ErrNotLoggedIn.Write(w) return nil } return userLogin @@ -585,11 +534,7 @@ func RespondWithError(w http.ResponseWriter, err error, message string) { if errors.As(err, &we) { we.Write(w) } else { - mautrix.RespError{ - Err: message, - ErrCode: "M_UNKNOWN", - StatusCode: http.StatusInternalServerError, - }.Write(w) + mautrix.MUnknown.WithMessage(message).Write(w) } } @@ -609,10 +554,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support resolving identifiers", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) return } resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) @@ -621,10 +563,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. RespondWithError(w, err, "Internal error resolving identifier") return } else if resp == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Identifier not found", - }) + mautrix.MNotFound.WithMessage("Identifier not found").Write(w) return } apiResp := &RespResolveIdentifier{ @@ -647,10 +586,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get portal", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to get portal").Write(w) return } } @@ -659,16 +595,13 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to create portal room", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w) return } } apiResp.DMRoomID = resp.Chat.Portal.MXID } - jsonResponse(w, status, apiResp) + exhttp.WriteJSONResponse(w, status, apiResp) } type RespGetContactList struct { @@ -723,10 +656,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque } api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support listing contacts", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w) return } resp, err := api.GetContactList(r.Context()) @@ -735,7 +665,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque RespondWithError(w, err, "Internal error fetching contact list") return } - jsonResponse(w, http.StatusOK, &RespGetContactList{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{ Contacts: prov.processResolveIdentifiers(r.Context(), resp), }) } @@ -753,10 +683,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := prov.GetLoginForRequest(w, r) @@ -765,10 +692,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ } api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support searching for users", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w) return } resp, err := api.SearchUsers(r.Context(), req.Query) @@ -777,7 +701,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ RespondWithError(w, err, "Internal error fetching contact list") return } - jsonResponse(w, http.StatusOK, &RespSearchUsers{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{ Results: prov.processResolveIdentifiers(r.Context(), resp), }) } @@ -795,10 +719,7 @@ func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Requ if login == nil { return } - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "Creating groups is not yet implemented", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w) } type ReqExportCredentials struct { @@ -817,10 +738,7 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } @@ -834,19 +752,13 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h } } if loginToExport == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "No matching user login found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) return } client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) if !ok { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Client does not support credential exporting", - ErrCode: mautrix.MInvalidParam.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w) return } @@ -858,10 +770,9 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h // Disconnect now so we don't use the same network session in two places at once client.Disconnect() - resp := RespExportCredentials{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{ Credentials: client.ExportCredentials(r.Context()), - } - jsonResponse(w, http.StatusOK, resp) + }) } func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { @@ -872,10 +783,7 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } @@ -889,16 +797,10 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r } } if loginToExport == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "No matching user login found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) return } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "No matching credential export found", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MBadState.WithMessage("No matching credential export found").Write(w) return } @@ -909,5 +811,5 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r loginToExport.Client.LogoutRemote(r.Context()) delete(prov.sessionTransfers, req.RemoteID) - jsonResponse(w, http.StatusOK, struct{}{}) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) }