From 63b658574a748190743df41825bf574ee059afe4 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 29 Sep 2025 10:56:03 +0200 Subject: [PATCH] Rewrite chat actor information for federation. --- api/stringmap.go | 9 ++++ api/stringmap_test.go | 25 +++++++++ clientsession_test.go | 60 ++++++++++++++++++--- federation.go | 121 ++++++++++++++++++++++++++++++++++-------- testutils_test.go | 18 +++++++ 5 files changed, 204 insertions(+), 29 deletions(-) diff --git a/api/stringmap.go b/api/stringmap.go index d4f7a73..749a3dc 100644 --- a/api/stringmap.go +++ b/api/stringmap.go @@ -24,6 +24,15 @@ package api // StringMap maps string keys to arbitrary values. type StringMap map[string]any +func (m StringMap) GetStringMap(key string) (StringMap, bool) { + v, found := m[key] + if !found { + return nil, false + } + + return ConvertStringMap(v) +} + func ConvertStringMap(ob any) (StringMap, bool) { if ob == nil { return nil, true diff --git a/api/stringmap_test.go b/api/stringmap_test.go index 1d45a20..941c7f9 100644 --- a/api/stringmap_test.go +++ b/api/stringmap_test.go @@ -88,3 +88,28 @@ func TestGetStringMapString(t *testing.T) { _, ok = GetStringMapString[StringMapTestString](m, "invalid") assert.False(ok) } + +func TestGetStringMapStringMap(t *testing.T) { + assert := assert.New(t) + + m := StringMap{ + "foo": map[string]any{ + "bar": 1, + }, + "bar": StringMap{ + "baz": 2, + }, + } + if v, ok := m.GetStringMap("foo"); assert.True(ok) { + assert.EqualValues(map[string]any{ + "bar": 1, + }, v) + } + if v, ok := m.GetStringMap("bar"); assert.True(ok) { + assert.EqualValues(map[string]any{ + "baz": 2, + }, v) + } + v, ok := m.GetStringMap("baz") + assert.False(ok, "expected missing entry, got %+v", v) +} diff --git a/clientsession_test.go b/clientsession_test.go index 7dc6205..12ef1f2 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -339,11 +339,57 @@ func TestFeatureChatRelayFederation(t *testing.T) { room := hub1.getRoom(roomId) require.NotNil(room) - chatComment := api.StringMap{ - "foo": "bar", - "baz": true, - "lala": map[string]any{ - "one": "eins", + chatComment := map[string]any{ + "actorId": hello1.Hello.UserId, + "actorType": "users", + "lastEditActorId": hello1.Hello.UserId, + "lastEditActorType": "users", + "parent": map[string]any{ + "actorId": hello1.Hello.UserId, + "actorType": "users", + "lastEditActorId": hello1.Hello.UserId, + "lastEditActorType": "users", + }, + "messageParameters": map[string]map[string]any{ + "mention-local-user": { + "type": "user", + "id": hello1.Hello.UserId, + "name": "User 1", + }, + "mention-remote-user": { + "type": "user", + "id": hello2.Hello.UserId, + "name": "User 2", + "mention-id": "federated_user/" + hello2.Hello.UserId + "@" + getCloudUrl(server2.URL), + "server": server2.URL, + }, + }, + } + federatedChatComment := map[string]any{ + "actorId": hello1.Hello.UserId + "@" + getCloudUrl(server1.URL), + "actorType": "federated_users", + "lastEditActorId": hello1.Hello.UserId + "@" + getCloudUrl(server1.URL), + "lastEditActorType": "federated_users", + "parent": map[string]any{ + "actorId": hello1.Hello.UserId + "@" + getCloudUrl(server1.URL), + "actorType": "federated_users", + "lastEditActorId": hello1.Hello.UserId + "@" + getCloudUrl(server1.URL), + "lastEditActorType": "federated_users", + }, + "messageParameters": map[string]map[string]any{ + "mention-local-user": { + "type": "user", + "id": hello1.Hello.UserId, + "mention-id": hello1.Hello.UserId, + "name": "User 1", + "server": server1.URL, + }, + "mention-remote-user": { + "type": "user", + "id": hello2.Hello.UserId, + "name": "User 2", + "mention-id": "federated_user/" + hello2.Hello.UserId + "@" + getCloudUrl(server2.URL), + }, }, } message := api.StringMap{ @@ -374,7 +420,7 @@ func TestFeatureChatRelayFederation(t *testing.T) { if err := json.Unmarshal(msg.Data, &data); assert.NoError(err) { assert.Equal("chat", data["type"], "invalid type entry in %+v", data) if chat, found := api.GetStringMapEntry[map[string]any](data, "chat"); assert.True(found, "chat entry is missing in %+v", data) { - assert.EqualValues(chatComment, chat["comment"]) + AssertEqualSerialized(t, chatComment, chat["comment"]) _, found := chat["refresh"] assert.False(found, "refresh should not be included") } @@ -389,7 +435,7 @@ func TestFeatureChatRelayFederation(t *testing.T) { assert.Equal("chat", data["type"], "invalid type entry in %+v", data) if chat, found := api.GetStringMapEntry[map[string]any](data, "chat"); assert.True(found, "chat entry is missing in %+v", data) { if feature { - assert.EqualValues(chatComment, chat["comment"]) + AssertEqualSerialized(t, federatedChatComment, chat["comment"]) _, found := chat["refresh"] assert.False(found, "refresh should not be included") } else { diff --git a/federation.go b/federation.go index 1f9747d..2d5cb4a 100644 --- a/federation.go +++ b/federation.go @@ -59,11 +59,7 @@ func isClosedError(err error) bool { strings.Contains(err.Error(), net.ErrClosed.Error()) } -func getCloudUrl(s string) string { - var found bool - if s, found = strings.CutPrefix(s, "https://"); !found { - s = strings.TrimPrefix(s, "http://") - } +func getCloudUrlWithoutPath(s string) string { if pos := strings.Index(s, "/ocs/v"); pos != -1 { s = s[:pos] } else { @@ -72,6 +68,14 @@ func getCloudUrl(s string) string { return s } +func getCloudUrl(s string) string { + var found bool + if s, found = strings.CutPrefix(s, "https://"); !found { + s = strings.TrimPrefix(s, "http://") + } + return getCloudUrlWithoutPath(s) +} + type FederationClient struct { hub *Hub session *ClientSession @@ -604,26 +608,73 @@ func (c *FederationClient) joinRoom() error { }) } -func (c *FederationClient) updateEventUsers(users []api.StringMap, localSessionId PublicSessionId, remoteSessionId PublicSessionId) { - localCloudUrl := "@" + getCloudUrl(c.session.BackendUrl()) - localCloudUrlLen := len(localCloudUrl) - remoteCloudUrl := "@" + getCloudUrl(c.federation.Load().NextcloudUrl) - checkSessionId := true - for _, u := range users { - if actorType, found := api.GetStringMapEntry[string](u, "actorType"); found { - if actorId, found := api.GetStringMapEntry[string](u, "actorId"); found { - switch actorType { - case ActorTypeFederatedUsers: - if strings.HasSuffix(actorId, localCloudUrl) { - u["actorId"] = actorId[:len(actorId)-localCloudUrlLen] - u["actorType"] = ActorTypeUsers - } - case ActorTypeUsers: - u["actorId"] = actorId + remoteCloudUrl - u["actorType"] = ActorTypeFederatedUsers +func (c *FederationClient) updateActor(u api.StringMap, actorIdKey, actorTypeKey string, localCloudUrl string, remoteCloudUrl string) (changed bool) { + if actorType, found := api.GetStringMapEntry[string](u, actorTypeKey); found { + if actorId, found := api.GetStringMapEntry[string](u, actorIdKey); found { + switch actorType { + case ActorTypeFederatedUsers: + if strings.HasSuffix(actorId, localCloudUrl) { + u[actorIdKey] = actorId[:len(actorId)-len(localCloudUrl)] + u[actorTypeKey] = ActorTypeUsers + changed = true + } + case ActorTypeUsers: + u[actorIdKey] = actorId + remoteCloudUrl + u[actorTypeKey] = ActorTypeFederatedUsers + changed = true + } + } + } + return +} + +func (c *FederationClient) updateComment(comment api.StringMap, localCloudUrl string, remoteCloudUrl string) bool { + changed := c.updateActor(comment, "actorId", "actorType", localCloudUrl, remoteCloudUrl) + if c.updateActor(comment, "lastEditActorId", "lastEditActorType", localCloudUrl, remoteCloudUrl) { + changed = true + } + + if params, found := api.GetStringMapEntry[map[string]any](comment, "messageParameters"); found { + localUrl := getCloudUrlWithoutPath(c.session.BackendUrl()) + remoteUrl := getCloudUrlWithoutPath(c.federation.Load().NextcloudUrl) + for key, paramOb := range params { + if !strings.HasPrefix(key, "mention-") { + // Only need to process mentions. + continue + } + + param, ok := api.ConvertStringMap(paramOb) + if !ok { + continue + } + + if ptype, found := api.GetStringMapString[string](param, "type"); found && ptype == "user" { + if server, found := api.GetStringMapString[string](param, "server"); found && server == localUrl { + delete(param, "server") + params[key] = param + changed = true + continue + } + + if _, found := api.GetStringMapString[string](param, "mention-id"); !found { + param["mention-id"] = param["id"] + param["server"] = remoteUrl + params[key] = param + changed = true + continue } } } + } + return changed +} + +func (c *FederationClient) updateEventUsers(users []api.StringMap, localSessionId PublicSessionId, remoteSessionId PublicSessionId) { + localCloudUrl := "@" + getCloudUrl(c.session.BackendUrl()) + remoteCloudUrl := "@" + getCloudUrl(c.federation.Load().NextcloudUrl) + checkSessionId := true + for _, u := range users { + c.updateActor(u, "actorId", "actorType", localCloudUrl, remoteCloudUrl) if checkSessionId { key := "sessionId" @@ -732,6 +783,32 @@ func (c *FederationClient) processMessage(msg *ServerMessage) { if c.changeRoomId.Load() && msg.Event.Message.RoomId == remoteRoomId { msg.Event.Message.RoomId = roomId } + if msg.Event.Type == "message" && msg.Event.Message != nil { + if data, err := msg.Event.Message.GetData(); err == nil { + if data.Type == "chat" && data.Chat != nil && len(data.Chat.Comment) > 0 { + var comment api.StringMap + if err := json.Unmarshal(data.Chat.Comment, &comment); err == nil { + localCloudUrl := "@" + getCloudUrl(c.session.BackendUrl()) + remoteCloudUrl := "@" + getCloudUrl(c.federation.Load().NextcloudUrl) + changed := c.updateComment(comment, localCloudUrl, remoteCloudUrl) + if parent, found := comment.GetStringMap("parent"); found { + if c.updateComment(parent, localCloudUrl, remoteCloudUrl) { + comment["parent"] = parent + changed = true + } + } + if changed { + if encoded, err := json.Marshal(comment); err == nil { + data.Chat.Comment = encoded + if encoded, err = json.Marshal(data); err == nil { + msg.Event.Message.Data = encoded + } + } + } + } + } + } + } } case "roomlist": switch msg.Event.Type { diff --git a/testutils_test.go b/testutils_test.go index dbf91e3..60932af 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -24,6 +24,7 @@ package signaling import ( "bytes" "context" + "encoding/json" "io" "os" "os/signal" @@ -32,6 +33,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -132,3 +134,19 @@ func MustSucceed3[T any, A1 any, A2 any, A3 any](t *testing.T, f func(a1 A1, a2 } return result } + +func AssertEqualSerialized(t *testing.T, expected any, actual any, msgAndArgs ...any) bool { + t.Helper() + + e, err := json.MarshalIndent(expected, "", " ") + if !assert.NoError(t, err) { + return false + } + + a, err := json.MarshalIndent(actual, "", " ") + if !assert.NoError(t, err) { + return false + } + + return assert.Equal(t, string(a), string(e), msgAndArgs...) +}