diff --git a/grpc_server.go b/grpc_server.go index 236467d..0e1d30e 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -55,6 +55,14 @@ func init() { GrpcServerId = hex.EncodeToString(md.Sum(nil)) } +type GrpcServerHub interface { + GetSessionByResumeId(resumeId string) Session + GetSessionByPublicId(sessionId string) Session + GetSessionIdByRoomSessionId(roomSessionId string) (string, error) + + GetBackend(u *url.URL) *Backend +} + type GrpcServer struct { UnimplementedRpcBackendServer UnimplementedRpcInternalServer @@ -66,7 +74,7 @@ type GrpcServer struct { listener net.Listener serverId string // can be overwritten from tests - hub *Hub + hub GrpcServerHub } func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { @@ -131,7 +139,7 @@ func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSession statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging log.Printf("Lookup session id for room session id %s", request.RoomSessionId) - sid, err := s.hub.roomSessions.GetSessionId(request.RoomSessionId) + sid, err := s.hub.GetSessionIdByRoomSessionId(request.RoomSessionId) if errors.Is(err, ErrNoSuchRoomSession) { return nil, status.Error(codes.NotFound, "no such room session id") } else if err != nil { @@ -221,7 +229,7 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou return nil, status.Error(codes.InvalidArgument, "invalid url") } - backend := s.hub.backend.GetBackend(u) + backend := s.hub.GetBackend(u) if backend == nil { return nil, status.Error(codes.NotFound, "no such backend") } @@ -233,13 +241,18 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error { statsGrpcServerCalls.WithLabelValues("ProxySession").Inc() - client, err := newRemoteGrpcClient(s.hub, request) + hub, ok := s.hub.(*Hub) + if !ok { + return status.Error(codes.Internal, "invalid hub type") + + } + client, err := newRemoteGrpcClient(hub, request) if err != nil { return err } - sid := s.hub.registerClient(client) - defer s.hub.unregisterClient(sid) + sid := hub.registerClient(client) + defer hub.unregisterClient(sid) return client.run() } diff --git a/hub.go b/hub.go index 5684f3c..77a86c7 100644 --- a/hub.go +++ b/hub.go @@ -38,6 +38,7 @@ import ( "log" "net" "net/http" + "net/url" "strings" "sync" "sync/atomic" @@ -623,6 +624,10 @@ func (h *Hub) GetSessionByResumeId(resumeId string) Session { return session } +func (h *Hub) GetSessionIdByRoomSessionId(roomSessionId string) (string, error) { + return h.roomSessions.GetSessionId(roomSessionId) +} + func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession { url := backend.Url() @@ -641,6 +646,10 @@ func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession return nil } +func (h *Hub) GetBackend(u *url.URL) *Backend { + return h.backend.GetBackend(u) +} + func (h *Hub) checkExpiredSessions(now time.Time) { for session, expires := range h.expiredSessions { if now.After(expires) {