From 6fa4a8b4344437b0f9edcb36bb509254ea4d1158 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 11 Apr 2022 11:05:54 +0200 Subject: [PATCH] Also return in-order list of "join" events. --- testclient_test.go | 26 ++++++++++++++++---------- transient_data_test.go | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/testclient_test.go b/testclient_test.go index cd54ce5..264d818 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -596,30 +596,36 @@ func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId return nil } -func (c *TestClient) RunUntilJoinedAndReturnIgnored(ctx context.Context, hello ...*HelloServerMessage) ([]*ServerMessage, error) { +func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*HelloServerMessage) ([]*EventServerMessageSessionEntry, []*ServerMessage, error) { + received := make([]*EventServerMessageSessionEntry, len(hello)) var ignored []*ServerMessage - for len(hello) > 0 { + hellos := make(map[*HelloServerMessage]int, len(hello)) + for idx, h := range hello { + hellos[h] = idx + } + for len(hellos) > 0 { message, err := c.RunUntilMessage(ctx) if err != nil { - return nil, err + return nil, nil, err } if err := checkMessageType(message, "event"); err != nil { ignored = append(ignored, message) continue } else if message.Event.Target != "room" { - return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) + return nil, nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "join" { - return nil, fmt.Errorf("Expected event type join, got %+v", message.Event) + return nil, nil, fmt.Errorf("Expected event type join, got %+v", message.Event) } for len(message.Event.Join) > 0 { found := false loop: - for idx, h := range hello { + for h, idx := range hellos { for idx2, evt := range message.Event.Join { if evt.SessionId == h.SessionId && evt.UserId == h.UserId { - hello = append(hello[:idx], hello[idx+1:]...) + received[idx] = evt + delete(hellos, h) message.Event.Join = append(message.Event.Join[:idx2], message.Event.Join[idx2+1:]...) found = true break loop @@ -627,15 +633,15 @@ func (c *TestClient) RunUntilJoinedAndReturnIgnored(ctx context.Context, hello . } } if !found { - return nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) + return nil, nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) } } } - return ignored, nil + return received, ignored, nil } func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { - unexpected, err := c.RunUntilJoinedAndReturnIgnored(ctx, hello...) + _, unexpected, err := c.RunUntilJoinedAndReturn(ctx, hello...) if err != nil { return err } diff --git a/transient_data_test.go b/transient_data_test.go index fb8e89f..598fac2 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -269,7 +269,7 @@ func Test_TransientMessages(t *testing.T) { t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) } - ignored, err := client3.RunUntilJoinedAndReturnIgnored(ctx, hello1.Hello, hello2.Hello, hello3.Hello) + _, ignored, err := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello) if err != nil { t.Fatal(err) }