From ed5415a8eea2513f77e6baecce857ca552c6a5b8 Mon Sep 17 00:00:00 2001 From: Luca Weiss Date: Sun, 9 Dec 2018 17:52:34 +0100 Subject: [PATCH 01/17] Add simple example --- example/go.mod | 6 ++++++ example/go.sum | 4 ++++ example/main.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 example/go.mod create mode 100644 example/go.sum create mode 100644 example/main.go diff --git a/example/go.mod b/example/go.mod new file mode 100644 index 00000000..0f078fed --- /dev/null +++ b/example/go.mod @@ -0,0 +1,6 @@ +module mautrix-example + +require ( + golang.org/x/net v0.0.0-20181207154023-610586996380 // indirect + maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14 +) diff --git a/example/go.sum b/example/go.sum new file mode 100644 index 00000000..7abc90e0 --- /dev/null +++ b/example/go.sum @@ -0,0 +1,4 @@ +golang.org/x/net v0.0.0-20181207154023-610586996380 h1:zPQexyRtNYBc7bcHmehl1dH6TB3qn8zytv8cBGLDNY0= +golang.org/x/net v0.0.0-20181207154023-610586996380/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14 h1:BU2payokgLMTgVLubdEmHhzjtoe3dR5eLpNLQXVblPY= +maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14/go.mod h1:+thZeequb2CuDc9BwbdsGu0RmALOBavZl3aRHsWm4xQ= diff --git a/example/main.go b/example/main.go new file mode 100644 index 00000000..91a15f1c --- /dev/null +++ b/example/main.go @@ -0,0 +1,55 @@ +// Copyright (C) 2017 Tulir Asokan +// Copyright (C) 2018 Luca Weiss +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package main + +import ( + "flag" + "fmt" + "maunium.net/go/mautrix" +) + +var homeserver = flag.String("homeserver", "https://matrix.org", "Matrix homeserver") +var username = flag.String("username", "", "Matrix username localpart") +var password = flag.String("password", "", "Matrix password") + +func main() { + flag.Parse() + fmt.Println("Logging to", *homeserver, "as", *username) + client, err := mautrix.NewClient(*homeserver, "", "") + if err != nil { + fmt.Println(err) + return + } + resp, err := client.Login(&mautrix.ReqLogin{Type: "m.login.password", User: *username, Password: *password}) + if err != nil { + fmt.Println(err) + return + } + client.SetCredentials(resp.UserID, resp.AccessToken) + + fmt.Println("Login successful") + + syncer := client.Syncer.(*mautrix.DefaultSyncer) + syncer.OnEventType(mautrix.EventMessage, func(evt *mautrix.Event) { + fmt.Printf("<%[1]s> %[4]s (%[2]s/%[3]s)\n", evt.Sender, evt.Type.String(), evt.ID, evt.Content.Body) + }) + + err = client.Sync() + if err != nil { + fmt.Println(err) + } +} From 7ac5924af9784c4f5eb71c665f958f7254769d70 Mon Sep 17 00:00:00 2001 From: Luca Weiss Date: Tue, 8 Sep 2020 17:12:19 +0200 Subject: [PATCH 02/17] Update example for 2020 --- example/go.mod | 7 +++---- example/go.sum | 37 +++++++++++++++++++++++++++++++++---- example/main.go | 23 +++++++++++++++++------ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/example/go.mod b/example/go.mod index 0f078fed..ff9be9bd 100644 --- a/example/go.mod +++ b/example/go.mod @@ -1,6 +1,5 @@ module mautrix-example -require ( - golang.org/x/net v0.0.0-20181207154023-610586996380 // indirect - maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14 -) +go 1.15 + +require maunium.net/go/mautrix v0.7.6 diff --git a/example/go.sum b/example/go.sum index 7abc90e0..4df7e47b 100644 --- a/example/go.sum +++ b/example/go.sum @@ -1,4 +1,33 @@ -golang.org/x/net v0.0.0-20181207154023-610586996380 h1:zPQexyRtNYBc7bcHmehl1dH6TB3qn8zytv8cBGLDNY0= -golang.org/x/net v0.0.0-20181207154023-610586996380/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14 h1:BU2payokgLMTgVLubdEmHhzjtoe3dR5eLpNLQXVblPY= -maunium.net/go/mautrix v0.0.0-20181114121347-3b909a424d14/go.mod h1:+thZeequb2CuDc9BwbdsGu0RmALOBavZl3aRHsWm4xQ= +github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= +github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= +github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= +maunium.net/go/mautrix v0.7.6 h1:jB9oCimPq0mVyolwQBC/9N1fu21AU+Ryq837cLf4gOo= +maunium.net/go/mautrix v0.7.6/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo= diff --git a/example/main.go b/example/main.go index 91a15f1c..2b125778 100644 --- a/example/main.go +++ b/example/main.go @@ -1,5 +1,5 @@ // Copyright (C) 2017 Tulir Asokan -// Copyright (C) 2018 Luca Weiss +// Copyright (C) 2018-2020 Luca Weiss // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by @@ -20,6 +20,8 @@ import ( "flag" "fmt" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "os" ) var homeserver = flag.String("homeserver", "https://matrix.org", "Matrix homeserver") @@ -28,24 +30,33 @@ var password = flag.String("password", "", "Matrix password") func main() { flag.Parse() + if *username == "" || *password == "" { + _, _ = fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + os.Exit(1) + } + fmt.Println("Logging to", *homeserver, "as", *username) client, err := mautrix.NewClient(*homeserver, "", "") if err != nil { fmt.Println(err) return } - resp, err := client.Login(&mautrix.ReqLogin{Type: "m.login.password", User: *username, Password: *password}) + _, err = client.Login(&mautrix.ReqLogin{ + Type: "m.login.password", + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: *username}, + Password: *password, + StoreCredentials: true, + }) if err != nil { fmt.Println(err) return } - client.SetCredentials(resp.UserID, resp.AccessToken) - fmt.Println("Login successful") syncer := client.Syncer.(*mautrix.DefaultSyncer) - syncer.OnEventType(mautrix.EventMessage, func(evt *mautrix.Event) { - fmt.Printf("<%[1]s> %[4]s (%[2]s/%[3]s)\n", evt.Sender, evt.Type.String(), evt.ID, evt.Content.Body) + syncer.OnEventType(event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { + fmt.Printf("<%[1]s> %[4]s (%[2]s/%[3]s)\n", evt.Sender, evt.Type.String(), evt.ID, evt.Content.AsMessage().Body) }) err = client.Sync() From 7a128af68065db65fae4e5f7e05a0cc0ff354dcf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Sep 2020 21:05:55 +0300 Subject: [PATCH 03/17] Update example a bit --- example/main.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/example/main.go b/example/main.go index 2b125778..fa7cc3ee 100644 --- a/example/main.go +++ b/example/main.go @@ -24,23 +24,22 @@ import ( "os" ) -var homeserver = flag.String("homeserver", "https://matrix.org", "Matrix homeserver") +var homeserver = flag.String("homeserver", "", "Matrix homeserver") var username = flag.String("username", "", "Matrix username localpart") var password = flag.String("password", "", "Matrix password") func main() { flag.Parse() - if *username == "" || *password == "" { + if *username == "" || *password == "" || *homeserver == "" { _, _ = fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) flag.PrintDefaults() os.Exit(1) } - fmt.Println("Logging to", *homeserver, "as", *username) + fmt.Println("Logging into", *homeserver, "as", *username) client, err := mautrix.NewClient(*homeserver, "", "") if err != nil { - fmt.Println(err) - return + panic(err) } _, err = client.Login(&mautrix.ReqLogin{ Type: "m.login.password", @@ -49,8 +48,7 @@ func main() { StoreCredentials: true, }) if err != nil { - fmt.Println(err) - return + panic(err) } fmt.Println("Login successful") @@ -61,6 +59,6 @@ func main() { err = client.Sync() if err != nil { - fmt.Println(err) + panic(err) } } From b0713d0ddca66839b01a4aba09e0b16bea1e5feb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Sep 2020 16:40:22 +0300 Subject: [PATCH 04/17] Update README.md --- README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 06e410e1..5a893ddf 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,23 @@ # mautrix-go [![GoDoc](https://godoc.org/maunium.net/go/mautrix?status.svg)](https://godoc.org/maunium.net/go/mautrix) -A Golang Matrix framework. +A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), +[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/tulir/mautrix-whatsapp) +and others. + +Matrix room: [`#maunium:maunium.net`](https://matrix.to/#/#maunium:maunium.net) This project is based on [matrix-org/gomatrix](https://github.com/matrix-org/gomatrix). The original project is licensed under [Apache 2.0](https://github.com/matrix-org/gomatrix/blob/master/LICENSE). +In addition to the basic client API features the original project has, this framework also has: + +* Appservice support (Intent API like mautrix-python, room state storage, etc) +* End-to-end encryption support (incl. interactive SAS verification) +* Structs for parsing event content +* Helpers for parsing and generating Matrix HTML +* Helpers for handling push rules + This project contains modules that are licensed under Apache 2.0: * [maunium.net/go/mautrix/crypto/canonicaljson](crypto/canonicaljson) From 8c3d0bb2c1cf9cd5399b59e98a241a16ce4a272a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 24 Sep 2020 14:26:56 +0300 Subject: [PATCH 05/17] Add utility function for waiting for Megolm session to arrive --- crypto/keyimport.go | 1 + crypto/keysharing.go | 1 + crypto/machine.go | 35 +++++++++++++++++++++++++++++++++++ crypto/sql_store.go | 2 +- crypto/store.go | 5 +++-- 5 files changed, 41 insertions(+), 3 deletions(-) diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 4c8fc163..6d4f88a1 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -118,6 +118,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er if err != nil { return false, errors.Wrap(err, "failed to store imported session") } + mach.markSessionReceived(igs.ID()) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 8c481112..51f9896d 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -134,6 +134,7 @@ func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content * mach.Log.Error("Failed to store new inbound group session: %v", err) return false } + mach.markSessionReceived(content.SessionID) mach.Log.Trace("Created inbound group session %s/%s/%s", content.RoomID, content.SenderKey, content.SessionID) return true } diff --git a/crypto/machine.go b/crypto/machine.go index f1537d7b..ddb39f9f 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -49,6 +49,9 @@ type OlmMachine struct { roomKeyRequestFilled *sync.Map keyVerificationTransactionState *sync.Map + + keyWaiters map[id.SessionID]chan struct{} + keyWaitersLock sync.Mutex } // StateStore is used by OlmMachine to get room state information that's needed for encryption. @@ -311,10 +314,42 @@ func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs) if err != nil { mach.Log.Error("Failed to store new inbound group session: %v", err) + return } + mach.markSessionReceived(sessionID) mach.Log.Trace("Created inbound group session %s/%s/%s", roomID, senderKey, sessionID) } +func (mach *OlmMachine) markSessionReceived(id id.SessionID) { + mach.keyWaitersLock.Lock() + ch, ok := mach.keyWaiters[id] + if ok { + close(ch) + delete(mach.keyWaiters, id) + } + mach.keyWaitersLock.Unlock() +} + +// WaitForSession waits for the given Megolm session to arrive. +func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + mach.keyWaitersLock.Lock() + ch, ok := mach.keyWaiters[sessionID] + if !ok { + ch := make(chan struct{}) + mach.keyWaiters[sessionID] = ch + } + mach.keyWaitersLock.Unlock() + select { + case <-ch: + return true + case <-time.After(timeout): + sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) + // Check if the session somehow appeared in the store without telling us + // We accept withheld sessions as received, as then the decryption attempt will show the error. + return sess != nil || errors.Is(err, ErrGroupSessionWithheld) + } +} + func (mach *OlmMachine) receiveRoomKey(evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) { // TODO nio had a comment saying "handle this better" for the case where evt.Keys.Ed25519 is none? if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" { diff --git a/crypto/sql_store.go b/crypto/sql_store.go index d1cb7f2a..82c6f4a4 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -224,7 +224,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send } else if err != nil { return nil, err } else if withheldCode.Valid { - return nil, ErrGroupSessionWithheld + return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String) } igs := olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) diff --git a/crypto/store.go b/crypto/store.go index b0ede8dd..3cf62086 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -9,6 +9,7 @@ package crypto import ( "encoding/gob" "errors" + "fmt" "os" "sort" "sync" @@ -294,10 +295,10 @@ func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, se gs.lock.Lock() session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] if !ok { - _, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] + withheld, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] gs.lock.Unlock() if ok { - return nil, ErrGroupSessionWithheld + return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheld.Code) } return nil, nil } From 72fc4c16439b5b1522b44aad8945bf3258401af0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 24 Sep 2020 14:41:29 +0300 Subject: [PATCH 06/17] Stop using github.com/pkg/errors --- crypto/decryptmegolm.go | 12 ++++++------ crypto/decryptolm.go | 18 +++++++++--------- crypto/devicelist.go | 5 +++-- crypto/encryptmegolm.go | 14 +++++++------- crypto/encryptolm.go | 5 ++--- crypto/keyexport.go | 3 +-- crypto/keyimport.go | 10 +++++----- crypto/machine.go | 14 +++++++------- crypto/sessions.go | 3 +-- crypto/sql_store.go | 13 ++++++------- crypto/sql_store_upgrade/upgrade.go | 9 +++++---- crypto/verification.go | 20 +++++++++++--------- event/content.go | 3 +-- event/encryption.go | 4 +--- go.mod | 1 - 15 files changed, 65 insertions(+), 69 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index d410d30c..fd07a301 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -8,8 +8,8 @@ package crypto import ( "encoding/json" - - "github.com/pkg/errors" + "errors" + "fmt" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -39,14 +39,14 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro } sess, err := mach.CryptoStore.GetGroupSession(evt.RoomID, content.SenderKey, content.SessionID) if err != nil { - return nil, errors.Wrap(err, "failed to get group session") + return nil, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { mach.checkIfWedged(evt) return nil, NoSessionFound } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - return nil, errors.Wrap(err, "failed to decrypt megolm event") + return nil, fmt.Errorf("failed to decrypt megolm event: %w", err) } else if !mach.CryptoStore.ValidateMessageIndex(content.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp) { return nil, DuplicateMessageIndex } @@ -72,13 +72,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro megolmEvt := &megolmEvent{} err = json.Unmarshal(plaintext, &megolmEvt) if err != nil { - return nil, errors.Wrap(err, "failed to parse megolm payload") + return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != evt.RoomID { return nil, WrongRoom } err = megolmEvt.Content.ParseRaw(megolmEvt.Type) if err != nil && !event.IsUnsupportedContentType(err) { - return nil, errors.Wrap(err, "failed to parse content of megolm payload event") + return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } relatable, ok := megolmEvt.Content.Parsed.(event.Relatable) if ok && content.RelatesTo != nil && relatable.OptionalGetRelatesTo() == nil { diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index ce09e37e..80282dbd 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -8,8 +8,8 @@ package crypto import ( "encoding/json" - - "github.com/pkg/errors" + "errors" + "fmt" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -76,7 +76,7 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey) mach.markDeviceForUnwedging(sender, senderKey) } - return nil, errors.Wrap(err, "failed to decrypt olm event") + return nil, fmt.Errorf("failed to decrypt olm event: %w", err) } // Decryption failed with every known session or no known sessions, let's try to create a new session. @@ -92,13 +92,13 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic session, err := mach.createInboundSession(senderKey, ciphertext) if err != nil { mach.markDeviceForUnwedging(sender, senderKey) - return nil, errors.Wrap(err, "failed to create new session from prekey message") + return nil, fmt.Errorf("failed to create new session from prekey message: %w", err) } mach.Log.Trace("Created inbound session %s for %s/%s (sender key: %s)", session.ID(), sender, deviceID, senderKey) plaintext, err = session.Decrypt(ciphertext, olmType) if err != nil { - return nil, errors.Wrap(err, "failed to decrypt olm event with session created from prekey message") + return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } err = mach.CryptoStore.UpdateSession(senderKey, session) @@ -110,7 +110,7 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic var olmEvt DecryptedOlmEvent err = json.Unmarshal(plaintext, &olmEvt) if err != nil { - return nil, errors.Wrap(err, "failed to parse olm payload") + return nil, fmt.Errorf("failed to parse olm payload: %w", err) } if sender != olmEvt.Sender { return nil, SenderMismatch @@ -122,7 +122,7 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic err = olmEvt.Content.ParseRaw(olmEvt.Type) if err != nil && !event.IsUnsupportedContentType(err) { - return nil, errors.Wrap(err, "failed to parse content of olm payload event") + return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) } olmEvt.SenderKey = senderKey @@ -133,13 +133,13 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic func (mach *OlmMachine) tryDecryptOlmCiphertext(senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { sessions, err := mach.CryptoStore.GetSessions(senderKey) if err != nil { - return nil, errors.Wrapf(err, "failed to get session for %s", senderKey) + return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) } for _, session := range sessions { if olmType == id.OlmMsgTypePreKey { matches, err := session.Internal.MatchesInboundSession(ciphertext) if err != nil { - return nil, errors.Wrap(err, "failed to check if ciphertext matches inbound session") + return nil, fmt.Errorf("failed to check if ciphertext matches inbound session: %w", err) } else if !matches { continue } diff --git a/crypto/devicelist.go b/crypto/devicelist.go index cc9c7b3b..a23df356 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -7,7 +7,8 @@ package crypto import ( - "github.com/pkg/errors" + "errors" + "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" @@ -131,7 +132,7 @@ func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, d ok, err := olm.VerifySignatureJSON(deviceKeys, userID, deviceID, signingKey) if err != nil { - return existing, errors.Wrap(err, "failed to verify signature") + return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { return existing, InvalidKeySignature } diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 8497572a..60cdc8d1 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -8,8 +8,8 @@ package crypto import ( "encoding/json" - - "github.com/pkg/errors" + "errors" + "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -34,8 +34,8 @@ func getRelatesTo(content interface{}) *event.RelatesTo { } type rawMegolmEvent struct { - RoomID id.RoomID `json:"room_id"` - Type event.Type `json:"type"` + RoomID id.RoomID `json:"room_id"` + Type event.Type `json:"type"` Content interface{} `json:"content"` } @@ -52,7 +52,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID) session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { - return nil, errors.Wrap(err, "failed to get outbound group session") + return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { return nil, NoGroupSession } @@ -97,7 +97,7 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e mach.Log.Debug("Sharing group session for room %s to %v", roomID, users) session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { - return errors.Wrap(err, "failed to get previous outbound group session") + return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { return AlreadyShared } @@ -180,7 +180,7 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e mach.Log.Trace("Sending to-device to %d users to share group session for %s", len(toDevice.Messages), roomID) _, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) if err != nil { - return errors.Wrap(err, "failed to share group session") + return fmt.Errorf("failed to share group session: %w", err) } mach.Log.Trace("Sending to-device messages to %d users to report withheld keys in %s", len(toDeviceWithheld.Messages), roomID) diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 09218a39..3408f742 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -8,8 +8,7 @@ package crypto import ( "encoding/json" - - "github.com/pkg/errors" + "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" @@ -69,7 +68,7 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device Timeout: 10 * 1000, }) if err != nil { - return errors.Wrap(err, "failed to claim keys") + return fmt.Errorf("failed to claim keys: %w", err) } for userID, user := range resp.OneTimeKeys { for deviceID, oneTimeKeys := range user { diff --git a/crypto/keyexport.go b/crypto/keyexport.go index bebdc32b..511a56dd 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -20,7 +20,6 @@ import ( "fmt" "math" - "github.com/pkg/errors" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/crypto/olm" @@ -91,7 +90,7 @@ func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) for i, session := range sessions { key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) if err != nil { - return nil, errors.Wrap(err, "failed to export session") + return nil, fmt.Errorf("failed to export session: %w", err) } export[i] = ExportedSession{ Algorithm: id.AlgorithmMegolmV1, diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 6d4f88a1..4144f070 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -15,8 +15,8 @@ import ( "encoding/base64" "encoding/binary" "encoding/json" - - "github.com/pkg/errors" + "errors" + "fmt" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" @@ -85,7 +85,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession, var sessionsJSON []ExportedSession err := json.Unmarshal(unencryptedData, &sessionsJSON) if err != nil { - return nil, errors.Wrap(err, "invalid export json") + return nil, fmt.Errorf("invalid export json: %w", err) } return sessionsJSON, nil } @@ -97,7 +97,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey)) if err != nil { - return false, errors.Wrap(err, "failed to import session") + return false, fmt.Errorf("failed to import session: %w", err) } else if igsInternal.ID() != session.SessionID { return false, ErrMismatchingExportedSessionID } @@ -116,7 +116,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er } err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs) if err != nil { - return false, errors.Wrap(err, "failed to store imported session") + return false, fmt.Errorf("failed to store imported session: %w", err) } mach.markSessionReceived(igs.ID()) return true, nil diff --git a/crypto/machine.go b/crypto/machine.go index ddb39f9f..3b34b039 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -7,11 +7,11 @@ package crypto import ( + "errors" + "fmt" "sync" "time" - "github.com/pkg/errors" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" @@ -50,7 +50,7 @@ type OlmMachine struct { roomKeyRequestFilled *sync.Map keyVerificationTransactionState *sync.Map - keyWaiters map[id.SessionID]chan struct{} + keyWaiters map[id.SessionID]chan struct{} keyWaitersLock sync.Mutex } @@ -251,7 +251,7 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) // get device identity device, err := mach.CryptoStore.GetDevice(userID, deviceID) if err != nil { - return nil, errors.Wrap(err, "failed to get sender device from store") + return nil, fmt.Errorf("failed to get sender device from store: %w", err) } else if device != nil { return device, nil } @@ -261,9 +261,9 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) if device, ok = devices[deviceID]; ok { return device, nil } - return nil, errors.Errorf("Failed to get identity for device %v", deviceID) + return nil, fmt.Errorf("didn't get identity for device %s of %s", deviceID, userID) } - return nil, errors.Errorf("Error fetching devices for user %v", userID) + return nil, fmt.Errorf("didn't get any devices for %s", userID) } // SendEncryptedToDevice sends an Olm-encrypted event to the given user device. @@ -283,7 +283,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, content ev return err } if olmSess == nil { - return errors.Errorf("Did not find created outbound session for device %v", device.DeviceID) + return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID) } encrypted := mach.encryptOlmEvent(olmSess, device, event.ToDeviceForwardedRoomKey, content) diff --git a/crypto/sessions.go b/crypto/sessions.go index c84d64e6..1cfc8296 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -7,11 +7,10 @@ package crypto import ( + "errors" "strings" "time" - "github.com/pkg/errors" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 82c6f4a4..87cc5eb9 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/lib/pq" - "github.com/pkg/errors" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" @@ -183,7 +182,7 @@ func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) e } // UpdateSession replaces the Olm session for a sender in the database. -func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *OlmSession) error { +func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3 AND account_id=$4", sessionBytes, session.UseTime, session.ID(), store.AccountID) @@ -492,18 +491,18 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI err = fmt.Errorf("unsupported dialect %s", store.Dialect) } if err != nil { - return errors.Wrap(err, "failed to add user to tracked users list") + return fmt.Errorf("failed to add user to tracked users list: %w", err) } _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID) if err != nil { _ = tx.Rollback() - return errors.Wrap(err, "failed to delete old devices") + return fmt.Errorf("failed to delete old devices: %w", err) } if len(devices) == 0 { err = tx.Commit() if err != nil { - return errors.Wrap(err, "failed to commit changes (no devices added)") + return fmt.Errorf("failed to commit changes (no devices added): %w", err) } return nil } @@ -533,12 +532,12 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...) if err != nil { _ = tx.Rollback() - return errors.Wrap(err, "failed to insert new devices") + return fmt.Errorf("failed to insert new devices: %w", err) } } err = tx.Commit() if err != nil { - return errors.Wrap(err, "failed to commit changes") + return fmt.Errorf("failed to commit changes: %w", err) } return nil } diff --git a/crypto/sql_store_upgrade/upgrade.go b/crypto/sql_store_upgrade/upgrade.go index aa982bae..80ced0f7 100644 --- a/crypto/sql_store_upgrade/upgrade.go +++ b/crypto/sql_store_upgrade/upgrade.go @@ -2,14 +2,15 @@ package sql_store_upgrade import ( "database/sql" + "errors" "fmt" "strings" - - "github.com/pkg/errors" ) type upgradeFunc func(*sql.Tx, string) error +var ErrUnknownDialect = errors.New("unknown dialect") + var Upgrades = [...]upgradeFunc{ func(tx *sql.Tx, _ string) error { for _, query := range []string{ @@ -153,7 +154,7 @@ var Upgrades = [...]upgradeFunc{ } } } else { - return errors.New("unknown dialect: " + dialect) + return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect) } return nil }, @@ -203,7 +204,7 @@ var Upgrades = [...]upgradeFunc{ return err } } else { - return errors.New("unknown dialect: " + dialect) + return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect) } return nil }, diff --git a/crypto/verification.go b/crypto/verification.go index 5f116f99..e542a6d2 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -11,6 +11,7 @@ package crypto import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "sort" @@ -19,8 +20,6 @@ import ( "sync" "time" - "github.com/pkg/errors" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" @@ -28,11 +27,14 @@ import ( "maunium.net/go/mautrix/id" ) -// ErrUnknownTransaction is returned when a key verification message is received with an unknown transaction ID. -var ErrUnknownTransaction = errors.New("Unknown transaction") - -// ErrUnknownVerificationMethod is returned when the verification method in a received m.key.verification.start is unknown. -var ErrUnknownVerificationMethod = errors.New("Unknown verification method") +var ( + ErrUnknownUserForTransaction = errors.New("unknown user for transaction") + ErrTransactionAlreadyExists = errors.New("transaction already exists") + // ErrUnknownTransaction is returned when a key verification message is received with an unknown transaction ID. + ErrUnknownTransaction = errors.New("unknown transaction") + // ErrUnknownVerificationMethod is returned when the verification method in a received m.key.verification.start is unknown. + ErrUnknownVerificationMethod = errors.New("unknown verification method") +) type VerificationHooks interface { // VerifySASMatch receives the generated SAS and its method, as well as the device that is being verified. @@ -130,7 +132,7 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID) _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - return nil, errors.New(reason) + return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID) } return verState, nil } @@ -544,7 +546,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *DeviceIdentity, hooks Ver verState.startEventCanonical = string(canonical) _, loaded := mach.keyVerificationTransactionState.LoadOrStore(device.UserID.String()+":"+transactionID, verState) if loaded { - return "", errors.New("Transaction already exists") + return "", ErrTransactionAlreadyExists } mach.timeoutAfter(verState, transactionID, timeout) diff --git a/event/content.go b/event/content.go index 5f3abe6b..37831334 100644 --- a/event/content.go +++ b/event/content.go @@ -9,11 +9,10 @@ package event import ( "encoding/gob" "encoding/json" + "errors" "fmt" "reflect" "strings" - - "github.com/pkg/errors" ) // TypeMap is a mapping from event type to the content struct type. diff --git a/event/encryption.go b/event/encryption.go index 58539906..4c9bdac3 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -9,8 +9,6 @@ package event import ( "encoding/json" - "github.com/pkg/errors" - "maunium.net/go/mautrix/id" ) @@ -58,7 +56,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return errors.New("input doesn't look like a JSON string") + return id.InputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } diff --git a/go.mod b/go.mod index ff27c35f..6c8c68d5 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/gorilla/mux v1.7.4 github.com/lib/pq v1.7.0 github.com/mattn/go-sqlite3 v1.14.0 - github.com/pkg/errors v0.9.1 github.com/russross/blackfriday/v2 v2.0.1 github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/stretchr/testify v1.6.1 From 702eaeade978bf18f48713ade47d78836ed5d13c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 24 Sep 2020 15:23:32 +0300 Subject: [PATCH 07/17] Bump version to v0.7.7 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 43c72fa5..c39c0b8a 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.6" +const Version = "v0.7.7" From 819fedddbba5c3955205ef3c42519481e00800cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Sep 2020 11:16:56 +0300 Subject: [PATCH 08/17] Fix nil map assingment when waiting for sessions --- crypto/encryptmegolm.go | 2 +- crypto/machine.go | 2 ++ version.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 60cdc8d1..e817a2c1 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -191,7 +191,7 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err) } - mach.Log.Debug("Group session for %s successfully shared", roomID) + mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID) session.Shared = true return mach.CryptoStore.AddOutboundGroupSession(session) } diff --git a/crypto/machine.go b/crypto/machine.go index 3b34b039..6992bb11 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -83,6 +83,8 @@ func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateS roomKeyRequestFilled: &sync.Map{}, keyVerificationTransactionState: &sync.Map{}, + + keyWaiters: make(map[id.SessionID]chan struct{}), } mach.AllowKeyShare = mach.defaultAllowKeyShare return mach diff --git a/version.go b/version.go index c39c0b8a..ae7481e7 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.7" +const Version = "v0.7.8" From 575e242018a1940b25e11d7ce12e8d8fc0978935 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Oct 2020 01:07:53 +0300 Subject: [PATCH 09/17] Lock olm sessions between encrypting and sending --- crypto/encryptmegolm.go | 92 ++++++++++++++++++++++++++++------------- crypto/machine.go | 5 ++- crypto/sessions.go | 16 +++++++ crypto/sql_store.go | 72 ++++++++++++++++++++++++++------ crypto/store.go | 1 + version.go | 2 +- 6 files changed, 144 insertions(+), 44 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index e817a2c1..c759e8a5 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -89,6 +89,11 @@ func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroup return session } +type deviceSessionWrapper struct { + session *OlmSession + identity *DeviceIdentity +} + // ShareGroupSession shares a group session for a specific room with all the devices of the given user list. // // For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent. @@ -105,8 +110,9 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e session = mach.newOutboundGroupSession(roomID) } - toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + withheldCount := 0 toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper) missingSessions := make(map[id.UserID]map[id.DeviceID]*DeviceIdentity) missingUserSessions := make(map[id.DeviceID]*DeviceIdentity) var fetchKeys []id.UserID @@ -122,9 +128,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e mach.Log.Trace("%s has no devices, skipping", userID) } else { mach.Log.Trace("Trying to encrypt group session %s for %s", session.ID(), userID) - toDevice.Messages[userID] = make(map[id.DeviceID]*event.Content) toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content) - mach.encryptGroupSessionForUser(session, userID, devices, toDevice.Messages[userID], toDeviceWithheld.Messages[userID], missingUserSessions) + olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper) + mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions) + withheldCount += len(toDeviceWithheld.Messages[userID]) if len(missingUserSessions) > 0 { missingSessions[userID] = missingUserSessions missingUserSessions = make(map[id.DeviceID]*DeviceIdentity) @@ -132,9 +139,6 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e if len(toDeviceWithheld.Messages[userID]) == 0 { delete(toDeviceWithheld.Messages, userID) } - if len(toDevice.Messages[userID]) == 0 { - delete(toDevice.Messages, userID) - } } } @@ -146,10 +150,12 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e } } - mach.Log.Trace("Creating missing outbound sessions") - err = mach.createOutboundSessions(missingSessions) - if err != nil { - mach.Log.Error("Failed to create missing outbound sessions: %v", err) + if len(missingSessions) > 0 { + mach.Log.Trace("Creating missing outbound sessions") + err = mach.createOutboundSessions(missingSessions) + if err != nil { + mach.Log.Error("Failed to create missing outbound sessions: %v", err) + } } for userID, devices := range missingSessions { @@ -157,10 +163,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e // No missing sessions continue } - output, ok := toDevice.Messages[userID] + output, ok := olmSessions[userID] if !ok { - output = make(map[id.DeviceID]*event.Content) - toDevice.Messages[userID] = output + output = make(map[id.DeviceID]deviceSessionWrapper) + olmSessions[userID] = output } withheld, ok := toDeviceWithheld.Messages[userID] if !ok { @@ -168,27 +174,29 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e toDeviceWithheld.Messages[userID] = withheld } mach.Log.Trace("Trying to encrypt group session %s for %s (post-fetch retry)", session.ID(), userID) - mach.encryptGroupSessionForUser(session, userID, devices, output, withheld, nil) + mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil) + withheldCount += len(toDeviceWithheld.Messages[userID]) if len(toDeviceWithheld.Messages[userID]) == 0 { delete(toDeviceWithheld.Messages, userID) } - if len(toDevice.Messages[userID]) == 0 { - delete(toDevice.Messages, userID) - } } - mach.Log.Trace("Sending to-device to %d users to share group session for %s", len(toDevice.Messages), roomID) - _, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) + err = mach.encryptAndSendGroupSession(session, olmSessions) if err != nil { return fmt.Errorf("failed to share group session: %w", err) } - mach.Log.Trace("Sending to-device messages to %d users to report withheld keys in %s", len(toDeviceWithheld.Messages), roomID) - // TODO remove the next line once clients support m.room_key.withheld - _, _ = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) - _, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld) - if err != nil { - mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err) + if len(toDeviceWithheld.Messages) > 0 { + mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID) + // TODO remove the next 4 lines once clients support m.room_key.withheld + _, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) + if err != nil { + mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err) + } + _, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld) + if err != nil { + mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err) + } } mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID) @@ -196,7 +204,32 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e return mach.CryptoStore.AddOutboundGroupSession(session) } -func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) { +func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error { + deviceCount := 0 + toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + for userID, sessions := range olmSessions { + if len(sessions) == 0 { + continue + } + output := make(map[id.DeviceID]*event.Content) + toDevice.Messages[userID] = output + for deviceID, device := range sessions { + device.session.Lock() + // We intentionally defer in a loop as it's the safest way of making sure nothing gets locked permanently. + defer device.session.Unlock() + content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) + output[deviceID] = &event.Content{Parsed: content} + deviceCount++ + mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID) + } + } + + mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID()) + _, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) + return err +} + +func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) { for deviceID, device := range devices { userKey := UserDevice{UserID: userID, DeviceID: deviceID} if state := session.Users[userKey]; state != OGSNotShared { @@ -233,10 +266,11 @@ func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession missingOutput[deviceID] = device } } else { - content := mach.encryptOlmEvent(deviceSession, device, event.ToDeviceRoomKey, session.ShareContent()) - output[deviceID] = &event.Content{Parsed: content} + output[deviceID] = deviceSessionWrapper{ + session: deviceSession, + identity: device, + } session.Users[userKey] = OGSAlreadyShared - mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID) } } } diff --git a/crypto/machine.go b/crypto/machine.go index 6992bb11..1a6f8d35 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -288,6 +288,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, content ev return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID) } + olmSess.Lock() + defer olmSess.Unlock() + encrypted := mach.encryptOlmEvent(olmSess, device, event.ToDeviceForwardedRoomKey, content) encryptedContent := &event.Content{Parsed: &encrypted} @@ -319,7 +322,7 @@ func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id return } mach.markSessionReceived(sessionID) - mach.Log.Trace("Created inbound group session %s/%s/%s", roomID, senderKey, sessionID) + mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID) } func (mach *OlmMachine) markSessionReceived(id id.SessionID) { diff --git a/crypto/sessions.go b/crypto/sessions.go index 1cfc8296..57af43ba 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -9,6 +9,7 @@ package crypto import ( "errors" "strings" + "sync" "time" "maunium.net/go/mautrix/crypto/olm" @@ -42,6 +43,20 @@ type OlmSession struct { Internal olm.Session ExpirationMixin id id.SessionID + // This is unexported so gob wouldn't insist on trying to marshaling it + lock sync.Locker +} + +func (session *OlmSession) SetLock(lock sync.Locker) { + session.lock = lock +} + +func (session *OlmSession) Lock() { + session.lock.Lock() +} + +func (session *OlmSession) Unlock() { + session.lock.Unlock() } func (session *OlmSession) ID() id.SessionID { @@ -54,6 +69,7 @@ func (session *OlmSession) ID() id.SessionID { func wrapSession(session *olm.Session) *OlmSession { return &OlmSession{ Internal: *session, + lock: &sync.Mutex{}, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 87cc5eb9..959ecc66 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -10,6 +10,7 @@ import ( "database/sql" "fmt" "strings" + "sync" "github.com/lib/pq" @@ -30,6 +31,9 @@ type SQLCryptoStore struct { SyncToken string PickleKey []byte Account *OlmAccount + + olmSessionCache map[id.SenderKey]map[id.SessionID]*OlmSession + olmSessionCacheLock sync.Mutex } var _ Store = (*SQLCryptoStore)(nil) @@ -44,6 +48,8 @@ func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id PickleKey: pickleKey, AccountID: accountID, DeviceID: deviceID, + + olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession), } } @@ -124,7 +130,12 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { // HasSession returns whether there is an Olm session for the given sender key. func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { - // TODO this may need to be changed if olm sessions start expiring + store.olmSessionCacheLock.Lock() + cache, ok := store.olmSessionCache[key] + store.olmSessionCacheLock.Unlock() + if ok && len(cache) > 0 { + return true + } var sessionID id.SessionID err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", key, store.AccountID).Scan(&sessionID) @@ -136,48 +147,83 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { // GetSessions returns all the known Olm sessions for a sender key. func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) { - rows, err := store.DB.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id", + rows, err := store.DB.Query("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id", key, store.AccountID) if err != nil { return nil, err } list := OlmSessionList{} + store.olmSessionCacheLock.Lock() + defer store.olmSessionCacheLock.Unlock() + cache := store.getOlmSessionCache(key) for rows.Next() { - sess := OlmSession{Internal: *olm.NewBlankSession()} + sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}} var sessionBytes []byte - err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime) + var sessionID id.SessionID + err := rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime) if err != nil { return nil, err + } else if existing, ok := cache[sessionID]; ok { + list = append(list, existing) + } else { + err = sess.Internal.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return nil, err + } + list = append(list, &sess) + cache[sess.ID()] = &sess } - err = sess.Internal.Unpickle(sessionBytes, store.PickleKey) - if err != nil { - return nil, err - } - list = append(list, &sess) } return list, nil } +func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.SessionID]*OlmSession { + data, ok := store.olmSessionCache[key] + if !ok { + data = make(map[id.SessionID]*OlmSession) + store.olmSessionCache[key] = data + } + return data +} + // GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) { - row := store.DB.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1", + store.olmSessionCacheLock.Lock() + defer store.olmSessionCacheLock.Unlock() + + row := store.DB.QueryRow("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1", key, store.AccountID) - sess := OlmSession{Internal: *olm.NewBlankSession()} + + sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}} var sessionBytes []byte - err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime) + var sessionID id.SessionID + + err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } - return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey) + + cache := store.getOlmSessionCache(key) + if oldSess, ok := cache[sessionID]; ok { + return oldSess, nil + } else if err = sess.Internal.Unpickle(sessionBytes, store.PickleKey); err != nil { + return nil, err + } else { + cache[sessionID] = &sess + return &sess, nil + } } // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error { + store.olmSessionCacheLock.Lock() + defer store.olmSessionCacheLock.Unlock() sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6)", session.ID(), key, sessionBytes, session.CreationTime, session.UseTime, store.AccountID) + store.getOlmSessionCache(key)[session.ID()] = session return err } diff --git a/crypto/store.go b/crypto/store.go index 3cf62086..bbec7da7 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -66,6 +66,7 @@ var ErrGroupSessionWithheld = errors.New("group session has been withheld") // General implementation details: // * Get methods should not return errors if the requested data does not exist in the store, they should simply return nil. // * Update methods may assume that the pointer is the same as what has earlier been added to or fetched from the store. +// * OlmSessions should be cached so that the mutex works. Alternatively, implementations can use OlmSession.SetLock to provide a custom mutex implementation. type Store interface { // Flush ensures that everything in the store is persisted to disk. // This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately. diff --git a/version.go b/version.go index ae7481e7..18e1925a 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.8" +const Version = "v0.7.9" From 1b3818a2b4ecb1fe0c05f6870519ad6aee4024ac Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Oct 2020 22:36:00 +0300 Subject: [PATCH 10/17] Add session ID to decrypt error --- crypto/decryptmegolm.go | 4 ++-- version.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index fd07a301..177da11e 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -18,7 +18,7 @@ import ( var ( IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate message index") + DuplicateMessageIndex = errors.New("duplicate megolm message index") WrongRoom = errors.New("encrypted megolm event is not intended for this room") DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") ) @@ -42,7 +42,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro return nil, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { mach.checkIfWedged(evt) - return nil, NoSessionFound + return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { diff --git a/version.go b/version.go index 18e1925a..ff66310a 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.9" +const Version = "v0.7.10" From a31302832e9d5494e505c1ba50164c58e70c19b7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Oct 2020 21:38:30 +0300 Subject: [PATCH 11/17] Move creating inbound session log to debug level --- crypto/decryptolm.go | 2 +- version.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 80282dbd..6dd6f38e 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -94,7 +94,7 @@ func (mach *OlmMachine) decryptOlmCiphertext(sender id.UserID, deviceID id.Devic mach.markDeviceForUnwedging(sender, senderKey) return nil, fmt.Errorf("failed to create new session from prekey message: %w", err) } - mach.Log.Trace("Created inbound session %s for %s/%s (sender key: %s)", session.ID(), sender, deviceID, senderKey) + mach.Log.Debug("Created inbound olm session %s for %s/%s (sender key: %s)", session.ID(), sender, deviceID, senderKey) plaintext, err = session.Decrypt(ciphertext, olmType) if err != nil { diff --git a/version.go b/version.go index ff66310a..991817e1 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.10" +const Version = "v0.7.11" From af317828466d6cbb0cd703d728ec21ca5677f64b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Oct 2020 18:22:44 +0300 Subject: [PATCH 12/17] Update go.sum --- go.mod | 1 + go.sum | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 6c8c68d5..78fea472 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.6.1 github.com/tidwall/gjson v1.6.0 github.com/tidwall/sjson v1.1.1 + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 golang.org/x/net v0.0.0-20200602114024-627f9648deb9 gopkg.in/yaml.v2 v2.3.0 maunium.net/go/maulogger/v2 v2.1.1 diff --git a/go.sum b/go.sum index d8dd1347..d36b3394 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,6 @@ github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY= github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= @@ -28,6 +26,7 @@ github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U= github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From 4da8e38b4a1c5d5439b54c0d3ab8c762c2ad3093 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Oct 2020 18:21:39 +0300 Subject: [PATCH 13/17] Add nil check to HTTPError.Is --- error.go | 2 +- version.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/error.go b/error.go index 42a7a071..4d8c9db7 100644 --- a/error.go +++ b/error.go @@ -69,7 +69,7 @@ type HTTPError struct { } func (e HTTPError) Is(err error) bool { - return errors.Is(e.RespError, err) || errors.Is(e.WrappedError, err) + return (e.RespError != nil && errors.Is(e.RespError, err)) || (e.WrappedError != nil && errors.Is(e.WrappedError, err)) } func (e HTTPError) IsStatus(code int) bool { diff --git a/version.go b/version.go index 991817e1..5a1b7071 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.11" +const Version = "v0.7.12" From 03c15a42973bc0d349933acd87ca33d769589b24 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 Oct 2020 16:51:13 +0300 Subject: [PATCH 14/17] Add appservice auth type constant --- requests.go | 2 ++ responses.go | 11 ++++++++++- version.go | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/requests.go b/requests.go index e36e2fe5..2c86dc03 100644 --- a/requests.go +++ b/requests.go @@ -19,6 +19,8 @@ const ( AuthTypeMSISDN = "m.login.msisdn" AuthTypeToken = "m.login.token" AuthTypeDummy = "m.login.dummy" + + AuthTypeAppservice = "uk.half-shot.msc2778.login.application_service" ) type IdentifierType string diff --git a/responses.go b/responses.go index dd9b245b..a7b8108e 100644 --- a/responses.go +++ b/responses.go @@ -119,10 +119,19 @@ type RespRegister struct { type RespLoginFlows struct { Flows []struct { - Type string `json:"type"` + Type AuthType `json:"type"` } `json:"flows"` } +func (rlf *RespLoginFlows) HasFlow(flowType AuthType) bool { + for _, flow := range rlf.Flows { + if flow.Type == flowType { + return true + } + } + return false +} + // RespLogin is the JSON response for http://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-login type RespLogin struct { AccessToken string `json:"access_token"` diff --git a/version.go b/version.go index 5a1b7071..5859cde9 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mautrix -const Version = "v0.7.12" +const Version = "v0.7.13" From 1e482c2e20733f9f1b4ce7c6089e7826b0cf5400 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Oct 2020 14:45:57 +0200 Subject: [PATCH 15/17] Support authorization header auth in appservices (MSC2832) --- appservice/http.go | 46 ++++++++++++++++++++++++++---------------- appservice/protocol.go | 8 ++++---- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index bcdbbc1e..dec75597 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -11,6 +11,7 @@ import ( "encoding/json" "io/ioutil" "net/http" + "strings" "time" "github.com/gorilla/mux" @@ -51,27 +52,38 @@ func (as *AppService) Stop() { return } - ctx, _ := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() _ = as.server.Shutdown(ctx) as.server = nil } // CheckServerToken checks if the given request originated from the Matrix homeserver. -func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) bool { - query := r.URL.Query() - val, ok := query["access_token"] - if !ok { +func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { + authHeader := r.Header.Get("Authorization") + if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") { + isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken + } else { + queryToken := r.URL.Query().Get("access_token") + if len(queryToken) > 0 { + isValid = queryToken == as.Registration.ServerToken + } else { + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Missing access token", + }.Write(w) + return + } + } + if !isValid { Error{ - ErrorCode: ErrForbidden, + ErrorCode: ErrUnknownToken, HTTPStatus: http.StatusForbidden, - Message: "Bad token supplied.", + Message: "Incorrect access token", }.Write(w) - return false } - for _, str := range val { - return str == as.Registration.ServerToken - } - return false + return } // PutTransaction handles a /transactions PUT call from the homeserver. @@ -86,7 +98,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { Error{ ErrorCode: ErrNoTransactionID, HTTPStatus: http.StatusBadRequest, - Message: "Missing transaction ID.", + Message: "Missing transaction ID", }.Write(w) return } @@ -94,9 +106,9 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil || len(body) == 0 { Error{ - ErrorCode: ErrNoBody, + ErrorCode: ErrNotJSON, HTTPStatus: http.StatusBadRequest, - Message: "Missing request body.", + Message: "Missing request body", }.Write(w) return } @@ -111,9 +123,9 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { if err != nil { as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err) Error{ - ErrorCode: ErrInvalidJSON, + ErrorCode: ErrBadJSON, HTTPStatus: http.StatusBadRequest, - Message: "Failed to parse body JSON.", + Message: "Failed to parse body JSON", }.Write(w) } else { for _, evt := range eventList.Events { diff --git a/appservice/protocol.go b/appservice/protocol.go index e6ec1f4c..8646594b 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -54,13 +54,13 @@ type ErrorCode string // Native ErrorCodes const ( - ErrForbidden ErrorCode = "M_FORBIDDEN" - ErrUnknown ErrorCode = "M_UNKNOWN" + ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" + ErrBadJSON ErrorCode = "M_BAD_JSON" + ErrNotJSON ErrorCode = "M_NOT_JSON" + ErrUnknown ErrorCode = "M_UNKNOWN" ) // Custom ErrorCodes const ( ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" - ErrNoBody ErrorCode = "NET.MAUNIUM.NO_REQUEST_BODY" - ErrInvalidJSON ErrorCode = "NET.MAUNIUM.INVALID_JSON" ) From 984d218436794172c7a563ce0dcfbc0ebfb2c243 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Oct 2020 14:55:08 +0200 Subject: [PATCH 16/17] Add content-type headers --- appservice/protocol.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/appservice/protocol.go b/appservice/protocol.go index 8646594b..fa465000 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -23,12 +23,14 @@ type EventListener func(evt *event.Event) // WriteBlankOK writes a blank OK message as a reply to a HTTP request. func WriteBlankOK(w http.ResponseWriter) { + w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("{}")) } // Respond responds to a HTTP request with a JSON object. func Respond(w http.ResponseWriter, data interface{}) error { + w.Header().Add("Content-Type", "application/json") dataStr, err := json.Marshal(data) if err != nil { return err @@ -45,6 +47,7 @@ type Error struct { } func (err Error) Write(w http.ResponseWriter) { + w.Header().Add("Content-Type", "application/json") w.WriteHeader(err.HTTPStatus) _ = Respond(w, &err) } From 488e1811bbb513e9c5404aa85c74ef28f634a983 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Oct 2020 18:31:14 +0200 Subject: [PATCH 17/17] Add support for receiving MSC2409 ephemeral events --- appservice/http.go | 32 ++++++++++++++++++++++++++++++++ appservice/protocol.go | 4 +++- appservice/registration.go | 1 + event/type.go | 6 +++--- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index dec75597..ad4b2392 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -128,6 +128,14 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { Message: "Failed to parse body JSON", }.Write(w) } else { + if as.Registration.EphemeralEvents { + if eventList.EphemeralEvents != nil { + as.handleEvents(eventList.EphemeralEvents, event.EphemeralEventType) + } else if eventList.SoruEphemeralEvents != nil { + as.handleEvents(eventList.SoruEphemeralEvents, event.EphemeralEventType) + } + } + as.handleEvents(eventList.Events, event.UnknownEventType) for _, evt := range eventList.Events { if evt.StateKey != nil { evt.Type.Class = event.StateEventType @@ -146,6 +154,30 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { as.lastProcessedTransaction = txnID } +func (as *AppService) handleEvents(evts []*event.Event, typeClass event.TypeClass) { + for _, evt := range evts { + if typeClass != event.UnknownEventType { + evt.Type.Class = typeClass + } else if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + if evt.ID != "" { + as.Log.Debugfln("Failed to parse content of %s (%s): %v", evt.ID, evt.Type.Type, err) + } else { + as.Log.Debugfln("Failed to parse content of a %s: %v", evt.Type.Type, err) + } + } + if evt.Type.IsState() { + as.UpdateState(evt) + } + as.Events <- evt + } +} + // GetRoom handles a /rooms GET call from the homeserver. func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { if !as.CheckServerToken(w, r) { diff --git a/appservice/protocol.go b/appservice/protocol.go index fa465000..b6cc13e6 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -15,7 +15,9 @@ import ( // EventList contains a list of events. type EventList struct { - Events []*event.Event `json:"events"` + Events []*event.Event `json:"events"` + EphemeralEvents []*event.Event `json:"ephemeral"` + SoruEphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral"` } // EventListener is a function that receives events. diff --git a/appservice/registration.go b/appservice/registration.go index c7a9f1ba..729c951c 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -23,6 +23,7 @@ type Registration struct { SenderLocalpart string `yaml:"sender_localpart"` RateLimited bool `yaml:"rate_limited"` Namespaces Namespaces `yaml:"namespaces"` + EphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. diff --git a/event/type.go b/event/type.go index 3f608056..778f46c1 100644 --- a/event/type.go +++ b/event/type.go @@ -32,8 +32,10 @@ func (tc TypeClass) Name() string { } const ( + // Unknown events + UnknownEventType TypeClass = iota // Normal message events - MessageEventType TypeClass = iota + MessageEventType // State events StateEventType // Ephemeral events @@ -42,8 +44,6 @@ const ( AccountDataEventType // Device-to-device events ToDeviceEventType - // Unknown events - UnknownEventType ) type Type struct {