From 7383c0594098d994616bb6da3e54138dd7219589 Mon Sep 17 00:00:00 2001 From: erhant Date: Wed, 27 Nov 2024 19:27:00 +0300 Subject: [PATCH 01/16] node refactors for parallelization --- compute/src/handlers/mod.rs | 27 --- compute/src/handlers/pingpong.rs | 20 ++- compute/src/handlers/workflow.rs | 60 +++---- compute/src/main.rs | 2 +- compute/src/node.rs | 271 ++++++++++++++++--------------- compute/src/payloads/stats.rs | 5 +- compute/src/utils/message.rs | 4 +- p2p/src/client.rs | 7 + 8 files changed, 190 insertions(+), 206 deletions(-) diff --git a/compute/src/handlers/mod.rs b/compute/src/handlers/mod.rs index 00ccf51..27d6f60 100644 --- a/compute/src/handlers/mod.rs +++ b/compute/src/handlers/mod.rs @@ -1,32 +1,5 @@ -use crate::{utils::DKNMessage, DriaComputeNode}; -use async_trait::async_trait; -use dkn_p2p::libp2p::gossipsub::MessageAcceptance; -use eyre::Result; - mod pingpong; pub use pingpong::PingpongHandler; mod workflow; pub use workflow::WorkflowHandler; - -/// A DKN task is to be handled by the compute node, respecting this trait. -/// -/// It is expected for the implemented handler to handle messages coming from `LISTEN_TOPIC`, -/// and then respond back to the `RESPONSE_TOPIC`. -#[async_trait] -pub trait ComputeHandler { - /// Gossipsub topic name to listen for incoming messages from the network. - const LISTEN_TOPIC: &'static str; - /// Gossipsub topic name to respond with messages to the network. - const RESPONSE_TOPIC: &'static str; - - /// A generic handler for DKN tasks. - /// - /// Returns a `MessageAcceptance` value that tells the P2P client to accept the incoming message. - /// - /// The handler has mutable reference to the compute node, and therefore can respond within the handler itself in any way it would like. - async fn handle_compute( - node: &mut DriaComputeNode, - message: DKNMessage, - ) -> Result; -} diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index eef930b..9f31fc8 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -1,9 +1,7 @@ -use super::ComputeHandler; use crate::{ utils::{get_current_time_nanos, DKNMessage}, DriaComputeNode, }; -use async_trait::async_trait; use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use dkn_workflows::{Model, ModelProvider}; use eyre::{Context, Result}; @@ -24,12 +22,20 @@ struct PingpongResponse { pub(crate) timestamp: u128, } -#[async_trait] -impl ComputeHandler for PingpongHandler { - const LISTEN_TOPIC: &'static str = "ping"; - const RESPONSE_TOPIC: &'static str = "pong"; +impl PingpongHandler { + pub(crate) const LISTEN_TOPIC: &'static str = "ping"; + pub(crate) const RESPONSE_TOPIC: &'static str = "pong"; - async fn handle_compute( + /// Handles the ping message and responds with a pong message. + /// + /// 1. Parses the payload of the incoming message into a `PingpongPayload`. + /// 2. Checks if the current time is past the deadline specified in the ping request. + /// 3. If the current time is past the deadline, logs a debug message and ignores the ping request. + /// 4. If the current time is within the deadline, constructs a `PingpongResponse` with the UUID from the ping request, the models from the node's configuration, and the current timestamp. + /// 5. Creates a new signed `DKNMessage` with the response body and the `RESPONSE_TOPIC`. + /// 6. Publishes the response message. + /// 7. Returns `MessageAcceptance::Accept` so that ping is propagated to others as well. + pub(crate) async fn handle_ping( node: &mut DriaComputeNode, message: DKNMessage, ) -> Result { diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index 793eda2..c405e09 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,23 +1,19 @@ -use std::time::Instant; - -use async_trait::async_trait; use dkn_p2p::libp2p::gossipsub::MessageAcceptance; use dkn_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; use eyre::{eyre, Context, Result}; use libsecp256k1::PublicKey; use serde::Deserialize; +use std::time::Instant; use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats}; use crate::utils::{get_current_time_nanos, DKNMessage}; use crate::DriaComputeNode; -use super::ComputeHandler; - pub struct WorkflowHandler; #[derive(Debug, Deserialize)] struct WorkflowPayload { - /// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed. + /// [Workflow](https://github.com/andthattoo/ollama-workflows/blob/main/src/program/workflow.rs) object to be parsed. pub(crate) workflow: Workflow, /// A lıst of model (that can be parsed into `Model`) or model provider names. /// If model provider is given, the first matching model in the node config is used for that. @@ -28,12 +24,11 @@ struct WorkflowPayload { pub(crate) prompt: Option, } -#[async_trait] -impl ComputeHandler for WorkflowHandler { - const LISTEN_TOPIC: &'static str = "task"; - const RESPONSE_TOPIC: &'static str = "results"; +impl WorkflowHandler { + pub(crate) const LISTEN_TOPIC: &'static str = "task"; + pub(crate) const RESPONSE_TOPIC: &'static str = "results"; - async fn handle_compute( + pub(crate) async fn handle_compute( node: &mut DriaComputeNode, message: DKNMessage, ) -> Result { @@ -85,26 +80,29 @@ impl ComputeHandler for WorkflowHandler { } else { Executor::new(model) }; - let mut memory = ProgramMemory::new(); let entry: Option = task .input .prompt .map(|prompt| Entry::try_value_or_str(&prompt)); // execute workflow with cancellation - let exec_result: Result; + let mut memory = ProgramMemory::new(); + let exec_started_at = Instant::now(); - tokio::select! { - _ = node.cancellation.cancelled() => { - log::info!("Received cancellation, quitting all tasks."); - return Ok(MessageAcceptance::Accept); - }, - exec_result_inner = executor.execute(entry.as_ref(), &task.input.workflow, &mut memory) => { - exec_result = exec_result_inner.map_err(|e| eyre!("Execution error: {}", e.to_string())); - } - } + let exec_result = executor + .execute(entry.as_ref(), &task.input.workflow, &mut memory) + .await + .map_err(|e| eyre!("Execution error: {}", e.to_string())); task_stats = task_stats.record_execution_time(exec_started_at); + Ok(MessageAcceptance::Accept) + } + + async fn handle_publish( + node: &mut DriaComputeNode, + result: String, + task_id: String, + ) -> Result<()> { let (message, acceptance) = match exec_result { Ok(result) => { // obtain public key from the payload @@ -115,7 +113,7 @@ impl ComputeHandler for WorkflowHandler { // prepare signed and encrypted payload let payload = TaskResponsePayload::new( result, - &task.task_id, + &task_id, &task_public_key, &node.config.secret_key, model_name, @@ -125,11 +123,7 @@ impl ComputeHandler for WorkflowHandler { .wrap_err("Could not serialize response payload")?; // prepare signed message - log::debug!( - "Publishing result for task {}\n{}", - task.task_id, - payload_str - ); + log::debug!("Publishing result for task {}\n{}", task_id, payload_str); let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC); // accept so that if there are others included in filter they can do the task (message, MessageAcceptance::Accept) @@ -137,11 +131,11 @@ impl ComputeHandler for WorkflowHandler { Err(err) => { // use pretty display string for error logging with causes let err_string = format!("{:#}", err); - log::error!("Task {} failed: {}", task.task_id, err_string); + log::error!("Task {} failed: {}", task_id, err_string); // prepare error payload let error_payload = TaskErrorPayload { - task_id: task.task_id.clone(), + task_id, error: err_string, model: model_name, stats: task_stats.record_published_at(), @@ -166,7 +160,7 @@ impl ComputeHandler for WorkflowHandler { log::error!("{}", err_msg); let payload = serde_json::json!({ - "taskId": task.task_id, + "taskId": task_id, "error": err_msg, }); let message = DKNMessage::new_signed( @@ -175,8 +169,8 @@ impl ComputeHandler for WorkflowHandler { &node.config.secret_key, ); node.publish(message).await?; - } + }; - Ok(acceptance) + Ok(()) } } diff --git a/compute/src/main.rs b/compute/src/main.rs index 76cc6a9..16b7bf4 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -84,7 +84,7 @@ async fn main() -> Result<()> { // launch the node in a separate thread log::info!("Spawning compute node thread."); let node_handle = tokio::spawn(async move { - if let Err(err) = node.launch().await { + if let Err(err) = node.run().await { log::error!("Node launch error: {}", err); panic!("Node failed.") }; diff --git a/compute/src/node.rs b/compute/src/node.rs index 1ebd802..cdc8198 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -1,6 +1,6 @@ use dkn_p2p::{ libp2p::{ - gossipsub::{self, Message, MessageId}, + gossipsub::{Message, MessageAcceptance, MessageId}, PeerId, }, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, @@ -88,22 +88,6 @@ impl DriaComputeNode { Ok(()) } - /// Validates a message, but only logs the error. - pub async fn validate_safe( - &mut self, - msg_id: &gossipsub::MessageId, - propagation_source: &PeerId, - acceptance: gossipsub::MessageAcceptance, - ) { - if let Err(e) = self - .p2p - .validate_message(msg_id, propagation_source, acceptance) - .await - { - log::error!("Error validating message: {:?}", e); - } - } - /// Unsubscribe from a certain task with its topic. pub async fn unsubscribe(&mut self, topic: &str) -> Result<()> { let ok = self.p2p.unsubscribe(topic).await?; @@ -130,14 +114,133 @@ impl DriaComputeNode { } /// Returns the list of connected peers, `mesh` and `all`. - #[inline] + #[inline(always)] pub async fn peers(&self) -> Result<(Vec, Vec)> { self.p2p.peers().await } - /// Launches the main loop of the compute node. + /// Handles a GossipSub message received from the network. + async fn handle_message( + &mut self, + (peer_id, message_id, message): (PeerId, &MessageId, Message), + ) -> MessageAcceptance { + // refresh admin rpc peer ids + // TODO: move this to main loop with tokio select + if self.available_nodes.can_refresh() { + log::info!("Refreshing available nodes."); + + if let Err(e) = self.available_nodes.populate_with_api().await { + log::error!("Error refreshing available nodes: {:?}", e); + }; + + // dial all rpc nodes for better connectivity + for rpc_addr in self.available_nodes.rpc_addrs.iter() { + log::debug!("Dialling RPC node: {}", rpc_addr); + if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { + log::warn!("Error dialling RPC node: {:?}", e); + }; + } + + // print network info + log::debug!("{:?}", self.p2p.network_info().await); + } + + // check peer count + // TODO: move this to main loop with tokio select + if self.peer_last_refreshed.elapsed() > Duration::from_secs(PEER_REFRESH_INTERVAL_SECS) { + match self.p2p.peer_counts().await { + Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), + Err(e) => { + log::error!("Error getting peer counts: {:?}", e); + } + } + + self.peer_last_refreshed = Instant::now(); + + // TODO: add peer list as well + } + + // handle message with respect to its topic + let topic_str = message.topic.as_str(); + if std::matches!( + topic_str, + PingpongHandler::LISTEN_TOPIC | WorkflowHandler::LISTEN_TOPIC + ) { + // ensure that the message is from a valid source (origin) + let Some(source_peer_id) = message.source else { + log::warn!( + "Received {} message from {} without source.", + topic_str, + peer_id + ); + return MessageAcceptance::Ignore; + }; + + // log the received message + log::info!( + "Received {} message ({}) from {}", + topic_str, + message_id, + peer_id, + ); + + // ensure that message is from the known RPCs + if !self.available_nodes.rpc_nodes.contains(&source_peer_id) { + log::warn!( + "Received message from unauthorized source: {}", + source_peer_id + ); + log::debug!("Allowed sources: {:#?}", self.available_nodes.rpc_nodes); + return MessageAcceptance::Ignore; + } + + // first, parse the raw gossipsub message to a prepared message + let message = match self.parse_message_to_prepared_message(message.clone()) { + Ok(message) => message, + Err(e) => { + log::error!("Error parsing message: {:?}", e); + log::debug!("Message: {}", String::from_utf8_lossy(&message.data)); + return MessageAcceptance::Ignore; + } + }; + + // then handle the prepared message + let handler_result = match topic_str { + WorkflowHandler::LISTEN_TOPIC => { + WorkflowHandler::handle_compute(self, message).await + } + PingpongHandler::LISTEN_TOPIC => PingpongHandler::handle_ping(self, message).await, + _ => unreachable!(), // unreachable because of the if condition + }; + + // validate the message based on the result + match handler_result { + Ok(acceptance) => { + return acceptance; + } + Err(err) => { + log::error!("Error handling {} message: {:?}", topic_str, err); + return MessageAcceptance::Ignore; + } + } + } else if std::matches!( + topic_str, + PingpongHandler::RESPONSE_TOPIC | WorkflowHandler::RESPONSE_TOPIC + ) { + // since we are responding to these topics, we might receive messages from other compute nodes + // we can gracefully ignore them and propagate it to to others + log::trace!("Ignoring message for topic: {}", topic_str); + return MessageAcceptance::Accept; + } else { + // reject this message as its from a foreign topic + log::warn!("Received message from unexpected topic: {}", topic_str); + return MessageAcceptance::Reject; + } + } + + /// Runs the main loop of the compute node. /// This method is not expected to return until cancellation occurs. - pub async fn launch(&mut self) -> Result<()> { + pub async fn run(&mut self) -> Result<()> { // subscribe to topics self.subscribe(PingpongHandler::LISTEN_TOPIC).await?; self.subscribe(PingpongHandler::RESPONSE_TOPIC).await?; @@ -148,117 +251,20 @@ impl DriaComputeNode { // the underlying p2p client is expected to handle the rest within its own loop loop { tokio::select! { - event = self.msg_rx.recv() => { - // refresh admin rpc peer ids - if self.available_nodes.can_refresh() { - log::info!("Refreshing available nodes."); - - if let Err(e) = self.available_nodes.populate_with_api().await { - log::error!("Error refreshing available nodes: {:?}", e); - }; - - // dial all rpc nodes for better connectivity - for rpc_addr in self.available_nodes.rpc_addrs.iter() { - log::debug!("Dialling RPC node: {}", rpc_addr); - if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { - log::warn!("Error dialling RPC node: {:?}", e); - }; - } - - // print network info - log::debug!("{:?}", self.p2p.network_info().await); - } - - // check peer count - if self.peer_last_refreshed.elapsed() > Duration::from_secs(PEER_REFRESH_INTERVAL_SECS) { - let (mesh_cnt, all_cnt) = self.p2p.peer_counts().await.unwrap_or((0, 0)); - log::info!("Peer Count (mesh/all): {} / {}", mesh_cnt, all_cnt); - - self.peer_last_refreshed = Instant::now(); - - // TODO: add peer list as well - } - - // check if there was any event at all - let Some((peer_id, message_id, message)) = event else { - continue; - }; - - let topic = message.topic.clone(); - let topic_str = topic.as_str(); - - // handle message w.r.t topic - if std::matches!(topic_str, PingpongHandler::LISTEN_TOPIC | WorkflowHandler::LISTEN_TOPIC) { - // ensure that the message is from a valid source (origin) - let source_peer_id = match message.source { - Some(peer) => peer, - None => { - log::warn!("Received {} message from {} without source.", topic_str, peer_id); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Ignore).await; - continue; - } - }; - - // log the received message - log::info!( - "Received {} message ({}) from {}", - topic_str, - message_id, - peer_id, - ); - - // ensure that message is from the static RPCs - if !self.available_nodes.rpc_nodes.contains(&source_peer_id) { - log::warn!("Received message from unauthorized source: {}", source_peer_id); - log::debug!("Allowed sources: {:#?}", self.available_nodes.rpc_nodes); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Ignore).await; - continue; - } - - // first, parse the raw gossipsub message to a prepared message - // if unparseable, - let message = match self.parse_message_to_prepared_message(message.clone()) { - Ok(message) => message, - Err(e) => { - log::error!("Error parsing message: {:?}", e); - log::debug!("Message: {}", String::from_utf8_lossy(&message.data)); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Ignore).await; - continue; - } - }; - - // then handle the prepared message - let handler_result = match topic_str { - WorkflowHandler::LISTEN_TOPIC => { - WorkflowHandler::handle_compute(self, message).await - } - PingpongHandler::LISTEN_TOPIC => { - PingpongHandler::handle_compute(self, message).await - } - // TODO: can we do this in a nicer way? yes, cast to enum above and let type-casting do the work - _ => unreachable!() // unreachable because of the if condition - }; - - // validate the message based on the result - match handler_result { - Ok(acceptance) => { - self.validate_safe(&message_id, &peer_id, acceptance).await; - }, - Err(err) => { - log::error!("Error handling {} message: {:?}", topic_str, err); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Ignore).await; - } + gossipsub_msg = self.msg_rx.recv() => { + if let Some((peer_id, message_id, message)) = gossipsub_msg { + // handle the message, returning a message acceptance for the received one + let acceptance = self.handle_message((peer_id, &message_id, message)).await; + + // validate the message based on the acceptance + // cant do anything but log if this gives an error as well + if let Err(e) = self.p2p.validate_message(&message_id, &peer_id, acceptance).await { + log::error!("Error validating message {}: {:?}", message_id, e); } - } else if std::matches!(topic_str, PingpongHandler::RESPONSE_TOPIC | WorkflowHandler::RESPONSE_TOPIC) { - // since we are responding to these topics, we might receive messages from other compute nodes - // we can gracefully ignore them and propagate it to to others - log::trace!("Ignoring message for topic: {}", topic_str); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Accept).await; } else { - // reject this message as its from a foreign topic - log::warn!("Received message from unexpected topic: {}", topic_str); - self.validate_safe(&message_id, &peer_id, gossipsub::MessageAcceptance::Reject).await; - } + log::warn!("Message channel closed."); + break; + }; }, _ = self.cancellation.cancelled() => break, } @@ -293,10 +299,7 @@ impl DriaComputeNode { /// /// This also checks the signature of the message, expecting a valid signature from admin node. // TODO: move this somewhere? - pub fn parse_message_to_prepared_message( - &self, - message: gossipsub::Message, - ) -> Result { + pub fn parse_message_to_prepared_message(&self, message: Message) -> Result { // the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately. log::debug!("Parsing {} message.", message.topic.as_str()); let message = DKNMessage::try_from(message)?; @@ -337,7 +340,7 @@ mod tests { // launch & wait for a while for connections log::info!("Waiting a bit for peer setup."); tokio::select! { - _ = node.launch() => (), + _ = node.run() => (), _ = tokio::time::sleep(tokio::time::Duration::from_secs(20)) => cancellation.cancel(), } log::info!("Connected Peers:\n{:#?}", node.peers().await?); diff --git a/compute/src/payloads/stats.rs b/compute/src/payloads/stats.rs index 86a2d66..7a16a2b 100644 --- a/compute/src/payloads/stats.rs +++ b/compute/src/payloads/stats.rs @@ -3,8 +3,8 @@ use std::time::Instant; use crate::utils::get_current_time_nanos; -/// A task stat. -/// Returning this as the payload helps to debug the errors received at client side. +/// Task stats for diagnostics. +/// Returning this as the payload helps to debug the errors received at client side, and latencies. #[derive(Default, Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TaskStats { @@ -30,6 +30,7 @@ impl TaskStats { self } + /// Records the execution time of the task. pub fn record_execution_time(mut self, started_at: Instant) -> Self { self.execution_time = Instant::now().duration_since(started_at).as_nanos(); self diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index b0985da..8beab5d 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -17,7 +17,7 @@ pub struct DKNMessage { pub(crate) payload: String, /// The topic of the message, derived from `TopicHash` /// - /// NOTE: This can be obtained via TopicHash in GossipSub + /// NOTE: This can be obtained via `TopicHash` in GossipSub pub(crate) topic: String, /// The version of the Dria Compute Node /// @@ -28,7 +28,7 @@ pub struct DKNMessage { pub(crate) identity: String, /// The timestamp of the message, in nanoseconds /// - /// NOTE: This can be obtained via DataTransform in GossipSub + /// NOTE: This can be obtained via `DataTransform` in GossipSub pub(crate) timestamp: u128, } diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 74c499b..4f48d5f 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -291,6 +291,13 @@ impl DriaP2PClient { // this is usually the external address via relay log::info!("External address confirmed: {}", address); } + // SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => { + // if let Some(peer_id) = peer_id { + // log::warn!("Could not connect to peer {}: {:?}", peer_id, error); + // } else { + // log::warn!("Outgoing connection error: {:?}", error); + // } + // } event => log::trace!("Unhandled Swarm Event: {:?}", event), } } From 601e0e87713c96e36478bbfdba8693cc844a5f5a Mon Sep 17 00:00:00 2001 From: erhant Date: Wed, 27 Nov 2024 23:25:52 +0300 Subject: [PATCH 02/16] parallel workflows first version works --- compute/src/handlers/workflow.rs | 88 ++++++++++--------- compute/src/lib.rs | 1 + compute/src/main.rs | 9 +- compute/src/node.rs | 53 +++++++++--- compute/src/workers/mod.rs | 1 + compute/src/workers/workflow.rs | 141 +++++++++++++++++++++++++++++++ 6 files changed, 241 insertions(+), 52 deletions(-) create mode 100644 compute/src/workers/mod.rs create mode 100644 compute/src/workers/workflow.rs diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index c405e09..cf11c16 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,12 +1,13 @@ use dkn_p2p::libp2p::gossipsub::MessageAcceptance; -use dkn_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow}; -use eyre::{eyre, Context, Result}; +use dkn_workflows::{Entry, Executor, ModelProvider, Workflow}; +use eyre::{Context, Result}; use libsecp256k1::PublicKey; use serde::Deserialize; -use std::time::Instant; +use tokio_util::either::Either; use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats}; use crate::utils::{get_current_time_nanos, DKNMessage}; +use crate::workers::workflow::*; use crate::DriaComputeNode; pub struct WorkflowHandler; @@ -31,11 +32,13 @@ impl WorkflowHandler { pub(crate) async fn handle_compute( node: &mut DriaComputeNode, message: DKNMessage, - ) -> Result { + ) -> Result> { let task = message .parse_payload::>(true) .wrap_err("Could not parse workflow task")?; - let mut task_stats = TaskStats::default().record_received_at(); + + // TODO: !!! + let task_stats = TaskStats::default().record_received_at(); // check if deadline is past or not let current_time = get_current_time_nanos(); @@ -48,7 +51,7 @@ impl WorkflowHandler { ); // ignore the message - return Ok(MessageAcceptance::Ignore); + return Ok(Either::Left(MessageAcceptance::Ignore)); } // check task inclusion via the bloom filter @@ -59,9 +62,15 @@ impl WorkflowHandler { ); // accept the message, someone else may be included in filter - return Ok(MessageAcceptance::Accept); + return Ok(Either::Left(MessageAcceptance::Accept)); } + // obtain public key from the payload + // do this early to avoid unnecessary processing + let task_public_key_bytes = + hex::decode(&task.public_key).wrap_err("could not decode public key")?; + let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?; + // read model / provider from the task let (model_provider, model) = node .config @@ -80,50 +89,51 @@ impl WorkflowHandler { } else { Executor::new(model) }; + + // prepare entry from prompt let entry: Option = task .input .prompt .map(|prompt| Entry::try_value_or_str(&prompt)); - // execute workflow with cancellation - let mut memory = ProgramMemory::new(); - - let exec_started_at = Instant::now(); - let exec_result = executor - .execute(entry.as_ref(), &task.input.workflow, &mut memory) - .await - .map_err(|e| eyre!("Execution error: {}", e.to_string())); - task_stats = task_stats.record_execution_time(exec_started_at); - - Ok(MessageAcceptance::Accept) + // get workflow as well + let workflow = task.input.workflow; + + Ok(Either::Right(WorkflowsWorkerInput { + entry, + executor, + workflow, + model_name, + task_id: task.task_id, + public_key: task_public_key, + stats: task_stats, + })) } - async fn handle_publish( + pub(crate) async fn handle_publish( node: &mut DriaComputeNode, - result: String, - task_id: String, - ) -> Result<()> { - let (message, acceptance) = match exec_result { + task: WorkflowsWorkerOutput, + ) -> Result { + let (message, acceptance) = match task.result { Ok(result) => { - // obtain public key from the payload - let task_public_key_bytes = - hex::decode(&task.public_key).wrap_err("Could not decode public key")?; - let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?; - // prepare signed and encrypted payload let payload = TaskResponsePayload::new( result, - &task_id, - &task_public_key, + &task.task_id, + &task.public_key, &node.config.secret_key, - model_name, - task_stats.record_published_at(), + task.model_name, + task.stats.record_published_at(), )?; let payload_str = serde_json::to_string(&payload) .wrap_err("Could not serialize response payload")?; // prepare signed message - log::debug!("Publishing result for task {}\n{}", task_id, payload_str); + log::debug!( + "Publishing result for task {}\n{}", + task.task_id, + payload_str + ); let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC); // accept so that if there are others included in filter they can do the task (message, MessageAcceptance::Accept) @@ -131,14 +141,14 @@ impl WorkflowHandler { Err(err) => { // use pretty display string for error logging with causes let err_string = format!("{:#}", err); - log::error!("Task {} failed: {}", task_id, err_string); + log::error!("Task {} failed: {}", task.task_id, err_string); // prepare error payload let error_payload = TaskErrorPayload { - task_id, + task_id: task.task_id.clone(), error: err_string, - model: model_name, - stats: task_stats.record_published_at(), + model: task.model_name, + stats: task.stats.record_published_at(), }; let error_payload_str = serde_json::to_string(&error_payload) .wrap_err("Could not serialize error payload")?; @@ -160,7 +170,7 @@ impl WorkflowHandler { log::error!("{}", err_msg); let payload = serde_json::json!({ - "taskId": task_id, + "taskId": task.task_id, "error": err_msg, }); let message = DKNMessage::new_signed( @@ -171,6 +181,6 @@ impl WorkflowHandler { node.publish(message).await?; }; - Ok(()) + Ok(acceptance) } } diff --git a/compute/src/lib.rs b/compute/src/lib.rs index e54feb7..cebc37a 100644 --- a/compute/src/lib.rs +++ b/compute/src/lib.rs @@ -6,6 +6,7 @@ pub(crate) mod handlers; pub(crate) mod node; pub(crate) mod payloads; pub(crate) mod utils; +pub(crate) mod workers; /// Crate version of the compute node. /// This value is attached within the published messages. diff --git a/compute/src/main.rs b/compute/src/main.rs index 16b7bf4..3812b37 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -75,12 +75,16 @@ async fn main() -> Result<()> { } let node_token = token.clone(); - let (mut node, p2p) = DriaComputeNode::new(config, node_token).await?; + let (mut node, p2p, mut workflows) = DriaComputeNode::new(config, node_token).await?; // launch the p2p in a separate thread log::info!("Spawning peer-to-peer client thread."); let p2p_handle = tokio::spawn(async move { p2p.run().await }); + // launch the workflows in a separate thread + log::info!("Spawning workflows worker thread."); + let workflows_handle = tokio::spawn(async move { workflows.run().await }); + // launch the node in a separate thread log::info!("Spawning compute node thread."); let node_handle = tokio::spawn(async move { @@ -94,6 +98,9 @@ async fn main() -> Result<()> { if let Err(err) = node_handle.await { log::error!("Node handle error: {}", err); }; + if let Err(err) = workflows_handle.await { + log::error!("Workflows handle error: {}", err); + }; if let Err(err) = p2p_handle.await { log::error!("P2P handle error: {}", err); }; diff --git a/compute/src/node.rs b/compute/src/node.rs index cdc8198..b832101 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -10,12 +10,13 @@ use tokio::{ sync::mpsc, time::{Duration, Instant}, }; -use tokio_util::sync::CancellationToken; +use tokio_util::{either::Either, sync::CancellationToken}; use crate::{ config::*, handlers::*, utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage}, + workers::workflow::{WorkflowsWorker, WorkflowsWorkerInput, WorkflowsWorkerOutput}, }; /// Number of seconds between refreshing the Kademlia DHT. @@ -27,7 +28,10 @@ pub struct DriaComputeNode { pub available_nodes: AvailableNodes, pub cancellation: CancellationToken, peer_last_refreshed: Instant, - msg_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, + // channels + message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, + worklow_tx: mpsc::Sender, + publish_rx: mpsc::Receiver, } impl DriaComputeNode { @@ -37,7 +41,7 @@ impl DriaComputeNode { pub async fn new( config: DriaComputeNodeConfig, cancellation: CancellationToken, - ) -> Result<(DriaComputeNode, DriaP2PClient)> { + ) -> Result<(DriaComputeNode, DriaP2PClient, WorkflowsWorker)> { // create the keypair from secret key let keypair = secret_to_keypair(&config.secret_key); @@ -55,7 +59,7 @@ impl DriaComputeNode { log::info!("Using identity: {}", protocol); // create p2p client - let (p2p_client, p2p_commander, msg_rx) = DriaP2PClient::new( + let (p2p_client, p2p_commander, message_rx) = DriaP2PClient::new( keypair, config.p2p_listen_addr.clone(), available_nodes.bootstrap_nodes.clone().into_iter(), @@ -64,16 +68,24 @@ impl DriaComputeNode { protocol, )?; + // create workflow worker + let (worklow_tx, workflow_rx) = mpsc::channel(256); + let (publish_tx, publish_rx) = mpsc::channel(256); + let workflows_worker = WorkflowsWorker::new(workflow_rx, publish_tx); + Ok(( DriaComputeNode { config, p2p: p2p_commander, cancellation, available_nodes, - msg_rx, + message_rx, + worklow_tx, + publish_rx, peer_last_refreshed: Instant::now(), }, p2p_client, + workflows_worker, )) } @@ -207,7 +219,18 @@ impl DriaComputeNode { // then handle the prepared message let handler_result = match topic_str { WorkflowHandler::LISTEN_TOPIC => { - WorkflowHandler::handle_compute(self, message).await + let compute_result = WorkflowHandler::handle_compute(self, message).await; + match compute_result { + Ok(Either::Left(acceptance)) => Ok(acceptance), + Ok(Either::Right(workflow_message)) => { + if let Err(e) = self.worklow_tx.send(workflow_message).await { + log::error!("Error sending workflow message: {:?}", e); + }; + + Ok(MessageAcceptance::Accept) + } + Err(err) => Err(err), + } } PingpongHandler::LISTEN_TOPIC => PingpongHandler::handle_ping(self, message).await, _ => unreachable!(), // unreachable because of the if condition @@ -251,7 +274,12 @@ impl DriaComputeNode { // the underlying p2p client is expected to handle the rest within its own loop loop { tokio::select! { - gossipsub_msg = self.msg_rx.recv() => { + publish_msg = self.publish_rx.recv() => { + if let Some(result) = publish_msg { + WorkflowHandler::handle_publish(self, result).await?; + } + }, + gossipsub_msg = self.message_rx.recv() => { if let Some((peer_id, message_id, message)) = gossipsub_msg { // handle the message, returning a message acceptance for the received one let acceptance = self.handle_message((peer_id, &message_id, message)).await; @@ -282,15 +310,16 @@ impl DriaComputeNode { Ok(()) } - /// Shutdown channels between p2p and yourself. + /// Shutdown channels between p2p, worker and yourself. pub async fn shutdown(&mut self) -> Result<()> { - // send shutdown signal log::debug!("Sending shutdown command to p2p client."); self.p2p.shutdown().await?; - // close message channel log::debug!("Closing message channel."); - self.msg_rx.close(); + self.message_rx.close(); + + log::debug!("Closing publish channel."); + self.publish_rx.close(); Ok(()) } @@ -329,7 +358,7 @@ mod tests { // create node let cancellation = CancellationToken::new(); - let (mut node, p2p) = + let (mut node, p2p, _) = DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone()) .await .expect("should create node"); diff --git a/compute/src/workers/mod.rs b/compute/src/workers/mod.rs new file mode 100644 index 0000000..a4283a6 --- /dev/null +++ b/compute/src/workers/mod.rs @@ -0,0 +1 @@ +pub mod workflow; diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs new file mode 100644 index 0000000..74a8e0c --- /dev/null +++ b/compute/src/workers/workflow.rs @@ -0,0 +1,141 @@ +use dkn_workflows::{Entry, ExecutionError, Executor, ProgramMemory, Workflow}; +use libsecp256k1::PublicKey; +use tokio::sync::mpsc; + +use crate::payloads::TaskStats; + +pub struct WorkflowsWorkerInput { + pub entry: Option, + pub executor: Executor, + pub workflow: Workflow, + // piggybacked + pub public_key: PublicKey, + pub task_id: String, + pub model_name: String, + pub stats: TaskStats, +} + +pub struct WorkflowsWorkerOutput { + pub result: Result, + // piggybacked + pub public_key: PublicKey, + pub task_id: String, + pub model_name: String, + pub stats: TaskStats, +} + +pub struct WorkflowsWorker { + worklow_rx: mpsc::Receiver, + publish_tx: mpsc::Sender, +} + +impl WorkflowsWorker { + /// Batch size that defines how many tasks can be executed in parallel at once. + /// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically, + /// if there are more tasks than the batch size, the function will panic. + const BATCH_SIZE: usize = 5; + + pub fn new( + worklow_rx: mpsc::Receiver, + publish_tx: mpsc::Sender, + ) -> Self { + Self { + worklow_rx, + publish_tx, + } + } + + pub async fn run(&mut self) { + loop { + // get tasks in batch from the channel + let mut batch_vec = Vec::new(); + let num_tasks = self + .worklow_rx + .recv_many(&mut batch_vec, Self::BATCH_SIZE) + .await; + debug_assert!( + num_tasks <= Self::BATCH_SIZE, + "drain cant be larger than batch size" + ); + // TODO: just to be sure, can be removed later + debug_assert_eq!(num_tasks, batch_vec.len()); + + if num_tasks == 0 { + self.worklow_rx.close(); + return; + } + + // process the batch + let mut batch = batch_vec.into_iter(); + log::info!("Processing {} workflows in batch", num_tasks); + let results = match num_tasks { + 1 => vec![WorkflowsWorker::execute(batch.next().unwrap()).await], + 2 => { + let (r0, r1) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1] + } + 3 => { + let (r0, r1, r2) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2] + } + 4 => { + let (r0, r1, r2, r3) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2, r3] + } + 5 => { + let (r0, r1, r2, r3, r4) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2, r3, r4] + } + _ => { + unreachable!("drain cant be larger than batch size"); + } + }; + + // publish all results + // TODO: make this a part of executor as well + log::info!("Publishing {} workflow results", results.len()); + for result in results { + if let Err(e) = self.publish_tx.send(result).await { + log::error!("Error sending workflow result: {}", e); + } + } + } + } + + /// A single task execution. + pub async fn execute(input: WorkflowsWorkerInput) -> WorkflowsWorkerOutput { + let mut memory = ProgramMemory::new(); + + let started_at = std::time::Instant::now(); + let result = input + .executor + .execute(input.entry.as_ref(), &input.workflow, &mut memory) + .await; + + WorkflowsWorkerOutput { + result, + public_key: input.public_key, + task_id: input.task_id, + model_name: input.model_name, + stats: input.stats.record_execution_time(started_at), + } + } +} From 572f5f8901b8893ebe637285fe06ecae63286909 Mon Sep 17 00:00:00 2001 From: erhant Date: Wed, 27 Nov 2024 23:57:26 +0300 Subject: [PATCH 03/16] batch size 8 --- compute/src/workers/workflow.rs | 44 +++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 74a8e0c..8053af2 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -33,7 +33,7 @@ impl WorkflowsWorker { /// Batch size that defines how many tasks can be executed in parallel at once. /// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically, /// if there are more tasks than the batch size, the function will panic. - const BATCH_SIZE: usize = 5; + const BATCH_SIZE: usize = 8; pub fn new( worklow_rx: mpsc::Receiver, @@ -61,6 +61,7 @@ impl WorkflowsWorker { debug_assert_eq!(num_tasks, batch_vec.len()); if num_tasks == 0 { + log::warn!("Closing workflows worker."); self.worklow_rx.close(); return; } @@ -69,7 +70,10 @@ impl WorkflowsWorker { let mut batch = batch_vec.into_iter(); log::info!("Processing {} workflows in batch", num_tasks); let results = match num_tasks { - 1 => vec![WorkflowsWorker::execute(batch.next().unwrap()).await], + 1 => { + let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await; + vec![r0] + } 2 => { let (r0, r1) = tokio::join!( WorkflowsWorker::execute(batch.next().unwrap()), @@ -104,6 +108,42 @@ impl WorkflowsWorker { ); vec![r0, r1, r2, r3, r4] } + 6 => { + let (r0, r1, r2, r3, r4, r5) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2, r3, r4, r5] + } + 7 => { + let (r0, r1, r2, r3, r4, r5, r6) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2, r3, r4, r5, r6] + } + 8 => { + let (r0, r1, r2, r3, r4, r5, r6, r7) = tokio::join!( + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()), + WorkflowsWorker::execute(batch.next().unwrap()) + ); + vec![r0, r1, r2, r3, r4, r5, r6, r7] + } _ => { unreachable!("drain cant be larger than batch size"); } From b74f09c34e4ee03f3582e0bb34445cf8c9dc9a12 Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 12:53:45 +0300 Subject: [PATCH 04/16] msg refactors, fix few clones --- Cargo.lock | 6 +- Cargo.toml | 2 +- compute/Cargo.toml | 4 +- compute/src/config.rs | 2 +- compute/src/handlers/pingpong.rs | 8 +- compute/src/handlers/workflow.rs | 35 ++- compute/src/lib.rs | 3 - compute/src/main.rs | 2 +- compute/src/node.rs | 240 +++++++++---------- compute/src/payloads/response.rs | 13 +- compute/src/payloads/stats.rs | 4 + compute/src/utils/available_nodes/mod.rs | 16 -- compute/src/utils/available_nodes/statics.rs | 3 + compute/src/utils/crypto.rs | 12 +- compute/src/utils/filter.rs | 2 +- compute/src/utils/message.rs | 4 +- compute/src/utils/misc.rs | 2 +- p2p/src/client.rs | 1 - p2p/tests/listen_test.rs | 1 - workflows/src/apis/jina.rs | 4 +- workflows/src/apis/serper.rs | 8 +- workflows/src/bin/tps.rs | 2 +- workflows/src/providers/gemini.rs | 2 +- workflows/src/providers/ollama.rs | 19 +- workflows/src/providers/openai.rs | 4 +- workflows/src/providers/openrouter.rs | 2 +- 26 files changed, 183 insertions(+), 218 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6161931..8e8f839 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -974,7 +974,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.24" +version = "0.2.25" dependencies = [ "async-trait", "base64 0.22.1", @@ -1006,7 +1006,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.2.24" +version = "0.2.25" dependencies = [ "env_logger 0.11.5", "eyre", @@ -1019,7 +1019,7 @@ dependencies = [ [[package]] name = "dkn-workflows" -version = "0.2.24" +version = "0.2.25" dependencies = [ "dotenvy", "env_logger 0.11.5", diff --git a/Cargo.toml b/Cargo.toml index cd3a529..69ab2fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ default-members = ["compute"] [workspace.package] edition = "2021" -version = "0.2.24" +version = "0.2.25" license = "Apache-2.0" readme = "README.md" diff --git a/compute/Cargo.toml b/compute/Cargo.toml index 588a5f9..2cc1c97 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -34,8 +34,6 @@ rand.workspace = true env_logger.workspace = true log.workspace = true eyre.workspace = true -# tracing = { version = "0.1.40" } -# tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } # encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters ecies = { version = "0.2", default-features = false, features = ["pure"] } @@ -48,7 +46,7 @@ fastbloom-rs = "0.5.9" dkn-p2p = { path = "../p2p" } dkn-workflows = { path = "../workflows" } -# Vendor OpenSSL so that its easier to build cross-platform packages +# vendor OpenSSL so that its easier to build cross-platform packages [dependencies.openssl] version = "*" features = ["vendored"] diff --git a/compute/src/config.rs b/compute/src/config.rs index 25d5ebc..8336d97 100644 --- a/compute/src/config.rs +++ b/compute/src/config.rs @@ -131,7 +131,7 @@ impl DriaComputeNodeConfig { .map(|addr| addr.trim_matches('"').to_string()) .unwrap_or(DEFAULT_P2P_LISTEN_ADDR.to_string()); let p2p_listen_addr = Multiaddr::from_str(&p2p_listen_addr_str) - .expect("Could not parse the given P2P listen address."); + .expect("could not parse the given P2P listen address."); // parse network type let network_type = env::var("DKN_NETWORK") diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 9f31fc8..02831c2 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -11,7 +11,9 @@ pub struct PingpongHandler; #[derive(Serialize, Deserialize, Debug, Clone)] struct PingpongPayload { + /// UUID of the ping request, prevents replay attacks. uuid: String, + /// Deadline for the ping request. deadline: u128, } @@ -37,11 +39,11 @@ impl PingpongHandler { /// 7. Returns `MessageAcceptance::Accept` so that ping is propagated to others as well. pub(crate) async fn handle_ping( node: &mut DriaComputeNode, - message: DKNMessage, + ping_message: &DKNMessage, ) -> Result { - let pingpong = message + let pingpong = ping_message .parse_payload::(true) - .wrap_err("Could not parse ping request")?; + .wrap_err("could not parse ping request")?; // check deadline let current_time = get_current_time_nanos(); diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index cf11c16..c0ae7fe 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -31,14 +31,12 @@ impl WorkflowHandler { pub(crate) async fn handle_compute( node: &mut DriaComputeNode, - message: DKNMessage, + compute_message: &DKNMessage, ) -> Result> { - let task = message + let stats = TaskStats::new().record_received_at(); + let task = compute_message .parse_payload::>(true) - .wrap_err("Could not parse workflow task")?; - - // TODO: !!! - let task_stats = TaskStats::default().record_received_at(); + .wrap_err("could not parse workflow task")?; // check if deadline is past or not let current_time = get_current_time_nanos(); @@ -106,15 +104,15 @@ impl WorkflowHandler { model_name, task_id: task.task_id, public_key: task_public_key, - stats: task_stats, + stats, })) } pub(crate) async fn handle_publish( node: &mut DriaComputeNode, task: WorkflowsWorkerOutput, - ) -> Result { - let (message, acceptance) = match task.result { + ) -> Result<()> { + let message = match task.result { Ok(result) => { // prepare signed and encrypted payload let payload = TaskResponsePayload::new( @@ -126,7 +124,7 @@ impl WorkflowHandler { task.stats.record_published_at(), )?; let payload_str = serde_json::to_string(&payload) - .wrap_err("Could not serialize response payload")?; + .wrap_err("could not serialize response payload")?; // prepare signed message log::debug!( @@ -134,9 +132,8 @@ impl WorkflowHandler { task.task_id, payload_str ); - let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC); - // accept so that if there are others included in filter they can do the task - (message, MessageAcceptance::Accept) + + DKNMessage::new(payload_str, Self::RESPONSE_TOPIC) } Err(err) => { // use pretty display string for error logging with causes @@ -151,22 +148,20 @@ impl WorkflowHandler { stats: task.stats.record_published_at(), }; let error_payload_str = serde_json::to_string(&error_payload) - .wrap_err("Could not serialize error payload")?; + .wrap_err("could not serialize error payload")?; // prepare signed message - let message = DKNMessage::new_signed( + DKNMessage::new_signed( error_payload_str, Self::RESPONSE_TOPIC, &node.config.secret_key, - ); - // ignore just in case, workflow may be bugged - (message, MessageAcceptance::Ignore) + ) } }; // try publishing the result if let Err(publish_err) = node.publish(message).await { - let err_msg = format!("Could not publish result: {:?}", publish_err); + let err_msg = format!("could not publish result: {:?}", publish_err); log::error!("{}", err_msg); let payload = serde_json::json!({ @@ -181,6 +176,6 @@ impl WorkflowHandler { node.publish(message).await?; }; - Ok(acceptance) + Ok(()) } } diff --git a/compute/src/lib.rs b/compute/src/lib.rs index cebc37a..696eb29 100644 --- a/compute/src/lib.rs +++ b/compute/src/lib.rs @@ -1,6 +1,3 @@ -// #![doc = include_str!("../README.md")] -// TODO: this line breaks docker, find a way to ignore during compose? - pub(crate) mod config; pub(crate) mod handlers; pub(crate) mod node; diff --git a/compute/src/main.rs b/compute/src/main.rs index 3812b37..7988b52 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -13,7 +13,7 @@ async fn main() -> Result<()> { .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) .init(); if let Err(e) = dotenv_result { - log::warn!("Could not load .env file: {}", e); + log::warn!("could not load .env file: {}", e); } log::info!( diff --git a/compute/src/node.rs b/compute/src/node.rs index b832101..e6ca858 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -6,10 +6,7 @@ use dkn_p2p::{ DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, }; use eyre::{eyre, Result}; -use tokio::{ - sync::mpsc, - time::{Duration, Instant}, -}; +use tokio::{sync::mpsc, time::Duration}; use tokio_util::{either::Either, sync::CancellationToken}; use crate::{ @@ -21,13 +18,14 @@ use crate::{ /// Number of seconds between refreshing the Kademlia DHT. const PEER_REFRESH_INTERVAL_SECS: u64 = 30; +/// Number of seconds between refreshing the available nodes. +const AVAILABLE_NODES_REFRESH_INTERVAL_SECS: u64 = 30 * 60; // 30 minutes pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, pub p2p: DriaP2PCommander, pub available_nodes: AvailableNodes, pub cancellation: CancellationToken, - peer_last_refreshed: Instant, // channels message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, worklow_tx: mpsc::Sender, @@ -82,7 +80,6 @@ impl DriaComputeNode { message_rx, worklow_tx, publish_rx, - peer_last_refreshed: Instant::now(), }, p2p_client, workflows_worker, @@ -136,128 +133,113 @@ impl DriaComputeNode { &mut self, (peer_id, message_id, message): (PeerId, &MessageId, Message), ) -> MessageAcceptance { - // refresh admin rpc peer ids - // TODO: move this to main loop with tokio select - if self.available_nodes.can_refresh() { - log::info!("Refreshing available nodes."); - - if let Err(e) = self.available_nodes.populate_with_api().await { - log::error!("Error refreshing available nodes: {:?}", e); - }; - - // dial all rpc nodes for better connectivity - for rpc_addr in self.available_nodes.rpc_addrs.iter() { - log::debug!("Dialling RPC node: {}", rpc_addr); - if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { - log::warn!("Error dialling RPC node: {:?}", e); + // handle message with respect to its topic + match message.topic.as_str() { + PingpongHandler::LISTEN_TOPIC | WorkflowHandler::LISTEN_TOPIC => { + // ensure that the message is from a valid source (origin) + let Some(source_peer_id) = message.source else { + log::warn!( + "Received {} message from {} without source.", + message.topic, + peer_id + ); + return MessageAcceptance::Ignore; }; - } - // print network info - log::debug!("{:?}", self.p2p.network_info().await); - } + // log the received message + log::info!( + "Received {} message ({}) from {}", + message.topic, + message_id, + peer_id, + ); - // check peer count - // TODO: move this to main loop with tokio select - if self.peer_last_refreshed.elapsed() > Duration::from_secs(PEER_REFRESH_INTERVAL_SECS) { - match self.p2p.peer_counts().await { - Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), - Err(e) => { - log::error!("Error getting peer counts: {:?}", e); + // ensure that message is from the known RPCs + if !self.available_nodes.rpc_nodes.contains(&source_peer_id) { + log::warn!( + "Received message from unauthorized source: {}", + source_peer_id + ); + log::debug!("Allowed sources: {:#?}", self.available_nodes.rpc_nodes); + return MessageAcceptance::Ignore; } - } - self.peer_last_refreshed = Instant::now(); + // first, parse the raw gossipsub message to a prepared message + let message = match self.parse_message_to_prepared_message(&message) { + Ok(message) => message, + Err(e) => { + log::error!("Error parsing message: {:?}", e); + log::debug!("Message: {}", String::from_utf8_lossy(&message.data)); + return MessageAcceptance::Ignore; + } + }; + + // then handle the prepared message + let handler_result = match message.topic.as_str() { + WorkflowHandler::LISTEN_TOPIC => { + match WorkflowHandler::handle_compute(self, &message).await { + Ok(Either::Left(acceptance)) => Ok(acceptance), + Ok(Either::Right(workflow_message)) => { + if let Err(e) = self.worklow_tx.send(workflow_message).await { + log::error!("Error sending workflow message: {:?}", e); + }; + + // accept the message in case others may be included in the filter as well + Ok(MessageAcceptance::Accept) + } + Err(err) => Err(err), + } + } + PingpongHandler::LISTEN_TOPIC => { + PingpongHandler::handle_ping(self, &message).await + } + _ => unreachable!(), // unreachable because of the `match` above + }; - // TODO: add peer list as well + // validate the message based on the result + handler_result.unwrap_or_else(|err| { + log::error!("Error handling {} message: {:?}", message.topic, err); + MessageAcceptance::Ignore + }) + } + PingpongHandler::RESPONSE_TOPIC | WorkflowHandler::RESPONSE_TOPIC => { + // since we are responding to these topics, we might receive messages from other compute nodes + // we can gracefully ignore them and propagate it to to others + log::trace!("Ignoring message for topic: {}", message.topic); + MessageAcceptance::Accept + } + other => { + // reject this message as its from a foreign topic + log::warn!("Received message from unexpected topic: {}", other); + MessageAcceptance::Reject + } } + } - // handle message with respect to its topic - let topic_str = message.topic.as_str(); - if std::matches!( - topic_str, - PingpongHandler::LISTEN_TOPIC | WorkflowHandler::LISTEN_TOPIC - ) { - // ensure that the message is from a valid source (origin) - let Some(source_peer_id) = message.source else { - log::warn!( - "Received {} message from {} without source.", - topic_str, - peer_id - ); - return MessageAcceptance::Ignore; - }; + /// Peer refresh simply reports the peer count to the user. + async fn handle_peer_refresh(&self) { + match self.p2p.peer_counts().await { + Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), + Err(e) => log::error!("Error getting peer counts: {:?}", e), + } + } - // log the received message - log::info!( - "Received {} message ({}) from {}", - topic_str, - message_id, - peer_id, - ); - - // ensure that message is from the known RPCs - if !self.available_nodes.rpc_nodes.contains(&source_peer_id) { - log::warn!( - "Received message from unauthorized source: {}", - source_peer_id - ); - log::debug!("Allowed sources: {:#?}", self.available_nodes.rpc_nodes); - return MessageAcceptance::Ignore; - } + /// Updates the local list of available nodes by refreshing it. + /// Dials the RPC nodes again for better connectivity. + async fn handle_available_nodes_refresh(&mut self) { + log::info!("Refreshing available nodes."); - // first, parse the raw gossipsub message to a prepared message - let message = match self.parse_message_to_prepared_message(message.clone()) { - Ok(message) => message, - Err(e) => { - log::error!("Error parsing message: {:?}", e); - log::debug!("Message: {}", String::from_utf8_lossy(&message.data)); - return MessageAcceptance::Ignore; - } - }; + // refresh available nodes + if let Err(e) = self.available_nodes.populate_with_api().await { + log::error!("Error refreshing available nodes: {:?}", e); + }; - // then handle the prepared message - let handler_result = match topic_str { - WorkflowHandler::LISTEN_TOPIC => { - let compute_result = WorkflowHandler::handle_compute(self, message).await; - match compute_result { - Ok(Either::Left(acceptance)) => Ok(acceptance), - Ok(Either::Right(workflow_message)) => { - if let Err(e) = self.worklow_tx.send(workflow_message).await { - log::error!("Error sending workflow message: {:?}", e); - }; - - Ok(MessageAcceptance::Accept) - } - Err(err) => Err(err), - } - } - PingpongHandler::LISTEN_TOPIC => PingpongHandler::handle_ping(self, message).await, - _ => unreachable!(), // unreachable because of the if condition + // dial all rpc nodes + for rpc_addr in self.available_nodes.rpc_addrs.iter() { + log::debug!("Dialling RPC node: {}", rpc_addr); + if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { + log::warn!("Error dialling RPC node: {:?}", e); }; - - // validate the message based on the result - match handler_result { - Ok(acceptance) => { - return acceptance; - } - Err(err) => { - log::error!("Error handling {} message: {:?}", topic_str, err); - return MessageAcceptance::Ignore; - } - } - } else if std::matches!( - topic_str, - PingpongHandler::RESPONSE_TOPIC | WorkflowHandler::RESPONSE_TOPIC - ) { - // since we are responding to these topics, we might receive messages from other compute nodes - // we can gracefully ignore them and propagate it to to others - log::trace!("Ignoring message for topic: {}", topic_str); - return MessageAcceptance::Accept; - } else { - // reject this message as its from a foreign topic - log::warn!("Received message from unexpected topic: {}", topic_str); - return MessageAcceptance::Reject; } } @@ -270,15 +252,28 @@ impl DriaComputeNode { self.subscribe(WorkflowHandler::LISTEN_TOPIC).await?; self.subscribe(WorkflowHandler::RESPONSE_TOPIC).await?; - // main loop, listens for message events in particular - // the underlying p2p client is expected to handle the rest within its own loop + let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS); + let available_node_refresh_duration = + Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS); + loop { tokio::select! { + // check peer count every now and then + _ = tokio::time::sleep(peer_refresh_duration) => self.handle_peer_refresh().await, + // available nodes are refreshed every now and then + _ = tokio::time::sleep(available_node_refresh_duration) => self.handle_available_nodes_refresh().await, + // a Workflow message to be published is received from the channel + // this is expected to be sent by the workflow worker publish_msg = self.publish_rx.recv() => { if let Some(result) = publish_msg { WorkflowHandler::handle_publish(self, result).await?; - } + } else { + log::error!("Publish channel closed unexpectedly."); + break; + }; }, + // a GossipSub message is received from the channel + // this is expected to be sent by the p2p client gossipsub_msg = self.message_rx.recv() => { if let Some((peer_id, message_id, message)) = gossipsub_msg { // handle the message, returning a message acceptance for the received one @@ -290,10 +285,12 @@ impl DriaComputeNode { log::error!("Error validating message {}: {:?}", message_id, e); } } else { - log::warn!("Message channel closed."); + log::error!("Message channel closed unexpectedly."); break; }; }, + // check if the cancellation token is cancelled + // this is expected to be cancelled by the main thread with signal handling _ = self.cancellation.cancelled() => break, } } @@ -323,12 +320,13 @@ impl DriaComputeNode { Ok(()) } + /// Parses a given raw Gossipsub message to a prepared P2PMessage object. /// This prepared message includes the topic, payload, version and timestamp. /// /// This also checks the signature of the message, expecting a valid signature from admin node. // TODO: move this somewhere? - pub fn parse_message_to_prepared_message(&self, message: Message) -> Result { + pub fn parse_message_to_prepared_message(&self, message: &Message) -> Result { // the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately. log::debug!("Parsing {} message.", message.topic.as_str()); let message = DKNMessage::try_from(message)?; diff --git a/compute/src/payloads/response.rs b/compute/src/payloads/response.rs index d023c64..b32e5f7 100644 --- a/compute/src/payloads/response.rs +++ b/compute/src/payloads/response.rs @@ -88,26 +88,25 @@ mod tests { MODEL.to_string(), Default::default(), ) - .expect("Should create payload"); + .expect("to create payload"); // decrypt result and compare it to plaintext let ciphertext_bytes = hex::decode(payload.ciphertext).unwrap(); - let result = decrypt(&task_sk.serialize(), &ciphertext_bytes).expect("Could not decrypt"); + let result = decrypt(&task_sk.serialize(), &ciphertext_bytes).expect("to decrypt"); assert_eq!(result, RESULT, "Result mismatch"); // verify signature - let signature_bytes = hex::decode(payload.signature).expect("Should decode"); + let signature_bytes = hex::decode(payload.signature).expect("to decode"); let signature = Signature::parse_standard_slice(&signature_bytes[..64]).unwrap(); let recid = RecoveryId::parse(signature_bytes[64]).unwrap(); let mut preimage = vec![]; preimage.extend_from_slice(task_id.as_bytes()); preimage.extend_from_slice(&result); let message = Message::parse(&sha256hash(preimage)); - assert!(verify(&message, &signature, &signer_pk), "Could not verify"); + assert!(verify(&message, &signature, &signer_pk), "could not verify"); // recover verifying key (public key) from signature - let recovered_public_key = - recover(&message, &signature, &recid).expect("Could not recover"); - assert_eq!(signer_pk, recovered_public_key, "Public key mismatch"); + let recovered_public_key = recover(&message, &signature, &recid).expect("to recover"); + assert_eq!(signer_pk, recovered_public_key, "public key mismatch"); } } diff --git a/compute/src/payloads/stats.rs b/compute/src/payloads/stats.rs index 7a16a2b..3263cb8 100644 --- a/compute/src/payloads/stats.rs +++ b/compute/src/payloads/stats.rs @@ -17,6 +17,10 @@ pub struct TaskStats { } impl TaskStats { + pub fn new() -> Self { + Self::default() + } + /// Records the current timestamp within `received_at`. pub fn record_received_at(mut self) -> Self { // can unwrap safely here as UNIX_EPOCH is always smaller than now diff --git a/compute/src/utils/available_nodes/mod.rs b/compute/src/utils/available_nodes/mod.rs index cb0cdc8..620870e 100644 --- a/compute/src/utils/available_nodes/mod.rs +++ b/compute/src/utils/available_nodes/mod.rs @@ -8,9 +8,6 @@ mod statics; use crate::DriaNetworkType; -/// Number of seconds between refreshing the available nodes. -const DEFAULT_REFRESH_INTERVAL_SECS: u64 = 30 * 60; // 30 minutes - impl DriaNetworkType { /// Returns the URL for fetching available nodes w.r.t network type. pub fn get_available_nodes_url(&self) -> &str { @@ -36,7 +33,6 @@ pub struct AvailableNodes { pub rpc_addrs: HashSet, pub last_refreshed: Instant, pub network_type: DriaNetworkType, - pub refresh_interval_secs: u64, } impl AvailableNodes { @@ -49,16 +45,9 @@ impl AvailableNodes { rpc_addrs: HashSet::new(), last_refreshed: Instant::now(), network_type: network, - refresh_interval_secs: DEFAULT_REFRESH_INTERVAL_SECS, } } - /// Sets the refresh interval in seconds. - pub fn with_refresh_interval(mut self, interval_secs: u64) -> Self { - self.refresh_interval_secs = interval_secs; - self - } - /// Parses static bootstrap & relay nodes from environment variables. /// /// The environment variables are: @@ -118,11 +107,6 @@ impl AvailableNodes { Ok(()) } - - /// Returns whether enough time has passed since the last refresh. - pub fn can_refresh(&self) -> bool { - self.last_refreshed.elapsed().as_secs() > self.refresh_interval_secs - } } /// Like `parse` of `str` but for vectors. diff --git a/compute/src/utils/available_nodes/statics.rs b/compute/src/utils/available_nodes/statics.rs index b0e3739..d26f079 100644 --- a/compute/src/utils/available_nodes/statics.rs +++ b/compute/src/utils/available_nodes/statics.rs @@ -3,6 +3,7 @@ use dkn_p2p::libp2p::{Multiaddr, PeerId}; impl DriaNetworkType { /// Static bootstrap nodes for Kademlia. + #[inline(always)] pub fn get_static_bootstrap_nodes(&self) -> Vec { match self { DriaNetworkType::Community => [ @@ -18,6 +19,7 @@ impl DriaNetworkType { } /// Static relay nodes for the `P2pCircuit`. + #[inline(always)] pub fn get_static_relay_nodes(&self) -> Vec { match self { DriaNetworkType::Community => [ @@ -33,6 +35,7 @@ impl DriaNetworkType { } /// Static RPC Peer IDs for the Admin RPC. + #[inline(always)] pub fn get_static_rpc_peer_ids(&self) -> Vec { // match self { // DriaNetworkType::Community => [].iter(), diff --git a/compute/src/utils/crypto.rs b/compute/src/utils/crypto.rs index 21a69f1..a5437e9 100644 --- a/compute/src/utils/crypto.rs +++ b/compute/src/utils/crypto.rs @@ -106,24 +106,24 @@ mod tests { #[test] fn test_sign_verify() { let secret_key = - SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("Should parse private key slice."); + SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("to parse private key slice"); // sign the message using the secret key let digest = sha256hash(MESSAGE); - let message = Message::parse_slice(&digest).expect("Should parse message."); + let message = Message::parse_slice(&digest).expect("to parse message"); let (signature, recid) = sign(&message, &secret_key); // recover verifying key (public key) from signature let expected_public_key = PublicKey::from_secret_key(&secret_key); let recovered_public_key = - recover(&message, &signature, &recid).expect("Should recover public key."); + recover(&message, &signature, &recid).expect("to recover public key"); assert_eq!(expected_public_key, recovered_public_key); // verify the signature let public_key = recovered_public_key; assert!( verify(&message, &signature, &public_key), - "Could not verify signature." + "could not verify signature" ); } @@ -131,12 +131,12 @@ mod tests { #[ignore = "run only with profiler if wanted"] fn test_memory_usage() { let secret_key = - SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("Should parse private key slice."); + SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("to parse private key slice"); let public_key = PublicKey::from_secret_key(&secret_key); // sign the message using the secret key let digest = sha256hash(MESSAGE); - let message = Message::parse_slice(&digest).expect("Should parse message."); + let message = Message::parse_slice(&digest).expect("to parse message"); let (signature, _) = sign(&message, &secret_key); // verify signature with context diff --git a/compute/src/utils/filter.rs b/compute/src/utils/filter.rs index 7b4a024..371a066 100644 --- a/compute/src/utils/filter.rs +++ b/compute/src/utils/filter.rs @@ -17,7 +17,7 @@ impl TaskFilter { pub fn contains(&self, address: &[u8]) -> Result { BloomFilter::try_from(self) .map(|filter| filter.contains(address)) - .wrap_err("Could not create filter.") + .wrap_err("could not create filter") } } diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index 8beab5d..a587166 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -133,10 +133,10 @@ impl fmt::Display for DKNMessage { } } -impl TryFrom for DKNMessage { +impl TryFrom<&dkn_p2p::libp2p::gossipsub::Message> for DKNMessage { type Error = serde_json::Error; - fn try_from(value: dkn_p2p::libp2p::gossipsub::Message) -> Result { + fn try_from(value: &dkn_p2p::libp2p::gossipsub::Message) -> Result { serde_json::from_slice(&value.data) } } diff --git a/compute/src/utils/misc.rs b/compute/src/utils/misc.rs index 521106e..d7179ef 100644 --- a/compute/src/utils/misc.rs +++ b/compute/src/utils/misc.rs @@ -40,7 +40,7 @@ pub fn address_in_use(addr: &Multiaddr) -> bool { .map(|port| is_port_reachable(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))) .unwrap_or_else(|| { log::error!( - "Could not find any TCP port in the given address: {:?}", + "could not find any TCP port in the given address: {:?}", addr ); false diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 4f48d5f..0c1a098 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -147,7 +147,6 @@ impl DriaP2PClient { /// /// To terminate, the command channel must be closed. pub async fn run(mut self) { - // FIXME: use refresh peers somewhere loop { tokio::select! { event = self.swarm.select_next_some() => self.handle_event(event).await, diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs index 4fdcd55..d24dd9e 100644 --- a/p2p/tests/listen_test.rs +++ b/p2p/tests/listen_test.rs @@ -4,7 +4,6 @@ use libp2p::Multiaddr; use libp2p_identity::Keypair; use std::{env, str::FromStr}; -// FIXME: not working!!! #[tokio::test] #[ignore = "run this manually"] async fn test_listen_topic_once() -> Result<()> { diff --git a/workflows/src/apis/jina.rs b/workflows/src/apis/jina.rs index 1b664a7..7492c30 100644 --- a/workflows/src/apis/jina.rs +++ b/workflows/src/apis/jina.rs @@ -39,10 +39,10 @@ impl JinaConfig { pub async fn check_optional(&self) -> Result<()> { // check API key let Some(api_key) = &self.api_key else { - log::debug!("Jina API key not found, skipping Jina check"); + log::info!("Jina API key not found, skipping"); return Ok(()); }; - log::info!("Jina API key found, checking Jina service"); + log::info!("Jina API key found {api_key}, checking service"); // make a dummy request to "example.com" let client = Client::new(); diff --git a/workflows/src/apis/serper.rs b/workflows/src/apis/serper.rs index 76340b0..0fb5cde 100644 --- a/workflows/src/apis/serper.rs +++ b/workflows/src/apis/serper.rs @@ -4,8 +4,6 @@ use std::env; use crate::utils::safe_read_env; -/// Makes a search request. -const SERPER_EXAMPLE_ENDPOINT: &str = "https://google.serper.dev/search"; const ENV_VAR_NAME: &str = "SERPER_API_KEY"; /// Serper-specific configurations. @@ -45,15 +43,15 @@ impl SerperConfig { pub async fn check_optional(&self) -> Result<()> { // check API key let Some(api_key) = &self.api_key else { - log::debug!("Serper API key not found, skipping Serper check"); + log::info!("Serper API key not found, skipping"); return Ok(()); }; - log::info!("Serper API key found, checking Serper service"); + log::info!("Serper API key found, checking service"); // make a dummy request let client = Client::new(); let request = client - .post(SERPER_EXAMPLE_ENDPOINT) + .post("https://google.serper.dev/search") .header("X-API-KEY", api_key) .header("Content-Type", "application/json") .body("{\"q\": \"Your search query here\"}") diff --git a/workflows/src/bin/tps.rs b/workflows/src/bin/tps.rs index 9ea255c..8284278 100644 --- a/workflows/src/bin/tps.rs +++ b/workflows/src/bin/tps.rs @@ -1,6 +1,6 @@ #[cfg(not(feature = "profiling"))] fn main() { - unimplemented!("This binary requires the 'profiling' feature to be enabled"); + unimplemented!("this binary requires the 'profiling' feature to be enabled"); } #[cfg(feature = "profiling")] diff --git a/workflows/src/providers/gemini.rs b/workflows/src/providers/gemini.rs index 8e97daf..4f694bf 100644 --- a/workflows/src/providers/gemini.rs +++ b/workflows/src/providers/gemini.rs @@ -166,7 +166,7 @@ impl GeminiConfig { response .text() .await - .unwrap_or("Could not get error text as well".to_string()) + .unwrap_or("could not get error text as well".to_string()) )); } log::debug!("Dummy request successful for model {}", model); diff --git a/workflows/src/providers/ollama.rs b/workflows/src/providers/ollama.rs index 905fea2..b4c8be8 100644 --- a/workflows/src/providers/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -107,6 +107,7 @@ impl OllamaConfig { ); let ollama = Ollama::new(&self.host, self.port); + log::info!("Connecting to Ollama at {}", ollama.url_str()); // fetch local models let local_models = match ollama.list_local_models().await { @@ -128,7 +129,7 @@ impl OllamaConfig { if !&local_models.iter().any(|s| s == model) { self.try_pull(&ollama, model.to_owned()) .await - .wrap_err("Could not pull model")?; + .wrap_err("could not pull model")?; } } @@ -139,7 +140,7 @@ impl OllamaConfig { if !local_models.contains(&model.to_string()) { self.try_pull(&ollama, model.to_string()) .await - .wrap_err("Could not pull model")?; + .wrap_err("could not pull model")?; } if self.test_performance(&ollama, &model).await { @@ -191,19 +192,7 @@ impl OllamaConfig { return false; }; - let mut generation_request = - GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); - - // FIXME: temporary workaround, can take num threads from outside - if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") { - generation_request = generation_request.options( - GenerationOptions::default().num_thread( - num_thread - .parse() - .expect("num threads should be a positive integer"), - ), - ); - } + let generation_request = GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string()); // then, run a sample generation with timeout and measure tps tokio::select! { diff --git a/workflows/src/providers/openai.rs b/workflows/src/providers/openai.rs index 33f427e..0ff1c11 100644 --- a/workflows/src/providers/openai.rs +++ b/workflows/src/providers/openai.rs @@ -110,7 +110,7 @@ impl OpenAIConfig { response .text() .await - .unwrap_or("Could not get error text as well".to_string()) + .unwrap_or("could not get error text as well".to_string()) )) } else { let openai_models = response.json::().await?; @@ -153,7 +153,7 @@ impl OpenAIConfig { response .text() .await - .unwrap_or("Could not get error text as well".to_string()) + .unwrap_or("could not get error text as well".to_string()) )); } log::debug!("Dummy request successful for model {}", model); diff --git a/workflows/src/providers/openrouter.rs b/workflows/src/providers/openrouter.rs index 5f56d02..a1c5cd8 100644 --- a/workflows/src/providers/openrouter.rs +++ b/workflows/src/providers/openrouter.rs @@ -101,7 +101,7 @@ impl OpenRouterConfig { response .text() .await - .unwrap_or("Could not get error text as well".to_string()) + .unwrap_or("could not get error text as well".to_string()) )); } log::debug!("Dummy request successful for model {}", model); From 304edc6b8c2eb55429c9156ac6d62f49ef3478aa Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 13:13:06 +0300 Subject: [PATCH 05/16] small rfks --- compute/src/node.rs | 85 +++++++++++++------------------ compute/src/utils/message.rs | 37 +++++++++++--- workflows/src/providers/ollama.rs | 3 +- 3 files changed, 65 insertions(+), 60 deletions(-) diff --git a/compute/src/node.rs b/compute/src/node.rs index e6ca858..94121aa 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -5,7 +5,7 @@ use dkn_p2p::{ }, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, }; -use eyre::{eyre, Result}; +use eyre::Result; use tokio::{sync::mpsc, time::Duration}; use tokio_util::{either::Either, sync::CancellationToken}; @@ -165,7 +165,10 @@ impl DriaComputeNode { } // first, parse the raw gossipsub message to a prepared message - let message = match self.parse_message_to_prepared_message(&message) { + let message = match DKNMessage::try_from_gossipsub_message( + &message, + &self.config.admin_public_key, + ) { Ok(message) => message, Err(e) => { log::error!("Error parsing message: {:?}", e); @@ -193,7 +196,7 @@ impl DriaComputeNode { PingpongHandler::LISTEN_TOPIC => { PingpongHandler::handle_ping(self, &message).await } - _ => unreachable!(), // unreachable because of the `match` above + _ => unreachable!("unreachable due to match expression"), }; // validate the message based on the result @@ -216,46 +219,20 @@ impl DriaComputeNode { } } - /// Peer refresh simply reports the peer count to the user. - async fn handle_peer_refresh(&self) { - match self.p2p.peer_counts().await { - Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), - Err(e) => log::error!("Error getting peer counts: {:?}", e), - } - } - - /// Updates the local list of available nodes by refreshing it. - /// Dials the RPC nodes again for better connectivity. - async fn handle_available_nodes_refresh(&mut self) { - log::info!("Refreshing available nodes."); - - // refresh available nodes - if let Err(e) = self.available_nodes.populate_with_api().await { - log::error!("Error refreshing available nodes: {:?}", e); - }; - - // dial all rpc nodes - for rpc_addr in self.available_nodes.rpc_addrs.iter() { - log::debug!("Dialling RPC node: {}", rpc_addr); - if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { - log::warn!("Error dialling RPC node: {:?}", e); - }; - } - } - /// Runs the main loop of the compute node. /// This method is not expected to return until cancellation occurs. pub async fn run(&mut self) -> Result<()> { + // prepare durations for sleeps + let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS); + let available_node_refresh_duration = + Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS); + // subscribe to topics self.subscribe(PingpongHandler::LISTEN_TOPIC).await?; self.subscribe(PingpongHandler::RESPONSE_TOPIC).await?; self.subscribe(WorkflowHandler::LISTEN_TOPIC).await?; self.subscribe(WorkflowHandler::RESPONSE_TOPIC).await?; - let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS); - let available_node_refresh_duration = - Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS); - loop { tokio::select! { // check peer count every now and then @@ -321,25 +298,31 @@ impl DriaComputeNode { Ok(()) } - /// Parses a given raw Gossipsub message to a prepared P2PMessage object. - /// This prepared message includes the topic, payload, version and timestamp. - /// - /// This also checks the signature of the message, expecting a valid signature from admin node. - // TODO: move this somewhere? - pub fn parse_message_to_prepared_message(&self, message: &Message) -> Result { - // the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately. - log::debug!("Parsing {} message.", message.topic.as_str()); - let message = DKNMessage::try_from(message)?; - log::debug!("Parsed: {}", message); - - // check dria signature - // NOTE: when we have many public keys, we should check the signature against all of them - // TODO: public key here will be given dynamically - if !message.is_signed(&self.config.admin_public_key)? { - return Err(eyre!("Invalid signature.")); + /// Peer refresh simply reports the peer count to the user. + async fn handle_peer_refresh(&self) { + match self.p2p.peer_counts().await { + Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), + Err(e) => log::error!("Error getting peer counts: {:?}", e), } + } - Ok(message) + /// Updates the local list of available nodes by refreshing it. + /// Dials the RPC nodes again for better connectivity. + async fn handle_available_nodes_refresh(&mut self) { + log::info!("Refreshing available nodes."); + + // refresh available nodes + if let Err(e) = self.available_nodes.populate_with_api().await { + log::error!("Error refreshing available nodes: {:?}", e); + }; + + // dial all rpc nodes + for rpc_addr in self.available_nodes.rpc_addrs.iter() { + log::debug!("Dialling RPC node: {}", rpc_addr); + if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { + log::warn!("Error dialling RPC node: {:?}", e); + }; + } } } diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index a587166..0c1ba05 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -6,7 +6,7 @@ use crate::DRIA_COMPUTE_NODE_VERSION; use base64::{prelude::BASE64_STANDARD, Engine}; use core::fmt; use ecies::PublicKey; -use eyre::{Context, Result}; +use eyre::{eyre, Context, Result}; use libsecp256k1::{verify, Message, SecretKey, Signature}; use serde::{Deserialize, Serialize}; @@ -44,7 +44,7 @@ impl DKNMessage { /// /// - `data` is given as bytes, it is encoded into base64 to make up the `payload` within. /// - `topic` is the name of the [gossipsub topic](https://docs.libp2p.io/concepts/pubsub/overview/). - pub fn new(data: impl AsRef<[u8]>, topic: &str) -> Self { + pub(crate) fn new(data: impl AsRef<[u8]>, topic: &str) -> Self { Self { payload: BASE64_STANDARD.encode(data), topic: topic.to_string(), @@ -55,7 +55,7 @@ impl DKNMessage { } /// Creates a new Message by signing the SHA256 of the payload, and prepending the signature. - pub fn new_signed(data: impl AsRef<[u8]>, topic: &str, signing_key: &SecretKey) -> Self { + pub(crate) fn new_signed(data: impl AsRef<[u8]>, topic: &str, signing_key: &SecretKey) -> Self { // sign the SHA256 hash of the data let signature_bytes = sign_bytes_recoverable(&sha256hash(data.as_ref()), signing_key); @@ -69,19 +69,19 @@ impl DKNMessage { } /// Sets the identity of the message. - pub fn with_identity(mut self, identity: String) -> Self { + pub(crate) fn with_identity(mut self, identity: String) -> Self { self.identity = identity; self } /// Decodes the base64 payload into bytes. #[inline(always)] - pub fn decode_payload(&self) -> Result, base64::DecodeError> { + pub(crate) fn decode_payload(&self) -> Result, base64::DecodeError> { BASE64_STANDARD.decode(&self.payload) } /// Decodes and parses the base64 payload into JSON for the provided type `T`. - pub fn parse_payload Deserialize<'a>>(&self, signed: bool) -> Result { + pub(crate) fn parse_payload Deserialize<'a>>(&self, signed: bool) -> Result { let payload = self.decode_payload()?; let body = if signed { @@ -96,7 +96,7 @@ impl DKNMessage { } /// Checks if the payload is signed by the given public key. - pub fn is_signed(&self, public_key: &PublicKey) -> Result { + pub(crate) fn is_signed(&self, public_key: &PublicKey) -> Result { // decode base64 payload let data = self.decode_payload()?; @@ -116,6 +116,29 @@ impl DKNMessage { let digest = Message::parse(&sha256hash(body)); Ok(verify(&digest, &signature, public_key)) } + + /// Tries to parse the given gossipsub message into a DKNMessage. + /// + /// This prepared message includes the topic, payload, version and timestamp. + /// It also checks the signature of the message, expecting a valid signature from admin node. + pub(crate) fn try_from_gossipsub_message( + gossipsub_message: &dkn_p2p::libp2p::gossipsub::Message, + public_key: &libsecp256k1::PublicKey, + ) -> Result { + // the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately. + log::debug!("Parsing {} message.", gossipsub_message.topic.as_str()); + let message = serde_json::from_slice::(&gossipsub_message.data) + .wrap_err("could not parse message")?; + log::debug!("Parsed: {}", message); + + // check dria signature + // NOTE: when we have many public keys, we should check the signature against all of them + if !message.is_signed(&public_key)? { + return Err(eyre!("Invalid signature.")); + } + + Ok(message) + } } impl fmt::Display for DKNMessage { diff --git a/workflows/src/providers/ollama.rs b/workflows/src/providers/ollama.rs index b4c8be8..b17ff39 100644 --- a/workflows/src/providers/ollama.rs +++ b/workflows/src/providers/ollama.rs @@ -4,7 +4,6 @@ use ollama_workflows::{ generation::{ completion::request::GenerationRequest, embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, - options::GenerationOptions, }, Ollama, }, @@ -171,7 +170,7 @@ impl OllamaConfig { // otherwise, give error log::error!("Please download missing model with: ollama pull {}", model); log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically."); - Err(eyre!("Required model not pulled in Ollama.")) + Err(eyre!("required model not pulled in Ollama")) } } From a17b54e8208cd795b0a8c1e3b8e3e8ae1958bc2c Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 15:08:33 +0300 Subject: [PATCH 06/16] added separate ollama worker, small rfks --- .github/workflows/tests.yml | 3 ++ compute/src/handlers/workflow.rs | 43 +++++++++-------- compute/src/main.rs | 20 +++++--- compute/src/node.rs | 65 ++++++++++++++++++------- compute/src/utils/message.rs | 2 +- compute/src/workers/workflow.rs | 81 +++++++++++++++++++++++++------- p2p/src/client.rs | 4 +- 7 files changed, 153 insertions(+), 65 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cee5eff..0b1d4ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,3 +32,6 @@ jobs: - name: Run tests run: cargo test --workspace + + - name: Run linter + run: cargo clippy --workspace diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index c0ae7fe..e480a8a 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -32,7 +32,7 @@ impl WorkflowHandler { pub(crate) async fn handle_compute( node: &mut DriaComputeNode, compute_message: &DKNMessage, - ) -> Result> { + ) -> Result> { let stats = TaskStats::new().record_received_at(); let task = compute_message .parse_payload::>(true) @@ -78,14 +78,17 @@ impl WorkflowHandler { log::info!("Using model {} for task {}", model_name, task.task_id); // prepare workflow executor - let executor = if model_provider == ModelProvider::Ollama { - Executor::new_at( - model, - &node.config.workflows.ollama.host, - node.config.workflows.ollama.port, + let (executor, batchable) = if model_provider == ModelProvider::Ollama { + ( + Executor::new_at( + model, + &node.config.workflows.ollama.host, + node.config.workflows.ollama.port, + ), + false, ) } else { - Executor::new(model) + (Executor::new(model), true) }; // prepare entry from prompt @@ -97,15 +100,18 @@ impl WorkflowHandler { // get workflow as well let workflow = task.input.workflow; - Ok(Either::Right(WorkflowsWorkerInput { - entry, - executor, - workflow, - model_name, - task_id: task.task_id, - public_key: task_public_key, - stats, - })) + Ok(Either::Right(( + WorkflowsWorkerInput { + entry, + executor, + workflow, + model_name, + task_id: task.task_id, + public_key: task_public_key, + stats, + }, + batchable, + ))) } pub(crate) async fn handle_publish( @@ -123,16 +129,15 @@ impl WorkflowHandler { task.model_name, task.stats.record_published_at(), )?; + + // convert payload to message let payload_str = serde_json::to_string(&payload) .wrap_err("could not serialize response payload")?; - - // prepare signed message log::debug!( "Publishing result for task {}\n{}", task.task_id, payload_str ); - DKNMessage::new(payload_str, Self::RESPONSE_TOPIC) } Err(err) => { diff --git a/compute/src/main.rs b/compute/src/main.rs index 7988b52..0a0e1ca 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -31,6 +31,7 @@ async fn main() -> Result<()> { let token = CancellationToken::new(); let cancellation_token = token.clone(); tokio::spawn(async move { + // the timeout is done for profiling only, and should not be used in production if let Ok(Ok(duration_secs)) = env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::()) { @@ -75,15 +76,17 @@ async fn main() -> Result<()> { } let node_token = token.clone(); - let (mut node, p2p, mut workflows) = DriaComputeNode::new(config, node_token).await?; + let (mut node, p2p, mut worker_batch, mut worker_single) = + DriaComputeNode::new(config, node_token).await?; - // launch the p2p in a separate thread log::info!("Spawning peer-to-peer client thread."); let p2p_handle = tokio::spawn(async move { p2p.run().await }); - // launch the workflows in a separate thread - log::info!("Spawning workflows worker thread."); - let workflows_handle = tokio::spawn(async move { workflows.run().await }); + log::info!("Spawning workflows batch worker thread."); + let worker_batch_handle = tokio::spawn(async move { worker_batch.run_batch().await }); + + log::info!("Spawning workflows single worker thread."); + let worker_single_handle = tokio::spawn(async move { worker_single.run().await }); // launch the node in a separate thread log::info!("Spawning compute node thread."); @@ -98,8 +101,11 @@ async fn main() -> Result<()> { if let Err(err) = node_handle.await { log::error!("Node handle error: {}", err); }; - if let Err(err) = workflows_handle.await { - log::error!("Workflows handle error: {}", err); + if let Err(err) = worker_single_handle.await { + log::error!("Workflows single worker handle error: {}", err); + }; + if let Err(err) = worker_batch_handle.await { + log::error!("Workflows batch worker handle error: {}", err); }; if let Err(err) = p2p_handle.await { log::error!("P2P handle error: {}", err); diff --git a/compute/src/node.rs b/compute/src/node.rs index 94121aa..7029274 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -26,10 +26,16 @@ pub struct DriaComputeNode { pub p2p: DriaP2PCommander, pub available_nodes: AvailableNodes, pub cancellation: CancellationToken, - // channels + /// Gossipsub message receiver. message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, - worklow_tx: mpsc::Sender, - publish_rx: mpsc::Receiver, + /// Workflow transmitter to send batchable tasks. + workflow_batch_tx: mpsc::Sender, + /// Publish receiver to receive messages to be published. + publish_batch_rx: mpsc::Receiver, + /// Workflow transmitter to send single tasks. + workflow_single_tx: mpsc::Sender, + /// Publish receiver to receive messages to be published. + publish_single_rx: mpsc::Receiver, } impl DriaComputeNode { @@ -39,7 +45,12 @@ impl DriaComputeNode { pub async fn new( config: DriaComputeNodeConfig, cancellation: CancellationToken, - ) -> Result<(DriaComputeNode, DriaP2PClient, WorkflowsWorker)> { + ) -> Result<( + DriaComputeNode, + DriaP2PClient, + WorkflowsWorker, + WorkflowsWorker, + )> { // create the keypair from secret key let keypair = secret_to_keypair(&config.secret_key); @@ -66,10 +77,10 @@ impl DriaComputeNode { protocol, )?; - // create workflow worker - let (worklow_tx, workflow_rx) = mpsc::channel(256); - let (publish_tx, publish_rx) = mpsc::channel(256); - let workflows_worker = WorkflowsWorker::new(workflow_rx, publish_tx); + // create workflow workers + let (workflows_batch_worker, workflow_batch_tx, publish_batch_rx) = WorkflowsWorker::new(); + let (workflows_single_worker, workflow_single_tx, publish_single_rx) = + WorkflowsWorker::new(); Ok(( DriaComputeNode { @@ -78,11 +89,14 @@ impl DriaComputeNode { cancellation, available_nodes, message_rx, - worklow_tx, - publish_rx, + workflow_batch_tx, + publish_batch_rx, + workflow_single_tx, + publish_single_rx, }, p2p_client, - workflows_worker, + workflows_batch_worker, + workflows_single_worker, )) } @@ -164,7 +178,7 @@ impl DriaComputeNode { return MessageAcceptance::Ignore; } - // first, parse the raw gossipsub message to a prepared message + // parse the raw gossipsub message to a prepared DKN message let message = match DKNMessage::try_from_gossipsub_message( &message, &self.config.admin_public_key, @@ -177,19 +191,25 @@ impl DriaComputeNode { } }; - // then handle the prepared message + // handle the DKN message with respect to the topic let handler_result = match message.topic.as_str() { WorkflowHandler::LISTEN_TOPIC => { match WorkflowHandler::handle_compute(self, &message).await { + // we got acceptance, so something was not right about the workflow and we can ignore it Ok(Either::Left(acceptance)) => Ok(acceptance), - Ok(Either::Right(workflow_message)) => { - if let Err(e) = self.worklow_tx.send(workflow_message).await { + // we got the parsed workflow itself, send to a worker thread w.r.t batchable + Ok(Either::Right((workflow_message, batchable))) => { + if let Err(e) = match batchable { + true => self.workflow_batch_tx.send(workflow_message).await, + false => self.workflow_single_tx.send(workflow_message).await, + } { log::error!("Error sending workflow message: {:?}", e); }; // accept the message in case others may be included in the filter as well Ok(MessageAcceptance::Accept) } + // something went wrong, handle this outside Err(err) => Err(err), } } @@ -241,7 +261,16 @@ impl DriaComputeNode { _ = tokio::time::sleep(available_node_refresh_duration) => self.handle_available_nodes_refresh().await, // a Workflow message to be published is received from the channel // this is expected to be sent by the workflow worker - publish_msg = self.publish_rx.recv() => { + publish_msg = self.publish_batch_rx.recv() => { + if let Some(result) = publish_msg { + WorkflowHandler::handle_publish(self, result).await?; + } else { + log::error!("Publish channel closed unexpectedly."); + break; + }; + }, + // TODO: make the both receivers handled together somehow + publish_msg = self.publish_single_rx.recv() => { if let Some(result) = publish_msg { WorkflowHandler::handle_publish(self, result).await?; } else { @@ -293,7 +322,7 @@ impl DriaComputeNode { self.message_rx.close(); log::debug!("Closing publish channel."); - self.publish_rx.close(); + self.publish_batch_rx.close(); Ok(()) } @@ -339,7 +368,7 @@ mod tests { // create node let cancellation = CancellationToken::new(); - let (mut node, p2p, _) = + let (mut node, p2p, _, _) = DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone()) .await .expect("should create node"); diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index 0c1ba05..bbc664d 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -133,7 +133,7 @@ impl DKNMessage { // check dria signature // NOTE: when we have many public keys, we should check the signature against all of them - if !message.is_signed(&public_key)? { + if !message.is_signed(public_key)? { return Err(eyre!("Invalid signature.")); } diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 8053af2..cc02a1e 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -25,50 +25,91 @@ pub struct WorkflowsWorkerOutput { } pub struct WorkflowsWorker { - worklow_rx: mpsc::Receiver, + workflow_rx: mpsc::Receiver, publish_tx: mpsc::Sender, } +const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024; +const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; + impl WorkflowsWorker { /// Batch size that defines how many tasks can be executed in parallel at once. /// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically, /// if there are more tasks than the batch size, the function will panic. const BATCH_SIZE: usize = 8; - pub fn new( - worklow_rx: mpsc::Receiver, - publish_tx: mpsc::Sender, - ) -> Self { - Self { - worklow_rx, - publish_tx, - } + /// Creates a worker and returns the sender and receiver for the worker. + pub fn new() -> ( + WorkflowsWorker, + mpsc::Sender, + mpsc::Receiver, + ) { + let (workflow_tx, workflow_rx) = mpsc::channel(WORKFLOW_CHANNEL_BUFSIZE); + let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE); + + ( + Self { + workflow_rx, + publish_tx, + }, + workflow_tx, + publish_rx, + ) } + fn shutdown(&mut self) { + log::warn!("Closing workflows worker."); + self.workflow_rx.close(); + } + + /// Launches the thread that can process tasks one by one. + /// This function will block until the channel is closed. + /// + /// It is suitable for task streams that consume local resources, unlike API calls. pub async fn run(&mut self) { + loop { + let task = self.workflow_rx.recv().await; + + let result = if let Some(task) = task { + log::info!("Processing single workflow for task {}", task.task_id); + WorkflowsWorker::execute(task).await + } else { + return self.shutdown(); + }; + + if let Err(e) = self.publish_tx.send(result).await { + log::error!("Error sending workflow result: {}", e); + } + } + } + + /// Launches the thread that can process tasks in batches. + /// This function will block until the channel is closed. + /// + /// It is suitable for task streams that make use of API calls, unlike Ollama-like + /// tasks that consumes local resources and would not make sense to run in parallel. + pub async fn run_batch(&mut self) { loop { // get tasks in batch from the channel - let mut batch_vec = Vec::new(); + let mut task_buffer = Vec::new(); let num_tasks = self - .worklow_rx - .recv_many(&mut batch_vec, Self::BATCH_SIZE) + .workflow_rx + .recv_many(&mut task_buffer, Self::BATCH_SIZE) .await; debug_assert!( num_tasks <= Self::BATCH_SIZE, "drain cant be larger than batch size" ); // TODO: just to be sure, can be removed later - debug_assert_eq!(num_tasks, batch_vec.len()); + debug_assert_eq!(num_tasks, task_buffer.len()); if num_tasks == 0 { - log::warn!("Closing workflows worker."); - self.worklow_rx.close(); - return; + return self.shutdown(); } // process the batch - let mut batch = batch_vec.into_iter(); log::info!("Processing {} workflows in batch", num_tasks); + let mut batch = task_buffer.into_iter(); let results = match num_tasks { 1 => { let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await; @@ -145,7 +186,11 @@ impl WorkflowsWorker { vec![r0, r1, r2, r3, r4, r5, r6, r7] } _ => { - unreachable!("drain cant be larger than batch size"); + unreachable!( + "drain cant be larger than batch size ({} > {})", + num_tasks, + Self::BATCH_SIZE + ); } }; diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 0c1a098..561d52b 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -31,9 +31,9 @@ pub struct DriaP2PClient { /// Number of seconds before an idle connection is closed. const IDLE_CONNECTION_TIMEOUT_SECS: u64 = 60; /// Buffer size for command channel. -const COMMAND_CHANNEL_BUFSIZE: usize = 256; +const COMMAND_CHANNEL_BUFSIZE: usize = 1024; /// Buffer size for events channel. -const MSG_CHANNEL_BUFSIZE: usize = 256; +const MSG_CHANNEL_BUFSIZE: usize = 1024; impl DriaP2PClient { /// Creates a new P2P client with the given keypair and listen address. From e5b26715067aa059f2fb90d0bd9e757fd1275fcc Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 15:53:19 +0300 Subject: [PATCH 07/16] add task report to heartbeat, doc fixes --- Makefile | 2 +- compute/src/handlers/pingpong.rs | 3 + compute/src/node.rs | 8 +++ p2p/README.md | 119 ++++++++++++++++++++++--------- p2p/tests/listen_test.rs | 4 +- 5 files changed, 100 insertions(+), 36 deletions(-) diff --git a/Makefile b/Makefile index 2091c87..f0a7baa 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # load .env ifneq (,$(wildcard ./.env)) - include .env + include ./.env export endif diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 02831c2..b344768 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -22,6 +22,8 @@ struct PingpongResponse { pub(crate) uuid: String, pub(crate) models: Vec<(ModelProvider, Model)>, pub(crate) timestamp: u128, + /// Number of tasks in the channel currently, `single` and `batch`. + pub(crate) tasks: [usize; 2], } impl PingpongHandler { @@ -64,6 +66,7 @@ impl PingpongHandler { uuid: pingpong.uuid.clone(), models: node.config.workflows.models.clone(), timestamp: get_current_time_nanos(), + tasks: node.task_count(), }; // publish message diff --git a/compute/src/node.rs b/compute/src/node.rs index 7029274..33a70cc 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -122,6 +122,14 @@ impl DriaComputeNode { Ok(()) } + /// Returns the task count within the channels, `single` and `batch`. + pub fn task_count(&self) -> [usize; 2] { + [ + self.workflow_single_tx.max_capacity() - self.workflow_single_tx.capacity(), + self.workflow_batch_tx.max_capacity() - self.workflow_batch_tx.capacity(), + ] + } + /// Publishes a given message to the network w.r.t the topic of it. /// /// Internally, identity is attached to the the message which is then JSON serialized to bytes diff --git a/p2p/README.md b/p2p/README.md index 72e0aaf..fb5463b 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -12,7 +12,92 @@ dkn-p2p = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } ## Usage -The P2P client is expected to be run within a separate thread, and it has two types of interactions: +The peer-to-peer client, when created, returns 3 things: + +- `client`: the actual peer-to-peer client that should be run in a **separate thread**. +- `commander`: a small client that exposes peer-to-peer functions with oneshot channels, so that we can communicate with the client in another thread. +- `channel`: a message channel receiver, it is expected to handle GossipSub messages that are handled & sent by the client. + +### Client + +Here is an example where we create the said entities: + +```rs +use dkn_p2p::{DriaP2PClient, DriaP2PProtocol}; + +// your wallet, or something random maybe +let keypair = Keypair::generate_secp256k1(); + +// your listen address +let addr = Multiaddr::from_str("/ip4/0.0.0.0/tcp/4001")?; + +// static bootstrap & relay & rpc addresses +let bootstraps = vec![Multiaddr::from_str( + "some-multiaddrs-here" +)?]; +let relays = vec![Multiaddr::from_str( + "some-multiaddrs-here" +)?]; +let rpcs = vec![Multiaddr::from_str( + "some-multiaddrs-here" +)?]; + +let protocol = "0.2"; + +// `new` returns 3 things: +// - p2p client itself, to be given to a thread +// - p2p commander, a small client to be able to speak with the p2p in another thread +// - `msg_rx`, the channel to listen for gossipsub messages +let (client, mut commander, mut msg_rx) = DriaP2PClient::new( + keypair, + addr, + bootstraps, + relays, + rpc, + protocol +)?; +``` + +Now, you can give the peer-to-peer client to a thread and store its handle: + +```rs +let task_handle = tokio::spawn(async move { client.run().await }); +``` + +This task handle should be `await`'ed at the end of the program to ensure thread has exited correctly. + +### Commander + +You can communicate with this thread using the `commander` entity. For example, here is how one would subscribe to a topic: + +```rs +commander + .subscribe("your-topic") + .await + .expect("could not subscribe"); +``` + +### Channel + +The message channel should be handled with `recv` (or `recv_many` to process in batches) to process the GossipSub messages. + +```rs +loop { + match msg_rx.recv().await { + Some(msg) => { + todo!("handle stuff") + } + None => { + todo!("channel closed"); + break + } + } +} +``` + +### Interactions + +Here is how the whole thing works in a bit more detail: - **Events**: When a message is received within the Swarm event handler, it is returned via a `mpsc` channel. Here, the p2p is `Sender` and your application must be the `Receiver`. The client handles many events, and only sends GossipSub message receipts via this channel so that the application can handle them however they would like. @@ -54,35 +139,3 @@ sequenceDiagram P ->> C: o_tx.send(output) deactivate P ``` - - diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs index d24dd9e..c67db71 100644 --- a/p2p/tests/listen_test.rs +++ b/p2p/tests/listen_test.rs @@ -29,7 +29,7 @@ async fn test_listen_topic_once() -> Result<()> { .expect("could not create p2p client"); // spawn task - let p2p_task = tokio::spawn(async move { client.run().await }); + let task_handle = tokio::spawn(async move { client.run().await }); // subscribe to the given topic commander @@ -61,7 +61,7 @@ async fn test_listen_topic_once() -> Result<()> { msg_rx.close(); log::info!("Waiting for p2p task to finish..."); - p2p_task.await?; + task_handle.await?; log::info!("Done!"); Ok(()) From 5629e1ae361e1146f1750179d8ae0d50c4b061e1 Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 16:34:45 +0300 Subject: [PATCH 08/16] slight renames, fix `sleep` with `interval` --- compute/src/handlers/pingpong.rs | 2 +- compute/src/node.rs | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index b344768..1c83779 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -66,7 +66,7 @@ impl PingpongHandler { uuid: pingpong.uuid.clone(), models: node.config.workflows.models.clone(), timestamp: get_current_time_nanos(), - tasks: node.task_count(), + tasks: node.get_active_task_count(), }; // publish message diff --git a/compute/src/node.rs b/compute/src/node.rs index 33a70cc..3f1ec7d 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -16,8 +16,8 @@ use crate::{ workers::workflow::{WorkflowsWorker, WorkflowsWorkerInput, WorkflowsWorkerOutput}, }; -/// Number of seconds between refreshing the Kademlia DHT. -const PEER_REFRESH_INTERVAL_SECS: u64 = 30; +/// Number of seconds between refreshing for diagnostic prints. +const DIAGNOSTIC_REFRESH_INTERVAL_SECS: u64 = 30; /// Number of seconds between refreshing the available nodes. const AVAILABLE_NODES_REFRESH_INTERVAL_SECS: u64 = 30 * 60; // 30 minutes @@ -123,7 +123,7 @@ impl DriaComputeNode { } /// Returns the task count within the channels, `single` and `batch`. - pub fn task_count(&self) -> [usize; 2] { + pub fn get_active_task_count(&self) -> [usize; 2] { [ self.workflow_single_tx.max_capacity() - self.workflow_single_tx.capacity(), self.workflow_batch_tx.max_capacity() - self.workflow_batch_tx.capacity(), @@ -251,9 +251,10 @@ impl DriaComputeNode { /// This method is not expected to return until cancellation occurs. pub async fn run(&mut self) -> Result<()> { // prepare durations for sleeps - let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS); - let available_node_refresh_duration = - Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS); + let mut peer_refresh_interval = + tokio::time::interval(Duration::from_secs(DIAGNOSTIC_REFRESH_INTERVAL_SECS)); + let mut available_node_refresh_interval = + tokio::time::interval(Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS)); // subscribe to topics self.subscribe(PingpongHandler::LISTEN_TOPIC).await?; @@ -264,9 +265,9 @@ impl DriaComputeNode { loop { tokio::select! { // check peer count every now and then - _ = tokio::time::sleep(peer_refresh_duration) => self.handle_peer_refresh().await, + _ = peer_refresh_interval.tick() => self.handle_diagnostic_refresh().await, // available nodes are refreshed every now and then - _ = tokio::time::sleep(available_node_refresh_duration) => self.handle_available_nodes_refresh().await, + _ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await, // a Workflow message to be published is received from the channel // this is expected to be sent by the workflow worker publish_msg = self.publish_batch_rx.recv() => { @@ -336,11 +337,16 @@ impl DriaComputeNode { } /// Peer refresh simply reports the peer count to the user. - async fn handle_peer_refresh(&self) { + async fn handle_diagnostic_refresh(&self) { + // print peer counts match self.p2p.peer_counts().await { Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all), Err(e) => log::error!("Error getting peer counts: {:?}", e), } + + // print task counts + let [single, batch] = self.get_active_task_count(); + log::info!("Active Task Count (single/batch): {} / {}", single, batch); } /// Updates the local list of available nodes by refreshing it. From aa1d995589b8232a75dd53665f8a6ba772d1bca3 Mon Sep 17 00:00:00 2001 From: erhant Date: Thu, 28 Nov 2024 18:20:07 +0300 Subject: [PATCH 09/16] combine publish channels --- compute/src/handlers/pingpong.rs | 8 ++--- compute/src/handlers/workflow.rs | 8 ++--- compute/src/main.rs | 8 ++--- compute/src/node.rs | 55 ++++++++++++-------------------- compute/src/workers/workflow.rs | 13 +++----- 5 files changed, 38 insertions(+), 54 deletions(-) diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 1c83779..3ca6479 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -19,11 +19,12 @@ struct PingpongPayload { #[derive(Serialize, Deserialize, Debug, Clone)] struct PingpongResponse { + /// UUID as given in the ping payload. pub(crate) uuid: String, + /// Models available in the node. pub(crate) models: Vec<(ModelProvider, Model)>, - pub(crate) timestamp: u128, /// Number of tasks in the channel currently, `single` and `batch`. - pub(crate) tasks: [usize; 2], + pub(crate) active_task_count: [usize; 2], } impl PingpongHandler { @@ -65,8 +66,7 @@ impl PingpongHandler { let response_body = PingpongResponse { uuid: pingpong.uuid.clone(), models: node.config.workflows.models.clone(), - timestamp: get_current_time_nanos(), - tasks: node.get_active_task_count(), + active_task_count: node.get_active_task_count(), }; // publish message diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index e480a8a..cc93d7c 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -114,6 +114,7 @@ impl WorkflowHandler { ))) } + /// Handles the result of a workflow task. pub(crate) async fn handle_publish( node: &mut DriaComputeNode, task: WorkflowsWorkerOutput, @@ -131,8 +132,7 @@ impl WorkflowHandler { )?; // convert payload to message - let payload_str = serde_json::to_string(&payload) - .wrap_err("could not serialize response payload")?; + let payload_str = serde_json::json!(payload).to_string(); log::debug!( "Publishing result for task {}\n{}", task.task_id, @@ -152,8 +152,7 @@ impl WorkflowHandler { model: task.model_name, stats: task.stats.record_published_at(), }; - let error_payload_str = serde_json::to_string(&error_payload) - .wrap_err("could not serialize error payload")?; + let error_payload_str = serde_json::json!(error_payload).to_string(); // prepare signed message DKNMessage::new_signed( @@ -178,6 +177,7 @@ impl WorkflowHandler { Self::RESPONSE_TOPIC, &node.config.secret_key, ); + node.publish(message).await?; }; diff --git a/compute/src/main.rs b/compute/src/main.rs index 0a0e1ca..daf6a28 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -75,9 +75,8 @@ async fn main() -> Result<()> { return Ok(()); } - let node_token = token.clone(); - let (mut node, p2p, mut worker_batch, mut worker_single) = - DriaComputeNode::new(config, node_token).await?; + // create the node + let (mut node, p2p, mut worker_batch, mut worker_single) = DriaComputeNode::new(config).await?; log::info!("Spawning peer-to-peer client thread."); let p2p_handle = tokio::spawn(async move { p2p.run().await }); @@ -90,8 +89,9 @@ async fn main() -> Result<()> { // launch the node in a separate thread log::info!("Spawning compute node thread."); + let node_token = token.clone(); let node_handle = tokio::spawn(async move { - if let Err(err) = node.run().await { + if let Err(err) = node.run(node_token).await { log::error!("Node launch error: {}", err); panic!("Node failed.") }; diff --git a/compute/src/node.rs b/compute/src/node.rs index 3f1ec7d..064d0ea 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -20,22 +20,21 @@ use crate::{ const DIAGNOSTIC_REFRESH_INTERVAL_SECS: u64 = 30; /// Number of seconds between refreshing the available nodes. const AVAILABLE_NODES_REFRESH_INTERVAL_SECS: u64 = 30 * 60; // 30 minutes +/// Buffer size for message publishes. +const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, pub p2p: DriaP2PCommander, pub available_nodes: AvailableNodes, - pub cancellation: CancellationToken, /// Gossipsub message receiver. message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, + /// Publish receiver to receive messages to be published. + publish_rx: mpsc::Receiver, /// Workflow transmitter to send batchable tasks. workflow_batch_tx: mpsc::Sender, - /// Publish receiver to receive messages to be published. - publish_batch_rx: mpsc::Receiver, /// Workflow transmitter to send single tasks. workflow_single_tx: mpsc::Sender, - /// Publish receiver to receive messages to be published. - publish_single_rx: mpsc::Receiver, } impl DriaComputeNode { @@ -44,7 +43,6 @@ impl DriaComputeNode { /// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all. pub async fn new( config: DriaComputeNodeConfig, - cancellation: CancellationToken, ) -> Result<( DriaComputeNode, DriaP2PClient, @@ -77,22 +75,20 @@ impl DriaComputeNode { protocol, )?; - // create workflow workers - let (workflows_batch_worker, workflow_batch_tx, publish_batch_rx) = WorkflowsWorker::new(); - let (workflows_single_worker, workflow_single_tx, publish_single_rx) = - WorkflowsWorker::new(); + // create workflow workers, all workers use the same publish channel + let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE); + let (workflows_batch_worker, workflow_batch_tx) = WorkflowsWorker::new(publish_tx.clone()); + let (workflows_single_worker, workflow_single_tx) = WorkflowsWorker::new(publish_tx); Ok(( DriaComputeNode { config, p2p: p2p_commander, - cancellation, available_nodes, message_rx, + publish_rx, workflow_batch_tx, - publish_batch_rx, workflow_single_tx, - publish_single_rx, }, p2p_client, workflows_batch_worker, @@ -248,8 +244,8 @@ impl DriaComputeNode { } /// Runs the main loop of the compute node. - /// This method is not expected to return until cancellation occurs. - pub async fn run(&mut self) -> Result<()> { + /// This method is not expected to return until cancellation occurs for the given token. + pub async fn run(&mut self, cancellation: CancellationToken) -> Result<()> { // prepare durations for sleeps let mut peer_refresh_interval = tokio::time::interval(Duration::from_secs(DIAGNOSTIC_REFRESH_INTERVAL_SECS)); @@ -270,16 +266,7 @@ impl DriaComputeNode { _ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await, // a Workflow message to be published is received from the channel // this is expected to be sent by the workflow worker - publish_msg = self.publish_batch_rx.recv() => { - if let Some(result) = publish_msg { - WorkflowHandler::handle_publish(self, result).await?; - } else { - log::error!("Publish channel closed unexpectedly."); - break; - }; - }, - // TODO: make the both receivers handled together somehow - publish_msg = self.publish_single_rx.recv() => { + publish_msg = self.publish_rx.recv() => { if let Some(result) = publish_msg { WorkflowHandler::handle_publish(self, result).await?; } else { @@ -306,7 +293,7 @@ impl DriaComputeNode { }, // check if the cancellation token is cancelled // this is expected to be cancelled by the main thread with signal handling - _ = self.cancellation.cancelled() => break, + _ = cancellation.cancelled() => break, } } @@ -331,7 +318,7 @@ impl DriaComputeNode { self.message_rx.close(); log::debug!("Closing publish channel."); - self.publish_batch_rx.close(); + self.publish_rx.close(); Ok(()) } @@ -345,8 +332,8 @@ impl DriaComputeNode { } // print task counts - let [single, batch] = self.get_active_task_count(); - log::info!("Active Task Count (single/batch): {} / {}", single, batch); + // let [single, batch] = self.get_active_task_count(); + // log::info!("Active Task Count (single/batch): {} / {}", single, batch); } /// Updates the local list of available nodes by refreshing it. @@ -382,18 +369,18 @@ mod tests { // create node let cancellation = CancellationToken::new(); - let (mut node, p2p, _, _) = - DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone()) - .await - .expect("should create node"); + let (mut node, p2p, _, _) = DriaComputeNode::new(DriaComputeNodeConfig::default()) + .await + .expect("should create node"); // spawn p2p task let p2p_task = tokio::spawn(async move { p2p.run().await }); // launch & wait for a while for connections log::info!("Waiting a bit for peer setup."); + let run_cancellation = cancellation.clone(); tokio::select! { - _ = node.run() => (), + _ = node.run(run_cancellation) => (), _ = tokio::time::sleep(tokio::time::Duration::from_secs(20)) => cancellation.cancel(), } log::info!("Connected Peers:\n{:#?}", node.peers().await?); diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index cc02a1e..e4584ac 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -29,8 +29,8 @@ pub struct WorkflowsWorker { publish_tx: mpsc::Sender, } +/// Buffer size for workflow tasks (per worker). const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024; -const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; impl WorkflowsWorker { /// Batch size that defines how many tasks can be executed in parallel at once. @@ -39,13 +39,10 @@ impl WorkflowsWorker { const BATCH_SIZE: usize = 8; /// Creates a worker and returns the sender and receiver for the worker. - pub fn new() -> ( - WorkflowsWorker, - mpsc::Sender, - mpsc::Receiver, - ) { + pub fn new( + publish_tx: mpsc::Sender, + ) -> (WorkflowsWorker, mpsc::Sender) { let (workflow_tx, workflow_rx) = mpsc::channel(WORKFLOW_CHANNEL_BUFSIZE); - let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE); ( Self { @@ -53,10 +50,10 @@ impl WorkflowsWorker { publish_tx, }, workflow_tx, - publish_rx, ) } + /// Closes the workflow receiver channel. fn shutdown(&mut self) { log::warn!("Closing workflows worker."); self.workflow_rx.close(); From 0365ded86858aef9c7cc7a108aa4abc3dd6bb652 Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 29 Nov 2024 13:58:59 +0300 Subject: [PATCH 10/16] optional threads w.r.t model --- compute/src/handlers/pingpong.rs | 4 +- compute/src/handlers/workflow.rs | 22 ++++---- compute/src/main.rs | 91 ++++++++++++++----------------- compute/src/node.rs | 94 +++++++++++++++++++++++++------- compute/src/workers/workflow.rs | 3 + workflows/src/config.rs | 10 ++++ 6 files changed, 141 insertions(+), 83 deletions(-) diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 3ca6479..3ce12a7 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -24,7 +24,7 @@ struct PingpongResponse { /// Models available in the node. pub(crate) models: Vec<(ModelProvider, Model)>, /// Number of tasks in the channel currently, `single` and `batch`. - pub(crate) active_task_count: [usize; 2], + pub(crate) pending_tasks: [usize; 2], } impl PingpongHandler { @@ -66,7 +66,7 @@ impl PingpongHandler { let response_body = PingpongResponse { uuid: pingpong.uuid.clone(), models: node.config.workflows.models.clone(), - active_task_count: node.get_active_task_count(), + pending_tasks: node.get_pending_task_count(), }; // publish message diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index cc93d7c..f7d7899 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -32,7 +32,7 @@ impl WorkflowHandler { pub(crate) async fn handle_compute( node: &mut DriaComputeNode, compute_message: &DKNMessage, - ) -> Result> { + ) -> Result> { let stats = TaskStats::new().record_received_at(); let task = compute_message .parse_payload::>(true) @@ -100,18 +100,16 @@ impl WorkflowHandler { // get workflow as well let workflow = task.input.workflow; - Ok(Either::Right(( - WorkflowsWorkerInput { - entry, - executor, - workflow, - model_name, - task_id: task.task_id, - public_key: task_public_key, - stats, - }, + Ok(Either::Right(WorkflowsWorkerInput { + entry, + executor, + workflow, + model_name, + task_id: task.task_id, + public_key: task_public_key, + stats, batchable, - ))) + })) } /// Handles the result of a workflow task. diff --git a/compute/src/main.rs b/compute/src/main.rs index daf6a28..c6a1b29 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -1,11 +1,12 @@ use dkn_compute::*; use eyre::Result; use std::env; -use tokio_util::sync::CancellationToken; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; #[tokio::main] async fn main() -> Result<()> { let dotenv_result = dotenvy::dotenv(); + // TODO: remove me later when the launcher is fixed amend_log_levels(); @@ -28,88 +29,80 @@ async fn main() -> Result<()> { "# ); - let token = CancellationToken::new(); - let cancellation_token = token.clone(); + // task tracker for multiple threads + let task_tracker = TaskTracker::new(); + let cancellation = CancellationToken::new(); + + // spawn the background task to wait for termination signals + let task_tracker_to_close = task_tracker.clone(); + let cancellation_token = cancellation.clone(); tokio::spawn(async move { - // the timeout is done for profiling only, and should not be used in production if let Ok(Ok(duration_secs)) = env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::()) { + // the timeout is done for profiling only, and should not be used in production log::warn!("Waiting for {} seconds before exiting.", duration_secs); tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await; log::warn!("Exiting due to DKN_EXIT_TIMEOUT."); + cancellation_token.cancel(); } else if let Err(err) = wait_for_termination(cancellation_token.clone()).await { + // if there is no timeout, we wait for termination signals here log::error!("Error waiting for termination: {:?}", err); log::error!("Cancelling due to unexpected error."); cancellation_token.cancel(); }; + + // close tracker in any case + task_tracker_to_close.close(); }); // create configurations & check required services & address in use let mut config = DriaComputeNodeConfig::new(); config.assert_address_not_in_use()?; - let service_check_token = token.clone(); - let config = tokio::spawn(async move { - tokio::select! { - result = config.workflows.check_services() => { - if let Err(err) = result { - log::error!("Error checking services: {:?}", err); - panic!("Service check failed.") - } - log::warn!("Using models: {:#?}", config.workflows.models); - config - } - _ = service_check_token.cancelled() => { - log::info!("Service check cancelled."); - config - } + // check services & models, will exit if there is an error + // since service check can take time, we allow early-exit here as well + tokio::select! { + result = config.workflows.check_services() => result, + _ = cancellation.cancelled() => { + log::info!("Service check cancelled, exiting."); + return Ok(()); } - }) - .await?; - - // check early exit due to failed service check - if token.is_cancelled() { - log::warn!("Not launching node due to early exit, bye!"); - return Ok(()); - } + }?; + log::warn!("Using models: {:#?}", config.workflows.models); // create the node - let (mut node, p2p, mut worker_batch, mut worker_single) = DriaComputeNode::new(config).await?; + let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?; + // spawn threads log::info!("Spawning peer-to-peer client thread."); - let p2p_handle = tokio::spawn(async move { p2p.run().await }); + task_tracker.spawn(async move { p2p.run().await }); - log::info!("Spawning workflows batch worker thread."); - let worker_batch_handle = tokio::spawn(async move { worker_batch.run_batch().await }); + if let Some(mut worker_batch) = worker_batch { + log::info!("Spawning workflows batch worker thread."); + task_tracker.spawn(async move { worker_batch.run_batch().await }); + } - log::info!("Spawning workflows single worker thread."); - let worker_single_handle = tokio::spawn(async move { worker_single.run().await }); + if let Some(mut worker_single) = worker_single { + log::info!("Spawning workflows single worker thread."); + task_tracker.spawn(async move { worker_single.run().await }); + } // launch the node in a separate thread log::info!("Spawning compute node thread."); - let node_token = token.clone(); - let node_handle = tokio::spawn(async move { + let node_token = cancellation.clone(); + task_tracker.spawn(async move { if let Err(err) = node.run(node_token).await { log::error!("Node launch error: {}", err); panic!("Node failed.") }; + log::info!("Closing node.") }); - // wait for tasks to complete - if let Err(err) = node_handle.await { - log::error!("Node handle error: {}", err); - }; - if let Err(err) = worker_single_handle.await { - log::error!("Workflows single worker handle error: {}", err); - }; - if let Err(err) = worker_batch_handle.await { - log::error!("Workflows batch worker handle error: {}", err); - }; - if let Err(err) = p2p_handle.await { - log::error!("P2P handle error: {}", err); - }; + // wait for all tasks to finish + task_tracker.wait().await; + log::info!("All tasks have exited succesfully."); log::info!("Bye!"); Ok(()) @@ -168,7 +161,7 @@ async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> { cancellation.cancel(); } - log::info!("Terminating the node..."); + log::info!("Terminating the application..."); Ok(()) } diff --git a/compute/src/node.rs b/compute/src/node.rs index 064d0ea..4dfe8a1 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -6,6 +6,7 @@ use dkn_p2p::{ DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, }; use eyre::Result; +use std::collections::HashSet; use tokio::{sync::mpsc, time::Duration}; use tokio_util::{either::Either, sync::CancellationToken}; @@ -32,9 +33,15 @@ pub struct DriaComputeNode { /// Publish receiver to receive messages to be published. publish_rx: mpsc::Receiver, /// Workflow transmitter to send batchable tasks. - workflow_batch_tx: mpsc::Sender, + workflow_batch_tx: Option>, /// Workflow transmitter to send single tasks. - workflow_single_tx: mpsc::Sender, + workflow_single_tx: Option>, + // TODO: instead of piggybacking task metadata within channels, we can store them here + // in a hashmap alone, and then use the task_id to get the metadata when needed + // Single tasks hash-map + pending_tasks_single: HashSet, + // Batch tasks hash-map + pending_tasks_batch: HashSet, } impl DriaComputeNode { @@ -46,8 +53,8 @@ impl DriaComputeNode { ) -> Result<( DriaComputeNode, DriaP2PClient, - WorkflowsWorker, - WorkflowsWorker, + Option, + Option, )> { // create the keypair from secret key let keypair = secret_to_keypair(&config.secret_key); @@ -77,8 +84,24 @@ impl DriaComputeNode { // create workflow workers, all workers use the same publish channel let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE); - let (workflows_batch_worker, workflow_batch_tx) = WorkflowsWorker::new(publish_tx.clone()); - let (workflows_single_worker, workflow_single_tx) = WorkflowsWorker::new(publish_tx); + + // check if we should create a worker for batchable workflows + let (workflows_batch_worker, workflow_batch_tx) = if config.workflows.has_batchable_models() + { + let worker = WorkflowsWorker::new(publish_tx.clone()); + (Some(worker.0), Some(worker.1)) + } else { + (None, None) + }; + + // check if we should create a worker for single workflows + let (workflows_single_worker, workflow_single_tx) = + if config.workflows.has_non_batchable_models() { + let worker = WorkflowsWorker::new(publish_tx); + (Some(worker.0), Some(worker.1)) + } else { + (None, None) + }; Ok(( DriaComputeNode { @@ -89,6 +112,8 @@ impl DriaComputeNode { publish_rx, workflow_batch_tx, workflow_single_tx, + pending_tasks_single: HashSet::new(), + pending_tasks_batch: HashSet::new(), }, p2p_client, workflows_batch_worker, @@ -119,10 +144,10 @@ impl DriaComputeNode { } /// Returns the task count within the channels, `single` and `batch`. - pub fn get_active_task_count(&self) -> [usize; 2] { + pub fn get_pending_task_count(&self) -> [usize; 2] { [ - self.workflow_single_tx.max_capacity() - self.workflow_single_tx.capacity(), - self.workflow_batch_tx.max_capacity() - self.workflow_batch_tx.capacity(), + self.pending_tasks_single.len(), + self.pending_tasks_batch.len(), ] } @@ -202,10 +227,32 @@ impl DriaComputeNode { // we got acceptance, so something was not right about the workflow and we can ignore it Ok(Either::Left(acceptance)) => Ok(acceptance), // we got the parsed workflow itself, send to a worker thread w.r.t batchable - Ok(Either::Right((workflow_message, batchable))) => { - if let Err(e) = match batchable { - true => self.workflow_batch_tx.send(workflow_message).await, - false => self.workflow_single_tx.send(workflow_message).await, + Ok(Either::Right(workflow_message)) => { + if let Err(e) = match workflow_message.batchable { + // this is a batchable task, send it to batch worker + // and keep track of the task id in pending tasks + true => match self.workflow_batch_tx { + Some(ref mut tx) => { + self.pending_tasks_batch + .insert(workflow_message.task_id.clone()); + tx.send(workflow_message).await + } + None => unreachable!( + "Batchable workflow received but no worker available." + ), + }, + // this is a single task, send it to single worker + // and keep track of the task id in pending tasks + false => match self.workflow_single_tx { + Some(ref mut tx) => { + self.pending_tasks_single + .insert(workflow_message.task_id.clone()); + tx.send(workflow_message).await + } + None => unreachable!( + "Single workflow received but no worker available." + ), + }, } { log::error!("Error sending workflow message: {:?}", e); }; @@ -266,9 +313,16 @@ impl DriaComputeNode { _ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await, // a Workflow message to be published is received from the channel // this is expected to be sent by the workflow worker - publish_msg = self.publish_rx.recv() => { - if let Some(result) = publish_msg { - WorkflowHandler::handle_publish(self, result).await?; + publish_msg_opt = self.publish_rx.recv() => { + if let Some(publish_msg) = publish_msg_opt { + // remove the task from pending tasks based on its batchability + match publish_msg.batchable { + true => self.pending_tasks_batch.remove(&publish_msg.task_id), + false => self.pending_tasks_single.remove(&publish_msg.task_id), + }; + + // publish the message + WorkflowHandler::handle_publish(self, publish_msg).await?; } else { log::error!("Publish channel closed unexpectedly."); break; @@ -276,8 +330,8 @@ impl DriaComputeNode { }, // a GossipSub message is received from the channel // this is expected to be sent by the p2p client - gossipsub_msg = self.message_rx.recv() => { - if let Some((peer_id, message_id, message)) = gossipsub_msg { + gossipsub_msg_opt = self.message_rx.recv() => { + if let Some((peer_id, message_id, message)) = gossipsub_msg_opt { // handle the message, returning a message acceptance for the received one let acceptance = self.handle_message((peer_id, &message_id, message)).await; @@ -332,8 +386,8 @@ impl DriaComputeNode { } // print task counts - // let [single, batch] = self.get_active_task_count(); - // log::info!("Active Task Count (single/batch): {} / {}", single, batch); + let [single, batch] = self.get_pending_task_count(); + log::info!("Pending Task Count (single/batch): {} / {}", single, batch); } /// Updates the local list of available nodes by refreshing it. diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index e4584ac..626907f 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -13,6 +13,7 @@ pub struct WorkflowsWorkerInput { pub task_id: String, pub model_name: String, pub stats: TaskStats, + pub batchable: bool, } pub struct WorkflowsWorkerOutput { @@ -22,6 +23,7 @@ pub struct WorkflowsWorkerOutput { pub task_id: String, pub model_name: String, pub stats: TaskStats, + pub batchable: bool, } pub struct WorkflowsWorker { @@ -217,6 +219,7 @@ impl WorkflowsWorker { public_key: input.public_key, task_id: input.task_id, model_name: input.model_name, + batchable: input.batchable, stats: input.stats.record_execution_time(started_at), } } diff --git a/workflows/src/config.rs b/workflows/src/config.rs index 5a4f199..e26f1df 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -87,6 +87,16 @@ impl DriaWorkflowsConfig { .collect() } + /// Returns `true` if the configuration contains models that can be processed in parallel, e.g. API calls. + pub fn has_batchable_models(&self) -> bool { + self.models.iter().any(|(p, _)| *p != ModelProvider::Ollama) + } + + /// Returns `true` if the configuration contains a model that cant be run in parallel, e.g. a Ollama model. + pub fn has_non_batchable_models(&self) -> bool { + self.models.iter().any(|(p, _)| *p == ModelProvider::Ollama) + } + /// Given a raw model name or provider (as a string), returns the first matching model & provider. /// /// - If input is `*` or `all`, a random model is returned. From f8f431858064fbf2b4559c2439665e316b9e9edc Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 29 Nov 2024 15:08:31 +0300 Subject: [PATCH 11/16] task completion reports in the end, some readme rfks & removals --- .github/ISSUE_TEMPLATE/bug_report.md | 34 ------------------ .github/ISSUE_TEMPLATE/feature_request.md | 28 --------------- README.md | 8 ++--- compute/src/handlers/workflow.rs | 5 +-- compute/src/main.rs | 43 +++-------------------- compute/src/node.rs | 30 +++++++++++++--- compute/src/workers/workflow.rs | 12 +++---- p2p/src/client.rs | 8 +++++ workflows/README.md | 10 +++--- workflows/tests/models_test.rs | 41 +++++++++++++++------ 10 files changed, 80 insertions(+), 139 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/bug_report.md delete mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index f343dc8..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -name: Bug Report -about: Report any bugs or unexpected behavior -title: "bug: " -labels: bug -assignees: "" ---- - -### Problem - -A clear and concise description of what the bug is. - -### How to Reproduce - -If you can reproduce the behavior, steps to reproduce: - -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -### Expected Behaviour - -A clear and concise description of what you expected to happen. - -### Version - -Please note down the version, it can be seen in the the first log of Dria Compute Node. - -- e.g. `v0.1.0` - -### Additional context - -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 3bc4af0..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,28 +0,0 @@ ---- -name: Feature Request -about: Suggest a new computation idea, or any other feature -title: "feat: " -labels: enhancement -assignees: "" ---- - -### Motivation - -Describe your motivation on requesting this feature. How does it extend the Dria Knowledge Network, or the compute node itself? - -### Technical Requirements - -What is required for the node to expect the requirements of this request? Some examples: - -- Does it require a GPU? -- Does it require an API key for some third party service? -- How much RAM will this feauture require? -- Does this feature require an additional container within the compose file? - -### Task Input - -Describe clearly what the input for this task is. For example, Synthesis tasks take in a prompt as an input, and this is instantiated as: - -```rs -type SynthesisPayload = TaskRequestPayload; -``` diff --git a/README.md b/README.md index d82a688..decf209 100644 --- a/README.md +++ b/README.md @@ -28,15 +28,11 @@ ## About -A **Dria Compute Node** is a unit of computation within the Dria Knowledge Network. It's purpose is to process tasks given by the **Dria Admin Node**. To get started, see [node guide](./docs/NODE_GUIDE.md)! - -### Tasks - Compute nodes can technically do any arbitrary task, from computing the square root of a given number to finding LLM outputs from a given prompt, or validating an LLM's output with respect to knowledge available on the web accessed via tools. -- **Ping/Pong**: Dria Admin Node broadcasts **ping** messages at a set interval, it is a required duty of the compute node to respond with a **pong** to these so that they can be included in the list of available nodes for task assignment. These tasks will respect the type of model provided within the pong message, e.g. if a task requires `gpt-4o` and you are running `phi3`, you won't be selected for that task. +- **Heartbeats**: Every few seconds, a heartbeat ping is published into the network, and every compute node responds with a digitally-signed pong message to indicate that they are alive, along with additional information such as which nodes they are running & how many tasks they have so far. -- **Workflows**: Each task is given in the form of a workflow, based on [Ollama Workflows](https://github.com/andthattoo/ollama-workflows). In simple terms, each workflow defines the agentic behavior of an LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching. +- **Workflows**: Each task is given in the form of a [workflow](https://github.com/andthattoo/ollama-workflows). Every workflow defines an agentic behavior for the chosen LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching & reasoning. ## Node Running diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index f7d7899..af368ea 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -54,10 +54,7 @@ impl WorkflowHandler { // check task inclusion via the bloom filter if !task.filter.contains(&node.config.address)? { - log::info!( - "Task {} does not include this node within the filter.", - task.task_id - ); + log::info!("Task {} ignored due to filter.", task.task_id); // accept the message, someone else may be included in filter return Ok(Either::Left(MessageAcceptance::Accept)); diff --git a/compute/src/main.rs b/compute/src/main.rs index c6a1b29..c0028c3 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -7,9 +7,6 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker}; async fn main() -> Result<()> { let dotenv_result = dotenvy::dotenv(); - // TODO: remove me later when the launcher is fixed - amend_log_levels(); - env_logger::builder() .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) .init(); @@ -75,21 +72,23 @@ async fn main() -> Result<()> { // create the node let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?; - // spawn threads + // spawn p2p client first log::info!("Spawning peer-to-peer client thread."); task_tracker.spawn(async move { p2p.run().await }); + // spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter) if let Some(mut worker_batch) = worker_batch { log::info!("Spawning workflows batch worker thread."); task_tracker.spawn(async move { worker_batch.run_batch().await }); } + // spawn single worker thread if we are using such models (e.g. Ollama) if let Some(mut worker_single) = worker_single { log::info!("Spawning workflows single worker thread."); task_tracker.spawn(async move { worker_single.run().await }); } - // launch the node in a separate thread + // spawn compute node thread log::info!("Spawning compute node thread."); let node_token = cancellation.clone(); task_tracker.spawn(async move { @@ -165,37 +164,3 @@ async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> { Ok(()) } - -// #[deprecated] -/// Very CRUDE fix due to launcher log level bug -/// -/// TODO: remove me later when the launcher is fixed -pub fn amend_log_levels() { - if let Ok(rust_log) = std::env::var("RUST_LOG") { - let log_level = if rust_log.contains("dkn_compute=info") { - "info" - } else if rust_log.contains("dkn_compute=debug") { - "debug" - } else if rust_log.contains("dkn_compute=trace") { - "trace" - } else { - return; - }; - - // check if it contains other log levels - let mut new_rust_log = rust_log.clone(); - if !rust_log.contains("dkn_p2p") { - new_rust_log = format!("{},{}={}", new_rust_log, "dkn_p2p", log_level); - } - if !rust_log.contains("dkn_workflows") { - new_rust_log = format!("{},{}={}", new_rust_log, "dkn_workflows", log_level); - } - std::env::set_var("RUST_LOG", new_rust_log); - } else { - // TODO: use env_logger default function instead of this - std::env::set_var( - "RUST_LOG", - "none,dkn_compute=info,dkn_p2p=info,dkn_workflows=info", - ); - } -} diff --git a/compute/src/node.rs b/compute/src/node.rs index 4dfe8a1..0211169 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -42,6 +42,10 @@ pub struct DriaComputeNode { pending_tasks_single: HashSet, // Batch tasks hash-map pending_tasks_batch: HashSet, + /// Completed single tasks count + completed_tasks_single: usize, + /// Completed batch tasks count + completed_tasks_batch: usize, } impl DriaComputeNode { @@ -114,6 +118,8 @@ impl DriaComputeNode { workflow_single_tx, pending_tasks_single: HashSet::new(), pending_tasks_batch: HashSet::new(), + completed_tasks_single: 0, + completed_tasks_batch: 0, }, p2p_client, workflows_batch_worker, @@ -317,8 +323,14 @@ impl DriaComputeNode { if let Some(publish_msg) = publish_msg_opt { // remove the task from pending tasks based on its batchability match publish_msg.batchable { - true => self.pending_tasks_batch.remove(&publish_msg.task_id), - false => self.pending_tasks_single.remove(&publish_msg.task_id), + true => { + self.completed_tasks_batch += 1; + self.pending_tasks_batch.remove(&publish_msg.task_id); + }, + false => { + self.completed_tasks_single += 1; + self.pending_tasks_single.remove(&publish_msg.task_id); + } }; // publish the message @@ -357,6 +369,9 @@ impl DriaComputeNode { self.unsubscribe(WorkflowHandler::LISTEN_TOPIC).await?; self.unsubscribe(WorkflowHandler::RESPONSE_TOPIC).await?; + // print one final diagnostic as a summary + self.handle_diagnostic_refresh().await; + // shutdown channels self.shutdown().await?; @@ -385,9 +400,16 @@ impl DriaComputeNode { Err(e) => log::error!("Error getting peer counts: {:?}", e), } - // print task counts + // print tasks count let [single, batch] = self.get_pending_task_count(); - log::info!("Pending Task Count (single/batch): {} / {}", single, batch); + log::info!("Pending Tasks (single/batch): {} / {}", single, batch); + + // completed tasks count + log::debug!( + "Completed Tasks (single/batch): {} / {}", + self.completed_tasks_single, + self.completed_tasks_batch + ); } /// Updates the local list of available nodes by refreshing it. diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 626907f..7016bd8 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -26,6 +26,9 @@ pub struct WorkflowsWorkerOutput { pub batchable: bool, } +/// Workflows worker is a task executor that can process workflows in parallel / series. +/// +/// It is expected to be spawned in another thread, with `run_batch` for batch processing and `run` for single processing. pub struct WorkflowsWorker { workflow_rx: mpsc::Receiver, publish_tx: mpsc::Sender, @@ -95,12 +98,6 @@ impl WorkflowsWorker { .workflow_rx .recv_many(&mut task_buffer, Self::BATCH_SIZE) .await; - debug_assert!( - num_tasks <= Self::BATCH_SIZE, - "drain cant be larger than batch size" - ); - // TODO: just to be sure, can be removed later - debug_assert_eq!(num_tasks, task_buffer.len()); if num_tasks == 0 { return self.shutdown(); @@ -186,7 +183,7 @@ impl WorkflowsWorker { } _ => { unreachable!( - "drain cant be larger than batch size ({} > {})", + "number of tasks cant be larger than batch size ({} > {})", num_tasks, Self::BATCH_SIZE ); @@ -194,7 +191,6 @@ impl WorkflowsWorker { }; // publish all results - // TODO: make this a part of executor as well log::info!("Publishing {} workflow results", results.len()); for result in results { if let Err(e) = self.publish_tx.send(result).await { diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 561d52b..dedad87 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -244,7 +244,15 @@ impl DriaP2PClient { let _ = sender.send((mesh, all)); } DriaP2PCommand::Shutdown { sender } => { + // close the command channel self.cmd_rx.close(); + + // remove own peerId from Kademlia DHT + let peer_id = self.swarm.local_peer_id().clone(); + self.swarm.behaviour_mut().kademlia.remove_peer(&peer_id); + + // remove own peerId from Autonat server list + self.swarm.behaviour_mut().autonat.remove_server(&peer_id); let _ = sender.send(()); } } diff --git a/workflows/README.md b/workflows/README.md index 03dd841..618fbc9 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -20,13 +20,13 @@ Note that the underlying [Ollama Workflows](https://github.com/andthattoo/ollama ## Usage -DKN Workflows make use of several environment variables, respecting the providers. +DKN Workflows make use of several environment variables, with respect to several model providers. -- `OLLAMA_HOST` is used to connect to Ollama server -- `OLLAMA_PORT` is used to connect to Ollama server +- `OLLAMA_HOST` is used to connect to **Ollama** server +- `OLLAMA_PORT` is used to connect to **Ollama** server - `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not -- `OPENAI_API_KEY` is used for OpenAI requests -- `GEMINI_API_KEY` is used for Gemini requests +- `OPENAI_API_KEY` is used for **OpenAI** requests +- `GEMINI_API_KEY` is used for **Gemini** requests - `SERPER_API_KEY` is optional API key to use **Serper**, for better Workflow executions - `JINA_API_KEY` is optional API key to use **Jina**, for better Workflow executions diff --git a/workflows/tests/models_test.rs b/workflows/tests/models_test.rs index 2af73b8..a034b93 100644 --- a/workflows/tests/models_test.rs +++ b/workflows/tests/models_test.rs @@ -1,14 +1,21 @@ use dkn_workflows::{DriaWorkflowsConfig, Model, ModelProvider}; use eyre::Result; -use std::env; -const LOG_LEVEL: &str = "none,dkn_workflows=debug"; +fn setup() { + // read api key from .env + let _ = dotenvy::dotenv(); + + // set logger + let _ = env_logger::builder() + .parse_filters("none,dkn_workflows=debug") + .is_test(true) + .try_init(); +} #[tokio::test] #[ignore = "requires Ollama"] async fn test_ollama_check() -> Result<()> { - env::set_var("RUST_LOG", LOG_LEVEL); - let _ = env_logger::builder().is_test(true).try_init(); + setup(); let models = vec![Model::Phi3_5Mini]; let mut model_config = DriaWorkflowsConfig::new(models); @@ -25,9 +32,7 @@ async fn test_ollama_check() -> Result<()> { #[tokio::test] #[ignore = "requires OpenAI"] async fn test_openai_check() -> Result<()> { - let _ = dotenvy::dotenv(); // read api key - env::set_var("RUST_LOG", LOG_LEVEL); - let _ = env_logger::builder().is_test(true).try_init(); + setup(); let models = vec![Model::GPT4Turbo]; let mut model_config = DriaWorkflowsConfig::new(models); @@ -41,11 +46,25 @@ async fn test_openai_check() -> Result<()> { } #[tokio::test] -async fn test_empty() -> Result<()> { - let mut model_config = DriaWorkflowsConfig::new(vec![]); +#[ignore = "requires Gemini"] +async fn test_gemini_check() -> Result<()> { + setup(); - let result = model_config.check_services().await; - assert!(result.is_err()); + let models = vec![Model::Gemini15Flash]; + let mut model_config = DriaWorkflowsConfig::new(models); + model_config.check_services().await?; + assert_eq!( + model_config.models[0], + (ModelProvider::Gemini, Model::Gemini15Flash) + ); Ok(()) } + +#[tokio::test] +async fn test_empty() { + assert!(DriaWorkflowsConfig::new(vec![]) + .check_services() + .await + .is_err()); +} From 3f4a331ef1d100a6a2965eb23c8519d6f04fe4e6 Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 29 Nov 2024 15:23:56 +0300 Subject: [PATCH 12/16] very small refactors --- Cargo.lock | 114 ++++++++++++++++---------------- compute/src/workers/workflow.rs | 13 ++-- p2p/src/client.rs | 2 +- p2p/tests/listen_test.rs | 22 +++--- 4 files changed, 73 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e8f839..dedd7ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,7 +276,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-sink", "futures-util", "memchr", @@ -451,9 +451,9 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" dependencies = [ "serde", ] @@ -466,9 +466,9 @@ checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" [[package]] name = "cc" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" dependencies = [ "shlex", ] @@ -1202,12 +1202,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1584,7 +1584,7 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "fnv", "futures-core", "futures-sink", @@ -1604,7 +1604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", - "bytes 1.8.0", + "bytes 1.9.0", "fnv", "futures-core", "futures-sink", @@ -1701,7 +1701,7 @@ dependencies = [ "ipnet", "once_cell", "rand 0.8.5", - "socket2 0.5.7", + "socket2 0.5.8", "thiserror 1.0.69", "tinyvec", "tokio 1.41.1", @@ -1827,7 +1827,7 @@ version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "fnv", "itoa 1.0.14", ] @@ -1838,7 +1838,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "fnv", "itoa 1.0.14", ] @@ -1859,7 +1859,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "http 0.2.12", "pin-project-lite 0.2.15", ] @@ -1870,7 +1870,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "http 1.1.0", ] @@ -1880,7 +1880,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-util", "http 1.1.0", "http-body 1.0.1", @@ -1941,7 +1941,7 @@ version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-channel", "futures-core", "futures-util", @@ -1952,7 +1952,7 @@ dependencies = [ "httpdate 1.0.3", "itoa 1.0.14", "pin-project-lite 0.2.15", - "socket2 0.5.7", + "socket2 0.5.8", "tokio 1.41.1", "tower-service", "tracing", @@ -1965,7 +1965,7 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-channel", "futures-util", "h2 0.4.7", @@ -2016,7 +2016,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "http-body-util", "hyper 1.5.1", "hyper-util", @@ -2032,14 +2032,14 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.1", "hyper 1.5.1", "pin-project-lite 0.2.15", - "socket2 0.5.7", + "socket2 0.5.8", "tokio 1.41.1", "tower-service", "tracing", @@ -2264,7 +2264,7 @@ checksum = "064d90fec10d541084e7b39ead8875a5a80d9114a2b18791565253bae25f49e4" dependencies = [ "async-trait", "attohttpc", - "bytes 1.8.0", + "bytes 1.9.0", "futures", "http 0.2.12", "hyper 0.14.31", @@ -2334,7 +2334,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.5.7", + "socket2 0.5.8", "widestring", "windows-sys 0.48.0", "winreg 0.50.0", @@ -2420,16 +2420,16 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.166" +version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2ccc108bbc0b1331bd061864e7cd823c0cab660bbe6970e66e2c0614decde36" +checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" [[package]] name = "libp2p" version = "0.54.1" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "either", "futures", "futures-timer", @@ -2478,7 +2478,7 @@ source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e dependencies = [ "async-trait", "asynchronous-codec", - "bytes 1.8.0", + "bytes 1.9.0", "either", "futures", "futures-bounded", @@ -2580,7 +2580,7 @@ dependencies = [ "asynchronous-codec", "base64 0.22.1", "byteorder", - "bytes 1.8.0", + "bytes 1.9.0", "either", "fnv", "futures", @@ -2652,7 +2652,7 @@ source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e dependencies = [ "arrayvec", "asynchronous-codec", - "bytes 1.8.0", + "bytes 1.9.0", "either", "fnv", "futures", @@ -2687,7 +2687,7 @@ dependencies = [ "libp2p-swarm", "rand 0.8.5", "smallvec", - "socket2 0.5.7", + "socket2 0.5.8", "tokio 1.41.1", "tracing", "void", @@ -2719,7 +2719,7 @@ version = "0.45.0" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ "asynchronous-codec", - "bytes 1.8.0", + "bytes 1.9.0", "curve25519-dalek", "futures", "libp2p-core", @@ -2760,7 +2760,7 @@ name = "libp2p-quic" version = "0.11.1" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures", "futures-timer", "if-watch", @@ -2772,7 +2772,7 @@ dependencies = [ "rand 0.8.5", "ring 0.17.8", "rustls", - "socket2 0.5.7", + "socket2 0.5.8", "thiserror 1.0.69", "tokio 1.41.1", "tracing", @@ -2784,7 +2784,7 @@ version = "0.18.0" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ "asynchronous-codec", - "bytes 1.8.0", + "bytes 1.9.0", "either", "futures", "futures-bounded", @@ -2866,7 +2866,7 @@ dependencies = [ "libc", "libp2p-core", "libp2p-identity", - "socket2 0.5.7", + "socket2 0.5.8", "tokio 1.41.1", "tracing", ] @@ -3217,7 +3217,7 @@ name = "multistream-select" version = "0.13.0" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures", "pin-project", "smallvec", @@ -3296,7 +3296,7 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b33524dc0968bfad349684447bfce6db937a9ac3332a1fe60c0c5a5ce63f21" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures", "log", "netlink-packet-core", @@ -3311,7 +3311,7 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "416060d346fbaf1f23f9512963e3e878f1a78e707cb699ba9215761754244307" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures", "libc", "log", @@ -3492,7 +3492,7 @@ version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3145b6053780214d0d872f204c92e2cf65706b8b78aa304d76567a8d3764d15" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "derive_builder", "reqwest 0.12.9", "serde", @@ -3952,7 +3952,7 @@ version = "0.3.1" source = "git+https://github.com/anilaltuner/rust-libp2p.git?rev=7ce9f9e#7ce9f9e65ddbe1fdac3913f0f3c1d94edc1de25e" dependencies = [ "asynchronous-codec", - "bytes 1.8.0", + "bytes 1.9.0", "quick-protobuf", "thiserror 1.0.69", "unsigned-varint", @@ -3964,14 +3964,14 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-io", "pin-project-lite 0.2.15", "quinn-proto", "quinn-udp", "rustc-hash", "rustls", - "socket2 0.5.7", + "socket2 0.5.8", "thiserror 2.0.3", "tokio 1.41.1", "tracing", @@ -3983,7 +3983,7 @@ version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "getrandom 0.2.15", "rand 0.8.5", "ring 0.17.8", @@ -4006,7 +4006,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.7", + "socket2 0.5.8", "tracing", "windows-sys 0.59.0", ] @@ -4224,7 +4224,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", - "bytes 1.8.0", + "bytes 1.9.0", "encoding_rs", "futures-core", "futures-util", @@ -4275,7 +4275,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee84cc47a7a0ac7562173c8f421c058e4c72089d6e662f32e2cb4bcc8e6e9201" dependencies = [ "async-trait", - "bytes 1.8.0", + "bytes 1.9.0", "cargo-husky", "futures", "reqwest 0.12.9", @@ -4388,9 +4388,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.18" +version = "0.23.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" +checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" dependencies = [ "once_cell", "ring 0.17.8", @@ -4694,7 +4694,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18278f6a914fa3070aa316493f7d2ddfb9ac86ebc06fa3b83bffda487e9065b0" dependencies = [ "async-trait", - "bytes 1.8.0", + "bytes 1.9.0", "hex", "sha2 0.10.8", "tokio 1.41.1", @@ -4794,9 +4794,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -5171,13 +5171,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", - "bytes 1.8.0", + "bytes 1.9.0", "libc", "mio 1.0.2", "parking_lot", "pin-project-lite 0.2.15", "signal-hook-registry", - "socket2 0.5.7", + "socket2 0.5.8", "tokio-macros", "windows-sys 0.52.0", ] @@ -5244,7 +5244,7 @@ version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ - "bytes 1.8.0", + "bytes 1.9.0", "futures-core", "futures-sink", "futures-util", @@ -5261,9 +5261,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "log", "pin-project-lite 0.2.15", diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 7016bd8..5513e38 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -49,13 +49,12 @@ impl WorkflowsWorker { ) -> (WorkflowsWorker, mpsc::Sender) { let (workflow_tx, workflow_rx) = mpsc::channel(WORKFLOW_CHANNEL_BUFSIZE); - ( - Self { - workflow_rx, - publish_tx, - }, - workflow_tx, - ) + let worker = WorkflowsWorker { + workflow_rx, + publish_tx, + }; + + (worker, workflow_tx) } /// Closes the workflow receiver channel. diff --git a/p2p/src/client.rs b/p2p/src/client.rs index dedad87..a956146 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -248,7 +248,7 @@ impl DriaP2PClient { self.cmd_rx.close(); // remove own peerId from Kademlia DHT - let peer_id = self.swarm.local_peer_id().clone(); + let peer_id = *self.swarm.local_peer_id(); self.swarm.behaviour_mut().kademlia.remove_peer(&peer_id); // remove own peerId from Autonat server list diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs index c67db71..4b52619 100644 --- a/p2p/tests/listen_test.rs +++ b/p2p/tests/listen_test.rs @@ -1,30 +1,25 @@ use dkn_p2p::{DriaP2PClient, DriaP2PProtocol}; use eyre::Result; -use libp2p::Multiaddr; use libp2p_identity::Keypair; -use std::{env, str::FromStr}; #[tokio::test] #[ignore = "run this manually"] async fn test_listen_topic_once() -> Result<()> { const TOPIC: &str = "pong"; - env::set_var("RUST_LOG", "none,listen_test=debug,dkn_p2p=debug"); - let _ = env_logger::builder().is_test(true).try_init(); + let _ = env_logger::builder() + .parse_filters("none,listen_test=debug,dkn_p2p=debug") + .is_test(true) + .try_init(); // spawn P2P client in another task let (client, mut commander, mut msg_rx) = DriaP2PClient::new( Keypair::generate_secp256k1(), - Multiaddr::from_str("/ip4/0.0.0.0/tcp/4001")?, - vec![Multiaddr::from_str( - "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", - )?].into_iter(), - vec![Multiaddr::from_str( - "/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa", - )?] - .into_iter(), + "/ip4/0.0.0.0/tcp/4001".parse()?, + vec!["/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4".parse()?].into_iter(), + vec!["/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa".parse()?].into_iter(), vec![].into_iter(), - DriaP2PProtocol::new_major_minor("dria"), + DriaP2PProtocol::default(), ) .expect("could not create p2p client"); @@ -57,6 +52,7 @@ async fn test_listen_topic_once() -> Result<()> { // close command channel commander.shutdown().await.expect("could not shutdown"); + // close message channel msg_rx.close(); From 5a3600dd26a8b999aa815760c1557b2a1b0cd7eb Mon Sep 17 00:00:00 2001 From: erhant Date: Fri, 29 Nov 2024 23:08:54 +0300 Subject: [PATCH 13/16] some kademlia fixes --- README.md | 2 +- compute/src/node.rs | 11 ++++---- compute/src/workers/workflow.rs | 4 ++- p2p/README.md | 2 +- p2p/src/behaviour.rs | 9 +++--- p2p/src/client.rs | 49 ++++++++++++++++++++++++++++++++- workflows/README.md | 5 ++-- 7 files changed, 65 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index decf209..b477daf 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Compute nodes can technically do any arbitrary task, from computing the square r - **Workflows**: Each task is given in the form of a [workflow](https://github.com/andthattoo/ollama-workflows). Every workflow defines an agentic behavior for the chosen LLM, all captured in a single JSON file, and can represent things ranging from simple LLM generations to iterative web searching & reasoning. -## Node Running +### Running a Node Refer to [node guide](./docs/NODE_GUIDE.md) to quickly get started and run your own node! diff --git a/compute/src/node.rs b/compute/src/node.rs index 0211169..bc8851b 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -26,18 +26,17 @@ const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, - pub p2p: DriaP2PCommander, pub available_nodes: AvailableNodes, - /// Gossipsub message receiver. + /// Peer-to-peer client commander to interact with the network. + pub p2p: DriaP2PCommander, + /// Gossipsub message receiver, used by peer-to-peer client in a separate thread. message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>, - /// Publish receiver to receive messages to be published. + /// Publish receiver to receive messages to be published, publish_rx: mpsc::Receiver, /// Workflow transmitter to send batchable tasks. workflow_batch_tx: Option>, /// Workflow transmitter to send single tasks. workflow_single_tx: Option>, - // TODO: instead of piggybacking task metadata within channels, we can store them here - // in a hashmap alone, and then use the task_id to get the metadata when needed // Single tasks hash-map pending_tasks_single: HashSet, // Batch tasks hash-map @@ -402,7 +401,7 @@ impl DriaComputeNode { // print tasks count let [single, batch] = self.get_pending_task_count(); - log::info!("Pending Tasks (single/batch): {} / {}", single, batch); + log::info!("Pending Tasks (single/batch): {} / {}", single, batch); // completed tasks count log::debug!( diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 5513e38..847332f 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -4,6 +4,8 @@ use tokio::sync::mpsc; use crate::payloads::TaskStats; +// TODO: instead of piggybacking stuff here, maybe node can hold it in a hashmap w.r.t taskId + pub struct WorkflowsWorkerInput { pub entry: Option, pub executor: Executor, @@ -59,7 +61,7 @@ impl WorkflowsWorker { /// Closes the workflow receiver channel. fn shutdown(&mut self) { - log::warn!("Closing workflows worker."); + log::info!("Closing workflows worker."); self.workflow_rx.close(); } diff --git a/p2p/README.md b/p2p/README.md index fb5463b..88b6f73 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -1,4 +1,4 @@ -# DKN Peer-to-Peer Client +# Dria Peer-to-Peer Client Dria Knowledge Network is a peer-to-peer network, built over libp2p. This crate is a wrapper client to easily interact with DKN. diff --git a/p2p/src/behaviour.rs b/p2p/src/behaviour.rs index f84dbfb..e9694b7 100644 --- a/p2p/src/behaviour.rs +++ b/p2p/src/behaviour.rs @@ -65,12 +65,14 @@ fn create_kademlia_behaviour( ) -> kad::Behaviour { use kad::{Behaviour, Config}; - const QUERY_TIMEOUT_SECS: u64 = 5 * 60; - const RECORD_TTL_SECS: u64 = 30; + const KADEMLIA_BOOTSTRAP_INTERVAL_SECS: u64 = 5 * 60; // default is 5 minutes + const QUERY_TIMEOUT_SECS: u64 = 3 * 60; // default is 1 minute let mut cfg = Config::new(protocol_name); cfg.set_query_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS)) - .set_record_ttl(Some(Duration::from_secs(RECORD_TTL_SECS))); + .set_periodic_bootstrap_interval(Some(Duration::from_secs( + KADEMLIA_BOOTSTRAP_INTERVAL_SECS, + ))); Behaviour::with_config(local_peer_id, MemoryStore::new(local_peer_id), cfg) } @@ -157,7 +159,6 @@ fn create_gossipsub_behaviour(author: PeerId) -> Result { MessageId::from(digest.to_be_bytes()) }; - // TODO: add data transform here later Behaviour::new( MessageAuthenticity::Author(author), ConfigBuilder::default() diff --git a/p2p/src/client.rs b/p2p/src/client.rs index a956146..c4a8eac 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -154,7 +154,7 @@ impl DriaP2PClient { Some(c) => self.handle_command(c).await, // channel closed, thus shutting down the network event loop None=> { - log::warn!("Closing P2P client."); + log::info!("Closing peer-to-peer client."); return }, }, @@ -253,6 +253,7 @@ impl DriaP2PClient { // remove own peerId from Autonat server list self.swarm.behaviour_mut().autonat.remove_server(&peer_id); + let _ = sender.send(()); } } @@ -272,18 +273,22 @@ impl DriaP2PClient { } } + // kademlia events SwarmEvent::Behaviour(DriaBehaviourEvent::Kademlia( kad::Event::OutboundQueryProgressed { result: QueryResult::GetClosestPeers(result), .. }, )) => self.handle_closest_peers_result(result), + + // identify events SwarmEvent::Behaviour(DriaBehaviourEvent::Identify(identify::Event::Received { peer_id, info, .. })) => self.handle_identify_event(peer_id, info), + // autonat events SwarmEvent::Behaviour(DriaBehaviourEvent::Autonat(autonat::Event::StatusChanged { old, new, @@ -291,13 +296,55 @@ impl DriaP2PClient { log::warn!("AutoNAT status changed from {:?} to {:?}", old, new); } + // log listen addreses SwarmEvent::NewListenAddr { address, .. } => { log::warn!("Local node is listening on {}", address); } + + // add external address of peers to Kademlia routing table + SwarmEvent::NewExternalAddrOfPeer { peer_id, address } => { + self.swarm + .behaviour_mut() + .kademlia + .add_address(&peer_id, address); + } + // add your own peer_id to kademlia as well SwarmEvent::ExternalAddrConfirmed { address } => { // this is usually the external address via relay log::info!("External address confirmed: {}", address); + let peer_id = *self.swarm.local_peer_id(); + self.swarm + .behaviour_mut() + .kademlia + .add_address(&peer_id, address); } + + // SwarmEvent::IncomingConnectionError { + // local_addr, + // send_back_addr, + // error, + // connection_id, + // } => { + // log::debug!( + // "Incoming connection {} error: from {} to {} - {:?}", + // connection_id, + // local_addr, + // send_back_addr, + // error + // ); + // } + // SwarmEvent::IncomingConnection { + // connection_id, + // local_addr, + // send_back_addr, + // } => { + // log::debug!( + // "Incoming connection {} attepmt: from {} to {}", + // connection_id, + // local_addr, + // send_back_addr + // ); + // } // SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => { // if let Some(peer_id) = peer_id { // log::warn!("Could not connect to peer {}: {:?}", peer_id, error); diff --git a/workflows/README.md b/workflows/README.md index 618fbc9..3c83296 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -1,7 +1,6 @@ -# DKN Workflows +# Dria Workflows -We make use of Ollama Workflows in DKN; however, we also want to make sure that the chosen models are valid and is performant enough (i.e. have enough TPS). -This crate handles the configurations of models to be used, and implements various service checks. +We make use of [Ollama Workflows](https://github.com/andthattoo/ollama-workflows) in Dria Knowledge Network; however, we also want to make sure that the chosen models are valid and is performant enough (i.e. have enough TPS). This crate handles the configurations of models to be used, and implements various service checks. There are two types of services: From f1067e9522b6cc849c53c86e51e1ac3583d4c59e Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 2 Dec 2024 19:11:51 +0300 Subject: [PATCH 14/16] utils as package, code tidyups and refactors --- Cargo.lock | 150 ++++++++++--------- Cargo.toml | 2 +- compute/Cargo.toml | 1 + compute/src/config.rs | 29 +--- compute/src/handlers/pingpong.rs | 10 +- compute/src/handlers/workflow.rs | 11 +- compute/src/lib.rs | 2 +- compute/src/node.rs | 39 ++--- compute/src/payloads/stats.rs | 3 +- compute/src/utils/available_nodes/mod.rs | 143 ------------------ compute/src/utils/available_nodes/statics.rs | 50 ------- compute/src/utils/message.rs | 22 ++- compute/src/utils/misc.rs | 19 +-- compute/src/utils/mod.rs | 8 +- compute/src/utils/nodes.rs | 69 +++++++++ p2p/Cargo.toml | 2 + p2p/src/client.rs | 3 +- p2p/src/lib.rs | 6 + p2p/src/network.rs | 78 ++++++++++ p2p/src/nodes.rs | 72 +++++++++ utils/Cargo.toml | 9 ++ utils/README.md | 19 +++ workflows/src/utils.rs => utils/src/lib.rs | 25 ++++ workflows/Cargo.toml | 3 +- workflows/src/apis/jina.rs | 3 +- workflows/src/apis/serper.rs | 3 +- workflows/src/config.rs | 3 +- workflows/src/lib.rs | 3 - workflows/src/providers/gemini.rs | 3 +- workflows/src/providers/openai.rs | 3 +- workflows/src/providers/openrouter.rs | 3 +- 31 files changed, 421 insertions(+), 375 deletions(-) delete mode 100644 compute/src/utils/available_nodes/mod.rs delete mode 100644 compute/src/utils/available_nodes/statics.rs create mode 100644 compute/src/utils/nodes.rs create mode 100644 p2p/src/network.rs create mode 100644 p2p/src/nodes.rs create mode 100644 utils/Cargo.toml create mode 100644 utils/README.md rename workflows/src/utils.rs => utils/src/lib.rs (71%) diff --git a/Cargo.lock b/Cargo.lock index dedd7ea..af06e03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,9 +76,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android-tzdata" @@ -186,7 +186,7 @@ checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", "synstructure", ] @@ -198,7 +198,7 @@ checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -256,7 +256,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -267,7 +267,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -320,7 +320,7 @@ dependencies = [ "derive_utils", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -700,7 +700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" dependencies = [ "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -768,7 +768,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -792,7 +792,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -803,7 +803,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -883,7 +883,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -893,7 +893,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -906,7 +906,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -917,7 +917,7 @@ checksum = "65f152f4b8559c4da5d574bafc7af85454d706b4c5fe8b530d508cacbb6807ea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -969,7 +969,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -979,6 +979,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "dkn-p2p", + "dkn-utils", "dkn-workflows", "dotenvy", "ecies", @@ -1008,6 +1009,7 @@ dependencies = [ name = "dkn-p2p" version = "0.2.25" dependencies = [ + "dkn-utils", "env_logger 0.11.5", "eyre", "libp2p", @@ -1017,10 +1019,15 @@ dependencies = [ "tokio-util 0.7.12", ] +[[package]] +name = "dkn-utils" +version = "0.2.25" + [[package]] name = "dkn-workflows" version = "0.2.25" dependencies = [ + "dkn-utils", "dotenvy", "env_logger 0.11.5", "eyre", @@ -1142,7 +1149,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -1223,9 +1230,9 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" dependencies = [ "event-listener", "pin-project-lite 0.2.15", @@ -1410,7 +1417,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -1590,7 +1597,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.6.0", + "indexmap 2.7.0", "slab", "tokio 1.41.1", "tokio-util 0.7.12", @@ -1609,7 +1616,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.6.0", + "indexmap 2.7.0", "slab", "tokio 1.41.1", "tokio-util 0.7.12", @@ -1818,7 +1825,7 @@ dependencies = [ "markup5ever 0.12.1", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -2183,7 +2190,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -2293,9 +2300,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -2386,10 +2393,11 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -2852,7 +2860,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -3150,11 +3158,10 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", @@ -3524,7 +3531,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -3720,7 +3727,7 @@ dependencies = [ "phf_shared 0.11.2", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -3767,7 +3774,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -3928,7 +3935,7 @@ checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -4351,9 +4358,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustc_version" @@ -4617,7 +4624,7 @@ checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -4887,7 +4894,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -4909,9 +4916,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.89" +version = "2.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" dependencies = [ "proc-macro2", "quote", @@ -4935,7 +4942,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5066,7 +5073,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5077,7 +5084,7 @@ checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5173,7 +5180,7 @@ dependencies = [ "backtrace", "bytes 1.9.0", "libc", - "mio 1.0.2", + "mio 1.0.3", "parking_lot", "pin-project-lite 0.2.15", "signal-hook-registry", @@ -5190,7 +5197,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5279,7 +5286,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5489,9 +5496,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" dependencies = [ "cfg-if 1.0.0", "once_cell", @@ -5502,36 +5509,37 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.45" +version = "0.4.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" dependencies = [ "cfg-if 1.0.0", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5539,22 +5547,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" [[package]] name = "wasm-streams" @@ -5571,9 +5579,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" dependencies = [ "js-sys", "wasm-bindgen", @@ -5706,7 +5714,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -5717,7 +5725,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -6076,7 +6084,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", "synstructure", ] @@ -6098,7 +6106,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -6118,7 +6126,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", "synstructure", ] @@ -6139,7 +6147,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] [[package]] @@ -6161,5 +6169,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.90", ] diff --git a/Cargo.toml b/Cargo.toml index 69ab2fe..10e060b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["compute", "p2p", "workflows"] +members = ["compute", "p2p", "workflows", "utils"] # compute node is the default member, until Oracle comes in # then, a Launcher will be the default member default-members = ["compute"] diff --git a/compute/Cargo.toml b/compute/Cargo.toml index 2cc1c97..a18c6bf 100644 --- a/compute/Cargo.toml +++ b/compute/Cargo.toml @@ -44,6 +44,7 @@ fastbloom-rs = "0.5.9" # dria subcrates dkn-p2p = { path = "../p2p" } +dkn-utils = { path = "../utils" } dkn-workflows = { path = "../workflows" } # vendor OpenSSL so that its easier to build cross-platform packages diff --git a/compute/src/config.rs b/compute/src/config.rs index 8336d97..8130689 100644 --- a/compute/src/config.rs +++ b/compute/src/config.rs @@ -2,40 +2,13 @@ use crate::utils::{ address_in_use, crypto::{secret_to_keypair, to_address}, }; -use dkn_p2p::libp2p::Multiaddr; +use dkn_p2p::{libp2p::Multiaddr, DriaNetworkType}; use dkn_workflows::DriaWorkflowsConfig; use eyre::{eyre, Result}; use libsecp256k1::{PublicKey, SecretKey}; use std::{env, str::FromStr}; -/// Network type. -#[derive(Default, Debug, Clone, Copy)] -pub enum DriaNetworkType { - #[default] - Community, - Pro, -} - -impl From<&str> for DriaNetworkType { - fn from(s: &str) -> Self { - match s { - "community" => DriaNetworkType::Community, - "pro" => DriaNetworkType::Pro, - _ => Default::default(), - } - } -} - -impl DriaNetworkType { - pub fn protocol_name(&self) -> &str { - match self { - DriaNetworkType::Community => "dria", - DriaNetworkType::Pro => "dria-sdk", - } - } -} - #[derive(Debug, Clone)] pub struct DriaComputeNodeConfig { /// Wallet secret/private key. diff --git a/compute/src/handlers/pingpong.rs b/compute/src/handlers/pingpong.rs index 3ce12a7..fbfe5c2 100644 --- a/compute/src/handlers/pingpong.rs +++ b/compute/src/handlers/pingpong.rs @@ -1,8 +1,6 @@ -use crate::{ - utils::{get_current_time_nanos, DKNMessage}, - DriaComputeNode, -}; +use crate::{utils::DriaMessage, DriaComputeNode}; use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_utils::get_current_time_nanos; use dkn_workflows::{Model, ModelProvider}; use eyre::{Context, Result}; use serde::{Deserialize, Serialize}; @@ -42,7 +40,7 @@ impl PingpongHandler { /// 7. Returns `MessageAcceptance::Accept` so that ping is propagated to others as well. pub(crate) async fn handle_ping( node: &mut DriaComputeNode, - ping_message: &DKNMessage, + ping_message: &DriaMessage, ) -> Result { let pingpong = ping_message .parse_payload::(true) @@ -70,7 +68,7 @@ impl PingpongHandler { }; // publish message - let message = DKNMessage::new_signed( + let message = DriaMessage::new_signed( serde_json::json!(response_body).to_string(), Self::RESPONSE_TOPIC, &node.config.secret_key, diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index af368ea..ec3e889 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -1,4 +1,5 @@ use dkn_p2p::libp2p::gossipsub::MessageAcceptance; +use dkn_utils::get_current_time_nanos; use dkn_workflows::{Entry, Executor, ModelProvider, Workflow}; use eyre::{Context, Result}; use libsecp256k1::PublicKey; @@ -6,7 +7,7 @@ use serde::Deserialize; use tokio_util::either::Either; use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats}; -use crate::utils::{get_current_time_nanos, DKNMessage}; +use crate::utils::DriaMessage; use crate::workers::workflow::*; use crate::DriaComputeNode; @@ -31,7 +32,7 @@ impl WorkflowHandler { pub(crate) async fn handle_compute( node: &mut DriaComputeNode, - compute_message: &DKNMessage, + compute_message: &DriaMessage, ) -> Result> { let stats = TaskStats::new().record_received_at(); let task = compute_message @@ -133,7 +134,7 @@ impl WorkflowHandler { task.task_id, payload_str ); - DKNMessage::new(payload_str, Self::RESPONSE_TOPIC) + DriaMessage::new(payload_str, Self::RESPONSE_TOPIC) } Err(err) => { // use pretty display string for error logging with causes @@ -150,7 +151,7 @@ impl WorkflowHandler { let error_payload_str = serde_json::json!(error_payload).to_string(); // prepare signed message - DKNMessage::new_signed( + DriaMessage::new_signed( error_payload_str, Self::RESPONSE_TOPIC, &node.config.secret_key, @@ -167,7 +168,7 @@ impl WorkflowHandler { "taskId": task.task_id, "error": err_msg, }); - let message = DKNMessage::new_signed( + let message = DriaMessage::new_signed( payload.to_string(), Self::RESPONSE_TOPIC, &node.config.secret_key, diff --git a/compute/src/lib.rs b/compute/src/lib.rs index 696eb29..c554054 100644 --- a/compute/src/lib.rs +++ b/compute/src/lib.rs @@ -9,5 +9,5 @@ pub(crate) mod workers; /// This value is attached within the published messages. pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION"); -pub use config::{DriaComputeNodeConfig, DriaNetworkType}; +pub use config::DriaComputeNodeConfig; pub use node::DriaComputeNode; diff --git a/compute/src/node.rs b/compute/src/node.rs index bc8851b..eaed9db 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -3,7 +3,7 @@ use dkn_p2p::{ gossipsub::{Message, MessageAcceptance, MessageId}, PeerId, }, - DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, + DriaNodes, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, }; use eyre::Result; use std::collections::HashSet; @@ -13,7 +13,7 @@ use tokio_util::{either::Either, sync::CancellationToken}; use crate::{ config::*, handlers::*, - utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage}, + utils::{crypto::secret_to_keypair, refresh_dria_nodes, DriaMessage}, workers::workflow::{WorkflowsWorker, WorkflowsWorkerInput, WorkflowsWorkerOutput}, }; @@ -26,7 +26,8 @@ const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; pub struct DriaComputeNode { pub config: DriaComputeNodeConfig, - pub available_nodes: AvailableNodes, + /// Pre-defined nodes that belong to Dria, e.g. bootstraps, relays and RPCs. + pub dria_nodes: DriaNodes, /// Peer-to-peer client commander to interact with the network. pub p2p: DriaP2PCommander, /// Gossipsub message receiver, used by peer-to-peer client in a separate thread. @@ -63,10 +64,10 @@ impl DriaComputeNode { let keypair = secret_to_keypair(&config.secret_key); // get available nodes (bootstrap, relay, rpc) for p2p - let mut available_nodes = AvailableNodes::new(config.network_type); - available_nodes.populate_with_statics(); - available_nodes.populate_with_env(); - if let Err(e) = available_nodes.populate_with_api().await { + let mut available_nodes = DriaNodes::new(config.network_type) + .with_statics() + .with_envs(); + if let Err(e) = refresh_dria_nodes(&mut available_nodes).await { log::error!("Error populating available nodes: {:?}", e); }; @@ -81,7 +82,7 @@ impl DriaComputeNode { config.p2p_listen_addr.clone(), available_nodes.bootstrap_nodes.clone().into_iter(), available_nodes.relay_nodes.clone().into_iter(), - available_nodes.rpc_addrs.clone().into_iter(), + available_nodes.rpc_nodes.clone().into_iter(), protocol, )?; @@ -110,7 +111,7 @@ impl DriaComputeNode { DriaComputeNode { config, p2p: p2p_commander, - available_nodes, + dria_nodes: available_nodes, message_rx, publish_rx, workflow_batch_tx, @@ -160,7 +161,7 @@ impl DriaComputeNode { /// /// Internally, identity is attached to the the message which is then JSON serialized to bytes /// and then published to the network as is. - pub async fn publish(&mut self, mut message: DKNMessage) -> Result<()> { + pub async fn publish(&mut self, mut message: DriaMessage) -> Result<()> { // attach protocol name to the message message = message.with_identity(self.p2p.protocol().name.clone()); @@ -203,17 +204,17 @@ impl DriaComputeNode { ); // ensure that message is from the known RPCs - if !self.available_nodes.rpc_nodes.contains(&source_peer_id) { + if !self.dria_nodes.rpc_peerids.contains(&source_peer_id) { log::warn!( "Received message from unauthorized source: {}", source_peer_id ); - log::debug!("Allowed sources: {:#?}", self.available_nodes.rpc_nodes); + log::debug!("Allowed sources: {:#?}", self.dria_nodes.rpc_peerids); return MessageAcceptance::Ignore; } // parse the raw gossipsub message to a prepared DKN message - let message = match DKNMessage::try_from_gossipsub_message( + let message = match DriaMessage::try_from_gossipsub_message( &message, &self.config.admin_public_key, ) { @@ -414,20 +415,22 @@ impl DriaComputeNode { /// Updates the local list of available nodes by refreshing it. /// Dials the RPC nodes again for better connectivity. async fn handle_available_nodes_refresh(&mut self) { - log::info!("Refreshing available nodes."); + log::info!("Refreshing available Dria nodes."); // refresh available nodes - if let Err(e) = self.available_nodes.populate_with_api().await { + if let Err(e) = refresh_dria_nodes(&mut self.dria_nodes).await { log::error!("Error refreshing available nodes: {:?}", e); }; // dial all rpc nodes - for rpc_addr in self.available_nodes.rpc_addrs.iter() { - log::debug!("Dialling RPC node: {}", rpc_addr); + for rpc_addr in self.dria_nodes.rpc_nodes.iter() { + log::info!("Dialling RPC node: {}", rpc_addr); if let Err(e) = self.p2p.dial(rpc_addr.clone()).await { log::warn!("Error dialling RPC node: {:?}", e); }; } + + log::info!("Finished refreshing!"); } } @@ -462,7 +465,7 @@ mod tests { // publish a dummy message let topic = "foo"; - let message = DKNMessage::new("hello from the other side", topic); + let message = DriaMessage::new("hello from the other side", topic); node.subscribe(topic).await.expect("should subscribe"); node.publish(message).await.expect("should publish"); node.unsubscribe(topic).await.expect("should unsubscribe"); diff --git a/compute/src/payloads/stats.rs b/compute/src/payloads/stats.rs index 3263cb8..9fcae51 100644 --- a/compute/src/payloads/stats.rs +++ b/compute/src/payloads/stats.rs @@ -1,8 +1,7 @@ +use dkn_utils::get_current_time_nanos; use serde::{Deserialize, Serialize}; use std::time::Instant; -use crate::utils::get_current_time_nanos; - /// Task stats for diagnostics. /// Returning this as the payload helps to debug the errors received at client side, and latencies. #[derive(Default, Debug, Clone, Serialize, Deserialize)] diff --git a/compute/src/utils/available_nodes/mod.rs b/compute/src/utils/available_nodes/mod.rs deleted file mode 100644 index 620870e..0000000 --- a/compute/src/utils/available_nodes/mod.rs +++ /dev/null @@ -1,143 +0,0 @@ -use dkn_p2p::libp2p::{Multiaddr, PeerId}; -use dkn_workflows::split_csv_line; -use eyre::Result; -use std::{collections::HashSet, env, fmt::Debug, str::FromStr}; -use tokio::time::Instant; - -mod statics; - -use crate::DriaNetworkType; - -impl DriaNetworkType { - /// Returns the URL for fetching available nodes w.r.t network type. - pub fn get_available_nodes_url(&self) -> &str { - match self { - DriaNetworkType::Community => "https://dkn.dria.co/available-nodes", - DriaNetworkType::Pro => "https://dkn.dria.co/sdk/available-nodes", - } - } -} -/// Available nodes within the hybrid P2P network. -/// -/// - Bootstrap: used for Kademlia DHT bootstrap. -/// - Relay: used for DCutR relay protocol. -/// - RPC: used for RPC nodes for task & ping messages. -/// -/// Note that while bootstrap & relay nodes are `Multiaddr`, RPC nodes are `PeerId` because we communicate -/// with them via GossipSub only. -#[derive(Debug, Clone)] -pub struct AvailableNodes { - pub bootstrap_nodes: HashSet, - pub relay_nodes: HashSet, - pub rpc_nodes: HashSet, - pub rpc_addrs: HashSet, - pub last_refreshed: Instant, - pub network_type: DriaNetworkType, -} - -impl AvailableNodes { - /// Creates a new `AvailableNodes` struct for the given network type. - pub fn new(network: DriaNetworkType) -> Self { - Self { - bootstrap_nodes: HashSet::new(), - relay_nodes: HashSet::new(), - rpc_nodes: HashSet::new(), - rpc_addrs: HashSet::new(), - last_refreshed: Instant::now(), - network_type: network, - } - } - - /// Parses static bootstrap & relay nodes from environment variables. - /// - /// The environment variables are: - /// - `DRIA_BOOTSTRAP_NODES`: comma-separated list of bootstrap nodes - /// - `DRIA_RELAY_NODES`: comma-separated list of relay nodes - pub fn populate_with_env(&mut self) { - // parse bootstrap nodes - let bootstrap_nodes = split_csv_line(&env::var("DKN_BOOTSTRAP_NODES").unwrap_or_default()); - if bootstrap_nodes.is_empty() { - log::debug!("No additional bootstrap nodes provided."); - } else { - log::debug!("Using additional bootstrap nodes: {:#?}", bootstrap_nodes); - } - self.bootstrap_nodes.extend(parse_vec(bootstrap_nodes)); - - // parse relay nodes - let relay_nodes = split_csv_line(&env::var("DKN_RELAY_NODES").unwrap_or_default()); - if relay_nodes.is_empty() { - log::debug!("No additional relay nodes provided."); - } else { - log::debug!("Using additional relay nodes: {:#?}", relay_nodes); - } - self.relay_nodes.extend(parse_vec(relay_nodes)); - } - - /// Adds the static nodes to the struct, with respect to network type. - pub fn populate_with_statics(&mut self) { - self.bootstrap_nodes - .extend(self.network_type.get_static_bootstrap_nodes()); - self.relay_nodes - .extend(self.network_type.get_static_relay_nodes()); - self.rpc_nodes - .extend(self.network_type.get_static_rpc_peer_ids()); - } - - /// Refresh available nodes using the API. - pub async fn populate_with_api(&mut self) -> Result<()> { - #[derive(serde::Deserialize, Default, Debug)] - struct AvailableNodesApiResponse { - pub bootstraps: Vec, - pub relays: Vec, - pub rpcs: Vec, - #[serde(rename = "rpcAddrs")] - pub rpc_addrs: Vec, - } - - // make the request w.r.t network type - let response = reqwest::get(self.network_type.get_available_nodes_url()).await?; - let response_body = response.json::().await?; - self.bootstrap_nodes - .extend(parse_vec(response_body.bootstraps)); - self.relay_nodes.extend(parse_vec(response_body.relays)); - self.rpc_addrs.extend(parse_vec(response_body.rpc_addrs)); - self.rpc_nodes - .extend(parse_vec::(response_body.rpcs)); - self.last_refreshed = Instant::now(); - - Ok(()) - } -} - -/// Like `parse` of `str` but for vectors. -fn parse_vec(input: Vec + Debug>) -> Vec -where - T: FromStr, -{ - let parsed = input - .iter() - .filter_map(|s| s.as_ref().parse::().ok()) - .collect::>(); - - if parsed.len() != input.len() { - log::warn!("Some inputs could not be parsed: {:?}", input); - } - parsed -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "run this manually"] - async fn test_get_available_nodes() { - let mut available_nodes = AvailableNodes::new(DriaNetworkType::Community); - available_nodes.populate_with_api().await.unwrap(); - println!("Community: {:#?}", available_nodes); - - let mut available_nodes = AvailableNodes::new(DriaNetworkType::Pro); - available_nodes.populate_with_api().await.unwrap(); - println!("Pro: {:#?}", available_nodes); - } -} diff --git a/compute/src/utils/available_nodes/statics.rs b/compute/src/utils/available_nodes/statics.rs deleted file mode 100644 index d26f079..0000000 --- a/compute/src/utils/available_nodes/statics.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::DriaNetworkType; -use dkn_p2p::libp2p::{Multiaddr, PeerId}; - -impl DriaNetworkType { - /// Static bootstrap nodes for Kademlia. - #[inline(always)] - pub fn get_static_bootstrap_nodes(&self) -> Vec { - match self { - DriaNetworkType::Community => [ - "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", - "/ip4/18.234.39.91/tcp/4001/p2p/16Uiu2HAmJqegPzwuGKWzmb5m3RdSUJ7NhEGWB5jNCd3ca9zdQ9dU", - "/ip4/54.242.44.217/tcp/4001/p2p/16Uiu2HAmR2sAoh9F8jT9AZup9y79Mi6NEFVUbwRvahqtWamfabkz", - "/ip4/52.201.242.227/tcp/4001/p2p/16Uiu2HAmFEUCy1s1gjyHfc8jey4Wd9i5bSDnyFDbWTnbrF2J3KFb", - ].iter(), - DriaNetworkType::Pro => [].iter(), - } - .map(|s| s.parse().expect("could not parse static bootstrap address")) - .collect() - } - - /// Static relay nodes for the `P2pCircuit`. - #[inline(always)] - pub fn get_static_relay_nodes(&self) -> Vec { - match self { - DriaNetworkType::Community => [ - "/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa", - "/ip4/18.232.93.227/tcp/4001/p2p/16Uiu2HAmHeGKhWkXTweHJTA97qwP81ww1W2ntGaebeZ25ikDhd4z", - "/ip4/54.157.219.194/tcp/4001/p2p/16Uiu2HAm7A5QVSy5FwrXAJdNNsdfNAcaYahEavyjnFouaEi22dcq", - "/ip4/54.88.171.104/tcp/4001/p2p/16Uiu2HAm5WP1J6bZC3aHxd7XCUumMt9txAystmbZSaMS2omHepXa", - ].iter(), - DriaNetworkType::Pro => [].iter(), - } - .map(|s| s.parse().expect("could not parse static relay address")) - .collect() - } - - /// Static RPC Peer IDs for the Admin RPC. - #[inline(always)] - pub fn get_static_rpc_peer_ids(&self) -> Vec { - // match self { - // DriaNetworkType::Community => [].iter(), - // DriaNetworkType::Pro => [].iter(), - // } - // .filter_map(|s| s.parse().ok()) - // .collect() - vec![] - } -} - -// help me diff --git a/compute/src/utils/message.rs b/compute/src/utils/message.rs index bbc664d..c3f76f9 100644 --- a/compute/src/utils/message.rs +++ b/compute/src/utils/message.rs @@ -1,10 +1,8 @@ -use crate::utils::{ - crypto::{sha256hash, sign_bytes_recoverable}, - get_current_time_nanos, -}; +use crate::utils::crypto::{sha256hash, sign_bytes_recoverable}; use crate::DRIA_COMPUTE_NODE_VERSION; use base64::{prelude::BASE64_STANDARD, Engine}; use core::fmt; +use dkn_utils::get_current_time_nanos; use ecies::PublicKey; use eyre::{eyre, Context, Result}; use libsecp256k1::{verify, Message, SecretKey, Signature}; @@ -12,7 +10,7 @@ use serde::{Deserialize, Serialize}; /// A message within Dria Knowledge Network. #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct DKNMessage { +pub struct DriaMessage { /// Base64 encoded payload, stores the main result. pub(crate) payload: String, /// The topic of the message, derived from `TopicHash` @@ -39,7 +37,7 @@ pub struct DKNMessage { /// and therefore use 128 characters: SIGNATURE_SIZE - 2. const SIGNATURE_SIZE_HEX: usize = 130; -impl DKNMessage { +impl DriaMessage { /// Creates a new message with current timestamp and version equal to the crate version. /// /// - `data` is given as bytes, it is encoded into base64 to make up the `payload` within. @@ -127,7 +125,7 @@ impl DKNMessage { ) -> Result { // the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately. log::debug!("Parsing {} message.", gossipsub_message.topic.as_str()); - let message = serde_json::from_slice::(&gossipsub_message.data) + let message = serde_json::from_slice::(&gossipsub_message.data) .wrap_err("could not parse message")?; log::debug!("Parsed: {}", message); @@ -141,7 +139,7 @@ impl DKNMessage { } } -impl fmt::Display for DKNMessage { +impl fmt::Display for DriaMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let payload_decoded = self .decode_payload() @@ -156,7 +154,7 @@ impl fmt::Display for DKNMessage { } } -impl TryFrom<&dkn_p2p::libp2p::gossipsub::Message> for DKNMessage { +impl TryFrom<&dkn_p2p::libp2p::gossipsub::Message> for DriaMessage { type Error = serde_json::Error; fn try_from(value: &dkn_p2p::libp2p::gossipsub::Message) -> Result { @@ -187,7 +185,7 @@ mod tests { #[test] #[ignore = "run manually"] fn test_display_message() { - let message = DKNMessage::new(b"hello world", TOPIC); + let message = DriaMessage::new(b"hello world", TOPIC); println!("{}", message); } @@ -196,7 +194,7 @@ mod tests { // create payload & message let body = TestStruct::default(); let data = serde_json::to_vec(&body).expect("Should serialize"); - let message = DKNMessage::new(data, TOPIC); + let message = DriaMessage::new(data, TOPIC); // decode message let message_body = message.decode_payload().expect("Should decode"); @@ -223,7 +221,7 @@ mod tests { // create payload & message with signature & body let body = TestStruct::default(); let body_str = serde_json::to_string(&body).unwrap(); - let message = DKNMessage::new_signed(body_str, TOPIC, &sk); + let message = DriaMessage::new_signed(body_str, TOPIC, &sk); // decode message let message_body = message.decode_payload().expect("Should decode"); diff --git a/compute/src/utils/misc.rs b/compute/src/utils/misc.rs index d7179ef..7914caf 100644 --- a/compute/src/utils/misc.rs +++ b/compute/src/utils/misc.rs @@ -1,23 +1,6 @@ use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr}; use port_check::is_port_reachable; -use std::{ - net::{Ipv4Addr, SocketAddrV4}, - time::SystemTime, -}; - -/// Returns the current time in nanoseconds since the Unix epoch. -/// -/// If a `SystemTimeError` occurs, will return 0 just to keep things running. -#[inline(always)] -pub fn get_current_time_nanos() -> u128 { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or_else(|e| { - log::error!("Error getting current time: {}", e); - Default::default() - }) - .as_nanos() -} +use std::net::{Ipv4Addr, SocketAddrV4}; /// Checks if a given address is already in use locally. /// This is mostly used to see if the P2P address is already in use. diff --git a/compute/src/utils/mod.rs b/compute/src/utils/mod.rs index 5002b14..9485dad 100644 --- a/compute/src/utils/mod.rs +++ b/compute/src/utils/mod.rs @@ -2,10 +2,10 @@ pub mod crypto; pub mod filter; mod message; -pub use message::DKNMessage; - -mod available_nodes; -pub use available_nodes::AvailableNodes; +pub use message::DriaMessage; mod misc; pub use misc::*; + +mod nodes; +pub use nodes::*; diff --git a/compute/src/utils/nodes.rs b/compute/src/utils/nodes.rs new file mode 100644 index 0000000..b00e4ac --- /dev/null +++ b/compute/src/utils/nodes.rs @@ -0,0 +1,69 @@ +use dkn_p2p::{libp2p::PeerId, DriaNetworkType, DriaNodes}; +use dkn_utils::parse_vec; +use eyre::Result; + +/// Refresh available nodes using the API. +pub async fn refresh_dria_nodes(nodes: &mut DriaNodes) -> Result<()> { + #[derive(serde::Deserialize, Default, Debug)] + struct AvailableNodesApiResponse { + pub bootstraps: Vec, + pub relays: Vec, + pub rpcs: Vec, + #[serde(rename = "rpcAddrs")] + pub rpc_addrs: Vec, + } + + // url to be used is determined by the network type + let url = match nodes.network_type { + DriaNetworkType::Community => "https://dkn.dria.co/available-nodes", + DriaNetworkType::Pro => "https://dkn.dria.co/sdk/available-nodes", + DriaNetworkType::Test => "https://dkn.dria.co/test/available-nodes", + }; + + // make the request + let response = reqwest::get(url).await?; + let response_body = response.json::().await?; + nodes + .bootstrap_nodes + .extend(parse_vec(response_body.bootstraps).unwrap_or_else(|e| { + log::error!("Failed to parse bootstrap nodes: {}", e); + vec![] + })); + nodes + .relay_nodes + .extend(parse_vec(response_body.relays).unwrap_or_else(|e| { + log::error!("Failed to parse relay nodes: {}", e); + vec![] + })); + nodes + .rpc_nodes + .extend(parse_vec(response_body.rpc_addrs).unwrap_or_else(|e| { + log::error!("Failed to parse rpc nodes: {}", e); + vec![] + })); + nodes + .rpc_peerids + .extend(parse_vec::(response_body.rpcs).unwrap_or_else(|e| { + log::error!("Failed to parse rpc peerids: {}", e); + vec![] + })); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "run this manually"] + async fn test_refresh_dria_nodes() { + let mut nodes = DriaNodes::new(DriaNetworkType::Community); + refresh_dria_nodes(&mut nodes).await.unwrap(); + println!("Community: {:#?}", nodes); + + let mut nodes = DriaNodes::new(DriaNetworkType::Pro); + refresh_dria_nodes(&mut nodes).await.unwrap(); + println!("Pro: {:#?}", nodes); + } +} diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index 7857512..caae8d5 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -34,5 +34,7 @@ eyre.workspace = true tokio-util.workspace = true tokio.workspace = true +dkn-utils = { path = "../utils" } + [dev-dependencies] env_logger.workspace = true diff --git a/p2p/src/client.rs b/p2p/src/client.rs index c4a8eac..f54d222 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -95,7 +95,6 @@ impl DriaP2PClient { Protocol::P2p(peer_id) => Some(peer_id), _ => None, }) { - log::info!("Dialling peer: {}", addr); swarm.dial(addr.clone())?; log::info!("Adding {} to Kademlia routing table", addr); swarm.behaviour_mut().kademlia.add_address(&peer_id, addr); @@ -333,6 +332,7 @@ impl DriaP2PClient { // error // ); // } + // SwarmEvent::IncomingConnection { // connection_id, // local_addr, @@ -345,6 +345,7 @@ impl DriaP2PClient { // send_back_addr // ); // } + // SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => { // if let Some(peer_id) = peer_id { // log::warn!("Could not connect to peer {}: {:?}", peer_id, error); diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 2415c6c..74b8913 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -11,6 +11,12 @@ pub use commands::{DriaP2PCommand, DriaP2PCommander}; mod protocol; pub use protocol::DriaP2PProtocol; +mod network; +pub use network::DriaNetworkType; + +mod nodes; +pub use nodes::DriaNodes; + // re-exports pub use libp2p; pub use libp2p_identity; diff --git a/p2p/src/network.rs b/p2p/src/network.rs new file mode 100644 index 0000000..387f662 --- /dev/null +++ b/p2p/src/network.rs @@ -0,0 +1,78 @@ +use libp2p::{Multiaddr, PeerId}; + +/// Network type. +#[derive(Default, Debug, Clone, Copy)] +pub enum DriaNetworkType { + #[default] + Community, + Pro, + Test, +} + +impl From<&str> for DriaNetworkType { + fn from(s: &str) -> Self { + match s { + "community" => DriaNetworkType::Community, + "pro" => DriaNetworkType::Pro, + "test" => DriaNetworkType::Test, + _ => Default::default(), + } + } +} + +impl DriaNetworkType { + /// Returns the protocol name. + pub fn protocol_name(&self) -> &str { + match self { + DriaNetworkType::Community => "dria", + DriaNetworkType::Pro => "dria-sdk", + DriaNetworkType::Test => "dria-test", + } + } + + /// Static bootstrap nodes for Kademlia. + #[inline(always)] + pub fn get_static_bootstrap_nodes(&self) -> Vec { + match self { + DriaNetworkType::Community => [ + "/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4", + "/ip4/18.234.39.91/tcp/4001/p2p/16Uiu2HAmJqegPzwuGKWzmb5m3RdSUJ7NhEGWB5jNCd3ca9zdQ9dU", + "/ip4/54.242.44.217/tcp/4001/p2p/16Uiu2HAmR2sAoh9F8jT9AZup9y79Mi6NEFVUbwRvahqtWamfabkz", + "/ip4/52.201.242.227/tcp/4001/p2p/16Uiu2HAmFEUCy1s1gjyHfc8jey4Wd9i5bSDnyFDbWTnbrF2J3KFb", + ].iter(), + DriaNetworkType::Pro => [].iter(), + DriaNetworkType::Test => [].iter(), + } + .map(|s| s.parse().expect("could not parse static bootstrap address")) + .collect() + } + + /// Static relay nodes for the `P2pCircuit`. + #[inline(always)] + pub fn get_static_relay_nodes(&self) -> Vec { + match self { + DriaNetworkType::Community => [ + "/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa", + "/ip4/18.232.93.227/tcp/4001/p2p/16Uiu2HAmHeGKhWkXTweHJTA97qwP81ww1W2ntGaebeZ25ikDhd4z", + "/ip4/54.157.219.194/tcp/4001/p2p/16Uiu2HAm7A5QVSy5FwrXAJdNNsdfNAcaYahEavyjnFouaEi22dcq", + "/ip4/54.88.171.104/tcp/4001/p2p/16Uiu2HAm5WP1J6bZC3aHxd7XCUumMt9txAystmbZSaMS2omHepXa", + ].iter(), + DriaNetworkType::Pro => [].iter(), + DriaNetworkType::Test => [].iter(), + } + .map(|s| s.parse().expect("could not parse static relay address")) + .collect() + } + + /// Static RPC Peer IDs. + #[inline(always)] + pub fn get_static_rpc_peer_ids(&self) -> Vec { + // match self { + // DriaNetworkType::Community => [].iter(), + // DriaNetworkType::Pro => [].iter(), + // } + // .filter_map(|s| s.parse().ok()) + // .collect() + vec![] + } +} diff --git a/p2p/src/nodes.rs b/p2p/src/nodes.rs new file mode 100644 index 0000000..627ebda --- /dev/null +++ b/p2p/src/nodes.rs @@ -0,0 +1,72 @@ +use crate::DriaNetworkType; +use dkn_utils::{parse_vec, split_csv_line}; +use libp2p::{Multiaddr, PeerId}; +use std::{collections::HashSet, env, fmt::Debug}; + +/// Dria-owned nodes within the hybrid P2P network. +/// +/// - Bootstrap: used for Kademlia DHT bootstrap. +/// - Relay: used for DCutR relay protocol. +/// - RPC: used for RPC nodes for task & ping messages. +#[derive(Debug, Clone)] +pub struct DriaNodes { + pub bootstrap_nodes: HashSet, + pub relay_nodes: HashSet, + pub rpc_nodes: HashSet, + pub rpc_peerids: HashSet, + pub network_type: DriaNetworkType, +} + +impl DriaNodes { + /// Creates a new `AvailableNodes` struct for the given network type. + pub fn new(network: DriaNetworkType) -> Self { + Self { + bootstrap_nodes: HashSet::new(), + relay_nodes: HashSet::new(), + rpc_peerids: HashSet::new(), + rpc_nodes: HashSet::new(), + network_type: network, + } + } + + /// Parses static bootstrap & relay nodes from environment variables. + /// + /// The environment variables are: + /// - `DRIA_BOOTSTRAP_NODES`: comma-separated list of bootstrap nodes + /// - `DRIA_RELAY_NODES`: comma-separated list of relay nodes + pub fn with_envs(mut self) -> Self { + // parse bootstrap nodes + let bootstrap_nodes = split_csv_line(&env::var("DKN_BOOTSTRAP_NODES").unwrap_or_default()); + if bootstrap_nodes.is_empty() { + log::debug!("No additional bootstrap nodes provided."); + } else { + log::debug!("Using additional bootstrap nodes: {:#?}", bootstrap_nodes); + } + self.bootstrap_nodes + .extend(parse_vec(bootstrap_nodes).expect("could not parse bootstrap nodes")); + + // parse relay nodes + let relay_nodes = split_csv_line(&env::var("DKN_RELAY_NODES").unwrap_or_default()); + if relay_nodes.is_empty() { + log::debug!("No additional relay nodes provided."); + } else { + log::debug!("Using additional relay nodes: {:#?}", relay_nodes); + } + self.relay_nodes + .extend(parse_vec(relay_nodes).expect("could not parse relay nodes")); + + self + } + + /// Adds the static nodes to the struct, with respect to network type. + pub fn with_statics(mut self) -> Self { + self.bootstrap_nodes + .extend(self.network_type.get_static_bootstrap_nodes()); + self.relay_nodes + .extend(self.network_type.get_static_relay_nodes()); + self.rpc_peerids + .extend(self.network_type.get_static_rpc_peer_ids()); + + self + } +} diff --git a/utils/Cargo.toml b/utils/Cargo.toml new file mode 100644 index 0000000..0f99c15 --- /dev/null +++ b/utils/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "dkn-utils" +version.workspace = true +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + +[dependencies] diff --git a/utils/README.md b/utils/README.md new file mode 100644 index 0000000..9131908 --- /dev/null +++ b/utils/README.md @@ -0,0 +1,19 @@ +# Dria Utils + +Just small utility functions such as reading environment variables or splitting strings etc. + +## Installation + +Add the package via `git` within your Cargo dependencies: + +```toml +dkn-utils = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } +``` + +## Usage + +```rs +use dkn_utils::*; + +// use whatever you like! +``` diff --git a/workflows/src/utils.rs b/utils/src/lib.rs similarity index 71% rename from workflows/src/utils.rs rename to utils/src/lib.rs index 2e7b800..9d4c4cd 100644 --- a/workflows/src/utils.rs +++ b/utils/src/lib.rs @@ -1,3 +1,5 @@ +use std::{fmt::Debug, str::FromStr, time::SystemTime}; + /// Utility to parse comma-separated string value line. /// /// - Trims `"` from both ends for the input @@ -25,6 +27,29 @@ pub fn safe_read_env(var: Result) -> Option .filter(|s| !s.is_empty()) } +/// Like `parse` of `str` but for vectors. +pub fn parse_vec(input: Vec + Debug>) -> Result, T::Err> +where + T: FromStr, +{ + let parsed = input + .iter() + .map(|s| s.as_ref().parse::()) + .collect::, _>>()?; + + Ok(parsed) +} + +/// Returns the current time in nanoseconds since the Unix epoch. +/// +/// If a `SystemTimeError` occurs, will return 0 just to keep things running. +pub fn get_current_time_nanos() -> u128 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() +} + #[cfg(test)] mod tests { use super::*; diff --git a/workflows/Cargo.toml b/workflows/Cargo.toml index a6f26ae..fcf946a 100644 --- a/workflows/Cargo.toml +++ b/workflows/Cargo.toml @@ -25,6 +25,7 @@ reqwest.workspace = true # utilities rand.workspace = true +dkn-utils = { path = "../utils" } # logging & errors log.workspace = true @@ -42,4 +43,4 @@ dotenvy.workspace = true [[bin]] name = "tps" -path = "src/bin/tps.rs" \ No newline at end of file +path = "src/bin/tps.rs" diff --git a/workflows/src/apis/jina.rs b/workflows/src/apis/jina.rs index 7492c30..2949026 100644 --- a/workflows/src/apis/jina.rs +++ b/workflows/src/apis/jina.rs @@ -1,9 +1,8 @@ +use dkn_utils::safe_read_env; use eyre::{eyre, Context, Result}; use reqwest::Client; use std::env; -use crate::utils::safe_read_env; - const ENV_VAR_NAME: &str = "JINA_API_KEY"; /// Jina-specific configurations. diff --git a/workflows/src/apis/serper.rs b/workflows/src/apis/serper.rs index 0fb5cde..f2ebefe 100644 --- a/workflows/src/apis/serper.rs +++ b/workflows/src/apis/serper.rs @@ -1,9 +1,8 @@ +use dkn_utils::safe_read_env; use eyre::{eyre, Context, Result}; use reqwest::Client; use std::env; -use crate::utils::safe_read_env; - const ENV_VAR_NAME: &str = "SERPER_API_KEY"; /// Serper-specific configurations. diff --git a/workflows/src/config.rs b/workflows/src/config.rs index e26f1df..0b38fd8 100644 --- a/workflows/src/config.rs +++ b/workflows/src/config.rs @@ -1,8 +1,9 @@ use crate::{ apis::{JinaConfig, SerperConfig}, providers::{GeminiConfig, OllamaConfig, OpenAIConfig, OpenRouterConfig}, - split_csv_line, Model, ModelProvider, + Model, ModelProvider, }; +use dkn_utils::split_csv_line; use eyre::{eyre, Result}; use rand::seq::IteratorRandom; // provides Vec<_>.choose diff --git a/workflows/src/lib.rs b/workflows/src/lib.rs index 76ecd2e..91d006c 100644 --- a/workflows/src/lib.rs +++ b/workflows/src/lib.rs @@ -3,9 +3,6 @@ pub use providers::OllamaConfig; mod apis; -mod utils; -pub use utils::split_csv_line; - mod config; pub use config::DriaWorkflowsConfig; diff --git a/workflows/src/providers/gemini.rs b/workflows/src/providers/gemini.rs index 4f694bf..e1de846 100644 --- a/workflows/src/providers/gemini.rs +++ b/workflows/src/providers/gemini.rs @@ -1,11 +1,10 @@ +use dkn_utils::safe_read_env; use eyre::{eyre, Context, Result}; use ollama_workflows::Model; use reqwest::Client; use serde::Deserialize; use std::env; -use crate::utils::safe_read_env; - const ENV_VAR_NAME: &str = "GEMINI_API_KEY"; /// OpenAI-specific configurations. diff --git a/workflows/src/providers/openai.rs b/workflows/src/providers/openai.rs index 0ff1c11..72fc581 100644 --- a/workflows/src/providers/openai.rs +++ b/workflows/src/providers/openai.rs @@ -1,11 +1,10 @@ +use dkn_utils::safe_read_env; use eyre::{eyre, Context, Result}; use ollama_workflows::Model; use reqwest::Client; use serde::Deserialize; use std::env; -use crate::utils::safe_read_env; - const ENV_VAR_NAME: &str = "OPENAI_API_KEY"; /// OpenAI-specific configurations. diff --git a/workflows/src/providers/openrouter.rs b/workflows/src/providers/openrouter.rs index a1c5cd8..2b7e278 100644 --- a/workflows/src/providers/openrouter.rs +++ b/workflows/src/providers/openrouter.rs @@ -1,10 +1,9 @@ +use dkn_utils::safe_read_env; use eyre::{eyre, Context, Result}; use ollama_workflows::Model; use reqwest::Client; use std::env; -use crate::utils::safe_read_env; - const ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; /// OpenRouter-specific configurations. From 6607cef3f5eac3fccb52dd17b8c47b48891a4c1c Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 2 Dec 2024 19:19:53 +0300 Subject: [PATCH 15/16] fix env logger init --- Makefile | 2 +- compute/src/main.rs | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index f0a7baa..36e911e 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ launch: .PHONY: run # | Run with INFO logs run: - RUST_LOG=none,dkn_compute=info,dkn_workflows=info,dkn_p2p=info cargo run + cargo run .PHONY: debug # | Run with DEBUG logs with INFO log-level workflows debug: diff --git a/compute/src/main.rs b/compute/src/main.rs index c0028c3..7d9e24d 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -9,6 +9,11 @@ async fn main() -> Result<()> { env_logger::builder() .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) + .filter(None, log::LevelFilter::Off) + .filter_module("dkn_compute", log::LevelFilter::Info) + .filter_module("dkn_p2p", log::LevelFilter::Info) + .filter_module("dkn_workflows", log::LevelFilter::Info) + .parse_default_env() // reads RUST_LOG variable .init(); if let Err(e) = dotenv_result { log::warn!("could not load .env file: {}", e); From 113aa65a44d334a9d67b107e0080f512d20b882e Mon Sep 17 00:00:00 2001 From: erhant Date: Mon, 2 Dec 2024 21:33:08 +0300 Subject: [PATCH 16/16] update workflows --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index af06e03..7ae991a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3455,7 +3455,7 @@ dependencies = [ [[package]] name = "ollama-workflows" version = "0.1.0" -source = "git+https://github.com/andthattoo/ollama-workflows#3f364cd92501d8fb065ef9b585f237809d0bddf5" +source = "git+https://github.com/andthattoo/ollama-workflows#5b26cb4fdd5278295534cbedf79acdf4f9cd2b37" dependencies = [ "async-trait", "base64 0.22.1",