bridgev2/provisioning: use exhttp utilities for writing responses

This commit is contained in:
Tulir Asokan 2025-05-28 21:24:15 +03:00
commit 64f55ac3a7

View file

@ -21,6 +21,7 @@ import (
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/requestlog"
@ -118,7 +119,7 @@ func (prov *ProvisioningAPI) Init() {
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
prov.Router.Use(hlog.NewHandler(prov.log))
prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
prov.Router.Use(corsMiddleware)
prov.Router.Use(exhttp.CORSMiddleware)
prov.Router.Use(requestlog.AccessLogger(false))
prov.Router.Use(prov.AuthMiddleware)
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
@ -152,25 +153,6 @@ func (prov *ProvisioningAPI) Init() {
}
}
func corsMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
handler.ServeHTTP(w, r)
})
}
func jsonResponse(w http.ResponseWriter, status int, response any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(response)
}
func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
prov.matrixAuthCacheLock.Lock()
defer prov.matrixAuthCacheLock.Unlock()
@ -216,15 +198,9 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" {
jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{
Err: "Missing auth token",
ErrCode: mautrix.MMissingToken.ErrCode,
})
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
} else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{
Err: "Invalid auth token",
ErrCode: mautrix.MUnknownToken.ErrCode,
})
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
} else {
h.ServeHTTP(w, r)
}
@ -238,10 +214,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
auth = prov.GetAuthFromRequest(r)
}
if auth == "" {
jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{
Err: "Missing auth token",
ErrCode: mautrix.MMissingToken.ErrCode,
})
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
return
}
userID := id.UserID(r.URL.Query().Get("user_id"))
@ -258,29 +231,20 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
if err != nil {
zerolog.Ctx(r.Context()).Warn().Err(err).
Msg("Provisioning API request contained invalid auth")
jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{
Err: "Invalid auth token",
ErrCode: mautrix.MUnknownToken.ErrCode,
})
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
return
}
}
user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user")
jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{
Err: "Failed to get user",
ErrCode: "M_UNKNOWN",
})
mautrix.MUnknown.WithMessage("Failed to get user").Write(w)
return
}
// TODO handle user being nil?
// TODO per-endpoint permissions?
if !user.Permissions.Login {
jsonResponse(w, http.StatusForbidden, &mautrix.RespError{
Err: "User does not have login permissions",
ErrCode: mautrix.MForbidden.ErrCode,
})
mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w)
return
}
@ -292,10 +256,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
prov.loginsLock.RUnlock()
if !ok {
zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found")
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
Err: "Login not found",
ErrCode: mautrix.MNotFound.ErrCode,
})
mautrix.MNotFound.WithMessage("Login not found").Write(w)
return
}
login.Lock.Lock()
@ -307,10 +268,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
Str("request_step_id", stepID).
Str("expected_step_id", login.NextStep.StepID).
Msg("Step ID does not match")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Step ID does not match",
ErrCode: mautrix.MBadState.ErrCode,
})
mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
return
}
stepType := mux.Vars(r)["stepType"]
@ -319,10 +277,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
Str("request_step_type", stepType).
Str("expected_step_type", string(login.NextStep.Type)).
Msg("Step type does not match")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Step type does not match",
ErrCode: mautrix.MBadState.ErrCode,
})
mautrix.MBadState.WithMessage("Step type does not match").Write(w)
return
}
ctx = context.WithValue(ctx, provisioningLoginProcessKey, login)
@ -391,7 +346,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
SpaceRoom: login.SpaceRoom,
}
}
jsonResponse(w, http.StatusOK, resp)
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
type RespLoginFlows struct {
@ -404,7 +359,7 @@ type RespSubmitLogin struct {
}
func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusOK, &RespLoginFlows{
exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{
Flows: prov.net.GetLoginFlows(),
})
}
@ -445,7 +400,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
Override: overrideLogin,
}
prov.loginsLock.Unlock()
jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
}
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
@ -467,10 +422,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
err := json.NewDecoder(r.Body).Decode(&params)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Failed to decode request body",
ErrCode: mautrix.MNotJSON.ErrCode,
})
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
return
}
login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin)
@ -492,7 +444,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
}
jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) {
@ -507,7 +459,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
}
jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
@ -524,15 +476,12 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request)
} else {
userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID)
if userLogin == nil || userLogin.UserMXID != user.MXID {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
Err: "Login not found",
ErrCode: mautrix.MNotFound.ErrCode,
})
mautrix.MNotFound.WithMessage("Login not found").Write(w)
return
}
userLogin.Logout(r.Context())
}
jsonResponse(w, http.StatusOK, json.RawMessage("{}"))
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
}
type RespGetLogins struct {
@ -541,7 +490,7 @@ type RespGetLogins struct {
func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()})
exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()})
}
func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) {
@ -551,15 +500,18 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r
}
userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID)
if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
Err: "Login not found",
ErrCode: mautrix.MNotFound.ErrCode,
})
mautrix.MNotFound.WithMessage("Login not found").Write(w)
return nil, true
}
return userLogin, false
}
var ErrNotLoggedIn = mautrix.RespError{
Err: "Not logged in",
ErrCode: "FI.MAU.NOT_LOGGED_IN",
StatusCode: http.StatusBadRequest,
}
func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin {
userLogin, failed := prov.GetExplicitLoginForRequest(w, r)
if userLogin != nil || failed {
@ -567,10 +519,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R
}
userLogin = prov.GetUser(r).GetDefaultLogin()
if userLogin == nil {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Not logged in",
ErrCode: "FI.MAU.NOT_LOGGED_IN",
})
ErrNotLoggedIn.Write(w)
return nil
}
return userLogin
@ -585,11 +534,7 @@ func RespondWithError(w http.ResponseWriter, err error, message string) {
if errors.As(err, &we) {
we.Write(w)
} else {
mautrix.RespError{
Err: message,
ErrCode: "M_UNKNOWN",
StatusCode: http.StatusInternalServerError,
}.Write(w)
mautrix.MUnknown.WithMessage(message).Write(w)
}
}
@ -609,10 +554,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
}
api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
if !ok {
jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
Err: "This bridge does not support resolving identifiers",
ErrCode: mautrix.MUnrecognized.ErrCode,
})
mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
return
}
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
@ -621,10 +563,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
RespondWithError(w, err, "Internal error resolving identifier")
return
} else if resp == nil {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MNotFound.ErrCode,
Err: "Identifier not found",
})
mautrix.MNotFound.WithMessage("Identifier not found").Write(w)
return
}
apiResp := &RespResolveIdentifier{
@ -647,10 +586,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal")
jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{
Err: "Failed to get portal",
ErrCode: "M_UNKNOWN",
})
mautrix.MUnknown.WithMessage("Failed to get portal").Write(w)
return
}
}
@ -659,16 +595,13 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room")
jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{
Err: "Failed to create portal room",
ErrCode: "M_UNKNOWN",
})
mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w)
return
}
}
apiResp.DMRoomID = resp.Chat.Portal.MXID
}
jsonResponse(w, status, apiResp)
exhttp.WriteJSONResponse(w, status, apiResp)
}
type RespGetContactList struct {
@ -723,10 +656,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
}
api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
if !ok {
jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
Err: "This bridge does not support listing contacts",
ErrCode: mautrix.MUnrecognized.ErrCode,
})
mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w)
return
}
resp, err := api.GetContactList(r.Context())
@ -735,7 +665,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
RespondWithError(w, err, "Internal error fetching contact list")
return
}
jsonResponse(w, http.StatusOK, &RespGetContactList{
exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{
Contacts: prov.processResolveIdentifiers(r.Context(), resp),
})
}
@ -753,10 +683,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Failed to decode request body",
ErrCode: mautrix.MNotJSON.ErrCode,
})
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
return
}
login := prov.GetLoginForRequest(w, r)
@ -765,10 +692,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
}
api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
if !ok {
jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
Err: "This bridge does not support searching for users",
ErrCode: mautrix.MUnrecognized.ErrCode,
})
mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w)
return
}
resp, err := api.SearchUsers(r.Context(), req.Query)
@ -777,7 +701,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
RespondWithError(w, err, "Internal error fetching contact list")
return
}
jsonResponse(w, http.StatusOK, &RespSearchUsers{
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{
Results: prov.processResolveIdentifiers(r.Context(), resp),
})
}
@ -795,10 +719,7 @@ func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Requ
if login == nil {
return
}
jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
Err: "Creating groups is not yet implemented",
ErrCode: mautrix.MUnrecognized.ErrCode,
})
mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w)
}
type ReqExportCredentials struct {
@ -817,10 +738,7 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Failed to decode request body",
ErrCode: mautrix.MNotJSON.ErrCode,
})
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
return
}
@ -834,19 +752,13 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h
}
}
if loginToExport == nil {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
Err: "No matching user login found",
ErrCode: mautrix.MNotFound.ErrCode,
})
mautrix.MNotFound.WithMessage("No matching user login found").Write(w)
return
}
client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI)
if !ok {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Client does not support credential exporting",
ErrCode: mautrix.MInvalidParam.ErrCode,
})
mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w)
return
}
@ -858,10 +770,9 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h
// Disconnect now so we don't use the same network session in two places at once
client.Disconnect()
resp := RespExportCredentials{
exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{
Credentials: client.ExportCredentials(r.Context()),
}
jsonResponse(w, http.StatusOK, resp)
})
}
func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) {
@ -872,10 +783,7 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "Failed to decode request body",
ErrCode: mautrix.MNotJSON.ErrCode,
})
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
return
}
@ -889,16 +797,10 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r
}
}
if loginToExport == nil {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
Err: "No matching user login found",
ErrCode: mautrix.MNotFound.ErrCode,
})
mautrix.MNotFound.WithMessage("No matching user login found").Write(w)
return
} else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
Err: "No matching credential export found",
ErrCode: mautrix.MNotJSON.ErrCode,
})
mautrix.MBadState.WithMessage("No matching credential export found").Write(w)
return
}
@ -909,5 +811,5 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r
loginToExport.Client.LogoutRemote(r.Context())
delete(prov.sessionTransfers, req.RemoteID)
jsonResponse(w, http.StatusOK, struct{}{})
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
}