Skip to content

Commit c0d31fd

Browse files
Send in nonblocking loop (#15)
* Replace blocking event loop with multiprocessing to spawn new process to run event loop if needed. * Add a couple of helper functions which handle the event loop sending telemetry data in the background. * Update docstrings * Add test which times sending a few requests to ensure that they are non blocking. * Update wheel version (security vulnerability) * Remove redundant test (didn't actually test anything or improve coverage)
1 parent 2de8315 commit c0d31fd

File tree

4 files changed

+131
-41
lines changed

4 files changed

+131
-41
lines changed

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@ max-complexity = 18
131131

132132
[tool.ruff.isort]
133133
known-first-party = ["access_py_telemetry"]
134+
135+
[tool.pytest.ini_options]
136+
asyncio_default_fixture_loop_scope = "function"

requirements_dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pip>=23.3
22
bump2version==0.5.11
3-
wheel==0.33.6
3+
wheel>=0.38.1
44
watchdog==0.9.0
55
tox==3.14.0
66
coverage==4.5.4

src/access_py_telemetry/api.py

+87-18
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,19 @@
1212
import asyncio
1313
import pydantic
1414
import yaml
15+
import multiprocessing
1516
from pathlib import Path
17+
from multiprocessing import Process
18+
19+
try:
20+
multiprocessing.set_start_method("spawn")
21+
except RuntimeError:
22+
"""
23+
Since this is a side effect on module import, we can't guarantee that it will
24+
always be called before the first Process is created, or that the start method
25+
isn't already set. In this case, we just ignore the error - the processes should
26+
still work fine.
27+
"""
1628
from .utils import ENDPOINTS, REGISTRIES
1729

1830
S = TypeVar("S", bound="SessionID")
@@ -84,11 +96,14 @@ def pop_fields(self) -> dict[str, list[str]]:
8496
return self._pop_fields
8597

8698
@pydantic.validate_call
87-
def remove_fields(self, service: str, fields: Iterable[str]) -> None:
99+
def remove_fields(self, service: str, fields: str | Iterable[str]) -> None:
88100
"""
89101
Set the fields to remove from the telemetry data for a given service. Useful for excluding default
90102
fields that are not needed for a particular telemetry call: eg, removing
91103
Session tracking if a CLI is being used.
104+
105+
Note: This does not use a set union, so you must specify all fields you want to remove in one call.
106+
# TODO: Maybe make this easier to use?
92107
"""
93108
if isinstance(fields, str):
94109
fields = [fields]
@@ -244,44 +259,98 @@ async def send_telemetry(endpoint: str, data: dict[str, Any]) -> None:
244259
return None
245260

246261

247-
def send_in_loop(endpoint: str, telemetry_data: dict[str, Any]) -> None:
262+
def send_in_loop(
263+
endpoint: str, telemetry_data: dict[str, Any], timeout: float | None = None
264+
) -> None:
248265
"""
249266
Wraps the send_telemetry function in an event loop. This function will:
250267
- Check if an event loop is already running
251-
- Create a new event loop if one is not running
252-
- Send the telemetry data
268+
- If an event loop is running, send the telemetry data in the background
269+
- If an event loop is not running, create a new event loop in a separate process
270+
and send the telemetry data in the background using that loop.
253271
254272
Parameters
255273
----------
256274
endpoint : str
257275
The URL to send the telemetry data to.
258276
telemetry_data : dict
259277
The telemetry data to send.
278+
timeout : float, optional
279+
The maximum time to wait for the coroutine to finish. If the coroutine takes
280+
longer than this time, a TimeoutError will be raised. If None, the coroutine
281+
will terminate after 60 seconds.
260282
261283
Returns
262284
-------
263285
None
264286
265-
Warnings
266-
--------
267-
RuntimeWarning
268-
If the event loop is not running, telemetry will block execution.
269287
"""
288+
timeout = timeout or 60
270289

271-
# Check if there's an existing event loop, otherwise create a new one
272290
try:
273291
loop = asyncio.get_running_loop()
274292
except RuntimeError:
275-
loop = asyncio.new_event_loop()
276-
asyncio.set_event_loop(loop)
277-
278-
if loop.is_running():
279-
loop.create_task(send_telemetry(endpoint, telemetry_data))
293+
_run_in_proc(endpoint, telemetry_data, timeout)
280294
else:
281-
# breakpoint()
282-
# loop.create_task(send_telemetry(telemetry_data))
283-
loop.run_until_complete(send_telemetry(endpoint, telemetry_data))
295+
loop.create_task(send_telemetry(endpoint, telemetry_data))
296+
return None
297+
298+
299+
def _run_event_loop(endpoint: str, telemetry_data: dict[str, Any]) -> None:
300+
"""
301+
Handles the creation and running of an event loop for sending telemetry data.
302+
This function is intended to be run in a separate process, and will:
303+
- Create a new event loop
304+
- Send the telemetry data
305+
- Run the event loop until the telemetry data is sent
306+
307+
Parameters
308+
----------
309+
endpoint : str
310+
The URL to send the telemetry data to.
311+
telemetry_data : dict
312+
The telemetry data to send.
313+
314+
Returns
315+
-------
316+
None
317+
"""
318+
loop = asyncio.new_event_loop()
319+
asyncio.set_event_loop(loop)
320+
loop.run_until_complete(send_telemetry(endpoint, telemetry_data))
321+
322+
323+
def _run_in_proc(endpoint: str, telemetry_data: dict[str, Any], timeout: float) -> None:
324+
"""
325+
Handles the creation and running of a separate process for sending telemetry data.
326+
This function will:
327+
- Create a new process and run the _run_event_loop function in that process
328+
- Wait for the process to finish
329+
- If the process takes longer than the specified timeout, terminate the process
330+
and raise a warning
331+
332+
Parameters
333+
----------
334+
endpoint : str
335+
The URL to send the telemetry data to.
336+
telemetry_data : dict
337+
The telemetry data to send.
338+
timeout : float
339+
The maximum time to wait for the process to finish.
340+
341+
Returns
342+
-------
343+
None
344+
345+
"""
346+
proc = Process(target=_run_event_loop, args=(endpoint, telemetry_data))
347+
proc.start()
348+
proc.join(timeout)
349+
if proc.is_alive():
350+
proc.terminate()
284351
warnings.warn(
285-
"Event loop not running, telemetry will block execution",
352+
f"Telemetry data not sent within {timeout} seconds",
286353
category=RuntimeWarning,
354+
stacklevel=2,
287355
)
356+
return None

tests/test_api.py

+40-22
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
"""Tests for `access_py_telemetry` package."""
55

66
import access_py_telemetry.api
7-
from access_py_telemetry.api import SessionID, ApiHandler
8-
import warnings
7+
from access_py_telemetry.api import SessionID, ApiHandler, send_in_loop
98
from pydantic import ValidationError
109
import pytest
1110

11+
import time
12+
1213

1314
@pytest.fixture
1415
def local_host():
@@ -164,13 +165,16 @@ def test_api_handler_remove_fields(api_handler):
164165

165166
assert api_handler._pop_fields == {"payu": ["session_id"]}
166167

168+
# Now remove the 'model' field from the payu record, as a string.
169+
api_handler.remove_fields("payu", "model")
170+
167171

168-
def test_api_handler_send_api_request_no_loop(local_host, api_handler):
172+
def test_api_handler_send_api_request(api_handler, capsys):
169173
"""
170-
Create and send an API request with telemetry data.
174+
Create and send an API request with telemetry data - just to make sure that
175+
the request is being sent correctly.
171176
"""
172-
173-
api_handler.server_url = local_host
177+
api_handler.server_url = "http://dud/host/endpoint"
174178

175179
# Pretend we only have catalog & payu services and then mock the initialisation
176180
# of the _extra_fields attribute
@@ -189,17 +193,15 @@ def test_api_handler_send_api_request_no_loop(local_host, api_handler):
189193
# Remove indeterminate fields
190194
api_handler.remove_fields("payu", ["session_id", "name"])
191195

192-
with pytest.warns(RuntimeWarning) as warnings_record:
193-
api_handler.send_api_request(
194-
service_name="payu",
195-
function_name="_test",
196-
args=[1, 2, 3],
197-
kwargs={"name": "test_username"},
198-
)
199-
200-
# This should contain two warnings - one for the failed request and one for the
201-
# event loop. Sometimes we get a third, which I need to find.
202-
assert len(warnings_record) >= 2
196+
# We should get a warning because we've used a dud url, but pytest doesn't
197+
# seem to capture subprocess warnings. I'm not sure there is really a good
198+
# way test for this.
199+
api_handler.send_api_request(
200+
service_name="payu",
201+
function_name="_test",
202+
args=[1, 2, 3],
203+
kwargs={"name": "test_username"},
204+
)
203205

204206
assert api_handler._last_record == {
205207
"function": "_test",
@@ -209,12 +211,28 @@ def test_api_handler_send_api_request_no_loop(local_host, api_handler):
209211
"random_number": 2,
210212
}
211213

212-
if len(warnings_record) == 3:
213-
# Just reraise all the warnings if we get an unexpected one so we can come
214-
# back and track it down
215214

216-
for warning in warnings_record:
217-
warnings.warn(warning.message, warning.category, stacklevel=2)
215+
def test_send_in_loop_is_bg():
216+
"""
217+
Send a request, but make sure that it runs in the background (ie. is non-blocking).
218+
219+
There will be some overhead associated with the processes startup and teardown,
220+
but we shouldn't be waiting for the requests to finish. Using a long timeout
221+
and only sending 3 requests should be enough to ensure that we're not accidentally
222+
testing the process startup/teardown time.
223+
"""
224+
start_time = time.time()
225+
226+
for _ in range(3):
227+
send_in_loop(endpoint="https://dud/endpoint", telemetry_data={}, timeout=3)
228+
229+
print("Requests sent")
230+
231+
end_time = time.time()
232+
233+
dt = end_time - start_time
234+
235+
assert dt < 4
218236

219237

220238
def test_api_handler_invalid_endpoint(api_handler):

0 commit comments

Comments
 (0)