Also return in-order list of "join" events.

This commit is contained in:
Joachim Bauch 2022-04-11 11:05:54 +02:00
parent ef8d5ff628
commit 6fa4a8b434
No known key found for this signature in database
GPG Key ID: 77C1D22D53E15F02
2 changed files with 17 additions and 11 deletions

View File

@ -596,30 +596,36 @@ func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId
return nil return nil
} }
func (c *TestClient) RunUntilJoinedAndReturnIgnored(ctx context.Context, hello ...*HelloServerMessage) ([]*ServerMessage, error) { func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*HelloServerMessage) ([]*EventServerMessageSessionEntry, []*ServerMessage, error) {
received := make([]*EventServerMessageSessionEntry, len(hello))
var ignored []*ServerMessage var ignored []*ServerMessage
for len(hello) > 0 { hellos := make(map[*HelloServerMessage]int, len(hello))
for idx, h := range hello {
hellos[h] = idx
}
for len(hellos) > 0 {
message, err := c.RunUntilMessage(ctx) message, err := c.RunUntilMessage(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if err := checkMessageType(message, "event"); err != nil { if err := checkMessageType(message, "event"); err != nil {
ignored = append(ignored, message) ignored = append(ignored, message)
continue continue
} else if message.Event.Target != "room" { } else if message.Event.Target != "room" {
return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) return nil, nil, fmt.Errorf("Expected event target room, got %+v", message.Event)
} else if message.Event.Type != "join" { } else if message.Event.Type != "join" {
return nil, fmt.Errorf("Expected event type join, got %+v", message.Event) return nil, nil, fmt.Errorf("Expected event type join, got %+v", message.Event)
} }
for len(message.Event.Join) > 0 { for len(message.Event.Join) > 0 {
found := false found := false
loop: loop:
for idx, h := range hello { for h, idx := range hellos {
for idx2, evt := range message.Event.Join { for idx2, evt := range message.Event.Join {
if evt.SessionId == h.SessionId && evt.UserId == h.UserId { if evt.SessionId == h.SessionId && evt.UserId == h.UserId {
hello = append(hello[:idx], hello[idx+1:]...) received[idx] = evt
delete(hellos, h)
message.Event.Join = append(message.Event.Join[:idx2], message.Event.Join[idx2+1:]...) message.Event.Join = append(message.Event.Join[:idx2], message.Event.Join[idx2+1:]...)
found = true found = true
break loop break loop
@ -627,15 +633,15 @@ func (c *TestClient) RunUntilJoinedAndReturnIgnored(ctx context.Context, hello .
} }
} }
if !found { if !found {
return nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) return nil, nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0])
} }
} }
} }
return ignored, nil return received, ignored, nil
} }
func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error {
unexpected, err := c.RunUntilJoinedAndReturnIgnored(ctx, hello...) _, unexpected, err := c.RunUntilJoinedAndReturn(ctx, hello...)
if err != nil { if err != nil {
return err return err
} }

View File

@ -269,7 +269,7 @@ func Test_TransientMessages(t *testing.T) {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
} }
ignored, err := client3.RunUntilJoinedAndReturnIgnored(ctx, hello1.Hello, hello2.Hello, hello3.Hello) _, ignored, err := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }