From 26cf3a96e74d856447eb96625c8ad64b5a6ee999 Mon Sep 17 00:00:00 2001 From: ppom Date: Fri, 20 Feb 2026 12:00:00 +0100 Subject: [PATCH] First draft of an nftables plugin Not compiling yet but I'm getting there. Must be careful on the unsafe, C-wrapping code! --- Cargo.lock | 105 ++++ Cargo.toml | 1 + plugins/reaction-plugin-nftables/Cargo.toml | 13 + .../reaction-plugin-nftables/src/action.rs | 493 ++++++++++++++++++ .../reaction-plugin-nftables/src/helpers.rs | 15 + plugins/reaction-plugin-nftables/src/main.rs | 169 ++++++ plugins/reaction-plugin-nftables/src/nft.rs | 69 +++ plugins/reaction-plugin-nftables/src/tests.rs | 253 +++++++++ shell.nix | 1 + 9 files changed, 1119 insertions(+) create mode 100644 plugins/reaction-plugin-nftables/Cargo.toml create mode 100644 plugins/reaction-plugin-nftables/src/action.rs create mode 100644 plugins/reaction-plugin-nftables/src/helpers.rs create mode 100644 plugins/reaction-plugin-nftables/src/main.rs create mode 100644 plugins/reaction-plugin-nftables/src/nft.rs create mode 100644 plugins/reaction-plugin-nftables/src/tests.rs diff --git a/Cargo.lock b/Cargo.lock index f5b71b2..f138507 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1980,6 +1980,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "libnftables1-sys" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b290d0d41f0ad578660aeed371bcae4cf85f129a6fe31350dbd2e097518cd7f" +dependencies = [ + "bindgen", +] + [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -2248,6 +2257,22 @@ dependencies = [ "wmi", ] +[[package]] +name = "nftables" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c57e7343eed9e9330e084eef12651b15be3c8ed7825915a0ffa33736b852bed" +dependencies = [ + "schemars", + "serde", + "serde_json", + "serde_path_to_error", + "strum", + "strum_macros", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "nix" version = "0.29.0" @@ -2899,6 +2924,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "reaction-plugin-nftables" +version = "0.1.0" +dependencies = [ + "libnftables1-sys", + "nftables", + "reaction-plugin", + "remoc", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "reaction-plugin-virtual" version = "1.0.0" @@ -2919,6 +2957,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "regex" version = "1.12.2" @@ -3194,6 +3252,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.114", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -3287,6 +3370,17 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "serde_json" version = "1.0.149" @@ -3300,6 +3394,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index dd537aa..39be1a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,7 @@ members = [ "plugins/reaction-plugin", "plugins/reaction-plugin-cluster", "plugins/reaction-plugin-ipset", + "plugins/reaction-plugin-nftables", "plugins/reaction-plugin-virtual" ] diff --git a/plugins/reaction-plugin-nftables/Cargo.toml b/plugins/reaction-plugin-nftables/Cargo.toml new file mode 100644 index 0000000..1de8e6b --- /dev/null +++ b/plugins/reaction-plugin-nftables/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "reaction-plugin-nftables" +version = "0.1.0" +edition = "2024" + +[dependencies] +tokio = { workspace = true, features = ["rt-multi-thread"] } +remoc.workspace = true +reaction-plugin.path = "../reaction-plugin" +serde.workspace = true +serde_json.workspace = true +nftables = { version = "0.6.3", features = ["tokio"] } +libnftables1-sys = { version = "0.1.1" } diff --git a/plugins/reaction-plugin-nftables/src/action.rs b/plugins/reaction-plugin-nftables/src/action.rs new file mode 100644 index 0000000..6649bd1 --- /dev/null +++ b/plugins/reaction-plugin-nftables/src/action.rs @@ -0,0 +1,493 @@ +use std::{ + borrow::Cow, + collections::HashSet, + fmt::{Debug, Display}, + u32, +}; + +use nftables::{ + batch::Batch, + expr::Expression, + helper::apply_ruleset_async, + schema::{Element, NfListObject, Rule, SetFlag, SetType, SetTypeValue}, + stmt::Statement, + types::{NfFamily, NfHook}, +}; +use reaction_plugin::{Exec, shutdown::ShutdownToken, time::parse_duration}; +use remoc::rch::mpsc as remocMpsc; +use serde::{Deserialize, Serialize}; + +use crate::{helpers::Version, nft::NftClient}; + +#[derive(Default, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] +pub enum IpVersion { + #[default] + #[serde(rename = "ip")] + Ip, + #[serde(rename = "ipv4")] + Ipv4, + #[serde(rename = "ipv6")] + Ipv6, +} +impl Debug for IpVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + IpVersion::Ipv4 => "ipv4", + IpVersion::Ipv6 => "ipv6", + IpVersion::Ip => "ip", + } + ) + } +} + +#[derive(Default, Debug, Serialize, Deserialize)] +pub enum AddDel { + #[default] + #[serde(alias = "add")] + Add, + #[serde(alias = "delete")] + Delete, +} + +/// User-facing action options +#[derive(Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ActionOptions { + /// The set that should be used by this action + pub set: String, + /// The pattern name of the IP. + /// Defaults to "ip" + #[serde(default = "serde_ip")] + pub pattern: String, + #[serde(skip)] + ip_index: usize, + // Whether the action is to "add" or "del" the ip from the set + #[serde(default)] + action: AddDel, + + #[serde(flatten)] + pub set_options: SetOptions, +} + +fn serde_ip() -> String { + "ip".into() +} + +impl ActionOptions { + pub fn set_ip_index(&mut self, patterns: Vec) -> Result<(), ()> { + self.ip_index = patterns + .into_iter() + .enumerate() + .filter(|(_, name)| name == &self.pattern) + .next() + .ok_or(())? + .0; + Ok(()) + } +} + +/// Merged set options +#[derive(Default, Clone, Deserialize, Serialize, Debug, PartialEq, Eq)] +pub struct SetOptions { + /// The IP type. + /// Defaults to `46`. + /// If `ipv4`: creates an IPv4 set with this name + /// If `ipv6`: creates an IPv6 set with this name + /// If `ip`: creates an IPv4 set with its name suffixed by 'v4' AND an IPv6 set with its name suffixed by 'v6' + /// *Merged set-wise*. + #[serde(default)] + version: Option, + /// Chains where the IP set should be inserted. + /// Defaults to `["input", "forward"]` + /// *Merged set-wise*. + #[serde(default)] + hooks: Option>, + // Optional timeout, letting linux/netfilter handle set removal instead of reaction + // Note that `reaction show` and `reaction flush` won't work if set instead of an `after` action + // Same syntax as after and retryperiod in reaction. + /// *Merged set-wise*. + #[serde(skip_serializing_if = "Option::is_none")] + timeout: Option, + #[serde(skip)] + timeout_u32: Option, + // Target that iptables should use when the IP is encountered. + // Defaults to DROP, but can also be ACCEPT, RETURN or any user-defined chain + /// *Merged set-wise*. + #[serde(default)] + target: Option, +} + +impl SetOptions { + pub fn merge(&mut self, options: &SetOptions) -> Result<(), String> { + // merge two Option and fail if there is conflict + fn inner_merge( + a: &mut Option, + b: &Option, + name: &str, + ) -> Result<(), String> { + match (&a, &b) { + (Some(aa), Some(bb)) => { + if aa != bb { + return Err(format!( + "Conflicting options for {name}: `{aa:?}` and `{bb:?}`" + )); + } + } + (None, Some(_)) => { + *a = b.clone(); + } + _ => (), + }; + Ok(()) + } + + inner_merge(&mut self.version, &options.version, "version")?; + inner_merge(&mut self.timeout, &options.timeout, "timeout")?; + inner_merge(&mut self.hooks, &options.hooks, "chains")?; + inner_merge(&mut self.target, &options.target, "target")?; + + if let Some(timeout) = &self.timeout { + let duration = parse_duration(timeout) + .map_err(|err| format!("failed to parse timeout: {}", err))? + .as_secs(); + if duration > u32::MAX as u64 { + return Err(format!( + "timeout is limited to {} seconds (approx {} days)", + u32::MAX, + 49_000 + )); + } + self.timeout_u32 = Some(duration as u32); + } + + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RHook { + Ingress, + Prerouting, + Forward, + Input, + Output, + Postrouting, + Egress, +} + +impl RHook { + pub fn as_str(&self) -> &'static str { + match self { + RHook::Ingress => "ingress", + RHook::Prerouting => "prerouting", + RHook::Forward => "forward", + RHook::Input => "input", + RHook::Output => "output", + RHook::Postrouting => "postrouting", + RHook::Egress => "egress", + } + } +} + +impl Display for RHook { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl From<&RHook> for NfHook { + fn from(value: &RHook) -> Self { + match value { + RHook::Ingress => Self::Ingress, + RHook::Prerouting => Self::Prerouting, + RHook::Forward => Self::Forward, + RHook::Input => Self::Input, + RHook::Output => Self::Output, + RHook::Postrouting => Self::Postrouting, + RHook::Egress => Self::Egress, + } + } +} + +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RStatement { + Accept, + Drop, + Continue, + Return, +} + +pub struct Set { + pub sets: SetNames, + pub hooks: Vec, + pub timeout: Option, + pub target: RStatement, +} + +impl Set { + pub fn from(name: String, options: SetOptions) -> Self { + Self { + sets: SetNames::new(name, options.version), + timeout: options.timeout_u32, + target: options.target.unwrap_or(RStatement::Drop), + hooks: options.hooks.unwrap_or(vec![RHook::Input, RHook::Forward]), + } + } + + pub fn init<'a>(&self, batch: &mut Batch<'a>) -> Result<(), String> { + for (set, version) in [ + (&self.sets.ipv4, Version::IPv4), + (&self.sets.ipv6, Version::IPv6), + ] { + if let Some(set) = set { + let family = NfFamily::INet; + let table = Cow::from("reaction"); + let name = Cow::from(set.as_str()); + + // create set + batch.add(NfListObject::<'a>::Set(Box::new(nftables::schema::Set { + family, + table: table.to_owned(), + name, + // TODO Try a set which is both ipv4 and ipv6? + set_type: SetTypeValue::Single(match version { + Version::IPv4 => SetType::Ipv4Addr, + Version::IPv6 => SetType::Ipv6Addr, + }), + flags: Some({ + let mut flags = HashSet::from([SetFlag::Interval]); + if self.timeout.is_some() { + flags.insert(SetFlag::Timeout); + } + flags + }), + timeout: self.timeout.clone(), + ..Default::default() + }))); + // insert set in chains + let expr = vec![match self.target { + RStatement::Accept => Statement::Accept(None), + RStatement::Drop => Statement::Drop(None), + RStatement::Continue => Statement::Continue(None), + RStatement::Return => Statement::Return(None), + }]; + for hook in &self.hooks { + batch.add(NfListObject::Rule(Rule { + family, + table: table.to_owned(), + chain: Cow::from(hook.to_string()), + expr: Cow::Owned(expr.clone()), + ..Default::default() + })); + } + } + } + Ok(()) + } +} + +pub struct SetNames { + pub ipv4: Option, + pub ipv6: Option, +} + +impl SetNames { + pub fn new(name: String, version: Option) -> Self { + Self { + ipv4: match version { + Some(IpVersion::Ipv4) => Some(name.clone()), + Some(IpVersion::Ipv6) => None, + None | Some(IpVersion::Ip) => Some(format!("{}v4", name)), + }, + ipv6: match version { + Some(IpVersion::Ipv4) => None, + Some(IpVersion::Ipv6) => Some(name), + None | Some(IpVersion::Ip) => Some(format!("{}v6", name)), + }, + } + } +} + +pub struct Action { + nft: NftClient, + rx: remocMpsc::Receiver, + shutdown: ShutdownToken, + sets: SetNames, + // index of pattern ip in match vec + ip_index: usize, + action: AddDel, +} + +impl Action { + pub fn new( + nft: NftClient, + shutdown: ShutdownToken, + rx: remocMpsc::Receiver, + options: ActionOptions, + ) -> Result { + Ok(Action { + nft, + rx, + shutdown, + sets: SetNames::new(options.set, options.set_options.version), + ip_index: options.ip_index, + action: options.action, + }) + } + + pub async fn serve(mut self) { + loop { + let event = tokio::select! { + exec = self.rx.recv() => Some(exec), + _ = self.shutdown.wait() => None, + }; + match event { + // shutdown asked + None => break, + // channel closed + Some(Ok(None)) => break, + // error from channel + Some(Err(err)) => { + eprintln!("ERROR {err}"); + break; + } + // ok + Some(Ok(Some(exec))) => { + if let Err(err) = self.handle_exec(exec).await { + eprintln!("ERROR {err}"); + break; + } + } + } + } + // eprintln!("DEBUG Asking for shutdown"); + // self.shutdown.ask_shutdown(); + } + + async fn handle_exec(&mut self, mut exec: Exec) -> Result<(), String> { + // safeguard against Vec::remove's panic + if exec.match_.len() <= self.ip_index { + return Err(format!( + "match received from reaction is smaller than expected. looking for index {} but size is {}. this is a bug!", + self.ip_index, + exec.match_.len() + )); + } + let ip = exec.match_.remove(self.ip_index); + // select set + let set = match (&self.sets.ipv4, &self.sets.ipv6) { + (None, None) => return Err(format!("action is neither IPv4 nor IPv6, this is a bug!")), + (None, Some(set)) => set, + (Some(set), None) => set, + (Some(set4), Some(set6)) => { + if ip.contains(':') { + set6 + } else { + set4 + } + } + }; + // add/remove ip to set + let element = NfListObject::Element(Element { + family: NfFamily::INet, + table: Cow::from("reaction"), + name: Cow::from(set), + elem: Cow::from(vec![Expression::String(Cow::from(ip.clone()))]), + }); + let mut batch = Batch::new(); + match self.action { + AddDel::Add => batch.add(element), + AddDel::Delete => batch.delete(element), + }; + match self.nft.send(batch).await { + Ok(ok) => { + eprintln!("DEBUG action ok {:?} {ip}: {ok}", self.action); + Ok(()) + } + Err(err) => Err(format!("action ko {:?} {ip}: {err}", self.action)), + } + } +} + +#[cfg(test)] +mod tests { + use crate::action::{IpVersion, RHook, RStatement, SetOptions}; + + #[tokio::test] + async fn set_options_merge() { + let s1 = SetOptions { + version: None, + hooks: None, + timeout: None, + timeout_u32: None, + target: None, + }; + let s2 = SetOptions { + version: Some(IpVersion::Ipv4), + hooks: Some(vec![RHook::Input]), + timeout: Some("3h".into()), + timeout_u32: Some(3 * 3600), + target: Some(RStatement::Drop), + }; + assert_ne!(s1, s2); + assert_eq!(s1, SetOptions::default()); + + { + // s2 can be merged in s1 + let mut s1 = s1.clone(); + assert!(s1.merge(&s2).is_ok()); + assert_eq!(s1, s2); + } + + { + // s1 can be merged in s2 + let mut s2 = s2.clone(); + assert!(s2.merge(&s1).is_ok()); + } + + { + // s1 can be merged in itself + let mut s3 = s1.clone(); + assert!(s3.merge(&s1).is_ok()); + assert_eq!(s1, s3); + } + + { + // s2 can be merged in itself + let mut s3 = s2.clone(); + assert!(s3.merge(&s2).is_ok()); + assert_eq!(s2, s3); + } + + for s3 in [ + SetOptions { + version: Some(IpVersion::Ipv6), + ..Default::default() + }, + SetOptions { + hooks: Some(vec![RHook::Output]), + ..Default::default() + }, + SetOptions { + timeout: Some("30min".into()), + ..Default::default() + }, + SetOptions { + target: Some(RStatement::Continue), + ..Default::default() + }, + ] { + // none with some is ok + assert!(s3.clone().merge(&s1).is_ok(), "s3: {s3:?}"); + assert!(s1.clone().merge(&s3).is_ok(), "s3: {s3:?}"); + // different some is ko + assert!(s3.clone().merge(&s2).is_err(), "s3: {s3:?}"); + assert!(s2.clone().merge(&s3).is_err(), "s3: {s3:?}"); + } + } +} diff --git a/plugins/reaction-plugin-nftables/src/helpers.rs b/plugins/reaction-plugin-nftables/src/helpers.rs new file mode 100644 index 0000000..b8b97b2 --- /dev/null +++ b/plugins/reaction-plugin-nftables/src/helpers.rs @@ -0,0 +1,15 @@ +use std::fmt::Display; + +#[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] +pub enum Version { + IPv4, + IPv6, +} +impl Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Version::IPv4 => "IPv4", + Version::IPv6 => "IPv6", + }) + } +} diff --git a/plugins/reaction-plugin-nftables/src/main.rs b/plugins/reaction-plugin-nftables/src/main.rs new file mode 100644 index 0000000..ac93eba --- /dev/null +++ b/plugins/reaction-plugin-nftables/src/main.rs @@ -0,0 +1,169 @@ +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, +}; + +use libnftables1_sys::Nftables; +use nftables::{ + batch::Batch, + schema::{Chain, NfListObject, Table}, + types::{NfChainType, NfFamily}, +}; +use reaction_plugin::{ + ActionConfig, ActionImpl, Hello, Manifest, PluginInfo, RemoteResult, StreamConfig, StreamImpl, + shutdown::ShutdownController, +}; +use remoc::rtc; + +use crate::action::{Action, ActionOptions, Set, SetOptions}; + +#[cfg(test)] +mod tests; + +mod action; +pub mod helpers; +mod nft; + +#[tokio::main] +async fn main() { + let plugin = Plugin::default(); + reaction_plugin::main_loop(plugin).await; +} + +#[derive(Default)] +struct Plugin { + sets: Vec, + actions: Vec, + shutdown: ShutdownController, +} + +impl PluginInfo for Plugin { + async fn manifest(&mut self) -> Result { + Ok(Manifest { + hello: Hello::new(), + streams: BTreeSet::default(), + actions: BTreeSet::from(["ipset".into()]), + }) + } + + async fn load_config( + &mut self, + streams: Vec, + actions: Vec, + ) -> RemoteResult<(Vec, Vec)> { + if !streams.is_empty() { + return Err("This plugin can't handle any stream type".into()); + } + + let mut ret_actions = Vec::with_capacity(actions.len()); + let mut set_options: BTreeMap = BTreeMap::new(); + + for ActionConfig { + stream_name, + filter_name, + action_name, + action_type, + config, + patterns, + } in actions + { + if &action_type != "nftables" { + return Err("This plugin can't handle other action types than nftables".into()); + } + + let mut options: ActionOptions = serde_json::from_value(config.into()).map_err(|err| { + format!("invalid options for action {stream_name}.{filter_name}.{action_name}: {err}") + })?; + + options.set_ip_index(patterns).map_err(|_| + format!( + "No pattern with name {} in filter {stream_name}.{filter_name}. Try setting the option `pattern` to your pattern name of type 'ip'", + &options.pattern + ) + )?; + + // Merge option + set_options + .entry(options.set.clone()) + .or_default() + .merge(&options.set_options) + .map_err(|err| format!("set {}: {err}", options.set))?; + + let (tx, rx) = remoc::rch::mpsc::channel(1); + self.actions + .push(Action::new(self.shutdown.token(), rx, options)?); + + ret_actions.push(ActionImpl { tx }); + } + + // Init all sets + while let Some((name, options)) = set_options.pop_first() { + self.sets.push(Set::from(name, options)); + } + + Ok((vec![], ret_actions)) + } + + async fn start(&mut self) -> RemoteResult<()> { + self.shutdown.delegate().handle_quit_signals()?; + + let mut batch = Batch::new(); + batch.add(reaction_table()); + + // Create a chain for each registered netfilter hook + for hook in self + .sets + .iter() + .flat_map(|set| &set.hooks) + .collect::>() + { + batch.add(NfListObject::Chain(Chain { + family: NfFamily::INet, + table: Cow::Borrowed("reaction"), + name: Cow::from(hook.as_str()), + _type: Some(NfChainType::Filter), + hook: Some(hook.into()), + prio: Some(0), + ..Default::default() + })); + } + + for set in &self.sets { + set.init(&mut batch)?; + } + + // TODO apply batch + Nftables::new(); + + // Launch a task that will destroy the table on shutdown + { + let token = self.shutdown.token(); + tokio::spawn(async move { + token.wait().await; + Batch::new().delete(reaction_table()); + }); + } + + // Launch all actions + while let Some(action) = self.actions.pop() { + tokio::spawn(async move { action.serve().await }); + } + self.actions = Default::default(); + + Ok(()) + } + + async fn close(self) -> RemoteResult<()> { + self.shutdown.ask_shutdown(); + self.shutdown.wait_all_task_shutdown().await; + Ok(()) + } +} + +fn reaction_table() -> NfListObject<'static> { + NfListObject::Table(Table { + family: NfFamily::INet, + name: Cow::Borrowed("reaction"), + handle: None, + }) +} diff --git a/plugins/reaction-plugin-nftables/src/nft.rs b/plugins/reaction-plugin-nftables/src/nft.rs new file mode 100644 index 0000000..45c4f5c --- /dev/null +++ b/plugins/reaction-plugin-nftables/src/nft.rs @@ -0,0 +1,69 @@ +use std::{ + ffi::{CStr, CString}, + thread, +}; + +use libnftables1_sys::Nftables; +use nftables::batch::Batch; +use tokio::sync::{mpsc, oneshot}; + +pub fn nftables_thread() -> NftClient { + let (tx, mut rx) = mpsc::channel(10); + + thread::spawn(move || { + let mut conn = Nftables::new(); + + while let Some(NftCommand { json, ret }) = rx.blocking_recv() { + let (rc, output, error) = conn.run_cmd(json.as_ptr()); + let res = match rc { + 0 => to_rust_string(output) + .ok_or_else(|| "unknown ok (rc = 0 but no output buffer)".into()), + _ => to_rust_string(error) + .map(|err| format!("error (rc = {rc}: {err})")) + .ok_or_else(|| format!("unknown error (rc = {rc} but no error buffer)")), + }; + ret.send(res); + } + }); + + NftClient { tx } +} + +fn to_rust_string(c_ptr: *const i8) -> Option { + if c_ptr.is_null() { + None + } else { + Some( + unsafe { CStr::from_ptr(c_ptr) } + .to_string_lossy() + .into_owned(), + ) + } +} + +pub struct NftClient { + tx: mpsc::Sender, +} + +impl NftClient { + pub async fn send(&self, batch: Batch<'_>) -> Result { + // convert JSON to CString + let mut json = serde_json::to_vec(&batch.to_nftables()) + .map_err(|err| format!("couldn't build json to send to nftables: {err}"))?; + json.push('\0' as u8); + let json = CString::from_vec_with_nul(json) + .map_err(|err| format!("invalid json with null char: {err}"))?; + // Send command + let (tx, rx) = oneshot::channel(); + let command = NftCommand { json, ret: tx }; + self.tx.send(command).await; + // Wait for result + rx.await + .map_err(|err| format!("nftables thread has quit: {err}"))? + } +} + +struct NftCommand { + json: CString, + ret: oneshot::Sender>, +} diff --git a/plugins/reaction-plugin-nftables/src/tests.rs b/plugins/reaction-plugin-nftables/src/tests.rs new file mode 100644 index 0000000..09ce6cf --- /dev/null +++ b/plugins/reaction-plugin-nftables/src/tests.rs @@ -0,0 +1,253 @@ +use reaction_plugin::{ActionConfig, PluginInfo, StreamConfig, Value}; +use serde_json::json; + +use crate::Plugin; + +#[tokio::test] +async fn conf_stream() { + // No stream is supported by ipset + assert!( + Plugin::default() + .load_config( + vec![StreamConfig { + stream_name: "stream".into(), + stream_type: "ipset".into(), + config: Value::Null + }], + vec![] + ) + .await + .is_err() + ); + + // Nothing is ok + assert!(Plugin::default().load_config(vec![], vec![]).await.is_ok()); +} + +#[tokio::test] +async fn conf_action_standalone() { + let p = vec!["name".into(), "ip".into(), "ip2".into()]; + let p_noip = vec!["name".into(), "ip2".into()]; + + for (is_ok, conf, patterns) in [ + // minimal set + (true, json!({ "set": "test" }), &p), + // missing set key + (false, json!({}), &p), + (false, json!({ "version": "ipv4" }), &p), + // unknown key + (false, json!({ "set": "test", "unknown": "yes" }), &p), + (false, json!({ "set": "test", "ip_index": 1 }), &p), + (false, json!({ "set": "test", "timeout_u32": 1 }), &p), + // pattern // + (true, json!({ "set": "test" }), &p), + (true, json!({ "set": "test", "pattern": "ip" }), &p), + (true, json!({ "set": "test", "pattern": "ip2" }), &p), + (true, json!({ "set": "test", "pattern": "ip2" }), &p_noip), + // unknown pattern "ip" + (false, json!({ "set": "test" }), &p_noip), + (false, json!({ "set": "test", "pattern": "ip" }), &p_noip), + // unknown pattern + (false, json!({ "set": "test", "pattern": "unknown" }), &p), + (false, json!({ "set": "test", "pattern": "uwu" }), &p_noip), + // bad type + (false, json!({ "set": "test", "pattern": 0 }), &p_noip), + (false, json!({ "set": "test", "pattern": true }), &p_noip), + // action // + (true, json!({ "set": "test", "action": "add" }), &p), + (true, json!({ "set": "test", "action": "del" }), &p), + // unknown action + (false, json!({ "set": "test", "action": "create" }), &p), + (false, json!({ "set": "test", "action": "insert" }), &p), + (false, json!({ "set": "test", "action": "delete" }), &p), + (false, json!({ "set": "test", "action": "destroy" }), &p), + // bad type + (false, json!({ "set": "test", "action": true }), &p), + (false, json!({ "set": "test", "action": 1 }), &p), + // ip version // + // ok + (true, json!({ "set": "test", "version": "ipv4" }), &p), + (true, json!({ "set": "test", "version": "ipv6" }), &p), + (true, json!({ "set": "test", "version": "ip" }), &p), + // unknown version + (false, json!({ "set": "test", "version": 4 }), &p), + (false, json!({ "set": "test", "version": 6 }), &p), + (false, json!({ "set": "test", "version": 46 }), &p), + (false, json!({ "set": "test", "version": "5" }), &p), + (false, json!({ "set": "test", "version": "ipv5" }), &p), + (false, json!({ "set": "test", "version": "4" }), &p), + (false, json!({ "set": "test", "version": "6" }), &p), + (false, json!({ "set": "test", "version": "46" }), &p), + // bad type + (false, json!({ "set": "test", "version": true }), &p), + // chains // + // everything is fine really + (true, json!({ "set": "test", "chains": [] }), &p), + (true, json!({ "set": "test", "chains": ["INPUT"] }), &p), + (true, json!({ "set": "test", "chains": ["FORWARD"] }), &p), + ( + true, + json!({ "set": "test", "chains": ["custom_chain"] }), + &p, + ), + ( + true, + json!({ "set": "test", "chains": ["INPUT", "FORWARD"] }), + &p, + ), + ( + true, + json!({ + "set": "test", + "chains": ["INPUT", "FORWARD", "my_iptables_chain"] + }), + &p, + ), + // timeout // + (true, json!({ "set": "test", "timeout": "1m" }), &p), + (true, json!({ "set": "test", "timeout": "3 days" }), &p), + // bad + (false, json!({ "set": "test", "timeout": "3 dayz"}), &p), + (false, json!({ "set": "test", "timeout": 12 }), &p), + // target // + // anything is fine too + (true, json!({ "set": "test", "target": "DROP" }), &p), + (true, json!({ "set": "test", "target": "ACCEPT" }), &p), + (true, json!({ "set": "test", "target": "RETURN" }), &p), + (true, json!({ "set": "test", "target": "custom_chain" }), &p), + // bad + (false, json!({ "set": "test", "target": 11 }), &p), + (false, json!({ "set": "test", "target": ["DROP"] }), &p), + ] { + let res = Plugin::default() + .load_config( + vec![], + vec![ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action".into(), + action_type: "ipset".into(), + config: conf.clone().into(), + patterns: patterns.clone(), + }], + ) + .await; + + assert!( + res.is_ok() == is_ok, + "conf: {:?}, must be ok: {is_ok}, result: {:?}", + conf, + // empty Result::Ok because ActionImpl is not Debug + res.map(|_| ()) + ); + } +} + +// TODO +#[tokio::test] +async fn conf_action_merge() { + let mut plugin = Plugin::default(); + + let set1 = ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action1".into(), + action_type: "ipset".into(), + config: json!({ + "set": "test", + "target": "DROP", + "chains": ["INPUT"], + "action": "add", + }) + .into(), + patterns: vec!["ip".into()], + }; + + let set2 = ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action2".into(), + action_type: "ipset".into(), + config: json!({ + "set": "test", + "target": "DROP", + "version": "ip", + "action": "add", + }) + .into(), + patterns: vec!["ip".into()], + }; + + let set3 = ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action2".into(), + action_type: "ipset".into(), + config: json!({ + "set": "test", + "action": "del", + }) + .into(), + patterns: vec!["ip".into()], + }; + + let res = plugin + .load_config( + vec![], + vec![ + // First set + set1.clone(), + // Same set, adding options, no conflict + set2.clone(), + // Same set, no new options, no conflict + set3.clone(), + // Unrelated set, so no conflict + ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action3".into(), + action_type: "ipset".into(), + config: json!({ + "set": "test2", + "target": "target1", + "version": "ipv6", + }) + .into(), + patterns: vec!["ip".into()], + }, + ], + ) + .await; + + assert!(res.is_ok(), "res: {:?}", res.map(|_| ())); + + // Another set with conflict is not ok + let res = plugin + .load_config( + vec![], + vec![ + // First set + set1, + // Same set, adding options, no conflict + set2, + // Same set, no new options, no conflict + set3, + // Another set with conflict + ActionConfig { + stream_name: "stream".into(), + filter_name: "filter".into(), + action_name: "action3".into(), + action_type: "ipset".into(), + config: json!({ + "set": "test", + "target": "target2", + "action": "del", + }) + .into(), + patterns: vec!["ip".into()], + }, + ], + ) + .await; + assert!(res.is_err(), "res: {:?}", res.map(|_| ())); +} diff --git a/shell.nix b/shell.nix index 27dac77..ecb4318 100644 --- a/shell.nix +++ b/shell.nix @@ -4,6 +4,7 @@ pkgs.mkShell { name = "libipset"; buildInputs = [ ipset + nftables clang ]; src = null;