mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
federation/pdu: add server name parameter to GetKeyFunc
This commit is contained in:
parent
fd20a61d87
commit
363aa94389
6 changed files with 13 additions and 10 deletions
|
|
@ -33,7 +33,7 @@ import (
|
|||
// valid at or after this time, but if that is not possible, the latest available key should be
|
||||
// returned without an error. The verify function will do its own validity checking based on the
|
||||
// returned valid until timestamp.
|
||||
type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
|
||||
type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
|
||||
|
||||
type AnyPDU interface {
|
||||
GetRoomID() (id.RoomID, error)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,10 @@ type serverDetails struct {
|
|||
keys map[id.KeyID]serverKey
|
||||
}
|
||||
|
||||
func (sd serverDetails) getKey(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
|
||||
func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
|
||||
if serverName != sd.serverName {
|
||||
return "", time.Time{}, nil
|
||||
}
|
||||
key, ok := sd.keys[keyID]
|
||||
if ok {
|
||||
return key.key, key.validUntilTS, nil
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, g
|
|||
verified := false
|
||||
for keyID, sig := range pdu.Signatures[serverName] {
|
||||
originServerTS := time.UnixMilli(pdu.OriginServerTS)
|
||||
key, validUntil, err := getKey(keyID, originServerTS)
|
||||
key, validUntil, err := getKey(serverName, keyID, originServerTS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
|
||||
} else if key == "" {
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ func TestPDU_VerifySignature(t *testing.T) {
|
|||
func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
|
||||
test := roomV12MessageTestPDU
|
||||
parsed := parsePDU(test.pdu)
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
return
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
|
@ -45,7 +45,7 @@ func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
|
|||
func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
|
||||
test := roomV4MessageTestPDU
|
||||
parsed := parsePDU(test.pdu)
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
key = test.keys[keyID].key
|
||||
validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
return
|
||||
|
|
@ -56,7 +56,7 @@ func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
|
|||
func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
|
||||
test := roomV12MessageTestPDU
|
||||
parsed := parsePDU(test.pdu)
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
key = test.keys[keyID].key
|
||||
validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
return
|
||||
|
|
@ -90,8 +90,8 @@ func TestPDU_Sign(t *testing.T) {
|
|||
}
|
||||
err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
|
||||
require.NoError(t, err)
|
||||
err = evt.VerifySignature(id.RoomV11, "example.com", func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
if keyID == "ed25519:rand" {
|
||||
err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
|
||||
if serverName == "example.com" && keyID == "ed25519:rand" {
|
||||
key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
|
||||
validUntil = time.Now()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName str
|
|||
verified := false
|
||||
for keyID, sig := range pdu.Signatures[serverName] {
|
||||
originServerTS := time.UnixMilli(pdu.OriginServerTS)
|
||||
key, _, err := getKey(keyID, originServerTS)
|
||||
key, _, err := getKey(serverName, keyID, originServerTS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
|
||||
} else if key == "" {
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ func TestRoomV1PDU_VerifySignature(t *testing.T) {
|
|||
for _, test := range testV1PDUs {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
parsed := parseV1PDU(test.pdu)
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
|
||||
err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
|
||||
key, ok := test.keys[keyID]
|
||||
if ok {
|
||||
return key.key, key.validUntilTS, nil
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue