Skip to content

Commit

Permalink
feat(planner): match vector indexes
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Jan 26, 2025
1 parent 1d646b8 commit bf8b849
Show file tree
Hide file tree
Showing 17 changed files with 292 additions and 8 deletions.
111 changes: 111 additions & 0 deletions src/binder/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,47 @@ use serde::{Deserialize, Serialize};
use super::*;
use crate::catalog::{ColumnId, SchemaId, TableId};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub enum VectorDistance {
Cosine,
L2,
NegativeDotProduct,
}

impl FromStr for VectorDistance {
type Err = String;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"cosine" => Ok(VectorDistance::Cosine),
"<=>" => Ok(VectorDistance::Cosine),
"l2" => Ok(VectorDistance::L2),
"<->" => Ok(VectorDistance::L2),
"dotproduct" => Ok(VectorDistance::NegativeDotProduct),
"<#>" => Ok(VectorDistance::NegativeDotProduct),
_ => Err(format!("invalid vector distance: {}", s)),
}
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub enum IndexType {
Hnsw,
IvfFlat {
distance: VectorDistance,
nlists: usize,
nprobe: usize,
},
Btree,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct CreateIndex {
pub schema_id: SchemaId,
pub index_name: String,
pub table_id: TableId,
pub columns: Vec<ColumnId>,
pub index_type: IndexType,
}

impl fmt::Display for CreateIndex {
Expand Down Expand Up @@ -48,6 +83,79 @@ impl FromStr for Box<CreateIndex> {
}

impl Binder {
fn parse_index_type(&self, using: Option<Ident>, with: Vec<Expr>) -> Result<IndexType> {
let Some(using) = using else {
return Err(ErrorKind::InvalidIndex("using clause is required".to_string()).into());
};
match using.to_string().to_lowercase().as_str() {
"hnsw" => Ok(IndexType::Hnsw),
"ivfflat" => {
let mut distfn = None;
let mut nlists = None;
let mut nprobe = None;
for expr in with {
let Expr::BinaryOp { left, op, right } = expr else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
if op != BinaryOperator::Eq {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
}
let Expr::Identifier(Ident { value: key, .. }) = *left else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
let key = key.to_lowercase();
let Expr::Value(v) = *right else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
let v: DataValue = v.into();
match key.as_str() {
"distfn" => {
let v = v.as_str();
distfn = Some(v.to_lowercase());
}
"nlists" => {
let Some(v) = v.as_usize().unwrap() else {
return Err(ErrorKind::InvalidIndex(
"invalid with clause".to_string(),
)
.into());
};
nlists = Some(v);
}
"nprobe" => {
let Some(v) = v.as_usize().unwrap() else {
return Err(ErrorKind::InvalidIndex(
"invalid with clause".to_string(),
)
.into());
};
nprobe = Some(v);
}
_ => {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
}
}
}
Ok(IndexType::IvfFlat {
distance: VectorDistance::from_str(distfn.unwrap().as_str()).unwrap(),
nlists: nlists.unwrap(),
nprobe: nprobe.unwrap(),
})
}
_ => Err(ErrorKind::InvalidIndex("invalid index type".to_string()).into()),
}
}

pub(super) fn bind_create_index(&mut self, stat: crate::parser::CreateIndex) -> Result {
let Some(ref name) = stat.name else {
return Err(
Expand All @@ -57,6 +165,8 @@ impl Binder {
let crate::parser::CreateIndex {
table_name,
columns,
using,
with,
..
} = stat;
let index_name = lower_case_name(name);
Expand Down Expand Up @@ -94,6 +204,7 @@ impl Binder {
index_name: index_name.into(),
table_id: table.id(),
columns: column_ids,
index_type: self.parse_index_type(using, with)?,
})));
Ok(create)
}
Expand Down
2 changes: 1 addition & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod select;
mod table;

pub use self::create_function::CreateFunction;
pub use self::create_index::CreateIndex;
pub use self::create_index::{CreateIndex, IndexType, VectorDistance};
pub use self::create_table::CreateTable;
pub use self::error::BindError;
use self::error::ErrorKind;
Expand Down
15 changes: 14 additions & 1 deletion src/catalog/index.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
// Copyright 2025 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::binder::IndexType;

/// The catalog of an index.
pub struct IndexCatalog {
id: IndexId,
name: String,
table_id: TableId,
column_idxs: Vec<ColumnId>,
index_type: IndexType,
}

impl IndexCatalog {
pub fn new(id: IndexId, name: String, table_id: TableId, column_idxs: Vec<ColumnId>) -> Self {
pub fn new(
id: IndexId,
name: String,
table_id: TableId,
column_idxs: Vec<ColumnId>,
index_type: IndexType,
) -> Self {
Self {
id,
name,
table_id,
column_idxs,
index_type,
}
}

Expand All @@ -35,4 +44,8 @@ impl IndexCatalog {
pub fn name(&self) -> &str {
&self.name
}

pub fn index_type(&self) -> IndexType {
self.index_type.clone()
}
}
4 changes: 3 additions & 1 deletion src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex};

use super::function::FunctionCatalog;
use super::*;
use crate::binder::IndexType;
use crate::parser;
use crate::planner::RecExpr;

Expand Down Expand Up @@ -104,10 +105,11 @@ impl RootCatalog {
index_name: String,
table_id: TableId,
column_idxs: &[ColumnId],
index_type: &IndexType,
) -> Result<IndexId, CatalogError> {
let mut inner = self.inner.lock().unwrap();
let schema = inner.schemas.get_mut(&schema_id).unwrap();
schema.add_index(index_name, table_id, column_idxs.to_vec())
schema.add_index(index_name, table_id, column_idxs.to_vec(), index_type)
}

pub fn get_index_on_table(&self, schema_id: SchemaId, table_id: TableId) -> Vec<IndexId> {
Expand Down
10 changes: 9 additions & 1 deletion src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;

use super::function::FunctionCatalog;
use super::*;
use crate::binder::IndexType;
use crate::planner::RecExpr;

/// The catalog of a schema.
Expand Down Expand Up @@ -62,13 +63,20 @@ impl SchemaCatalog {
name: String,
table_id: TableId,
columns: Vec<ColumnId>,
index_type: &IndexType,
) -> Result<IndexId, CatalogError> {
if self.indexes_idxs.contains_key(&name) {
return Err(CatalogError::Duplicated("index", name));
}
let index_id = self.next_id;
self.next_id += 1;
let index_catalog = Arc::new(IndexCatalog::new(index_id, name.clone(), table_id, columns));
let index_catalog = Arc::new(IndexCatalog::new(
index_id,
name.clone(),
table_id,
columns,
index_type.clone(),
));
self.indexes_idxs.insert(name, index_id);
self.indexes.insert(index_id, index_catalog);
Ok(index_id)
Expand Down
1 change: 1 addition & 0 deletions src/executor/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl<S: Storage> CreateIndexExecutor<S> {
&self.index.index_name,
self.index.table_id,
&self.index.columns,
&self.index.index_type,
)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion src/planner/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl egg::CostFunction<Expr> for CostFn<'_> {

let c = match enode {
// plan nodes
Scan(_) | Values(_) => build(),
Scan(_) | Values(_) | IndexScan(_) => build(),
Order([_, c]) => nlogn(rows(c)) + build() + costs(c),
Filter([exprs, c]) => costs(exprs) * rows(c) + build() + costs(c),
Proj([exprs, c]) | Window([exprs, c]) => costs(exprs) * rows(c) + costs(c),
Expand Down
11 changes: 11 additions & 0 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ impl<'a> Explain<'a> {
("filter", self.expr(filter).pretty()),
]),
),
IndexScan([table, columns, filter, op, key, vector]) => Pretty::childless_record(
"IndexScan",
with_meta(vec![
("table", self.expr(table).pretty()),
("columns", self.expr(columns).pretty()),
("filter", self.expr(filter).pretty()),
("op", self.expr(op).pretty()),
("key", self.expr(key).pretty()),
("vector", self.expr(vector).pretty()),
]),
),
Values(values) => Pretty::simple_record(
"Values",
with_meta(vec![("rows", Pretty::display(&values.len()))]),
Expand Down
1 change: 1 addition & 0 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ define_language! {

// plans
"scan" = Scan([Id; 3]), // (scan table [column..] filter)
"vector_index_scan" = IndexScan([Id; 6]), // (vector_index_scan table [column..] filter <op> key vector)
"values" = Values(Box<[Id]>), // (values [expr..]..)
"proj" = Proj([Id; 2]), // (proj [expr..] child)
"filter" = Filter([Id; 2]), // (filter expr child)
Expand Down
3 changes: 2 additions & 1 deletion src/planner/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ static STAGE1_RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(|| {
});

/// Stage2 rules in the optimizer.
/// - pushdown predicate and projection
/// - pushdown predicate, projection, and index scan
static STAGE2_RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(|| {
let mut rules = vec![];
rules.append(&mut rules::expr::rules());
rules.append(&mut rules::plan::always_better_rules());
rules.append(&mut rules::plan::predicate_pushdown_rules());
rules.append(&mut rules::plan::projection_pushdown_rules());
rules.append(&mut rules::plan::index_scan_rules());
rules
});

Expand Down
67 changes: 67 additions & 0 deletions src/planner/rules/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use itertools::Itertools;

use super::schema::schema_is_eq;
use super::*;
use crate::binder::{IndexType, VectorDistance};
use crate::planner::ExprExt;
use crate::types::DataValue;

/// Returns the rules that always improve the plan.
pub fn always_better_rules() -> Vec<Rewrite> {
Expand Down Expand Up @@ -398,6 +400,71 @@ pub fn projection_pushdown_rules() -> Vec<Rewrite> { vec![
),
]}

/// Pushdown projections and prune unused columns.
#[rustfmt::skip]
pub fn index_scan_rules() -> Vec<Rewrite> { vec![
rw!("vector-index-scan-1";
"(order (list (<-> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <-> ?column ?vector)"
if has_vector_index("?column", "<->", "?vector", "?filter")
),
rw!("vector-index-scan-2";
"(order (list (<#> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <#> ?column ?vector)"
if has_vector_index("?column", "<#>", "?vector", "?filter")
),
rw!("vector-index-scan-3";
"(order (list (<=> ?column ?vector)) (scan ?table ?columns ?filter))" => "(vector_index_scan ?table ?columns ?filter <=> ?column ?vector)"
if has_vector_index("?column", "<=>", "?vector", "?filter")
),
]}

fn has_vector_index(
column: &str,
op: &str,
vector: &str,
filter: &str,
) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let column = var(column);
let vector = var(vector);
let filter = var(filter);
let op = op.to_string();
move |egraph, _, subst| {
let filter = &egraph[subst[filter]].data;
let vector = &egraph[subst[vector]].data;
let column = &egraph[subst[column]].data;
let Ok(vector_op) = op.parse::<VectorDistance>() else {
return false;
};
// Only support null filter or always true filter for now
if !matches!(filter.constant, Some(DataValue::Bool(true)) | None) {
return false;
}
if !matches!(vector.constant, Some(DataValue::Vector(_))) {
return false;
}
if column.columns.len() != 1 {
return false;
}
let column = column.columns.iter().next().unwrap();
let Expr::Column(col) = column else {
return false;
};
let catalog = &egraph.analysis.catalog;
let indexes = catalog.get_index_on_table(col.schema_id, col.table_id);
for index_id in indexes {
let index = catalog.get_index_by_id(col.schema_id, index_id).unwrap();
if index.column_idxs() != [col.column_id] {
continue;
}
if let IndexType::IvfFlat { distance, .. } = index.index_type() {
if distance == vector_op {
return true;
}
}
}
false
}
}

/// Returns true if the columns used in `expr` is disjoint from columns produced by `plan`.
fn not_depend_on(expr: &str, plan: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let expr = var(expr);
Expand Down
Loading

0 comments on commit bf8b849

Please sign in to comment.