Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sanity test #14

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
lint:
uv run ruff check
uv run ruff format --check .
uv run --index-strategy unsafe-best-match -- ruff check
uv run --index-strategy unsafe-best-match -- ruff format --check .
.PHONY: lint

fix:
uv run ruff check --fix
uv run ruff format .
uv run --index-strategy unsafe-best-match -- ruff check --fix
uv run --index-strategy unsafe-best-match -- ruff format .
.PHONY: fix
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@ export LANCEDB_DB_URI=<your db uri from lancedb cloud console, i.e. "db://mydb-d
export LANCEDB_HOST_OVERRIDE=<optional uri if using lancedb enterprise>`
```

3. Run the benchmark
3. Install dependencies
```
uv pip install . --index-strategy unsafe-best-match
```

4. Run the sanity test

`uv run sanity.py`

5. Run the benchmark

`uv run bench.py`

Expand Down
69 changes: 63 additions & 6 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import traceback
from typing import Iterable, List, Tuple
import multiprocessing as mp
import backoff

from lancedb.remote.errors import LanceDBClientError
from lancedb.remote.table import RemoteTable
Expand Down Expand Up @@ -86,6 +87,13 @@ def add_benchmark_args(parser: argparse.ArgumentParser):
action=argparse.BooleanOptionalAction,
help="drop tables before starting",
)
parser.add_argument(
"-s",
"--size",
type=int,
default=None,
help="number of rows to ingest (no limit if not specified)",
)


class Benchmark:
Expand All @@ -100,6 +108,7 @@ def __init__(
index: bool,
prefix: str,
reset: bool,
size: int = None,
):
self.dataset = dataset
self.num_tables = num_tables
Expand All @@ -109,12 +118,19 @@ def __init__(
self.index = index
self.prefix = prefix
self.reset = reset
self.size = size

azure_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
storage_options = None
if azure_account_name:
storage_options = {"azure_storage_account_name": azure_account_name}

self.db = lancedb.connect(
uri=os.environ["LANCEDB_DB_URI"],
api_key=os.environ["LANCEDB_API_KEY"],
host_override=os.getenv("LANCEDB_HOST_OVERRIDE"),
region=os.getenv("LANCEDB_REGION", "us-east-1"),
storage_options=storage_options,
)

if query_type == QueryType.VECTOR.value:
Expand Down Expand Up @@ -271,15 +287,48 @@ def _await_index(self, table: RemoteTable, index_type: str, start_time):
f"{table.name}: {index_type} indexing completed in {int(time.time() - start_time)}s."
)

@backoff.on_exception(
backoff.constant,
LanceDBClientError,
max_time=6000,
interval=10,
logger=None,
giveup=lambda e: "Commit conflict for version" not in str(e),
)
def create_vector_index(self, table):
table.create_index(
metric="cosine", vector_column_name="openai", index_type="IVF_PQ"
)

@backoff.on_exception(
backoff.constant,
LanceDBClientError,
max_time=6000,
interval=10,
logger=None,
giveup=lambda e: "Commit conflict for version" not in str(e),
)
def create_scalar_index(self, table):
table.create_scalar_index("id", index_type="BTREE")

@backoff.on_exception(
backoff.constant,
LanceDBClientError,
max_time=6000,
interval=10,
logger=None,
giveup=lambda e: "Commit conflict for version" not in str(e),
)
def create_fts_index(self, table: RemoteTable):
table.create_fts_index("title")

def _create_indices(self):
# create the indices - these will be created async
table_indices = {}
for t in self.tables:
t.create_index(
metric="cosine", vector_column_name="openai", index_type="IVF_PQ"
)
t.create_scalar_index("id", index_type="BTREE")
t.create_fts_index("title")
self.create_vector_index(t)
self.create_scalar_index(t)
self.create_fts_index(t)
table_indices[t] = ["IVF_PQ", "FTS", "BTREE"]

print("waiting for index completion...")
Expand Down Expand Up @@ -316,7 +365,10 @@ def _convert_dataset(self, schema) -> Iterable[pa.RecordBatch]:

buffer = []
buffer_rows = 0
total_converted_rows = 0
for batch in batch_iterator:
if self.size is not None and total_converted_rows >= self.size:
break
rb = pa.RecordBatch.from_arrays(
[
batch["_id"],
Expand All @@ -334,10 +386,12 @@ def _convert_dataset(self, schema) -> Iterable[pa.RecordBatch]:
)[0]
buffer.clear()
buffer_rows = 0
total_converted_rows += len(combined)
yield combined
else:
buffer.append(rb)
buffer_rows += len(rb)
total_converted_rows += len(rb)

for b in buffer:
yield b
Expand All @@ -346,7 +400,7 @@ def _query_table(self, table: RemoteTable, warmup_queries=100):
# log a warning if data is not fully indexed
try:
total_rows = table.count_rows()
for idx in table.list_indices()["indexes"]:
for idx in table.list_indices():
stats = table.index_stats(idx["index_name"])
if total_rows != stats["num_indexed_rows"]:
print(
Expand Down Expand Up @@ -427,6 +481,7 @@ def run_multi_benchmark(
index: bool,
prefix: str,
reset: bool,
size: int,
) -> BenchmarkResults:
total_processes = num_processes * (
query_processes if not ingest and not index else 1
Expand All @@ -443,6 +498,7 @@ def run_multi_benchmark(
"index": index,
"prefix": prefix, # Base prefix, will be modified per process
"reset": reset,
"size": size,
}

process_args = []
Expand Down Expand Up @@ -522,6 +578,7 @@ def main():
args.index,
args.prefix,
args.reset,
args.size,
)

result.print()
Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lancedb-cloud-benchmarks"
version = "0.1.2"
version = "0.1.3"
description = ""
authors = [{ name = "LanceDB Devs", email = "[email protected]" }]
readme = "README.md"
Expand All @@ -14,5 +14,14 @@ dependencies = [
]

[tool.uv]
prerelease = "allow"
dev-dependencies = [
]

[[tool.uv.index]]
name = "fury"
url = "https://pypi.fury.io/lancedb/"

[[tool.uv.index]]
name = "pypi"
url = "https://pypi.org/simple"
98 changes: 98 additions & 0 deletions sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from bench import Benchmark, QueryType, add_benchmark_args
import argparse
import time


def get_default_args():
"""Get default arguments from the benchmark argument parser"""
parser = argparse.ArgumentParser()
add_benchmark_args(parser)
default_args = parser.parse_args([])
args_dict = vars(default_args)

# Map parser argument names to Benchmark parameter names
return {
"dataset": args_dict["dataset"],
"num_tables": args_dict["tables"],
"batch_size": args_dict["batch"],
"num_queries": args_dict["queries"],
"query_type": args_dict["query_type"],
"ingest": args_dict["ingest"],
"index": args_dict["index"],
"prefix": args_dict["prefix"],
"reset": args_dict["reset"],
"size": args_dict["size"],
}


def run_benchmark(benchmark_args: dict) -> None:
print(f"\nRunning sanity test with args: {benchmark_args}")
benchmark = Benchmark(**benchmark_args)
benchmark.run()
print("Sanity test passed")


def main():
# Setting for sanity runs
batch_size = 100
dataset_size = 1000
num_queries = 3 # Number of queries per run

base_args = get_default_args()

# Override only the parameters we want to change
base_args.update(
{
"num_tables": 1,
"batch_size": batch_size,
"size": dataset_size,
"prefix": f"sanity-test-{int(time.time())}",
}
)

print("=== Starting Sanity Test ===")
print(f"Using prefix: {base_args['prefix']}")

try:
# Step 1: Ingest and index data
print("\n=== Step 1: Data Ingestion and Indexing ===")
ingest_args = base_args.copy()
ingest_args.update(
{
"num_queries": 0, # No queries during ingestion
"ingest": True,
"index": True,
"reset": True,
}
)
run_benchmark(ingest_args)

# Step 2: Run queries for each query type
print("\n=== Step 2: Running Queries ===")

query_args = base_args.copy()
query_args.update(
{
"num_queries": num_queries,
"ingest": False, # Skip ingestion
"index": False, # Skip indexing
"reset": False,
}
)

for query_type in QueryType:
print(f"\n--- Testing {query_type.value} queries ---")

run_args = query_args.copy()
run_args["query_type"] = query_type.value
run_benchmark(run_args)

print("\n=== Sanity Test Completed ===")
except Exception as e:
print("\n=== Sanity Test Failed ===")
print(f"Error: {e}")
raise


if __name__ == "__main__":
main()
26 changes: 12 additions & 14 deletions src/cloud/benchmark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
import json


@backoff.on_exception(
backoff.constant, ValueError, max_time=600, interval=10, logger=None
)
@backoff.on_exception(backoff.constant, ValueError, max_time=600, interval=10)
def await_indices(
table: RemoteTable, count: int = 1, index_types: Optional[list[str]] = []
table: RemoteTable,
count: int = 1,
index_types: Optional[list[str]] = [],
) -> list[dict]:
"""poll for all indices to be created on the table"""
indices = table.list_indices()
# print(f"current indices for table {table}: {indices}")
# The old SDK returns a dict with a key "indexes" containing the list of indices
if isinstance(indices, dict):
indices = indices["indexes"]

result_indices = []
for index in indices["indexes"]:
for index in indices:
if not index["index_name"]:
raise ValueError("still waiting for index creation")
result_indices.append(index)
Expand All @@ -33,22 +36,17 @@ def await_indices(
f"(current: {len(result_indices)}, desired: {count})"
)

index_names = [n["index_name"] for n in result_indices]
stats = [table.index_stats(n) for n in index_names]
if index_types:
types = [s["index_type"] for s in stats]
index_names = [n["index_name"] for n in result_indices]
stats = [table.index_stats(n) for n in index_names]
types = [stat["index_type"] for stat in stats]
for t in index_types:
if t not in types:
raise ValueError(
f"still waiting for correct index type "
f"(current: {types}, desired: {index_types})"
)

unindexed_rows = [s["num_unindexed_rows"] for s in stats]
for u in unindexed_rows:
if u != 0:
raise ValueError(f"still waiting for unindexed rows to be 0 (current: {u})")

return result_indices


Expand Down