diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 32556de1..d908483e 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -14,6 +14,7 @@ import ( "fmt" "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridgev2" @@ -23,7 +24,7 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { +func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { _, err := br.DB.Exec(ctx, renameTablesQuery) if err != nil { @@ -36,6 +37,22 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if upgradesTo < newDBVersion || compat > newDBVersion { return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) } + if otherTable != nil { + _, err = br.DB.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", otherTableName)) + if err != nil { + return err + } + upgradesTo, compat, err = otherTable[0].DangerouslyRun(ctx, br.DB) + if err != nil { + return err + } else if upgradesTo < otherNewVersion || compat > otherNewVersion { + return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, upgradesTo, compat, newDBVersion) + } + _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), upgradesTo, compat) + if err != nil { + return err + } + } copyDataQuery, err = br.DB.Internals().FilterSQLUpgrade(bytes.Split([]byte(copyDataQuery), []byte("\n"))) if err != nil { return err @@ -61,7 +78,17 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin } } -func (br *BridgeMain) CheckLegacyDB(expectedVersion int, minBridgeVersion, firstMegaVersion string, migrator func(context.Context) error, transaction bool) { +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { + return br.LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery, newDBVersion, nil, "", 0) +} + +func (br *BridgeMain) CheckLegacyDB( + expectedVersion int, + minBridgeVersion, + firstMegaVersion string, + migrator func(context.Context) error, + transaction bool, +) { log := br.Log.With().Str("action", "migrate legacy db").Logger() ctx := log.WithContext(context.Background()) exists, err := br.DB.TableExists(ctx, "database_owner")