From 3d4ddf716193ed61d9e7301edd23df9b9afef94d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Apr 2020 02:56:09 +0300 Subject: [PATCH] Add support for olm events in EncryptedEventContent struct --- event/content.go | 5 +++ event/encryption.go | 84 ++++++++++++++++++++++++++++----------------- event/type.go | 1 + 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/event/content.go b/event/content.go index e9663cc7..b351e7aa 100644 --- a/event/content.go +++ b/event/content.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "github.com/fatih/structs" ) @@ -101,6 +102,10 @@ func mergeMaps(into, from map[string]interface{}) { } } +func IsUnsupportedContentType(err error) bool { + return strings.HasPrefix(err.Error(), "unsupported content type ") +} + func (content *Content) ParseRaw(evtType Type) error { structType, ok := TypeMap[evtType] if !ok { diff --git a/event/encryption.go b/event/encryption.go index dac16a6a..c8a0deb0 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -7,8 +7,7 @@ package event import ( - "encoding/base64" - "errors" + "encoding/json" "maunium.net/go/mautrix/id" ) @@ -22,29 +21,6 @@ const ( AlgorithmMegolmV1 Algorithm = "m.megolm.v1.aes-sha2" ) -var unpaddedBase64 = base64.StdEncoding.WithPadding(base64.NoPadding) - -// UnpaddedBase64 is a byte array that implements the JSON Marshaler and Unmarshaler interfaces -// to encode and decode the byte array as unpadded base64. -type UnpaddedBase64 []byte - -func (ub64 *UnpaddedBase64) UnmarshalJSON(data []byte) error { - if data[0] != '"' || data[len(data)-1] != '"' { - return errors.New("failed to decode data into bytes: input doesn't look like a JSON string") - } - *ub64 = make([]byte, unpaddedBase64.DecodedLen(len(data)-2)) - _, err := unpaddedBase64.Decode(*ub64, data[1:len(data)-1]) - return err -} - -func (ub64 *UnpaddedBase64) MarshalJSON() ([]byte, error) { - data := make([]byte, unpaddedBase64.EncodedLen(len(*ub64))+2) - data[0] = '"' - data[len(data)-1] = '"' - unpaddedBase64.Encode(data[1:len(data)-1], *ub64) - return data, nil -} - // EncryptionEventContent represents the content of a m.room.encryption state event. // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-encryption type EncryptionEventContent struct { @@ -57,14 +33,60 @@ type EncryptionEventContent struct { } // EncryptedEventContent represents the content of a m.room.encrypted message event. -// This struct only supports the m.megolm.v1.aes-sha2 algorithm. The legacy m.olm.v1 algorithm is not supported. // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-encrypted type EncryptedEventContent struct { - Algorithm Algorithm `json:"algorithm"` - SenderKey string `json:"sender_key"` - DeviceID id.DeviceID `json:"device_id"` - SessionID string `json:"session_id"` - Ciphertext UnpaddedBase64 `json:"ciphertext"` + Algorithm Algorithm `json:"algorithm"` + SenderKey string `json:"sender_key"` + DeviceID id.DeviceID `json:"device_id"` + SessionID string `json:"session_id"` + Ciphertext json.RawMessage `json:"ciphertext"` + + MegolmCiphertext string `json:"-"` + OlmCiphertext OlmCiphertexts `json:"-"` +} + +type OlmMessageType int + +const ( + OlmPreKeyMessage OlmMessageType = 0 + OlmNormalMessage OlmMessageType = 1 +) + +type OlmCiphertexts map[string]struct { + Body string `json:"body"` + Type OlmMessageType `json:"type"` +} + +type serializableEncryptedEventContent EncryptedEventContent + +func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*serializableEncryptedEventContent)(content)) + if err != nil { + return err + } + switch content.Algorithm { + case AlgorithmOlmV1: + content.OlmCiphertext = make(OlmCiphertexts) + return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) + case AlgorithmMegolmV1: + return json.Unmarshal(content.Ciphertext, &content.MegolmCiphertext) + default: + return nil + } +} + +func (content *EncryptedEventContent) MarshalJSON() ([]byte, error) { + var err error + switch content.Algorithm { + case AlgorithmOlmV1: + content.Ciphertext, err = json.Marshal(content.OlmCiphertext) + case AlgorithmMegolmV1: + content.Ciphertext, err = json.Marshal(content.MegolmCiphertext) + } + if err != nil { + return nil, err + } + return json.Marshal((*serializableEncryptedEventContent)(content)) } // RoomKeyEventContent represents the content of a m.room_key to_device event. diff --git a/event/type.go b/event/type.go index 6a80edb8..4b151fb6 100644 --- a/event/type.go +++ b/event/type.go @@ -165,4 +165,5 @@ var ( ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType} ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType} ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType} + ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType} )