crypto/goolm/message: use buffers for encode/decode functions

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans 2025-01-17 11:21:24 -07:00
commit 976e11ad11
No known key found for this signature in database
6 changed files with 170 additions and 262 deletions

View file

@ -1,70 +1,28 @@
package message
import (
"bytes"
"encoding/binary"
"maunium.net/go/mautrix/crypto/olm"
)
// checkDecodeErr checks if there was an error during decode.
func checkDecodeErr(readBytes int) error {
if readBytes == 0 {
//end reached
return olm.ErrInputToSmall
type Decoder struct {
*bytes.Buffer
}
func NewDecoder(buf []byte) *Decoder {
return &Decoder{bytes.NewBuffer(buf)}
}
func (d *Decoder) ReadVarInt() (uint64, error) {
return binary.ReadUvarint(d)
}
func (d *Decoder) ReadVarBytes() ([]byte, error) {
if n, err := d.ReadVarInt(); err != nil {
return nil, err
} else {
out := make([]byte, n)
_, err = d.Read(out)
return out, err
}
if readBytes < 0 {
return olm.ErrOverflow
}
return nil
}
// decodeVarInt decodes a single big-endian encoded varint.
func decodeVarInt(input []byte) (uint32, int) {
value, readBytes := binary.Uvarint(input)
return uint32(value), readBytes
}
// decodeVarString decodes the length of the string (varint) and returns the actual string
func decodeVarString(input []byte) ([]byte, int) {
stringLen, readBytes := decodeVarInt(input)
if readBytes <= 0 {
return nil, readBytes
}
input = input[readBytes:]
value := input[:stringLen]
readBytes += int(stringLen)
return value, readBytes
}
// encodeVarIntByteLength returns the number of bytes needed to encode the uint32.
func encodeVarIntByteLength(input uint32) int {
result := 1
for input >= 128 {
result++
input >>= 7
}
return result
}
// encodeVarStringByteLength returns the number of bytes needed to encode the input.
func encodeVarStringByteLength(input []byte) int {
result := encodeVarIntByteLength(uint32(len(input)))
result += len(input)
return result
}
// encodeVarInt encodes a single uint32
func encodeVarInt(input uint32) []byte {
out := make([]byte, encodeVarIntByteLength(input))
binary.PutUvarint(out, uint64(input))
return out
}
// encodeVarString encodes the length of the input (varint) and appends the actual input
func encodeVarString(input []byte) []byte {
out := make([]byte, encodeVarStringByteLength(input))
length := encodeVarInt(uint32(len(input)))
copy(out, length)
copy(out[len(length):], input)
return out
}

View file

@ -0,0 +1,24 @@
package message
import "encoding/binary"
type Encoder struct {
buf []byte
}
func (e *Encoder) Bytes() []byte {
return e.buf
}
func (e *Encoder) PutByte(val byte) {
e.buf = append(e.buf, val)
}
func (e *Encoder) PutVarInt(val uint64) {
e.buf = binary.AppendUvarint(e.buf, val)
}
func (e *Encoder) PutVarBytes(data []byte) {
e.PutVarInt(uint64(len(data)))
e.buf = append(e.buf, data...)
}

View file

@ -1,33 +1,13 @@
package message
package message_test
import (
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/message"
)
func TestEncodeLengthInt(t *testing.T) {
numbers := []uint32{127, 128, 16383, 16384, 32767}
expected := []int{1, 2, 2, 3, 3}
for curIndex := range numbers {
assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex]))
}
}
func TestEncodeLengthString(t *testing.T) {
var strings [][]byte
var expected []int
strings = append(strings, []byte("test"))
expected = append(expected, 1+4)
strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------"))
expected = append(expected, 1+127)
strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------"))
expected = append(expected, 2+155)
for curIndex := range strings {
assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex]))
}
}
func TestEncodeInt(t *testing.T) {
var ints []uint32
var expected [][]byte
@ -40,7 +20,9 @@ func TestEncodeInt(t *testing.T) {
ints = append(ints, 16383)
expected = append(expected, []byte{0b11111111, 0b01111111})
for curIndex := range ints {
assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex]))
var encoder message.Encoder
encoder.PutVarInt(uint64(ints[curIndex]))
assert.Equal(t, expected[curIndex], encoder.Bytes())
}
}
@ -70,6 +52,8 @@ func TestEncodeString(t *testing.T) {
res = append(res, curTest...) //Add string itself
expected = append(expected, res)
for curIndex := range strings {
assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex]))
var encoder message.Encoder
encoder.PutVarBytes(strings[curIndex])
assert.Equal(t, expected[curIndex], encoder.Bytes())
}
}

View file

@ -2,6 +2,7 @@ package message
import (
"bytes"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
@ -22,85 +23,63 @@ type GroupMessage struct {
}
// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present.
func (r *GroupMessage) Decode(input []byte) error {
func (r *GroupMessage) Decode(input []byte) (err error) {
r.Version = 0
r.MessageIndex = 0
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize])
r.Version, err = decoder.ReadByte() // First byte is the version
if err != nil {
return
}
for {
// Read Key
if curKey, err := decoder.ReadVarInt(); err != nil {
if err == io.EOF {
// No more keys to read
return nil
}
curPos += readBytes
switch curKey {
case messageIndexTag:
r.MessageIndex = value
return err
} else if (curKey & 0b111) == 0 {
// The value is of type varint
if value, err := decoder.ReadVarInt(); err != nil {
return err
} else if curKey == messageIndexTag {
r.MessageIndex = uint32(value)
r.HasMessageIndex = true
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
// The value is of type string
if value, err := decoder.ReadVarBytes(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case cipherTextTag:
} else if curKey == cipherTextTag {
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message.
// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended.
func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex)
lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(messageIndexTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarInt(r.MessageIndex)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
mac, err := r.MAC(cipher, out)
var encoder Encoder
encoder.PutByte(r.Version)
encoder.PutVarInt(messageIndexTag)
encoder.PutVarInt(uint64(r.MessageIndex))
encoder.PutVarInt(cipherTextTag)
encoder.PutVarBytes(r.Ciphertext)
mac, err := r.MAC(cipher, encoder.Bytes())
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesGroupMessage]...)
signature, err := signKey.Sign(out)
if err != nil {
return nil, err
}
out = append(out, signature...)
return out, nil
ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...)
signature, err := signKey.Sign(ciphertextWithMAC)
return append(ciphertextWithMAC, signature...), err
}
// MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length.

View file

@ -2,6 +2,7 @@ package message
import (
"bytes"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
@ -24,7 +25,7 @@ type Message struct {
}
// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present.
func (r *Message) Decode(input []byte) error {
func (r *Message) Decode(input []byte) (err error) {
r.Version = 0
r.HasCounter = false
r.Counter = 0
@ -33,82 +34,55 @@ func (r *Message) Decode(input []byte) error {
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesMessage {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
decoder := NewDecoder(input[:len(input)-countMACBytesMessage])
r.Version, err = decoder.ReadByte() // first byte is always version
if err != nil {
return
}
for {
// Read Key
if curKey, err := decoder.ReadVarInt(); err != nil {
if err == io.EOF {
// No more keys to read
return nil
}
curPos += readBytes
switch curKey {
case counterTag:
return err
} else if (curKey & 0b111) == 0 {
// The value is of type varint
if value, err := decoder.ReadVarInt(); err != nil {
return err
} else if curKey == counterTag {
r.Counter = uint32(value)
r.HasCounter = true
r.Counter = value
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
// The value is of type string
if value, err := decoder.ReadVarBytes(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case ratchetKeyTag:
} else if curKey == ratchetKeyTag {
r.RatchetKey = value
case cipherTextKeyTag:
} else if curKey == cipherTextKeyTag {
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher.
// If key or cipher is nil, no MAC is appended.
func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey)
lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter)
lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(ratchetKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.RatchetKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(counterTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarInt(r.Counter)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
mac, err := cipher.MAC(out)
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesMessage]...)
return out, nil
var encoder Encoder
encoder.PutByte(r.Version)
encoder.PutVarInt(ratchetKeyTag)
encoder.PutVarBytes(r.RatchetKey)
encoder.PutVarInt(counterTag)
encoder.PutVarInt(uint64(r.Counter))
encoder.PutVarInt(cipherTextKeyTag)
encoder.PutVarBytes(r.Ciphertext)
mac, err := cipher.MAC(encoder.Bytes())
return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err
}
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.

View file

@ -1,11 +1,14 @@
package message
import (
"io"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/olm"
)
const (
oneTimeKeyIdTag = 0x0A
oneTimeKeyIDTag = 0x0A
baseKeyTag = 0x12
identityKeyTag = 0x1A
messageTag = 0x22
@ -20,7 +23,7 @@ type PreKeyMessage struct {
}
// Decodes decodes the input and populates the corresponding fileds.
func (r *PreKeyMessage) Decode(input []byte) error {
func (r *PreKeyMessage) Decode(input []byte) (err error) {
r.Version = 0
r.IdentityKey = nil
r.BaseKey = nil
@ -29,44 +32,52 @@ func (r *PreKeyMessage) Decode(input []byte) error {
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input) {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
decoder := NewDecoder(input)
r.Version, err = decoder.ReadByte() // first byte is always version
if err != nil {
if err == io.EOF {
return olm.ErrInputToSmall
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
_, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return
}
for {
// Read Key
if curKey, err := decoder.ReadVarInt(); err != nil {
if err == io.EOF {
return nil
}
return err
} else if (curKey & 0b111) == 0 {
// The value is of type varint
if _, err = decoder.ReadVarInt(); err != nil {
if err == io.EOF {
return olm.ErrInputToSmall
}
return err
}
curPos += readBytes
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
// The value is of type string
if value, err := decoder.ReadVarBytes(); err != nil {
if err == io.EOF {
return olm.ErrInputToSmall
}
return err
}
curPos += readBytes
switch curKey {
case oneTimeKeyIdTag:
r.OneTimeKey = value
case baseKeyTag:
r.BaseKey = value
case identityKeyTag:
r.IdentityKey = value
case messageTag:
r.Message = value
} else {
switch curKey {
case oneTimeKeyIDTag:
r.OneTimeKey = value
case baseKeyTag:
r.BaseKey = value
case identityKeyTag:
r.IdentityKey = value
case messageTag:
r.Message = value
}
}
}
}
return nil
}
// CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message.
@ -84,37 +95,15 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey
// Encode encodes the message.
func (r *PreKeyMessage) Encode() ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey)
lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey)
lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey)
lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(oneTimeKeyIdTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.OneTimeKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(identityKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.IdentityKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(baseKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.BaseKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(messageTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Message)
copy(out[curPos:], encodedValue)
return out, nil
var encoder Encoder
encoder.PutByte(r.Version)
encoder.PutVarInt(oneTimeKeyIDTag)
encoder.PutVarBytes(r.OneTimeKey)
encoder.PutVarInt(identityKeyTag)
encoder.PutVarBytes(r.IdentityKey)
encoder.PutVarInt(baseKeyTag)
encoder.PutVarBytes(r.BaseKey)
encoder.PutVarInt(messageTag)
encoder.PutVarBytes(r.Message)
return encoder.Bytes(), nil
}