From e584822bac172b95ac208c5a2a7fb75c50e63bce Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Thu, 16 May 2024 13:02:30 -0700 Subject: [PATCH] Fix flaky test in Faiss JNI range search (#1705) Signed-off-by: Junqiu Lei --- jni/tests/faiss_wrapper_test.cpp | 33 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index e9316dcc2..4cd3b319e 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -25,6 +25,9 @@ using ::testing::Return; float randomDataMin = -500.0; float randomDataMax = 500.0; +float rangeSearchRandomDataMin = -50; +float rangeSearchRandomDataMax = 50; +float rangeSearchRadius = 20000; TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data @@ -621,13 +624,12 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -635,7 +637,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -659,7 +661,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -677,13 +679,12 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -691,7 +692,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -715,7 +716,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -734,13 +735,12 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -748,7 +748,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -767,7 +767,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { std::vector bitmap(num_bits,0); std::vector filterIds; - for (int64_t i = 154; i < 163; i++) { + for (int64_t i = 1; i < 50; i++) { filterIds.push_back(i); test_util::setBitSet(i, bitmap.data(), bitmap.size()); } @@ -782,7 +782,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, reinterpret_cast(&bitmap), 0, nullptr))); // assert result size is not 0 @@ -814,7 +814,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { } ids.push_back(i); for (int j = 0; j < dim; j++) { - vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } } @@ -822,7 +822,6 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 1; std::vector> queries; @@ -830,7 +829,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(-500.0, 500.0)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -858,7 +857,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr, 0, + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr, 0, reinterpret_cast(&parentIds)))); // assert result size is not 0