From 363aa943895876bff46f2cae3a310ace981058f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Aug 2025 03:13:10 +0300 Subject: [PATCH] federation/pdu: add server name parameter to GetKeyFunc --- federation/pdu/pdu.go | 2 +- federation/pdu/pdu_test.go | 5 ++++- federation/pdu/signature.go | 2 +- federation/pdu/signature_test.go | 10 +++++----- federation/pdu/v1.go | 2 +- federation/pdu/v1_test.go | 2 +- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 0e63ea7c..c6faf3d0 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -33,7 +33,7 @@ import ( // valid at or after this time, but if that is not possible, the latest available key should be // returned without an error. The verify function will do its own validity checking based on the // returned valid until timestamp. -type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) +type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) type AnyPDU interface { GetRoomID() (id.RoomID, error) diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go index 93244741..59d7c3a6 100644 --- a/federation/pdu/pdu_test.go +++ b/federation/pdu/pdu_test.go @@ -28,7 +28,10 @@ type serverDetails struct { keys map[id.KeyID]serverKey } -func (sd serverDetails) getKey(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { +func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + if serverName != sd.serverName { + return "", time.Time{}, nil + } key, ok := sd.keys[keyID] if ok { return key.key, key.validUntilTS, nil diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go index 1f8ae0b5..a7685cc6 100644 --- a/federation/pdu/signature.go +++ b/federation/pdu/signature.go @@ -46,7 +46,7 @@ func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, g verified := false for keyID, sig := range pdu.Signatures[serverName] { originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, validUntil, err := getKey(keyID, originServerTS) + key, validUntil, err := getKey(serverName, keyID, originServerTS) if err != nil { return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) } else if key == "" { diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go index 68e7a773..01df5076 100644 --- a/federation/pdu/signature_test.go +++ b/federation/pdu/signature_test.go @@ -36,7 +36,7 @@ func TestPDU_VerifySignature(t *testing.T) { func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { test := roomV12MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { return }) assert.Error(t, err) @@ -45,7 +45,7 @@ func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { test := roomV4MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { key = test.keys[keyID].key validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) return @@ -56,7 +56,7 @@ func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { test := roomV12MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { key = test.keys[keyID].key validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) return @@ -90,8 +90,8 @@ func TestPDU_Sign(t *testing.T) { } err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) require.NoError(t, err) - err = evt.VerifySignature(id.RoomV11, "example.com", func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - if keyID == "ed25519:rand" { + err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + if serverName == "example.com" && keyID == "ed25519:rand" { key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) validUntil = time.Now() } diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index fc958b03..0e4c95e9 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -202,7 +202,7 @@ func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName str verified := false for keyID, sig := range pdu.Signatures[serverName] { originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, _, err := getKey(keyID, originServerTS) + key, _, err := getKey(serverName, keyID, originServerTS) if err != nil { return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) } else if key == "" { diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go index e5531b0b..ecf2dbd2 100644 --- a/federation/pdu/v1_test.go +++ b/federation/pdu/v1_test.go @@ -73,7 +73,7 @@ func TestRoomV1PDU_VerifySignature(t *testing.T) { for _, test := range testV1PDUs { t.Run(test.name, func(t *testing.T) { parsed := parseV1PDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { key, ok := test.keys[keyID] if ok { return key.key, key.validUntilTS, nil