From e618c5bcd16502df52f524519d85c72a325e3d6d Mon Sep 17 00:00:00 2001 From: ppom Date: Wed, 28 Jan 2026 12:00:00 +0100 Subject: [PATCH] ipset: so much ~~waow~~ code --- plugins/reaction-plugin-ipset/src/action.rs | 258 ++++++++++++++++++++ plugins/reaction-plugin-ipset/src/ipset.rs | 194 +++++++++++---- plugins/reaction-plugin-ipset/src/main.rs | 148 ++++------- shell.nix | 1 + 4 files changed, 452 insertions(+), 149 deletions(-) create mode 100644 plugins/reaction-plugin-ipset/src/action.rs diff --git a/plugins/reaction-plugin-ipset/src/action.rs b/plugins/reaction-plugin-ipset/src/action.rs new file mode 100644 index 0000000..f4277f4 --- /dev/null +++ b/plugins/reaction-plugin-ipset/src/action.rs @@ -0,0 +1,258 @@ +use std::u32; + +use reaction_plugin::{Exec, shutdown::ShutdownToken, time::parse_duration}; +use remoc::rch::mpsc as remocMpsc; +use serde::{Deserialize, Serialize, de::Deserializer, de::Error}; + +use crate::ipset::{IpSet, Order, SetChain, SetOptions, Version}; +pub enum IpVersion { + V4, + V6, + V46, +} +impl<'de> Deserialize<'de> for IpVersion { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match Option::::deserialize(deserializer)? { + None => Ok(IpVersion::V46), + Some(version) => match version { + 4 => Ok(IpVersion::V4), + 6 => Ok(IpVersion::V6), + 46 => Ok(IpVersion::V46), + _ => Err(D::Error::custom("version must be one of 4, 6 or 46")), + }, + } + } +} +impl Serialize for IpVersion { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_u8(match self { + IpVersion::V4 => 4, + IpVersion::V6 => 6, + IpVersion::V46 => 46, + }) + } +} + +// FIXME block configs that have different set options for the same name +// treat default values as none? + +#[derive(Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ActionOptions { + /// The set that should be used by this action + set: String, + /// The pattern name of the IP. + /// Defaults to "ip" + #[serde(default = "serde_ip")] + pub pattern: String, + /// The IP type. + /// Defaults to `46`. + /// If `4`: creates an IPv4 set with this name + /// If `6`: creates an IPv6 set with this name + /// If `46`: creates an IPv4 set with its name suffixed by 'v4' AND an IPv6 set with its name suffixed by 'v6' + version: IpVersion, + /// Chains where the IP set should be inserted. + /// Defaults to `["INPUT", "FORWARD"]` + #[serde(default = "serde_chains")] + chains: Vec, + // Optional timeout, letting linux/netfilter handle set removal instead of reaction + // Note that `reaction show` and `reaction flush` won't work if set instead of an `after` action + #[serde(skip_serializing_if = "Option::is_none")] + timeout: Option, + // Target that iptables should use when the IP is encountered. + // Defaults to DROP, but can also be ACCEPT, RETURN or any user-defined chain + #[serde(default = "serde_drop")] + target: String, + // TODO add `add`//`remove` option +} + +fn serde_ip() -> String { + "ip".into() +} +fn serde_drop() -> String { + "DROP".into() +} +fn serde_chains() -> Vec { + vec!["INPUT".into(), "FORWARD".into()] +} + +pub struct Action { + ipset: IpSet, + rx: remocMpsc::Receiver, + shutdown: ShutdownToken, + ipv4_set: Option, + ipv6_set: Option, + // index of pattern ip in match vec + ip_index: usize, + chains: Vec, + timeout: Option, + target: String, +} + +impl Action { + pub fn new( + ipset: IpSet, + shutdown: ShutdownToken, + ip_index: usize, + rx: remocMpsc::Receiver, + options: ActionOptions, + ) -> Result { + Ok(Action { + ipset, + rx, + shutdown, + ip_index, + target: options.target, + chains: options.chains, + timeout: if let Some(timeout) = options.timeout { + let duration = parse_duration(&timeout) + .map_err(|err| format!("failed to parse timeout: {}", err))? + .as_secs(); + if duration > u32::MAX as u64 { + return Err(format!( + "timeout is limited to {} seconds (approx {} days)", + u32::MAX, + 49_000 + )); + } + Some(duration as u32) + } else { + None + }, + ipv4_set: match options.version { + IpVersion::V4 => Some(options.set.clone()), + IpVersion::V6 => None, + IpVersion::V46 => Some(format!("{}v4", options.set)), + }, + ipv6_set: match options.version { + IpVersion::V4 => None, + IpVersion::V6 => Some(options.set), + IpVersion::V46 => Some(format!("{}v6", options.set)), + }, + }) + } + + pub async fn init(&mut self) -> Result<(), String> { + for (set, version) in [ + (&self.ipv4_set, Version::IPv4), + (&self.ipv6_set, Version::IPv6), + ] { + if let Some(set) = set { + println!("INFO creating {version} set {set}"); + // create set + self.ipset + .order(Order::CreateSet(SetOptions { + name: set.clone(), + version, + timeout: self.timeout, + })) + .await?; + // insert set in chains + for chain in &self.chains { + println!("INFO inserting {version} set {set} in chain {chain}"); + self.ipset + .order(Order::InsertSet(SetChain { + set: set.clone(), + chain: chain.clone(), + target: self.target.clone(), + })) + .await?; + } + } + } + Ok(()) + } + + pub async fn destroy(&mut self) { + for (set, version) in [ + (&self.ipv4_set, Version::IPv4), + (&self.ipv6_set, Version::IPv6), + ] { + if let Some(set) = set { + for chain in &self.chains { + println!("INFO removing {version} set {set} from chain {chain}"); + if let Err(err) = self + .ipset + .order(Order::RemoveSet(SetChain { + set: set.clone(), + chain: chain.clone(), + target: self.target.clone(), + })) + .await + { + println!( + "ERROR while removing {version} set {set} from chain {chain}: {err}" + ); + } + } + println!("INFO destroying {version} set {set}"); + if let Err(err) = self.ipset.order(Order::DestroySet(set.clone())).await { + println!("ERROR while destroying {version} set {set}: {err}"); + } + } + } + } + + pub async fn serve(mut self) { + loop { + let event = tokio::select! { + exec = self.rx.recv() => Some(exec), + _ = self.shutdown.wait() => None, + }; + match event { + // shutdown asked + None => break, + // channel closed + Some(Ok(None)) => break, + // error from channel + Some(Err(err)) => { + println!("ERROR {err}"); + break; + } + // ok + Some(Ok(Some(exec))) => { + if let Err(err) = self.handle_exec(exec).await { + println!("ERROR {err}"); + break; + } + } + } + } + self.shutdown.ask_shutdown(); + self.destroy().await; + } + + async fn handle_exec(&mut self, mut exec: Exec) -> Result<(), String> { + // safeguard against Vec::remove's panic + if exec.match_.len() <= self.ip_index { + return Err(format!( + "match received from reaction is smaller than expected. looking for index {} but size is {}. this is a bug!", + self.ip_index, + exec.match_.len() + )); + } + let ip = exec.match_.remove(self.ip_index); + // select set + let set = match (&self.ipv4_set, &self.ipv6_set) { + (None, None) => return Err(format!("action is neither IPv4 nor IPv6, this is a bug!")), + (None, Some(set)) => set, + (Some(set), None) => set, + (Some(set4), Some(set6)) => { + if ip.contains(':') { + set6 + } else { + set4 + } + } + }; + // add ip to set + self.ipset.order(Order::Insert(set.clone(), ip)).await?; + Ok(()) + } +} diff --git a/plugins/reaction-plugin-ipset/src/ipset.rs b/plugins/reaction-plugin-ipset/src/ipset.rs index 6f841fb..263addb 100644 --- a/plugins/reaction-plugin-ipset/src/ipset.rs +++ b/plugins/reaction-plugin-ipset/src/ipset.rs @@ -1,27 +1,35 @@ -use std::{collections::BTreeMap, process::Command, thread}; +use std::{collections::BTreeMap, fmt::Display, net::Ipv4Addr, process::Command, thread}; use ipset::{ Session, - types::{Error, HashNet}, + types::{HashNet, NetDataType, Parse}, }; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Copy, Clone)] pub enum Version { IPv4, IPv6, } +impl Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Version::IPv4 => "IPv4", + Version::IPv6 => "IPv6", + }) + } +} pub struct SetOptions { - name: String, - version: Version, - timeout: Option, + pub name: String, + pub version: Version, + pub timeout: Option, } pub struct SetChain { - set: String, - chain: String, - action: String, + pub set: String, + pub chain: String, + pub target: String, } pub enum Order { @@ -29,14 +37,63 @@ pub enum Order { DestroySet(String), InsertSet(SetChain), RemoveSet(SetChain), + Insert(String, String), + Remove(String, String), } -pub fn ipset_thread() -> Result, String> { - let (tx, rx) = mpsc::channel(1); - thread::spawn(move || IPsetManager::default().serve(rx)); - Ok(tx) +#[derive(Clone)] +pub struct IpSet { + tx: mpsc::Sender, } +impl Default for IpSet { + fn default() -> Self { + let (tx, rx) = mpsc::channel(1); + thread::spawn(move || IPsetManager::default().serve(rx)); + Self { tx } + } +} + +impl IpSet { + pub async fn order(&mut self, order: Order) -> Result<(), IpSetError> { + let (tx, rx) = oneshot::channel(); + self.tx + .send((order, tx)) + .await + .map_err(|err| IpSetError::Thread(format!("ipset thread has quit: {err}")))?; + rx.await + .map_err(|err| IpSetError::Thread(format!("ipset thread didn't respond: {err}")))? + .map_err(IpSetError::IpSet) + } +} + +pub enum IpSetError { + Thread(String), + IpSet(String), +} +impl Display for IpSetError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + IpSetError::Thread(err) => err, + IpSetError::IpSet(err) => err, + } + ) + } +} +impl From for String { + fn from(value: IpSetError) -> Self { + match value { + IpSetError::Thread(err) => err, + IpSetError::IpSet(err) => err, + } + } +} + +pub type OrderType = (Order, oneshot::Sender>); + struct Set { session: Session, version: Version, @@ -48,12 +105,13 @@ struct IPsetManager { } impl IPsetManager { - fn serve(&mut self, mut rx: mpsc::Receiver) { + fn serve(mut self, mut rx: mpsc::Receiver) { loop { match rx.blocking_recv() { None => break, - Some(order) => { + Some((order, response)) => { let result = self.handle_order(order); + let _ = response.send(result); } } } @@ -88,35 +146,83 @@ impl IPsetManager { .map_err(|err| format!("Could not destroy set {set}: {err}"))?; } } - Order::InsertSet(SetChain { set, chain, action }) => { - let child = Command::new("iptables") - .args([ - "-w", - "-I", - &chain, - "-m", - "set", - "--match-set", - &set, - "src", - "-j", - &action, - ]) - .spawn() - .map_err(|err| { - format!("Could not insert ipset {set} in chain {chain}: {err}") - })?; - match child.wait() { - Ok(exit) => { - if !exit.success() { - return Err(format!("Could not insert ipset")); - } - } - Err(_) => todo!(), - }; - } - Order::RemoveSet(options) => {} + + Order::InsertSet(options) => insert_remove_set(options, true)?, + Order::RemoveSet(options) => insert_remove_set(options, false)?, + + Order::Insert(set, ip) => self.insert_remove_ip(set, ip, true)?, + Order::Remove(set, ip) => self.insert_remove_ip(set, ip, false)?, }; Ok(()) } + + fn insert_remove_ip(&mut self, set: String, ip: String, insert: bool) -> Result<(), String> { + let session = self + .sessions + .get_mut(&set) + .ok_or(format!("No set managed by us with this name: {set}"))?; + + let mut net_data = NetDataType::new(Ipv4Addr::LOCALHOST, 0); + net_data + .parse(&ip) + .map_err(|err| format!("`{ip}` is not recognized as an IP: {err}"))?; + + if insert { + session.session.add(net_data, &[]) + } else { + session.session.del(net_data) + } + .map_err(|err| format!("Could not add `{ip}` to set {set}: {err}"))?; + + Ok(()) + } + + fn insert_remove_set(&self, options: SetChain, insert: bool) -> Result<(), String> { + let SetChain { + set, + chain, + target: action, + } = options; + + let command = match self + .sessions + .get(&set) + .ok_or(format!("No set managed by us with this name: {set}"))? + .version + { + Version::IPv4 => "iptables", + Version::IPv6 => "ip6tables", + }; + + let mut child = Command::new(command) + .args([ + "-w", + if insert { "-I" } else { "-D" }, + &chain, + "-m", + "set", + "--match-set", + &set, + "src", + "-j", + &action, + ]) + .spawn() + .map_err(|err| format!("Could not insert ipset {set} in chain {chain}: {err}"))?; + + let exit = child + .wait() + .map_err(|err| format!("Could not insert ipset: {err}"))?; + + if exit.success() { + Ok(()) + } else { + Err(format!( + "Could not insert ipset: exit code {}", + exit.code() + .map(|c| c.to_string()) + .unwrap_or_else(|| "".to_string()) + )) + } + } } diff --git a/plugins/reaction-plugin-ipset/src/main.rs b/plugins/reaction-plugin-ipset/src/main.rs index 47b4e3d..74e494e 100644 --- a/plugins/reaction-plugin-ipset/src/main.rs +++ b/plugins/reaction-plugin-ipset/src/main.rs @@ -1,16 +1,20 @@ use std::collections::BTreeSet; use reaction_plugin::{ - ActionImpl, Exec, Hello, Manifest, PluginInfo, RemoteResult, StreamImpl, Value, + ActionImpl, Hello, Manifest, PluginInfo, RemoteError, RemoteResult, StreamImpl, Value, + shutdown::ShutdownController, }; -use remoc::{rch::mpsc, rtc}; -use serde::{Deserialize, Serialize, de::Deserializer, de::Error}; +use remoc::rtc; -use crate::ipset::ipset_thread; +use crate::{ + action::{Action, ActionOptions}, + ipset::IpSet, +}; #[cfg(test)] mod tests; +mod action; mod ipset; #[tokio::main] @@ -21,8 +25,9 @@ async fn main() { #[derive(Default)] struct Plugin { - // ipset: Arc>, + ipset: IpSet, actions: Vec, + shutdown: ShutdownController, } impl PluginInfo for Plugin { @@ -56,7 +61,7 @@ impl PluginInfo for Plugin { return Err("This plugin can't handle other action types than ipset".into()); } - let mut options: ActionOptions = serde_json::from_value(config.into()).map_err(|err| { + let options: ActionOptions = serde_json::from_value(config.into()).map_err(|err| { format!("invalid options for action {stream_name}.{filter_name}.{action_name}: {err}") })?; @@ -67,121 +72,54 @@ impl PluginInfo for Plugin { .next() .ok_or_else(|| { format!( - "No pattern with name {} in filter {stream_name}.{filter_name}", + "No pattern with name {} in filter {stream_name}.{filter_name}. Try setting the option `pattern` to your pattern name of type 'ip'", options.pattern ) })? .0; let (tx, rx) = remoc::rch::mpsc::channel(1); - self.actions.push(Action { - chains: options.chains, - ipv4_set: match options.version { - IpVersion::V4 => Some(options.set.clone()), - IpVersion::V6 => None, - IpVersion::V46 => Some(format!("{}v4", options.set)), - }, - ipv6_set: match options.version { - IpVersion::V4 => None, - IpVersion::V6 => Some(options.set), - IpVersion::V46 => Some(format!("{}v6", options.set)), - }, + self.actions.push(Action::new( + self.ipset.clone(), + self.shutdown.token(), ip_index, rx, - }); + options, + )?); Ok(ActionImpl { tx }) } async fn finish_setup(&mut self) -> RemoteResult<()> { - ipset_thread()?; + // Init all sets + let mut first_error = None; + for (i, action) in self.actions.iter_mut().enumerate() { + // Retain if error + if let Err(err) = action.init().await { + first_error = Some((i, RemoteError::Plugin(err))); + break; + } + } + // Destroy initialized sets if error + if let Some((i, err)) = first_error { + for action in self.actions.iter_mut().take(i + 1) { + let _ = action.destroy().await; + } + return Err(err); + } - todo!(); + // Launch all actions + while let Some(action) = self.actions.pop() { + tokio::spawn(async move { action.serve() }); + } + + self.actions = Default::default(); + Ok(()) } async fn close(self) -> RemoteResult<()> { - todo!(); - } -} - -enum IpVersion { - V4, - V6, - V46, -} -impl<'de> Deserialize<'de> for IpVersion { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - match Option::::deserialize(deserializer)? { - None => Ok(IpVersion::V46), - Some(version) => match version { - 4 => Ok(IpVersion::V4), - 6 => Ok(IpVersion::V6), - 46 => Ok(IpVersion::V46), - _ => Err(D::Error::custom("version must be one of 4, 6 or 46")), - }, - } - } -} -impl Serialize for IpVersion { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_u8(match self { - IpVersion::V4 => 4, - IpVersion::V6 => 6, - IpVersion::V46 => 46, - }) - } -} - -#[derive(Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -struct ActionOptions { - /// The set that should be used by this action - set: String, - /// The pattern name of the IP. - /// Defaults to "ip" - #[serde(default = "serde_ip")] - pattern: String, - /// The IP type. - /// Defaults to `46`. - /// If `4`: creates an IPv4 set with this name - /// If `6`: creates an IPv6 set with this name - /// If `46`: creates an IPv4 set with its name suffixed by 'v4' AND an IPv6 set with its name suffixed by 'v6' - version: IpVersion, - /// Chains where the IP set should be inserted. - /// Defaults to `["INPUT", "FORWARD"]` - #[serde(default = "serde_chains")] - chains: Vec, - // Optional timeout, letting linux/netfilter handle set removal instead of reaction - // Note that `reaction show` and `reaction flush` won't work if set instead of an `after` action -} - -fn serde_ip() -> String { - "ip".into() -} -fn serde_chains() -> Vec { - vec!["INPUT".into(), "FORWARD".into()] -} - -struct Action { - ipv4_set: Option, - ipv6_set: Option, - // index of pattern ip in match vec - ip_index: usize, - chains: Vec, - rx: mpsc::Receiver, -} - -impl Action { - async fn serve(&mut self) { - // while let Ok(Some(exec)) = self.rx.recv().await { - // let line = self.send.line(exec.match_); - // self.to.tx.send((line, exec.time)).await.unwrap(); - // } + self.shutdown.ask_shutdown(); + self.shutdown.wait_all_task_shutdown().await; + Ok(()) } } diff --git a/shell.nix b/shell.nix index e761b07..27dac77 100644 --- a/shell.nix +++ b/shell.nix @@ -1,3 +1,4 @@ +# This shell.nix for NixOS users is only needed when building reaction-plugin-ipset with import {}; pkgs.mkShell { name = "libipset";