Merge pull request #1127 from strukturag/initial-transient-data-clustered

Fix initial transient data in clustered setups
This commit is contained in:
Joachim Bauch 2025-11-24 09:41:28 +01:00 committed by GitHub
commit ba1af553e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 605 additions and 83 deletions

View file

@ -1342,7 +1342,9 @@ Message format (Server -> Client):
### Initial data
When sessions initially join a room, they receive the current state of the
transient data.
transient data. Please note that the initial data can be sent in multiple
events of type `initial` which must be combined to generate the total initial
data.
Message format (Server -> Client):

View file

@ -326,6 +326,40 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, url string) (uint32, e
return response.GetCount(), nil
}
func (c *GrpcClient) GetTransientData(ctx context.Context, room *Room) (TransientDataEntries, error) {
statsGrpcClientCalls.WithLabelValues("GetTransientData").Inc()
// TODO: Remove debug logging
c.logger.Printf("Get transient data for %s@%s on %s", room.Id(), room.Backend().Id(), c.Target())
response, err := c.impl.GetTransientData(ctx, &GetTransientDataRequest{
RoomId: room.Id(),
BackendUrls: room.Backend().Urls(),
}, grpc.WaitForReady(true))
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return nil, nil
} else if err != nil {
return nil, err
}
entries := response.GetEntries()
if len(entries) == 0 {
return nil, nil
}
result := make(TransientDataEntries, len(entries))
for k, v := range entries {
var value any
if err := json.Unmarshal(v.Value, &value); err != nil {
return nil, err
}
if v.Expires > 0 {
result[k] = NewTransientDataEntryWithExpires(value, time.UnixMicro(v.Expires))
} else {
result[k] = NewTransientDataEntry(value, 0)
}
}
return result, nil
}
type ProxySessionReceiver interface {
RemoteAddr() string
Country() string

View file

@ -127,6 +127,154 @@ func (x *GetServerIdReply) GetVersion() string {
return ""
}
type GetTransientDataRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
RoomId string `protobuf:"bytes,1,opt,name=roomId,proto3" json:"roomId,omitempty"`
BackendUrls []string `protobuf:"bytes,2,rep,name=backendUrls,proto3" json:"backendUrls,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetTransientDataRequest) Reset() {
*x = GetTransientDataRequest{}
mi := &file_grpc_internal_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetTransientDataRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetTransientDataRequest) ProtoMessage() {}
func (x *GetTransientDataRequest) ProtoReflect() protoreflect.Message {
mi := &file_grpc_internal_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetTransientDataRequest.ProtoReflect.Descriptor instead.
func (*GetTransientDataRequest) Descriptor() ([]byte, []int) {
return file_grpc_internal_proto_rawDescGZIP(), []int{2}
}
func (x *GetTransientDataRequest) GetRoomId() string {
if x != nil {
return x.RoomId
}
return ""
}
func (x *GetTransientDataRequest) GetBackendUrls() []string {
if x != nil {
return x.BackendUrls
}
return nil
}
type GrpcTransientDataEntry struct {
state protoimpl.MessageState `protogen:"open.v1"`
Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
Expires int64 `protobuf:"varint,2,opt,name=expires,proto3" json:"expires,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GrpcTransientDataEntry) Reset() {
*x = GrpcTransientDataEntry{}
mi := &file_grpc_internal_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GrpcTransientDataEntry) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GrpcTransientDataEntry) ProtoMessage() {}
func (x *GrpcTransientDataEntry) ProtoReflect() protoreflect.Message {
mi := &file_grpc_internal_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GrpcTransientDataEntry.ProtoReflect.Descriptor instead.
func (*GrpcTransientDataEntry) Descriptor() ([]byte, []int) {
return file_grpc_internal_proto_rawDescGZIP(), []int{3}
}
func (x *GrpcTransientDataEntry) GetValue() []byte {
if x != nil {
return x.Value
}
return nil
}
func (x *GrpcTransientDataEntry) GetExpires() int64 {
if x != nil {
return x.Expires
}
return 0
}
type GetTransientDataReply struct {
state protoimpl.MessageState `protogen:"open.v1"`
Entries map[string]*GrpcTransientDataEntry `protobuf:"bytes,1,rep,name=entries,proto3" json:"entries,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetTransientDataReply) Reset() {
*x = GetTransientDataReply{}
mi := &file_grpc_internal_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetTransientDataReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetTransientDataReply) ProtoMessage() {}
func (x *GetTransientDataReply) ProtoReflect() protoreflect.Message {
mi := &file_grpc_internal_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetTransientDataReply.ProtoReflect.Descriptor instead.
func (*GetTransientDataReply) Descriptor() ([]byte, []int) {
return file_grpc_internal_proto_rawDescGZIP(), []int{4}
}
func (x *GetTransientDataReply) GetEntries() map[string]*GrpcTransientDataEntry {
if x != nil {
return x.Entries
}
return nil
}
var File_grpc_internal_proto protoreflect.FileDescriptor
const file_grpc_internal_proto_rawDesc = "" +
@ -135,9 +283,21 @@ const file_grpc_internal_proto_rawDesc = "" +
"\x12GetServerIdRequest\"H\n" +
"\x10GetServerIdReply\x12\x1a\n" +
"\bserverId\x18\x01 \x01(\tR\bserverId\x12\x18\n" +
"\aversion\x18\x02 \x01(\tR\aversion2Z\n" +
"\aversion\x18\x02 \x01(\tR\aversion\"S\n" +
"\x17GetTransientDataRequest\x12\x16\n" +
"\x06roomId\x18\x01 \x01(\tR\x06roomId\x12 \n" +
"\vbackendUrls\x18\x02 \x03(\tR\vbackendUrls\"H\n" +
"\x16GrpcTransientDataEntry\x12\x14\n" +
"\x05value\x18\x01 \x01(\fR\x05value\x12\x18\n" +
"\aexpires\x18\x02 \x01(\x03R\aexpires\"\xbf\x01\n" +
"\x15GetTransientDataReply\x12G\n" +
"\aentries\x18\x01 \x03(\v2-.signaling.GetTransientDataReply.EntriesEntryR\aentries\x1a]\n" +
"\fEntriesEntry\x12\x10\n" +
"\x03key\x18\x01 \x01(\tR\x03key\x127\n" +
"\x05value\x18\x02 \x01(\v2!.signaling.GrpcTransientDataEntryR\x05value:\x028\x012\xb6\x01\n" +
"\vRpcInternal\x12K\n" +
"\vGetServerId\x12\x1d.signaling.GetServerIdRequest\x1a\x1b.signaling.GetServerIdReply\"\x00B<Z:github.com/strukturag/nextcloud-spreed-signaling;signalingb\x06proto3"
"\vGetServerId\x12\x1d.signaling.GetServerIdRequest\x1a\x1b.signaling.GetServerIdReply\"\x00\x12Z\n" +
"\x10GetTransientData\x12\".signaling.GetTransientDataRequest\x1a .signaling.GetTransientDataReply\"\x00B<Z:github.com/strukturag/nextcloud-spreed-signaling;signalingb\x06proto3"
var (
file_grpc_internal_proto_rawDescOnce sync.Once
@ -151,19 +311,27 @@ func file_grpc_internal_proto_rawDescGZIP() []byte {
return file_grpc_internal_proto_rawDescData
}
var file_grpc_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_grpc_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_grpc_internal_proto_goTypes = []any{
(*GetServerIdRequest)(nil), // 0: signaling.GetServerIdRequest
(*GetServerIdReply)(nil), // 1: signaling.GetServerIdReply
(*GetServerIdRequest)(nil), // 0: signaling.GetServerIdRequest
(*GetServerIdReply)(nil), // 1: signaling.GetServerIdReply
(*GetTransientDataRequest)(nil), // 2: signaling.GetTransientDataRequest
(*GrpcTransientDataEntry)(nil), // 3: signaling.GrpcTransientDataEntry
(*GetTransientDataReply)(nil), // 4: signaling.GetTransientDataReply
nil, // 5: signaling.GetTransientDataReply.EntriesEntry
}
var file_grpc_internal_proto_depIdxs = []int32{
0, // 0: signaling.RpcInternal.GetServerId:input_type -> signaling.GetServerIdRequest
1, // 1: signaling.RpcInternal.GetServerId:output_type -> signaling.GetServerIdReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
5, // 0: signaling.GetTransientDataReply.entries:type_name -> signaling.GetTransientDataReply.EntriesEntry
3, // 1: signaling.GetTransientDataReply.EntriesEntry.value:type_name -> signaling.GrpcTransientDataEntry
0, // 2: signaling.RpcInternal.GetServerId:input_type -> signaling.GetServerIdRequest
2, // 3: signaling.RpcInternal.GetTransientData:input_type -> signaling.GetTransientDataRequest
1, // 4: signaling.RpcInternal.GetServerId:output_type -> signaling.GetServerIdReply
4, // 5: signaling.RpcInternal.GetTransientData:output_type -> signaling.GetTransientDataReply
4, // [4:6] is the sub-list for method output_type
2, // [2:4] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_grpc_internal_proto_init() }
@ -177,7 +345,7 @@ func file_grpc_internal_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_grpc_internal_proto_rawDesc), len(file_grpc_internal_proto_rawDesc)),
NumEnums: 0,
NumMessages: 2,
NumMessages: 6,
NumExtensions: 0,
NumServices: 1,
},

View file

@ -27,6 +27,7 @@ package signaling;
service RpcInternal {
rpc GetServerId(GetServerIdRequest) returns (GetServerIdReply) {}
rpc GetTransientData(GetTransientDataRequest) returns (GetTransientDataReply) {}
}
message GetServerIdRequest {
@ -36,3 +37,17 @@ message GetServerIdReply {
string serverId = 1;
string version = 2;
}
message GetTransientDataRequest {
string roomId = 1;
repeated string backendUrls = 2;
}
message GrpcTransientDataEntry {
bytes value = 1;
int64 expires = 2;
}
message GetTransientDataReply {
map<string, GrpcTransientDataEntry> entries = 1;
}

View file

@ -37,7 +37,8 @@ import (
const _ = grpc.SupportPackageIsVersion9
const (
RpcInternal_GetServerId_FullMethodName = "/signaling.RpcInternal/GetServerId"
RpcInternal_GetServerId_FullMethodName = "/signaling.RpcInternal/GetServerId"
RpcInternal_GetTransientData_FullMethodName = "/signaling.RpcInternal/GetTransientData"
)
// RpcInternalClient is the client API for RpcInternal service.
@ -45,6 +46,7 @@ const (
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type RpcInternalClient interface {
GetServerId(ctx context.Context, in *GetServerIdRequest, opts ...grpc.CallOption) (*GetServerIdReply, error)
GetTransientData(ctx context.Context, in *GetTransientDataRequest, opts ...grpc.CallOption) (*GetTransientDataReply, error)
}
type rpcInternalClient struct {
@ -65,11 +67,22 @@ func (c *rpcInternalClient) GetServerId(ctx context.Context, in *GetServerIdRequ
return out, nil
}
func (c *rpcInternalClient) GetTransientData(ctx context.Context, in *GetTransientDataRequest, opts ...grpc.CallOption) (*GetTransientDataReply, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetTransientDataReply)
err := c.cc.Invoke(ctx, RpcInternal_GetTransientData_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// RpcInternalServer is the server API for RpcInternal service.
// All implementations must embed UnimplementedRpcInternalServer
// for forward compatibility.
type RpcInternalServer interface {
GetServerId(context.Context, *GetServerIdRequest) (*GetServerIdReply, error)
GetTransientData(context.Context, *GetTransientDataRequest) (*GetTransientDataReply, error)
mustEmbedUnimplementedRpcInternalServer()
}
@ -83,6 +96,9 @@ type UnimplementedRpcInternalServer struct{}
func (UnimplementedRpcInternalServer) GetServerId(context.Context, *GetServerIdRequest) (*GetServerIdReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetServerId not implemented")
}
func (UnimplementedRpcInternalServer) GetTransientData(context.Context, *GetTransientDataRequest) (*GetTransientDataReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetTransientData not implemented")
}
func (UnimplementedRpcInternalServer) mustEmbedUnimplementedRpcInternalServer() {}
func (UnimplementedRpcInternalServer) testEmbeddedByValue() {}
@ -122,6 +138,24 @@ func _RpcInternal_GetServerId_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _RpcInternal_GetTransientData_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetTransientDataRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(RpcInternalServer).GetTransientData(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: RpcInternal_GetTransientData_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(RpcInternalServer).GetTransientData(ctx, req.(*GetTransientDataRequest))
}
return interceptor(ctx, in, info, handler)
}
// RpcInternal_ServiceDesc is the grpc.ServiceDesc for RpcInternal service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -133,6 +167,10 @@ var RpcInternal_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetServerId",
Handler: _RpcInternal_GetServerId_Handler,
},
{
MethodName: "GetTransientData",
Handler: _RpcInternal_GetTransientData_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "grpc_internal.proto",

View file

@ -25,6 +25,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
@ -304,6 +305,67 @@ func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdReques
}, nil
}
func (s *GrpcServer) GetTransientData(ctx context.Context, request *GetTransientDataRequest) (*GetTransientDataReply, error) {
statsGrpcServerCalls.WithLabelValues("GetTransientData").Inc()
backendUrls := request.BackendUrls
if len(backendUrls) == 0 {
// Only compat backend.
backendUrls = []string{""}
}
result := &GetTransientDataReply{}
processed := make(map[string]bool)
for _, bu := range backendUrls {
var parsed *url.URL
if bu != "" {
var err error
parsed, err = url.Parse(bu)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid url")
}
}
backend := s.hub.GetBackend(parsed)
if backend == nil {
return nil, status.Error(codes.NotFound, "no such backend")
}
// Only process each backend once.
if processed[backend.Id()] {
continue
}
processed[backend.Id()] = true
room := s.hub.GetRoomForBackend(request.RoomId, backend)
if room == nil {
return nil, status.Error(codes.NotFound, "no such room")
}
entries := room.transientData.GetEntries()
if len(entries) == 0 {
return nil, status.Error(codes.NotFound, "room has no transient data")
}
if result.Entries == nil {
result.Entries = make(map[string]*GrpcTransientDataEntry)
}
for k, v := range entries {
e := &GrpcTransientDataEntry{}
var err error
if e.Value, err = json.Marshal(v.Value); err != nil {
return nil, status.Errorf(codes.Internal, "error marshalling data: %s", err)
}
if !v.Expires.IsZero() {
e.Expires = v.Expires.UnixMicro()
}
result.Entries[k] = e
}
}
return result, nil
}
func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) {
statsGrpcServerCalls.WithLabelValues("SessionCount").Inc()

39
hub.go
View file

@ -2676,33 +2676,7 @@ func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
return
}
var err error
if msg.Value == nil {
err = h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionDelete,
Key: msg.Key,
},
},
})
} else {
err = h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionSet,
Key: msg.Key,
Value: msg.Value,
TTL: msg.TTL,
},
},
})
}
if err != nil {
if err := room.SetTransientDataTTL(msg.Key, msg.Value, msg.TTL); err != nil {
response := message.NewWrappedErrorServerMessage(err)
session.SendMessage(response)
return
@ -2713,16 +2687,7 @@ func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
return
}
if err := h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionDelete,
Key: msg.Key,
},
},
}); err != nil {
if err := room.RemoveTransientData(msg.Key); err != nil {
response := message.NewWrappedErrorServerMessage(err)
session.SendMessage(response)
return

112
room.go
View file

@ -258,9 +258,13 @@ func (r *Room) processBackendRoomRequestRoom(message *BackendServerRoomRequest)
case "transient":
switch message.Transient.Action {
case TransientActionSet:
r.SetTransientDataTTL(message.Transient.Key, message.Transient.Value, message.Transient.TTL)
if message.Transient.TTL == 0 {
r.doSetTransientData(message.Transient.Key, message.Transient.Value)
} else {
r.doSetTransientDataTTL(message.Transient.Key, message.Transient.Value, message.Transient.TTL)
}
case TransientActionDelete:
r.RemoveTransientData(message.Transient.Key)
r.doRemoveTransientData(message.Transient.Key)
default:
r.logger.Printf("Unsupported transient action in room %s: %+v", r.Id(), message.Transient)
}
@ -293,6 +297,7 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) {
sid := session.PublicId()
r.mu.Lock()
isFirst := len(r.sessions) == 0
_, found := r.sessions[sid]
r.sessions[sid] = session
if !found {
@ -334,6 +339,9 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) {
if clientSession, ok := session.(*ClientSession); ok {
r.transientData.AddListener(clientSession)
}
if isFirst {
r.fetchInitialTransientData()
}
}
// Trigger notifications that the session joined.
@ -1202,14 +1210,108 @@ func (r *Room) notifyInternalRoomDeleted() {
}
}
func (r *Room) SetTransientData(key string, value any) {
func (r *Room) SetTransientData(key string, value any) error {
if value == nil {
return r.RemoveTransientData(key)
}
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionSet,
Key: key,
Value: value,
},
},
})
}
func (r *Room) doSetTransientData(key string, value any) {
r.transientData.Set(key, value)
}
func (r *Room) SetTransientDataTTL(key string, value any, ttl time.Duration) {
func (r *Room) SetTransientDataTTL(key string, value any, ttl time.Duration) error {
if value == nil {
return r.RemoveTransientData(key)
} else if ttl == 0 {
return r.SetTransientData(key, value)
}
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionSet,
Key: key,
Value: value,
TTL: ttl,
},
},
})
}
func (r *Room) doSetTransientDataTTL(key string, value any, ttl time.Duration) {
r.transientData.SetTTL(key, value, ttl)
}
func (r *Room) RemoveTransientData(key string) {
func (r *Room) RemoveTransientData(key string) error {
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
Type: "room",
Room: &BackendServerRoomRequest{
Type: "transient",
Transient: &BackendRoomTransientRequest{
Action: TransientActionDelete,
Key: key,
},
},
})
}
func (r *Room) doRemoveTransientData(key string) {
r.transientData.Remove(key)
}
func (r *Room) fetchInitialTransientData() {
if r.hub.rpcClients == nil {
return
}
ctx := NewLoggerContext(context.Background(), r.logger)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
var wg sync.WaitGroup
var mu sync.Mutex
// +checklocks:mu
var initial TransientDataEntries
for _, client := range r.hub.rpcClients.GetClients() {
wg.Add(1)
go func(c *GrpcClient) {
defer wg.Done()
data, err := c.GetTransientData(ctx, r)
if err != nil {
r.logger.Printf("Received error while getting transient data for %s@%s from %s: %s", r.Id(), r.Backend().Id(), c.Target(), err)
return
}
r.logger.Printf("Received initial transient data %+v from %s", data, c.Target())
mu.Lock()
defer mu.Unlock()
if initial == nil {
initial = make(TransientDataEntries)
}
maps.Copy(initial, data)
}(client)
}
wg.Wait()
mu.Lock()
defer mu.Unlock()
if len(initial) > 0 {
r.transientData.SetInitial(initial)
}
}

View file

@ -22,7 +22,8 @@
package signaling
import (
"maps"
"encoding/json"
"fmt"
"reflect"
"sync"
"time"
@ -34,10 +35,58 @@ type TransientListener interface {
SendMessage(message *ServerMessage) bool
}
type TransientDataEntry struct {
Value any `json:"value"`
Expires time.Time `json:"expires,omitzero"`
}
func NewTransientDataEntry(value any, ttl time.Duration) *TransientDataEntry {
entry := &TransientDataEntry{
Value: value,
}
if ttl > 0 {
entry.Expires = time.Now().Add(ttl)
}
return entry
}
func NewTransientDataEntryWithExpires(value any, expires time.Time) *TransientDataEntry {
entry := &TransientDataEntry{
Value: value,
Expires: expires,
}
return entry
}
func (e *TransientDataEntry) clone() *TransientDataEntry {
result := *e
return &result
}
func (e *TransientDataEntry) update(value any, ttl time.Duration) {
e.Value = value
if ttl > 0 {
e.Expires = time.Now().Add(ttl)
} else {
e.Expires = time.Time{}
}
}
type TransientDataEntries map[string]*TransientDataEntry
func (e TransientDataEntries) String() string {
data, err := json.Marshal(e)
if err != nil {
return fmt.Sprintf("Could not serialize %#v: %s", e, err)
}
return string(data)
}
type TransientData struct {
mu sync.Mutex
// +checklocks:mu
data api.StringMap
data TransientDataEntries
// +checklocks:mu
listeners map[TransientListener]bool
// +checklocks:mu
@ -66,8 +115,8 @@ func (t *TransientData) notifySet(key string, prev, value any) {
TransientData: &TransientDataServerMessage{
Type: "set",
Key: key,
OldValue: prev,
Value: value,
OldValue: prev,
},
}
for listener := range t.listeners {
@ -76,15 +125,17 @@ func (t *TransientData) notifySet(key string, prev, value any) {
}
// +checklocks:t.mu
func (t *TransientData) notifyDeleted(key string, prev any) {
func (t *TransientData) notifyDeleted(key string, prev *TransientDataEntry) {
msg := &ServerMessage{
Type: "transient",
TransientData: &TransientDataServerMessage{
Type: "remove",
Key: key,
OldValue: prev,
Type: "remove",
Key: key,
},
}
if prev != nil {
msg.TransientData.OldValue = prev.Value
}
for listener := range t.listeners {
t.sendMessageToListener(listener, msg)
}
@ -100,11 +151,15 @@ func (t *TransientData) AddListener(listener TransientListener) {
}
t.listeners[listener] = true
if len(t.data) > 0 {
data := make(api.StringMap, len(t.data))
for k, v := range t.data {
data[k] = v.Value
}
msg := &ServerMessage{
Type: "transient",
TransientData: &TransientDataServerMessage{
Type: "initial",
Data: t.data,
Data: data,
},
}
t.sendMessageToListener(listener, msg)
@ -157,12 +212,19 @@ func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration)
}
// +checklocks:t.mu
func (t *TransientData) doSet(key string, value any, prev any, ttl time.Duration) {
func (t *TransientData) doSet(key string, value any, prev *TransientDataEntry, ttl time.Duration) {
if t.data == nil {
t.data = make(api.StringMap)
t.data = make(TransientDataEntries)
}
t.data[key] = value
t.notifySet(key, prev, value)
var oldValue any
if prev == nil {
entry := NewTransientDataEntry(value, ttl)
t.data[key] = entry
} else {
oldValue = prev.Value
prev.update(value, ttl)
}
t.notifySet(key, oldValue, value)
t.removeAfterTTL(key, value, ttl)
}
@ -183,7 +245,7 @@ func (t *TransientData) SetTTL(key string, value any, ttl time.Duration) bool {
defer t.mu.Unlock()
prev, found := t.data[key]
if found && reflect.DeepEqual(prev, value) {
if found && reflect.DeepEqual(prev.Value, value) {
t.updateTTL(key, value, ttl)
return false
}
@ -210,7 +272,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du
defer t.mu.Unlock()
prev, found := t.data[key]
if old != nil && (!found || !reflect.DeepEqual(prev, old)) {
if old != nil && (!found || !reflect.DeepEqual(prev.Value, old)) {
return false
} else if old == nil && found {
return false
@ -221,7 +283,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du
}
// +checklocks:t.mu
func (t *TransientData) doRemove(key string, prev any) {
func (t *TransientData) doRemove(key string, prev *TransientDataEntry) {
delete(t.data, key)
if old, found := t.timers[key]; found {
old.Stop()
@ -257,7 +319,7 @@ func (t *TransientData) CompareAndRemove(key string, old any) bool {
// +checklocks:t.mu
func (t *TransientData) compareAndRemove(key string, old any) bool {
prev, found := t.data[key]
if !found || !reflect.DeepEqual(prev, old) {
if !found || !reflect.DeepEqual(prev.Value, old) {
return false
}
@ -270,7 +332,66 @@ func (t *TransientData) GetData() api.StringMap {
t.mu.Lock()
defer t.mu.Unlock()
result := make(api.StringMap)
maps.Copy(result, t.data)
if len(t.data) == 0 {
return nil
}
result := make(api.StringMap, len(t.data))
for k, entry := range t.data {
result[k] = entry.Value
}
return result
}
// GetEntries returns a copy of the internal data entries.
func (t *TransientData) GetEntries() TransientDataEntries {
t.mu.Lock()
defer t.mu.Unlock()
if len(t.data) == 0 {
return nil
}
result := make(TransientDataEntries, len(t.data))
for k, e := range t.data {
result[k] = e.clone()
}
return result
}
// SetInitial sets the initial data and notifies listeners.
func (t *TransientData) SetInitial(data TransientDataEntries) {
if len(data) == 0 {
return
}
t.mu.Lock()
defer t.mu.Unlock()
if t.data == nil {
t.data = make(TransientDataEntries)
}
msgData := make(api.StringMap, len(data))
for k, v := range data {
if _, found := t.data[k]; found {
// Entry already present (i.e. was set by regular event).
continue
}
msgData[k] = v.Value
}
if len(msgData) == 0 {
return
}
msg := &ServerMessage{
Type: "transient",
TransientData: &TransientDataServerMessage{
Type: "initial",
Data: msgData,
},
}
for listener := range t.listeners {
t.sendMessageToListener(listener, msg)
}
}

View file

@ -143,6 +143,7 @@ func Test_TransientMessages(t *testing.T) {
t.Run(subtest, func(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
var hub1 *Hub
var hub2 *Hub
var server1 *httptest.Server
@ -245,13 +246,24 @@ func Test_TransientMessages(t *testing.T) {
client1.RunUntilErrorIs(ctx3, context.DeadlineExceeded)
require.NoError(client1.SetTransientData("abc", data, 10*time.Millisecond))
ttl := 200 * time.Millisecond
require.NoError(client1.SetTransientData("abc", data, ttl))
setAt := time.Now()
if msg, ok := client2.RunUntilMessage(ctx); ok {
checkMessageTransientSet(t, msg, "abc", data, nil)
}
client1.CloseWithBye()
require.NoError(client1.WaitForClientRemoved(ctx))
client2.RunUntilLeft(ctx, hello1.Hello)
client3, hello3 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"3")
roomMsg = MustSucceed2(t, client3.JoinRoom, ctx, roomId)
require.Equal(roomId, roomMsg.Room.RoomId)
_, ignored, ok := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello)
client2.RunUntilJoined(ctx, hello3.Hello)
_, ignored, ok := client3.RunUntilJoinedAndReturn(ctx, hello2.Hello, hello3.Hello)
require.True(ok)
var msg *ServerMessage
@ -263,13 +275,16 @@ func Test_TransientMessages(t *testing.T) {
require.LessOrEqual(len(ignored), 1, "Received too many messages: %+v", ignored)
}
checkMessageTransientInitial(t, msg, api.StringMap{
"abc": data,
})
delta := time.Until(setAt.Add(ttl))
if assert.Greater(delta, time.Duration(0), "test runner too slow?") {
checkMessageTransientInitial(t, msg, api.StringMap{
"abc": data,
})
time.Sleep(10 * time.Millisecond)
if msg, ok = client3.RunUntilMessage(ctx); ok {
checkMessageTransientRemove(t, msg, "abc", data)
time.Sleep(delta)
if msg, ok = client2.RunUntilMessage(ctx); ok {
checkMessageTransientRemove(t, msg, "abc", data)
}
}
})
}