From 5443682fe49faee8202de09f2ef3f363bab4edc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Sat, 21 Sep 2024 14:31:36 +0200 Subject: [PATCH] Expose a better arg interface to run the benchmarks --- Cargo.lock | 34 +++++++++++++--- benchmarks/Cargo.toml | 2 + benchmarks/src/main.rs | 92 ++++++++++++++++++++++++++---------------- 3 files changed, 87 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dbcd799..99d4179 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -375,6 +375,8 @@ dependencies = [ "arroy", "byte-unit", "bytemuck", + "clap", + "enum-iterator", "futures-util", "heed", "memmap2", @@ -579,9 +581,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3" dependencies = [ "clap_builder", "clap_derive", @@ -589,9 +591,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.17" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b" dependencies = [ "anstream", "anstyle", @@ -601,9 +603,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck", "proc-macro2", @@ -798,6 +800,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-iterator" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c280b9e6b3ae19e152d8e31cf47f18389781e119d4013a2a2bb0180e5facc635" +dependencies = [ + "enum-iterator-derive", +] + +[[package]] +name = "enum-iterator-derive" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.70", +] + [[package]] name = "equivalent" version = "1.0.1" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 57e4277..5b980e2 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -8,6 +8,8 @@ anyhow = "1.0.86" arroy = { git = "https://github.com/meilisearch/arroy", rev = "2386594" } byte-unit = "5.1.4" bytemuck = "1.16.1" +clap = { version = "4.5.18", features = ["derive"] } +enum-iterator = "2.1.0" futures-util = "0.3.30" heed = "0.20.3" memmap2 = "0.9.4" diff --git a/benchmarks/src/main.rs b/benchmarks/src/main.rs index 95c2b71..014624d 100644 --- a/benchmarks/src/main.rs +++ b/benchmarks/src/main.rs @@ -1,49 +1,71 @@ use benchmarks::{bench_over_all_distances, MatLEView}; +use clap::{Parser, ValueEnum}; +use enum_iterator::Sequence; -fn hn_top_post() -> MatLEView { - MatLEView::new("Hackernews top posts", "assets/hn-top-posts.mat", 1024) +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Sequence)] +enum Dataset { + /// Hackernews posts (512) + HnPosts, + /// Wikipedia (768) + Wikipedia, + /// Hackernews top posts (1024) + HnTopPost, + /// db pedia OpenAI text-embedding ada 002 (1536) + DbPediaAda002, + /// db pedia OpenAI text-embedding 3 large (3072) + DbPedia3Large, } -fn hn_posts() -> MatLEView { - MatLEView::new("Hackernews posts", "assets/hn-posts.mat", 512) +impl From for MatLEView { + fn from(dataset: Dataset) -> Self { + match dataset { + Dataset::HnPosts => MatLEView::new("Hackernews posts", "assets/hn-posts.mat", 512), + Dataset::Wikipedia => MatLEView::new( + "wikipedia 22 12 simple embeddings", + "assets/wikipedia-22-12-simple-embeddings.mat", + 768, + ), + Dataset::HnTopPost => { + MatLEView::new("Hackernews top posts", "assets/hn-top-posts.mat", 1024) + } + Dataset::DbPediaAda002 => MatLEView::new( + "db pedia OpenAI text-embedding ada 002", + "assets/db-pedia-OpenAI-text-embedding-ada-002.mat", + 1536, + ), + Dataset::DbPedia3Large => MatLEView::new( + "db pedia OpenAI text-embedding 3 large", + "assets/db-pedia-OpenAI-text-embedding-3-large.mat", + 3072, + ), + } + } } -fn db_pedia_3_large() -> MatLEView { - MatLEView::new( - "db pedia OpenAI text-embedding 3 large", - "assets/db-pedia-OpenAI-text-embedding-3-large.mat", - 3072, - ) -} +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// The datasets to run and all of them are ran if empty. + #[arg(value_enum)] + datasets: Vec, -fn db_pedia_ada_002_large() -> MatLEView { - MatLEView::new( - "db pedia OpenAI text-embedding ada 002", - "assets/db-pedia-OpenAI-text-embedding-ada-002.mat", - 1536, - ) -} - -fn wikipedia_768() -> MatLEView { - MatLEView::new( - "wikipedia 22 12 simple embeddings", - "assets/wikipedia-22-12-simple-embeddings.mat", - 768, - ) + /// Number of vectors to evaluate from the datasets. + #[arg(long, default_value_t = 100_000)] + count: usize, } fn main() { - let take = 100_000; - for dataset in [ - &hn_posts(), - &hn_top_post(), - &db_pedia_3_large(), - &db_pedia_ada_002_large(), - &wikipedia_768(), - ] { - let vectors: Vec<(u32, &[f32])> = - dataset.iter().enumerate().map(|(i, v)| (i as u32, v)).take(take).collect(); + let Args { datasets, count } = Args::parse(); + + let datasets: Vec> = if datasets.is_empty() { + enum_iterator::all::().map(Into::into).collect() + } else { + datasets.into_iter().map(Into::into).collect() + }; + for dataset in datasets { + let vectors: Vec<_> = + dataset.iter().enumerate().map(|(i, v)| (i as u32, v)).take(count).collect(); dataset.header(); bench_over_all_distances(dataset.dimensions(), vectors.as_slice()); println!();