From f70c810bbca3d10fa9edfd192f602014f87ccb3d Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 19:03:54 +0900 Subject: [PATCH 1/6] fix: result search all boundaries --- src/lib.rs | 330 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 277 insertions(+), 53 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c355d71..ec1ad34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,153 @@ +use std::{ + cmp::Reverse, + collections::BinaryHeap, + fmt::Debug, + iter::zip, + ops::{Add, Deref, Div, Mul, Sub}, +}; + use serde::{Deserialize, Serialize}; +#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] +#[repr(transparent)] +pub struct NonNanFloat(f64); + +impl PartialEq for NonNanFloat { + #[inline] + fn eq(&self, other: &Self) -> bool { + if self.0.is_nan() { + other.0.is_nan() + } else { + self.0 == other.0 + } + } +} + +impl PartialOrd for NonNanFloat { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Eq for NonNanFloat {} + +impl Ord for NonNanFloat { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl Sub for NonNanFloat { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self(self.0 - other.0) + } +} + +impl Sub for NonNanFloat { + type Output = Self; + + fn sub(self, other: f64) -> Self { + Self(self.0 - other) + } +} + +impl Add for NonNanFloat { + type Output = Self; + + fn add(self, other: Self) -> Self { + NonNanFloat(self.0 + other.0) + } +} + +impl Add for NonNanFloat { + type Output = Self; + + fn add(self, other: f64) -> Self { + NonNanFloat(self.0 + other) + } +} + +impl Mul for NonNanFloat { + type Output = Self; + + fn mul(self, other: Self) -> Self { + NonNanFloat(self.0 * other.0) + } +} + +impl Mul for NonNanFloat { + type Output = Self; + + fn mul(self, other: f64) -> Self { + NonNanFloat(self.0 * other) + } +} + +impl Div for NonNanFloat { + type Output = Self; + + fn div(self, other: Self) -> Self { + NonNanFloat(self.0 / other.0) + } +} + +impl Div for NonNanFloat { + type Output = Self; + + fn div(self, other: f64) -> Self { + NonNanFloat(self.0 / other) + } +} + +impl Deref for NonNanFloat { + type Target = f64; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl NonNanFloat { + fn new(value: f64) -> Self { + assert!(!value.is_nan()); + NonNanFloat(value) + } +} + +impl From for NonNanFloat { + fn from(value: f64) -> Self { + NonNanFloat::new(value) + } +} + #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] pub struct Coordinate { - pub longitude: f64, - pub latitude: f64, + pub longitude: NonNanFloat, + pub latitude: NonNanFloat, +} + +pub trait TCoordianteFloat { + fn into(self) -> NonNanFloat; +} +impl TCoordianteFloat for f64 { + fn into(self) -> NonNanFloat { + NonNanFloat::new(self) + } +} +impl TCoordianteFloat for NonNanFloat { + fn into(self) -> NonNanFloat { + self + } } impl Coordinate { - pub fn new(longitude: f64, latitude: f64) -> Self { + pub fn new(longitude: F, latitude: F) -> Self { Coordinate { - longitude, - latitude, + longitude: longitude.into(), + latitude: latitude.into(), } } @@ -68,10 +205,34 @@ impl Boundary { (nw, ne, sw, se) } + + fn distance(&self, point: &Coordinate) -> NonNanFloat { + if self.contains(point) { + return NonNanFloat(0.0); + } + + let dx = if point.longitude < self.top_left_coor.longitude { + self.top_left_coor.longitude - point.longitude + } else if point.longitude >= self.bottom_right_coor.longitude { + point.longitude - self.bottom_right_coor.longitude + } else { + NonNanFloat(0.0) + }; + + let dy = if point.latitude < self.top_left_coor.latitude { + self.top_left_coor.latitude - point.latitude + } else if point.latitude >= self.bottom_right_coor.latitude { + point.latitude - self.bottom_right_coor.latitude + } else { + NonNanFloat(0.0) + }; + + NonNanFloat((dx.powi(2) + dy.powi(2)).sqrt()) + } } #[derive(Debug, Serialize, Deserialize)] -pub struct Quadtree { +pub struct Quadtree { pub boundary: Boundary, pub capacity: usize, pub coordinates: Vec, @@ -80,7 +241,7 @@ pub struct Quadtree { children: Vec>, } -impl Quadtree { +impl Quadtree { pub fn new(boundary: Boundary, capacity: usize) -> Quadtree { Quadtree { boundary, @@ -158,33 +319,49 @@ impl Quadtree { } // recursive function - fn search<'a: 'b, 'b>(&'a self, distances: &mut Vec<(&'b T, f64)>, query_point: &Coordinate) { - if !self.boundary.contains(query_point) { - return; + fn search<'a>( + &'a self, + nearest_neighbers: &mut BinaryHeap<(NonNanFloat, &'a T)>, + query_point: &Coordinate, + k: usize, + ) { + if nearest_neighbers.len() == k { + if let Some((max_distance, _)) = nearest_neighbers.peek() { + if self.boundary.distance(query_point) >= *max_distance { + return; + } + } } - if self.children.is_empty() { - for (coordinate, interest) in self.coordinates.iter().zip(self.interests.iter()) { - let distance = coordinate.distance(query_point); - distances.push((interest, distance)); - } - } else { - for child in self.children.iter() { - child.search(distances, query_point); + for (coord, v) in zip(&self.coordinates, &self.interests) { + let distance = coord.distance(query_point); + if nearest_neighbers.len() < k { + nearest_neighbers.push((NonNanFloat::new(distance), v)); + } else if let Some((max_distance, _)) = nearest_neighbers.peek() { + if distance >= **max_distance { + continue; + } + if nearest_neighbers.len() == k { + nearest_neighbers.pop(); + } + nearest_neighbers.push((NonNanFloat::new(distance), v)); } } + + for child in self.children.iter() { + child.search(nearest_neighbers, query_point, k); + } } pub fn find_nearest_neighbors(&self, query_point: &Coordinate, k: usize) -> Vec<&T> { - let mut distances = vec![]; - - self.search(&mut distances, query_point); - distances.sort_by(|(_, dis1), (_, dis2)| dis1.partial_cmp(dis2).unwrap()); - distances - .into_iter() - .take(k) - .map(|(c, _)| c) - .collect::>() + let mut nearest_neighbors = BinaryHeap::new(); + self.search(&mut nearest_neighbors, query_point, k); + + nearest_neighbors + .into_sorted_vec() + .iter() + .map(|(_, v)| *v) + .collect() } } @@ -193,18 +370,80 @@ mod test { use crate::{Boundary, Coordinate, Quadtree}; #[test] - fn test_contains() { + fn non_nan_float() { + //GIVEN + let a = 10.0; + let b = 20.0; + let c = 30.0; + let d = 40.0; + + //WHEN + let res = a + b; + let res2 = c - d; + let res3 = a * b; + let res4 = c / d; + let res5 = a > b; + let res6 = a < b; + let res7 = a == b; + + //THEN + assert_eq!(res, 30.0); + assert_eq!(res2, -10.0); + assert_eq!(res3, 200.0); + assert_eq!(res4, 0.75); + assert!(!res5); + assert!(res6); + assert!(!res7); + } + + #[test] + fn test_boundary_contains() { //GIVEN let boundary = Boundary::new(Coordinate::new(0.0, 0.0), Coordinate::new(100.0, 100.0)); //WHEN - let res = boundary.contains(&Coordinate { - longitude: 0.5, - latitude: 2.0, - }); + let res = boundary.contains(&Coordinate::new(0.5, 2.0)); //THEN assert!(res); } + #[test] + fn test_boundary_distance() { + //GIVEN + let boundary = Boundary::new(Coordinate::new(0.0, 0.0), Coordinate::new(100.0, 100.0)); + + // WHEN + // The pair means (coordinate, expected distance) + let cases = vec![ + // point is inside the boundary + (Coordinate::new(0.5, 2.0), 0.0), + (Coordinate::new(0.0, 0.0), 0.0), + (Coordinate::new(100.0, 100.0), 0.0), + (Coordinate::new(0.0, 100.0), 0.0), + (Coordinate::new(100.0, 0.0), 0.0), + (Coordinate::new(50.0, 50.0), 0.0), + (Coordinate::new(50.0, 0.0), 0.0), + (Coordinate::new(0.0, 50.0), 0.0), + (Coordinate::new(100.0, 50.0), 0.0), + (Coordinate::new(50.0, 100.0), 0.0), + // point is outside the boundary + (Coordinate::new(0.0, 200.0), 100.0), + (Coordinate::new(200.0, 0.0), 100.0), + (Coordinate::new(200.0, 200.0), f64::sqrt(20000.0)), + (Coordinate::new(50.0, 200.0), 100.0), + (Coordinate::new(200.0, 50.0), 100.0), + (Coordinate::new(200.0, 100.0), 100.0), + (Coordinate::new(100.0, 200.0), 100.0), + (Coordinate::new(150.0, 150.0), f64::sqrt(5000.0)), + (Coordinate::new(101.0, 103.0), f64::sqrt(10.0)), + ]; + + //THEN + for (coor, expected) in cases { + let res = boundary.distance(&coor); + assert_eq!(*res, expected); + } + } + #[test] fn test_insert() { //GIVEN @@ -277,11 +516,11 @@ mod test { //WHEN let query_point = Coordinate::new(4999.0, 4950.0); - let interests = quadtree.find_nearest_neighbors(&query_point, 3); + let interests = quadtree.find_nearest_neighbors(&query_point, 8); //THEN - assert_eq!(interests.len(), 3); - let mut expected = vec!["I", "J", "K"]; + assert_eq!(interests.len(), 8); + let mut expected = vec!["I", "J", "K", "G", "H", "E", "D", "F"]; for interest in interests { assert!(expected.contains(interest)); let pos = expected.iter().position(|x| x == interest); @@ -309,13 +548,7 @@ mod test { // WHEN for (longitude, latitude) in million_record { - quadtree.insert( - Coordinate { - longitude, - latitude, - }, - "A", - ); + quadtree.insert(Coordinate::new(longitude, latitude), "A"); } let second_instant = std::time::Instant::now(); @@ -344,10 +577,7 @@ mod test { for (longitude, latitude) in million_record { quadtree.insert( - Coordinate { - longitude, - latitude, - }, + Coordinate::new(longitude, latitude), format!("long : {longitude}, lat: {latitude}"), ); } @@ -355,13 +585,7 @@ mod test { // WHEN let instance = std::time::Instant::now(); - let k_business = quadtree.find_nearest_neighbors( - &Coordinate { - longitude: 127.0, - latitude: 38.1, - }, - 5, - ); + let k_business = quadtree.find_nearest_neighbors(&Coordinate::new(127.0, 38.1), 5); // THEN let second_instant = std::time::Instant::now(); From ce8d24406c2c08856c54f9357ae8e9243c96dd90 Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 19:08:02 +0900 Subject: [PATCH 2/6] fix: lint --- src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ec1ad34..8374452 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ use std::{ - cmp::Reverse, collections::BinaryHeap, fmt::Debug, iter::zip, @@ -25,7 +24,7 @@ impl PartialEq for NonNanFloat { impl PartialOrd for NonNanFloat { fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) + Some(self.cmp(other)) } } From e1bcfec9e9ce7a13c9a8a962f8fcc6f8e2c677b4 Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 19:11:14 +0900 Subject: [PATCH 3/6] fix: stack overflow --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 8374452..cedfb49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ impl PartialEq for NonNanFloat { impl PartialOrd for NonNanFloat { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + self.0.partial_cmp(&other.0) } } From 2206967bc2a10ecfbe80a8475c7a1206dab0960e Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 19:14:02 +0900 Subject: [PATCH 4/6] fix: test --- src/lib.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cedfb49..2566001 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ impl PartialEq for NonNanFloat { } impl PartialOrd for NonNanFloat { + #[inline] fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } @@ -366,15 +367,15 @@ impl Quadtree { #[cfg(test)] mod test { - use crate::{Boundary, Coordinate, Quadtree}; + use crate::{Boundary, Coordinate, NonNanFloat, Quadtree}; #[test] fn non_nan_float() { //GIVEN - let a = 10.0; - let b = 20.0; - let c = 30.0; - let d = 40.0; + let a = NonNanFloat(10.0); + let b = NonNanFloat(20.0); + let c = NonNanFloat(30.0); + let d = NonNanFloat(40.0); //WHEN let res = a + b; @@ -386,10 +387,10 @@ mod test { let res7 = a == b; //THEN - assert_eq!(res, 30.0); - assert_eq!(res2, -10.0); - assert_eq!(res3, 200.0); - assert_eq!(res4, 0.75); + assert_eq!(res, NonNanFloat(30.0)); + assert_eq!(res2, NonNanFloat(-10.0)); + assert_eq!(res3, NonNanFloat(200.0)); + assert_eq!(res4, NonNanFloat(0.75)); assert!(!res5); assert!(res6); assert!(!res7); From 0efaff78a58d56cb6ccbd1b511e653c5c2dbcb5a Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 22:26:50 +0900 Subject: [PATCH 5/6] chore: add more test cases --- src/lib.rs | 151 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 124 insertions(+), 27 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2566001..bd643b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,17 @@ impl PartialEq for NonNanFloat { } } +impl PartialEq for NonNanFloat { + #[inline] + fn eq(&self, other: &f64) -> bool { + if self.0.is_nan() { + other.is_nan() + } else { + self.0 == *other + } + } +} + impl PartialOrd for NonNanFloat { #[inline] fn partial_cmp(&self, other: &Self) -> Option { @@ -376,6 +387,9 @@ mod test { let b = NonNanFloat(20.0); let c = NonNanFloat(30.0); let d = NonNanFloat(40.0); + let nan = NonNanFloat(f64::NAN); + let another_nan = NonNanFloat(f64::NAN); + let plain_float = 10.0; //WHEN let res = a + b; @@ -385,6 +399,12 @@ mod test { let res5 = a > b; let res6 = a < b; let res7 = a == b; + let res8 = nan == another_nan; + let res9 = a == plain_float; + let res10 = a + plain_float; + let res11 = a - plain_float; + let res12 = a * plain_float; + let res13 = a / plain_float; //THEN assert_eq!(res, NonNanFloat(30.0)); @@ -394,6 +414,12 @@ mod test { assert!(!res5); assert!(res6); assert!(!res7); + assert!(res8); + assert!(res9); + assert_eq!(res10, NonNanFloat(20.0)); + assert_eq!(res11, NonNanFloat(0.0)); + assert_eq!(res12, NonNanFloat(100.0)); + assert_eq!(res13, NonNanFloat(1.0)); } #[test] @@ -435,6 +461,8 @@ mod test { (Coordinate::new(100.0, 200.0), 100.0), (Coordinate::new(150.0, 150.0), f64::sqrt(5000.0)), (Coordinate::new(101.0, 103.0), f64::sqrt(10.0)), + // real data + (Coordinate::new(37.512428, 127.054513), 27.054513), ]; //THEN @@ -495,38 +523,107 @@ mod test { #[test] fn test_k_nearest() { + struct TestCase<'a> { + points: Vec<(f64, f64, &'a str)>, + query_point: (f64, f64), + k: usize, + expected: Vec<&'a str>, + } + //GIVEN - let mut quadtree: Quadtree<&str> = Quadtree::new( - Boundary::new(Coordinate::new(0.0, 0.0), Coordinate::new(5000.0, 5000.0)), - 4, - ); + let mut cases = vec![ + TestCase { + points: vec![ + (30.0, 30.0, "A"), + (10.0, 50.0, "B"), + (70.0, 20.0, "C"), + (80.0, 80.0, "D"), + (80.0, 90.0, "E"), + (60.0, 90.0, "F"), + (2500.0, 2700.0, "G"), + (1700.0, 1500.0, "H"), + (4993.0, 4999.0, "I"), + (4993.0, 4330.0, "J"), + (4993.0, 4500.0, "K"), + ], + query_point: (4999.0, 4950.0), + k: 8, + expected: vec!["I", "J", "K", "G", "H", "E", "D", "F"], + }, + TestCase { + points: vec![ + (30.0, 30.0, "A"), + (10.0, 50.0, "B"), + (70.0, 20.0, "C"), + (80.0, 80.0, "D"), + (80.0, 90.0, "E"), + (60.0, 90.0, "F"), + (2500.0, 2700.0, "G"), + (1700.0, 1500.0, "H"), + (4993.0, 4999.0, "I"), + (4993.0, 4330.0, "J"), + (4993.0, 4500.0, "K"), + ], + query_point: (4999.0, 4950.0), + k: 5, + expected: vec!["I", "J", "K", "G", "H"], + }, + TestCase { + points: vec![ + (30.0, 30.0, "A"), + (10.0, 50.0, "B"), + (70.0, 20.0, "C"), + (80.0, 80.0, "D"), + (80.0, 90.0, "E"), + (60.0, 90.0, "F"), + (2500.0, 2700.0, "G"), + (1700.0, 1500.0, "H"), + (4993.0, 4999.0, "I"), + (4993.0, 4330.0, "J"), + (4993.0, 4500.0, "K"), + ], + query_point: (4999.0, 4950.0), + k: 3, + expected: vec!["I", "J", "K"], + }, + TestCase { + points: vec![(30.0, 30.0, "A"), (10.0, 50.0, "B"), (70.0, 20.0, "C")], + query_point: (4999.0, 4950.0), + k: 5, + expected: vec!["A", "B", "C"], + }, + TestCase { + points: Vec::new(), + query_point: (4999.0, 4950.0), + k: 5, + expected: Vec::new(), + }, + ]; - // Inserting some points - quadtree.insert(Coordinate::new(30.0, 30.0), "A"); - quadtree.insert(Coordinate::new(10.0, 50.0), "B"); - quadtree.insert(Coordinate::new(70.0, 20.0), "C"); - quadtree.insert(Coordinate::new(80.0, 80.0), "D"); - quadtree.insert(Coordinate::new(80.0, 90.0), "E"); - quadtree.insert(Coordinate::new(60.0, 90.0), "F"); - quadtree.insert(Coordinate::new(2500.0, 2700.0), "G"); - quadtree.insert(Coordinate::new(1700.0, 1500.0), "H"); - quadtree.insert(Coordinate::new(4993.0, 4999.0), "I"); - quadtree.insert(Coordinate::new(4993.0, 4330.0), "J"); - quadtree.insert(Coordinate::new(4993.0, 4500.0), "K"); + //THEN + for case in cases.iter_mut() { + let mut quadtree: Quadtree<&str> = Quadtree::new( + Boundary::new(Coordinate::new(0.0, 0.0), Coordinate::new(5000.0, 5000.0)), + 4, + ); - //WHEN - let query_point = Coordinate::new(4999.0, 4950.0); - let interests = quadtree.find_nearest_neighbors(&query_point, 8); + for (longitude, latitude, interest) in &case.points { + quadtree.insert(Coordinate::new(*longitude, *latitude), interest); + } - //THEN - assert_eq!(interests.len(), 8); - let mut expected = vec!["I", "J", "K", "G", "H", "E", "D", "F"]; - for interest in interests { - assert!(expected.contains(interest)); - let pos = expected.iter().position(|x| x == interest); - expected.remove(pos.unwrap()); + let mut interests = quadtree.find_nearest_neighbors( + &Coordinate::new(case.query_point.0, case.query_point.1), + case.k, + ); + + interests.sort(); + case.expected.sort(); + + assert_eq!(interests.len(), case.expected.len()); + for (interest, expected) in interests.iter().zip(case.expected.iter()) { + assert_eq!(*interest, expected); + } } - assert!(expected.is_empty()); } #[test] From 3227e9c10a5c15fce80a44a5dc17155b8c12fa1e Mon Sep 17 00:00:00 2001 From: JWSong Date: Sat, 18 May 2024 22:28:35 +0900 Subject: [PATCH 6/6] chore: ignore clippy --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index bd643b6..01c2ead 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ impl PartialEq for NonNanFloat { impl PartialOrd for NonNanFloat { #[inline] + #[allow(clippy::all)] fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) }