diff --git a/src/account.rs b/src/account.rs index 0241480..47438df 100644 --- a/src/account.rs +++ b/src/account.rs @@ -35,10 +35,18 @@ pub fn is_logged_in(req: &HttpRequest) -> Option<&str> { } } -pub fn has_admintoken(req: &HttpRequest) -> Option<&str> { +pub fn has_admintoken(req: &HttpRequest) -> Option { + get_cookie(req, "sncf_admin_token") +} + +pub fn has_csrftoken(req: &HttpRequest) -> Option { + get_cookie(req, "sncf_csrf_cookie") +} + +fn get_cookie(req: &HttpRequest, cookie_name: &str) -> Option { let c = req.headers().get("Cookie")?.to_str().ok()?; - if c.contains("sncf_admin_token") { - Some(c) + if c.contains(cookie_name) { + Some(c.to_string()) } else { None } diff --git a/src/config.rs b/src/config.rs index 07c9114..34fc4a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -77,3 +77,9 @@ impl Config { } } } + +pub fn get_csrf_key() -> [u8; 32] { + let mut key: [u8; 32] = Default::default(); + key.copy_from_slice(&CONFIG.cookie_key.clone().into_bytes()[..32]); + key +} diff --git a/src/forward.rs b/src/forward.rs index 0c2ff6d..f824561 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -5,7 +5,9 @@ use chrono::Utc; use regex::Regex; use std::time::Duration; use url::Url; +use csrf::{AesGcmCsrfProtection, CsrfProtection}; +use crate::config::get_csrf_key; use crate::account::*; use crate::config::PAYLOAD_LIMIT; use crate::config::PROXY_TIMEOUT; @@ -129,6 +131,11 @@ pub struct LoginToken { pub token: String, } +#[derive(Deserialize)] +pub struct CsrfToken { + pub csrf_token: String, +} + pub async fn forward_login( req: HttpRequest, params: web::Path, @@ -175,6 +182,7 @@ pub async fn forward_login( // the account gets associated with a token in sqlite DB. pub async fn forward_register( req: HttpRequest, + csrf_post: web::Form, client: web::Data, dbpool: web::Data, ) -> Result { @@ -223,6 +231,48 @@ pub async fn forward_register( } } + // check if the csrf token is OK + if let Some(cookie_token) = has_csrftoken(&req) { + lazy_static! { + static ref RE: Regex = Regex::new(r#"sncf_csrf_cookie=(?P[0-9A-Za-z_\-]*)"#) + .expect("Error while parsing the sncf_csrf_cookie regex"); + } + let cookie_csrf_token = RE + .captures(&cookie_token) + .ok_or_else(|| { + eprintln!("error_csrf_cookie: no capture"); + crash(get_lang(&req), "error_csrf_cookie") + })? + .name("token") + .ok_or_else(|| { + eprintln!("error_csrf_cookie: no capture named token"); + crash(get_lang(&req), "error_csrf_cookie") + })? + .as_str(); + + let raw_ctoken = base64::decode_config(cookie_csrf_token.as_bytes(), base64::URL_SAFE_NO_PAD).map_err(|e| { + eprintln!("error_csrf_cookie (base64): {}", e); + crash(get_lang(&req), "error_csrf_cookie") + })?; + + let raw_token = base64::decode_config(csrf_post.csrf_token.as_bytes(), base64::URL_SAFE_NO_PAD).map_err(|e| { + eprintln!("error_csrf_token (base64): {}", e); + crash(get_lang(&req), "error_csrf_token") + })?; + + let seed = AesGcmCsrfProtection::from_key(get_csrf_key()); + let parsed_token = seed.parse_token(&raw_token).expect("token not parsed"); + let parsed_cookie = seed.parse_cookie(&raw_ctoken).expect("cookie not parsed"); + if !seed.verify_token_pair(&parsed_token, &parsed_cookie) { + debug("warn: CSRF token doesn't match."); + return Err(crash(lang, "error_csrf_token")); + } + } + else { + debug("warn: missing CSRF token."); + return Err(crash(lang, "error_csrf_cookie")); + } + let nc_username = gen_name(); println!("gen_name: {}", nc_username); let nc_password = gen_token(45); @@ -315,11 +365,21 @@ fn web_redir(location: &str) -> HttpResponse { } pub async fn index(req: HttpRequest) -> Result { + + let seed = AesGcmCsrfProtection::from_key(get_csrf_key()); + let (csrf_token, csrf_cookie) = seed.generate_token_pair(None, 300) + .expect("couldn't generate token/cookie pair"); + Ok(HttpResponse::Ok() .content_type("text/html") + .set_header( + "Set-Cookie", + format!("sncf_csrf_cookie={}; HttpOnly; SameSite=Strict", + base64::encode_config(&csrf_cookie.value(), base64::URL_SAFE_NO_PAD))) .body( TplIndex { lang: &get_lang(&req), + csrf_token: &base64::encode_config(&csrf_token.value(), base64::URL_SAFE_NO_PAD), } .render() .map_err(|e| { diff --git a/src/main.rs b/src/main.rs index 3bdcf25..82c1053 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,12 +63,7 @@ async fn main() -> std::io::Result<()> { println!("No database specified. Please enter a MySQL, PostgreSQL or SQLite connection string in config.toml."); } - if CONFIG.debug_mode { - println!("Opening database {}", CONFIG.database_path); - } - else { - println!("Opening database..."); - } + debug(&format!("Opening database {}", CONFIG.database_path)); let manager = ConnectionManager::::new(&CONFIG.database_path); let pool = r2d2::Pool::builder() @@ -99,7 +94,7 @@ async fn main() -> std::io::Result<()> { /*.wrap(middleware::Compress::default())*/ .service(Files::new("/assets/", "./templates/assets/").index_file("index.html")) .route("/", web::get().to(index)) - .route("/link", web::get().to(forward_register)) + .route("/link", web::post().to(forward_register)) .route("/admin/{token}", web::get().to(forward_login)) .default_service(web::route().to(forward)) .data(String::configure(|cfg| cfg.limit(PAYLOAD_LIMIT))) diff --git a/src/templates.rs b/src/templates.rs index 23cc4c6..34bf1e5 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -7,6 +7,7 @@ use crate::config::Config; #[template(path = "index.html")] pub struct TplIndex<'a> { pub lang: &'a str, + pub csrf_token: &'a str, } #[derive(Template)]