mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Move a bunch of stuff from mautrix-whatsapp
Moved parts: * Appservice SQL state store * Bridge crypto helper * Database upgrade framework * Bridge startup flow Other changes: * Improved database upgrade framework * Now primarily using static SQL files compiled with go:embed * Moved appservice SQL state store to using membership enum on Postgres
This commit is contained in:
parent
915aa9dd1f
commit
d578d1a610
24 changed files with 1925 additions and 393 deletions
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2019 Tulir Asokan
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -10,11 +10,11 @@ import (
|
|||
"io/ioutil"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Registration contains the data in a Matrix appservice registration.
|
||||
// See https://matrix.org/docs/spec/application_service/unstable.html#registration
|
||||
// See https://spec.matrix.org/v1.2/application-service-api/#registration
|
||||
type Registration struct {
|
||||
ID string `yaml:"id"`
|
||||
URL string `yaml:"url"`
|
||||
|
|
@ -23,8 +23,10 @@ type Registration struct {
|
|||
SenderLocalpart string `yaml:"sender_localpart"`
|
||||
RateLimited *bool `yaml:"rate_limited,omitempty"`
|
||||
Namespaces Namespaces `yaml:"namespaces"`
|
||||
EphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"`
|
||||
Protocols []string `yaml:"protocols,omitempty"`
|
||||
|
||||
SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"`
|
||||
EphemeralEvents bool `yaml:"push_ephemeral,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRegistration creates a Registration with random appservice and homeserver tokens.
|
||||
|
|
@ -70,9 +72,9 @@ func (reg *Registration) YAML() (string, error) {
|
|||
|
||||
// Namespaces contains the three areas that appservices can reserve parts of.
|
||||
type Namespaces struct {
|
||||
UserIDs []Namespace `yaml:"users,omitempty"`
|
||||
RoomAliases []Namespace `yaml:"aliases,omitempty"`
|
||||
RoomIDs []Namespace `yaml:"rooms,omitempty"`
|
||||
UserIDs NamespaceList `yaml:"users,omitempty"`
|
||||
RoomAliases NamespaceList `yaml:"aliases,omitempty"`
|
||||
RoomIDs NamespaceList `yaml:"rooms,omitempty"`
|
||||
}
|
||||
|
||||
// Namespace is a reserved namespace in any area.
|
||||
|
|
@ -81,26 +83,16 @@ type Namespace struct {
|
|||
Exclusive bool `yaml:"exclusive"`
|
||||
}
|
||||
|
||||
// RegisterUserIDs creates an user ID namespace registration.
|
||||
func (nslist *Namespaces) RegisterUserIDs(regex *regexp.Regexp, exclusive bool) {
|
||||
nslist.UserIDs = append(nslist.UserIDs, Namespace{
|
||||
Regex: regex.String(),
|
||||
Exclusive: exclusive,
|
||||
})
|
||||
}
|
||||
type NamespaceList []Namespace
|
||||
|
||||
// RegisterRoomAliases creates an room alias namespace registration.
|
||||
func (nslist *Namespaces) RegisterRoomAliases(regex *regexp.Regexp, exclusive bool) {
|
||||
nslist.RoomAliases = append(nslist.RoomAliases, Namespace{
|
||||
func (nsl *NamespaceList) Register(regex *regexp.Regexp, exclusive bool) {
|
||||
ns := Namespace{
|
||||
Regex: regex.String(),
|
||||
Exclusive: exclusive,
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterRoomIDs creates an room ID namespace registration.
|
||||
func (nslist *Namespaces) RegisterRoomIDs(regex *regexp.Regexp, exclusive bool) {
|
||||
nslist.RoomIDs = append(nslist.RoomIDs, Namespace{
|
||||
Regex: regex.String(),
|
||||
Exclusive: exclusive,
|
||||
})
|
||||
}
|
||||
if nsl == nil {
|
||||
*nsl = []Namespace{ns}
|
||||
} else {
|
||||
*nsl = append(*nsl, ns)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
284
appservice/sqlstatestore/statestore.go
Normal file
284
appservice/sqlstatestore/statestore.go
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var rawUpgrades embed.FS
|
||||
|
||||
var UpgradeTable dbutil.UpgradeTable
|
||||
|
||||
func init() {
|
||||
UpgradeTable.RegisterFS(rawUpgrades)
|
||||
}
|
||||
|
||||
const VersionTableName = "mx_version"
|
||||
|
||||
type SQLStateStore struct {
|
||||
*dbutil.Database
|
||||
*appservice.TypingStateStore
|
||||
|
||||
log log.Logger
|
||||
|
||||
Typing map[id.RoomID]map[id.UserID]int64
|
||||
typingLock sync.RWMutex
|
||||
}
|
||||
|
||||
var _ appservice.StateStore = (*SQLStateStore)(nil)
|
||||
|
||||
func NewSQLStateStore(db *dbutil.Database) *SQLStateStore {
|
||||
return &SQLStateStore{
|
||||
Database: db.Child("StateStore", VersionTableName, UpgradeTable),
|
||||
TypingStateStore: appservice.NewTypingStateStore(),
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
|
||||
var isRegistered bool
|
||||
err := store.
|
||||
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
Scan(&isRegistered)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
|
||||
}
|
||||
return isRegistered
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
|
||||
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
rows, err := store.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
|
||||
if err != nil {
|
||||
return members
|
||||
}
|
||||
var userID id.UserID
|
||||
var member event.MemberEventContent
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members[userID] = &member
|
||||
}
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
membership := event.MembershipLeave
|
||||
err := store.
|
||||
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&membership)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return membership
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = event.MembershipLeave
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
|
||||
var member event.MemberEventContent
|
||||
err := store.
|
||||
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return &member, err == nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||
rows, err := store.Query(`
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
WHERE user_id=$1 AND portal.encrypted=true
|
||||
`, userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
_, err := store.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
|
||||
`, roomID, userID, membership)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
_, err := store.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
|
||||
`, roomID, levelsBytes)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
levels = &event.PowerLevelsEventContent{}
|
||||
err = json.Unmarshal(data, levels)
|
||||
if err != nil {
|
||||
store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
|
||||
return nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, userID).
|
||||
Scan(&powerLevel)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&powerLevel)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var hasPower bool
|
||||
err := store.
|
||||
QueryRow(`SELECT
|
||||
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
>=
|
||||
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
|
||||
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&hasPower)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue == 0
|
||||
}
|
||||
return hasPower
|
||||
}
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
19
appservice/sqlstatestore/v01-initial-revision.sql
Normal file
19
appservice/sqlstatestore/v01-initial-revision.sql
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
-- v1: Initial revision
|
||||
|
||||
CREATE TABLE mx_registrations (
|
||||
user_id TEXT PRIMARY KEY
|
||||
);
|
||||
|
||||
CREATE TABLE mx_user_profile (
|
||||
room_id TEXT,
|
||||
user_id TEXT,
|
||||
membership TEXT NOT NULL,
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
PRIMARY KEY (room_id, user_id)
|
||||
);
|
||||
|
||||
CREATE TABLE mx_room_state (
|
||||
room_id TEXT PRIMARY KEY,
|
||||
power_levels jsonb
|
||||
);
|
||||
5
appservice/sqlstatestore/v02-membership-enum.sql
Normal file
5
appservice/sqlstatestore/v02-membership-enum.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
-- v2: Use enum for membership field on Postgres
|
||||
-- only: postgres
|
||||
|
||||
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership USING LOWER(membership)::membership;
|
||||
353
bridge/bridge.go
Normal file
353
bridge/bridge.go
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
flag "maunium.net/go/mauflag"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/appservice/sqlstatestore"
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/configupgrade"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
|
||||
var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
|
||||
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
|
||||
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
|
||||
var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
|
||||
var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
|
||||
var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
|
||||
var wantHelp, _ = flag.MakeHelpFlag()
|
||||
|
||||
type Portal interface {
|
||||
IsEncrypted() bool
|
||||
}
|
||||
|
||||
type User interface {
|
||||
IsAdmin() bool
|
||||
}
|
||||
|
||||
type ChildOverride interface {
|
||||
GetExampleConfig() string
|
||||
GetConfigPtr() interface{}
|
||||
|
||||
Init()
|
||||
Start()
|
||||
Stop()
|
||||
|
||||
GetIPortalByMXID(id id.RoomID) Portal
|
||||
GetIUserByMXID(id id.UserID) User
|
||||
}
|
||||
|
||||
type Bridge struct {
|
||||
Name string
|
||||
URL string
|
||||
Description string
|
||||
Version string
|
||||
ProtocolName string
|
||||
|
||||
VersionDesc string
|
||||
LinkifiedVersion string
|
||||
|
||||
AS *appservice.AppService
|
||||
EventProcessor *appservice.EventProcessor
|
||||
Bot *appservice.IntentAPI
|
||||
Config bridgeconfig.BaseConfig
|
||||
ConfigUpgrader configupgrade.BaseUpgrader
|
||||
Log log.Logger
|
||||
DB *dbutil.Database
|
||||
StateStore *sqlstatestore.SQLStateStore
|
||||
Crypto Crypto
|
||||
|
||||
Child ChildOverride
|
||||
}
|
||||
|
||||
type Crypto interface {
|
||||
HandleMemberEvent(*event.Event)
|
||||
Decrypt(*event.Event) (*event.Event, error)
|
||||
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
|
||||
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
|
||||
ResetSession(id.RoomID)
|
||||
Init() error
|
||||
Start()
|
||||
Stop()
|
||||
}
|
||||
|
||||
func (br *Bridge) GenerateRegistration() {
|
||||
if *dontSaveConfig {
|
||||
// We need to save the generated as_token and hs_token in the config
|
||||
_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
|
||||
os.Exit(5)
|
||||
}
|
||||
reg := br.Config.GenerateRegistration()
|
||||
err := reg.Save(*registrationPath)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
|
||||
os.Exit(21)
|
||||
}
|
||||
|
||||
updateTokens := func(helper *configupgrade.Helper) {
|
||||
helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token")
|
||||
helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token")
|
||||
}
|
||||
_, _, err = configupgrade.Do(*configPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens))
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
|
||||
os.Exit(22)
|
||||
}
|
||||
fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func (br *Bridge) InitVersion(tag, commit, buildTime string) {
|
||||
if len(tag) > 0 && tag[0] == 'v' {
|
||||
tag = tag[1:]
|
||||
}
|
||||
if tag != br.Version {
|
||||
suffix := ""
|
||||
if !strings.HasSuffix(br.Version, "+dev") {
|
||||
suffix = "+dev"
|
||||
}
|
||||
if len(commit) > 8 {
|
||||
br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
|
||||
} else {
|
||||
br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
|
||||
}
|
||||
}
|
||||
|
||||
linkifiedVersion := fmt.Sprintf("v%s", br.Version)
|
||||
if tag == br.Version {
|
||||
linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
|
||||
} else if len(commit) > 8 {
|
||||
linkifiedVersion = strings.Replace(linkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
|
||||
}
|
||||
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
|
||||
br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, buildTime)
|
||||
}
|
||||
|
||||
func (br *Bridge) ensureConnection() {
|
||||
for {
|
||||
versions, err := br.Bot.Versions()
|
||||
if err != nil {
|
||||
br.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
|
||||
br.Log.Warnfln("Server isn't advertising modern spec versions")
|
||||
}
|
||||
resp, err := br.Bot.Whoami()
|
||||
if err != nil {
|
||||
if errors.Is(err, mautrix.MUnknownToken) {
|
||||
br.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
|
||||
os.Exit(16)
|
||||
} else if errors.Is(err, mautrix.MExclusive) {
|
||||
br.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
|
||||
os.Exit(16)
|
||||
}
|
||||
br.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
|
||||
time.Sleep(10 * time.Second)
|
||||
} else if resp.UserID != br.Bot.UserID {
|
||||
br.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, br.Bot.UserID)
|
||||
os.Exit(17)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (br *Bridge) UpdateBotProfile() {
|
||||
br.Log.Debugln("Updating bot profile")
|
||||
botConfig := &br.Config.AppService.Bot
|
||||
|
||||
var err error
|
||||
var mxc id.ContentURI
|
||||
if botConfig.Avatar == "remove" {
|
||||
err = br.Bot.SetAvatarURL(mxc)
|
||||
} else if !botConfig.ParsedAvatar.IsEmpty() {
|
||||
err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar)
|
||||
}
|
||||
if err != nil {
|
||||
br.Log.Warnln("Failed to update bot avatar:", err)
|
||||
}
|
||||
|
||||
if botConfig.Displayname == "remove" {
|
||||
err = br.Bot.SetDisplayName("")
|
||||
} else if len(botConfig.Displayname) > 0 {
|
||||
err = br.Bot.SetDisplayName(botConfig.Displayname)
|
||||
}
|
||||
if err != nil {
|
||||
br.Log.Warnln("Failed to update bot displayname:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (br *Bridge) loadConfig() {
|
||||
configData, upgraded, err := configupgrade.Do(*configPath, !*dontSaveConfig, br.ConfigUpgrader)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
|
||||
if configData == nil {
|
||||
os.Exit(10)
|
||||
}
|
||||
}
|
||||
|
||||
target := br.Child.GetConfigPtr()
|
||||
if !upgraded {
|
||||
// Fallback: if config upgrading failed, load example config for base values
|
||||
err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err)
|
||||
os.Exit(10)
|
||||
}
|
||||
}
|
||||
err = yaml.Unmarshal(configData, target)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
|
||||
os.Exit(10)
|
||||
}
|
||||
}
|
||||
|
||||
func (br *Bridge) init() {
|
||||
var err error
|
||||
|
||||
br.AS = br.Config.MakeAppService()
|
||||
_, _ = br.AS.Init()
|
||||
|
||||
br.Log = log.Create()
|
||||
br.Config.Logging.Configure(br.Log)
|
||||
log.DefaultLogger = br.Log.(*log.BasicLogger)
|
||||
if len(br.Config.Logging.FileNameFormat) > 0 {
|
||||
err = log.OpenFile()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
|
||||
os.Exit(12)
|
||||
}
|
||||
}
|
||||
br.AS.Log = log.Sub("Matrix")
|
||||
br.Bot = br.AS.BotIntent()
|
||||
br.Log.Infoln("Initializing", br.VersionDesc)
|
||||
|
||||
br.Log.Debugln("Initializing database connection")
|
||||
br.DB, err = dbutil.NewFromConfig(br.Name, br.Config.AppService.Database, br.Log.Sub("Database"))
|
||||
if err != nil {
|
||||
br.Log.Fatalln("Failed to initialize database connection:", err)
|
||||
os.Exit(14)
|
||||
}
|
||||
br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase
|
||||
br.DB.IgnoreForeignTables = *ignoreForeignTables
|
||||
|
||||
br.Log.Debugln("Initializing state store")
|
||||
br.StateStore = sqlstatestore.NewSQLStateStore(br.DB)
|
||||
br.AS.StateStore = br.StateStore
|
||||
|
||||
br.Log.Debugln("Initializing Matrix event processor")
|
||||
br.EventProcessor = appservice.NewEventProcessor(br.AS)
|
||||
|
||||
br.Crypto = NewCryptoHelper(br)
|
||||
|
||||
br.Child.Init()
|
||||
}
|
||||
|
||||
func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) {
|
||||
br.Log.Fatalfln("Failed to initialize %s: %v", name, err)
|
||||
if errors.Is(err, dbutil.ErrForeignTables) {
|
||||
br.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
|
||||
} else if errors.Is(err, dbutil.ErrNotOwned) {
|
||||
br.Log.Infoln("Sharing the same database with different programs is not supported")
|
||||
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
|
||||
br.Log.Infoln("Downgrading the bridge is not supported")
|
||||
}
|
||||
os.Exit(15)
|
||||
}
|
||||
|
||||
func (br *Bridge) start() {
|
||||
br.Log.Debugln("Running database upgrades")
|
||||
err := br.DB.Upgrade()
|
||||
if err != nil {
|
||||
br.LogDBUpgradeErrorAndExit("main database", err)
|
||||
} else if err = br.StateStore.Upgrade(); err != nil {
|
||||
br.LogDBUpgradeErrorAndExit("matrix state store", err)
|
||||
}
|
||||
|
||||
br.Log.Debugln("Checking connection to homeserver")
|
||||
br.ensureConnection()
|
||||
|
||||
if br.Crypto != nil {
|
||||
err = br.Crypto.Init()
|
||||
if err != nil {
|
||||
br.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
|
||||
os.Exit(19)
|
||||
}
|
||||
}
|
||||
|
||||
br.Log.Debugln("Starting application service HTTP server")
|
||||
go br.AS.Start()
|
||||
br.Log.Debugln("Starting event processor")
|
||||
go br.EventProcessor.Start()
|
||||
|
||||
go br.UpdateBotProfile()
|
||||
if br.Crypto != nil {
|
||||
go br.Crypto.Start()
|
||||
}
|
||||
|
||||
br.Child.Start()
|
||||
br.AS.Ready = true
|
||||
}
|
||||
|
||||
func (br *Bridge) Main() {
|
||||
flag.SetHelpTitles(
|
||||
fmt.Sprintf("%s - %s", br.Name, br.Description),
|
||||
fmt.Sprintf("%s [-hgvn] [-c <path>] [-r <path>]", br.Name))
|
||||
err := flag.Parse()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, err)
|
||||
flag.PrintHelp()
|
||||
os.Exit(1)
|
||||
} else if *wantHelp {
|
||||
flag.PrintHelp()
|
||||
os.Exit(0)
|
||||
} else if *version {
|
||||
fmt.Println(br.VersionDesc)
|
||||
return
|
||||
}
|
||||
|
||||
br.loadConfig()
|
||||
|
||||
if *generateRegistration {
|
||||
br.GenerateRegistration()
|
||||
return
|
||||
}
|
||||
|
||||
br.init()
|
||||
br.Log.Infoln("Bridge initialization complete, starting...")
|
||||
br.start()
|
||||
br.Log.Infoln("Bridge started!")
|
||||
|
||||
c := make(chan os.Signal)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
|
||||
br.Log.Infoln("Interrupt received, stopping...")
|
||||
br.Child.Stop()
|
||||
br.Log.Infoln("Bridge stopped.")
|
||||
os.Exit(0)
|
||||
}
|
||||
200
bridge/bridgeconfig/config.go
Normal file
200
bridge/bridgeconfig/config.go
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package bridgeconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/id"
|
||||
up "maunium.net/go/mautrix/util/configupgrade"
|
||||
)
|
||||
|
||||
type HomeserverConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Domain string `yaml:"domain"`
|
||||
AsyncMedia bool `yaml:"async_media"`
|
||||
|
||||
Asmux bool `yaml:"asmux"`
|
||||
StatusEndpoint string `yaml:"status_endpoint"`
|
||||
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
|
||||
}
|
||||
|
||||
type AppserviceConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
|
||||
ID string `yaml:"id"`
|
||||
Bot BotUserConfig `yaml:"bot"`
|
||||
|
||||
ASToken string `yaml:"as_token"`
|
||||
HSToken string `yaml:"hs_token"`
|
||||
|
||||
EphemeralEvents bool `yaml:"ephemeral_events"`
|
||||
}
|
||||
|
||||
func (config *BaseConfig) MakeUserIDRegex() *regexp.Regexp {
|
||||
usernamePlaceholder := appservice.RandomString(16)
|
||||
usernameTemplate := fmt.Sprintf("@%s:%s",
|
||||
config.Bridge.FormatUsername(usernamePlaceholder),
|
||||
config.Homeserver.Domain)
|
||||
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
|
||||
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
|
||||
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
|
||||
return regexp.MustCompile(usernameTemplate)
|
||||
}
|
||||
|
||||
// GenerateRegistration generates a registration file for the homeserver.
|
||||
func (config *BaseConfig) GenerateRegistration() *appservice.Registration {
|
||||
registration := appservice.CreateRegistration()
|
||||
config.AppService.HSToken = registration.ServerToken
|
||||
config.AppService.ASToken = registration.AppToken
|
||||
config.AppService.copyToRegistration(registration)
|
||||
|
||||
registration.SenderLocalpart = appservice.RandomString(32)
|
||||
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
|
||||
regexp.QuoteMeta(config.AppService.Bot.Username),
|
||||
regexp.QuoteMeta(config.Homeserver.Domain)))
|
||||
registration.Namespaces.UserIDs.Register(botRegex, true)
|
||||
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(), true)
|
||||
|
||||
return registration
|
||||
}
|
||||
|
||||
func (config *BaseConfig) MakeAppService() *appservice.AppService {
|
||||
as := appservice.Create()
|
||||
as.HomeserverDomain = config.Homeserver.Domain
|
||||
as.HomeserverURL = config.Homeserver.Address
|
||||
as.Host.Hostname = config.AppService.Hostname
|
||||
as.Host.Port = config.AppService.Port
|
||||
as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
|
||||
as.DefaultHTTPRetries = 4
|
||||
as.Registration = config.AppService.GetRegistration()
|
||||
return as
|
||||
}
|
||||
|
||||
// GetRegistration copies the data from the bridge config into an *appservice.Registration struct.
|
||||
// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver.
|
||||
func (asc *AppserviceConfig) GetRegistration() *appservice.Registration {
|
||||
reg := &appservice.Registration{}
|
||||
asc.copyToRegistration(reg)
|
||||
reg.SenderLocalpart = asc.Bot.Username
|
||||
reg.ServerToken = asc.HSToken
|
||||
reg.AppToken = asc.ASToken
|
||||
return reg
|
||||
}
|
||||
|
||||
func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) {
|
||||
registration.ID = asc.ID
|
||||
registration.URL = asc.Address
|
||||
falseVal := false
|
||||
registration.RateLimited = &falseVal
|
||||
registration.EphemeralEvents = asc.EphemeralEvents
|
||||
registration.SoruEphemeralEvents = asc.EphemeralEvents
|
||||
}
|
||||
|
||||
type BotUserConfig struct {
|
||||
Username string `yaml:"username"`
|
||||
Displayname string `yaml:"displayname"`
|
||||
Avatar string `yaml:"avatar"`
|
||||
|
||||
ParsedAvatar id.ContentURI `yaml:"-"`
|
||||
}
|
||||
|
||||
type serializableBUC BotUserConfig
|
||||
|
||||
func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var sbuc serializableBUC
|
||||
err := unmarshal(&sbuc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*buc = (BotUserConfig)(sbuc)
|
||||
if buc.Avatar != "" && buc.Avatar != "remove" {
|
||||
buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w in bot avatar", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
|
||||
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
type BridgeConfig interface {
|
||||
FormatUsername(username string) string
|
||||
GetEncryptionConfig() EncryptionConfig
|
||||
}
|
||||
|
||||
type EncryptionConfig struct {
|
||||
Allow bool `yaml:"allow"`
|
||||
Default bool `yaml:"default"`
|
||||
|
||||
KeySharing struct {
|
||||
Allow bool `yaml:"allow"`
|
||||
RequireCrossSigning bool `yaml:"require_cross_signing"`
|
||||
RequireVerification bool `yaml:"require_verification"`
|
||||
} `yaml:"key_sharing"`
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
Homeserver HomeserverConfig `yaml:"homeserver"`
|
||||
AppService AppserviceConfig `yaml:"appservice"`
|
||||
Bridge BridgeConfig `yaml:"-"`
|
||||
Logging appservice.LogConfig `yaml:"logging"`
|
||||
}
|
||||
|
||||
type configUpgrader struct{}
|
||||
|
||||
var Upgrader = configUpgrader{}
|
||||
|
||||
func (upg configUpgrader) DoUpgrade(helper *up.Helper) {
|
||||
helper.Copy(up.Str, "homeserver", "address")
|
||||
helper.Copy(up.Str, "homeserver", "domain")
|
||||
helper.Copy(up.Bool, "homeserver", "asmux")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
|
||||
helper.Copy(up.Bool, "homeserver", "async_media")
|
||||
|
||||
helper.Copy(up.Str, "appservice", "address")
|
||||
helper.Copy(up.Str, "appservice", "hostname")
|
||||
helper.Copy(up.Int, "appservice", "port")
|
||||
helper.Copy(up.Str, "appservice", "database", "type")
|
||||
helper.Copy(up.Str, "appservice", "database", "uri")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
|
||||
helper.Copy(up.Str, "appservice", "id")
|
||||
helper.Copy(up.Str, "appservice", "bot", "username")
|
||||
helper.Copy(up.Str, "appservice", "bot", "displayname")
|
||||
helper.Copy(up.Str, "appservice", "bot", "avatar")
|
||||
helper.Copy(up.Bool, "appservice", "ephemeral_events")
|
||||
helper.Copy(up.Str, "appservice", "as_token")
|
||||
helper.Copy(up.Str, "appservice", "hs_token")
|
||||
|
||||
helper.Copy(up.Str, "logging", "directory")
|
||||
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
|
||||
helper.Copy(up.Int, "logging", "file_mode")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
|
||||
helper.Copy(up.Str, "logging", "print_level")
|
||||
}
|
||||
309
bridge/crypto.go
Normal file
309
bridge/crypto.go
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
package bridge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
|
||||
var levelTrace = maulogger.Level{
|
||||
Name: "TRACE",
|
||||
Severity: -10,
|
||||
Color: -1,
|
||||
}
|
||||
|
||||
type CryptoHelper struct {
|
||||
bridge *Bridge
|
||||
client *mautrix.Client
|
||||
mach *crypto.OlmMachine
|
||||
store *SQLCryptoStore
|
||||
log maulogger.Logger
|
||||
baseLog maulogger.Logger
|
||||
}
|
||||
|
||||
func NewCryptoHelper(bridge *Bridge) Crypto {
|
||||
if !bridge.Config.Bridge.GetEncryptionConfig().Allow {
|
||||
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
|
||||
return nil
|
||||
}
|
||||
baseLog := bridge.Log.Sub("Crypto")
|
||||
return &CryptoHelper{
|
||||
bridge: bridge,
|
||||
log: baseLog.Sub("Helper"),
|
||||
baseLog: baseLog,
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
helper.log.Debugln("Initializing end-to-bridge encryption...")
|
||||
|
||||
helper.store = NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
|
||||
fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
|
||||
|
||||
err := helper.store.Upgrade()
|
||||
if err != nil {
|
||||
helper.bridge.LogDBUpgradeErrorAndExit("crypto store", err)
|
||||
}
|
||||
|
||||
helper.client, err = helper.loginBot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
|
||||
logger := &cryptoLogger{helper.baseLog}
|
||||
stateStore := &cryptoStateStore{helper.bridge}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
|
||||
helper.mach.AllowKeyShare = helper.allowKeyShare
|
||||
|
||||
helper.client.Syncer = &cryptoSyncer{helper.mach}
|
||||
helper.client.Store = &cryptoClientStore{helper.store}
|
||||
|
||||
return helper.mach.Load()
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
|
||||
cfg := helper.bridge.Config.Bridge.GetEncryptionConfig().KeySharing
|
||||
if !cfg.Allow {
|
||||
return &crypto.KeyShareRejectNoResponse
|
||||
} else if device.Trust == crypto.TrustStateBlacklisted {
|
||||
return &crypto.KeyShareRejectBlacklisted
|
||||
} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
|
||||
portal := helper.bridge.Child.GetIPortalByMXID(info.RoomID)
|
||||
if portal == nil {
|
||||
helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
|
||||
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
|
||||
}
|
||||
user := helper.bridge.Child.GetIUserByMXID(device.UserID)
|
||||
// FIXME reimplement IsInPortal
|
||||
if !user.IsAdmin() /*&& !user.IsInPortal(portal.Key)*/ {
|
||||
helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
|
||||
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
|
||||
}
|
||||
helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
|
||||
return nil
|
||||
} else {
|
||||
return &crypto.KeyShareRejectUnverified
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
|
||||
deviceID := helper.store.FindDeviceID()
|
||||
if len(deviceID) > 0 {
|
||||
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
|
||||
}
|
||||
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize client: %w", err)
|
||||
}
|
||||
client.Logger = helper.baseLog.Sub("Bot")
|
||||
client.Client = helper.bridge.AS.HTTPClient
|
||||
client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
|
||||
flows, err := client.GetLoginFlows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get supported login flows: %w", err)
|
||||
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
|
||||
return nil, fmt.Errorf("homeserver does not support appservice login")
|
||||
}
|
||||
// We set the API token to the AS token here to authenticate the appservice login
|
||||
// It'll get overridden after the login
|
||||
client.AccessToken = helper.bridge.AS.Registration.AppToken
|
||||
resp, err := client.Login(&mautrix.ReqLogin{
|
||||
Type: mautrix.AuthTypeAppservice,
|
||||
Identifier: mautrix.UserIdentifier{
|
||||
Type: mautrix.IdentifierTypeUser,
|
||||
User: string(helper.bridge.AS.BotMXID()),
|
||||
},
|
||||
DeviceID: deviceID,
|
||||
StoreCredentials: true,
|
||||
|
||||
InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
|
||||
}
|
||||
helper.store.DeviceID = resp.DeviceID
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Start() {
|
||||
helper.log.Debugln("Starting syncer for receiving to-device messages")
|
||||
err := helper.client.Sync()
|
||||
if err != nil {
|
||||
helper.log.Errorln("Fatal error syncing:", err)
|
||||
} else {
|
||||
helper.log.Infoln("Bridge bot to-device syncer stopped without error")
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Stop() {
|
||||
helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
|
||||
helper.client.StopSync()
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
return helper.mach.DecryptMegolmEvent(evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
|
||||
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
return nil, err
|
||||
}
|
||||
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
|
||||
users, err := helper.store.GetRoomMembers(roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get room member list: %w", err)
|
||||
}
|
||||
err = helper.mach.ShareGroupSession(roomID, users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to share group session: %w", err)
|
||||
}
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
|
||||
}
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
|
||||
if err != nil {
|
||||
helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
|
||||
} else {
|
||||
helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
|
||||
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
|
||||
helper.mach.HandleMemberEvent(evt)
|
||||
}
|
||||
|
||||
type cryptoSyncer struct {
|
||||
*crypto.OlmMachine
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
syncer.Log.Trace("Starting sync response handling (%s)", since)
|
||||
syncer.ProcessSyncResponse(resp, since)
|
||||
syncer.Log.Trace("Successfully handled sync response (%s)", since)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(30 * time.Second):
|
||||
syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
|
||||
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
|
||||
return 10 * time.Second, nil
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
|
||||
everything := []event.Type{{Type: "*"}}
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{NotTypes: everything},
|
||||
AccountData: mautrix.FilterPart{NotTypes: everything},
|
||||
Room: mautrix.RoomFilter{
|
||||
IncludeLeave: false,
|
||||
Ephemeral: mautrix.FilterPart{NotTypes: everything},
|
||||
AccountData: mautrix.FilterPart{NotTypes: everything},
|
||||
State: mautrix.FilterPart{NotTypes: everything},
|
||||
Timeline: mautrix.FilterPart{NotTypes: everything},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type cryptoLogger struct {
|
||||
int maulogger.Logger
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Error(message string, args ...interface{}) {
|
||||
c.int.Errorfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
|
||||
c.int.Warnfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
|
||||
c.int.Debugfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||
c.int.Logfln(levelTrace, message, args...)
|
||||
}
|
||||
|
||||
type cryptoClientStore struct {
|
||||
int *SQLCryptoStore
|
||||
}
|
||||
|
||||
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
|
||||
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
|
||||
|
||||
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||
c.int.PutNextBatch(nextBatchToken)
|
||||
}
|
||||
|
||||
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
|
||||
return c.int.GetNextBatch()
|
||||
}
|
||||
|
||||
var _ mautrix.Storer = (*cryptoClientStore)(nil)
|
||||
|
||||
type cryptoStateStore struct {
|
||||
bridge *Bridge
|
||||
}
|
||||
|
||||
var _ crypto.StateStore = (*cryptoStateStore)(nil)
|
||||
|
||||
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
|
||||
portal := c.bridge.Child.GetIPortalByMXID(id)
|
||||
if portal != nil {
|
||||
return portal.IsEncrypted()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
|
||||
return c.bridge.StateStore.FindSharedRooms(id)
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
|
||||
// TODO implement
|
||||
return nil
|
||||
}
|
||||
65
bridge/cryptostore.go
Normal file
65
bridge/cryptostore.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build cgo && !nocrypto
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type SQLCryptoStore struct {
|
||||
*crypto.SQLCryptoStore
|
||||
UserID id.UserID
|
||||
GhostIDFormat string
|
||||
}
|
||||
|
||||
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||
|
||||
func NewSQLCryptoStore(db *dbutil.Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
|
||||
return &SQLCryptoStore{
|
||||
SQLCryptoStore: crypto.NewSQLCryptoStore(db, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
|
||||
UserID: userID,
|
||||
GhostIDFormat: ghostIDFormat,
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
|
||||
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.Log.Warn("Failed to scan device ID: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||
var rows *sql.Rows
|
||||
rows, err = store.DB.Query(`
|
||||
SELECT user_id FROM mx_user_profile
|
||||
WHERE room_id=$1
|
||||
AND (membership='join' OR membership='invite')
|
||||
AND user_id<>$2
|
||||
AND user_id NOT LIKE $3
|
||||
`, roomID, store.UserID, store.GhostIDFormat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID id.UserID
|
||||
err = rows.Scan(&userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members = append(members, userID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
17
bridge/no-crypto.go
Normal file
17
bridge/no-crypto.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
//go:build !cgo || nocrypto
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func NewCryptoHelper(bridge *Bridge) Crypto {
|
||||
if !bridge.Config.Bridge.Encryption.Allow {
|
||||
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
|
||||
}
|
||||
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
|
||||
return nil
|
||||
}
|
||||
|
||||
var NoSessionFound = errors.New("nil")
|
||||
|
|
@ -12,17 +12,23 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
sqlUpgrade "maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
func getOlmMachine(t *testing.T) *OlmMachine {
|
||||
db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlUpgrade.Upgrade(db, "sqlite3")
|
||||
sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{})
|
||||
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.Upgrade(); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
|
||||
userID := id.UserID("@mautrix")
|
||||
mk, _ := olm.NewPkSigning()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -17,6 +17,7 @@ import (
|
|||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
var PostgresArrayWrapper func(interface{}) interface {
|
||||
|
|
@ -26,9 +27,7 @@ var PostgresArrayWrapper func(interface{}) interface {
|
|||
|
||||
// SQLCryptoStore is an implementation of a crypto Store for a database backend.
|
||||
type SQLCryptoStore struct {
|
||||
DB *sql.DB
|
||||
Log Logger
|
||||
Dialect string
|
||||
*dbutil.Database
|
||||
|
||||
AccountID string
|
||||
DeviceID id.DeviceID
|
||||
|
|
@ -44,11 +43,9 @@ var _ Store = (*SQLCryptoStore)(nil)
|
|||
|
||||
// NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material.
|
||||
// The stored material will be encrypted with the given key.
|
||||
func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id.DeviceID, pickleKey []byte, log Logger) *SQLCryptoStore {
|
||||
func NewSQLCryptoStore(db *dbutil.Database, accountID string, deviceID id.DeviceID, pickleKey []byte) *SQLCryptoStore {
|
||||
return &SQLCryptoStore{
|
||||
DB: db,
|
||||
Dialect: dialect,
|
||||
Log: log,
|
||||
Database: db.Child("CryptoStore", sql_store_upgrade.VersionTableName, sql_store_upgrade.Table),
|
||||
PickleKey: pickleKey,
|
||||
AccountID: accountID,
|
||||
DeviceID: deviceID,
|
||||
|
|
@ -58,8 +55,10 @@ func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id
|
|||
}
|
||||
|
||||
// CreateTables applies all the pending database migrations.
|
||||
//
|
||||
// Deprecated: The Upgrade method (inherited from dbutil.Database) should be used instead
|
||||
func (store *SQLCryptoStore) CreateTables() error {
|
||||
return sql_store_upgrade.Upgrade(store.DB, store.Dialect)
|
||||
return store.Upgrade()
|
||||
}
|
||||
|
||||
// Flush does nothing for this implementation as data is already persisted in the database.
|
||||
|
|
@ -581,7 +580,7 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI
|
|||
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if store.Dialect == "postgres" && PostgresArrayWrapper != nil {
|
||||
if store.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
|
||||
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
|
||||
} else {
|
||||
queryString := make([]string, len(users))
|
||||
|
|
|
|||
86
crypto/sql_store_upgrade/00-latest-revision.sql
Normal file
86
crypto/sql_store_upgrade/00-latest-revision.sql
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
-- v0 -> v6: Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT PRIMARY KEY,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account bytea NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_message_index (
|
||||
sender_key CHAR(43),
|
||||
session_id CHAR(43),
|
||||
"index" INTEGER,
|
||||
event_id TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
PRIMARY KEY (sender_key, session_id, "index")
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_tracked_user (
|
||||
user_id TEXT PRIMARY KEY
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_device (
|
||||
user_id TEXT,
|
||||
device_id TEXT,
|
||||
identity_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
trust SMALLINT NOT NULL,
|
||||
deleted BOOLEAN NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, device_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_olm_session (
|
||||
account_id TEXT,
|
||||
session_id CHAR(43),
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
session bytea NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_decrypted timestamp NOT NULL,
|
||||
last_encrypted timestamp NOT NULL,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||
account_id TEXT,
|
||||
session_id CHAR(43),
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43),
|
||||
room_id TEXT NOT NULL,
|
||||
session bytea,
|
||||
forwarding_chains bytea,
|
||||
withheld_code TEXT,
|
||||
withheld_reason TEXT,
|
||||
PRIMARY KEY (account_id, session)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
|
||||
account_id TEXT,
|
||||
room_id TEXT,
|
||||
session_id CHAR(43) NOT NULL UNIQUE,
|
||||
session bytea NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
max_messages INTEGER NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
max_age BIGINT NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL,
|
||||
PRIMARY KEY (account_id, room_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
|
||||
user_id TEXT,
|
||||
usage TEXT,
|
||||
key CHAR(43) NOT NULL,
|
||||
PRIMARY KEY (user_id, usage)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
|
||||
signed_user_id TEXT,
|
||||
signed_key TEXT,
|
||||
signer_user_id TEXT,
|
||||
signer_key TEXT,
|
||||
signature CHAR(88) NOT NULL,
|
||||
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
|
||||
);
|
||||
16
crypto/sql_store_upgrade/04-cross-signing-keys.sql
Normal file
16
crypto/sql_store_upgrade/04-cross-signing-keys.sql
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
-- v4: Add tables for cross-signing keys
|
||||
CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
usage VARCHAR(20) NOT NULL,
|
||||
key CHAR(43) NOT NULL,
|
||||
PRIMARY KEY (user_id, usage)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
|
||||
signed_user_id VARCHAR(255) NOT NULL,
|
||||
signed_key VARCHAR(255) NOT NULL,
|
||||
signer_user_id VARCHAR(255) NOT NULL,
|
||||
signer_key VARCHAR(255) NOT NULL,
|
||||
signature CHAR(88) NOT NULL,
|
||||
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
|
||||
)
|
||||
31
crypto/sql_store_upgrade/05-varchar-to-text.sql
Normal file
31
crypto/sql_store_upgrade/05-varchar-to-text.sql
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
-- v5: Switch from VARCHAR(255) to TEXT
|
||||
-- only: postgres
|
||||
|
||||
ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT;
|
||||
ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT;
|
||||
ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT;
|
||||
ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT;
|
||||
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT;
|
||||
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN withheld_code TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT;
|
||||
ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_cross_signing_keys ALTER COLUMN user_id TYPE TEXT;
|
||||
ALTER TABLE crypto_cross_signing_keys ALTER COLUMN usage TYPE TEXT;
|
||||
|
||||
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_user_id TYPE TEXT;
|
||||
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_key TYPE TEXT;
|
||||
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_user_id TYPE TEXT;
|
||||
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_key TYPE TEXT;
|
||||
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature TYPE TEXT;
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
-- v6: Split last_used into last_encrypted and last_decrypted for Olm sessions
|
||||
ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted;
|
||||
ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp;
|
||||
UPDATE crypto_olm_session SET last_encrypted=last_decrypted;
|
||||
-- only: postgres (too complicated on SQLite)
|
||||
ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL;
|
||||
|
|
@ -1,361 +1,40 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package sql_store_upgrade
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type upgradeFunc func(*sql.Tx, string) error
|
||||
var Table dbutil.UpgradeTable
|
||||
|
||||
var ErrUnknownDialect = errors.New("unknown dialect")
|
||||
const VersionTableName = "crypto_version"
|
||||
|
||||
var Upgrades = [...]upgradeFunc{
|
||||
func(tx *sql.Tx, _ string) error {
|
||||
for _, query := range []string{
|
||||
`CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
device_id VARCHAR(255) PRIMARY KEY,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account bytea NOT NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_message_index (
|
||||
sender_key CHAR(43),
|
||||
session_id CHAR(43),
|
||||
"index" INTEGER,
|
||||
event_id VARCHAR(255) NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
PRIMARY KEY (sender_key, session_id, "index")
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_tracked_user (
|
||||
user_id VARCHAR(255) PRIMARY KEY
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_device (
|
||||
user_id VARCHAR(255),
|
||||
device_id VARCHAR(255),
|
||||
identity_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
trust SMALLINT NOT NULL,
|
||||
deleted BOOLEAN NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
PRIMARY KEY (user_id, device_id)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_olm_session (
|
||||
session_id CHAR(43) PRIMARY KEY,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
session bytea NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||
session_id CHAR(43) PRIMARY KEY,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
room_id VARCHAR(255) NOT NULL,
|
||||
session bytea NOT NULL,
|
||||
forwarding_chains bytea NOT NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
|
||||
room_id VARCHAR(255) PRIMARY KEY,
|
||||
session_id CHAR(43) NOT NULL UNIQUE,
|
||||
session bytea NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
max_messages INTEGER NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
max_age BIGINT NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL
|
||||
)`,
|
||||
} {
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(tx *sql.Tx, dialect string) error {
|
||||
if dialect == "postgres" {
|
||||
tablesToPkeys := map[string][]string{
|
||||
"crypto_account": {},
|
||||
"crypto_olm_session": {"session_id"},
|
||||
"crypto_megolm_inbound_session": {"session_id"},
|
||||
"crypto_megolm_outbound_session": {"room_id"},
|
||||
}
|
||||
for tableName, pkeys := range tablesToPkeys {
|
||||
// add account_id to primary key
|
||||
pkeyStr := strings.Join(append(pkeys, "account_id"), ", ")
|
||||
for _, query := range []string{
|
||||
fmt.Sprintf("ALTER TABLE %s ADD COLUMN account_id VARCHAR(255)", tableName),
|
||||
fmt.Sprintf("UPDATE %s SET account_id=''", tableName),
|
||||
fmt.Sprintf("ALTER TABLE %s ALTER COLUMN account_id SET NOT NULL", tableName),
|
||||
fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s_pkey", tableName, tableName),
|
||||
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s_pkey PRIMARY KEY (%s)", tableName, tableName, pkeyStr),
|
||||
} {
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if dialect == "sqlite3" {
|
||||
tableCols := map[string]string{
|
||||
"crypto_account": `
|
||||
account_id VARCHAR(255) NOT NULL,
|
||||
device_id VARCHAR(255) NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account BLOB NOT NULL,
|
||||
PRIMARY KEY (account_id)
|
||||
`,
|
||||
"crypto_olm_session": `
|
||||
account_id VARCHAR(255) NOT NULL,
|
||||
session_id CHAR(43) NOT NULL,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
session BLOB NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
`,
|
||||
"crypto_megolm_inbound_session": `
|
||||
account_id VARCHAR(255) NOT NULL,
|
||||
session_id CHAR(43) NOT NULL,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
room_id VARCHAR(255) NOT NULL,
|
||||
session BLOB NOT NULL,
|
||||
forwarding_chains BLOB NOT NULL,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
`,
|
||||
"crypto_megolm_outbound_session": `
|
||||
account_id VARCHAR(255) NOT NULL,
|
||||
room_id VARCHAR(255) NOT NULL,
|
||||
session_id CHAR(43) NOT NULL UNIQUE,
|
||||
session BLOB NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
max_messages INTEGER NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
max_age BIGINT NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL,
|
||||
PRIMARY KEY (account_id, room_id)
|
||||
`,
|
||||
}
|
||||
for tableName, cols := range tableCols {
|
||||
// re-create tables with account_id column and new pkey and re-insert rows
|
||||
for _, query := range []string{
|
||||
fmt.Sprintf("ALTER TABLE %s RENAME TO old_%s", tableName, tableName),
|
||||
fmt.Sprintf("CREATE TABLE %s (%s)", tableName, cols),
|
||||
fmt.Sprintf("INSERT INTO %s SELECT '', * FROM old_%s", tableName, tableName),
|
||||
fmt.Sprintf("DROP TABLE old_%s", tableName),
|
||||
} {
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(tx *sql.Tx, dialect string) error {
|
||||
if dialect == "postgres" {
|
||||
alters := [...]string{
|
||||
"ADD COLUMN withheld_code VARCHAR(255)",
|
||||
"ADD COLUMN withheld_reason TEXT",
|
||||
"ALTER COLUMN signing_key DROP NOT NULL",
|
||||
"ALTER COLUMN session DROP NOT NULL",
|
||||
"ALTER COLUMN forwarding_chains DROP NOT NULL",
|
||||
}
|
||||
for _, alter := range alters {
|
||||
_, err := tx.Exec(fmt.Sprintf("ALTER TABLE crypto_megolm_inbound_session %s", alter))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if dialect == "sqlite3" {
|
||||
_, err := tx.Exec("ALTER TABLE crypto_megolm_inbound_session RENAME TO old_crypto_megolm_inbound_session")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
|
||||
account_id VARCHAR(255) NOT NULL,
|
||||
session_id CHAR(43) NOT NULL,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43),
|
||||
room_id VARCHAR(255) NOT NULL,
|
||||
session BLOB,
|
||||
forwarding_chains BLOB,
|
||||
withheld_code VARCHAR(255),
|
||||
withheld_reason TEXT,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`INSERT INTO crypto_megolm_inbound_session
|
||||
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
|
||||
SELECT * FROM old_crypto_megolm_inbound_session`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DROP TABLE old_crypto_megolm_inbound_session")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(tx *sql.Tx, dialect string) error {
|
||||
if _, err := tx.Exec(
|
||||
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
usage VARCHAR(20) NOT NULL,
|
||||
key CHAR(43) NOT NULL,
|
||||
PRIMARY KEY (user_id, usage)
|
||||
)`,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(
|
||||
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
|
||||
signed_user_id VARCHAR(255) NOT NULL,
|
||||
signed_key VARCHAR(255) NOT NULL,
|
||||
signer_user_id VARCHAR(255) NOT NULL,
|
||||
signer_key VARCHAR(255) NOT NULL,
|
||||
signature CHAR(88) NOT NULL,
|
||||
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
|
||||
)`,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(tx *sql.Tx, dialect string) error {
|
||||
if dialect == "sqlite3" {
|
||||
// SQLite doesn't enforce varchar sizes anyway
|
||||
return nil
|
||||
}
|
||||
alters := [...]string{
|
||||
`ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT`,
|
||||
//go:embed *.sql
|
||||
var fs embed.FS
|
||||
|
||||
`ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN withheld_code TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_cross_signing_keys ALTER COLUMN user_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_cross_signing_keys ALTER COLUMN usage TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_user_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_key TYPE TEXT`,
|
||||
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_user_id TYPE TEXT`,
|
||||
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_key TYPE TEXT`,
|
||||
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature TYPE TEXT`,
|
||||
}
|
||||
for _, alter := range alters {
|
||||
_, err := tx.Exec(alter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(tx *sql.Tx, dialect string) error {
|
||||
_, err := tx.Exec("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("UPDATE crypto_olm_session SET last_encrypted=last_decrypted")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dialect == "postgres" {
|
||||
// This is too hard to do on sqlite, so let's just do it on postgres
|
||||
_, err = tx.Exec("ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// GetVersion returns the current version of the DB schema.
|
||||
func GetVersion(db *sql.DB) (int, error) {
|
||||
_, err := db.Exec("CREATE TABLE IF NOT EXISTS crypto_version (version INTEGER)")
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
version := 0
|
||||
row := db.QueryRow("SELECT version FROM crypto_version LIMIT 1")
|
||||
if row != nil {
|
||||
_ = row.Scan(&version)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// SetVersion sets the schema version in a running DB transaction.
|
||||
func SetVersion(tx *sql.Tx, version int) error {
|
||||
_, err := tx.Exec("DELETE FROM crypto_version")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("INSERT INTO crypto_version (version) VALUES ($1)", version)
|
||||
return err
|
||||
func init() {
|
||||
Table.Register(-1, 3, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
|
||||
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
|
||||
})
|
||||
Table.RegisterFS(fs)
|
||||
}
|
||||
|
||||
// Upgrade upgrades the database from the current to the latest version available.
|
||||
func Upgrade(db *sql.DB, dialect string) error {
|
||||
version, err := GetVersion(db)
|
||||
func Upgrade(sqlDB *sql.DB, dialect string) error {
|
||||
db, err := dbutil.NewWithDB(sqlDB, dialect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// perform migrations starting with #version
|
||||
for ; version < len(Upgrades); version++ {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run each migrate func
|
||||
migrateFunc := Upgrades[version]
|
||||
err = migrateFunc(tx, dialect)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// also update the version in this tx
|
||||
if err = SetVersion(tx, version+1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
db.VersionTable = VersionTableName
|
||||
db.UpgradeTable = Table
|
||||
return db.Upgrade()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
const olmSessID = "sJlikQQKXp7UQjmS9/lyZCNUVJ2AmKyHbufPBaC7tpk"
|
||||
|
|
@ -26,12 +27,16 @@ const olmPickled = "L6cdv3JYO9OzhXbcjNSwl7ldN5bDvwmGyin+hISePETE6bO71DIlhqTC9YIh
|
|||
const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4fA5tovoPfdpk5QLh7PfDyjmgOcO9sSA37maJyzCy6Ap+uBZLAXp6VLJ0mjSvxi+PAbzGKDMqpn+pa+oeEIH6SFPG/2GGDSRoXVi5fttAClCIoav5RflWiMypKqnQRfkZR2Gx8glOaBiTzAd7m0X6XGfYIPol41JUIHfBLuJBfXQ0Uu5GScV4eKUWdJP2J6zzC2Hx8cZAhiBBzAza0CbGcnUK+YJXMYaJg92HiIo++l317LlsYUJ/P+gKOLafYR9/l8bAzxH7j5s31PnRs7mD1Bl6G1LFM+dPsGXUOLx6PlvlTlYYM/opai0uKKzT0Wk6zPoq9fN/smlXEPBtKlw2fqcytL4gOF0MrBPEca"
|
||||
|
||||
func getCryptoStores(t *testing.T) (map[string]Store, func()) {
|
||||
db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{})
|
||||
if err = sqlStore.CreateTables(); err != nil {
|
||||
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.Upgrade(); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -12,7 +12,8 @@ require (
|
|||
github.com/yuin/goldmark v1.4.12
|
||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.0
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
maunium.net/go/maulogger/v2 v2.3.2
|
||||
)
|
||||
|
||||
|
|
@ -21,5 +22,4 @@ require (
|
|||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||
)
|
||||
|
|
|
|||
7
go.sum
7
go.sum
|
|
@ -38,9 +38,10 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
|||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
|
||||
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
|
||||
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
|
||||
|
|
|
|||
|
|
@ -28,6 +28,20 @@ type BaseUpgrader interface {
|
|||
GetBase() string
|
||||
}
|
||||
|
||||
type StructUpgrader struct {
|
||||
SimpleUpgrader
|
||||
Blocks [][]string
|
||||
Base string
|
||||
}
|
||||
|
||||
func (su *StructUpgrader) SpacedBlocks() [][]string {
|
||||
return su.Blocks
|
||||
}
|
||||
|
||||
func (su *StructUpgrader) GetBase() string {
|
||||
return su.Base
|
||||
}
|
||||
|
||||
type SimpleUpgrader func(helper *Helper)
|
||||
|
||||
func (su SimpleUpgrader) DoUpgrade(helper *Helper) {
|
||||
|
|
|
|||
132
util/dbutil/database.go
Normal file
132
util/dbutil/database.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
|
||||
const (
|
||||
DialectUnknown Dialect = iota
|
||||
Postgres
|
||||
SQLite
|
||||
)
|
||||
|
||||
func (dialect Dialect) String() string {
|
||||
switch dialect {
|
||||
case Postgres:
|
||||
return "postgres"
|
||||
case SQLite:
|
||||
return "sqlite3"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func ParseDialect(engine string) (Dialect, error) {
|
||||
switch strings.ToLower(engine) {
|
||||
case "postgres", "postgresql":
|
||||
return Postgres, nil
|
||||
case "sqlite3", "sqlite":
|
||||
return SQLite, nil
|
||||
default:
|
||||
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
|
||||
}
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
*sql.DB
|
||||
Owner string
|
||||
VersionTable string
|
||||
Log log.Logger
|
||||
Dialect Dialect
|
||||
UpgradeTable UpgradeTable
|
||||
|
||||
IgnoreForeignTables bool
|
||||
IgnoreUnsupportedDatabase bool
|
||||
}
|
||||
|
||||
func (db *Database) Child(logName, versionTable string, upgradeTable UpgradeTable) *Database {
|
||||
return &Database{
|
||||
DB: db.DB,
|
||||
Owner: "",
|
||||
VersionTable: versionTable,
|
||||
UpgradeTable: upgradeTable,
|
||||
Log: db.Log.Sub(logName),
|
||||
Dialect: db.Dialect,
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
|
||||
dialect, err := ParseDialect(rawDialect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{
|
||||
DB: db,
|
||||
Dialect: dialect,
|
||||
Log: log.Sub("Database"),
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
VersionTable: "version",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewFromConfig(owner string, cfg bridgeconfig.DatabaseConfig, dbLog log.Logger) (*Database, error) {
|
||||
dialect, err := ParseDialect(cfg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := sql.Open(cfg.Type, cfg.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
conn.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
if len(cfg.ConnMaxIdleTime) > 0 {
|
||||
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
conn.SetConnMaxIdleTime(maxIdleTimeDuration)
|
||||
}
|
||||
if len(cfg.ConnMaxLifetime) > 0 {
|
||||
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
conn.SetConnMaxLifetime(maxLifetimeDuration)
|
||||
}
|
||||
if dbLog == nil {
|
||||
dbLog = log.Sub("Database")
|
||||
}
|
||||
return &Database{
|
||||
DB: conn,
|
||||
Owner: owner,
|
||||
Log: dbLog,
|
||||
Dialect: dialect,
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
VersionTable: "version",
|
||||
}, nil
|
||||
}
|
||||
154
util/dbutil/upgrades.go
Normal file
154
util/dbutil/upgrades.go
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
)
|
||||
|
||||
type upgradeFunc func(*sql.Tx, *Database) error
|
||||
|
||||
type upgrade struct {
|
||||
message string
|
||||
fn upgradeFunc
|
||||
|
||||
upgradesTo int
|
||||
}
|
||||
|
||||
type Upgrader struct {
|
||||
*sql.DB
|
||||
Log log.Logger
|
||||
Dialect Dialect
|
||||
|
||||
upgrades []upgrade
|
||||
}
|
||||
|
||||
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
|
||||
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
|
||||
var ErrNotOwned = fmt.Errorf("the database is owned by")
|
||||
|
||||
func (db *Database) getVersion() (int, error) {
|
||||
_, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version INTEGER)", db.VersionTable))
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
version := 0
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT version FROM %s LIMIT 1", db.VersionTable)).Scan(&version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return -1, err
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
|
||||
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
|
||||
|
||||
func (db *Database) tableExists(table string) (exists bool) {
|
||||
if db.Dialect == SQLite {
|
||||
_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
|
||||
} else if db.Dialect == Postgres {
|
||||
_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const createOwnerTable = `
|
||||
CREATE TABLE IF NOT EXISTS database_owner (
|
||||
key INTEGER PRIMARY KEY DEFAULT 0,
|
||||
owner TEXT NOT NULL
|
||||
)
|
||||
`
|
||||
|
||||
func (db *Database) checkDatabaseOwner() error {
|
||||
var owner string
|
||||
if !db.IgnoreForeignTables {
|
||||
if db.tableExists("state_groups_state") {
|
||||
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
|
||||
} else if db.tableExists("goose_db_version") {
|
||||
return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
|
||||
}
|
||||
}
|
||||
if db.Owner == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := db.Exec(createOwnerTable); err != nil {
|
||||
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
|
||||
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", db.Owner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert database owner: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to check database owner: %w", err)
|
||||
} else if owner != db.Owner {
|
||||
return fmt.Errorf("%w %s", ErrNotOwned, owner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) setVersion(tx *sql.Tx, version int) error {
|
||||
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", db.VersionTable), version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) Upgrade() error {
|
||||
err := db.checkDatabaseOwner()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
version, err := db.getVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if version > len(db.UpgradeTable) {
|
||||
warning := fmt.Sprintf("currently on v%d, latest known: v%d", version, len(db.UpgradeTable))
|
||||
if db.IgnoreUnsupportedDatabase {
|
||||
db.Log.Warnfln("Unsupported database schema version: %s - continuing anyway", warning)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrUnsupportedDatabaseVersion, warning)
|
||||
}
|
||||
|
||||
db.Log.Infofln("Database currently on v%d, latest: v%d", version, len(db.UpgradeTable))
|
||||
upgradesToApply := db.UpgradeTable[version:]
|
||||
for _, upgradeItem := range upgradesToApply {
|
||||
if upgradeItem.fn == nil {
|
||||
continue
|
||||
}
|
||||
db.Log.Infofln("Upgrading database from v%d to v%d: %s", version, upgradeItem.upgradesTo, upgradeItem.message)
|
||||
var tx *sql.Tx
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = upgradeItem.fn(tx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
version = upgradeItem.upgradesTo
|
||||
err = db.setVersion(tx, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
159
util/dbutil/upgradetable.go
Normal file
159
util/dbutil/upgradetable.go
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type UpgradeTable []upgrade
|
||||
|
||||
func (ut *UpgradeTable) extend(toSize int) {
|
||||
if cap(*ut) >= toSize {
|
||||
*ut = (*ut)[:toSize]
|
||||
} else {
|
||||
resized := make([]upgrade, toSize)
|
||||
copy(resized, *ut)
|
||||
*ut = resized
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
|
||||
if from < 0 {
|
||||
from += to
|
||||
}
|
||||
if from < 0 {
|
||||
panic("invalid from value in UpgradeTable.Register() call")
|
||||
}
|
||||
upg := upgrade{message: message, fn: fn, upgradesTo: to}
|
||||
if len(*ut) == from {
|
||||
*ut = append(*ut, upg)
|
||||
return
|
||||
} else if len(*ut) < from {
|
||||
ut.extend(from + 1)
|
||||
} else if (*ut)[from].fn != nil {
|
||||
panic(fmt.Errorf("tried to override upgrade at %d ('%s') with '%s'", from, (*ut)[from].message, upg.message))
|
||||
}
|
||||
(*ut)[from] = upg
|
||||
}
|
||||
|
||||
// Syntax is either
|
||||
// -- v0 -> v1: Message
|
||||
// or
|
||||
// -- v1: Message
|
||||
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
|
||||
|
||||
func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte, err error) {
|
||||
lines = bytes.Split(file, []byte("\n"))
|
||||
if len(lines) < 2 {
|
||||
err = errors.New("upgrade file too short")
|
||||
return
|
||||
}
|
||||
var maybeFrom int
|
||||
match := upgradeHeaderRegex.FindSubmatch(lines[0])
|
||||
lines = lines[1:]
|
||||
if match == nil {
|
||||
err = errors.New("header not found")
|
||||
} else if len(match) != 4 {
|
||||
err = errors.New("unexpected number of items in regex match")
|
||||
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
|
||||
err = fmt.Errorf("invalid source version: %w", err)
|
||||
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
|
||||
err = fmt.Errorf("invalid target version: %w", err)
|
||||
} else {
|
||||
if len(match[1]) > 0 {
|
||||
from = maybeFrom
|
||||
} else {
|
||||
from = -1
|
||||
}
|
||||
message = string(match[3])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// To limit the next line to one dialect:
|
||||
// -- only: postgres
|
||||
// To limit the next N lines:
|
||||
// -- only: sqlite for next 123 lines
|
||||
// If the single-line limit is on the second line of the file, the whole file is limited to that dialect.
|
||||
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines)?`)
|
||||
|
||||
func (db *Database) parseDialectFilter(line []byte) (int, error) {
|
||||
match := dialectLineFilter.FindSubmatch(line)
|
||||
if match != nil {
|
||||
dialect, err := ParseDialect(string(match[1]))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if dialect != db.Dialect {
|
||||
if len(match[2]) == 0 {
|
||||
return 1, nil
|
||||
} else {
|
||||
lineCount, err := strconv.Atoi(string(match[2]))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid line count '%s': %w", match[2], err)
|
||||
}
|
||||
return lineCount, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (db *Database) mutateSQLUpgrade(lines [][]byte) (string, error) {
|
||||
output := lines[:0]
|
||||
for i := 0; i < len(lines); i++ {
|
||||
skipLines, err := db.parseDialectFilter(lines[i])
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if skipLines > 0 {
|
||||
i += skipLines
|
||||
} else {
|
||||
output = append(output, lines[i])
|
||||
}
|
||||
}
|
||||
return string(bytes.Join(output, []byte("\n"))), nil
|
||||
}
|
||||
|
||||
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
|
||||
return func(tx *sql.Tx, db *Database) error {
|
||||
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == 1 {
|
||||
return nil
|
||||
} else if upgradeSQL, err := db.mutateSQLUpgrade(lines); err != nil {
|
||||
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
|
||||
} else {
|
||||
_, err = tx.Exec(upgradeSQL)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fullFS interface {
|
||||
fs.ReadFileFS
|
||||
fs.ReadDirFS
|
||||
}
|
||||
|
||||
func (ut *UpgradeTable) RegisterFS(fs fullFS) {
|
||||
files, err := fs.ReadDir(".")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, file := range files {
|
||||
if data, err := fs.ReadFile(file.Name()); err != nil {
|
||||
panic(err)
|
||||
} else if from, to, message, lines, err := parseFileHeader(data); err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
|
||||
} else {
|
||||
ut.Register(from, to, message, sqlUpgradeFunc(file.Name(), lines))
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue