federation/pdu: add server name parameter to GetKeyFunc
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

This commit is contained in:
Tulir Asokan 2025-08-23 03:13:10 +03:00
commit 363aa94389
6 changed files with 13 additions and 10 deletions

View file

@ -33,7 +33,7 @@ import (
// valid at or after this time, but if that is not possible, the latest available key should be
// returned without an error. The verify function will do its own validity checking based on the
// returned valid until timestamp.
type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
type AnyPDU interface {
GetRoomID() (id.RoomID, error)

View file

@ -28,7 +28,10 @@ type serverDetails struct {
keys map[id.KeyID]serverKey
}
func (sd serverDetails) getKey(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
if serverName != sd.serverName {
return "", time.Time{}, nil
}
key, ok := sd.keys[keyID]
if ok {
return key.key, key.validUntilTS, nil

View file

@ -46,7 +46,7 @@ func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, g
verified := false
for keyID, sig := range pdu.Signatures[serverName] {
originServerTS := time.UnixMilli(pdu.OriginServerTS)
key, validUntil, err := getKey(keyID, originServerTS)
key, validUntil, err := getKey(serverName, keyID, originServerTS)
if err != nil {
return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
} else if key == "" {

View file

@ -36,7 +36,7 @@ func TestPDU_VerifySignature(t *testing.T) {
func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
test := roomV12MessageTestPDU
parsed := parsePDU(test.pdu)
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
return
})
assert.Error(t, err)
@ -45,7 +45,7 @@ func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
test := roomV4MessageTestPDU
parsed := parsePDU(test.pdu)
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
key = test.keys[keyID].key
validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
return
@ -56,7 +56,7 @@ func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
test := roomV12MessageTestPDU
parsed := parsePDU(test.pdu)
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
key = test.keys[keyID].key
validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
return
@ -90,8 +90,8 @@ func TestPDU_Sign(t *testing.T) {
}
err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
require.NoError(t, err)
err = evt.VerifySignature(id.RoomV11, "example.com", func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
if keyID == "ed25519:rand" {
err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
if serverName == "example.com" && keyID == "ed25519:rand" {
key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
validUntil = time.Now()
}

View file

@ -202,7 +202,7 @@ func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName str
verified := false
for keyID, sig := range pdu.Signatures[serverName] {
originServerTS := time.UnixMilli(pdu.OriginServerTS)
key, _, err := getKey(keyID, originServerTS)
key, _, err := getKey(serverName, keyID, originServerTS)
if err != nil {
return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
} else if key == "" {

View file

@ -73,7 +73,7 @@ func TestRoomV1PDU_VerifySignature(t *testing.T) {
for _, test := range testV1PDUs {
t.Run(test.name, func(t *testing.T) {
parsed := parseV1PDU(test.pdu)
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
key, ok := test.keys[keyID]
if ok {
return key.key, key.validUntilTS, nil