mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2026-03-14 14:35:44 +01:00
Merge pull request #1127 from strukturag/initial-transient-data-clustered
Fix initial transient data in clustered setups
This commit is contained in:
commit
ba1af553e0
10 changed files with 605 additions and 83 deletions
|
|
@ -1342,7 +1342,9 @@ Message format (Server -> Client):
|
|||
### Initial data
|
||||
|
||||
When sessions initially join a room, they receive the current state of the
|
||||
transient data.
|
||||
transient data. Please note that the initial data can be sent in multiple
|
||||
events of type `initial` which must be combined to generate the total initial
|
||||
data.
|
||||
|
||||
Message format (Server -> Client):
|
||||
|
||||
|
|
|
|||
|
|
@ -326,6 +326,40 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, url string) (uint32, e
|
|||
return response.GetCount(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetTransientData(ctx context.Context, room *Room) (TransientDataEntries, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("GetTransientData").Inc()
|
||||
// TODO: Remove debug logging
|
||||
c.logger.Printf("Get transient data for %s@%s on %s", room.Id(), room.Backend().Id(), c.Target())
|
||||
response, err := c.impl.GetTransientData(ctx, &GetTransientDataRequest{
|
||||
RoomId: room.Id(),
|
||||
BackendUrls: room.Backend().Urls(),
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries := response.GetEntries()
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(TransientDataEntries, len(entries))
|
||||
for k, v := range entries {
|
||||
var value any
|
||||
if err := json.Unmarshal(v.Value, &value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.Expires > 0 {
|
||||
result[k] = NewTransientDataEntryWithExpires(value, time.UnixMicro(v.Expires))
|
||||
} else {
|
||||
result[k] = NewTransientDataEntry(value, 0)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type ProxySessionReceiver interface {
|
||||
RemoteAddr() string
|
||||
Country() string
|
||||
|
|
|
|||
|
|
@ -127,6 +127,154 @@ func (x *GetServerIdReply) GetVersion() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
type GetTransientDataRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
RoomId string `protobuf:"bytes,1,opt,name=roomId,proto3" json:"roomId,omitempty"`
|
||||
BackendUrls []string `protobuf:"bytes,2,rep,name=backendUrls,proto3" json:"backendUrls,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetTransientDataRequest) Reset() {
|
||||
*x = GetTransientDataRequest{}
|
||||
mi := &file_grpc_internal_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *GetTransientDataRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetTransientDataRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetTransientDataRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_grpc_internal_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetTransientDataRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetTransientDataRequest) Descriptor() ([]byte, []int) {
|
||||
return file_grpc_internal_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *GetTransientDataRequest) GetRoomId() string {
|
||||
if x != nil {
|
||||
return x.RoomId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *GetTransientDataRequest) GetBackendUrls() []string {
|
||||
if x != nil {
|
||||
return x.BackendUrls
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GrpcTransientDataEntry struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
|
||||
Expires int64 `protobuf:"varint,2,opt,name=expires,proto3" json:"expires,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GrpcTransientDataEntry) Reset() {
|
||||
*x = GrpcTransientDataEntry{}
|
||||
mi := &file_grpc_internal_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *GrpcTransientDataEntry) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GrpcTransientDataEntry) ProtoMessage() {}
|
||||
|
||||
func (x *GrpcTransientDataEntry) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_grpc_internal_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GrpcTransientDataEntry.ProtoReflect.Descriptor instead.
|
||||
func (*GrpcTransientDataEntry) Descriptor() ([]byte, []int) {
|
||||
return file_grpc_internal_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *GrpcTransientDataEntry) GetValue() []byte {
|
||||
if x != nil {
|
||||
return x.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *GrpcTransientDataEntry) GetExpires() int64 {
|
||||
if x != nil {
|
||||
return x.Expires
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type GetTransientDataReply struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Entries map[string]*GrpcTransientDataEntry `protobuf:"bytes,1,rep,name=entries,proto3" json:"entries,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetTransientDataReply) Reset() {
|
||||
*x = GetTransientDataReply{}
|
||||
mi := &file_grpc_internal_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *GetTransientDataReply) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetTransientDataReply) ProtoMessage() {}
|
||||
|
||||
func (x *GetTransientDataReply) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_grpc_internal_proto_msgTypes[4]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetTransientDataReply.ProtoReflect.Descriptor instead.
|
||||
func (*GetTransientDataReply) Descriptor() ([]byte, []int) {
|
||||
return file_grpc_internal_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *GetTransientDataReply) GetEntries() map[string]*GrpcTransientDataEntry {
|
||||
if x != nil {
|
||||
return x.Entries
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_grpc_internal_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_grpc_internal_proto_rawDesc = "" +
|
||||
|
|
@ -135,9 +283,21 @@ const file_grpc_internal_proto_rawDesc = "" +
|
|||
"\x12GetServerIdRequest\"H\n" +
|
||||
"\x10GetServerIdReply\x12\x1a\n" +
|
||||
"\bserverId\x18\x01 \x01(\tR\bserverId\x12\x18\n" +
|
||||
"\aversion\x18\x02 \x01(\tR\aversion2Z\n" +
|
||||
"\aversion\x18\x02 \x01(\tR\aversion\"S\n" +
|
||||
"\x17GetTransientDataRequest\x12\x16\n" +
|
||||
"\x06roomId\x18\x01 \x01(\tR\x06roomId\x12 \n" +
|
||||
"\vbackendUrls\x18\x02 \x03(\tR\vbackendUrls\"H\n" +
|
||||
"\x16GrpcTransientDataEntry\x12\x14\n" +
|
||||
"\x05value\x18\x01 \x01(\fR\x05value\x12\x18\n" +
|
||||
"\aexpires\x18\x02 \x01(\x03R\aexpires\"\xbf\x01\n" +
|
||||
"\x15GetTransientDataReply\x12G\n" +
|
||||
"\aentries\x18\x01 \x03(\v2-.signaling.GetTransientDataReply.EntriesEntryR\aentries\x1a]\n" +
|
||||
"\fEntriesEntry\x12\x10\n" +
|
||||
"\x03key\x18\x01 \x01(\tR\x03key\x127\n" +
|
||||
"\x05value\x18\x02 \x01(\v2!.signaling.GrpcTransientDataEntryR\x05value:\x028\x012\xb6\x01\n" +
|
||||
"\vRpcInternal\x12K\n" +
|
||||
"\vGetServerId\x12\x1d.signaling.GetServerIdRequest\x1a\x1b.signaling.GetServerIdReply\"\x00B<Z:github.com/strukturag/nextcloud-spreed-signaling;signalingb\x06proto3"
|
||||
"\vGetServerId\x12\x1d.signaling.GetServerIdRequest\x1a\x1b.signaling.GetServerIdReply\"\x00\x12Z\n" +
|
||||
"\x10GetTransientData\x12\".signaling.GetTransientDataRequest\x1a .signaling.GetTransientDataReply\"\x00B<Z:github.com/strukturag/nextcloud-spreed-signaling;signalingb\x06proto3"
|
||||
|
||||
var (
|
||||
file_grpc_internal_proto_rawDescOnce sync.Once
|
||||
|
|
@ -151,19 +311,27 @@ func file_grpc_internal_proto_rawDescGZIP() []byte {
|
|||
return file_grpc_internal_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_grpc_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_grpc_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_grpc_internal_proto_goTypes = []any{
|
||||
(*GetServerIdRequest)(nil), // 0: signaling.GetServerIdRequest
|
||||
(*GetServerIdReply)(nil), // 1: signaling.GetServerIdReply
|
||||
(*GetServerIdRequest)(nil), // 0: signaling.GetServerIdRequest
|
||||
(*GetServerIdReply)(nil), // 1: signaling.GetServerIdReply
|
||||
(*GetTransientDataRequest)(nil), // 2: signaling.GetTransientDataRequest
|
||||
(*GrpcTransientDataEntry)(nil), // 3: signaling.GrpcTransientDataEntry
|
||||
(*GetTransientDataReply)(nil), // 4: signaling.GetTransientDataReply
|
||||
nil, // 5: signaling.GetTransientDataReply.EntriesEntry
|
||||
}
|
||||
var file_grpc_internal_proto_depIdxs = []int32{
|
||||
0, // 0: signaling.RpcInternal.GetServerId:input_type -> signaling.GetServerIdRequest
|
||||
1, // 1: signaling.RpcInternal.GetServerId:output_type -> signaling.GetServerIdReply
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
5, // 0: signaling.GetTransientDataReply.entries:type_name -> signaling.GetTransientDataReply.EntriesEntry
|
||||
3, // 1: signaling.GetTransientDataReply.EntriesEntry.value:type_name -> signaling.GrpcTransientDataEntry
|
||||
0, // 2: signaling.RpcInternal.GetServerId:input_type -> signaling.GetServerIdRequest
|
||||
2, // 3: signaling.RpcInternal.GetTransientData:input_type -> signaling.GetTransientDataRequest
|
||||
1, // 4: signaling.RpcInternal.GetServerId:output_type -> signaling.GetServerIdReply
|
||||
4, // 5: signaling.RpcInternal.GetTransientData:output_type -> signaling.GetTransientDataReply
|
||||
4, // [4:6] is the sub-list for method output_type
|
||||
2, // [2:4] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_grpc_internal_proto_init() }
|
||||
|
|
@ -177,7 +345,7 @@ func file_grpc_internal_proto_init() {
|
|||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_grpc_internal_proto_rawDesc), len(file_grpc_internal_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumMessages: 6,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ package signaling;
|
|||
|
||||
service RpcInternal {
|
||||
rpc GetServerId(GetServerIdRequest) returns (GetServerIdReply) {}
|
||||
rpc GetTransientData(GetTransientDataRequest) returns (GetTransientDataReply) {}
|
||||
}
|
||||
|
||||
message GetServerIdRequest {
|
||||
|
|
@ -36,3 +37,17 @@ message GetServerIdReply {
|
|||
string serverId = 1;
|
||||
string version = 2;
|
||||
}
|
||||
|
||||
message GetTransientDataRequest {
|
||||
string roomId = 1;
|
||||
repeated string backendUrls = 2;
|
||||
}
|
||||
|
||||
message GrpcTransientDataEntry {
|
||||
bytes value = 1;
|
||||
int64 expires = 2;
|
||||
}
|
||||
|
||||
message GetTransientDataReply {
|
||||
map<string, GrpcTransientDataEntry> entries = 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ import (
|
|||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
RpcInternal_GetServerId_FullMethodName = "/signaling.RpcInternal/GetServerId"
|
||||
RpcInternal_GetServerId_FullMethodName = "/signaling.RpcInternal/GetServerId"
|
||||
RpcInternal_GetTransientData_FullMethodName = "/signaling.RpcInternal/GetTransientData"
|
||||
)
|
||||
|
||||
// RpcInternalClient is the client API for RpcInternal service.
|
||||
|
|
@ -45,6 +46,7 @@ const (
|
|||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type RpcInternalClient interface {
|
||||
GetServerId(ctx context.Context, in *GetServerIdRequest, opts ...grpc.CallOption) (*GetServerIdReply, error)
|
||||
GetTransientData(ctx context.Context, in *GetTransientDataRequest, opts ...grpc.CallOption) (*GetTransientDataReply, error)
|
||||
}
|
||||
|
||||
type rpcInternalClient struct {
|
||||
|
|
@ -65,11 +67,22 @@ func (c *rpcInternalClient) GetServerId(ctx context.Context, in *GetServerIdRequ
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (c *rpcInternalClient) GetTransientData(ctx context.Context, in *GetTransientDataRequest, opts ...grpc.CallOption) (*GetTransientDataReply, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetTransientDataReply)
|
||||
err := c.cc.Invoke(ctx, RpcInternal_GetTransientData_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// RpcInternalServer is the server API for RpcInternal service.
|
||||
// All implementations must embed UnimplementedRpcInternalServer
|
||||
// for forward compatibility.
|
||||
type RpcInternalServer interface {
|
||||
GetServerId(context.Context, *GetServerIdRequest) (*GetServerIdReply, error)
|
||||
GetTransientData(context.Context, *GetTransientDataRequest) (*GetTransientDataReply, error)
|
||||
mustEmbedUnimplementedRpcInternalServer()
|
||||
}
|
||||
|
||||
|
|
@ -83,6 +96,9 @@ type UnimplementedRpcInternalServer struct{}
|
|||
func (UnimplementedRpcInternalServer) GetServerId(context.Context, *GetServerIdRequest) (*GetServerIdReply, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetServerId not implemented")
|
||||
}
|
||||
func (UnimplementedRpcInternalServer) GetTransientData(context.Context, *GetTransientDataRequest) (*GetTransientDataReply, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetTransientData not implemented")
|
||||
}
|
||||
func (UnimplementedRpcInternalServer) mustEmbedUnimplementedRpcInternalServer() {}
|
||||
func (UnimplementedRpcInternalServer) testEmbeddedByValue() {}
|
||||
|
||||
|
|
@ -122,6 +138,24 @@ func _RpcInternal_GetServerId_Handler(srv interface{}, ctx context.Context, dec
|
|||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _RpcInternal_GetTransientData_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetTransientDataRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(RpcInternalServer).GetTransientData(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: RpcInternal_GetTransientData_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(RpcInternalServer).GetTransientData(ctx, req.(*GetTransientDataRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// RpcInternal_ServiceDesc is the grpc.ServiceDesc for RpcInternal service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
|
|
@ -133,6 +167,10 @@ var RpcInternal_ServiceDesc = grpc.ServiceDesc{
|
|||
MethodName: "GetServerId",
|
||||
Handler: _RpcInternal_GetServerId_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetTransientData",
|
||||
Handler: _RpcInternal_GetTransientData_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "grpc_internal.proto",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
|
@ -304,6 +305,67 @@ func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdReques
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) GetTransientData(ctx context.Context, request *GetTransientDataRequest) (*GetTransientDataReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("GetTransientData").Inc()
|
||||
|
||||
backendUrls := request.BackendUrls
|
||||
if len(backendUrls) == 0 {
|
||||
// Only compat backend.
|
||||
backendUrls = []string{""}
|
||||
}
|
||||
|
||||
result := &GetTransientDataReply{}
|
||||
processed := make(map[string]bool)
|
||||
for _, bu := range backendUrls {
|
||||
var parsed *url.URL
|
||||
if bu != "" {
|
||||
var err error
|
||||
parsed, err = url.Parse(bu)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid url")
|
||||
}
|
||||
}
|
||||
|
||||
backend := s.hub.GetBackend(parsed)
|
||||
if backend == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such backend")
|
||||
}
|
||||
|
||||
// Only process each backend once.
|
||||
if processed[backend.Id()] {
|
||||
continue
|
||||
}
|
||||
processed[backend.Id()] = true
|
||||
|
||||
room := s.hub.GetRoomForBackend(request.RoomId, backend)
|
||||
if room == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such room")
|
||||
}
|
||||
|
||||
entries := room.transientData.GetEntries()
|
||||
if len(entries) == 0 {
|
||||
return nil, status.Error(codes.NotFound, "room has no transient data")
|
||||
}
|
||||
|
||||
if result.Entries == nil {
|
||||
result.Entries = make(map[string]*GrpcTransientDataEntry)
|
||||
}
|
||||
for k, v := range entries {
|
||||
e := &GrpcTransientDataEntry{}
|
||||
var err error
|
||||
if e.Value, err = json.Marshal(v.Value); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "error marshalling data: %s", err)
|
||||
}
|
||||
if !v.Expires.IsZero() {
|
||||
e.Expires = v.Expires.UnixMicro()
|
||||
}
|
||||
result.Entries[k] = e
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("SessionCount").Inc()
|
||||
|
||||
|
|
|
|||
39
hub.go
39
hub.go
|
|
@ -2676,33 +2676,7 @@ func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if msg.Value == nil {
|
||||
err = h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionDelete,
|
||||
Key: msg.Key,
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
err = h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionSet,
|
||||
Key: msg.Key,
|
||||
Value: msg.Value,
|
||||
TTL: msg.TTL,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
if err := room.SetTransientDataTTL(msg.Key, msg.Value, msg.TTL); err != nil {
|
||||
response := message.NewWrappedErrorServerMessage(err)
|
||||
session.SendMessage(response)
|
||||
return
|
||||
|
|
@ -2713,16 +2687,7 @@ func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.events.PublishBackendRoomMessage(room.Id(), room.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionDelete,
|
||||
Key: msg.Key,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
if err := room.RemoveTransientData(msg.Key); err != nil {
|
||||
response := message.NewWrappedErrorServerMessage(err)
|
||||
session.SendMessage(response)
|
||||
return
|
||||
|
|
|
|||
112
room.go
112
room.go
|
|
@ -258,9 +258,13 @@ func (r *Room) processBackendRoomRequestRoom(message *BackendServerRoomRequest)
|
|||
case "transient":
|
||||
switch message.Transient.Action {
|
||||
case TransientActionSet:
|
||||
r.SetTransientDataTTL(message.Transient.Key, message.Transient.Value, message.Transient.TTL)
|
||||
if message.Transient.TTL == 0 {
|
||||
r.doSetTransientData(message.Transient.Key, message.Transient.Value)
|
||||
} else {
|
||||
r.doSetTransientDataTTL(message.Transient.Key, message.Transient.Value, message.Transient.TTL)
|
||||
}
|
||||
case TransientActionDelete:
|
||||
r.RemoveTransientData(message.Transient.Key)
|
||||
r.doRemoveTransientData(message.Transient.Key)
|
||||
default:
|
||||
r.logger.Printf("Unsupported transient action in room %s: %+v", r.Id(), message.Transient)
|
||||
}
|
||||
|
|
@ -293,6 +297,7 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) {
|
|||
|
||||
sid := session.PublicId()
|
||||
r.mu.Lock()
|
||||
isFirst := len(r.sessions) == 0
|
||||
_, found := r.sessions[sid]
|
||||
r.sessions[sid] = session
|
||||
if !found {
|
||||
|
|
@ -334,6 +339,9 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) {
|
|||
if clientSession, ok := session.(*ClientSession); ok {
|
||||
r.transientData.AddListener(clientSession)
|
||||
}
|
||||
if isFirst {
|
||||
r.fetchInitialTransientData()
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger notifications that the session joined.
|
||||
|
|
@ -1202,14 +1210,108 @@ func (r *Room) notifyInternalRoomDeleted() {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Room) SetTransientData(key string, value any) {
|
||||
func (r *Room) SetTransientData(key string, value any) error {
|
||||
if value == nil {
|
||||
return r.RemoveTransientData(key)
|
||||
}
|
||||
|
||||
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionSet,
|
||||
Key: key,
|
||||
Value: value,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Room) doSetTransientData(key string, value any) {
|
||||
r.transientData.Set(key, value)
|
||||
}
|
||||
|
||||
func (r *Room) SetTransientDataTTL(key string, value any, ttl time.Duration) {
|
||||
func (r *Room) SetTransientDataTTL(key string, value any, ttl time.Duration) error {
|
||||
if value == nil {
|
||||
return r.RemoveTransientData(key)
|
||||
} else if ttl == 0 {
|
||||
return r.SetTransientData(key, value)
|
||||
}
|
||||
|
||||
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionSet,
|
||||
Key: key,
|
||||
Value: value,
|
||||
TTL: ttl,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Room) doSetTransientDataTTL(key string, value any, ttl time.Duration) {
|
||||
r.transientData.SetTTL(key, value, ttl)
|
||||
}
|
||||
|
||||
func (r *Room) RemoveTransientData(key string) {
|
||||
func (r *Room) RemoveTransientData(key string) error {
|
||||
return r.events.PublishBackendRoomMessage(r.Id(), r.Backend(), &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &BackendServerRoomRequest{
|
||||
Type: "transient",
|
||||
Transient: &BackendRoomTransientRequest{
|
||||
Action: TransientActionDelete,
|
||||
Key: key,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Room) doRemoveTransientData(key string) {
|
||||
r.transientData.Remove(key)
|
||||
}
|
||||
|
||||
func (r *Room) fetchInitialTransientData() {
|
||||
if r.hub.rpcClients == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := NewLoggerContext(context.Background(), r.logger)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
// +checklocks:mu
|
||||
var initial TransientDataEntries
|
||||
for _, client := range r.hub.rpcClients.GetClients() {
|
||||
wg.Add(1)
|
||||
go func(c *GrpcClient) {
|
||||
defer wg.Done()
|
||||
|
||||
data, err := c.GetTransientData(ctx, r)
|
||||
if err != nil {
|
||||
r.logger.Printf("Received error while getting transient data for %s@%s from %s: %s", r.Id(), r.Backend().Id(), c.Target(), err)
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Printf("Received initial transient data %+v from %s", data, c.Target())
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if initial == nil {
|
||||
initial = make(TransientDataEntries)
|
||||
}
|
||||
maps.Copy(initial, data)
|
||||
}(client)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(initial) > 0 {
|
||||
r.transientData.SetInitial(initial)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -34,10 +35,58 @@ type TransientListener interface {
|
|||
SendMessage(message *ServerMessage) bool
|
||||
}
|
||||
|
||||
type TransientDataEntry struct {
|
||||
Value any `json:"value"`
|
||||
Expires time.Time `json:"expires,omitzero"`
|
||||
}
|
||||
|
||||
func NewTransientDataEntry(value any, ttl time.Duration) *TransientDataEntry {
|
||||
entry := &TransientDataEntry{
|
||||
Value: value,
|
||||
}
|
||||
if ttl > 0 {
|
||||
entry.Expires = time.Now().Add(ttl)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func NewTransientDataEntryWithExpires(value any, expires time.Time) *TransientDataEntry {
|
||||
entry := &TransientDataEntry{
|
||||
Value: value,
|
||||
Expires: expires,
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func (e *TransientDataEntry) clone() *TransientDataEntry {
|
||||
result := *e
|
||||
return &result
|
||||
}
|
||||
|
||||
func (e *TransientDataEntry) update(value any, ttl time.Duration) {
|
||||
e.Value = value
|
||||
if ttl > 0 {
|
||||
e.Expires = time.Now().Add(ttl)
|
||||
} else {
|
||||
e.Expires = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
type TransientDataEntries map[string]*TransientDataEntry
|
||||
|
||||
func (e TransientDataEntries) String() string {
|
||||
data, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not serialize %#v: %s", e, err)
|
||||
}
|
||||
|
||||
return string(data)
|
||||
}
|
||||
|
||||
type TransientData struct {
|
||||
mu sync.Mutex
|
||||
// +checklocks:mu
|
||||
data api.StringMap
|
||||
data TransientDataEntries
|
||||
// +checklocks:mu
|
||||
listeners map[TransientListener]bool
|
||||
// +checklocks:mu
|
||||
|
|
@ -66,8 +115,8 @@ func (t *TransientData) notifySet(key string, prev, value any) {
|
|||
TransientData: &TransientDataServerMessage{
|
||||
Type: "set",
|
||||
Key: key,
|
||||
OldValue: prev,
|
||||
Value: value,
|
||||
OldValue: prev,
|
||||
},
|
||||
}
|
||||
for listener := range t.listeners {
|
||||
|
|
@ -76,15 +125,17 @@ func (t *TransientData) notifySet(key string, prev, value any) {
|
|||
}
|
||||
|
||||
// +checklocks:t.mu
|
||||
func (t *TransientData) notifyDeleted(key string, prev any) {
|
||||
func (t *TransientData) notifyDeleted(key string, prev *TransientDataEntry) {
|
||||
msg := &ServerMessage{
|
||||
Type: "transient",
|
||||
TransientData: &TransientDataServerMessage{
|
||||
Type: "remove",
|
||||
Key: key,
|
||||
OldValue: prev,
|
||||
Type: "remove",
|
||||
Key: key,
|
||||
},
|
||||
}
|
||||
if prev != nil {
|
||||
msg.TransientData.OldValue = prev.Value
|
||||
}
|
||||
for listener := range t.listeners {
|
||||
t.sendMessageToListener(listener, msg)
|
||||
}
|
||||
|
|
@ -100,11 +151,15 @@ func (t *TransientData) AddListener(listener TransientListener) {
|
|||
}
|
||||
t.listeners[listener] = true
|
||||
if len(t.data) > 0 {
|
||||
data := make(api.StringMap, len(t.data))
|
||||
for k, v := range t.data {
|
||||
data[k] = v.Value
|
||||
}
|
||||
msg := &ServerMessage{
|
||||
Type: "transient",
|
||||
TransientData: &TransientDataServerMessage{
|
||||
Type: "initial",
|
||||
Data: t.data,
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
t.sendMessageToListener(listener, msg)
|
||||
|
|
@ -157,12 +212,19 @@ func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration)
|
|||
}
|
||||
|
||||
// +checklocks:t.mu
|
||||
func (t *TransientData) doSet(key string, value any, prev any, ttl time.Duration) {
|
||||
func (t *TransientData) doSet(key string, value any, prev *TransientDataEntry, ttl time.Duration) {
|
||||
if t.data == nil {
|
||||
t.data = make(api.StringMap)
|
||||
t.data = make(TransientDataEntries)
|
||||
}
|
||||
t.data[key] = value
|
||||
t.notifySet(key, prev, value)
|
||||
var oldValue any
|
||||
if prev == nil {
|
||||
entry := NewTransientDataEntry(value, ttl)
|
||||
t.data[key] = entry
|
||||
} else {
|
||||
oldValue = prev.Value
|
||||
prev.update(value, ttl)
|
||||
}
|
||||
t.notifySet(key, oldValue, value)
|
||||
t.removeAfterTTL(key, value, ttl)
|
||||
}
|
||||
|
||||
|
|
@ -183,7 +245,7 @@ func (t *TransientData) SetTTL(key string, value any, ttl time.Duration) bool {
|
|||
defer t.mu.Unlock()
|
||||
|
||||
prev, found := t.data[key]
|
||||
if found && reflect.DeepEqual(prev, value) {
|
||||
if found && reflect.DeepEqual(prev.Value, value) {
|
||||
t.updateTTL(key, value, ttl)
|
||||
return false
|
||||
}
|
||||
|
|
@ -210,7 +272,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du
|
|||
defer t.mu.Unlock()
|
||||
|
||||
prev, found := t.data[key]
|
||||
if old != nil && (!found || !reflect.DeepEqual(prev, old)) {
|
||||
if old != nil && (!found || !reflect.DeepEqual(prev.Value, old)) {
|
||||
return false
|
||||
} else if old == nil && found {
|
||||
return false
|
||||
|
|
@ -221,7 +283,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du
|
|||
}
|
||||
|
||||
// +checklocks:t.mu
|
||||
func (t *TransientData) doRemove(key string, prev any) {
|
||||
func (t *TransientData) doRemove(key string, prev *TransientDataEntry) {
|
||||
delete(t.data, key)
|
||||
if old, found := t.timers[key]; found {
|
||||
old.Stop()
|
||||
|
|
@ -257,7 +319,7 @@ func (t *TransientData) CompareAndRemove(key string, old any) bool {
|
|||
// +checklocks:t.mu
|
||||
func (t *TransientData) compareAndRemove(key string, old any) bool {
|
||||
prev, found := t.data[key]
|
||||
if !found || !reflect.DeepEqual(prev, old) {
|
||||
if !found || !reflect.DeepEqual(prev.Value, old) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -270,7 +332,66 @@ func (t *TransientData) GetData() api.StringMap {
|
|||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
result := make(api.StringMap)
|
||||
maps.Copy(result, t.data)
|
||||
if len(t.data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(api.StringMap, len(t.data))
|
||||
for k, entry := range t.data {
|
||||
result[k] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetEntries returns a copy of the internal data entries.
|
||||
func (t *TransientData) GetEntries() TransientDataEntries {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if len(t.data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(TransientDataEntries, len(t.data))
|
||||
for k, e := range t.data {
|
||||
result[k] = e.clone()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SetInitial sets the initial data and notifies listeners.
|
||||
func (t *TransientData) SetInitial(data TransientDataEntries) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.data == nil {
|
||||
t.data = make(TransientDataEntries)
|
||||
}
|
||||
|
||||
msgData := make(api.StringMap, len(data))
|
||||
for k, v := range data {
|
||||
if _, found := t.data[k]; found {
|
||||
// Entry already present (i.e. was set by regular event).
|
||||
continue
|
||||
}
|
||||
|
||||
msgData[k] = v.Value
|
||||
}
|
||||
if len(msgData) == 0 {
|
||||
return
|
||||
}
|
||||
msg := &ServerMessage{
|
||||
Type: "transient",
|
||||
TransientData: &TransientDataServerMessage{
|
||||
Type: "initial",
|
||||
Data: msgData,
|
||||
},
|
||||
}
|
||||
for listener := range t.listeners {
|
||||
t.sendMessageToListener(listener, msg)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ func Test_TransientMessages(t *testing.T) {
|
|||
t.Run(subtest, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
|
|
@ -245,13 +246,24 @@ func Test_TransientMessages(t *testing.T) {
|
|||
|
||||
client1.RunUntilErrorIs(ctx3, context.DeadlineExceeded)
|
||||
|
||||
require.NoError(client1.SetTransientData("abc", data, 10*time.Millisecond))
|
||||
ttl := 200 * time.Millisecond
|
||||
require.NoError(client1.SetTransientData("abc", data, ttl))
|
||||
setAt := time.Now()
|
||||
if msg, ok := client2.RunUntilMessage(ctx); ok {
|
||||
checkMessageTransientSet(t, msg, "abc", data, nil)
|
||||
}
|
||||
|
||||
client1.CloseWithBye()
|
||||
require.NoError(client1.WaitForClientRemoved(ctx))
|
||||
client2.RunUntilLeft(ctx, hello1.Hello)
|
||||
|
||||
client3, hello3 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"3")
|
||||
roomMsg = MustSucceed2(t, client3.JoinRoom, ctx, roomId)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
_, ignored, ok := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello)
|
||||
client2.RunUntilJoined(ctx, hello3.Hello)
|
||||
|
||||
_, ignored, ok := client3.RunUntilJoinedAndReturn(ctx, hello2.Hello, hello3.Hello)
|
||||
require.True(ok)
|
||||
|
||||
var msg *ServerMessage
|
||||
|
|
@ -263,13 +275,16 @@ func Test_TransientMessages(t *testing.T) {
|
|||
require.LessOrEqual(len(ignored), 1, "Received too many messages: %+v", ignored)
|
||||
}
|
||||
|
||||
checkMessageTransientInitial(t, msg, api.StringMap{
|
||||
"abc": data,
|
||||
})
|
||||
delta := time.Until(setAt.Add(ttl))
|
||||
if assert.Greater(delta, time.Duration(0), "test runner too slow?") {
|
||||
checkMessageTransientInitial(t, msg, api.StringMap{
|
||||
"abc": data,
|
||||
})
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if msg, ok = client3.RunUntilMessage(ctx); ok {
|
||||
checkMessageTransientRemove(t, msg, "abc", data)
|
||||
time.Sleep(delta)
|
||||
if msg, ok = client2.RunUntilMessage(ctx); ok {
|
||||
checkMessageTransientRemove(t, msg, "abc", data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue