Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions rust/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
# Rust extension for pymde

This directory contains a Rust implementation of brute-force exact L2
k-nearest neighbor search, exposed to Python via [PyO3](https://pyo3.rs).
It is a self-contained, solution that uses platform-native BLAS
(Accelerate on macOS, OpenBLAS on Linux) for accelerated matrix operations.
This directory contains Rust implementations of core pymde algorithms,
exposed to Python via [PyO3](https://pyo3.rs):

- **Exact kNN** (`knn_l2`): brute-force L2 k-nearest neighbor search using
platform-native BLAS (Accelerate on macOS, OpenBLAS on Linux). Can be
extremely fast and competitive with approximate kNN algorithms on
machines with enough cores.
- **Approximate kNN** (`nn_descent`): NN-Descent algorithm for building
approximate k-nearest neighbor graphs. Uses RP-tree initialization and
iterative local joins to converge on high-recall neighbor graphs. Originally
based on the pynndescent implementation. For machines with only a few cores,
or on very large datasets, this can be much faster than the exact alternative.
- **BFS** (`breadth_first_directed`): breadth-first search on directed CSR
graphs.

## Prerequisites

Expand Down Expand Up @@ -33,7 +43,7 @@ pip install -e '.[dev]'
```

This compiles the Rust code in release mode and places the resulting shared
library (`_knn.*.so` / `_knn.*.pyd`) into `pymde/`.
library (`_native.*.so` / `_native.*.pyd`) into `pymde/`.

To rebuild after editing Rust code, run the same command again. Only changed
files are recompiled.
Expand All @@ -46,48 +56,73 @@ files are recompiled.
cd rust && cargo test
```

This runs the native Rust tests (for `insert_topk`, `sgemm_nn_t`, and
`knn_blas_tiled`) without needing Python. No extra setup beyond the Rust
toolchain is required.
This runs the native Rust tests without needing Python. No extra setup
beyond the Rust toolchain is required.

### Python integration tests

```sh
pytest pymde/test_knn.py -v
pytest pymde/test_knn.py -v # exact kNN
pytest pymde/preprocess/test_nndescent.py -v # approximate kNN
```

## Project layout

```
rust/
├── Cargo.toml # Package metadata and dependencies
├── Cargo.lock # Pinned dependency versions (committed for reproducibility)
├── Cargo.toml # Package metadata and dependencies
├── Cargo.lock # Pinned dependency versions (committed for reproducibility)
└── src/
└── lib.rs # All Rust source code (single file)
├── lib.rs # PyO3 module definition and exports
├── knn.rs # Exact kNN (BLAS-accelerated brute force)
├── blas.rs # BLAS FFI bindings (sgemm)
├── nndescent.rs # NN-Descent approximate kNN algorithm
├── heap.rs # Thread-safe neighbor heaps with AtomicBool try-locks
├── candidates.rs # Candidate tracking for NN-Descent iterations
├── distance.rs # L2 distance kernels (with NEON intrinsics on aarch64)
├── rng.rs # Fast deterministic PRNG (SplitMix64)
└── bfs.rs # Breadth-first search on directed CSR graphs
```

## How it works

The module exposes one Python function: `pymde._knn.knn_l2(data, k)`.
### Exact kNN (`knn_l2`)

The algorithm:
`pymde._native.knn_l2(data, k)` — brute-force exact search.

1. Precompute squared norms `||x_i||^2` for every row.
2. Tile the data matrix into query blocks and database blocks.
3. For each tile pair, compute pairwise inner products using BLAS `sgemm`
(the fastest way to do dense matrix multiply).
3. For each tile pair, compute pairwise inner products using BLAS `sgemm`.
4. Recover squared distances via `||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a · b`.
5. Maintain a sorted top-k list per query row, keeping only the closest neighbors.

Query tiles are processed in parallel using [rayon](https://docs.rs/rayon).

### Approximate kNN (`nn_descent`)

`pymde._native.nn_descent(data, n_neighbors)` — approximate search via
NN-Descent, much faster than exact search for large datasets.

1. **RP-Tree Init**: Build random projection trees to get an initial neighbor
graph. Points in the same leaf node become candidate neighbors.
2. **NN-Descent Loop**: Iteratively refine the graph using local joins — for
each point, compare its neighbors' neighbors as potential new neighbors.
Repeat until convergence (few updates per iteration).
3. **Finalize**: Sort heaps, apply sqrt to distances, return
`(neighbors, distances)`.

Thread safety uses per-point `AtomicBool` try-locks for concurrent heap
updates, skipping on contention rather than blocking.

## Key dependencies

| Crate | Purpose |
|-------|---------|
| [pyo3](https://pyo3.rs) | Rust ↔ Python bindings (function signatures, type conversions, GIL management) |
| [numpy](https://docs.rs/numpy) | Zero-copy access to NumPy arrays from Rust |
| [rayon](https://docs.rs/rayon) | Data-parallel iteration (parallelizes across query tiles) |
| [rayon](https://docs.rs/rayon) | Data-parallel iteration (parallelizes across query tiles and NN-Descent joins) |
| [rand](https://docs.rs/rand) | Random number generation (RP-tree construction) |
| [rand_chacha](https://docs.rs/rand_chacha) | Deterministic seeded RNG for reproducibility |

BLAS is linked directly via `extern "C"` — no Rust BLAS crate is used.

Expand All @@ -104,11 +139,11 @@ BLAS is linked directly via `extern "C"` — no Rust BLAS crate is used.
}
```

2. Export it from the module at the bottom of `lib.rs`:
2. Export it from the module in `lib.rs`:

```rust
#[pymodule]
mod _knn {
mod _native {
#[pymodule_export]
use super::my_function;
}
Expand Down
Loading