Skip to content

Commit 2792abf

Browse files
Add ChangeCounter example! 💰 💰
1 parent d132ae4 commit 2792abf

File tree

1 file changed

+199
-0
lines changed

1 file changed

+199
-0
lines changed

examples/ChangeCounter.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
######## Count Change Using Object Detection #########
2+
#
3+
# Author: Evan Juras, EJ Technology Consultants (www.ejtech.io)
4+
# Date: 10/29/22
5+
#
6+
# Description:
7+
# This program uses a TFLite coin detection model to locate and identify coins in
8+
# a live camera feed. It calculates the total value of the coins in the camera's view.
9+
# (Works on US currency, but can be modified to work with coins from other countries!)
10+
11+
# Import packages
12+
import os
13+
import argparse
14+
import cv2
15+
import numpy as np
16+
import sys
17+
import time
18+
from threading import Thread
19+
import importlib.util
20+
21+
### User-defined variables
22+
23+
# Model info
24+
MODEL_NAME = 'change_counter'
25+
GRAPH_NAME = 'detect.tflite'
26+
LABELMAP_NAME = 'labelmap.txt'
27+
use_TPU = False
28+
29+
# Program settings
30+
min_conf_threshold = 0.50
31+
resW, resH = 1280, 720 # Resolution to run camera at
32+
imW, imH = resW, resH
33+
34+
### Set up model parameters
35+
36+
# Import TensorFlow libraries
37+
# If tflite_runtime is installed, import interpreter from tflite_runtime, else import from regular tensorflow
38+
# If using Coral Edge TPU, import the load_delegate library
39+
pkg = importlib.util.find_spec('tflite_runtime')
40+
if pkg:
41+
from tflite_runtime.interpreter import Interpreter
42+
if use_TPU:
43+
from tflite_runtime.interpreter import load_delegate
44+
else:
45+
from tensorflow.lite.python.interpreter import Interpreter
46+
if use_TPU:
47+
from tensorflow.lite.python.interpreter import load_delegate
48+
49+
# If using Edge TPU, assign filename for Edge TPU model
50+
if use_TPU:
51+
# If user has specified the name of the .tflite file, use that name, otherwise use default 'edgetpu.tflite'
52+
if (GRAPH_NAME == 'detect.tflite'):
53+
GRAPH_NAME = 'edgetpu.tflite'
54+
55+
# Get path to current working directory
56+
CWD_PATH = os.getcwd()
57+
58+
# Path to .tflite file, which contains the model that is used for object detection
59+
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,GRAPH_NAME)
60+
61+
# Path to label map file
62+
PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,LABELMAP_NAME)
63+
64+
# Load the label map
65+
with open(PATH_TO_LABELS, 'r') as f:
66+
labels = [line.strip() for line in f.readlines()]
67+
68+
### Load Tensorflow Lite model
69+
# If using Edge TPU, use special load_delegate argument
70+
if use_TPU:
71+
interpreter = Interpreter(model_path=PATH_TO_CKPT,
72+
experimental_delegates=[load_delegate('libedgetpu.so.1.0')])
73+
else:
74+
interpreter = Interpreter(model_path=PATH_TO_CKPT)
75+
76+
interpreter.allocate_tensors()
77+
78+
# Get model details
79+
input_details = interpreter.get_input_details()
80+
output_details = interpreter.get_output_details()
81+
height = input_details[0]['shape'][1]
82+
width = input_details[0]['shape'][2]
83+
84+
floating_model = (input_details[0]['dtype'] == np.float32)
85+
86+
input_mean = 127.5
87+
input_std = 127.5
88+
89+
# Check output layer name to determine if this model was created with TF2 or TF1,
90+
# because outputs are ordered differently for TF2 and TF1 models
91+
outname = output_details[0]['name']
92+
93+
if ('StatefulPartitionedCall' in outname): # This is a TF2 model
94+
boxes_idx, classes_idx, scores_idx = 1, 3, 0
95+
else: # This is a TF1 model
96+
boxes_idx, classes_idx, scores_idx = 0, 1, 2
97+
98+
# Initialize camera
99+
cap = cv2.VideoCapture(0)
100+
ret = cap.set(3, resW)
101+
ret = cap.set(4, resH)
102+
103+
# Initialize frame rate calculation
104+
frame_rate_calc = 1
105+
freq = cv2.getTickFrequency()
106+
107+
### Continuously process frames from camera
108+
while True:
109+
110+
# Start timer (for calculating frame rate)
111+
t1 = cv2.getTickCount()
112+
113+
# Reset coin value count for this frame
114+
total_coin_value = 0
115+
116+
# Grab frame from camera
117+
hasFrame, frame1 = cap.read()
118+
119+
# Acquire frame and resize to input shape expected by model [1xHxWx3]
120+
frame = frame1.copy()
121+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
122+
frame_resized = cv2.resize(frame_rgb, (width, height))
123+
input_data = np.expand_dims(frame_resized, axis=0)
124+
125+
# Normalize pixel values if using a floating model (i.e. if model is non-quantized)
126+
if floating_model:
127+
input_data = (np.float32(input_data) - input_mean) / input_std
128+
129+
# Perform detection by running the model with the image as input
130+
interpreter.set_tensor(input_details[0]['index'],input_data)
131+
interpreter.invoke()
132+
133+
# Retrieve detection results
134+
boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0] # Bounding box coordinates of detected objects
135+
classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0] # Class index of detected objects
136+
scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0] # Confidence of detected objects
137+
138+
# Loop over all detections and process each detection if its confidence is above minimum threshold
139+
for i in range(len(scores)):
140+
if ((scores[i] > min_conf_threshold) and (scores[i] <= 1.0)):
141+
142+
# Get bounding box coordinates
143+
# Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()
144+
ymin = int(max(1,(boxes[i][0] * imH)))
145+
xmin = int(max(1,(boxes[i][1] * imW)))
146+
ymax = int(min(imH,(boxes[i][2] * imH)))
147+
xmax = int(min(imW,(boxes[i][3] * imW)))
148+
149+
# Draw bounding box
150+
cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), (10, 255, 0), 2)
151+
152+
# Get object's name and draw label
153+
object_name = labels[int(classes[i])] # Look up object name from "labels" array using class index
154+
label = '%s: %d%%' % (object_name, int(scores[i]*100)) # Example: 'quarter: 72%'
155+
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) # Get font size
156+
label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window
157+
cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED) # Draw white box to put label text in
158+
cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) # Draw label text
159+
160+
# Assign the value of this coin based on the class name of the detected object
161+
# (There are more efficient ways to do this, but this shows an example of how to trigger an action when a certain class is detected)
162+
if object_name == 'penny':
163+
this_coin_value = 0.01
164+
elif object_name == 'nickel':
165+
this_coin_value = 0.05
166+
elif object_name == 'dime':
167+
this_coin_value = 0.10
168+
elif object_name == 'quarter':
169+
this_coin_value = 0.25
170+
171+
# Add this coin's value to the running total
172+
total_coin_value = total_coin_value + this_coin_value
173+
174+
175+
# Now that we've gone through every detection, we know the total value of all coins in the frame. Let's display it in the corner of the frame.
176+
cv2.putText(frame,'Total change:',(20,80),cv2.FONT_HERSHEY_PLAIN,2,(0,0,0),4,cv2.LINE_AA)
177+
cv2.putText(frame,'Total change:',(20,80),cv2.FONT_HERSHEY_PLAIN,2,(230,230,230),2,cv2.LINE_AA)
178+
cv2.putText(frame,'$%.2f' % total_coin_value,(260,85),cv2.FONT_HERSHEY_PLAIN,2.5,(0,0,0),4,cv2.LINE_AA)
179+
cv2.putText(frame,'$%.2f' % total_coin_value,(260,85),cv2.FONT_HERSHEY_PLAIN,2.5,(85,195,105),2,cv2.LINE_AA)
180+
181+
# Draw framerate in corner of frame
182+
cv2.putText(frame,'FPS: %.2f' % frame_rate_calc,(20,50),cv2.FONT_HERSHEY_PLAIN,2,(0,0,0),4,cv2.LINE_AA)
183+
cv2.putText(frame,'FPS: %.2f' % frame_rate_calc,(20,50),cv2.FONT_HERSHEY_PLAIN,2,(230,230,230),2,cv2.LINE_AA)
184+
185+
# All the results have been drawn on the frame, so it's time to display it.
186+
cv2.imshow('Object detector', frame)
187+
188+
# Calculate framerate
189+
t2 = cv2.getTickCount()
190+
time1 = (t2-t1)/freq
191+
frame_rate_calc= 1/time1
192+
193+
# Press 'q' to quit
194+
if cv2.waitKey(1) == ord('q'):
195+
break
196+
197+
# Clean up
198+
cv2.destroyAllWindows()
199+
cap.release()

0 commit comments

Comments
 (0)