From 976e11ad112afc19ce47fa185e326a4f33726249 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 17 Jan 2025 11:21:24 -0700 Subject: [PATCH] crypto/goolm/message: use buffers for encode/decode functions Signed-off-by: Sumner Evans --- crypto/goolm/message/decoder.go | 82 +++--------- crypto/goolm/message/encoder.go | 24 ++++ .../{decoder_test.go => encoder_test.go} | 34 ++--- crypto/goolm/message/group_message.go | 91 ++++++-------- crypto/goolm/message/message.go | 98 ++++++--------- crypto/goolm/message/prekey_message.go | 119 ++++++++---------- 6 files changed, 178 insertions(+), 270 deletions(-) create mode 100644 crypto/goolm/message/encoder.go rename crypto/goolm/message/{decoder_test.go => encoder_test.go} (58%) diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index 9ce426b5..a71cf302 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -1,70 +1,28 @@ package message import ( + "bytes" "encoding/binary" - - "maunium.net/go/mautrix/crypto/olm" ) -// checkDecodeErr checks if there was an error during decode. -func checkDecodeErr(readBytes int) error { - if readBytes == 0 { - //end reached - return olm.ErrInputToSmall +type Decoder struct { + *bytes.Buffer +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{bytes.NewBuffer(buf)} +} + +func (d *Decoder) ReadVarInt() (uint64, error) { + return binary.ReadUvarint(d) +} + +func (d *Decoder) ReadVarBytes() ([]byte, error) { + if n, err := d.ReadVarInt(); err != nil { + return nil, err + } else { + out := make([]byte, n) + _, err = d.Read(out) + return out, err } - if readBytes < 0 { - return olm.ErrOverflow - } - return nil -} - -// decodeVarInt decodes a single big-endian encoded varint. -func decodeVarInt(input []byte) (uint32, int) { - value, readBytes := binary.Uvarint(input) - return uint32(value), readBytes -} - -// decodeVarString decodes the length of the string (varint) and returns the actual string -func decodeVarString(input []byte) ([]byte, int) { - stringLen, readBytes := decodeVarInt(input) - if readBytes <= 0 { - return nil, readBytes - } - input = input[readBytes:] - value := input[:stringLen] - readBytes += int(stringLen) - return value, readBytes -} - -// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. -func encodeVarIntByteLength(input uint32) int { - result := 1 - for input >= 128 { - result++ - input >>= 7 - } - return result -} - -// encodeVarStringByteLength returns the number of bytes needed to encode the input. -func encodeVarStringByteLength(input []byte) int { - result := encodeVarIntByteLength(uint32(len(input))) - result += len(input) - return result -} - -// encodeVarInt encodes a single uint32 -func encodeVarInt(input uint32) []byte { - out := make([]byte, encodeVarIntByteLength(input)) - binary.PutUvarint(out, uint64(input)) - return out -} - -// encodeVarString encodes the length of the input (varint) and appends the actual input -func encodeVarString(input []byte) []byte { - out := make([]byte, encodeVarStringByteLength(input)) - length := encodeVarInt(uint32(len(input))) - copy(out, length) - copy(out[len(length):], input) - return out } diff --git a/crypto/goolm/message/encoder.go b/crypto/goolm/message/encoder.go new file mode 100644 index 00000000..95ab6d41 --- /dev/null +++ b/crypto/goolm/message/encoder.go @@ -0,0 +1,24 @@ +package message + +import "encoding/binary" + +type Encoder struct { + buf []byte +} + +func (e *Encoder) Bytes() []byte { + return e.buf +} + +func (e *Encoder) PutByte(val byte) { + e.buf = append(e.buf, val) +} + +func (e *Encoder) PutVarInt(val uint64) { + e.buf = binary.AppendUvarint(e.buf, val) +} + +func (e *Encoder) PutVarBytes(data []byte) { + e.PutVarInt(uint64(len(data))) + e.buf = append(e.buf, data...) +} diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/encoder_test.go similarity index 58% rename from crypto/goolm/message/decoder_test.go rename to crypto/goolm/message/encoder_test.go index 8b7561ad..1fe2ebdb 100644 --- a/crypto/goolm/message/decoder_test.go +++ b/crypto/goolm/message/encoder_test.go @@ -1,33 +1,13 @@ -package message +package message_test import ( "testing" "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/message" ) -func TestEncodeLengthInt(t *testing.T) { - numbers := []uint32{127, 128, 16383, 16384, 32767} - expected := []int{1, 2, 2, 3, 3} - for curIndex := range numbers { - assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex])) - } -} - -func TestEncodeLengthString(t *testing.T) { - var strings [][]byte - var expected []int - strings = append(strings, []byte("test")) - expected = append(expected, 1+4) - strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) - expected = append(expected, 1+127) - strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) - expected = append(expected, 2+155) - for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex])) - } -} - func TestEncodeInt(t *testing.T) { var ints []uint32 var expected [][]byte @@ -40,7 +20,9 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex])) + var encoder message.Encoder + encoder.PutVarInt(uint64(ints[curIndex])) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } @@ -70,6 +52,8 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex])) + var encoder message.Encoder + encoder.PutVarBytes(strings[curIndex]) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index 411e0879..c2a43b1f 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -22,85 +23,63 @@ type GroupMessage struct { } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. -func (r *GroupMessage) Decode(input []byte) error { +func (r *GroupMessage) Decode(input []byte) (err error) { r.Version = 0 r.MessageIndex = 0 r.Ciphertext = nil if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize]) + r.Version, err = decoder.ReadByte() // First byte is the version + if err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case messageIndexTag: - r.MessageIndex = value + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == messageIndexTag { + r.MessageIndex = uint32(value) r.HasMessageIndex = true } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case cipherTextTag: + } else if curKey == cipherTextTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) - lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(messageIndexTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarInt(r.MessageIndex) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - mac, err := r.MAC(cipher, out) + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(messageIndexTag) + encoder.PutVarInt(uint64(r.MessageIndex)) + encoder.PutVarInt(cipherTextTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := r.MAC(cipher, encoder.Bytes()) if err != nil { return nil, err } - out = append(out, mac[:countMACBytesGroupMessage]...) - signature, err := signKey.Sign(out) - if err != nil { - return nil, err - } - out = append(out, signature...) - return out, nil + ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...) + signature, err := signKey.Sign(ciphertextWithMAC) + return append(ciphertextWithMAC, signature...), err } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 88efdc14..8bb6e0cd 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -24,7 +25,7 @@ type Message struct { } // Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. -func (r *Message) Decode(input []byte) error { +func (r *Message) Decode(input []byte) (err error) { r.Version = 0 r.HasCounter = false r.Counter = 0 @@ -33,82 +34,55 @@ func (r *Message) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesMessage { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesMessage]) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case counterTag: + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == counterTag { + r.Counter = uint32(value) r.HasCounter = true - r.Counter = value } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case ratchetKeyTag: + } else if curKey == ratchetKeyTag { r.RatchetKey = value - case cipherTextKeyTag: + } else if curKey == cipherTextKeyTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) - lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) - lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(ratchetKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.RatchetKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(counterTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarInt(r.Counter) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - mac, err := cipher.MAC(out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesMessage]...) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(ratchetKeyTag) + encoder.PutVarBytes(r.RatchetKey) + encoder.PutVarInt(counterTag) + encoder.PutVarInt(uint64(r.Counter)) + encoder.PutVarInt(cipherTextKeyTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := cipher.MAC(encoder.Bytes()) + return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 1238a9a5..22ebf9c3 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,11 +1,14 @@ package message import ( + "io" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( - oneTimeKeyIdTag = 0x0A + oneTimeKeyIDTag = 0x0A baseKeyTag = 0x12 identityKeyTag = 0x1A messageTag = 0x22 @@ -20,7 +23,7 @@ type PreKeyMessage struct { } // Decodes decodes the input and populates the corresponding fileds. -func (r *PreKeyMessage) Decode(input []byte) error { +func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 r.IdentityKey = nil r.BaseKey = nil @@ -29,44 +32,52 @@ func (r *PreKeyMessage) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input) { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + if err == io.EOF { + return olm.ErrInputToSmall } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - _, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return nil + } + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if _, err = decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err } - curPos += readBytes } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err - } - curPos += readBytes - switch curKey { - case oneTimeKeyIdTag: - r.OneTimeKey = value - case baseKeyTag: - r.BaseKey = value - case identityKeyTag: - r.IdentityKey = value - case messageTag: - r.Message = value + } else { + switch curKey { + case oneTimeKeyIDTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value + } } } } - - return nil } // CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. @@ -84,37 +95,15 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey // Encode encodes the message. func (r *PreKeyMessage) Encode() ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) - lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) - lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) - lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(oneTimeKeyIdTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.OneTimeKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(identityKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.IdentityKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(baseKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.BaseKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(messageTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Message) - copy(out[curPos:], encodedValue) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(oneTimeKeyIDTag) + encoder.PutVarBytes(r.OneTimeKey) + encoder.PutVarInt(identityKeyTag) + encoder.PutVarBytes(r.IdentityKey) + encoder.PutVarInt(baseKeyTag) + encoder.PutVarBytes(r.BaseKey) + encoder.PutVarInt(messageTag) + encoder.PutVarBytes(r.Message) + return encoder.Bytes(), nil }