diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 6b594deb..e3ec21dd 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -258,38 +258,6 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) ctx = context.WithValue(ctx, provisioningUserKey, user) - if loginID := r.PathValue("loginProcessID"); loginID != "" { - prov.loginsLock.RLock() - login, ok := prov.logins[loginID] - prov.loginsLock.RUnlock() - if !ok { - zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - mautrix.MNotFound.WithMessage("Login not found").Write(w) - return - } - login.Lock.Lock() - // This will only unlock after the handler runs - defer login.Lock.Unlock() - stepID := r.PathValue("stepID") - if login.NextStep.StepID != stepID { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_id", stepID). - Str("expected_step_id", login.NextStep.StepID). - Msg("Step ID does not match") - mautrix.MBadState.WithMessage("Step ID does not match").Write(w) - return - } - stepType := r.PathValue("stepType") - if login.NextStep.Type != bridgev2.LoginStepType(stepType) { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_type", stepType). - Str("expected_step_type", string(login.NextStep.Type)). - Msg("Step type does not match") - mautrix.MBadState.WithMessage("Step type does not match").Write(w) - return - } - ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) - } h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -431,6 +399,38 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov } func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { + loginID := r.PathValue("loginProcessID") + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := r.PathValue("stepID") + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) + return + } + stepType := r.PathValue("stepType") + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + mautrix.MBadState.WithMessage("Step type does not match").Write(w) + return + } + ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login) + r = r.WithContext(ctx) switch bridgev2.LoginStepType(r.PathValue("stepType")) { case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: prov.PostLoginSubmitInput(w, r) @@ -439,7 +439,7 @@ func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Reques case bridgev2.LoginStepTypeComplete: fallthrough default: - // This is probably impossible because AuthMiddleware checks that the next step type matches the request. + // This is probably impossible because of the above check that the next step type matches the request. mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) } }