From bf8b849e0f3d919e7c6c72a452f9e8aa092e11b6 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Sun, 19 Jan 2025 17:45:59 -0500 Subject: [PATCH] feat(planner): match vector indexes Signed-off-by: Alex Chi --- src/binder/create_index.rs | 111 ++++++++++++++++++++++++++ src/binder/mod.rs | 2 +- src/catalog/index.rs | 15 +++- src/catalog/root.rs | 4 +- src/catalog/schema.rs | 10 ++- src/executor/create_index.rs | 1 + src/planner/cost.rs | 2 +- src/planner/explain.rs | 11 +++ src/planner/mod.rs | 1 + src/planner/optimizer.rs | 3 +- src/planner/rules/plan.rs | 67 ++++++++++++++++ src/storage/memory/mod.rs | 10 ++- src/storage/mod.rs | 2 + src/storage/secondary/mod.rs | 10 ++- tests/planner_test/vector.planner.sql | 15 ++++ tests/planner_test/vector.yml | 18 +++++ tests/sql/vector_index.slt | 18 +++++ 17 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 tests/planner_test/vector.planner.sql create mode 100644 tests/planner_test/vector.yml create mode 100644 tests/sql/vector_index.slt diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index b3e93492..4b0cae93 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -10,12 +10,47 @@ use serde::{Deserialize, Serialize}; use super::*; use crate::catalog::{ColumnId, SchemaId, TableId}; +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] +pub enum VectorDistance { + Cosine, + L2, + NegativeDotProduct, +} + +impl FromStr for VectorDistance { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s { + "cosine" => Ok(VectorDistance::Cosine), + "<=>" => Ok(VectorDistance::Cosine), + "l2" => Ok(VectorDistance::L2), + "<->" => Ok(VectorDistance::L2), + "dotproduct" => Ok(VectorDistance::NegativeDotProduct), + "<#>" => Ok(VectorDistance::NegativeDotProduct), + _ => Err(format!("invalid vector distance: {}", s)), + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] +pub enum IndexType { + Hnsw, + IvfFlat { + distance: VectorDistance, + nlists: usize, + nprobe: usize, + }, + Btree, +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] pub struct CreateIndex { pub schema_id: SchemaId, pub index_name: String, pub table_id: TableId, pub columns: Vec, + pub index_type: IndexType, } impl fmt::Display for CreateIndex { @@ -48,6 +83,79 @@ impl FromStr for Box { } impl Binder { + fn parse_index_type(&self, using: Option, with: Vec) -> Result { + let Some(using) = using else { + return Err(ErrorKind::InvalidIndex("using clause is required".to_string()).into()); + }; + match using.to_string().to_lowercase().as_str() { + "hnsw" => Ok(IndexType::Hnsw), + "ivfflat" => { + let mut distfn = None; + let mut nlists = None; + let mut nprobe = None; + for expr in with { + let Expr::BinaryOp { left, op, right } = expr else { + return Err( + ErrorKind::InvalidIndex("invalid with clause".to_string()).into() + ); + }; + if op != BinaryOperator::Eq { + return Err( + ErrorKind::InvalidIndex("invalid with clause".to_string()).into() + ); + } + let Expr::Identifier(Ident { value: key, .. }) = *left else { + return Err( + ErrorKind::InvalidIndex("invalid with clause".to_string()).into() + ); + }; + let key = key.to_lowercase(); + let Expr::Value(v) = *right else { + return Err( + ErrorKind::InvalidIndex("invalid with clause".to_string()).into() + ); + }; + let v: DataValue = v.into(); + match key.as_str() { + "distfn" => { + let v = v.as_str(); + distfn = Some(v.to_lowercase()); + } + "nlists" => { + let Some(v) = v.as_usize().unwrap() else { + return Err(ErrorKind::InvalidIndex( + "invalid with clause".to_string(), + ) + .into()); + }; + nlists = Some(v); + } + "nprobe" => { + let Some(v) = v.as_usize().unwrap() else { + return Err(ErrorKind::InvalidIndex( + "invalid with clause".to_string(), + ) + .into()); + }; + nprobe = Some(v); + } + _ => { + return Err( + ErrorKind::InvalidIndex("invalid with clause".to_string()).into() + ); + } + } + } + Ok(IndexType::IvfFlat { + distance: VectorDistance::from_str(distfn.unwrap().as_str()).unwrap(), + nlists: nlists.unwrap(), + nprobe: nprobe.unwrap(), + }) + } + _ => Err(ErrorKind::InvalidIndex("invalid index type".to_string()).into()), + } + } + pub(super) fn bind_create_index(&mut self, stat: crate::parser::CreateIndex) -> Result { let Some(ref name) = stat.name else { return Err( @@ -57,6 +165,8 @@ impl Binder { let crate::parser::CreateIndex { table_name, columns, + using, + with, .. } = stat; let index_name = lower_case_name(name); @@ -94,6 +204,7 @@ impl Binder { index_name: index_name.into(), table_id: table.id(), columns: column_ids, + index_type: self.parse_index_type(using, with)?, }))); Ok(create) } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 4b5d6fce..edb901df 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -29,7 +29,7 @@ mod select; mod table; pub use self::create_function::CreateFunction; -pub use self::create_index::CreateIndex; +pub use self::create_index::{CreateIndex, IndexType, VectorDistance}; pub use self::create_table::CreateTable; pub use self::error::BindError; use self::error::ErrorKind; diff --git a/src/catalog/index.rs b/src/catalog/index.rs index 42574378..161efa32 100644 --- a/src/catalog/index.rs +++ b/src/catalog/index.rs @@ -1,6 +1,7 @@ // Copyright 2025 RisingLight Project Authors. Licensed under Apache-2.0. use super::*; +use crate::binder::IndexType; /// The catalog of an index. pub struct IndexCatalog { @@ -8,15 +9,23 @@ pub struct IndexCatalog { name: String, table_id: TableId, column_idxs: Vec, + index_type: IndexType, } impl IndexCatalog { - pub fn new(id: IndexId, name: String, table_id: TableId, column_idxs: Vec) -> Self { + pub fn new( + id: IndexId, + name: String, + table_id: TableId, + column_idxs: Vec, + index_type: IndexType, + ) -> Self { Self { id, name, table_id, column_idxs, + index_type, } } @@ -35,4 +44,8 @@ impl IndexCatalog { pub fn name(&self) -> &str { &self.name } + + pub fn index_type(&self) -> IndexType { + self.index_type.clone() + } } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index d54a35f1..a5c47e41 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex}; use super::function::FunctionCatalog; use super::*; +use crate::binder::IndexType; use crate::parser; use crate::planner::RecExpr; @@ -104,10 +105,11 @@ impl RootCatalog { index_name: String, table_id: TableId, column_idxs: &[ColumnId], + index_type: &IndexType, ) -> Result { let mut inner = self.inner.lock().unwrap(); let schema = inner.schemas.get_mut(&schema_id).unwrap(); - schema.add_index(index_name, table_id, column_idxs.to_vec()) + schema.add_index(index_name, table_id, column_idxs.to_vec(), index_type) } pub fn get_index_on_table(&self, schema_id: SchemaId, table_id: TableId) -> Vec { diff --git a/src/catalog/schema.rs b/src/catalog/schema.rs index 5a6a81d1..e6d12ae5 100644 --- a/src/catalog/schema.rs +++ b/src/catalog/schema.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use super::function::FunctionCatalog; use super::*; +use crate::binder::IndexType; use crate::planner::RecExpr; /// The catalog of a schema. @@ -62,13 +63,20 @@ impl SchemaCatalog { name: String, table_id: TableId, columns: Vec, + index_type: &IndexType, ) -> Result { if self.indexes_idxs.contains_key(&name) { return Err(CatalogError::Duplicated("index", name)); } let index_id = self.next_id; self.next_id += 1; - let index_catalog = Arc::new(IndexCatalog::new(index_id, name.clone(), table_id, columns)); + let index_catalog = Arc::new(IndexCatalog::new( + index_id, + name.clone(), + table_id, + columns, + index_type.clone(), + )); self.indexes_idxs.insert(name, index_id); self.indexes.insert(index_id, index_catalog); Ok(index_id) diff --git a/src/executor/create_index.rs b/src/executor/create_index.rs index a1ac1576..2189e730 100644 --- a/src/executor/create_index.rs +++ b/src/executor/create_index.rs @@ -21,6 +21,7 @@ impl CreateIndexExecutor { &self.index.index_name, self.index.table_id, &self.index.columns, + &self.index.index_type, ) .await?; diff --git a/src/planner/cost.rs b/src/planner/cost.rs index 019f450b..722b9400 100644 --- a/src/planner/cost.rs +++ b/src/planner/cost.rs @@ -31,7 +31,7 @@ impl egg::CostFunction for CostFn<'_> { let c = match enode { // plan nodes - Scan(_) | Values(_) => build(), + Scan(_) | Values(_) | IndexScan(_) => build(), Order([_, c]) => nlogn(rows(c)) + build() + costs(c), Filter([exprs, c]) => costs(exprs) * rows(c) + build() + costs(c), Proj([exprs, c]) | Window([exprs, c]) => costs(exprs) * rows(c) + costs(c), diff --git a/src/planner/explain.rs b/src/planner/explain.rs index 8c19f666..3a396045 100644 --- a/src/planner/explain.rs +++ b/src/planner/explain.rs @@ -248,6 +248,17 @@ impl<'a> Explain<'a> { ("filter", self.expr(filter).pretty()), ]), ), + IndexScan([table, columns, filter, op, key, vector]) => Pretty::childless_record( + "IndexScan", + with_meta(vec![ + ("table", self.expr(table).pretty()), + ("columns", self.expr(columns).pretty()), + ("filter", self.expr(filter).pretty()), + ("op", self.expr(op).pretty()), + ("key", self.expr(key).pretty()), + ("vector", self.expr(vector).pretty()), + ]), + ), Values(values) => Pretty::simple_record( "Values", with_meta(vec![("rows", Pretty::display(&values.len()))]), diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 59ca6d94..43b67aff 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -98,6 +98,7 @@ define_language! { // plans "scan" = Scan([Id; 3]), // (scan table [column..] filter) + "vector_index_scan" = IndexScan([Id; 6]), // (vector_index_scan table [column..] filter key vector) "values" = Values(Box<[Id]>), // (values [expr..]..) "proj" = Proj([Id; 2]), // (proj [expr..] child) "filter" = Filter([Id; 2]), // (filter expr child) diff --git a/src/planner/optimizer.rs b/src/planner/optimizer.rs index 202509bc..4d457766 100644 --- a/src/planner/optimizer.rs +++ b/src/planner/optimizer.rs @@ -121,13 +121,14 @@ static STAGE1_RULES: LazyLock> = LazyLock::new(|| { }); /// Stage2 rules in the optimizer. -/// - pushdown predicate and projection +/// - pushdown predicate, projection, and index scan static STAGE2_RULES: LazyLock> = LazyLock::new(|| { let mut rules = vec![]; rules.append(&mut rules::expr::rules()); rules.append(&mut rules::plan::always_better_rules()); rules.append(&mut rules::plan::predicate_pushdown_rules()); rules.append(&mut rules::plan::projection_pushdown_rules()); + rules.append(&mut rules::plan::index_scan_rules()); rules }); diff --git a/src/planner/rules/plan.rs b/src/planner/rules/plan.rs index 1dcd3be6..90ccb219 100644 --- a/src/planner/rules/plan.rs +++ b/src/planner/rules/plan.rs @@ -6,7 +6,9 @@ use itertools::Itertools; use super::schema::schema_is_eq; use super::*; +use crate::binder::{IndexType, VectorDistance}; use crate::planner::ExprExt; +use crate::types::DataValue; /// Returns the rules that always improve the plan. pub fn always_better_rules() -> Vec { @@ -398,6 +400,71 @@ pub fn projection_pushdown_rules() -> Vec { vec![ ), ]} +/// Pushdown projections and prune unused columns. +#[rustfmt::skip] +pub fn index_scan_rules() -> Vec { vec![ + rw!("vector-index-scan-1"; + "(order (list (<-> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <-> ?column ?vector)" + if has_vector_index("?column", "<->", "?vector", "?filter") + ), + rw!("vector-index-scan-2"; + "(order (list (<#> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <#> ?column ?vector)" + if has_vector_index("?column", "<#>", "?vector", "?filter") + ), + rw!("vector-index-scan-3"; + "(order (list (<=> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <=> ?column ?vector)" + if has_vector_index("?column", "<=>", "?vector", "?filter") + ), +]} + +fn has_vector_index( + column: &str, + op: &str, + vector: &str, + filter: &str, +) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let column = var(column); + let vector = var(vector); + let filter = var(filter); + let op = op.to_string(); + move |egraph, _, subst| { + let filter = &egraph[subst[filter]].data; + let vector = &egraph[subst[vector]].data; + let column = &egraph[subst[column]].data; + let Ok(vector_op) = op.parse::() else { + return false; + }; + // Only support null filter or always true filter for now + if !matches!(filter.constant, Some(DataValue::Bool(true)) | None) { + return false; + } + if !matches!(vector.constant, Some(DataValue::Vector(_))) { + return false; + } + if column.columns.len() != 1 { + return false; + } + let column = column.columns.iter().next().unwrap(); + let Expr::Column(col) = column else { + return false; + }; + let catalog = &egraph.analysis.catalog; + let indexes = catalog.get_index_on_table(col.schema_id, col.table_id); + for index_id in indexes { + let index = catalog.get_index_by_id(col.schema_id, index_id).unwrap(); + if index.column_idxs() != [col.column_id] { + continue; + } + if let IndexType::IvfFlat { distance, .. } = index.index_type() { + if distance == vector_op { + return true; + } + } + } + false + } +} + /// Returns true if the columns used in `expr` is disjoint from columns produced by `plan`. fn not_depend_on(expr: &str, plan: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let expr = var(expr); diff --git a/src/storage/memory/mod.rs b/src/storage/memory/mod.rs index 6d7a37f7..b2eeb6f0 100644 --- a/src/storage/memory/mod.rs +++ b/src/storage/memory/mod.rs @@ -22,6 +22,7 @@ use std::sync::{Arc, Mutex}; use super::index::InMemoryIndexes; use super::{InMemoryIndex, Storage, StorageError, StorageResult, TracedStorageError}; +use crate::binder::IndexType; use crate::catalog::{ ColumnCatalog, ColumnId, IndexId, RootCatalog, RootCatalogRef, SchemaId, TableId, TableRefId, }; @@ -133,10 +134,17 @@ impl Storage for InMemoryStorage { index_name: &str, table_id: TableId, column_idxs: &[ColumnId], + index_type: &IndexType, ) -> StorageResult { let idx_id = self .catalog - .add_index(schema_id, index_name.to_string(), table_id, column_idxs) + .add_index( + schema_id, + index_name.to_string(), + table_id, + column_idxs, + index_type, + ) .map_err(|_| StorageError::Duplicated("index", index_name.into()))?; self.indexes .lock() diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 8eacafab..5e958772 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -24,6 +24,7 @@ pub use chunk::*; use enum_dispatch::enum_dispatch; use crate::array::{ArrayImpl, DataChunk}; +use crate::binder::IndexType; use crate::catalog::{ ColumnCatalog, ColumnId, IndexId, RootCatalog, SchemaId, TableId, TableRefId, }; @@ -93,6 +94,7 @@ pub trait Storage: Sync + Send + 'static { index_name: &str, table_id: TableId, column_idxs: &[ColumnId], + index_type: &IndexType, ) -> impl Future> + Send; /// Get the catalog of the storage engine. diff --git a/src/storage/secondary/mod.rs b/src/storage/secondary/mod.rs index c6d4705d..0a7a5242 100644 --- a/src/storage/secondary/mod.rs +++ b/src/storage/secondary/mod.rs @@ -34,6 +34,7 @@ use version_manager::*; use super::index::InMemoryIndexes; use super::{InMemoryIndex, Storage, StorageError, StorageResult, TracedStorageError}; +use crate::binder::IndexType; use crate::catalog::{ ColumnCatalog, ColumnId, IndexId, RootCatalog, RootCatalogRef, SchemaId, TableId, TableRefId, }; @@ -200,10 +201,17 @@ impl Storage for SecondaryStorage { index_name: &str, table_id: TableId, column_idxs: &[ColumnId], + index_type: &IndexType, ) -> StorageResult { let idx_id = self .catalog - .add_index(schema_id, index_name.to_string(), table_id, column_idxs) + .add_index( + schema_id, + index_name.to_string(), + table_id, + column_idxs, + index_type, + ) .map_err(|_| StorageError::Duplicated("index", index_name.into()))?; self.indexes .lock() diff --git a/tests/planner_test/vector.planner.sql b/tests/planner_test/vector.planner.sql new file mode 100644 index 00000000..e2294aea --- /dev/null +++ b/tests/planner_test/vector.planner.sql @@ -0,0 +1,15 @@ +-- match the index +explain select * from t order by a <-> '[0, 0, 1]'::VECTOR(3); + +/* +IndexScan { table: t, columns: [ a, b ], filter: true, op: <->, key: a, vector: [0,0,1], cost: 0, rows: 1 } +*/ + +-- match the index +explain select * from t order by a <=> '[0, 0, 1]'::VECTOR(3); + +/* +Order { by: [ VectorCosineDistance { lhs: a, rhs: [0,0,1] } ], cost: 18, rows: 3 } +└── Scan { table: t, list: [ a, b ], filter: true, cost: 6, rows: 3 } +*/ + diff --git a/tests/planner_test/vector.yml b/tests/planner_test/vector.yml new file mode 100644 index 00000000..39508573 --- /dev/null +++ b/tests/planner_test/vector.yml @@ -0,0 +1,18 @@ +- sql: | + explain select * from t order by a <-> '[0, 0, 1]'::VECTOR(3); + desc: match the index + before: + - CREATE TABLE t (a vector(3) not null, b text not null); + INSERT INTO t VALUES ('[0, 0, 1]', 'a'), ('[0, 0, 2]', 'b'), ('[0, 0, 3]', 'c'); + CREATE INDEX t_ivfflat ON t USING ivfflat (a) WITH (distfn = '<->', nlists = 3, nprobe = 2); + tasks: + - print +- sql: | + explain select * from t order by a <=> '[0, 0, 1]'::VECTOR(3); + desc: match the index + before: + - CREATE TABLE t (a vector(3) not null, b text not null); + INSERT INTO t VALUES ('[0, 0, 1]', 'a'), ('[0, 0, 2]', 'b'), ('[0, 0, 3]', 'c'); + CREATE INDEX t_ivfflat ON t USING ivfflat (a) WITH (distfn = '<->', nlists = 3, nprobe = 2); + tasks: + - print diff --git a/tests/sql/vector_index.slt b/tests/sql/vector_index.slt new file mode 100644 index 00000000..bc144723 --- /dev/null +++ b/tests/sql/vector_index.slt @@ -0,0 +1,18 @@ +# vector_index +statement ok +create table t (a vector(3) not null, b text not null); + +statement ok +insert into t values ('[-1, -2.0, -3]', 'a'), ('[1, 2.0, 3]', 'b'); + +query RRR +select * from t order by a <-> '[0, 0, 1]'::VECTOR(3); +---- +[1,2,3] b +[-1,-2,-3] a + +statement ok +CREATE INDEX t_ivfflat ON t USING ivfflat (a) WITH (distfn = 'l2', nlists = 3, nprobe = 2); + +statement ok +drop table t