Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(planner): match vector indexes #874

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/binder/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,44 @@ use serde::{Deserialize, Serialize};
use super::*;
use crate::catalog::{ColumnId, SchemaId, TableId};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub enum VectorDistance {
Cosine,
L2,
NegativeDotProduct,
}

impl FromStr for VectorDistance {
type Err = String;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"cosine" | "<=>" => Ok(VectorDistance::Cosine),
"l2" | "<->" => Ok(VectorDistance::L2),
"dotproduct" | "<#>" => Ok(VectorDistance::NegativeDotProduct),
_ => Err(format!("invalid vector distance: {}", s)),
}
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub enum IndexType {
Hnsw,
IvfFlat {
distance: VectorDistance,
nlists: usize,
nprobe: usize,
},
Btree,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct CreateIndex {
pub schema_id: SchemaId,
pub index_name: String,
pub table_id: TableId,
pub columns: Vec<ColumnId>,
pub index_type: IndexType,
}

impl fmt::Display for CreateIndex {
Expand Down Expand Up @@ -48,6 +80,80 @@ impl FromStr for Box<CreateIndex> {
}

impl Binder {
fn parse_index_type(&self, using: Option<Ident>, with: Vec<Expr>) -> Result<IndexType> {
let Some(using) = using else {
return Err(ErrorKind::InvalidIndex("using clause is required".to_string()).into());
};
match using.to_string().to_lowercase().as_str() {
"btree" => Ok(IndexType::Btree),
"hnsw" => Ok(IndexType::Hnsw),
"ivfflat" => {
let mut distfn = None;
let mut nlists = None;
let mut nprobe = None;
for expr in with {
let Expr::BinaryOp { left, op, right } = expr else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
if op != BinaryOperator::Eq {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
}
let Expr::Identifier(Ident { value: key, .. }) = *left else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
let key = key.to_lowercase();
let Expr::Value(v) = *right else {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
};
let v: DataValue = v.into();
match key.as_str() {
"distfn" => {
let v = v.as_str();
distfn = Some(v.to_lowercase());
}
"nlists" => {
let Some(v) = v.as_usize().unwrap() else {
return Err(ErrorKind::InvalidIndex(
"invalid with clause".to_string(),
)
.into());
};
nlists = Some(v);
}
"nprobe" => {
let Some(v) = v.as_usize().unwrap() else {
return Err(ErrorKind::InvalidIndex(
"invalid with clause".to_string(),
)
.into());
};
nprobe = Some(v);
}
_ => {
return Err(
ErrorKind::InvalidIndex("invalid with clause".to_string()).into()
);
}
}
}
Ok(IndexType::IvfFlat {
distance: VectorDistance::from_str(distfn.unwrap().as_str()).unwrap(),
nlists: nlists.unwrap(),
nprobe: nprobe.unwrap(),
})
}
_ => Err(ErrorKind::InvalidIndex("invalid index type".to_string()).into()),
}
}

pub(super) fn bind_create_index(&mut self, stat: crate::parser::CreateIndex) -> Result {
let Some(ref name) = stat.name else {
return Err(
Expand All @@ -57,6 +163,8 @@ impl Binder {
let crate::parser::CreateIndex {
table_name,
columns,
using,
with,
..
} = stat;
let index_name = lower_case_name(name);
Expand Down Expand Up @@ -94,6 +202,7 @@ impl Binder {
index_name: index_name.into(),
table_id: table.id(),
columns: column_ids,
index_type: self.parse_index_type(using, with)?,
})));
Ok(create)
}
Expand Down
2 changes: 1 addition & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod select;
mod table;

pub use self::create_function::CreateFunction;
pub use self::create_index::CreateIndex;
pub use self::create_index::{CreateIndex, IndexType, VectorDistance};
pub use self::create_table::CreateTable;
pub use self::error::BindError;
use self::error::ErrorKind;
Expand Down
15 changes: 14 additions & 1 deletion src/catalog/index.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
// Copyright 2025 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::binder::IndexType;

/// The catalog of an index.
pub struct IndexCatalog {
id: IndexId,
name: String,
table_id: TableId,
column_idxs: Vec<ColumnId>,
index_type: IndexType,
}

impl IndexCatalog {
pub fn new(id: IndexId, name: String, table_id: TableId, column_idxs: Vec<ColumnId>) -> Self {
pub fn new(
id: IndexId,
name: String,
table_id: TableId,
column_idxs: Vec<ColumnId>,
index_type: IndexType,
) -> Self {
Self {
id,
name,
table_id,
column_idxs,
index_type,
}
}

Expand All @@ -35,4 +44,8 @@ impl IndexCatalog {
pub fn name(&self) -> &str {
&self.name
}

pub fn index_type(&self) -> IndexType {
self.index_type.clone()
}
}
4 changes: 3 additions & 1 deletion src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex};

use super::function::FunctionCatalog;
use super::*;
use crate::binder::IndexType;
use crate::parser;
use crate::planner::RecExpr;

Expand Down Expand Up @@ -104,10 +105,11 @@ impl RootCatalog {
index_name: String,
table_id: TableId,
column_idxs: &[ColumnId],
index_type: &IndexType,
) -> Result<IndexId, CatalogError> {
let mut inner = self.inner.lock().unwrap();
let schema = inner.schemas.get_mut(&schema_id).unwrap();
schema.add_index(index_name, table_id, column_idxs.to_vec())
schema.add_index(index_name, table_id, column_idxs.to_vec(), index_type)
}

pub fn get_index_on_table(&self, schema_id: SchemaId, table_id: TableId) -> Vec<IndexId> {
Expand Down
10 changes: 9 additions & 1 deletion src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;

use super::function::FunctionCatalog;
use super::*;
use crate::binder::IndexType;
use crate::planner::RecExpr;

/// The catalog of a schema.
Expand Down Expand Up @@ -62,13 +63,20 @@ impl SchemaCatalog {
name: String,
table_id: TableId,
columns: Vec<ColumnId>,
index_type: &IndexType,
) -> Result<IndexId, CatalogError> {
if self.indexes_idxs.contains_key(&name) {
return Err(CatalogError::Duplicated("index", name));
}
let index_id = self.next_id;
self.next_id += 1;
let index_catalog = Arc::new(IndexCatalog::new(index_id, name.clone(), table_id, columns));
let index_catalog = Arc::new(IndexCatalog::new(
index_id,
name.clone(),
table_id,
columns,
index_type.clone(),
));
self.indexes_idxs.insert(name, index_id);
self.indexes.insert(index_id, index_catalog);
Ok(index_id)
Expand Down
1 change: 1 addition & 0 deletions src/executor/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl<S: Storage> CreateIndexExecutor<S> {
&self.index.index_name,
self.index.table_id,
&self.index.columns,
&self.index.index_type,
)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion src/planner/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl egg::CostFunction<Expr> for CostFn<'_> {

let c = match enode {
// plan nodes
Scan(_) | Values(_) => build(),
Scan(_) | Values(_) | IndexScan(_) => build(),
Order([_, c]) => nlogn(rows(c)) + build() + costs(c),
Filter([exprs, c]) => costs(exprs) * rows(c) + build() + costs(c),
Proj([exprs, c]) | Window([exprs, c]) => costs(exprs) * rows(c) + costs(c),
Expand Down
10 changes: 10 additions & 0 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ impl<'a> Explain<'a> {
("filter", self.expr(filter).pretty()),
]),
),
IndexScan([table, columns, filter, key, vector]) => Pretty::childless_record(
"IndexScan",
with_meta(vec![
("table", self.expr(table).pretty()),
("columns", self.expr(columns).pretty()),
("filter", self.expr(filter).pretty()),
("key", self.expr(key).pretty()),
("vector", self.expr(vector).pretty()),
]),
),
Values(values) => Pretty::simple_record(
"Values",
with_meta(vec![("rows", Pretty::display(&values.len()))]),
Expand Down
1 change: 1 addition & 0 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ define_language! {

// plans
"scan" = Scan([Id; 3]), // (scan table [column..] filter)
"index_scan" = IndexScan([Id; 5]), // (index_scan table [column..] filter key value)
"values" = Values(Box<[Id]>), // (values [expr..]..)
"proj" = Proj([Id; 2]), // (proj [expr..] child)
"filter" = Filter([Id; 2]), // (filter expr child)
Expand Down
3 changes: 2 additions & 1 deletion src/planner/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ static STAGE1_RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(|| {
});

/// Stage2 rules in the optimizer.
/// - pushdown predicate and projection
/// - pushdown predicate, projection, and index scan
static STAGE2_RULES: LazyLock<Vec<Rewrite>> = LazyLock::new(|| {
let mut rules = vec![];
rules.append(&mut rules::expr::rules());
rules.append(&mut rules::plan::always_better_rules());
rules.append(&mut rules::plan::predicate_pushdown_rules());
rules.append(&mut rules::plan::projection_pushdown_rules());
rules.append(&mut rules::plan::index_scan_rules());
rules
});

Expand Down
76 changes: 76 additions & 0 deletions src/planner/rules/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use itertools::Itertools;

use super::schema::schema_is_eq;
use super::*;
use crate::binder::{IndexType, VectorDistance};
use crate::planner::ExprExt;
use crate::types::DataValue;

/// Returns the rules that always improve the plan.
pub fn always_better_rules() -> Vec<Rewrite> {
Expand Down Expand Up @@ -398,6 +400,80 @@ pub fn projection_pushdown_rules() -> Vec<Rewrite> { vec![
),
]}

/// Pushdown projections and prune unused columns.
#[rustfmt::skip]
pub fn index_scan_rules() -> Vec<Rewrite> { vec![
rw!("vector-index-scan-1";
"(order (list (<-> ?column ?vector)) (scan ?table ?columns ?filter))" =>
"(index_scan ?table ?columns ?filter ?column ?vector)"
if has_vector_index("?column", "<->", "?vector", "?filter")
),
rw!("vector-index-scan-2";
"(order (list (<#> ?column ?vector)) (scan ?table ?columns ?filter))" =>
"(index_scan ?table ?columns ?filter ?column ?vector)"
if has_vector_index("?column", "<#>", "?vector", "?filter")
),
rw!("vector-index-scan-3";
"(order (list (<=> ?column ?vector)) (scan ?table ?columns ?filter))" =>
"(index_scan ?table ?columns ?filter ?column ?vector)"
if has_vector_index("?column", "<=>", "?vector", "?filter")
),
]}

/// Check if there is a vector index matching the statement. i.e.,
/// `SELECT * FROM t ORDER BY v <-> constant_vector` will match the index
/// on the table t with the vector column v and using the `<->` distance function.
fn has_vector_index(
column: &str,
op: &str,
vector: &str,
filter: &str,
) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
skyzh marked this conversation as resolved.
Show resolved Hide resolved
let column = var(column);
let vector = var(vector);
let filter = var(filter);
let op = op.to_string();
move |egraph, _, subst| {
let filter = &egraph[subst[filter]].data;
let vector = &egraph[subst[vector]].data;
let column = &egraph[subst[column]].data;
let Ok(vector_op) = op.parse::<VectorDistance>() else {
return false;
};
// Only support null filter or always true filter for now. Check if the filter is null or
// true.
if !matches!(filter.constant, Some(DataValue::Bool(true)) | None) {
return false;
}
if !matches!(vector.constant, Some(DataValue::Vector(_))) {
return false;
}
// Check if the order by statement is in the form of vector column <-> constant vector
if column.columns.len() != 1 {
return false;
}
let column = column.columns.iter().next().unwrap();
let Expr::Column(col) = column else {
return false;
};
let catalog = &egraph.analysis.catalog;
let indexes = catalog.get_index_on_table(col.schema_id, col.table_id);
for index_id in indexes {
// Check if any index matches the exact op and the column
let index = catalog.get_index_by_id(col.schema_id, index_id).unwrap();
if index.column_idxs() != [col.column_id] {
continue;
}
if let IndexType::IvfFlat { distance, .. } = index.index_type() {
if distance == vector_op {
return true;
}
}
}
false
}
}

/// Returns true if the columns used in `expr` is disjoint from columns produced by `plan`.
fn not_depend_on(expr: &str, plan: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let expr = var(expr);
Expand Down
Loading
Loading