diff --git a/Cargo.lock b/Cargo.lock index c7ab058..2f4e246 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,6 +40,18 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -201,6 +213,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -478,26 +496,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "cstr" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68523903c8ae5aacfa32a0d9ae60cadeb764e1da14ee0d26b1f3089f13a54636" -dependencies = [ - "proc-macro2", - "quote", -] - -[[package]] -name = "ctrlc" -version = "3.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" -dependencies = [ - "nix", - "windows-sys 0.52.0", -] - [[package]] name = "darling" version = "0.20.9" @@ -882,7 +880,7 @@ dependencies = [ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash 0.8.11", "allocator-api2", @@ -894,7 +892,7 @@ version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ - "base64", + "base64 0.21.7", "byteorder", "crossbeam-channel", "flate2", @@ -904,9 +902,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.5.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" @@ -985,12 +983,6 @@ dependencies = [ "png", ] -[[package]] -name = "is_terminal_polyfill" -version = "1.70.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" - [[package]] name = "itertools" version = "0.11.0" @@ -1002,9 +994,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1058,20 +1050,19 @@ dependencies = [ [[package]] name = "latte-cli" -version = "0.25.2-scylladb" +version = "0.26.1-scylladb" dependencies = [ "anyhow", - "base64", + "base64 0.22.1", "chrono", "clap", "console", "cpu-time", - "ctrlc", "err-derive", "futures", "hdrhistogram", "hytra", - "itertools 0.12.1", + "itertools 0.13.0", "jemallocator", "lazy_static", "metrohash", @@ -1224,9 +1215,9 @@ dependencies = [ [[package]] name = "nalgebra" -version = "0.29.0" +version = "0.32.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4" dependencies = [ "approx", "matrixmultiply", @@ -1242,27 +1233,15 @@ dependencies = [ [[package]] name = "nalgebra-macros" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998" dependencies = [ "proc-macro2", "quote", "syn 1.0.109", ] -[[package]] -name = "nix" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" -dependencies = [ - "bitflags 2.5.0", - "cfg-if", - "cfg_aliases", - "libc", -] - [[package]] name = "nom" version = "7.1.3" @@ -1971,7 +1950,7 @@ dependencies = [ "chrono", "dashmap", "futures", - "hashbrown 0.14.5", + "hashbrown 0.14.3", "histogram", "itertools 0.11.0", "lazy_static", @@ -2092,11 +2071,20 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "simba" -version = "0.6.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" dependencies = [ "approx", "num-complex 0.4.6", @@ -2153,12 +2141,11 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "statrs" -version = "0.16.1" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b35a062dbadac17a42e0fc64c27f419b25d6fae98572eb43c8814c9e873d7721" +checksum = "f697a07e4606a0a25c044de247e583a330dbb1731d11bc7350b81f48ad567255" dependencies = [ "approx", - "lazy_static", "nalgebra", "num-traits", "rand", @@ -2182,24 +2169,24 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "723b93e8addf9aa965ebe2d11da6d7540fa2283fcea14b3371ff055f7ba13f5f" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.26.4" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.50", ] [[package]] @@ -2307,6 +2294,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", @@ -2755,7 +2743,7 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" name = "windows_x86_64_msvc" version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" [[package]] name = "wio" @@ -2795,5 +2783,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.50", ] diff --git a/Cargo.toml b/Cargo.toml index e5093a2..1d40286 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "latte-cli" description = "A database benchmarking tool for Apache Cassandra" -version = "0.25.2-scylladb" +version = "0.26.1-scylladb" authors = ["Piotr Kołaczkowski "] edition = "2021" readme = "README.md" @@ -14,19 +14,18 @@ path = "src/main.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] anyhow = "1.0" -base64 = "0.21" +base64 = "0.22" rmp = "0.8.10" rmp-serde = "1.0.0-beta.2" chrono = { version = "0.4.18", features = ["serde"] } clap = { version = "4", features = ["derive", "cargo", "env"] } console = "0.15.0" cpu-time = "1.0.0" -ctrlc = "3.2.1" err-derive = "0.3" futures = "0.3" hdrhistogram = "7.1.0" hytra = "0.1.2" -itertools = "0.12" +itertools = "0.13" jemallocator = "0.5" lazy_static = "1.4.0" metrohash = "1.0" @@ -43,13 +42,13 @@ scylla = { version = "0.13", features = ["ssl"] } search_path = "0.1" serde = { version = "1.0.116", features = ["derive"] } serde_json = "1.0.57" -statrs = "0.16" +statrs = "0.17" status-line = "0.2.0" strum = { version = "0.26", features = ["derive"] } strum_macros = "0.26" time = "0.3" thiserror = "1.0.26" -tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "parking_lot"] } +tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "parking_lot", "signal"] } tokio-stream = "0.1" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/README.md b/README.md index 84d89d3..cde656b 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,19 @@ pub async fn run(ctx, i) { } ``` +Query parameters can be bound and passed by names as well: +```rust +const INSERT = "my_insert"; + +pub async fn prepare(ctx) { + ctx.prepare(INSERT, "INSERT INTO test.test(id, data) VALUES (:id, :data)").await?; +} + +pub async fn run(ctx, i) { + ctx.execute_prepared(INSERT, #{id: 5, data: "foo"}).await +} +``` + ### Populating the database Read queries are more interesting when they return non-empty result sets. @@ -209,18 +222,20 @@ are pure, i.e. invoking them multiple times with the same parameters yields alwa - `latte::hash_select(i, vector)` – selects an item from a vector based on a hash - `latte::blob(i, len)` – generates a random binary blob of length `len` - `latte::normal(i, mean, std_dev)` – generates a floating point number from a normal distribution - -#### Numeric conversions - -Rune represents integers as 64-bit signed values. Therefore, it is possible to directly pass a Rune integer to -a Cassandra column of type `bigint`. However, binding a 64-bit value to smaller integer column types, like -`int`, `smallint` or `tinyint` will result in a runtime error. As long as an integer value does not exceed the bounds, -you can convert it to smaller signed integer types by using the following instance functions: - -- `x.to_i32()` – converts a float or integer to a 32-bit signed integer, compatible with Cassandra `int` type -- `x.to_i16()` – converts a float or integer to a 16-bit signed integer, compatible with Cassandra `smallint` type -- `x.to_i8()` – converts a float or integer to an 8-bit signed integer, compatible with Cassandra `tinyint` type -- `x.clamp(min, max)` – restricts the range of an integer or a float value to given range +- `latte::uniform(i, min, max)` – generates a floating point number from a uniform distribution + +#### Type conversions +Rune uses 64-bit representation for integers and floats. +Since version 0.28 Rune numbers are automatically converted to proper target query parameter type, +therefore you don't need to do explicit conversions. E.g. you can pass an integer as a parameter +of Cassandra type `smallint`. If the number is too big to fit into the range allowed by the target +type, a runtime error will be signalled. + +The following methods are available: +- `x.to_integer()` – converts a float to an integer +- `x.to_float()` – converts an integer to a float +- `x.to_string()` – converts a float or integer to a string +- `x.clamp(min, max)` – restricts the range of an integer or a float value to given range You can also convert between floats and integers by calling `to_integer` or `to_float` instance functions. diff --git a/src/config.rs b/src/config.rs index d50f3dd..66256f2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -97,7 +97,7 @@ impl RetryInterval { if values.len() > 2 { return None; } - let min_ms = RetryInterval::parse_time(values.get(0).unwrap_or(&""))?; + let min_ms = RetryInterval::parse_time(values.first().unwrap_or(&""))?; let max_ms = RetryInterval::parse_time(values.get(1).unwrap_or(&"")).unwrap_or(min_ms); if min_ms > max_ms { None @@ -136,7 +136,8 @@ impl FromStr for RetryInterval { Err(concat!( "Expected 1 or 2 parts separated by comma such as '500ms' or '200ms,5s' or '1s'.", " First value cannot be bigger than second one.", - ).to_string()) + ) + .to_string()) } } } @@ -194,9 +195,12 @@ pub struct ConnectionConf { #[clap(long("retry-number"), default_value = "10", value_name = "COUNT")] pub retry_number: u64, - #[clap(long("retry-interval"), default_value = "100ms,5s", value_name = "TIME[,TIME]")] + #[clap( + long("retry-interval"), + default_value = "100ms,5s", + value_name = "TIME[,TIME]" + )] pub retry_interval: RetryInterval, - } #[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] @@ -425,8 +429,8 @@ impl RunCommand { } /// Returns the value of parameter under given key. - /// If key doesn't exist, or parameter is not an integer, returns `None`. - pub fn get_param(&self, key: &str) -> Option { + /// If key doesn't exist, or parameter is not a number, returns `None`. + pub fn get_param(&self, key: &str) -> Option { self.params .iter() .find(|(k, _)| k == key) @@ -554,6 +558,7 @@ pub struct AppConfig { } #[derive(Debug, Deserialize, Default)] +#[allow(unused)] pub struct SchemaConfig { #[serde(default)] pub script: Vec, @@ -562,6 +567,7 @@ pub struct SchemaConfig { } #[derive(Debug, Deserialize)] +#[allow(unused)] pub struct LoadConfig { pub count: u64, #[serde(default)] @@ -577,6 +583,7 @@ mod defaults { } #[derive(Debug, Deserialize)] +#[allow(unused)] pub struct RunConfig { #[serde(default = "defaults::ratio")] pub ratio: f64, @@ -587,6 +594,7 @@ pub struct RunConfig { } #[derive(Debug, Deserialize)] +#[allow(unused)] pub struct WorkloadConfig { #[serde(default)] pub schema: SchemaConfig, diff --git a/src/context.rs b/src/context.rs index 6a05d56..9fed1e5 100644 --- a/src/context.rs +++ b/src/context.rs @@ -4,6 +4,8 @@ use std::fs::File; use std::hash::{Hash, Hasher}; use std::io; use std::io::{BufRead, BufReader, ErrorKind, Read}; +use std::net::IpAddr; +use std::str::FromStr; use std::sync::Arc; use anyhow::anyhow; @@ -15,7 +17,7 @@ use openssl::error::ErrorStack; use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod}; use rand::distributions::Distribution; use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::{random, Rng, SeedableRng}; use rune::ast; use rune::ast::Kind; use rune::macros::{quote, MacroContext, TokenStream}; @@ -23,18 +25,20 @@ use rune::parse::Parser; use rune::runtime::{Object, Shared, TypeInfo, VmError}; use rune::{Any, Value}; use rust_embed::RustEmbed; +use scylla::_macro_internal::ColumnType; use scylla::frame::response::result::CqlValue; +use scylla::frame::value::CqlTimeuuid; use scylla::load_balancing::DefaultPolicy; use scylla::prepared_statement::PreparedStatement; use scylla::transport::errors::{DbError, NewSessionError, QueryError}; use scylla::transport::session::PoolSize; use scylla::{ExecutionProfile, QueryResult, SessionBuilder}; -use statrs::distribution::Normal; +use statrs::distribution::{Normal, Uniform}; use tokio::time::{Duration, Instant}; use try_lock::TryLock; use uuid::{Variant, Version}; -use crate::config::{ConnectionConf, PRINT_RETRY_ERROR_LIMIT, RetryInterval}; +use crate::config::{ConnectionConf, RetryInterval, PRINT_RETRY_ERROR_LIMIT}; use crate::LatteError; fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> { @@ -59,7 +63,7 @@ fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> pub async fn connect(conf: &ConnectionConf) -> Result { let mut policy_builder = DefaultPolicy::builder().token_aware(true); let dc = &conf.datacenter; - if dc.len() > 0 { + if !dc.is_empty() { policy_builder = policy_builder.prefer_datacenter(dc.to_owned()).permit_dc_failover(true); } let profile = ExecutionProfile::builder() @@ -77,7 +81,11 @@ pub async fn connect(conf: &ConnectionConf) -> Result { .build() .await .map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?; - Ok(Context::new(scylla_session, conf.retry_number, conf.retry_interval)) + Ok(Context::new( + scylla_session, + conf.retry_number, + conf.retry_interval, + )) } pub struct ClusterInfo { @@ -89,17 +97,21 @@ pub struct ClusterInfo { pub fn cql_value_obj_to_string(v: &CqlValue) -> String { let no_transformation_size_limit = 32; match v { - // Replace big string- and bytes-alike object values with it's size labels + // Replace big string- and bytes-alike object values with its size labels CqlValue::Text(param) if param.len() > no_transformation_size_limit => { format!("Text(={})", param.len()) - }, + } CqlValue::Ascii(param) if param.len() > no_transformation_size_limit => { format!("Ascii(={})", param.len()) - }, + } CqlValue::Blob(param) if param.len() > no_transformation_size_limit => { format!("Blob(={})", param.len()) - }, - CqlValue::UserDefinedType { keyspace, type_name, fields } => { + } + CqlValue::UserDefinedType { + keyspace, + type_name, + fields, + } => { let mut result = format!( "UDT {{ keyspace: \"{}\", type_name: \"{}\", fields: [", keyspace, type_name, @@ -114,9 +126,9 @@ pub fn cql_value_obj_to_string(v: &CqlValue) -> String { if result.len() >= 2 { result.truncate(result.len() - 2); } - result.push_str(&format!("] }}")); + result.push_str("] }"); result - }, + } CqlValue::List(elements) => { let mut result = String::from("List(["); for element in elements { @@ -129,7 +141,7 @@ pub fn cql_value_obj_to_string(v: &CqlValue) -> String { } result.push_str("])"); result - }, + } CqlValue::Set(elements) => { let mut result = String::from("Set(["); for element in elements { @@ -142,7 +154,7 @@ pub fn cql_value_obj_to_string(v: &CqlValue) -> String { } result.push_str("])"); result - }, + } CqlValue::Map(pairs) => { let mut result = String::from("Map({"); for (key, value) in pairs { @@ -155,7 +167,7 @@ pub fn cql_value_obj_to_string(v: &CqlValue) -> String { } result.push_str("})"); result - }, + } _ => format!("{v:?}"), } } @@ -171,7 +183,7 @@ impl CassError { fn query_execution_error(cql: &str, params: &[CqlValue], err: QueryError) -> CassError { let query = QueryInfo { cql: cql.to_string(), - params: params.iter().map(|v| cql_value_obj_to_string(v)).collect(), + params: params.iter().map(cql_value_obj_to_string).collect(), }; let kind = match err { QueryError::RequestTimeout(_) @@ -186,9 +198,10 @@ impl CassError { } fn query_retries_exceeded(retry_number: u64) -> CassError { - CassError(CassErrorKind::QueryRetriesExceeded( - format!("Max retry attempts ({}) reached", retry_number) - )) + CassError(CassErrorKind::QueryRetriesExceeded(format!( + "Max retry attempts ({}) reached", + retry_number + ))) } } @@ -215,8 +228,11 @@ pub enum CassErrorKind { FailedToConnect(Vec, NewSessionError), PreparedStatementNotFound(String), QueryRetriesExceeded(String), + QueryParamConversion(TypeInfo, ColumnType), + ValueOutOfRange(String, ColumnType), + InvalidNumberOfQueryParams, + InvalidQueryParamsObject(TypeInfo), WrongDataStructure(String), - UnsupportedType(TypeInfo), Prepare(String, QueryError), Overloaded(QueryInfo, QueryError), QueryExecution(QueryInfo, QueryError), @@ -238,12 +254,24 @@ impl CassError { CassErrorKind::QueryRetriesExceeded(s) => { write!(buf, "QueryRetriesExceeded: {s}") } + CassErrorKind::ValueOutOfRange(v, t) => { + write!(buf, "Value {v} out of range for Cassandra type {t:?}") + } + CassErrorKind::QueryParamConversion(s, t) => { + write!( + buf, + "Cannot convert value of type {s} to Cassandra type {t:?}" + ) + } + CassErrorKind::InvalidNumberOfQueryParams => { + write!(buf, "Incorrect number of query parameters") + } + CassErrorKind::InvalidQueryParamsObject(t) => { + write!(buf, "Value of type {t} cannot by used as query parameters; expected a list or object") + } CassErrorKind::WrongDataStructure(s) => { write!(buf, "Wrong data structure: {s}") } - CassErrorKind::UnsupportedType(s) => { - write!(buf, "Unsupported type: {s}") - } CassErrorKind::Prepare(q, e) => { write!(buf, "Failed to prepare query \"{q}\": {e}") } @@ -353,35 +381,45 @@ impl Default for SessionStats { } } -pub fn get_expoinential_retry_interval(min_interval: u64, - max_interval: u64, - current_attempt_num: u64) -> u64 { +pub fn get_exponential_retry_interval( + min_interval: u64, + max_interval: u64, + current_attempt_num: u64, +) -> u64 { let min_interval_float: f64 = min_interval as f64; - let mut current_interval: f64 = min_interval_float * ( - 2u64.pow((current_attempt_num - 1).try_into().unwrap_or(0)) as f64 - ); + let mut current_interval: f64 = + min_interval_float * (2u64.pow(current_attempt_num.try_into().unwrap_or(0)) as f64); // Add jitter - current_interval += rand::thread_rng().gen::() * min_interval_float; + current_interval += random::() * min_interval_float; current_interval -= min_interval_float / 2.0; - std::cmp::min(current_interval as u64, max_interval as u64) as u64 + std::cmp::min(current_interval as u64, max_interval) } -pub async fn handle_retry_error(ctxt: &Context, current_attempt_num: u64, current_error: CassError) { - let current_retry_interval = get_expoinential_retry_interval( - ctxt.retry_interval.min_ms, ctxt.retry_interval.max_ms, current_attempt_num, +pub async fn handle_retry_error( + ctxt: &Context, + current_attempt_num: u64, + current_error: CassError, +) { + let current_retry_interval = get_exponential_retry_interval( + ctxt.retry_interval.min_ms, + ctxt.retry_interval.max_ms, + current_attempt_num, ); let mut next_attempt_str = String::new(); let is_last_attempt = current_attempt_num == ctxt.retry_number; if !is_last_attempt { - next_attempt_str += &format!("[Retry in {}ms]", current_retry_interval); + next_attempt_str += &format!("[Retry in {} ms]", current_retry_interval); } let err_msg = format!( "{}: [ERROR][Attempt {}/{}]{} {}", - Utc::now().format("%Y-%m-%d %H:%M:%S%.3f").to_string(), - current_attempt_num, ctxt.retry_number, next_attempt_str, current_error, + Utc::now().format("%Y-%m-%d %H:%M:%S%.3f"), + current_attempt_num, + ctxt.retry_number, + next_attempt_str, + current_error, ); if !is_last_attempt { ctxt.stats.try_lock().unwrap().store_retry_error(err_msg); @@ -407,22 +445,26 @@ pub struct Context { } // Needed, because Rune `Value` is !Send, as it may contain some internal pointers. -// Therefore it is not safe to pass a `Value` to another thread by cloning it, because +// Therefore, it is not safe to pass a `Value` to another thread by cloning it, because // both objects could accidentally share some unprotected, `!Sync` data. -// To make it safe, the same `Context` is never used by more than one thread at once and +// To make it safe, the same `Context` is never used by more than one thread at once, and // we make sure in `clone` to make a deep copy of the `data` field by serializing // and deserializing it, so no pointers could get through. unsafe impl Send for Context {} unsafe impl Sync for Context {} impl Context { - pub fn new(session: scylla::Session, retry_number: u64, retry_interval: RetryInterval) -> Context { + pub fn new( + session: scylla::Session, + retry_number: u64, + retry_interval: RetryInterval, + ) -> Context { Context { session: Arc::new(session), statements: HashMap::new(), stats: TryLock::new(SessionStats::new()), - retry_number: retry_number, - retry_interval: retry_interval, + retry_number, + retry_interval, load_cycle_count: 0, data: Value::Object(Shared::new(Object::new())), } @@ -480,7 +522,7 @@ impl Context { /// Executes an ad-hoc CQL statement with no parameters. Does not prepare. pub async fn execute(&self, cql: &str) -> Result<(), CassError> { - for current_attempt_num in 0..self.retry_number+1 { + for current_attempt_num in 0..self.retry_number + 1 { let start_time = self.stats.try_lock().unwrap().start_request(); let rs = self.session.query(cql, ()).await; let duration = Instant::now() - start_time; @@ -489,12 +531,15 @@ impl Context { Err(e) => { let current_error = CassError::query_execution_error(cql, &[], e.clone()); handle_retry_error(self, current_attempt_num, current_error).await; - continue + continue; } } - self.stats.try_lock().unwrap().complete_request(duration, &rs); + self.stats + .try_lock() + .unwrap() + .complete_request(duration, &rs); rs.map_err(|e| CassError::query_execution_error(cql, &[], e.clone()))?; - return Ok(()) + return Ok(()); } Err(CassError::query_retries_exceeded(self.retry_number)) } @@ -505,8 +550,9 @@ impl Context { .statements .get(key) .ok_or_else(|| CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())))?; - let params = bind::to_scylla_query_params(¶ms)?; - for current_attempt_num in 0..self.retry_number+1 { + + let params = bind::to_scylla_query_params(¶ms, statement.get_variable_col_specs())?; + for current_attempt_num in 0..self.retry_number + 1 { let start_time = self.stats.try_lock().unwrap().start_request(); let rs = self.session.execute(statement, params.clone()).await; let duration = Instant::now() - start_time; @@ -514,14 +560,21 @@ impl Context { Ok(_) => {} Err(e) => { let current_error = CassError::query_execution_error( - statement.get_statement(), ¶ms, e.clone() + statement.get_statement(), + ¶ms, + e.clone(), ); handle_retry_error(self, current_attempt_num, current_error).await; - continue + continue; } } - self.stats.try_lock().unwrap().complete_request(duration, &rs); - rs.map_err(|e| CassError::query_execution_error(statement.get_statement(), ¶ms, e))?; + self.stats + .try_lock() + .unwrap() + .complete_request(duration, &rs); + rs.map_err(|e| { + CassError::query_execution_error(statement.get_statement(), ¶ms, e) + })?; return Ok(()); } Err(CassError::query_retries_exceeded(self.retry_number)) @@ -544,214 +597,284 @@ impl Context { /// Functions for binding rune values to CQL parameters mod bind { use crate::CassErrorKind; - use scylla::frame::response::result::CqlValue; + use scylla::_macro_internal::ColumnType; + use scylla::frame::response::result::{ColumnSpec, CqlValue}; use super::*; - fn to_scylla_value(v: &Value) -> Result { - match v { - // TODO: add support for the following native CQL types: - // 'counter', 'date', 'decimal', 'duration', 'float', 'inet', 'time', 'timeuuid' - // and 'variant'. - // Also, for the 'tuple'. - Value::Bool(v) => Ok(CqlValue::Boolean(*v)), - Value::Byte(v) => Ok(CqlValue::TinyInt(*v as i8)), - Value::Integer(v) => Ok(CqlValue::BigInt(*v)), - Value::Float(v) => Ok(CqlValue::Double(*v)), - Value::StaticString(v) => Ok(CqlValue::Text(v.as_str().to_string())), - Value::String(v) => Ok(CqlValue::Text(v.borrow_ref().unwrap().as_str().to_string())), - Value::Bytes(v) => Ok(CqlValue::Blob(v.borrow_ref().unwrap().to_vec())), - Value::Option(v) => match v.borrow_ref().unwrap().as_ref() { - Some(v) => to_scylla_value(v), - None => Ok(CqlValue::Empty), - }, - Value::Vec(v) => { - let v = v.borrow_ref().unwrap(); - let elements = v.as_ref().iter().map(to_scylla_value).try_collect()?; - Ok(CqlValue::List(elements)) + fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result { + // TODO: add support for the following native CQL types: + // 'counter', 'date', 'decimal', 'duration', 'time' and 'variant'. + // Also, for the 'tuple'. + match (v, typ) { + (Value::Bool(v), ColumnType::Boolean) => Ok(CqlValue::Boolean(*v)), + + (Value::Byte(v), ColumnType::TinyInt) => Ok(CqlValue::TinyInt(*v as i8)), + (Value::Byte(v), ColumnType::SmallInt) => Ok(CqlValue::SmallInt(*v as i16)), + (Value::Byte(v), ColumnType::Int) => Ok(CqlValue::Int(*v as i32)), + (Value::Byte(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v as i64)), + + (Value::Integer(v), ColumnType::TinyInt) => { + convert_int(*v, ColumnType::TinyInt, CqlValue::TinyInt) + } + (Value::Integer(v), ColumnType::SmallInt) => { + convert_int(*v, ColumnType::SmallInt, CqlValue::SmallInt) + } + (Value::Integer(v), ColumnType::Int) => convert_int(*v, ColumnType::Int, CqlValue::Int), + (Value::Integer(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v)), + (Value::Integer(v), ColumnType::Timestamp) => { + Ok(CqlValue::Timestamp(scylla::frame::value::CqlTimestamp(*v))) } - Value::Object(v) => { - let borrowed = v.borrow_ref().unwrap(); - let set_key_name = "_set"; - let list_key_name = "_list"; - let map_key_name = "_map"; - let timestamp_key_name = "_timestamp"; - let udt_keyspace = "_keyspace"; - let udt_key_name = "_type_name"; - - // Check that we don't have a mess of different data types in scope of single object - let mutually_exclusive_keys = vec![ - set_key_name, list_key_name, map_key_name, udt_key_name, timestamp_key_name, - ]; - let mut found_mutually_exclusive_keys = HashSet::new(); - for mutually_exclusive_key in &mutually_exclusive_keys { - if borrowed.contains_key(&mutually_exclusive_key as &str) { - found_mutually_exclusive_keys.insert(mutually_exclusive_key); + + (Value::Float(v), ColumnType::Float) => Ok(CqlValue::Float(*v as f32)), + (Value::Float(v), ColumnType::Double) => Ok(CqlValue::Double(*v)), + + (Value::StaticString(v), ColumnType::Timeuuid) => { + let timeuuid = CqlTimeuuid::from_str(v); + match timeuuid { + Ok(timeuuid) => Ok(CqlValue::Timeuuid(timeuuid)), + Err(e) => { + Err(CassError(CassErrorKind::WrongDataStructure( + format!("Failed to parse '{}' StaticString as Timeuuid: {}", v.as_str(), e), + ))) } } - if found_mutually_exclusive_keys.len() > 1 { - return Err(CassError(CassErrorKind::WrongDataStructure(format!( - "Following mutually exclusive keys were found: {:?}", - found_mutually_exclusive_keys, - )))); - } else if found_mutually_exclusive_keys.len() == 0 { - return Err(CassError(CassErrorKind::WrongDataStructure(format!( - "None of the expected keys were provided: {:?}", - mutually_exclusive_keys, - )))); - } - - // Check if "_timestamp" field exists and is of integer type - if let Some(timestamp_value) = borrowed.get(timestamp_key_name) { - if let Value::Integer(timestamp) = timestamp_value { - return Ok(CqlValue::Timestamp(scylla::frame::value::CqlTimestamp(*timestamp))); - } else { - return Err(CassError(CassErrorKind::WrongDataStructure(format!( - "Unexpected data type provided for the 'timestamp': {:?}", - timestamp_value.type_info().unwrap(), - )))); + } + (Value::String(v), ColumnType::Timeuuid) => { + let timeuuid_str = v.borrow_ref().unwrap(); + let timeuuid = CqlTimeuuid::from_str(timeuuid_str.as_str()); + match timeuuid { + Ok(timeuuid) => Ok(CqlValue::Timeuuid(timeuuid)), + Err(e) => { + Err(CassError(CassErrorKind::WrongDataStructure( + format!("Failed to parse '{}' String as Timeuuid: {}", timeuuid_str.as_str(), e), + ))) } } - - // Check if "_set" field exists and is a vector of values - if let Some(set_value) = borrowed.get(set_key_name) { - if let Value::Vec(elements) = set_value { - let elements = elements.borrow_ref().unwrap().as_ref() - .iter().map(to_scylla_value).try_collect()?; - return Ok(CqlValue::Set(elements)); - } else { - return Err(CassError(CassErrorKind::WrongDataStructure(format!( - "Unexpected data type provided for the 'set': {:?}", - set_value.type_info().unwrap(), - )))); + } + (Value::StaticString(v), ColumnType::Text | ColumnType::Ascii) => { + Ok(CqlValue::Text(v.as_str().to_string())) + } + (Value::String(v), ColumnType::Text | ColumnType::Ascii) => { + Ok(CqlValue::Text(v.borrow_ref().unwrap().as_str().to_string())) + } + (Value::StaticString(v), ColumnType::Inet) => { + let ipaddr = IpAddr::from_str(v); + match ipaddr { + Ok(ipaddr) => Ok(CqlValue::Inet(ipaddr)), + Err(e) => { + Err(CassError(CassErrorKind::WrongDataStructure( + format!("Failed to parse '{}' StaticString as IP address: {}", v.as_str(), e), + ))) } } - - // Check for "_list" field exists and is a vector of values - if let Some(list_value) = borrowed.get(list_key_name) { - if let Value::Vec(elements) = list_value { - let elements = elements.borrow_ref().unwrap().as_ref() - .iter().map(to_scylla_value).try_collect()?; - return Ok(CqlValue::List(elements)); - } else { - return Err(CassError(CassErrorKind::WrongDataStructure(format!( - "Unexpected data type provided for the 'list': {:?}", - list_value.type_info().unwrap(), - )))); + } + (Value::String(v), ColumnType::Inet) => { + let ipaddr_str = v.borrow_ref().unwrap(); + let ipaddr = IpAddr::from_str(ipaddr_str.as_str()); + match ipaddr { + Ok(ipaddr) => Ok(CqlValue::Inet(ipaddr)), + Err(e) => { + Err(CassError(CassErrorKind::WrongDataStructure( + format!("Failed to parse '{}' String as IP address: {}", ipaddr_str.as_str(), e), + ))) } } + } - // Check for "_map" field exists and is a vector of tuples - if let Some(map_value) = borrowed.get(map_key_name) { - if let Value::Vec(vec_value) = map_value { - let vec_unwrapped = vec_value.borrow_ref().unwrap(); - if vec_unwrapped.len() > 0 { - if let Value::Tuple(first_tuple) = &vec_unwrapped[0] { - if first_tuple.borrow_ref().unwrap().len() == 2 { - let map_values: Vec<(CqlValue, CqlValue)> = vec_unwrapped.iter() - .filter_map(|tuple_wrapped| { - if let Value::Tuple(tuple_wrapped) = &tuple_wrapped { - let tuple = tuple_wrapped.borrow_ref().unwrap(); - let key = to_scylla_value(tuple.get(0).unwrap()).unwrap(); - let value = to_scylla_value(tuple.get(1).unwrap()).unwrap(); - Some((key, value)) - } else { None } - }).collect(); - return Ok(CqlValue::Map(map_values)); - } else { - return Err(CassError(CassErrorKind::WrongDataStructure( - "Vector's tuple must have exactly 2 elements".to_string(), - ))) - } - } else { - return Err(CassError(CassErrorKind::WrongDataStructure( - "'_map' is expected to contain vector of tuples only".to_string(), - ))) - } + (Value::Bytes(v), ColumnType::Blob) => { + Ok(CqlValue::Blob(v.borrow_ref().unwrap().to_vec())) + } + (Value::Option(v), typ) => match v.borrow_ref().unwrap().as_ref() { + Some(v) => to_scylla_value(v, typ), + None => Ok(CqlValue::Empty), + } + (Value::Vec(v), ColumnType::List(elt)) => { + let v = v.borrow_ref().unwrap(); + let elements = v + .as_ref() + .iter() + .map(|v| to_scylla_value(v, elt)) + .try_collect()?; + Ok(CqlValue::List(elements)) + } + (Value::Vec(v), ColumnType::Set(elt)) => { + let v = v.borrow_ref().unwrap(); + let elements = v + .as_ref() + .iter() + .map(|v| to_scylla_value(v, elt)) + .try_collect()?; + Ok(CqlValue::Set(elements)) + } + (Value::Vec(v), ColumnType::Map(key_elt, value_elt)) => { + let v = v.borrow_ref().unwrap(); + if v.len() > 0 { + if let Value::Tuple(first_tuple) = &v[0] { + if first_tuple.borrow_ref().unwrap().len() == 2 { + let map_values: Vec<(CqlValue, CqlValue)> = v + .iter() + .filter_map(|tuple_wrapped| { + if let Value::Tuple(tuple_wrapped) = &tuple_wrapped { + let tuple = tuple_wrapped.borrow_ref().unwrap(); + let key = to_scylla_value(tuple.get(0).unwrap(), key_elt).unwrap(); + let value = to_scylla_value(tuple.get(1).unwrap(), value_elt).unwrap(); + Some((key, value)) + } else { + None + } + }) + .collect(); + Ok(CqlValue::Map(map_values)) } else { - return Ok(CqlValue::Map(vec![])); + Err(CassError(CassErrorKind::WrongDataStructure( + "Vector's tuple must have exactly 2 elements".to_string(), + ))) } } else { - return Err(CassError(CassErrorKind::WrongDataStructure( - "'_map' field is expected to contain only Vector type of data".to_string(), + Err(CassError(CassErrorKind::WrongDataStructure( + "ColumnType::Map expects only vector of tuples".to_string(), ))) } + } else { + Ok(CqlValue::Map(vec![])) } + } + ( + Value::Object(v), + ColumnType::UserDefinedType { + keyspace, + type_name, + field_types, + }, + ) => { + let obj = v.borrow_ref().unwrap(); + let fields = read_fields(|s| obj.get(s), field_types)?; + Ok(CqlValue::UserDefinedType { + keyspace: keyspace.to_string(), + type_name: type_name.to_string(), + fields, + }) + } + ( + Value::Struct(v), + ColumnType::UserDefinedType { + keyspace, + type_name, + field_types, + }, + ) => { + let obj = v.borrow_ref().unwrap(); + let fields = read_fields(|s| obj.get(s), field_types)?; + Ok(CqlValue::UserDefinedType { + keyspace: keyspace.to_string(), + type_name: type_name.to_string(), + fields, + }) + } - // Handle last supported case - User Defined Type (UDT) - let keyspace = match borrowed.get_value::(udt_keyspace) { - Ok(Some(value)) => value, - _ => "unknown".to_string(), - }; - let type_name = match borrowed.get_value::(udt_key_name) { - Ok(Some(value)) => value, - _ => "unknown".to_string(), - }; - let keys = borrowed.keys(); - let values: Result>, _> = borrowed.values() - .map(|value| to_scylla_value(&value.clone()) - .map(Some)).collect(); - let fields: Vec<(String, Option)> = keys.into_iter() - .zip(values?.into_iter()) - .filter(|&(key, _)| key != udt_keyspace && key != udt_key_name) - .map(|(key, value)| (key.to_string(), value)) - .collect(); - let udt = CqlValue::UserDefinedType{ - keyspace: keyspace, - type_name: type_name, - fields: fields, - }; - Ok(udt) - }, - Value::Any(obj) => { + (Value::Any(obj), ColumnType::Uuid) => { let obj = obj.borrow_ref().unwrap(); let h = obj.type_hash(); if h == Uuid::type_hash() { let uuid: &Uuid = obj.downcast_borrow_ref().unwrap(); Ok(CqlValue::Uuid(uuid.0)) - } else if h == Int32::type_hash() { - let int32: &Int32 = obj.downcast_borrow_ref().unwrap(); - Ok(CqlValue::Int(int32.0)) - } else if h == Int16::type_hash() { - let int16: &Int16 = obj.downcast_borrow_ref().unwrap(); - Ok(CqlValue::SmallInt(int16.0)) - } else if h == Int8::type_hash() { - let int8: &Int8 = obj.downcast_borrow_ref().unwrap(); - Ok(CqlValue::TinyInt(int8.0)) } else { - Err(CassError(CassErrorKind::UnsupportedType( + Err(CassError(CassErrorKind::QueryParamConversion( v.type_info().unwrap(), + ColumnType::Uuid, ))) } } - other => Err(CassError(CassErrorKind::UnsupportedType( - other.type_info().unwrap(), + (value, typ) => Err(CassError(CassErrorKind::QueryParamConversion( + value.type_info().unwrap(), + typ.clone(), ))), } } + fn convert_int, R>( + value: i64, + typ: ColumnType, + f: impl Fn(T) -> R, + ) -> Result { + let converted = value.try_into().map_err(|_| { + CassError(CassErrorKind::ValueOutOfRange( + value.to_string(), + typ.clone(), + )) + })?; + Ok(f(converted)) + } + /// Binds parameters passed as a single rune value to the arguments of the statement. /// The `params` value can be a tuple, a vector, a struct or an object. - pub fn to_scylla_query_params(params: &Value) -> Result, CassError> { - let mut values = Vec::new(); - match params { + pub fn to_scylla_query_params( + params: &Value, + types: &[ColumnSpec], + ) -> Result, CassError> { + Ok(match params { Value::Tuple(tuple) => { + let mut values = Vec::new(); let tuple = tuple.borrow_ref().unwrap(); - for v in tuple.iter() { - values.push(to_scylla_value(v)?); + if tuple.len() != types.len() { + return Err(CassError(CassErrorKind::InvalidNumberOfQueryParams)); + } + for (v, t) in tuple.iter().zip(types) { + values.push(to_scylla_value(v, &t.typ)?); } + values } Value::Vec(vec) => { + let mut values = Vec::new(); + let vec = vec.borrow_ref().unwrap(); - for v in vec.iter() { - values.push(to_scylla_value(v)?); + for (v, t) in vec.iter().zip(types) { + values.push(to_scylla_value(v, &t.typ)?); } + values + } + Value::Object(obj) => { + let obj = obj.borrow_ref().unwrap(); + read_params(|f| obj.get(f), types)? + } + Value::Struct(obj) => { + let obj = obj.borrow_ref().unwrap(); + read_params(|f| obj.get(f), types)? } other => { - return Err(CassError(CassErrorKind::UnsupportedType( + return Err(CassError(CassErrorKind::InvalidQueryParamsObject( other.type_info().unwrap(), ))); } + }) + } + + fn read_params<'a, 'b>( + get_value: impl Fn(&String) -> Option<&'a Value>, + params: &[ColumnSpec], + ) -> Result, CassError> { + let mut values = Vec::with_capacity(params.len()); + for column in params { + let value = match get_value(&column.name) { + Some(value) => to_scylla_value(value, &column.typ)?, + None => CqlValue::Empty, + }; + values.push(value) + } + Ok(values) + } + + fn read_fields<'a, 'b>( + get_value: impl Fn(&String) -> Option<&'a Value>, + fields: &[(String, ColumnType)], + ) -> Result)>, CassError> { + let mut values = Vec::with_capacity(fields.len()); + for (field_name, field_type) in fields { + if let Some(value) = get_value(field_name) { + let value = Some(to_scylla_value(value, field_type)?); + values.push((field_name.to_string(), value)) + }; } Ok(values) } @@ -791,6 +914,9 @@ pub struct Int16(pub i16); #[derive(Clone, Debug, Any)] pub struct Int32(pub i32); +#[derive(Clone, Debug, Any)] +pub struct Float32(pub f32); + /// Returns the literal value stored in the `params` map under the key given as the first /// macro arg, and if not found, returns the expression from the second arg. pub fn param( @@ -844,6 +970,22 @@ pub fn float_to_i32(value: f64) -> Option { int_to_i32(value as i64) } +pub fn int_to_f32(value: i64) -> Option { + Some(Float32(value as f32)) +} + +pub fn float_to_f32(value: f64) -> Option { + Some(Float32(value as f32)) +} + +pub fn int_to_string(value: i64) -> Option { + Some(value.to_string()) +} + +pub fn float_to_string(value: f64) -> Option { + Some(value.to_string()) +} + /// Computes a hash of an integer value `i`. /// Returns a value in range `0..i64::MAX`. pub fn hash(i: i64) -> i64 { @@ -873,6 +1015,12 @@ pub fn normal(i: i64, mean: f64, std_dev: f64) -> Result { Ok(distribution.sample(&mut rng)) } +pub fn uniform(i: i64, min: f64, max: f64) -> Result { + let mut rng = StdRng::seed_from_u64(i as u64); + let distribution = Uniform::new(min, max).map_err(|e| VmError::panic(format!("{e}")))?; + Ok(distribution.sample(&mut rng)) +} + /// Restricts a value to a certain interval unless it is NaN. pub fn clamp_float(value: f64, min: f64, max: f64) -> f64 { value.clamp(min, max) @@ -895,10 +1043,12 @@ pub fn blob(seed: i64, len: usize) -> rune::runtime::Bytes { /// Parameter `seed` is used to seed the RNG. pub fn text(seed: i64, len: usize) -> rune::runtime::StaticString { let mut rng = StdRng::seed_from_u64(seed as u64); - let s: String = (0..len).map(|_| { - let code_point = rng.gen_range(0x0061u32..=0x007Au32); // Unicode range for 'a-z' - std::char::from_u32(code_point).unwrap() - }).collect(); + let s: String = (0..len) + .map(|_| { + let code_point = rng.gen_range(0x0061u32..=0x007Au32); // Unicode range for 'a-z' + std::char::from_u32(code_point).unwrap() + }) + .collect(); rune::runtime::StaticString::new(s) } diff --git a/src/error.rs b/src/error.rs index 88e8e8e..5e7bbb4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,5 @@ use crate::context::CassError; +use crate::stats::BenchmarkStats; use err_derive::*; use hdrhistogram::serialization::interval_log::IntervalLogWriterError; use hdrhistogram::serialization::V2DeflateSerializeError; @@ -37,7 +38,7 @@ pub enum LatteError { HdrLogWrite(#[source] IntervalLogWriterError), #[error(display = "Interrupted")] - Interrupted, + Interrupted(Box), } pub type Result = std::result::Result; diff --git a/src/exec.rs b/src/exec.rs index 4c626c6..41d8b43 100644 --- a/src/exec.rs +++ b/src/exec.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use std::time::Instant; use tokio_stream::wrappers::IntervalStream; -use crate::error::Result; +use crate::error::{LatteError, Result}; use crate::{ - BenchmarkStats, BoundedCycleCounter, InterruptHandler, Interval, Progress, Recorder, Sampler, - Workload, WorkloadStats, + BenchmarkStats, BoundedCycleCounter, Interval, Progress, Recorder, Sampler, Workload, + WorkloadStats, }; /// Returns a stream emitting `rate` events per second. @@ -43,7 +43,6 @@ async fn run_stream( cycle_counter: BoundedCycleCounter, concurrency: NonZeroUsize, sampling: Interval, - interrupt: Arc, progress: Arc>, mut out: Sender>, ) { @@ -68,9 +67,6 @@ async fn run_stream( return; } } - if interrupt.is_interrupted() { - break; - } } // Send the statistics of remaining requests sampler.finish().await; @@ -88,7 +84,6 @@ fn spawn_stream( sampling: Interval, workload: Workload, iter_counter: BoundedCycleCounter, - interrupt: Arc, progress: Arc>, ) -> Receiver> { let (tx, rx) = channel(1); @@ -103,7 +98,6 @@ fn spawn_stream( iter_counter, concurrency, sampling, - interrupt, progress, tx, ) @@ -117,7 +111,6 @@ fn spawn_stream( iter_counter, concurrency, sampling, - interrupt, progress, tx, ) @@ -170,7 +163,6 @@ pub async fn par_execute( sampling: Interval, store_samples: bool, workload: Workload, - signals: Arc, show_progress: bool, ) -> Result { let thread_count = exec_options.threads.get(); @@ -197,29 +189,31 @@ pub async fn par_execute( sampling, workload.clone()?, deadline.share(), - signals.clone(), progress.clone(), ); streams.push(s); } loop { - let partial_stats: Vec<_> = receive_one_of_each(&mut streams) - .await - .into_iter() - .try_collect()?; - - if partial_stats.is_empty() { - break; - } + tokio::select! { + partial_stats = receive_one_of_each(&mut streams) => { + let partial_stats: Vec<_> = partial_stats.into_iter().try_collect()?; + if partial_stats.is_empty() { + break Ok(stats.finish()); + } + + let aggregate = stats.record(&partial_stats); + if sampling.is_bounded() { + progress.set_visible(false); + println!("{aggregate}"); + progress.set_visible(show_progress); + } + } - let aggregate = stats.record(&partial_stats); - if sampling.is_bounded() { - progress.set_visible(false); - println!("{aggregate}"); - progress.set_visible(show_progress); + _ = tokio::signal::ctrl_c() => { + progress.set_visible(false); + break Err(LatteError::Interrupted(Box::new(stats.finish()))); + } } } - - Ok(stats.finish()) } diff --git a/src/histogram.rs b/src/histogram.rs index ba90630..6a320a9 100644 --- a/src/histogram.rs +++ b/src/histogram.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Deserializer, Serialize}; /// A wrapper for HDR histogram that allows us to serialize/deserialize it to/from /// a base64 encoded string we can store in JSON report. +#[derive(Debug)] pub struct SerializableHistogram(pub Histogram); impl Serialize for SerializableHistogram { diff --git a/src/interrupt.rs b/src/interrupt.rs deleted file mode 100644 index 9ce27dd..0000000 --- a/src/interrupt.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -/// Notifies about received Ctrl-C signal -pub struct InterruptHandler { - interrupted: Arc, -} - -impl InterruptHandler { - pub fn install() -> InterruptHandler { - let cell = Arc::new(AtomicBool::new(false)); - let cell_ref = cell.clone(); - let _ = ctrlc::set_handler(move || cell_ref.store(true, Ordering::Relaxed)); - InterruptHandler { interrupted: cell } - } - - /// Returns true if Ctrl-C was pressed - pub fn is_interrupted(&self) -> bool { - self.interrupted.load(Ordering::Relaxed) - } -} diff --git a/src/main.rs b/src/main.rs index a45ab15..3d7eb00 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use std::fs::File; use std::io::{stdout, Write}; use std::path::{Path, PathBuf}; use std::process::exit; -use std::sync::Arc; use std::time::Duration; use clap::Parser; @@ -24,7 +23,6 @@ use crate::context::{CassError, CassErrorKind, Context, SessionStats}; use crate::cycle::BoundedCycleCounter; use crate::error::{LatteError, Result}; use crate::exec::{par_execute, ExecutionOptions}; -use crate::interrupt::InterruptHandler; use crate::plot::plot_graph; use crate::progress::Progress; use crate::report::{Report, RunConfigCmp}; @@ -38,7 +36,6 @@ mod cycle; mod error; mod exec; mod histogram; -mod interrupt; mod plot; mod progress; mod report; @@ -170,7 +167,6 @@ async fn load(conf: LoadCommand) -> Result<()> { } } - let interrupt = Arc::new(InterruptHandler::install()); eprintln!("info: Loading data..."); let loader = Workload::new(session.clone()?, program.clone(), FnRef::new(LOAD_FN)); let load_options = ExecutionOptions { @@ -185,7 +181,6 @@ async fn load(conf: LoadCommand) -> Result<()> { config::Interval::Unbounded, false, loader, - interrupt.clone(), !conf.quiet, ) .await?; @@ -229,7 +224,6 @@ async fn run(conf: RunCommand) -> Result<()> { } let runner = Workload::new(session.clone()?, program.clone(), function); - let interrupt = Arc::new(InterruptHandler::install()); if conf.warmup_duration.is_not_zero() { eprintln!("info: Warming up..."); let warmup_options = ExecutionOptions { @@ -244,16 +238,11 @@ async fn run(conf: RunCommand) -> Result<()> { Interval::Unbounded, conf.generate_report, runner.clone()?, - interrupt.clone(), !conf.quiet, ) .await?; } - if interrupt.is_interrupted() { - return Err(LatteError::Interrupted); - } - eprintln!("info: Running benchmark..."); println!( @@ -272,16 +261,22 @@ async fn run(conf: RunCommand) -> Result<()> { }; report::print_log_header(); - let stats = par_execute( + let stats = match par_execute( "Running...", &exec_options, conf.sampling_interval, conf.generate_report, runner, - interrupt.clone(), !conf.quiet, ) - .await?; + .await + { + Ok(stats) => stats, + Err(LatteError::Interrupted(stats)) => *stats, + Err(e) => { + return Err(e); + } + }; let stats_cmp = BenchmarkCmp { v1: &stats, diff --git a/src/plot.rs b/src/plot.rs index 85c2d00..2b4a7b8 100644 --- a/src/plot.rs +++ b/src/plot.rs @@ -144,7 +144,7 @@ pub async fn plot_graph(conf: PlotCommand) -> Result<()> { let output_path = conf .output - .unwrap_or(reports[0].conf.default_output_file_name("png")); + .unwrap_or(reports[0].conf.default_output_file_name("svg")); let root = SVGBackend::new(&output_path, (2000, 1000)).into_drawing_area(); root.fill(&WHITE).unwrap(); diff --git a/src/report.rs b/src/report.rs index 222e702..e1adadb 100644 --- a/src/report.rs +++ b/src/report.rs @@ -5,7 +5,7 @@ use std::num::NonZeroUsize; use std::path::Path; use std::{fs, io}; -use chrono::{Local, NaiveDateTime, TimeZone}; +use chrono::{Local, TimeZone}; use console::{pad_str, style, Alignment}; use err_derive::*; use itertools::Itertools; @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use statrs::statistics::Statistics; use strum::IntoEnumIterator; -use crate::config::{PRINT_RETRY_ERROR_LIMIT, RunCommand}; +use crate::config::{RunCommand, PRINT_RETRY_ERROR_LIMIT}; use crate::stats::{ BenchmarkCmp, BenchmarkStats, Bucket, Mean, Percentile, Sample, Significance, TimeDistribution, }; @@ -71,7 +71,7 @@ impl Report { pub struct Quantity { pub value: Option, pub error: Option, - pub precision: usize, + pub precision: Option, } impl Quantity { @@ -79,12 +79,12 @@ impl Quantity { Quantity { value, error: None, - precision: 0, + precision: None, } } pub fn with_precision(mut self, precision: usize) -> Self { - self.precision = precision; + self.precision = Some(precision); self } @@ -96,9 +96,10 @@ impl Quantity { impl Quantity { fn format_error(&self) -> String { + let prec = self.precision.unwrap_or_default(); match &self.error { None => "".to_owned(), - Some(e) => format!("± {:<6.prec$}", e, prec = self.precision), + Some(e) => format!("± {:<6.prec$}", e, prec = prec), } } } @@ -142,13 +143,19 @@ impl From<&Option> for Quantity { impl Display for Quantity { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match &self.value { - None => write!(f, "{}", " ".repeat(18)), - Some(v) => write!( + match (&self.value, self.precision) { + (None, _) => write!(f, "{}", " ".repeat(18)), + (Some(v), None) => write!( + f, + "{value:9} {error:8}", + value = style(v).bright().for_stdout(), + error = style(self.format_error()).dim().for_stdout(), + ), + (Some(v), Some(prec)) => write!( f, "{value:9.prec$} {error:8}", value = style(v).bright().for_stdout(), - prec = self.precision, + prec = prec, error = style(self.format_error()).dim().for_stdout(), ), } @@ -431,8 +438,10 @@ impl RunConfigCmp<'_> { fn format_time(&self, conf: &RunCommand, format: &str) -> String { conf.timestamp .and_then(|ts| { - NaiveDateTime::from_timestamp_opt(ts, 0) - .map(|utc| Local.from_utc_datetime(&utc).format(format).to_string()) + Local + .timestamp_opt(ts, 0) + .latest() + .map(|l| l.format(format).to_string()) }) .unwrap_or_default() } @@ -534,7 +543,7 @@ impl<'a> Display for RunConfigCmp<'a> { self.line("Request timeout", "", |conf| { Quantity::from(conf.connection.request_timeout) }), - self.line("Retries", "", |_| {Quantity::from("")}), + self.line("Retries", "", |_| Quantity::from("")), self.line("┌──────┴number", "", |conf| { Quantity::from(conf.connection.retry_number) }), @@ -568,7 +577,9 @@ impl Display for Sample { if num_of_printed_errors < PRINT_RETRY_ERROR_LIMIT { error_msg_bunch += &format!("{}\n", retry_error); num_of_printed_errors += 1; - } else { break } + } else { + break; + } } let num_of_dropped_errors = self.retry_error_count - num_of_printed_errors; if num_of_dropped_errors > 0 { @@ -577,7 +588,7 @@ impl Display for Sample { num_of_dropped_errors, ); } - eprintln!("{}", error_msg_bunch); + writeln!(f, "{}", error_msg_bunch)?; } write!( f, @@ -653,24 +664,36 @@ impl<'a> Display for BenchmarkCmp<'a> { let summary_part2: Vec> = vec![ self.line("Mean sample size", "op", |s| { Quantity::from(s.log.iter().map(|s| s.cycle_count as f64).mean()) + .with_precision(0) }), self.line("└─", "req", |s| { Quantity::from(s.log.iter().map(|s| s.request_count as f64).mean()) + .with_precision(0) + }), + self.line("Concurrency", "req", |s| { + Quantity::from(s.concurrency).with_precision(0) }), - self.line("Concurrency", "req", |s| Quantity::from(s.concurrency)), - self.line("└─", "%", |s| Quantity::from(s.concurrency_ratio)), - self.line("Throughput", "op/s", |s| Quantity::from(s.cycle_throughput)) - .with_significance(self.cmp_cycle_throughput()) - .with_orientation(1) - .into_box(), - self.line("├─", "req/s", |s| Quantity::from(s.req_throughput)) - .with_significance(self.cmp_req_throughput()) - .with_orientation(1) - .into_box(), - self.line("└─", "row/s", |s| Quantity::from(s.row_throughput)) - .with_significance(self.cmp_row_throughput()) - .with_orientation(1) - .into_box(), + self.line("└─", "%", |s| { + Quantity::from(s.concurrency_ratio).with_precision(0) + }), + self.line("Throughput", "op/s", |s| { + Quantity::from(s.cycle_throughput).with_precision(0) + }) + .with_significance(self.cmp_cycle_throughput()) + .with_orientation(1) + .into_box(), + self.line("├─", "req/s", |s| { + Quantity::from(s.req_throughput).with_precision(0) + }) + .with_significance(self.cmp_req_throughput()) + .with_orientation(1) + .into_box(), + self.line("└─", "row/s", |s| { + Quantity::from(s.row_throughput).with_precision(0) + }) + .with_significance(self.cmp_row_throughput()) + .with_orientation(1) + .into_box(), self.line("Mean cycle time", "ms", |s| { Quantity::from(&s.cycle_time_ms).with_precision(3) }) diff --git a/src/stats.rs b/src/stats.rs index 685b4a5..b9977c6 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -259,7 +259,7 @@ impl Percentile { } /// Records basic statistics for a sample (a group) of requests -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct Sample { pub time_s: f32, pub duration_s: f32, @@ -360,7 +360,7 @@ impl Log { Log { samples: Vec::new(), samples_counter: 0, - store_samples: store_samples, + store_samples, } } @@ -451,7 +451,7 @@ impl Log { } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct Bucket { pub percentile: f64, pub duration_ms: f64, @@ -459,7 +459,7 @@ pub struct Bucket { pub cumulative_count: u64, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct TimeDistribution { pub mean: Mean, pub percentiles: Vec, @@ -467,7 +467,7 @@ pub struct TimeDistribution { } /// Stores the final statistics of the test run. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct BenchmarkStats { pub start_time: DateTime, pub end_time: DateTime, @@ -570,7 +570,6 @@ pub struct Recorder { pub row_count: u64, pub cycle_times_ns: Histogram, pub resp_times_ns: Histogram, - pub queue_len_sum: u64, log: Log, rate_limit: Option, concurrency_limit: NonZeroUsize, @@ -600,7 +599,6 @@ impl Recorder { error_count: 0, cycle_times_ns: Histogram::new(3).unwrap(), resp_times_ns: Histogram::new(3).unwrap(), - queue_len_sum: 0, } } diff --git a/src/workload.rs b/src/workload.rs index ebefc0e..a9cfc6f 100644 --- a/src/workload.rs +++ b/src/workload.rs @@ -129,7 +129,9 @@ impl Program { let mut latte_module = Module::with_crate("latte"); latte_module.function(&["blob"], context::blob).unwrap(); latte_module.function(&["text"], context::text).unwrap(); - latte_module.function(&["now_timestamp"], context::now_timestamp).unwrap(); + latte_module + .function(&["now_timestamp"], context::now_timestamp) + .unwrap(); latte_module.function(&["hash"], context::hash).unwrap(); latte_module.function(&["hash2"], context::hash2).unwrap(); latte_module @@ -142,10 +144,20 @@ impl Program { .function(&["uuid"], context::Uuid::new) .unwrap(); latte_module.function(&["normal"], context::normal).unwrap(); + latte_module + .function(&["uniform"], context::uniform) + .unwrap(); latte_module .macro_(&["param"], move |ctx, ts| context::param(ctx, ¶ms, ts)) .unwrap(); + latte_module + .inst_fn("to_string", context::int_to_string) + .unwrap(); + latte_module + .inst_fn("to_string", context::float_to_string) + .unwrap(); + latte_module.inst_fn("to_i32", context::int_to_i32).unwrap(); latte_module .inst_fn("to_i32", context::float_to_i32) @@ -156,6 +168,11 @@ impl Program { .unwrap(); latte_module.inst_fn("to_i8", context::int_to_i8).unwrap(); latte_module.inst_fn("to_i8", context::float_to_i8).unwrap(); + latte_module.inst_fn("to_f32", context::int_to_f32).unwrap(); + latte_module + .inst_fn("to_f32", context::float_to_f32) + .unwrap(); + latte_module.inst_fn("clamp", context::clamp_float).unwrap(); latte_module.inst_fn("clamp", context::clamp_int).unwrap();