From e75544f0a4d0813693572aabdfb025e6caa88c93 Mon Sep 17 00:00:00 2001 From: clowzed Date: Mon, 6 Nov 2023 01:05:48 +0300 Subject: [PATCH] Added cors layer. Added entity and migrations to store origins. Added endpoints for manipulation origins. Automatically clear origins on teardown and upload implemented. Added endpoints spec to openapi.yml. Untested --- Cargo.lock | 25 +++ Cargo.toml | 1 + entity/src/cors.rs | 38 +++++ entity/src/lib.rs | 1 + entity/src/mod.rs | 1 + entity/src/prelude.rs | 4 + entity/src/subdomain.rs | 8 + entity/src/user.rs | 14 +- migration/src/lib.rs | 2 + migration/src/m20231105_171000_create_cors.rs | 49 ++++++ openapi.yml | 79 ++++++++- src/apperror.rs | 158 +++++++----------- src/config/mod.rs | 2 +- src/extractors/mod.rs | 47 ++++-- src/handlers/cors.rs | 66 ++++++++ src/handlers/mod.rs | 1 + src/main.rs | 48 +++++- src/services/cors.rs | 22 +++ src/services/mod.rs | 1 + src/services/sites.rs | 5 + 20 files changed, 433 insertions(+), 139 deletions(-) create mode 100644 entity/src/cors.rs create mode 100644 migration/src/m20231105_171000_create_cors.rs create mode 100644 src/handlers/cors.rs create mode 100644 src/services/cors.rs diff --git a/Cargo.lock b/Cargo.lock index efc031a..1708a42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1357,6 +1357,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" + [[package]] name = "httparse" version = "1.8.0" @@ -2625,6 +2631,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-util", + "tower-http", "tracing", "tracing-subscriber", "uuid", @@ -3267,6 +3274,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +dependencies = [ + "bitflags 2.4.0", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index bc6c532..1b40851 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ bytes = "1.5.0" async-trait = "0.1.73" tokio-util = { version = "0.7.9", features = ["io"] } dotenv = "0.15.0" +tower-http = { version = "0.4.4", features = ["cors"] } [workspace] members = [".", "entity", "migration"] diff --git a/entity/src/cors.rs b/entity/src/cors.rs new file mode 100644 index 0000000..229fcf0 --- /dev/null +++ b/entity/src/cors.rs @@ -0,0 +1,38 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.3 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "cors")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub subdomain_id: i32, + pub origin: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::subdomain::Entity", + from = "Column::SubdomainId", + to = "super::subdomain::Column::Id", + on_update = "NoAction", + on_delete = "Cascade" + )] + Subdomain, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Subdomain.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} + +impl Model { + pub fn matches(&self, origin: &str) -> bool { + self.origin == "*" || self.origin == origin + } +} diff --git a/entity/src/lib.rs b/entity/src/lib.rs index 51a1bfb..e78438c 100644 --- a/entity/src/lib.rs +++ b/entity/src/lib.rs @@ -1,5 +1,6 @@ pub mod prelude; +pub mod cors; pub mod file; pub mod subdomain; pub mod user; diff --git a/entity/src/mod.rs b/entity/src/mod.rs index 302ec71..40aed17 100644 --- a/entity/src/mod.rs +++ b/entity/src/mod.rs @@ -2,6 +2,7 @@ pub mod prelude; +pub mod cors; pub mod file; pub mod subdomain; pub mod user; diff --git a/entity/src/prelude.rs b/entity/src/prelude.rs index 4191c1a..c0b43d3 100644 --- a/entity/src/prelude.rs +++ b/entity/src/prelude.rs @@ -1,15 +1,19 @@ +pub use super::cors::Entity as CorsEntity; pub use super::file::Entity as FileEntity; pub use super::subdomain::Entity as SubdomainEntity; pub use super::user::Entity as UserEntity; +pub use super::cors::ActiveModel as ActiveCors; pub use super::file::ActiveModel as ActiveFile; pub use super::subdomain::ActiveModel as ActiveSubdomain; pub use super::user::ActiveModel as ActiveUser; +pub use super::cors::Model as Cors; pub use super::file::Model as File; pub use super::subdomain::Model as Subdomain; pub use super::user::Model as User; +pub use super::cors::Column as CorsColumn; pub use super::file::Column as FileColumn; pub use super::subdomain::Column as SubdomainColumn; pub use super::user::Column as UserColumn; diff --git a/entity/src/subdomain.rs b/entity/src/subdomain.rs index dafb0ff..6230909 100644 --- a/entity/src/subdomain.rs +++ b/entity/src/subdomain.rs @@ -17,6 +17,8 @@ pub struct Model { #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { + #[sea_orm(has_many = "super::cors::Entity")] + Cors, #[sea_orm(has_many = "super::file::Entity")] File, #[sea_orm( @@ -29,6 +31,12 @@ pub enum Relation { User, } +impl Related for Entity { + fn to() -> RelationDef { + Relation::Cors.def() + } +} + impl Related for Entity { fn to() -> RelationDef { Relation::File.def() diff --git a/entity/src/user.rs b/entity/src/user.rs index c1dc066..89a6932 100644 --- a/entity/src/user.rs +++ b/entity/src/user.rs @@ -1,10 +1,8 @@ //! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.3 -use std::fmt::Debug; - use sea_orm::entity::prelude::*; -#[derive(Clone, PartialEq, DeriveEntityModel, Eq)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] #[sea_orm(table_name = "user")] pub struct Model { #[sea_orm(primary_key)] @@ -27,13 +25,3 @@ impl Related for Entity { } impl ActiveModelBehavior for ActiveModel {} - -impl Debug for Model { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Model") - .field("id", &self.id) - .field("username", &self.username) - .field("password", &"...") - .finish() - } -} diff --git a/migration/src/lib.rs b/migration/src/lib.rs index a3447e9..605ae38 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -3,6 +3,7 @@ pub use sea_orm_migration::prelude::*; mod m20230927_162921_create_users; mod m20230929_081415_create_subdomains; mod m20230929_152215_create_files; +mod m20231105_171000_create_cors; pub struct Migrator; @@ -13,6 +14,7 @@ impl MigratorTrait for Migrator { Box::new(m20230927_162921_create_users::Migration), Box::new(m20230929_081415_create_subdomains::Migration), Box::new(m20230929_152215_create_files::Migration), + Box::new(m20231105_171000_create_cors::Migration), ] } } diff --git a/migration/src/m20231105_171000_create_cors.rs b/migration/src/m20231105_171000_create_cors.rs new file mode 100644 index 0000000..a5421dd --- /dev/null +++ b/migration/src/m20231105_171000_create_cors.rs @@ -0,0 +1,49 @@ +use sea_orm_migration::prelude::*; + +use crate::m20230929_081415_create_subdomains::Subdomain; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Cors::Table) + .if_not_exists() + .col( + ColumnDef::new(Cors::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col(ColumnDef::new(Cors::SubdomainId).integer().not_null()) + .foreign_key( + ForeignKey::create() + .from(Cors::Table, Cors::SubdomainId) + .to(Subdomain::Table, Subdomain::Id) + .on_delete(ForeignKeyAction::Cascade), + ) + .col(ColumnDef::new(Cors::Origin).string().not_null()) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(Cors::Table).to_owned()) + .await + } +} + +#[derive(DeriveIden)] +pub enum Cors { + Table, + Id, + SubdomainId, + Origin, +} diff --git a/openapi.yml b/openapi.yml index a09fc6c..55eccac 100644 --- a/openapi.yml +++ b/openapi.yml @@ -294,11 +294,6 @@ paths: responses: "200": description: Site was successfully uploaded - content: - application/octet-stream: - schema: - type: string - format: binary "403": description: Subdomain is owned by another user @@ -324,6 +319,74 @@ paths: application/json: schema: $ref: "#/components/schemas/Details" + /api/cors/add: + post: + summary: Add origin + requestBody: + required: true + content: + application/x-www-form-urlencoded: + schema: + $ref: "#/components/schemas/Origin" + responses: + "200": + description: Origin was successfully added + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "401": + description: Authentication failed, unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "403": + description: Subdomain is owned by another user + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "500": + description: Server error + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + /api/cors/clear: + post: + summary: Clear related origins + + responses: + "200": + description: Origins were successfully removed + + "400": + description: Bad request + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "401": + description: Authentication failed, unauthorized + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "403": + description: Subdomain is owned by another user + content: + application/json: + schema: + $ref: "#/components/schemas/Details" + "500": + description: Server error + content: + application/json: + schema: + $ref: "#/components/schemas/Details" components: schemas: @@ -344,6 +407,12 @@ components: properties: details: type: string + Origin: + type: object + properties: + origin: + type: string + securitySchemes: bearerAuth: type: http diff --git a/src/apperror.rs b/src/apperror.rs index f608098..eaf3fb3 100644 --- a/src/apperror.rs +++ b/src/apperror.rs @@ -25,109 +25,67 @@ pub enum SeroError { EmptyCredentials, } -impl IntoResponse for SeroError { - fn into_response(self) -> axum::response::Response { - let response = match self { - SeroError::XSubdomainHeaderMissing => ( - StatusCode::BAD_REQUEST, - Json(Details { - details: "X-Subdomain header is missing!".into(), - }), - ), - SeroError::AuthorizationHeaderMissing => ( - StatusCode::BAD_REQUEST, - Json(Details { - details: "Authorization header is missing!".into(), - }), - ), - SeroError::AuthorizationHeaderBadSchema => ( - StatusCode::BAD_REQUEST, - Json(Details { - details: "Authorization header does not match schema! - Required schema: Authorization: Bearer " - .into(), - }), - ), - SeroError::SubdomainIsOwnedByAnotherUser(subdomain_name) => ( - StatusCode::FORBIDDEN, - Json(Details { - details: format!( - "Subdomain with name {} is owned by another user!", - subdomain_name - ), - }), - ), - SeroError::AuthorizationHeaderBabChars => ( - StatusCode::BAD_REQUEST, - Json(Details { - details: "Authorization header contains invalid characters!".into(), - }), - ), +impl std::fmt::Debug for SeroError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + SeroError::XSubdomainHeaderMissing => "X-Subdomain header is missing!".to_string(), + SeroError::AuthorizationHeaderMissing => "Authorization header is missing!".to_string(), + SeroError::AuthorizationHeaderBadSchema => "Authorization header does not match schema! Required schema: Authorization: Bearer ".to_string(), + SeroError::SubdomainIsOwnedByAnotherUser(subdomain_name) => format!("Subdomain with name {} is owned by another user!", subdomain_name), + SeroError::AuthorizationHeaderBabChars => "Authorization header contains invalid characters!".to_string(), SeroError::InternalServerError(cause) => { tracing::error!(%cause, "Error!"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(Details { - details: "Some error occurred on the server!".into(), - }), - ) - } - SeroError::UserWasNotFoundUsingJwt => ( - StatusCode::UNAUTHORIZED, - Json(Details { - details: "User with id from jwt token was not found!".into(), - }), - ), - SeroError::RegisteredUserLimitExceeded => ( - StatusCode::FORBIDDEN, - Json(Details { - details: "Registered user limit exceeded!".into(), - }), - ), - SeroError::Unauthorized => ( - StatusCode::UNAUTHORIZED, - Json(Details { - details: "Unauthorized! Bad credentials were provided!".into(), - }), - ), - SeroError::UserHasAlreadyBeenRegistered => ( - StatusCode::CONFLICT, - Json(Details { - details: "User with this username has already been registered!".into(), - }), - ), - SeroError::SubdomainWasNotFound(subdomain_name) => ( - StatusCode::NOT_FOUND, - Json(Details { - details: format!("Subdomain with name {subdomain_name} was not found!"), - }), - ), - SeroError::ArchiveFileWasNotFoundForSubdomain(subdomain_name) => ( - StatusCode::NOT_FOUND, - Json(Details { - details: format!("Archive file was not found for subdomain {subdomain_name}"), - }), - ), - SeroError::MaxSitesPerUserLimitExceeded => ( - StatusCode::FORBIDDEN, - Json(Details { - details: "Max sites per this user limit exceeded!".into(), - }), - ), - SeroError::SiteDisabled => ( - StatusCode::SERVICE_UNAVAILABLE, - Json(Details { - details: "Service is currently unavailable!".into(), - }), - ), - SeroError::EmptyCredentials => ( - StatusCode::BAD_REQUEST, - Json(Details { - details: "Username or password is empty!".into(), - }), - ), - }; + "Some error occurred on the server!".to_string() + }, + SeroError::UserWasNotFoundUsingJwt => "User with id from jwt token was not found!".to_string(), + SeroError::RegisteredUserLimitExceeded => "Registered user limit exceeded!".to_string(), + SeroError::Unauthorized => "Unauthorized! Bad credentials were provided!".to_string(), + SeroError::UserHasAlreadyBeenRegistered => "User with this username has already been registered!".to_string(), + SeroError::SubdomainWasNotFound(subdomain_name) => format!("Subdomain with name {} was not found!", subdomain_name), + SeroError::ArchiveFileWasNotFoundForSubdomain(subdomain_name) => format!("Archive file was not found for subdomain {}", subdomain_name), + SeroError::MaxSitesPerUserLimitExceeded => "Max sites per this user limit exceeded!".to_string(), + SeroError::SiteDisabled => "Site is disabled!".to_string(), + SeroError::EmptyCredentials => "Empty credentials were provided!".to_string(), + }) + } +} +impl std::fmt::Display for SeroError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From<&SeroError> for StatusCode { + fn from(val: &SeroError) -> Self { + match val { + SeroError::XSubdomainHeaderMissing => StatusCode::BAD_REQUEST, + SeroError::AuthorizationHeaderMissing => StatusCode::BAD_REQUEST, + SeroError::AuthorizationHeaderBadSchema => StatusCode::BAD_REQUEST, + SeroError::SubdomainIsOwnedByAnotherUser(_) => StatusCode::FORBIDDEN, + SeroError::AuthorizationHeaderBabChars => StatusCode::BAD_REQUEST, + SeroError::InternalServerError(_) => StatusCode::INTERNAL_SERVER_ERROR, + SeroError::UserWasNotFoundUsingJwt => StatusCode::UNAUTHORIZED, + SeroError::RegisteredUserLimitExceeded => StatusCode::FORBIDDEN, + SeroError::Unauthorized => StatusCode::UNAUTHORIZED, + SeroError::UserHasAlreadyBeenRegistered => StatusCode::CONFLICT, + SeroError::SubdomainWasNotFound(_) => StatusCode::NOT_FOUND, + SeroError::ArchiveFileWasNotFoundForSubdomain(_) => StatusCode::NOT_FOUND, + SeroError::MaxSitesPerUserLimitExceeded => StatusCode::FORBIDDEN, + SeroError::SiteDisabled => StatusCode::SERVICE_UNAVAILABLE, + SeroError::EmptyCredentials => StatusCode::BAD_REQUEST, + } + } +} + +impl IntoResponse for SeroError { + fn into_response(self) -> axum::response::Response { + let response = ( + Into::::into(&self), + Json(Details { + details: format!("{:?}", self), + }), + ); tracing::error!(cause = response.1.details, "Response with error!"); response.into_response() } diff --git a/src/config/mod.rs b/src/config/mod.rs index 93f7bc8..2f37925 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -13,7 +13,7 @@ impl Default for Config { let mut config: Self = envy::from_env().expect("Failed to read config from environment variables!"); if config.jwt_secret.is_none() { - config.jwt_secret = Some(uuid::Uuid::new_v4().to_string()) + config.jwt_secret = Some(uuid::Uuid::new_v4().to_string()); } config } diff --git a/src/extractors/mod.rs b/src/extractors/mod.rs index b9ad79d..5504af2 100644 --- a/src/extractors/mod.rs +++ b/src/extractors/mod.rs @@ -1,7 +1,7 @@ use axum::{ async_trait, extract::{FromRef, FromRequestParts}, - http::request::Parts, + http::{request::Parts, HeaderMap}, }; use sea_orm::prelude::*; use std::sync::Arc; @@ -59,19 +59,10 @@ where } } -#[async_trait] -impl FromRequestParts for Subdomain -where - Arc: FromRef, - - S: Send + Sync, -{ - type Rejection = SeroError; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { +impl Subdomain { + pub fn from_headers(headers: &HeaderMap) -> Result { Ok(Self({ - let header = parts - .headers + let header = headers .get("X-Subdomain") .ok_or(SeroError::XSubdomainHeaderMissing)? .to_str() @@ -86,6 +77,20 @@ where } } +#[async_trait] +impl FromRequestParts for Subdomain +where + Arc: FromRef, + + S: Send + Sync, +{ + type Rejection = SeroError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + Self::from_headers(&parts.headers) + } +} + #[async_trait] impl FromRequestParts for SubdomainModel where @@ -96,10 +101,20 @@ where #[tracing::instrument(skip(parts, state))] async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + Self::from_headers(&parts.headers, state).await + } +} + +impl SubdomainModel { + pub async fn from_headers(headers: &HeaderMap, state: &S) -> Result + where + Arc: FromRef, + S: Send + Sync, + { let app_state = Arc::from_ref(state); - let subdomain_name = Subdomain::from_request_parts(parts, state).await?.0; - Ok(match entity::prelude::SubdomainEntity::find() + let subdomain_name = Subdomain::from_headers(headers)?.0; + match entity::prelude::SubdomainEntity::find() .filter(entity::prelude::SubdomainColumn::Name.eq(&subdomain_name)) .one(&app_state.connection) .await @@ -107,7 +122,7 @@ where Ok(Some(subdomain)) => Ok(Self(subdomain)), Ok(None) => Err(SeroError::SubdomainWasNotFound(subdomain_name)), Err(cause) => Err(SeroError::InternalServerError(Box::new(cause))), - }?) + } } } diff --git a/src/handlers/cors.rs b/src/handlers/cors.rs new file mode 100644 index 0000000..d336a61 --- /dev/null +++ b/src/handlers/cors.rs @@ -0,0 +1,66 @@ +use crate::{ + apperror::SeroError, + extractors::{AuthJWT, SubdomainModel}, + AppState, +}; + +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Form, +}; +use entity::prelude::*; +use sea_orm::{prelude::*, Set}; +use std::sync::Arc; + +#[derive(Debug, serde::Deserialize)] +pub struct OriginForm { + origin: String, +} + +#[tracing::instrument(skip(state))] +pub async fn add_origin( + State(state): State>, + SubdomainModel(subdomain_model): SubdomainModel, + AuthJWT(user): AuthJWT, + Form(origin_form): Form, +) -> Response { + if subdomain_model.owner_id != user.id { + return SeroError::SubdomainIsOwnedByAnotherUser(subdomain_model.name).into_response(); + } + + let active_cors_origin = ActiveCors { + origin: Set(origin_form.origin), + subdomain_id: Set(subdomain_model.id), + ..Default::default() + }; + + match CorsEntity::insert(active_cors_origin) + .exec(&state.connection) + .await + { + Ok(_) => StatusCode::NO_CONTENT.into_response(), + Err(cause) => SeroError::InternalServerError(Box::new(cause)).into_response(), + } +} + +#[tracing::instrument(skip(state))] +pub async fn clear_all( + State(state): State>, + SubdomainModel(subdomain_model): SubdomainModel, + AuthJWT(user): AuthJWT, +) -> Response { + if subdomain_model.owner_id != user.id { + return SeroError::SubdomainIsOwnedByAnotherUser(subdomain_model.name).into_response(); + } + + match CorsEntity::delete_many() + .filter(CorsColumn::SubdomainId.eq(subdomain_model.id)) + .exec(&state.connection) + .await + { + Ok(_) => StatusCode::NO_CONTENT.into_response(), + Err(cause) => SeroError::InternalServerError(Box::new(cause)).into_response(), + } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index e96eb0f..00a30fe 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,2 +1,3 @@ pub mod auth; +pub mod cors; pub mod sites; diff --git a/src/main.rs b/src/main.rs index 776a51e..cf6c469 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,14 @@ use std::{fmt::Debug, net::SocketAddr}; use axum::{ extract::DefaultBodyLimit, - http::StatusCode, + http::{request::Parts, HeaderName, HeaderValue, Method, StatusCode}, routing::{get, post}, Router, }; +use extractors::SubdomainModel; +use services::cors::CorsService; +use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer}; + use migration::{Migrator, MigratorTrait}; use sea_orm::{ConnectOptions, Database}; @@ -47,17 +51,53 @@ async fn main() { .route("/teardown", post(handlers::sites::teardown)) .route("/download", post(handlers::sites::download)) .route("/enable", post(handlers::sites::enable)) - .route("/disable", post(handlers::sites::disable)); + .route("/disable", post(handlers::sites::disable)) + .route("/cors/add", post(handlers::cors::add_origin)) + .route("/cors/clear", post(handlers::cors::clear_all)); let state = std::sync::Arc::new(AppState { connection, config: Default::default(), }); - let mut app = Router::new() - .nest("/api", api_router) + let cloned_state = state.clone(); + + let files_router = Router::new() .route("/*path", get(handlers::sites::file)) .route("/", get(handlers::sites::index_redirect)) + .layer( + CorsLayer::new() + .allow_methods(AllowMethods::exact(Method::GET)) + .allow_headers(AllowHeaders::list([HeaderName::from_static("X-Subdomain")])) + .allow_origin(AllowOrigin::predicate( + move |origin: &HeaderValue, parts: &Parts| { + let origin = origin.to_str().unwrap_or_default(); + let subdomain_model_future = + SubdomainModel::from_headers(&parts.headers, &cloned_state); + + tokio::runtime::Handle::current().block_on(async { + let subdomain_model = subdomain_model_future.await; + match subdomain_model { + Ok(model) => match CorsService::check(model.0, origin, &cloned_state.connection).await{ + Ok(result) => result, + Err(cause) => { + tracing::error!(%cause, "Failed to check origin for cors filtering!"); + false + } + }, + Err(cause) => { + tracing::error!(%cause, "Failed to find subdomain model for cors filtering!"); + false + }, + } + }) + }, + )), + ); + + let mut app = Router::new() + .nest("/api", api_router) + .nest("/", files_router) .with_state(state.clone()); if config.max_body_limit_size.is_some() { diff --git a/src/services/cors.rs b/src/services/cors.rs new file mode 100644 index 0000000..3a995d0 --- /dev/null +++ b/src/services/cors.rs @@ -0,0 +1,22 @@ +use entity::prelude::*; +use sea_orm::{ConnectionTrait, DbErr, ModelTrait}; + +use sea_orm::TransactionTrait; + +pub struct CorsService; + +impl CorsService { + #[tracing::instrument(skip(connection))] + pub async fn check( + subdomain: Subdomain, + origin: &str, + connection: &T, + ) -> Result { + Ok(subdomain + .find_related(CorsEntity) + .all(connection) + .await? + .iter() + .any(|origin_model| origin_model.matches(origin))) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index ec40696..5e18e7f 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,4 +1,5 @@ pub mod archive; pub mod auth; +pub mod cors; pub mod sites; pub mod users; diff --git a/src/services/sites.rs b/src/services/sites.rs index 62d2844..587f567 100644 --- a/src/services/sites.rs +++ b/src/services/sites.rs @@ -88,6 +88,11 @@ impl SitesService { } }; + CorsEntity::delete_many() + .filter(CorsColumn::SubdomainId.eq(subdomain.id)) + .exec(connection) + .await?; + match new_archive_file.write_all(&contents).await { Ok(()) => { let mut active: ActiveSubdomain = subdomain.clone().into();