diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 275b2051..9ad5f77c 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -101,7 +101,7 @@ func fnResolveIdentifier(ce *Event) { } portal := resp.Chat.Portal if portal == nil { - portal, err = ce.Bridge.GetPortalByID(ce.Ctx, resp.Chat.PortalID) + portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) if err != nil { ce.Log.Err(err).Msg("Failed to get portal") ce.Reply("Failed to get portal: %v", err) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 503d8a62..417035f0 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -64,7 +64,7 @@ const ( metadata FROM portal ` - getPortalByIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + getPortalByKeyQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` @@ -99,8 +99,8 @@ const ( reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` ) -func (pq *PortalQuery) GetByID(ctx context.Context, key networkid.PortalKey) (*Portal, error) { - return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, key.ID, key.Receiver) +func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByKeyQuery, pq.BridgeID, key.ID, key.Receiver) } func (pq *PortalQuery) FindReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (key networkid.PortalKey, err error) { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 197670d4..c0eb16e4 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -416,7 +416,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } if resp.Chat != nil { if resp.Chat.Portal == nil { - resp.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), resp.Chat.PortalID) + resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ @@ -498,7 +498,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque } if contact.Chat != nil { if contact.Chat.Portal == nil { - contact.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), contact.Chat.PortalID) + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), contact.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 6ee84e4a..43842bdd 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -504,8 +504,8 @@ type ResolveIdentifierResponse struct { } type CreateChatResponse struct { - PortalID networkid.PortalKey - // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalID if necessary. + PortalKey networkid.PortalKey + // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. Portal *Portal PortalInfo *ChatInfo } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c5054022..81b65ad2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -95,7 +95,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } var err error if portal.ParentID != "" { - portal.Parent, err = br.UnlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) + portal.Parent, err = br.UnlockedGetPortalByKey(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } @@ -119,17 +119,17 @@ func (portal *Portal) updateLogger() { portal.Log = logWith.Logger() } -func (br *Bridge) UnlockedGetPortalByID(ctx context.Context, id networkid.PortalKey, onlyIfExists bool) (*Portal, error) { - cached, ok := br.portalsByKey[id] +func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { + cached, ok := br.portalsByKey[key] if ok { return cached, nil } - idPtr := &id + keyPtr := &key if onlyIfExists { - idPtr = nil + keyPtr = nil } - db, err := br.DB.Portal.GetByID(ctx, id) - return br.loadPortal(ctx, db, err, idPtr) + db, err := br.DB.Portal.GetByKey(ctx, key) + return br.loadPortal(ctx, db, err, keyPtr) } func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (networkid.PortalKey, error) { @@ -172,27 +172,27 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, return br.loadPortal(ctx, db, err, nil) } -func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { +func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.UnlockedGetPortalByID(ctx, id, false) + return br.UnlockedGetPortalByKey(ctx, key, false) } -func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { +func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - if id.Receiver == "" { - return br.UnlockedGetPortalByID(ctx, id, true) + if key.Receiver == "" { + return br.UnlockedGetPortalByKey(ctx, key, true) } - cached, ok := br.portalsByKey[id] + cached, ok := br.portalsByKey[key] if ok { return cached, nil } - cached, ok = br.portalsByKey[networkid.PortalKey{ID: id.ID}] + cached, ok = br.portalsByKey[networkid.PortalKey{ID: key.ID}] if ok { return cached, nil } - db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, id) + db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, key) return br.loadPortal(ctx, db, err, nil) } @@ -2178,7 +2178,7 @@ func (portal *Portal) UpdateParent(ctx context.Context, newParent networkid.Port portal.ParentID = newParent portal.InSpace = false if newParent != "" { - portal.Parent, err = portal.Bridge.GetPortalByID(ctx, networkid.PortalKey{ID: newParent}) + portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, networkid.PortalKey{ID: newParent}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") } diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index c4f7a69b..a25fe820 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -39,7 +39,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta }() br.cacheLock.Lock() defer br.cacheLock.Unlock() - sourcePortal, err := br.UnlockedGetPortalByID(ctx, source, true) + sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { @@ -59,7 +59,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) - targetPortal, err := br.UnlockedGetPortalByID(ctx, target, true) + targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } diff --git a/bridgev2/queue.go b/bridgev2/queue.go index ec60cbb8..6254fd62 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -139,7 +139,7 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) - portal, err := br.GetPortalByID(ctx, evt.GetPortalKey()) + portal, err := br.GetPortalByKey(ctx, evt.GetPortalKey()) if err != nil { log.Err(err).Object("portal_id", evt.GetPortalKey()). Msg("Failed to get portal to handle remote event")