diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 13940d79..ca75b3f6 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -664,6 +664,20 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { + if eventID == "" && timestamp == 0 { + var notOK bool + const validateEmptyQuery = ` + SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) + ` + err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) + if notOK { + zerolog.Ctx(ctx).Debug(). + Uint("message_index", index). + Msg("Rejecting event without event ID and timestamp due to already knowing them") + } + return !notOK, err + } + const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) diff --git a/crypto/store.go b/crypto/store.go index 8b7c0a96..7620cf35 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -525,6 +525,9 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { + if eventID == "" && timestamp == 0 { + return true, nil + } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, diff --git a/crypto/store_test.go b/crypto/store_test.go index 8aeae7af..7a47243e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -75,8 +75,13 @@ func TestValidateMessageIndex(t *testing.T) { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() + // Validating without event ID and timestamp before we have them should work + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + // First message should validate successfully - ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") @@ -94,6 +99,11 @@ func TestValidateMessageIndex(t *testing.T) { ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") + + // Validating without event ID and timestamp must fail if we already know them + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.False(t, ok, "First message validation should be invalid") }) } }