From 9cf796640c6aed253234851044d7f5e52d62795d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 27 Apr 2021 08:39:19 +0200 Subject: [PATCH] Don't require certain order of own/other "joined" events. --- hub_test.go | 11 ++--------- testclient_test.go | 29 +++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/hub_test.go b/hub_test.go index 9de528e..ca3afa0 100644 --- a/hub_test.go +++ b/hub_test.go @@ -1700,10 +1700,7 @@ func TestJoinMultiple(t *testing.T) { } // We will receive a "joined" event for the first and the second client. - if err := client2.RunUntilJoined(ctx, hello1.Hello); err != nil { - t.Error(err) - } - if err := client2.RunUntilJoined(ctx, hello2.Hello); err != nil { + if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { t.Error(err) } // The first client will also receive a "joined" event from the second client. @@ -2089,11 +2086,7 @@ func TestClientTakeoverRoomSession(t *testing.T) { // The new client will receive "joined" events for the existing client3 and // himself. - if err := client2.RunUntilJoined(ctx, hello3.Hello); err != nil { - t.Error(err) - } - - if err := client2.RunUntilJoined(ctx, hello2.Hello); err != nil { + if err := client2.RunUntilJoined(ctx, hello3.Hello, hello2.Hello); err != nil { t.Error(err) } diff --git a/testclient_test.go b/testclient_test.go index 93a06b7..f67890c 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -498,7 +498,7 @@ func (c *TestClient) checkMessageJoined(message *ServerMessage, hello *HelloServ return c.checkMessageJoinedSession(message, hello.SessionId, hello.UserId) } -func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId string, userId string) error { +func (c *TestClient) checkSingleMessageJoined(message *ServerMessage) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Target != "room" { @@ -507,6 +507,13 @@ func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId return fmt.Errorf("Expected event type join, got %+v", message.Event) } else if len(message.Event.Join) != 1 { return fmt.Errorf("Expected one join event entry, got %+v", message.Event) + } + return nil +} + +func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId string, userId string) error { + if err := c.checkSingleMessageJoined(message); err != nil { + return err } else { evt := message.Event.Join[0] if sessionId != "" && evt.SessionId != sessionId { @@ -520,12 +527,22 @@ func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId return nil } -func (c *TestClient) RunUntilJoined(ctx context.Context, hello *HelloServerMessage) error { - if message, err := c.RunUntilMessage(ctx); err != nil { - return err - } else { - return c.checkMessageJoined(message, hello) +func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { + for len(hello) > 0 { + if message, err := c.RunUntilMessage(ctx); err != nil { + return err + } else { + if err := c.checkSingleMessageJoined(message); err != nil { + return err + } + for idx, h := range hello { + if err := c.checkMessageJoined(message, h); err == nil { + hello = append(hello[:idx], hello[idx+1:]...) + } + } + } } + return nil } func (c *TestClient) checkMessageRoomLeave(message *ServerMessage, hello *HelloServerMessage) error {