cluster: created ConnectionManager

Reorganized code.
Moved some functionnality from EndpointManager to ConnectionManager.
Still a lot to do there, but few in the rest of the code.
This commit is contained in:
ppom 2025-11-15 12:00:00 +01:00
commit 0635bae544
No known key found for this signature in database
7 changed files with 378 additions and 237 deletions

1
Cargo.lock generated
View file

@ -2861,6 +2861,7 @@ version = "0.1.0"
dependencies = [
"chrono",
"data-encoding",
"futures",
"iroh",
"rand 0.9.2",
"reaction-plugin",

View file

@ -50,7 +50,7 @@ jrsonnet-evaluator = "0.4.2"
# Error macro
thiserror = "1.0.63"
# Async runtime & helpers
futures = "0.3.30"
futures = { workspace = true }
tokio = { workspace = true, features = ["full", "tracing"] }
tokio-util = { workspace = true, features = ["codec"] }
# Async logging
@ -79,6 +79,7 @@ members = ["plugins/reaction-plugin", "plugins/reaction-plugin-cluster", "plugin
[workspace.dependencies]
chrono = { version = "0.4.38", features = ["std", "clock", "serde"] }
futures = "0.3.30"
remoc = { version = "0.18.3" }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.117"

View file

@ -7,6 +7,7 @@ edition = "2024"
reaction-plugin.workspace = true
chrono.workspace = true
futures.workspace = true
remoc.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -1,19 +1,21 @@
use std::{
collections::VecDeque,
collections::BTreeMap,
net::{SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use chrono::{DateTime, Local, Utc};
use iroh::{Endpoint, EndpointAddr, endpoint::Connection};
use futures::future::join_all;
use iroh::{Endpoint, PublicKey, endpoint::Connection};
use reaction_plugin::{Line, shutdown::ShutdownToken};
use tokio::sync::mpsc;
use remoc::rch::mpsc as remocMpsc;
use tokio::sync::mpsc as tokioMpsc;
use crate::{ActionInit, StreamInit, endpoint::EndpointManager};
use crate::{ActionInit, StreamInit, connection::ConnectionManager, endpoint::EndpointManager};
pub const ALPN: [&[u8]; 1] = ["reaction_cluster_1".as_bytes()];
type UtcLine = Arc<(String, DateTime<Utc>)>;
pub type UtcLine = Arc<(String, DateTime<Utc>)>;
pub async fn bind(stream: &StreamInit) -> Result<Endpoint, String> {
let mut builder = Endpoint::builder()
@ -39,36 +41,81 @@ pub async fn bind(stream: &StreamInit) -> Result<Endpoint, String> {
pub fn cluster_tasks(
endpoint: Endpoint,
stream: StreamInit,
actions: Vec<ActionInit>,
mut stream: StreamInit,
mut actions: Vec<ActionInit>,
shutdown: ShutdownToken,
) {
let messages_from_actions = spawn_actions(actions, stream.tx.clone());
let (message_action2connection_txs, mut message_action2connection_rxs): (
Vec<tokioMpsc::Sender<UtcLine>>,
Vec<tokioMpsc::Receiver<UtcLine>>,
) = (0..stream.nodes.len())
.map(|_| tokioMpsc::channel(1))
.unzip();
let (endpoint_addr_tx, connection_rx) =
EndpointManager::new(endpoint, stream.name.clone(), stream.nodes.len());
// TODO create ConnectionManagers and connect them to EndpointManager
}
fn spawn_actions(
mut actions: Vec<ActionInit>,
own_cluster_tx: remoc::rch::mpsc::Sender<Line>,
) -> mpsc::Receiver<UtcLine> {
let (nodes_tx, nodes_rx) = mpsc::channel(1);
// Spawn action tasks
while let Some(mut action) = actions.pop() {
let nodes_tx = nodes_tx.clone();
let own_cluster_tx = own_cluster_tx.clone();
tokio::spawn(async move { action.serve(nodes_tx, own_cluster_tx).await });
let message_action2connection_txs = message_action2connection_txs.clone();
let own_cluster_tx = stream.tx.clone();
tokio::spawn(async move {
action
.serve(message_action2connection_txs, own_cluster_tx)
.await
});
}
let endpoint = Arc::new(endpoint);
let (connection_endpoint2connection_txs, mut connection_endpoint2connection_rxs): (
BTreeMap<PublicKey, tokioMpsc::Sender<Connection>>,
Vec<(PublicKey, tokioMpsc::Receiver<Connection>)>,
) = stream
.nodes
.keys()
.map(|pk| {
let (tx, rx) = tokioMpsc::channel(1);
((pk.clone(), tx), (pk.clone(), rx))
})
.unzip();
// Spawn connection accepter
EndpointManager::new(
endpoint.clone(),
stream.name.clone(),
connection_endpoint2connection_txs,
shutdown.clone(),
);
// Spawn connection managers
while let Some((pk, connection_endpoint2connection_rx)) =
connection_endpoint2connection_rxs.pop()
{
let cluster_name = stream.name.clone();
let endpoint_addr = stream.nodes.remove(&pk).unwrap();
let endpoint = endpoint.clone();
let message_action2connection_rx = message_action2connection_rxs.pop().unwrap();
let stream_tx = stream.tx.clone();
tokio::spawn(async move {
ConnectionManager::new(
cluster_name,
endpoint_addr,
endpoint,
connection_endpoint2connection_rx,
stream.message_timeout,
message_action2connection_rx,
stream_tx,
)
.task()
.await
});
}
nodes_rx
}
impl ActionInit {
// Receive messages from its reaction action and dispatch them to all connections and to the reaction stream
async fn serve(
&mut self,
nodes_tx: mpsc::Sender<UtcLine>,
own_stream_tx: remoc::rch::mpsc::Sender<Line>,
nodes_tx: Vec<tokioMpsc::Sender<UtcLine>>,
own_stream_tx: remocMpsc::Sender<Line>,
) {
while let Ok(Some(m)) = self.rx.recv().await {
let line = if m.match_.is_empty() {
@ -88,9 +135,11 @@ impl ActionInit {
}
let line = Arc::new((line, now.to_utc()));
if let Err(err) = nodes_tx.send(line).await {
eprintln!("ERROR while queueing message to be sent to cluster nodes: {err}");
};
for result in join_all(nodes_tx.iter().map(|tx| tx.send(line.clone()))).await {
if let Err(err) = result {
eprintln!("ERROR while queueing message to be sent to cluster nodes: {err}");
};
}
if let Err(err) = m.result.send(Ok(())) {
eprintln!("ERROR while responding to reaction action: {err}");
@ -99,30 +148,18 @@ impl ActionInit {
}
}
pub struct ConnectionManager {
endpoint: EndpointAddr,
// Ask the EndpointManager to connect
ask_connection: mpsc::Sender<EndpointAddr>,
// Our own connection (when we have one)
connection: Option<Connection>,
// The EndpointManager sending us a connection (whether we asked for it or not)
connection_rx: mpsc::Receiver<Connection>,
// Our queue of messages to send
queue: VecDeque<UtcLine>,
// Messages we send from remote nodes to our own stream
own_cluster_tx: remoc::rch::mpsc::Sender<String>,
}
#[cfg(test)]
mod tests {
use chrono::{DateTime, Local};
// As long as nodes communicate with UTC datetimes, them having different local timezones is not an issue!
#[test]
fn different_local_tz_is_ok() {
let date1: DateTime<Local> =
serde_json::from_str("2025-11-02T17:47:21.716229569+01:00").unwrap();
let date2: DateTime<Local> =
serde_json::from_str("2025-11-02T18:47:21.716229569+02:00").unwrap();
let dates: Vec<DateTime<Local>> = serde_json::from_str(
"[\"2025-11-02T17:47:21.716229569+01:00\",\"2025-11-02T18:47:21.716229569+02:00\"]",
)
.unwrap();
assert_eq!(date1.to_utc(), date2.to_utc());
assert_eq!(dates[0].to_utc(), dates[1].to_utc());
}
}

View file

@ -0,0 +1,230 @@
use std::{collections::VecDeque, sync::Arc, time::Duration};
use chrono::TimeDelta;
use iroh::{Endpoint, EndpointAddr, endpoint::Connection};
use reaction_plugin::Line;
use remoc::{Connect, rch::base};
use serde::{Deserialize, Serialize};
use tokio::{
sync::mpsc,
time::{Sleep, sleep},
};
use crate::cluster::{ALPN, UtcLine};
const START_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_TIMEOUT: Duration = Duration::from_secs(60 * 60); // 1 hour
const TIMEOUT_FACTOR: f64 = 1.5;
const PROTOCOL_VERSION: u32 = 1;
enum Event {
Tick,
LocalMessageReceived(UtcLine),
RemoteMessageReceived(RemoteMessage),
ConnectionReceived(Connection),
}
struct OwnConnection {
connection: Connection,
tx: base::Sender<RemoteMessage>,
rx: base::Receiver<RemoteMessage>,
}
pub struct ConnectionManager {
/// Cluster's name (for logging)
cluster_name: String,
/// The remote node we're communicating with
remote: EndpointAddr,
/// Endpoint
endpoint: Arc<Endpoint>,
/// The EndpointManager sending us a connection (whether we asked for it or not)
connection_rx: mpsc::Receiver<Connection>,
/// Our own connection (when we have one)
connection: Option<OwnConnection>,
/// Delta we'll use next time we'll try to connect to remote
delta: Duration,
/// When this Future resolves, we'll retry connecting to remote
tick: Option<Sleep>,
/// Max duration before we drop pending messages to a node we can't connect to.
message_timeout: TimeDelta,
/// Message we receive from actions
message_rx: mpsc::Receiver<UtcLine>,
/// Our queue of messages to send
message_queue: VecDeque<UtcLine>,
/// Messages we send from remote nodes to our own stream
own_cluster_tx: remoc::rch::mpsc::Sender<Line>,
}
impl ConnectionManager {
pub fn new(
cluster_name: String,
remote: EndpointAddr,
endpoint: Arc<Endpoint>,
connection_rx: mpsc::Receiver<Connection>,
message_timeout: TimeDelta,
message_rx: mpsc::Receiver<UtcLine>,
own_cluster_tx: remoc::rch::mpsc::Sender<Line>,
) -> Self {
Self {
cluster_name,
remote,
endpoint,
connection: None,
delta: Duration::default(),
tick: None,
connection_rx,
message_timeout,
message_rx,
message_queue: VecDeque::default(),
own_cluster_tx,
}
}
pub async fn task(mut self) {
self.try_connect().await;
loop {
// TODO event
let event = Event::Tick;
self.handle_event(event).await;
}
}
/// Main loop
async fn handle_event(&mut self, event: Event) {
match event {
Event::Tick => {
// TODO
self.try_connect().await;
}
Event::ConnectionReceived(connection) => {
// TODO
}
Event::LocalMessageReceived(utc_line) => {
// TODO
}
Event::RemoteMessageReceived(remote_message) => {
// TODO
}
}
}
/// Try connecting to a remote endpoint
/// Returns true if we have a valid connection now
async fn try_connect(&mut self) -> bool {
if self.connection.is_none() {
match self.endpoint.connect(self.remote.clone(), ALPN[0]).await {
Ok(connection) => self.handle_connection(connection).await,
Err(err) => {
self.try_connect_error(err.to_string());
false
}
}
} else {
true
}
}
/// Bootstrap a new Connection
/// Returns true if we have a valid connection now
async fn handle_connection(&mut self, connection: Connection) -> bool {
self.delta = Duration::default();
self.tick = None;
match open_channels(&connection).await {
Ok((tx, rx)) => {
self.connection = Some(OwnConnection { connection, tx, rx });
true
}
Err(err) => {
self.try_connect_error(err);
false
}
}
}
/// Update the state and log an error when bootstraping a new Connection
async fn try_connect_error(&mut self, err: String) {
self.delta = next_delta(self.delta);
self.tick = Some(sleep(self.delta));
eprintln!(
"ERROR cluster {}: node {}: {err}",
self.cluster_name, self.remote.id
);
eprintln!(
"INFO cluster {}: retry connecting to node {} in {:?}",
self.cluster_name, self.remote.id, self.delta
);
}
}
/// Compute the next wait Duration.
/// We're multiplying the Duration by [`TIMEOUT_FACTOR`] and cap it to [`MAX_TIMEOUT`].
fn next_delta(delta: Duration) -> Duration {
// Multiply timeout by TIMEOUT_FACTOR
let delta = Duration::from_millis(((delta.as_millis() as f64) * TIMEOUT_FACTOR) as u64);
// Cap to MAX_TIMEOUT
if delta > MAX_TIMEOUT {
MAX_TIMEOUT
} else {
delta
}
}
/// All possible communication messages
/// Set as an enum for forward compatibility
#[derive(Serialize, Deserialize)]
pub enum RemoteMessage {
/// Must be the first message sent over, then should not be sent again
Version(u32),
/// A line to transmit to your stream
Line(UtcLine),
/// Announce the node is closing
Quitting,
}
/// Open accept one stream and create one stream.
/// This way, there is no need to know if we created or accepted the connection.
async fn open_channels(
connection: &Connection,
) -> Result<(base::Sender<RemoteMessage>, base::Receiver<RemoteMessage>), String> {
let output = connection
.open_uni()
.await
.map_err(|err| format!("{err}"))?;
let input = connection
.accept_uni()
.await
.map_err(|err| format!("{err}"))?;
let (conn, mut tx, mut rx) = Connect::io_buffered(remoc::Cfg::default(), input, output, 1024)
.await
.map_err(|err| format!("{err}"))?;
tokio::spawn(conn);
tx.send(RemoteMessage::Version(PROTOCOL_VERSION)).await;
match rx.recv().await {
// Good protocol version!
Ok(Some(RemoteMessage::Version(PROTOCOL_VERSION))) => Ok((tx, rx)),
// Errors
Ok(Some(RemoteMessage::Version(other))) => Err(format!(
"incompatible version: {other}. We use {PROTOCOL_VERSION}. Consider upgrading the node with the older version."
)),
Ok(Some(RemoteMessage::Line(_))) => Err(format!(
"incorrect protocol message: remote did not send its protocol version."
)),
Ok(Some(RemoteMessage::Quitting)) => Err("remote unexpectedly quit".into()),
Ok(None) => Err("remote unexpectedly closed its channel".into()),
Err(err) => Err(format!("could not receive message: {err}")),
}
}

View file

@ -1,26 +1,12 @@
use std::collections::BTreeMap;
use std::time::Duration;
use std::sync::Arc;
use iroh::endpoint::Incoming;
use iroh::{Endpoint, PublicKey};
use iroh::{EndpointAddr, endpoint::Connection};
use tokio::{
sync::mpsc,
time::{Instant, sleep, sleep_until},
use iroh::{
Endpoint, PublicKey,
endpoint::{Connection, Incoming},
};
use crate::cluster::ALPN;
const START_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_TIMEOUT: Duration = Duration::from_secs(60 * 60); // 1 hour
const TIMEOUT_FACTOR: f64 = 1.5;
enum Event {
TryConnect(EndpointAddr),
Quit,
Tick,
Incoming(Option<Incoming>),
}
use reaction_plugin::shutdown::ShutdownToken;
use tokio::sync::mpsc;
enum Break {
Yes,
@ -29,176 +15,75 @@ enum Break {
pub struct EndpointManager {
/// The [`iroh::Endpoint`] to manage
endpoint: Endpoint,
endpoint: Arc<Endpoint>,
/// Cluster's name (for logging)
cluster_name: String,
/// Map of remote Endpoints to try to connect to
retry_connections: BTreeMap<Instant, (EndpointAddr, Duration)>,
/// Set of PublicKeys we're trying to connect to
all_connections: BTreeMap<PublicKey, Instant>,
/// Connection requests from the [`crate::Cluster`]
endpoint_addr_rx: mpsc::Receiver<EndpointAddr>,
/// Connection sender to the [`crate::Cluster`]
connection_tx: mpsc::Sender<Connection>,
/// Connection sender to the Connection Managers
connections_tx: BTreeMap<PublicKey, mpsc::Sender<Connection>>,
/// shutdown
shutdown: ShutdownToken,
}
impl EndpointManager {
pub fn new(
endpoint: Endpoint,
endpoint: Arc<Endpoint>,
cluster_name: String,
cluster_size: usize,
) -> (mpsc::Sender<EndpointAddr>, mpsc::Receiver<Connection>) {
let (tx1, rx1) = mpsc::channel(cluster_size);
let (tx2, rx2) = mpsc::channel(cluster_size);
connections_tx: BTreeMap<PublicKey, mpsc::Sender<Connection>>,
shutdown: ShutdownToken,
) {
tokio::spawn(async move {
Self {
endpoint,
cluster_name,
retry_connections: Default::default(),
all_connections: Default::default(),
endpoint_addr_rx: rx1,
connection_tx: tx2,
connections_tx,
shutdown,
}
.task()
.await
});
(tx1, rx2)
}
async fn task(&mut self) {
let mut tick = sleep(Duration::default());
loop {
// Uncomment this line and comment the select! for faster development in this function
// let event = Event::TryConnect(self.endpoint_addr_rx.recv().await);
let event = tokio::select! {
received = self.endpoint_addr_rx.recv() => {
match received {
Some(endpoint_addr) => Event::TryConnect(endpoint_addr),
None => Event::Quit,
}
}
incoming = self.endpoint.accept() => Event::Incoming(incoming),
_ = tick => Event::Tick,
let incoming = tokio::select! {
incoming = self.endpoint.accept() => incoming,
_ = self.shutdown.wait() => None,
};
if let Break::Yes = self.handle_event(event).await {
break;
match incoming {
Some(incoming) => {
if let Break::Yes = self.handle_incoming(incoming).await {
break;
}
}
None => break,
}
// Tick at next deadline
tick = sleep_until(
self.retry_connections
.keys()
.next()
.map(ToOwned::to_owned)
.unwrap_or_else(|| Instant::now() + MAX_TIMEOUT),
);
}
self.endpoint.close().await
}
async fn handle_event(&mut self, event: Event) -> Break {
match event {
Event::Quit => return Break::Yes,
Event::TryConnect(endpoint_addr) => match self.try_connect(endpoint_addr).await {
Ok(connection) => return self.check_and_send_connection(connection).await,
Err(endpoint_addr) => {
self.insert_endpoint(endpoint_addr, START_TIMEOUT);
}
},
Event::Tick => {
if let Some((endpoint_addr, delta)) = self.pop_next_endpoint() {
match self.try_connect(endpoint_addr).await {
Ok(connection) => {
return self.check_and_send_connection(connection).await;
}
Err(endpoint_addr) => {
let delta = next_delta(delta);
self.insert_endpoint(endpoint_addr, delta);
}
}
}
}
Event::Incoming(incoming) => {
// FIXME a malicious actor could maybe prevent a node from connecting to
// its cluster by sending lots of invalid slow connection requests?
// We could lower its priority https://docs.rs/tokio/latest/tokio/macro.select.html#fairness
// And/or moving the handshake to another task
if let Some(incoming) = incoming {
let remote_address = incoming.remote_address();
let remote_address_validated = incoming.remote_address_validated();
match incoming.await {
Ok(connection) => {
return self.check_and_send_connection(connection).await;
}
Err(err) => {
if remote_address_validated {
eprintln!("INFO refused connection from {}: {err}", remote_address)
} else {
eprintln!("INFO refused connection: {err}")
}
}
}
}
}
}
Break::No
}
/// Schedule an endpoint to try to connect to later
fn insert_endpoint(&mut self, endpoint_addr: EndpointAddr, delta: Duration) {
if !delta.is_zero() {
eprintln!(
"INFO cluster {}: retry connecting to node {} in {:?}",
self.cluster_name, endpoint_addr.id, delta
);
}
let next = Instant::now() + delta;
// Schedule this address for later
self.all_connections.insert(endpoint_addr.id, next);
self.retry_connections.insert(next, (endpoint_addr, delta));
}
/// Returns the next endpoint we should try to connect to
fn pop_next_endpoint(&mut self) -> Option<(EndpointAddr, Duration)> {
if self
.retry_connections
.keys()
.next()
.is_some_and(|time| time < &Instant::now())
{
let (_, tuple) = self.retry_connections.pop_first().unwrap();
self.all_connections.remove(&tuple.0.id);
Some(tuple)
} else {
None
}
}
/// Try connecting to a remote endpoint
async fn try_connect(&self, addr: EndpointAddr) -> Result<Connection, EndpointAddr> {
match self.endpoint.connect(addr.clone(), ALPN[0]).await {
Ok(connection) => Ok(connection),
async fn handle_incoming(&mut self, incoming: Incoming) -> Break {
// FIXME a malicious actor could maybe prevent a node from connecting to
// its cluster by sending lots of invalid slow connection requests?
// We could lower its priority https://docs.rs/tokio/latest/tokio/macro.select.html#fairness
// And/or moving the handshake to another task
let remote_address = incoming.remote_address();
let remote_address_validated = incoming.remote_address_validated();
let connection = match incoming.await {
Ok(connection) => connection,
Err(err) => {
eprintln!(
"ERROR cluster {}: node {}: {err}",
self.cluster_name, addr.id
);
Err(addr)
if remote_address_validated {
eprintln!("INFO refused connection from {}: {err}", remote_address)
} else {
eprintln!("INFO refused connection: {err}")
}
return Break::No;
}
}
}
};
/// Check that an incoming connection is an endpoint we're trying to connect,
/// and send it to the [`Cluster`]
async fn check_and_send_connection(&mut self, connection: Connection) -> Break {
let remote_id = match connection.remote_id() {
Ok(id) => id,
Err(err) => {
@ -210,44 +95,30 @@ impl EndpointManager {
}
};
match self.all_connections.remove(&remote_id) {
match self.connections_tx.get(&remote_id) {
None => {
eprintln!(
"WARN cluster {}: new peer's id '{remote_id}' is not in our list, refusing incoming connection.",
self.cluster_name
"WARN cluster {}: incoming connection from node '{remote_id}', ip: {} is not in our list, refusing incoming connection.",
self.cluster_name, remote_address
);
eprintln!(
"INFO cluster {}: {}, {}",
self.cluster_name,
"maybe we're already connected to it, maybe it's not from our cluster",
"maybe it is new and it has not been configured yet on this node"
"maybe it's not from our cluster,",
"maybe this node's configuration has not yet been updated to add this new node."
);
return Break::No;
}
Some(time) => {
self.retry_connections.remove(&time);
Some(tx) => {
if let Err(_) = tx.send(connection).await {
// This means the main cluster loop has exited, so let's quit
return Break::Yes;
}
}
}
// TODO persist the incoming address, so that we don't forget this address
if let Err(_) = self.connection_tx.send(connection).await {
// This means the main cluster loop has exited, so let's quit
return Break::Yes;
}
Break::No
}
}
/// Compute the next wait Duration.
/// We're multiplying the Duration by [`TIMEOUT_FACTOR`] and cap it to [`MAX_TIMEOUT`].
fn next_delta(delta: Duration) -> Duration {
// Multiply timeout by TIMEOUT_FACTOR
let delta = Duration::from_millis(((delta.as_millis() as f64) * TIMEOUT_FACTOR) as u64);
// Cap to MAX_TIMEOUT
if delta > MAX_TIMEOUT {
MAX_TIMEOUT
} else {
delta
}
}

View file

@ -11,9 +11,9 @@ use reaction_plugin::{
};
use remoc::{rch::mpsc, rtc};
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
mod cluster;
mod connection;
mod endpoint;
mod secret_key;
@ -60,6 +60,12 @@ fn ipv6_unspecified() -> Option<Ipv6Addr> {
Some(Ipv6Addr::UNSPECIFIED)
}
#[derive(Serialize, Deserialize)]
struct NodeOption {
public_key: String,
addresses: Vec<SocketAddr>,
}
/// Stream information before start
struct StreamInit {
name: String,
@ -72,12 +78,6 @@ struct StreamInit {
tx: mpsc::Sender<Line>,
}
#[derive(Serialize, Deserialize)]
struct NodeOption {
public_key: String,
addresses: Vec<SocketAddr>,
}
#[derive(Serialize, Deserialize)]
struct ActionOptions {
/// The line to send to the corresponding cluster, example: "ban \<ip\>"