|
12 | 12 | import asyncio
|
13 | 13 | import pydantic
|
14 | 14 | import yaml
|
| 15 | +import multiprocessing |
15 | 16 | from pathlib import Path
|
| 17 | +from multiprocessing import Process |
| 18 | + |
| 19 | +try: |
| 20 | + multiprocessing.set_start_method("spawn") |
| 21 | +except RuntimeError: |
| 22 | + """ |
| 23 | + Since this is a side effect on module import, we can't guarantee that it will |
| 24 | + always be called before the first Process is created, or that the start method |
| 25 | + isn't already set. In this case, we just ignore the error - the processes should |
| 26 | + still work fine. |
| 27 | + """ |
16 | 28 | from .utils import ENDPOINTS, REGISTRIES
|
17 | 29 |
|
18 | 30 | S = TypeVar("S", bound="SessionID")
|
@@ -84,11 +96,14 @@ def pop_fields(self) -> dict[str, list[str]]:
|
84 | 96 | return self._pop_fields
|
85 | 97 |
|
86 | 98 | @pydantic.validate_call
|
87 |
| - def remove_fields(self, service: str, fields: Iterable[str]) -> None: |
| 99 | + def remove_fields(self, service: str, fields: str | Iterable[str]) -> None: |
88 | 100 | """
|
89 | 101 | Set the fields to remove from the telemetry data for a given service. Useful for excluding default
|
90 | 102 | fields that are not needed for a particular telemetry call: eg, removing
|
91 | 103 | Session tracking if a CLI is being used.
|
| 104 | +
|
| 105 | + Note: This does not use a set union, so you must specify all fields you want to remove in one call. |
| 106 | + # TODO: Maybe make this easier to use? |
92 | 107 | """
|
93 | 108 | if isinstance(fields, str):
|
94 | 109 | fields = [fields]
|
@@ -244,44 +259,98 @@ async def send_telemetry(endpoint: str, data: dict[str, Any]) -> None:
|
244 | 259 | return None
|
245 | 260 |
|
246 | 261 |
|
247 |
| -def send_in_loop(endpoint: str, telemetry_data: dict[str, Any]) -> None: |
| 262 | +def send_in_loop( |
| 263 | + endpoint: str, telemetry_data: dict[str, Any], timeout: float | None = None |
| 264 | +) -> None: |
248 | 265 | """
|
249 | 266 | Wraps the send_telemetry function in an event loop. This function will:
|
250 | 267 | - Check if an event loop is already running
|
251 |
| - - Create a new event loop if one is not running |
252 |
| - - Send the telemetry data |
| 268 | + - If an event loop is running, send the telemetry data in the background |
| 269 | + - If an event loop is not running, create a new event loop in a separate process |
| 270 | + and send the telemetry data in the background using that loop. |
253 | 271 |
|
254 | 272 | Parameters
|
255 | 273 | ----------
|
256 | 274 | endpoint : str
|
257 | 275 | The URL to send the telemetry data to.
|
258 | 276 | telemetry_data : dict
|
259 | 277 | The telemetry data to send.
|
| 278 | + timeout : float, optional |
| 279 | + The maximum time to wait for the coroutine to finish. If the coroutine takes |
| 280 | + longer than this time, a TimeoutError will be raised. If None, the coroutine |
| 281 | + will terminate after 60 seconds. |
260 | 282 |
|
261 | 283 | Returns
|
262 | 284 | -------
|
263 | 285 | None
|
264 | 286 |
|
265 |
| - Warnings |
266 |
| - -------- |
267 |
| - RuntimeWarning |
268 |
| - If the event loop is not running, telemetry will block execution. |
269 | 287 | """
|
| 288 | + timeout = timeout or 60 |
270 | 289 |
|
271 |
| - # Check if there's an existing event loop, otherwise create a new one |
272 | 290 | try:
|
273 | 291 | loop = asyncio.get_running_loop()
|
274 | 292 | except RuntimeError:
|
275 |
| - loop = asyncio.new_event_loop() |
276 |
| - asyncio.set_event_loop(loop) |
277 |
| - |
278 |
| - if loop.is_running(): |
279 |
| - loop.create_task(send_telemetry(endpoint, telemetry_data)) |
| 293 | + _run_in_proc(endpoint, telemetry_data, timeout) |
280 | 294 | else:
|
281 |
| - # breakpoint() |
282 |
| - # loop.create_task(send_telemetry(telemetry_data)) |
283 |
| - loop.run_until_complete(send_telemetry(endpoint, telemetry_data)) |
| 295 | + loop.create_task(send_telemetry(endpoint, telemetry_data)) |
| 296 | + return None |
| 297 | + |
| 298 | + |
| 299 | +def _run_event_loop(endpoint: str, telemetry_data: dict[str, Any]) -> None: |
| 300 | + """ |
| 301 | + Handles the creation and running of an event loop for sending telemetry data. |
| 302 | + This function is intended to be run in a separate process, and will: |
| 303 | + - Create a new event loop |
| 304 | + - Send the telemetry data |
| 305 | + - Run the event loop until the telemetry data is sent |
| 306 | +
|
| 307 | + Parameters |
| 308 | + ---------- |
| 309 | + endpoint : str |
| 310 | + The URL to send the telemetry data to. |
| 311 | + telemetry_data : dict |
| 312 | + The telemetry data to send. |
| 313 | +
|
| 314 | + Returns |
| 315 | + ------- |
| 316 | + None |
| 317 | + """ |
| 318 | + loop = asyncio.new_event_loop() |
| 319 | + asyncio.set_event_loop(loop) |
| 320 | + loop.run_until_complete(send_telemetry(endpoint, telemetry_data)) |
| 321 | + |
| 322 | + |
| 323 | +def _run_in_proc(endpoint: str, telemetry_data: dict[str, Any], timeout: float) -> None: |
| 324 | + """ |
| 325 | + Handles the creation and running of a separate process for sending telemetry data. |
| 326 | + This function will: |
| 327 | + - Create a new process and run the _run_event_loop function in that process |
| 328 | + - Wait for the process to finish |
| 329 | + - If the process takes longer than the specified timeout, terminate the process |
| 330 | + and raise a warning |
| 331 | +
|
| 332 | + Parameters |
| 333 | + ---------- |
| 334 | + endpoint : str |
| 335 | + The URL to send the telemetry data to. |
| 336 | + telemetry_data : dict |
| 337 | + The telemetry data to send. |
| 338 | + timeout : float |
| 339 | + The maximum time to wait for the process to finish. |
| 340 | +
|
| 341 | + Returns |
| 342 | + ------- |
| 343 | + None |
| 344 | +
|
| 345 | + """ |
| 346 | + proc = Process(target=_run_event_loop, args=(endpoint, telemetry_data)) |
| 347 | + proc.start() |
| 348 | + proc.join(timeout) |
| 349 | + if proc.is_alive(): |
| 350 | + proc.terminate() |
284 | 351 | warnings.warn(
|
285 |
| - "Event loop not running, telemetry will block execution", |
| 352 | + f"Telemetry data not sent within {timeout} seconds", |
286 | 353 | category=RuntimeWarning,
|
| 354 | + stacklevel=2, |
287 | 355 | )
|
| 356 | + return None |
0 commit comments