Skip to content

Commit 0eb619c

Browse files
committed
Feat: add concurrency to findMatches methods
1 parent a3ed2ec commit 0eb619c

File tree

3 files changed

+88
-52
lines changed

3 files changed

+88
-52
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ All notable changes to this project will be documented in this file.
33

44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
55

6+
### [1.6.4] - 2023-01-19
7+
8+
### Changed
9+
- Improve speed by adding concurrency in the findMatches method, using half the number of cpu cores available.
10+
611
### [1.6.3] - 2021-11-24
712

813
### Changed

MTM/__init__.py

+82-51
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""Main code for Multi-Template-Matching (MTM)."""
2+
import os
3+
import warnings
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
26
import cv2
3-
import numpy as np
7+
import numpy as np
48
import pandas as pd
5-
import warnings
9+
from scipy.signal import find_peaks
610
from skimage.feature import peak_local_max
7-
from scipy.signal import find_peaks
8-
from .version import __version__
911

1012
from .NMS import NMS
13+
from .version import __version__
1114

1215
__all__ = ['NMS']
1316

@@ -33,7 +36,7 @@ def _findLocalMax_(corrMap, score_threshold=0.6):
3336
peaks = [[i,0] for i in peaks[0]]
3437

3538

36-
else: # Correlatin map is 2D
39+
else: # Correlation map is 2D
3740
peaks = peak_local_max(corrMap, threshold_abs=score_threshold, exclude_border=False).tolist()
3841

3942
return peaks
@@ -116,82 +119,110 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
116119
-------
117120
- Pandas DataFrame with 1 row per hit and column "TemplateName"(string), "BBox":(X, Y, Width, Height), "Score":float
118121
"""
119-
if N_object != float("inf") and type(N_object) != int:
122+
if N_object != float("inf") and not isinstance(N_object, int):
120123
raise TypeError("N_object must be an integer")
121124

122125
## Crop image to search region if provided
123126
if searchBox is not None:
124127
xOffset, yOffset, searchWidth, searchHeight = searchBox
125128
image = image[yOffset : yOffset+searchHeight, xOffset : xOffset+searchWidth]
126-
127129
else:
128130
xOffset=yOffset=0
129-
131+
130132
# Check that the template are all smaller are equal to the image (original, or cropped if there is a search region)
131133
for index, tempTuple in enumerate(listTemplates):
132-
134+
133135
if not isinstance(tempTuple, tuple) or len(tempTuple)==1:
134136
raise ValueError("listTemplates should be a list of tuples as ('name','array') or ('name', 'array', 'mask')")
135-
137+
136138
templateSmallerThanImage = all(templateDim <= imageDim for templateDim, imageDim in zip(tempTuple[1].shape, image.shape))
137-
139+
138140
if not templateSmallerThanImage :
139141
fitIn = "searchBox" if (searchBox is not None) else "image"
140142
raise ValueError("Template '{}' at index {} in the list of templates is larger than {}.".format(tempTuple[0], index, fitIn) )
141-
143+
142144
listHit = []
143-
for tempTuple in listTemplates:
145+
## Use multi-threading to iterate through all templates, using half the number of cpu cores available.
146+
with ThreadPoolExecutor(max_workers=round(os.cpu_count()*.5)) as executor:
147+
futures = [executor.submit(_multi_compute, tempTuple, image, method, N_object, score_threshold, xOffset, yOffset, listHit) for tempTuple in listTemplates]
148+
for future in as_completed(futures):
149+
_ = future.result()
144150

145-
templateName, template = tempTuple[:2]
146-
mask = None
151+
if listHit:
152+
return pd.DataFrame(listHit) # All possible hits before Non-Maxima Supression
153+
else:
154+
return pd.DataFrame(columns=["TemplateName", "BBox", "Score"])
147155

148-
if len(tempTuple)>=3: # ie a mask is also provided
149-
if method in (0,3):
150-
mask = tempTuple[2]
151-
else:
152-
warnings.warn("Template matching method not supporting the use of Mask. Use 0/TM_SQDIFF or 3/TM_CCORR_NORMED.")
153156

154-
#print('\nSearch with template : ',templateName)
155-
corrMap = computeScoreMap(template, image, method, mask=mask)
157+
def _multi_compute(tempTuple, image, method, N_object, score_threshold, xOffset, yOffset, listHit):
158+
"""
159+
Find all possible template locations satisfying the score threshold provided a template to search and an image.
160+
Add the hits in the list of hits.
161+
162+
Parameters
163+
----------
164+
- tempTuple : a tuple (LabelString, template, mask (optional))
165+
template to search in each image, associated to a label
166+
labelstring : string
167+
template : numpy array (grayscale or RGB)
168+
mask (optional): numpy array, should have the same dimensions and type than the template
156169
157-
## Find possible location of the object
158-
if N_object==1: # Detect global Min/Max
159-
minVal, maxVal, minLoc, maxLoc = cv2.minMaxLoc(corrMap)
170+
- image : Grayscale or RGB numpy array
171+
image in which to perform the search, it should be the same bitDepth and number of channels than the templates
160172
161-
if method in (0,1):
162-
peaks = [minLoc[::-1]] # opposite sorting than in the multiple detection
173+
- method : int
174+
one of OpenCV template matching method (0 to 5), default 5=0-mean cross-correlation
163175
164-
else:
165-
peaks = [maxLoc[::-1]]
176+
- N_object: int or float("inf")
177+
expected number of objects in the image, default to infinity if unknown
166178
179+
- score_threshold: float in range [0,1]
180+
if N_object>1, returns local minima/maxima respectively below/above the score_threshold
167181
168-
else:# Detect local max or min
169-
if method in (0,1): # Difference => look for local minima
170-
peaks = _findLocalMin_(corrMap, score_threshold)
182+
- xOffset : int
183+
optional the x offset if the search area is provided
171184
172-
else:
173-
peaks = _findLocalMax_(corrMap, score_threshold)
185+
- yOffset : int
186+
optional the y offset if the search area is provided
174187
188+
- listHit : the list of hits which we want to add the discovered hit
189+
expected array of hits
190+
"""
191+
templateName, template = tempTuple[:2]
192+
mask = None
175193

176-
#print('Initially found',len(peaks),'hit with this template')
194+
if len(tempTuple)>=3: # ie a mask is also provided
195+
if method in (0,3):
196+
mask = tempTuple[2]
197+
else:
198+
warnings.warn("Template matching method not supporting the use of Mask. Use 0/TM_SQDIFF or 3/TM_CCORR_NORMED.")
177199

200+
#print('\nSearch with template : ',templateName)
201+
corrMap = computeScoreMap(template, image, method, mask=mask)
178202

179-
# Once every peak was detected for this given template
180-
## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
203+
## Find possible location of the object
204+
if N_object==1: # Detect global Min/Max
205+
_, _, minLoc, maxLoc = cv2.minMaxLoc(corrMap)
206+
if method in (0,1):
207+
peaks = [minLoc[::-1]] # opposite sorting than in the multiple detection
208+
else:
209+
peaks = [maxLoc[::-1]]
210+
else:# Detect local max or min
211+
if method in (0,1): # Difference => look for local minima
212+
peaks = _findLocalMin_(corrMap, score_threshold)
213+
else:
214+
peaks = _findLocalMax_(corrMap, score_threshold)
181215

182-
height, width = template.shape[0:2] # slicing make sure it works for RGB too
216+
#print('Initially found',len(peaks),'hit with this template')
183217

184-
for peak in peaks :
185-
coeff = corrMap[tuple(peak)]
186-
newHit = {'TemplateName':templateName, 'BBox': ( int(peak[1])+xOffset, int(peak[0])+yOffset, width, height ) , 'Score':coeff}
218+
# Once every peak was detected for this given template
219+
## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
187220

188-
# append to list of potential hit before Non maxima suppression
189-
listHit.append(newHit)
221+
height, width = template.shape[0:2] # slicing make sure it works for RGB too
190222

191-
if listHit:
192-
return pd.DataFrame(listHit) # All possible hits before Non-Maxima Supression
193-
else:
194-
return pd.DataFrame(columns=["TemplateName", "BBox", "Score"]) # empty df with correct column header
223+
for peak in peaks :
224+
# append to list of potential hit before Non maxima suppression
225+
listHit.append({'TemplateName':templateName, 'BBox': ( int(peak[1])+xOffset, int(peak[0])+yOffset, width, height ) , 'Score':corrMap[tuple(peak)]}) # empty df with correct column header
195226

196227

197228
def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, maxOverlap=0.25, searchBox=None):
@@ -239,7 +270,7 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f
239270
tableHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)
240271

241272
if method == 0: raise ValueError("The method TM_SQDIFF is not supported. Use TM_SQDIFF_NORMED instead.")
242-
sortAscending = True if method==1 else False
273+
sortAscending = (method==1)
243274

244275
return NMS(tableHit, score_threshold, sortAscending, N_object, maxOverlap)
245276

@@ -275,7 +306,7 @@ def drawBoxesOnRGB(image, tableHit, boxThickness=2, boxColor=(255, 255, 00), sho
275306
if image.ndim == 2: outImage = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert to RGB to be able to show detections as color box on grayscale image
276307
else: outImage = image.copy()
277308

278-
for index, row in tableHit.iterrows():
309+
for _, row in tableHit.iterrows():
279310
x,y,w,h = row['BBox']
280311
cv2.rectangle(outImage, (x, y), (x+w, y+h), color=boxColor, thickness=boxThickness)
281312
if showLabel: cv2.putText(outImage, text=row['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
@@ -315,9 +346,9 @@ def drawBoxesOnGray(image, tableHit, boxThickness=2, boxColor=255, showLabel=Fal
315346
if image.ndim == 3: outImage = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # convert to RGB to be able to show detections as color box on grayscale image
316347
else: outImage = image.copy()
317348

318-
for index, row in tableHit.iterrows():
349+
for _, row in tableHit.iterrows():
319350
x,y,w,h = row['BBox']
320351
cv2.rectangle(outImage, (x, y), (x+w, y+h), color=boxColor, thickness=boxThickness)
321352
if showLabel: cv2.putText(outImage, text=row['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
322353

323-
return outImage
354+
return outImage

MTM/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# 1) we don't load dependencies by storing it in __init__.py
33
# 2) we can import it in setup.py for the same reason
44
# 3) we can import it into your module module
5-
__version__ = '1.6.3'
5+
__version__ = '1.6.4'

0 commit comments

Comments
 (0)