Move a bunch of stuff from mautrix-whatsapp

Moved parts:
* Appservice SQL state store
* Bridge crypto helper
* Database upgrade framework
* Bridge startup flow

Other changes:
* Improved database upgrade framework
  * Now primarily using static SQL files compiled with go:embed
* Moved appservice SQL state store to using membership enum on Postgres
This commit is contained in:
Tulir Asokan 2022-05-22 00:50:33 +03:00
commit d578d1a610
24 changed files with 1925 additions and 393 deletions

View file

@ -23,7 +23,7 @@ import (
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v2"
"gopkg.in/yaml.v3"
"maunium.net/go/maulogger/v2"

View file

@ -1,4 +1,4 @@
// Copyright (c) 2019 Tulir Asokan
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -10,11 +10,11 @@ import (
"io/ioutil"
"regexp"
"gopkg.in/yaml.v2"
"gopkg.in/yaml.v3"
)
// Registration contains the data in a Matrix appservice registration.
// See https://matrix.org/docs/spec/application_service/unstable.html#registration
// See https://spec.matrix.org/v1.2/application-service-api/#registration
type Registration struct {
ID string `yaml:"id"`
URL string `yaml:"url"`
@ -23,8 +23,10 @@ type Registration struct {
SenderLocalpart string `yaml:"sender_localpart"`
RateLimited *bool `yaml:"rate_limited,omitempty"`
Namespaces Namespaces `yaml:"namespaces"`
EphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"`
Protocols []string `yaml:"protocols,omitempty"`
SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"`
EphemeralEvents bool `yaml:"push_ephemeral,omitempty"`
}
// CreateRegistration creates a Registration with random appservice and homeserver tokens.
@ -70,9 +72,9 @@ func (reg *Registration) YAML() (string, error) {
// Namespaces contains the three areas that appservices can reserve parts of.
type Namespaces struct {
UserIDs []Namespace `yaml:"users,omitempty"`
RoomAliases []Namespace `yaml:"aliases,omitempty"`
RoomIDs []Namespace `yaml:"rooms,omitempty"`
UserIDs NamespaceList `yaml:"users,omitempty"`
RoomAliases NamespaceList `yaml:"aliases,omitempty"`
RoomIDs NamespaceList `yaml:"rooms,omitempty"`
}
// Namespace is a reserved namespace in any area.
@ -81,26 +83,16 @@ type Namespace struct {
Exclusive bool `yaml:"exclusive"`
}
// RegisterUserIDs creates an user ID namespace registration.
func (nslist *Namespaces) RegisterUserIDs(regex *regexp.Regexp, exclusive bool) {
nslist.UserIDs = append(nslist.UserIDs, Namespace{
Regex: regex.String(),
Exclusive: exclusive,
})
}
type NamespaceList []Namespace
// RegisterRoomAliases creates an room alias namespace registration.
func (nslist *Namespaces) RegisterRoomAliases(regex *regexp.Regexp, exclusive bool) {
nslist.RoomAliases = append(nslist.RoomAliases, Namespace{
func (nsl *NamespaceList) Register(regex *regexp.Regexp, exclusive bool) {
ns := Namespace{
Regex: regex.String(),
Exclusive: exclusive,
})
}
// RegisterRoomIDs creates an room ID namespace registration.
func (nslist *Namespaces) RegisterRoomIDs(regex *regexp.Regexp, exclusive bool) {
nslist.RoomIDs = append(nslist.RoomIDs, Namespace{
Regex: regex.String(),
Exclusive: exclusive,
})
}
if nsl == nil {
*nsl = []Namespace{ns}
} else {
*nsl = append(*nsl, ns)
}
}

View file

@ -0,0 +1,284 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package sqlstatestore
import (
"database/sql"
"embed"
"encoding/json"
"errors"
"sync"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
//go:embed *.sql
var rawUpgrades embed.FS
var UpgradeTable dbutil.UpgradeTable
func init() {
UpgradeTable.RegisterFS(rawUpgrades)
}
const VersionTableName = "mx_version"
type SQLStateStore struct {
*dbutil.Database
*appservice.TypingStateStore
log log.Logger
Typing map[id.RoomID]map[id.UserID]int64
typingLock sync.RWMutex
}
var _ appservice.StateStore = (*SQLStateStore)(nil)
func NewSQLStateStore(db *dbutil.Database) *SQLStateStore {
return &SQLStateStore{
Database: db.Child("StateStore", VersionTableName, UpgradeTable),
TypingStateStore: appservice.NewTypingStateStore(),
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
var isRegistered bool
err := store.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
}
return isRegistered
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
}
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
rows, err := store.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
if err != nil {
return members
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
}
return membership
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
}
return member
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
var member event.MemberEventContent
err := store.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
}
return &member, err == nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
rows, err := store.Query(`
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
WHERE user_id=$1 AND portal.encrypted=true
`, userID)
if err != nil {
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
return
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.log.Warnfln("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.Exec(`
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
}
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
return nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
if store.Dialect == dbutil.Postgres {
var powerLevel int
err := store.
QueryRow(`
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var powerLevel int
err := store.
QueryRow(`
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var hasPower bool
err := store.
QueryRow(`SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
}
return hasPower
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View file

@ -0,0 +1,19 @@
-- v1: Initial revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
);
CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
);
CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb
);

View file

@ -0,0 +1,5 @@
-- v2: Use enum for membership field on Postgres
-- only: postgres
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership USING LOWER(membership)::membership;

353
bridge/bridge.go Normal file
View file

@ -0,0 +1,353 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package bridge
import (
"errors"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
"gopkg.in/yaml.v3"
flag "maunium.net/go/mauflag"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/appservice/sqlstatestore"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/configupgrade"
"maunium.net/go/mautrix/util/dbutil"
)
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
var wantHelp, _ = flag.MakeHelpFlag()
type Portal interface {
IsEncrypted() bool
}
type User interface {
IsAdmin() bool
}
type ChildOverride interface {
GetExampleConfig() string
GetConfigPtr() interface{}
Init()
Start()
Stop()
GetIPortalByMXID(id id.RoomID) Portal
GetIUserByMXID(id id.UserID) User
}
type Bridge struct {
Name string
URL string
Description string
Version string
ProtocolName string
VersionDesc string
LinkifiedVersion string
AS *appservice.AppService
EventProcessor *appservice.EventProcessor
Bot *appservice.IntentAPI
Config bridgeconfig.BaseConfig
ConfigUpgrader configupgrade.BaseUpgrader
Log log.Logger
DB *dbutil.Database
StateStore *sqlstatestore.SQLStateStore
Crypto Crypto
Child ChildOverride
}
type Crypto interface {
HandleMemberEvent(*event.Event)
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
Start()
Stop()
}
func (br *Bridge) GenerateRegistration() {
if *dontSaveConfig {
// We need to save the generated as_token and hs_token in the config
_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
os.Exit(5)
}
reg := br.Config.GenerateRegistration()
err := reg.Save(*registrationPath)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
os.Exit(21)
}
updateTokens := func(helper *configupgrade.Helper) {
helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token")
helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token")
}
_, _, err = configupgrade.Do(*configPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens))
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
os.Exit(22)
}
fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.")
os.Exit(0)
}
func (br *Bridge) InitVersion(tag, commit, buildTime string) {
if len(tag) > 0 && tag[0] == 'v' {
tag = tag[1:]
}
if tag != br.Version {
suffix := ""
if !strings.HasSuffix(br.Version, "+dev") {
suffix = "+dev"
}
if len(commit) > 8 {
br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
} else {
br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
}
}
linkifiedVersion := fmt.Sprintf("v%s", br.Version)
if tag == br.Version {
linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
} else if len(commit) > 8 {
linkifiedVersion = strings.Replace(linkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
}
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, buildTime)
}
func (br *Bridge) ensureConnection() {
for {
versions, err := br.Bot.Versions()
if err != nil {
br.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
continue
}
if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
br.Log.Warnfln("Server isn't advertising modern spec versions")
}
resp, err := br.Bot.Whoami()
if err != nil {
if errors.Is(err, mautrix.MUnknownToken) {
br.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
os.Exit(16)
} else if errors.Is(err, mautrix.MExclusive) {
br.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
os.Exit(16)
}
br.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
} else if resp.UserID != br.Bot.UserID {
br.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, br.Bot.UserID)
os.Exit(17)
} else {
break
}
}
}
func (br *Bridge) UpdateBotProfile() {
br.Log.Debugln("Updating bot profile")
botConfig := &br.Config.AppService.Bot
var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" {
err = br.Bot.SetAvatarURL(mxc)
} else if !botConfig.ParsedAvatar.IsEmpty() {
err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar)
}
if err != nil {
br.Log.Warnln("Failed to update bot avatar:", err)
}
if botConfig.Displayname == "remove" {
err = br.Bot.SetDisplayName("")
} else if len(botConfig.Displayname) > 0 {
err = br.Bot.SetDisplayName(botConfig.Displayname)
}
if err != nil {
br.Log.Warnln("Failed to update bot displayname:", err)
}
}
func (br *Bridge) loadConfig() {
configData, upgraded, err := configupgrade.Do(*configPath, !*dontSaveConfig, br.ConfigUpgrader)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
if configData == nil {
os.Exit(10)
}
}
target := br.Child.GetConfigPtr()
if !upgraded {
// Fallback: if config upgrading failed, load example config for base values
err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err)
os.Exit(10)
}
}
err = yaml.Unmarshal(configData, target)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
os.Exit(10)
}
}
func (br *Bridge) init() {
var err error
br.AS = br.Config.MakeAppService()
_, _ = br.AS.Init()
br.Log = log.Create()
br.Config.Logging.Configure(br.Log)
log.DefaultLogger = br.Log.(*log.BasicLogger)
if len(br.Config.Logging.FileNameFormat) > 0 {
err = log.OpenFile()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
os.Exit(12)
}
}
br.AS.Log = log.Sub("Matrix")
br.Bot = br.AS.BotIntent()
br.Log.Infoln("Initializing", br.VersionDesc)
br.Log.Debugln("Initializing database connection")
br.DB, err = dbutil.NewFromConfig(br.Name, br.Config.AppService.Database, br.Log.Sub("Database"))
if err != nil {
br.Log.Fatalln("Failed to initialize database connection:", err)
os.Exit(14)
}
br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase
br.DB.IgnoreForeignTables = *ignoreForeignTables
br.Log.Debugln("Initializing state store")
br.StateStore = sqlstatestore.NewSQLStateStore(br.DB)
br.AS.StateStore = br.StateStore
br.Log.Debugln("Initializing Matrix event processor")
br.EventProcessor = appservice.NewEventProcessor(br.AS)
br.Crypto = NewCryptoHelper(br)
br.Child.Init()
}
func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) {
br.Log.Fatalfln("Failed to initialize %s: %v", name, err)
if errors.Is(err, dbutil.ErrForeignTables) {
br.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
} else if errors.Is(err, dbutil.ErrNotOwned) {
br.Log.Infoln("Sharing the same database with different programs is not supported")
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
br.Log.Infoln("Downgrading the bridge is not supported")
}
os.Exit(15)
}
func (br *Bridge) start() {
br.Log.Debugln("Running database upgrades")
err := br.DB.Upgrade()
if err != nil {
br.LogDBUpgradeErrorAndExit("main database", err)
} else if err = br.StateStore.Upgrade(); err != nil {
br.LogDBUpgradeErrorAndExit("matrix state store", err)
}
br.Log.Debugln("Checking connection to homeserver")
br.ensureConnection()
if br.Crypto != nil {
err = br.Crypto.Init()
if err != nil {
br.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
os.Exit(19)
}
}
br.Log.Debugln("Starting application service HTTP server")
go br.AS.Start()
br.Log.Debugln("Starting event processor")
go br.EventProcessor.Start()
go br.UpdateBotProfile()
if br.Crypto != nil {
go br.Crypto.Start()
}
br.Child.Start()
br.AS.Ready = true
}
func (br *Bridge) Main() {
flag.SetHelpTitles(
fmt.Sprintf("%s - %s", br.Name, br.Description),
fmt.Sprintf("%s [-hgvn] [-c <path>] [-r <path>]", br.Name))
err := flag.Parse()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
flag.PrintHelp()
os.Exit(1)
} else if *wantHelp {
flag.PrintHelp()
os.Exit(0)
} else if *version {
fmt.Println(br.VersionDesc)
return
}
br.loadConfig()
if *generateRegistration {
br.GenerateRegistration()
return
}
br.init()
br.Log.Infoln("Bridge initialization complete, starting...")
br.start()
br.Log.Infoln("Bridge started!")
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
br.Log.Infoln("Interrupt received, stopping...")
br.Child.Stop()
br.Log.Infoln("Bridge stopped.")
os.Exit(0)
}

View file

@ -0,0 +1,200 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package bridgeconfig
import (
"fmt"
"regexp"
"strings"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/id"
up "maunium.net/go/mautrix/util/configupgrade"
)
type HomeserverConfig struct {
Address string `yaml:"address"`
Domain string `yaml:"domain"`
AsyncMedia bool `yaml:"async_media"`
Asmux bool `yaml:"asmux"`
StatusEndpoint string `yaml:"status_endpoint"`
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
}
type AppserviceConfig struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
Port uint16 `yaml:"port"`
Database DatabaseConfig `yaml:"database"`
ID string `yaml:"id"`
Bot BotUserConfig `yaml:"bot"`
ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
EphemeralEvents bool `yaml:"ephemeral_events"`
}
func (config *BaseConfig) MakeUserIDRegex() *regexp.Regexp {
usernamePlaceholder := appservice.RandomString(16)
usernameTemplate := fmt.Sprintf("@%s:%s",
config.Bridge.FormatUsername(usernamePlaceholder),
config.Homeserver.Domain)
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
return regexp.MustCompile(usernameTemplate)
}
// GenerateRegistration generates a registration file for the homeserver.
func (config *BaseConfig) GenerateRegistration() *appservice.Registration {
registration := appservice.CreateRegistration()
config.AppService.HSToken = registration.ServerToken
config.AppService.ASToken = registration.AppToken
config.AppService.copyToRegistration(registration)
registration.SenderLocalpart = appservice.RandomString(32)
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
regexp.QuoteMeta(config.AppService.Bot.Username),
regexp.QuoteMeta(config.Homeserver.Domain)))
registration.Namespaces.UserIDs.Register(botRegex, true)
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(), true)
return registration
}
func (config *BaseConfig) MakeAppService() *appservice.AppService {
as := appservice.Create()
as.HomeserverDomain = config.Homeserver.Domain
as.HomeserverURL = config.Homeserver.Address
as.Host.Hostname = config.AppService.Hostname
as.Host.Port = config.AppService.Port
as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
as.DefaultHTTPRetries = 4
as.Registration = config.AppService.GetRegistration()
return as
}
// GetRegistration copies the data from the bridge config into an *appservice.Registration struct.
// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver.
func (asc *AppserviceConfig) GetRegistration() *appservice.Registration {
reg := &appservice.Registration{}
asc.copyToRegistration(reg)
reg.SenderLocalpart = asc.Bot.Username
reg.ServerToken = asc.HSToken
reg.AppToken = asc.ASToken
return reg
}
func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) {
registration.ID = asc.ID
registration.URL = asc.Address
falseVal := false
registration.RateLimited = &falseVal
registration.EphemeralEvents = asc.EphemeralEvents
registration.SoruEphemeralEvents = asc.EphemeralEvents
}
type BotUserConfig struct {
Username string `yaml:"username"`
Displayname string `yaml:"displayname"`
Avatar string `yaml:"avatar"`
ParsedAvatar id.ContentURI `yaml:"-"`
}
type serializableBUC BotUserConfig
func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
var sbuc serializableBUC
err := unmarshal(&sbuc)
if err != nil {
return err
}
*buc = (BotUserConfig)(sbuc)
if buc.Avatar != "" && buc.Avatar != "remove" {
buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar)
if err != nil {
return fmt.Errorf("%w in bot avatar", err)
}
}
return nil
}
type DatabaseConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}
type BridgeConfig interface {
FormatUsername(username string) string
GetEncryptionConfig() EncryptionConfig
}
type EncryptionConfig struct {
Allow bool `yaml:"allow"`
Default bool `yaml:"default"`
KeySharing struct {
Allow bool `yaml:"allow"`
RequireCrossSigning bool `yaml:"require_cross_signing"`
RequireVerification bool `yaml:"require_verification"`
} `yaml:"key_sharing"`
}
type BaseConfig struct {
Homeserver HomeserverConfig `yaml:"homeserver"`
AppService AppserviceConfig `yaml:"appservice"`
Bridge BridgeConfig `yaml:"-"`
Logging appservice.LogConfig `yaml:"logging"`
}
type configUpgrader struct{}
var Upgrader = configUpgrader{}
func (upg configUpgrader) DoUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "homeserver", "address")
helper.Copy(up.Str, "homeserver", "domain")
helper.Copy(up.Bool, "homeserver", "asmux")
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
helper.Copy(up.Bool, "homeserver", "async_media")
helper.Copy(up.Str, "appservice", "address")
helper.Copy(up.Str, "appservice", "hostname")
helper.Copy(up.Int, "appservice", "port")
helper.Copy(up.Str, "appservice", "database", "type")
helper.Copy(up.Str, "appservice", "database", "uri")
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
helper.Copy(up.Str, "appservice", "id")
helper.Copy(up.Str, "appservice", "bot", "username")
helper.Copy(up.Str, "appservice", "bot", "displayname")
helper.Copy(up.Str, "appservice", "bot", "avatar")
helper.Copy(up.Bool, "appservice", "ephemeral_events")
helper.Copy(up.Str, "appservice", "as_token")
helper.Copy(up.Str, "appservice", "hs_token")
helper.Copy(up.Str, "logging", "directory")
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
helper.Copy(up.Int, "logging", "file_mode")
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
helper.Copy(up.Str, "logging", "print_level")
}

309
bridge/crypto.go Normal file
View file

@ -0,0 +1,309 @@
package bridge
import (
"fmt"
"runtime/debug"
"time"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var NoSessionFound = crypto.NoSessionFound
var levelTrace = maulogger.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
type CryptoHelper struct {
bridge *Bridge
client *mautrix.Client
mach *crypto.OlmMachine
store *SQLCryptoStore
log maulogger.Logger
baseLog maulogger.Logger
}
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.GetEncryptionConfig().Allow {
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
return nil
}
baseLog := bridge.Log.Sub("Crypto")
return &CryptoHelper{
bridge: bridge,
log: baseLog.Sub("Helper"),
baseLog: baseLog,
}
}
func (helper *CryptoHelper) Init() error {
helper.log.Debugln("Initializing end-to-bridge encryption...")
helper.store = NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
err := helper.store.Upgrade()
if err != nil {
helper.bridge.LogDBUpgradeErrorAndExit("crypto store", err)
}
helper.client, err = helper.loginBot()
if err != nil {
return err
}
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
logger := &cryptoLogger{helper.baseLog}
stateStore := &cryptoStateStore{helper.bridge}
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
helper.mach.AllowKeyShare = helper.allowKeyShare
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = &cryptoClientStore{helper.store}
return helper.mach.Load()
}
func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
cfg := helper.bridge.Config.Bridge.GetEncryptionConfig().KeySharing
if !cfg.Allow {
return &crypto.KeyShareRejectNoResponse
} else if device.Trust == crypto.TrustStateBlacklisted {
return &crypto.KeyShareRejectBlacklisted
} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
portal := helper.bridge.Child.GetIPortalByMXID(info.RoomID)
if portal == nil {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
}
user := helper.bridge.Child.GetIUserByMXID(device.UserID)
// FIXME reimplement IsInPortal
if !user.IsAdmin() /*&& !user.IsInPortal(portal.Key)*/ {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
}
helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
return nil
} else {
return &crypto.KeyShareRejectUnverified
}
}
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
}
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
if err != nil {
return nil, fmt.Errorf("failed to initialize client: %w", err)
}
client.Logger = helper.baseLog.Sub("Bot")
client.Client = helper.bridge.AS.HTTPClient
client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
flows, err := client.GetLoginFlows()
if err != nil {
return nil, fmt.Errorf("failed to get supported login flows: %w", err)
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
return nil, fmt.Errorf("homeserver does not support appservice login")
}
// We set the API token to the AS token here to authenticate the appservice login
// It'll get overridden after the login
client.AccessToken = helper.bridge.AS.Registration.AppToken
resp, err := client.Login(&mautrix.ReqLogin{
Type: mautrix.AuthTypeAppservice,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: string(helper.bridge.AS.BotMXID()),
},
DeviceID: deviceID,
StoreCredentials: true,
InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName),
})
if err != nil {
return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
}
helper.store.DeviceID = resp.DeviceID
return client, nil
}
func (helper *CryptoHelper) Start() {
helper.log.Debugln("Starting syncer for receiving to-device messages")
err := helper.client.Sync()
if err != nil {
helper.log.Errorln("Fatal error syncing:", err)
} else {
helper.log.Infoln("Bridge bot to-device syncer stopped without error")
}
}
func (helper *CryptoHelper) Stop() {
helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
helper.client.StopSync()
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return nil, err
}
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
users, err := helper.store.GetRoomMembers(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room member list: %w", err)
}
err = helper.mach.ShareGroupSession(roomID, users)
if err != nil {
return nil, fmt.Errorf("failed to share group session: %w", err)
}
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return encrypted, nil
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
if err != nil {
helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
} else {
helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
}
}
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
if err != nil {
helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
}
}
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
helper.mach.HandleMemberEvent(evt)
}
type cryptoSyncer struct {
*crypto.OlmMachine
}
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
done := make(chan struct{})
go func() {
defer func() {
if err := recover(); err != nil {
syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
}
done <- struct{}{}
}()
syncer.Log.Trace("Starting sync response handling (%s)", since)
syncer.ProcessSyncResponse(resp, since)
syncer.Log.Trace("Successfully handled sync response (%s)", since)
}()
select {
case <-done:
case <-time.After(30 * time.Second):
syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
}
return nil
}
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
return 10 * time.Second, nil
}
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
everything := []event.Type{{Type: "*"}}
return &mautrix.Filter{
Presence: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
Room: mautrix.RoomFilter{
IncludeLeave: false,
Ephemeral: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
State: mautrix.FilterPart{NotTypes: everything},
Timeline: mautrix.FilterPart{NotTypes: everything},
},
}
}
type cryptoLogger struct {
int maulogger.Logger
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}
type cryptoClientStore struct {
int *SQLCryptoStore
}
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
c.int.PutNextBatch(nextBatchToken)
}
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
return c.int.GetNextBatch()
}
var _ mautrix.Storer = (*cryptoClientStore)(nil)
type cryptoStateStore struct {
bridge *Bridge
}
var _ crypto.StateStore = (*cryptoStateStore)(nil)
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
portal := c.bridge.Child.GetIPortalByMXID(id)
if portal != nil {
return portal.IsEncrypted()
}
return false
}
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
return c.bridge.StateStore.FindSharedRooms(id)
}
func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
// TODO implement
return nil
}

65
bridge/cryptostore.go Normal file
View file

@ -0,0 +1,65 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build cgo && !nocrypto
package bridge
import (
"database/sql"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type SQLCryptoStore struct {
*crypto.SQLCryptoStore
UserID id.UserID
GhostIDFormat string
}
var _ crypto.Store = (*SQLCryptoStore)(nil)
func NewSQLCryptoStore(db *dbutil.Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
return &SQLCryptoStore{
SQLCryptoStore: crypto.NewSQLCryptoStore(db, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
UserID: userID,
GhostIDFormat: ghostIDFormat,
}
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && err != sql.ErrNoRows {
store.Log.Warn("Failed to scan device ID: %v", err)
}
return
}
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
var rows *sql.Rows
rows, err = store.DB.Query(`
SELECT user_id FROM mx_user_profile
WHERE room_id=$1
AND (membership='join' OR membership='invite')
AND user_id<>$2
AND user_id NOT LIKE $3
`, roomID, store.UserID, store.GhostIDFormat)
if err != nil {
return
}
for rows.Next() {
var userID id.UserID
err = rows.Scan(&userID)
if err != nil {
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
} else {
members = append(members, userID)
}
}
return
}

17
bridge/no-crypto.go Normal file
View file

@ -0,0 +1,17 @@
//go:build !cgo || nocrypto
package main
import (
"errors"
)
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
}
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
return nil
}
var NoSessionFound = errors.New("nil")

View file

@ -12,17 +12,23 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
sqlUpgrade "maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
func getOlmMachine(t *testing.T) *OlmMachine {
db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
if err != nil {
t.Fatalf("Error opening db: %v", err)
}
sqlUpgrade.Upgrade(db, "sqlite3")
sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{})
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
if err != nil {
t.Fatalf("Error opening db: %v", err)
}
sqlStore := NewSQLCryptoStore(db, "accid", id.DeviceID("dev"), []byte("test"))
if err = sqlStore.Upgrade(); err != nil {
t.Fatalf("Error creating tables: %v", err)
}
userID := id.UserID("@mautrix")
mk, _ := olm.NewPkSigning()

View file

@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -17,6 +17,7 @@ import (
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
var PostgresArrayWrapper func(interface{}) interface {
@ -26,9 +27,7 @@ var PostgresArrayWrapper func(interface{}) interface {
// SQLCryptoStore is an implementation of a crypto Store for a database backend.
type SQLCryptoStore struct {
DB *sql.DB
Log Logger
Dialect string
*dbutil.Database
AccountID string
DeviceID id.DeviceID
@ -44,11 +43,9 @@ var _ Store = (*SQLCryptoStore)(nil)
// NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material.
// The stored material will be encrypted with the given key.
func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id.DeviceID, pickleKey []byte, log Logger) *SQLCryptoStore {
func NewSQLCryptoStore(db *dbutil.Database, accountID string, deviceID id.DeviceID, pickleKey []byte) *SQLCryptoStore {
return &SQLCryptoStore{
DB: db,
Dialect: dialect,
Log: log,
Database: db.Child("CryptoStore", sql_store_upgrade.VersionTableName, sql_store_upgrade.Table),
PickleKey: pickleKey,
AccountID: accountID,
DeviceID: deviceID,
@ -58,8 +55,10 @@ func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id
}
// CreateTables applies all the pending database migrations.
//
// Deprecated: The Upgrade method (inherited from dbutil.Database) should be used instead
func (store *SQLCryptoStore) CreateTables() error {
return sql_store_upgrade.Upgrade(store.DB, store.Dialect)
return store.Upgrade()
}
// Flush does nothing for this implementation as data is already persisted in the database.
@ -581,7 +580,7 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
var rows *sql.Rows
var err error
if store.Dialect == "postgres" && PostgresArrayWrapper != nil {
if store.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
} else {
queryString := make([]string, len(users))

View file

@ -0,0 +1,86 @@
-- v0 -> v6: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
);
CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id TEXT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
);
CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
);
CREATE TABLE IF NOT EXISTS crypto_device (
user_id TEXT,
device_id TEXT,
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name TEXT NOT NULL,
PRIMARY KEY (user_id, device_id)
);
CREATE TABLE IF NOT EXISTS crypto_olm_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_decrypted timestamp NOT NULL,
last_encrypted timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
);
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id TEXT NOT NULL,
session bytea,
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
PRIMARY KEY (account_id, session)
);
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
room_id TEXT,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
);
CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
);
CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
);

View file

@ -0,0 +1,16 @@
-- v4: Add tables for cross-signing keys
CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
user_id VARCHAR(255) NOT NULL,
usage VARCHAR(20) NOT NULL,
key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
);
CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
signed_user_id VARCHAR(255) NOT NULL,
signed_key VARCHAR(255) NOT NULL,
signer_user_id VARCHAR(255) NOT NULL,
signer_key VARCHAR(255) NOT NULL,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)

View file

@ -0,0 +1,31 @@
-- v5: Switch from VARCHAR(255) to TEXT
-- only: postgres
ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT;
ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT;
ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT;
ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT;
ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT;
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT;
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT;
ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN withheld_code TYPE TEXT;
ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT;
ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT;
ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT;
ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT;
ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT;
ALTER TABLE crypto_cross_signing_keys ALTER COLUMN user_id TYPE TEXT;
ALTER TABLE crypto_cross_signing_keys ALTER COLUMN usage TYPE TEXT;
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_user_id TYPE TEXT;
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_key TYPE TEXT;
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_user_id TYPE TEXT;
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_key TYPE TEXT;
ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature TYPE TEXT;

View file

@ -0,0 +1,6 @@
-- v6: Split last_used into last_encrypted and last_decrypted for Olm sessions
ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted;
ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp;
UPDATE crypto_olm_session SET last_encrypted=last_decrypted;
-- only: postgres (too complicated on SQLite)
ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL;

View file

@ -1,361 +1,40 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package sql_store_upgrade
import (
"database/sql"
"errors"
"embed"
"fmt"
"strings"
"maunium.net/go/mautrix/util/dbutil"
)
type upgradeFunc func(*sql.Tx, string) error
var Table dbutil.UpgradeTable
var ErrUnknownDialect = errors.New("unknown dialect")
const VersionTableName = "crypto_version"
var Upgrades = [...]upgradeFunc{
func(tx *sql.Tx, _ string) error {
for _, query := range []string{
`CREATE TABLE IF NOT EXISTS crypto_account (
device_id VARCHAR(255) PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id VARCHAR(255) NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)`,
`CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id VARCHAR(255) PRIMARY KEY
)`,
`CREATE TABLE IF NOT EXISTS crypto_device (
user_id VARCHAR(255),
device_id VARCHAR(255),
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (user_id, device_id)
)`,
`CREATE TABLE IF NOT EXISTS crypto_olm_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains bytea NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`,
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if dialect == "postgres" {
tablesToPkeys := map[string][]string{
"crypto_account": {},
"crypto_olm_session": {"session_id"},
"crypto_megolm_inbound_session": {"session_id"},
"crypto_megolm_outbound_session": {"room_id"},
}
for tableName, pkeys := range tablesToPkeys {
// add account_id to primary key
pkeyStr := strings.Join(append(pkeys, "account_id"), ", ")
for _, query := range []string{
fmt.Sprintf("ALTER TABLE %s ADD COLUMN account_id VARCHAR(255)", tableName),
fmt.Sprintf("UPDATE %s SET account_id=''", tableName),
fmt.Sprintf("ALTER TABLE %s ALTER COLUMN account_id SET NOT NULL", tableName),
fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s_pkey", tableName, tableName),
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s_pkey PRIMARY KEY (%s)", tableName, tableName, pkeyStr),
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
}
} else if dialect == "sqlite3" {
tableCols := map[string]string{
"crypto_account": `
account_id VARCHAR(255) NOT NULL,
device_id VARCHAR(255) NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account BLOB NOT NULL,
PRIMARY KEY (account_id)
`,
"crypto_olm_session": `
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
session BLOB NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
`,
"crypto_megolm_inbound_session": `
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session BLOB NOT NULL,
forwarding_chains BLOB NOT NULL,
PRIMARY KEY (account_id, session_id)
`,
"crypto_megolm_outbound_session": `
account_id VARCHAR(255) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL UNIQUE,
session BLOB NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
`,
}
for tableName, cols := range tableCols {
// re-create tables with account_id column and new pkey and re-insert rows
for _, query := range []string{
fmt.Sprintf("ALTER TABLE %s RENAME TO old_%s", tableName, tableName),
fmt.Sprintf("CREATE TABLE %s (%s)", tableName, cols),
fmt.Sprintf("INSERT INTO %s SELECT '', * FROM old_%s", tableName, tableName),
fmt.Sprintf("DROP TABLE old_%s", tableName),
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
}
} else {
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if dialect == "postgres" {
alters := [...]string{
"ADD COLUMN withheld_code VARCHAR(255)",
"ADD COLUMN withheld_reason TEXT",
"ALTER COLUMN signing_key DROP NOT NULL",
"ALTER COLUMN session DROP NOT NULL",
"ALTER COLUMN forwarding_chains DROP NOT NULL",
}
for _, alter := range alters {
_, err := tx.Exec(fmt.Sprintf("ALTER TABLE crypto_megolm_inbound_session %s", alter))
if err != nil {
return err
}
}
} else if dialect == "sqlite3" {
_, err := tx.Exec("ALTER TABLE crypto_megolm_inbound_session RENAME TO old_crypto_megolm_inbound_session")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id VARCHAR(255) NOT NULL,
session BLOB,
forwarding_chains BLOB,
withheld_code VARCHAR(255),
withheld_reason TEXT,
PRIMARY KEY (account_id, session_id)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
SELECT * FROM old_crypto_megolm_inbound_session`)
if err != nil {
return err
}
_, err = tx.Exec("DROP TABLE old_crypto_megolm_inbound_session")
if err != nil {
return err
}
} else {
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if _, err := tx.Exec(
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
user_id VARCHAR(255) NOT NULL,
usage VARCHAR(20) NOT NULL,
key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
)`,
); err != nil {
return err
}
if _, err := tx.Exec(
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
signed_user_id VARCHAR(255) NOT NULL,
signed_key VARCHAR(255) NOT NULL,
signer_user_id VARCHAR(255) NOT NULL,
signer_key VARCHAR(255) NOT NULL,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)`,
); err != nil {
return err
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if dialect == "sqlite3" {
// SQLite doesn't enforce varchar sizes anyway
return nil
}
alters := [...]string{
`ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT`,
`ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT`,
//go:embed *.sql
var fs embed.FS
`ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT`,
`ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT`,
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT`,
`ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN withheld_code TYPE TEXT`,
`ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT`,
`ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT`,
`ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT`,
`ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_keys ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_keys ALTER COLUMN usage TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_user_id TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_key TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_user_id TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_key TYPE TEXT`,
`ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature TYPE TEXT`,
}
for _, alter := range alters {
_, err := tx.Exec(alter)
if err != nil {
return err
}
}
return nil
},
func(tx *sql.Tx, dialect string) error {
_, err := tx.Exec("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted")
if err != nil {
return err
}
_, err = tx.Exec("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE crypto_olm_session SET last_encrypted=last_decrypted")
if err != nil {
return err
}
if dialect == "postgres" {
// This is too hard to do on sqlite, so let's just do it on postgres
_, err = tx.Exec("ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL")
if err != nil {
return err
}
}
return nil
},
}
// GetVersion returns the current version of the DB schema.
func GetVersion(db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS crypto_version (version INTEGER)")
if err != nil {
return -1, err
}
version := 0
row := db.QueryRow("SELECT version FROM crypto_version LIMIT 1")
if row != nil {
_ = row.Scan(&version)
}
return version, nil
}
// SetVersion sets the schema version in a running DB transaction.
func SetVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec("DELETE FROM crypto_version")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO crypto_version (version) VALUES ($1)", version)
return err
func init() {
Table.Register(-1, 3, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)
}
// Upgrade upgrades the database from the current to the latest version available.
func Upgrade(db *sql.DB, dialect string) error {
version, err := GetVersion(db)
func Upgrade(sqlDB *sql.DB, dialect string) error {
db, err := dbutil.NewWithDB(sqlDB, dialect)
if err != nil {
return err
}
// perform migrations starting with #version
for ; version < len(Upgrades); version++ {
tx, err := db.Begin()
if err != nil {
return err
}
// run each migrate func
migrateFunc := Upgrades[version]
err = migrateFunc(tx, dialect)
if err != nil {
_ = tx.Rollback()
return err
}
// also update the version in this tx
if err = SetVersion(tx, version+1); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
}
return nil
db.VersionTable = VersionTableName
db.UpgradeTable = Table
return db.Upgrade()
}

View file

@ -16,6 +16,7 @@ import (
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
const olmSessID = "sJlikQQKXp7UQjmS9/lyZCNUVJ2AmKyHbufPBaC7tpk"
@ -26,12 +27,16 @@ const olmPickled = "L6cdv3JYO9OzhXbcjNSwl7ldN5bDvwmGyin+hISePETE6bO71DIlhqTC9YIh
const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4fA5tovoPfdpk5QLh7PfDyjmgOcO9sSA37maJyzCy6Ap+uBZLAXp6VLJ0mjSvxi+PAbzGKDMqpn+pa+oeEIH6SFPG/2GGDSRoXVi5fttAClCIoav5RflWiMypKqnQRfkZR2Gx8glOaBiTzAd7m0X6XGfYIPol41JUIHfBLuJBfXQ0Uu5GScV4eKUWdJP2J6zzC2Hx8cZAhiBBzAza0CbGcnUK+YJXMYaJg92HiIo++l317LlsYUJ/P+gKOLafYR9/l8bAzxH7j5s31PnRs7mD1Bl6G1LFM+dPsGXUOLx6PlvlTlYYM/opai0uKKzT0Wk6zPoq9fN/smlXEPBtKlw2fqcytL4gOF0MrBPEca"
func getCryptoStores(t *testing.T) (map[string]Store, func()) {
db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
if err != nil {
t.Fatalf("Error opening db: %v", err)
}
sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{})
if err = sqlStore.CreateTables(); err != nil {
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
if err != nil {
t.Fatalf("Error opening db: %v", err)
}
sqlStore := NewSQLCryptoStore(db, "accid", id.DeviceID("dev"), []byte("test"))
if err = sqlStore.Upgrade(); err != nil {
t.Fatalf("Error creating tables: %v", err)
}

4
go.mod
View file

@ -12,7 +12,8 @@ require (
github.com/yuin/goldmark v1.4.12
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/net v0.0.0-20220513224357-95641704303c
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0
maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.3.2
)
@ -21,5 +22,4 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

7
go.sum
View file

@ -38,9 +38,10 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=

View file

@ -28,6 +28,20 @@ type BaseUpgrader interface {
GetBase() string
}
type StructUpgrader struct {
SimpleUpgrader
Blocks [][]string
Base string
}
func (su *StructUpgrader) SpacedBlocks() [][]string {
return su.Blocks
}
func (su *StructUpgrader) GetBase() string {
return su.Base
}
type SimpleUpgrader func(helper *Helper)
func (su SimpleUpgrader) DoUpgrade(helper *Helper) {

132
util/dbutil/database.go Normal file
View file

@ -0,0 +1,132 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"database/sql"
"fmt"
"strings"
"time"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/bridge/bridgeconfig"
)
type Dialect int
const (
DialectUnknown Dialect = iota
Postgres
SQLite
)
func (dialect Dialect) String() string {
switch dialect {
case Postgres:
return "postgres"
case SQLite:
return "sqlite3"
default:
return ""
}
}
func ParseDialect(engine string) (Dialect, error) {
switch strings.ToLower(engine) {
case "postgres", "postgresql":
return Postgres, nil
case "sqlite3", "sqlite":
return SQLite, nil
default:
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
}
}
type Scannable interface {
Scan(...interface{}) error
}
type Database struct {
*sql.DB
Owner string
VersionTable string
Log log.Logger
Dialect Dialect
UpgradeTable UpgradeTable
IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
}
func (db *Database) Child(logName, versionTable string, upgradeTable UpgradeTable) *Database {
return &Database{
DB: db.DB,
Owner: "",
VersionTable: versionTable,
UpgradeTable: upgradeTable,
Log: db.Log.Sub(logName),
Dialect: db.Dialect,
IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
}
}
func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
dialect, err := ParseDialect(rawDialect)
if err != nil {
return nil, err
}
return &Database{
DB: db,
Dialect: dialect,
Log: log.Sub("Database"),
IgnoreForeignTables: true,
VersionTable: "version",
}, nil
}
func NewFromConfig(owner string, cfg bridgeconfig.DatabaseConfig, dbLog log.Logger) (*Database, error) {
dialect, err := ParseDialect(cfg.Type)
if err != nil {
return nil, err
}
conn, err := sql.Open(cfg.Type, cfg.URI)
if err != nil {
return nil, err
}
conn.SetMaxOpenConns(cfg.MaxOpenConns)
conn.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
conn.SetConnMaxIdleTime(maxIdleTimeDuration)
}
if len(cfg.ConnMaxLifetime) > 0 {
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
conn.SetConnMaxLifetime(maxLifetimeDuration)
}
if dbLog == nil {
dbLog = log.Sub("Database")
}
return &Database{
DB: conn,
Owner: owner,
Log: dbLog,
Dialect: dialect,
IgnoreForeignTables: true,
VersionTable: "version",
}, nil
}

154
util/dbutil/upgrades.go Normal file
View file

@ -0,0 +1,154 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"database/sql"
"errors"
"fmt"
log "maunium.net/go/maulogger/v2"
)
type upgradeFunc func(*sql.Tx, *Database) error
type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
}
type Upgrader struct {
*sql.DB
Log log.Logger
Dialect Dialect
upgrades []upgrade
}
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
var ErrNotOwned = fmt.Errorf("the database is owned by")
func (db *Database) getVersion() (int, error) {
_, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version INTEGER)", db.VersionTable))
if err != nil {
return -1, err
}
version := 0
err = db.QueryRow(fmt.Sprintf("SELECT version FROM %s LIMIT 1", db.VersionTable)).Scan(&version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return -1, err
}
return version, nil
}
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
func (db *Database) tableExists(table string) (exists bool) {
if db.Dialect == SQLite {
_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
} else if db.Dialect == Postgres {
_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
}
return
}
const createOwnerTable = `
CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)
`
func (db *Database) checkDatabaseOwner() error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExists("state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExists("goose_db_version") {
return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to check database owner: %w", err)
} else if owner != db.Owner {
return fmt.Errorf("%w %s", ErrNotOwned, owner)
}
return nil
}
func (db *Database) setVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", db.VersionTable), version)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
if err != nil {
return err
}
version, err := db.getVersion()
if err != nil {
return err
}
if version > len(db.UpgradeTable) {
warning := fmt.Sprintf("currently on v%d, latest known: v%d", version, len(db.UpgradeTable))
if db.IgnoreUnsupportedDatabase {
db.Log.Warnfln("Unsupported database schema version: %s - continuing anyway", warning)
return nil
}
return fmt.Errorf("%w: %s", ErrUnsupportedDatabaseVersion, warning)
}
db.Log.Infofln("Database currently on v%d, latest: v%d", version, len(db.UpgradeTable))
upgradesToApply := db.UpgradeTable[version:]
for _, upgradeItem := range upgradesToApply {
if upgradeItem.fn == nil {
continue
}
db.Log.Infofln("Upgrading database from v%d to v%d: %s", version, upgradeItem.upgradesTo, upgradeItem.message)
var tx *sql.Tx
tx, err = db.Begin()
if err != nil {
return err
}
err = upgradeItem.fn(tx, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
err = db.setVersion(tx, version)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
}

159
util/dbutil/upgradetable.go Normal file
View file

@ -0,0 +1,159 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"bytes"
"database/sql"
"errors"
"fmt"
"io/fs"
"regexp"
"strconv"
)
type UpgradeTable []upgrade
func (ut *UpgradeTable) extend(toSize int) {
if cap(*ut) >= toSize {
*ut = (*ut)[:toSize]
} else {
resized := make([]upgrade, toSize)
copy(resized, *ut)
*ut = resized
}
}
func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
upg := upgrade{message: message, fn: fn, upgradesTo: to}
if len(*ut) == from {
*ut = append(*ut, upg)
return
} else if len(*ut) < from {
ut.extend(from + 1)
} else if (*ut)[from].fn != nil {
panic(fmt.Errorf("tried to override upgrade at %d ('%s') with '%s'", from, (*ut)[from].message, upg.message))
}
(*ut)[from] = upg
}
// Syntax is either
// -- v0 -> v1: Message
// or
// -- v1: Message
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
return
}
var maybeFrom int
match := upgradeHeaderRegex.FindSubmatch(lines[0])
lines = lines[1:]
if match == nil {
err = errors.New("header not found")
} else if len(match) != 4 {
err = errors.New("unexpected number of items in regex match")
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
err = fmt.Errorf("invalid source version: %w", err)
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
err = fmt.Errorf("invalid target version: %w", err)
} else {
if len(match[1]) > 0 {
from = maybeFrom
} else {
from = -1
}
message = string(match[3])
}
return
}
// To limit the next line to one dialect:
// -- only: postgres
// To limit the next N lines:
// -- only: sqlite for next 123 lines
// If the single-line limit is on the second line of the file, the whole file is limited to that dialect.
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines)?`)
func (db *Database) parseDialectFilter(line []byte) (int, error) {
match := dialectLineFilter.FindSubmatch(line)
if match != nil {
dialect, err := ParseDialect(string(match[1]))
if err != nil {
return 0, err
} else if dialect != db.Dialect {
if len(match[2]) == 0 {
return 1, nil
} else {
lineCount, err := strconv.Atoi(string(match[2]))
if err != nil {
return 0, fmt.Errorf("invalid line count '%s': %w", match[2], err)
}
return lineCount, nil
}
}
}
return 0, nil
}
func (db *Database) mutateSQLUpgrade(lines [][]byte) (string, error) {
output := lines[:0]
for i := 0; i < len(lines); i++ {
skipLines, err := db.parseDialectFilter(lines[i])
if err != nil {
return "", err
} else if skipLines > 0 {
i += skipLines
} else {
output = append(output, lines[i])
}
}
return string(bytes.Join(output, []byte("\n"))), nil
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx *sql.Tx, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == 1 {
return nil
} else if upgradeSQL, err := db.mutateSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else {
_, err = tx.Exec(upgradeSQL)
return err
}
}
}
type fullFS interface {
fs.ReadFileFS
fs.ReadDirFS
}
func (ut *UpgradeTable) RegisterFS(fs fullFS) {
files, err := fs.ReadDir(".")
if err != nil {
panic(err)
}
for _, file := range files {
if data, err := fs.ReadFile(file.Name()); err != nil {
panic(err)
} else if from, to, message, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, message, sqlUpgradeFunc(file.Name(), lines))
}
}
}