diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 83418290..2ad6a614 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -289,10 +289,17 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { return false, nil } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) + if err != nil { + log.Err(err).Msg("Failed to fix parent portals after split portal migration") + os.Exit(31) + return false, nil + } + log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) if err != nil { log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") - os.Exit(32) + os.Exit(31) return false, nil } var roomsToDelete []id.RoomID diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 8570d840..e02b9e44 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -131,6 +131,11 @@ const ( FROM user_portal WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' LIMIT 1 + ), ( + SELECT login_id + FROM user_portal + WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id + LIMIT 1 ), ( SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 ), '') AS new_receiver @@ -141,6 +146,9 @@ const ( SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver ) ` + fixParentsAfterSplitPortalMigrationQuery = ` + UPDATE portal SET parent_receiver=receiver WHERE parent_receiver='' AND receiver<>'' AND parent_id<>''; + ` ) func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { @@ -209,6 +217,14 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) return res.RowsAffected() } +func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64