diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 93fe6409..7c8a7542 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -15,6 +15,8 @@ import ( "fmt" "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" @@ -27,7 +29,24 @@ var ( NoGroupSession = errors.New("no group session created") ) -func getRelatesTo(content interface{}) *event.RelatesTo { +func getRawJSON[T any](content json.RawMessage, path ...string) *T { + value := gjson.GetBytes(content, exgjson.Path(path...)) + if !value.IsObject() { + return nil + } + var result T + err := json.Unmarshal([]byte(value.Raw), &result) + if err != nil { + return nil + } + return &result +} + +func getRelatesTo(content any) *event.RelatesTo { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -39,7 +58,11 @@ func getRelatesTo(content interface{}) *event.RelatesTo { return nil } -func getMentions(content interface{}) *event.Mentions { +func getMentions(content any) *event.Mentions { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.Mentions](contentJSON, "m.mentions") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed