diff --git a/federation.go b/federation.go index 3aab6e5..9b609c8 100644 --- a/federation.go +++ b/federation.go @@ -454,7 +454,8 @@ func (c *FederationClient) sendHelloLocked(auth *FederationAuthParams) error { Id: c.helloMsgId, Type: "hello", Hello: &HelloClientMessage{ - Version: HelloVersionV2, + Version: HelloVersionV2, + Features: c.session.GetFeatures(), }, } if resumeId := c.resumeId; resumeId != "" { diff --git a/federation_test.go b/federation_test.go index ea4a504..c185b5c 100644 --- a/federation_test.go +++ b/federation_test.go @@ -82,11 +82,13 @@ func Test_Federation(t *testing.T) { client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() - require.NoError(client1.SendHelloV2(testDefaultUserId + "1")) + features1 := []string{"one", "two", "three"} + require.NoError(client1.SendHelloV2WithFeatures(testDefaultUserId+"1", features1)) client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() - require.NoError(client2.SendHelloV2(testDefaultUserId + "2")) + features2 := []string{"1", "2", "3"} + require.NoError(client2.SendHelloV2WithFeatures(testDefaultUserId+"2", features2)) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -148,6 +150,7 @@ func Test_Federation(t *testing.T) { assert.NotEqual(hello2.Hello.SessionId, remoteSessionId) assert.Equal(testDefaultUserId+"2", evt.UserId) assert.True(evt.Federated) + assert.Equal(features2, evt.Features) } // The client2 will see its own session id, not the one from the remote server. @@ -252,6 +255,7 @@ func Test_Federation(t *testing.T) { assert.NotEqual(hello2.Hello.SessionId, remoteSessionId) assert.Equal(testDefaultUserId+"2", evt.UserId) assert.True(evt.Federated) + assert.Equal(features2, evt.Features) } assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)) @@ -424,7 +428,7 @@ func Test_Federation(t *testing.T) { client4 := NewTestClient(t, server2, hub1) defer client4.CloseWithBye() - require.NoError(client4.SendHelloV2(testDefaultUserId + "4")) + require.NoError(client4.SendHelloV2WithFeatures(testDefaultUserId+"4", features2)) hello4, err := client4.RunUntilHello(ctx) require.NoError(err) @@ -468,6 +472,7 @@ func Test_Federation(t *testing.T) { assert.NotEqual(hello4.Hello.SessionId, remoteSessionId) assert.Equal(testDefaultUserId+"4", evt.UserId) assert.True(evt.Federated) + assert.Equal(features2, evt.Features) } assert.NoError(client2.RunUntilJoined(ctx, &HelloServerMessage{ diff --git a/testclient_test.go b/testclient_test.go index 464d986..9078bcc 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -389,8 +389,12 @@ func (c *TestClient) SendHelloV1(userid string) error { } func (c *TestClient) SendHelloV2(userid string) error { + return c.SendHelloV2WithFeatures(userid, nil) +} + +func (c *TestClient) SendHelloV2WithFeatures(userid string, features []string) error { now := time.Now() - return c.SendHelloV2WithTimes(userid, now, now.Add(time.Minute)) + return c.SendHelloV2WithTimesAndFeatures(userid, now, now.Add(time.Minute), features) } func (c *TestClient) CreateHelloV2TokenWithUserdata(userid string, issuedAt time.Time, expiresAt time.Time, userdata map[string]interface{}) (string, error) { @@ -434,13 +438,17 @@ func (c *TestClient) CreateHelloV2Token(userid string, issuedAt time.Time, expir } func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error { + return c.SendHelloV2WithTimesAndFeatures(userid, issuedAt, expiresAt, nil) +} + +func (c *TestClient) SendHelloV2WithTimesAndFeatures(userid string, issuedAt time.Time, expiresAt time.Time, features []string) error { tokenString, err := c.CreateHelloV2Token(userid, issuedAt, expiresAt) require.NoError(c.t, err) params := HelloV2AuthParams{ Token: tokenString, } - return c.SendHelloParams(c.server.URL, HelloVersionV2, "", nil, params) + return c.SendHelloParams(c.server.URL, HelloVersionV2, "", features, params) } func (c *TestClient) SendHelloResume(resumeId string) error {