diff --git a/testclient_test.go b/testclient_test.go index 92bcf10..dbb6b1c 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -559,19 +559,21 @@ func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId return nil } -func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { +func (c *TestClient) RunUntilJoinedAndReturnIgnored(ctx context.Context, hello ...*HelloServerMessage) ([]*ServerMessage, error) { + var ignored []*ServerMessage for len(hello) > 0 { message, err := c.RunUntilMessage(ctx) if err != nil { - return err + return nil, err } if err := checkMessageType(message, "event"); err != nil { - return err + ignored = append(ignored, message) + continue } else if message.Event.Target != "room" { - return fmt.Errorf("Expected event target room, got %+v", message.Event) + return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "join" { - return fmt.Errorf("Expected event type join, got %+v", message.Event) + return nil, fmt.Errorf("Expected event type join, got %+v", message.Event) } for len(message.Event.Join) > 0 { @@ -588,10 +590,21 @@ func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMe } } if !found { - return fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) + return nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) } } } + return ignored, nil +} + +func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { + unexpected, err := c.RunUntilJoinedAndReturnIgnored(ctx, hello...) + if err != nil { + return err + } + if len(unexpected) > 0 { + return fmt.Errorf("Received unexpected messages: %+v", unexpected) + } return nil }