Skip to content

Commit dbf913d

Browse files
jeremyfowersapsonawane
authored andcommitted
Allow tools to display percent progress in the monitor (onnx#258)
1 parent 2a9893d commit dbf913d

File tree

2 files changed

+49
-7
lines changed
  • examples/cli/plugins/example_tool/turnkeyml_plugin_example_tool
  • src/turnkeyml/tools

2 files changed

+49
-7
lines changed

examples/cli/plugins/example_tool/turnkeyml_plugin_example_tool/tool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
and build_model() to have any custom behavior.
55
66
In this the tool simply passes the build state to the next
7-
tool in the sequence (i.e., this example is a no-op).
7+
tool in the sequence (i.e., this example is a no-op). It also
8+
spends a few seconds updating the monitor's percent progress indicator.
89
910
After you install the plugin, you can tell `turnkey` to use this sequence with:
1011
1112
turnkey -i INPUT_SCRIPT export-pytorch example-plugin-tool
1213
"""
1314

1415
import argparse
16+
from time import sleep
1517
from turnkeyml.tools import Tool
1618
from turnkeyml.state import State
1719

@@ -39,4 +41,10 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
3941
return parser
4042

4143
def run(self, state: State):
44+
self.set_percent_progress(0.0)
45+
total = 15 # seconds
46+
for i in range(total):
47+
sleep(1)
48+
percent_progress = (i + 1) / float(total) * 100
49+
self.set_percent_progress(percent_progress)
4250
return state

src/turnkeyml/tools/tool.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import textwrap as _textwrap
77
import re
88
from typing import Tuple, Dict
9-
from multiprocessing import Process
9+
from multiprocessing import Process, Queue
1010
import psutil
1111
import turnkeyml.common.printing as printing
1212
import turnkeyml.common.exceptions as exp
@@ -15,13 +15,25 @@
1515
from turnkeyml.state import State
1616

1717

18-
def _spinner(message):
18+
def _spinner(message, q: Queue):
19+
"""
20+
Displays a moving "..." indicator so that the user knows that the
21+
Tool is still working. Tools can optionally use a multiprocessing
22+
Queue to display the percent progress of the Tool.
23+
"""
24+
percent_complete = None
25+
1926
try:
2027
parent_process = psutil.Process(pid=os.getppid())
2128
while parent_process.status() == psutil.STATUS_RUNNING:
2229
for cursor in [" ", ". ", ".. ", "..."]:
2330
time.sleep(0.5)
24-
status = f" {message}{cursor}\r"
31+
if not q.empty():
32+
percent_complete = q.get()
33+
if percent_complete is not None:
34+
status = f" {message} ({percent_complete:.1f}%){cursor}\r"
35+
else:
36+
status = f" {message}{cursor}\r"
2537
sys.stdout.write(status)
2638
sys.stdout.flush()
2739
except psutil.NoSuchProcess:
@@ -146,17 +158,22 @@ def status_line(self, successful, verbosity):
146158
success_tick = "+"
147159
fail_tick = "x"
148160

161+
if self.percent_progress is None:
162+
progress_indicator = ""
163+
else:
164+
progress_indicator = f" ({self.percent_progress:.1f}%)"
165+
149166
if successful is None:
150167
# Initialize the message
151168
printing.logn(f" {self.monitor_message} ")
152169
elif successful:
153170
# Print success message
154171
printing.log(f" {success_tick} ", c=printing.Colors.OKGREEN)
155-
printing.logn(self.monitor_message + " ")
172+
printing.logn(self.monitor_message + progress_indicator + " ")
156173
else:
157174
# successful == False, print failure message
158175
printing.log(f" {fail_tick} ", c=printing.Colors.FAIL)
159-
printing.logn(self.monitor_message + " ")
176+
printing.logn(self.monitor_message + progress_indicator + " ")
160177

161178
def __init__(
162179
self,
@@ -169,6 +186,8 @@ def __init__(
169186
self.duration_key = f"{fs.Keys.TOOL_DURATION}:{self.__class__.unique_name}"
170187
self.monitor_message = monitor_message
171188
self.progress = None
189+
self.progress_queue = None
190+
self.percent_progress = None
172191
self.logfile_path = None
173192
# Tools can disable build.Logger, which captures all stdout and stderr from
174193
# the Tool, by setting enable_logger=False
@@ -192,6 +211,18 @@ def parser() -> argparse.ArgumentParser:
192211
line interface for this Tool.
193212
"""
194213

214+
def set_percent_progress(self, percent_progress: float):
215+
"""
216+
Update the progress monitor with a percent progress to let the user
217+
know how much progress the Tool has made.
218+
"""
219+
220+
if not isinstance(percent_progress, float):
221+
raise ValueError(f"Input argument must be a float, got {percent_progress}")
222+
223+
self.progress_queue.put(percent_progress)
224+
self.percent_progress = percent_progress
225+
195226
# pylint: disable=unused-argument
196227
def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
197228
"""
@@ -250,7 +281,10 @@ def run_helper(
250281
)
251282

252283
if monitor:
253-
self.progress = Process(target=_spinner, args=[self.monitor_message])
284+
self.progress_queue = Queue()
285+
self.progress = Process(
286+
target=_spinner, args=(self.monitor_message, self.progress_queue)
287+
)
254288
self.progress.start()
255289

256290
try:

0 commit comments

Comments
 (0)