diff --git a/testclient_test.go b/testclient_test.go index 9e9a3b9..92bcf10 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -566,19 +566,30 @@ func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMe return err } - if err := c.checkSingleMessageJoined(message); err != nil { + if err := checkMessageType(message, "event"); err != nil { return err + } else if message.Event.Target != "room" { + return fmt.Errorf("Expected event target room, got %+v", message.Event) + } else if message.Event.Type != "join" { + return fmt.Errorf("Expected event type join, got %+v", message.Event) } - found := false - for idx, h := range hello { - if err := c.checkMessageJoined(message, h); err == nil { - hello = append(hello[:idx], hello[idx+1:]...) - found = true - break + + for len(message.Event.Join) > 0 { + found := false + loop: + for idx, h := range hello { + for idx2, evt := range message.Event.Join { + if evt.SessionId == h.SessionId && evt.UserId == h.UserId { + hello = append(hello[:idx], hello[idx+1:]...) + message.Event.Join = append(message.Event.Join[:idx2], message.Event.Join[idx2+1:]...) + found = true + break loop + } + } + } + if !found { + return fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) } - } - if !found { - return fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) } } return nil