From 2ba76400c451d72d2389377910134a30bd27c265 Mon Sep 17 00:00:00 2001 From: clowzed Date: Wed, 8 Nov 2023 16:50:50 +0300 Subject: [PATCH] Rewritten cors check closure to tokio spawn with channels --- openapi.yml | 17 +++++++++++++++ src/main.rs | 63 +++++++++++++++++++++++++++++++---------------------- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/openapi.yml b/openapi.yml index 55eccac..c1dd5c9 100644 --- a/openapi.yml +++ b/openapi.yml @@ -322,6 +322,14 @@ paths: /api/cors/add: post: summary: Add origin + parameters: + - in: header + name: X-Subdomain + schema: + type: string + required: true + security: + - bearerAuth: [] requestBody: required: true content: @@ -358,6 +366,15 @@ paths: /api/cors/clear: post: summary: Clear related origins + parameters: + - in: header + name: X-Subdomain + schema: + type: string + required: true + security: + - bearerAuth: [] + responses: "200": diff --git a/src/main.rs b/src/main.rs index cf6c469..5b6d103 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ -use std::{fmt::Debug, net::SocketAddr}; +use std::{fmt::Debug, net::SocketAddr, sync::mpsc}; use axum::{ extract::DefaultBodyLimit, - http::{request::Parts, HeaderName, HeaderValue, Method, StatusCode}, + http::{HeaderName, Method, StatusCode}, routing::{get, post}, Router, }; @@ -68,31 +68,42 @@ async fn main() { .layer( CorsLayer::new() .allow_methods(AllowMethods::exact(Method::GET)) - .allow_headers(AllowHeaders::list([HeaderName::from_static("X-Subdomain")])) - .allow_origin(AllowOrigin::predicate( - move |origin: &HeaderValue, parts: &Parts| { - let origin = origin.to_str().unwrap_or_default(); - let subdomain_model_future = - SubdomainModel::from_headers(&parts.headers, &cloned_state); + .allow_headers(AllowHeaders::list([HeaderName::from_static("x-subdomain")])) + .allow_origin(AllowOrigin::predicate(move |origin, parts| { + let cloned_state = cloned_state.clone(); + let cloned_origin = origin + .clone() + .to_str() + .map(|s| s.to_string()) + .unwrap_or_default(); + let cloned_headers = parts.headers.clone(); + let (tx, rx) = mpsc::channel(); - tokio::runtime::Handle::current().block_on(async { - let subdomain_model = subdomain_model_future.await; - match subdomain_model { - Ok(model) => match CorsService::check(model.0, origin, &cloned_state.connection).await{ - Ok(result) => result, - Err(cause) => { - tracing::error!(%cause, "Failed to check origin for cors filtering!"); - false - } - }, - Err(cause) => { - tracing::error!(%cause, "Failed to find subdomain model for cors filtering!"); - false - }, - } - }) - }, - )), + tokio::spawn(async move { + let subdomain_model_extractor = + SubdomainModel::from_headers(&cloned_headers, &cloned_state) + .await + .map_err(|cause| { + tracing::error!(%cause, "Failed to extract subdomain model from headers for cors!"); + }); + if subdomain_model_extractor.is_err() { + tx.send(false).ok(); + return; + } + + let res = CorsService::check( + subdomain_model_extractor.unwrap().0, + &cloned_origin, + &cloned_state.connection, + ) + .await + .unwrap_or(false); + + tx.send(res).ok(); + }); + + rx.recv().unwrap_or(false) + })), ); let mut app = Router::new()