diff --git a/faiss_vector_cache.go b/faiss_vector_cache.go index aeba58a..e86550e 100644 --- a/faiss_vector_cache.go +++ b/faiss_vector_cache.go @@ -57,28 +57,29 @@ func (vc *vectorIndexCache) Clear() { // map. It's false otherwise. func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, loadDocVecIDMap bool, except *roaring.Bitmap) ( - index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]uint32, - vecIDsToExclude []int64, err error) { - index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err = vc.loadFromCache( + index *faiss.IndexImpl, clusterAssignment map[int64]*roaring.Bitmap, + vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]uint32, vecIDsToExclude []int64, err error) { + index, clusterAssignment, vecDocIDMap, docVecIDMap, vecIDsToExclude, err = vc.loadFromCache( fieldID, loadDocVecIDMap, mem, except) - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err + return index, clusterAssignment, vecDocIDMap, docVecIDMap, vecIDsToExclude, err } // function to load the vectorDocIDMap and if required, docVecIDMap from cache // If not, it will create these and add them to the cache. func (vc *vectorIndexCache) loadFromCache(fieldID uint16, loadDocVecIDMap bool, - mem []byte, except *roaring.Bitmap) (index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, - docVecIDMap map[uint32][]uint32, vecIDsToExclude []int64, err error) { + mem []byte, except *roaring.Bitmap) (index *faiss.IndexImpl, clusterAssignment map[int64]*roaring.Bitmap, + vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]uint32, + vecIDsToExclude []int64, err error) { vc.m.RLock() entry, ok := vc.cache[fieldID] if ok { - index, vecDocIDMap, docVecIDMap = entry.load() + index, clusterAssignment, vecDocIDMap, docVecIDMap = entry.load() vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except) if !loadDocVecIDMap || (loadDocVecIDMap && len(entry.docVecIDMap) > 0) { vc.m.RUnlock() - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, clusterAssignment, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil } vc.m.RUnlock() @@ -88,7 +89,7 @@ func (vc *vectorIndexCache) loadFromCache(fieldID uint16, loadDocVecIDMap bool, // typically seen for the first filtered query. docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry) vc.m.Unlock() - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, clusterAssignment, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil } vc.m.RUnlock() @@ -117,20 +118,21 @@ func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint // Rebuilding the cache on a miss. func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte, loadDocVecIDMap bool, except *roaring.Bitmap) ( - index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, - docVecIDMap map[uint32][]uint32, vecIDsToExclude []int64, err error) { + index *faiss.IndexImpl, centroidVecIDMap map[int64]*roaring.Bitmap, + vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]uint32, + vecIDsToExclude []int64, err error) { // Handle concurrent accesses (to avoid unnecessary work) by adding a // check within the write lock here. entry := vc.cache[fieldID] if entry != nil { - index, vecDocIDMap, docVecIDMap = entry.load() + index, centroidVecIDMap, vecDocIDMap, docVecIDMap = entry.load() vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except) if !loadDocVecIDMap || (loadDocVecIDMap && len(entry.docVecIDMap) > 0) { - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, centroidVecIDMap, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil } docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry) - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, centroidVecIDMap, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil } // if the cache doesn't have the entry, construct the vector to doc id map and @@ -166,11 +168,24 @@ func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte, index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err + } + + clusterAssignment, _ := index.ObtainClusterToVecIDsFromIVFIndex() + centroidVecIDMap = make(map[int64]*roaring.Bitmap) + for centroidID, vecIDs := range clusterAssignment { + if _, exists := centroidVecIDMap[centroidID]; !exists { + centroidVecIDMap[centroidID] = roaring.NewBitmap() + } + vecIDsUint32 := make([]uint32, len(vecIDs)) + for i, vecID := range vecIDs { + vecIDsUint32[i] = uint32(vecID) + } + centroidVecIDMap[centroidID].AddMany(vecIDsUint32) } vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap) - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, centroidVecIDMap, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil } func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16, @@ -308,9 +323,10 @@ type cacheEntry struct { // threshold we close/cleanup only if the live refs to the cache entry is 0. refs int64 - index *faiss.IndexImpl - vecDocIDMap map[int64]uint32 - docVecIDMap map[uint32][]uint32 + index *faiss.IndexImpl + vecDocIDMap map[int64]uint32 + docVecIDMap map[uint32][]uint32 + clusterAssignment map[int64]*roaring.Bitmap } func (ce *cacheEntry) incHit() { @@ -325,10 +341,11 @@ func (ce *cacheEntry) decRef() { atomic.AddInt64(&ce.refs, -1) } -func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]uint32) { +func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]*roaring.Bitmap, + map[int64]uint32, map[uint32][]uint32) { ce.incHit() ce.addRef() - return ce.index, ce.vecDocIDMap, ce.docVecIDMap + return ce.index, ce.clusterAssignment, ce.vecDocIDMap, ce.docVecIDMap } func (ce *cacheEntry) close() { diff --git a/faiss_vector_posting.go b/faiss_vector_posting.go index 9c6aa1b..4920ff4 100644 --- a/faiss_vector_posting.go +++ b/faiss_vector_posting.go @@ -309,6 +309,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool segment.VectorIndex, error) { // Params needed for the closures var vecIndex *faiss.IndexImpl + var centroidVecIDMap map[int64]*roaring.Bitmap var vecDocIDMap map[int64]uint32 var docVecIDMap map[uint32][]uint32 var vectorIDsToExclude []int64 @@ -409,10 +410,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool return rv, nil } - // Retrieve the mapping of centroid IDs to vectors within - // the cluster. - clusterAssignment, _ := vecIndex.ObtainClusterToVecIDsFromIVFIndex() - if len(clusterAssignment) == 0 { + if len(centroidVecIDMap) == 0 { // Accounting for a flat index scores, ids, err := vecIndex.SearchWithIDs(qVector, k, vectorIDsToInclude, params) if err != nil { @@ -422,21 +420,6 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool return rv, nil } - // TODO: WHY NOT CACHE THIS? - // Converting to roaring bitmap for ease of intersect ops with - // the set of eligible doc IDs. - centroidVecIDMap := make(map[int64]*roaring.Bitmap) - for centroidID, vecIDs := range clusterAssignment { - if _, exists := centroidVecIDMap[centroidID]; !exists { - centroidVecIDMap[centroidID] = roaring.NewBitmap() - } - vecIDsUint32 := make([]uint32, len(vecIDs)) - for i, vecID := range vecIDs { - vecIDsUint32[i] = uint32(vecID) - } - centroidVecIDMap[centroidID].AddMany(vecIDsUint32) - } - // Determining which clusters, identified by centroid ID, // have at least one eligible vector and hence, ought to be // probed. @@ -584,7 +567,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool pos += n } - vecIndex, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err = + vecIndex, centroidVecIDMap, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err = sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], requiresFiltering, except)