Skip to content

Commit 8c596e8

Browse files
author
Thordata
committed
fix: monotonic wait_for_task and raise errors in get_task_status
1 parent f260a0d commit 8c596e8

File tree

3 files changed

+160
-20
lines changed

3 files changed

+160
-20
lines changed

src/thordata/async_client.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,11 @@ async def create_scraper_task_advanced(self, config: ScraperTaskConfig) -> str:
564564
async def get_task_status(self, task_id: str) -> str:
565565
"""
566566
Check async task status.
567+
568+
Raises:
569+
ThordataConfigError: If public credentials are missing.
570+
ThordataAPIError: If API returns a non-200 code in JSON payload.
571+
ThordataNetworkError: If network/HTTP request fails.
567572
"""
568573
self._require_public_credentials()
569574
session = self._get_session()
@@ -577,17 +582,50 @@ async def get_task_status(self, task_id: str) -> str:
577582
async with session.post(
578583
self._status_url, data=payload, headers=headers
579584
) as response:
585+
response.raise_for_status()
580586
data = await response.json()
581587

582-
if data.get("code") == 200 and data.get("data"):
583-
for item in data["data"]:
588+
if isinstance(data, dict):
589+
code = data.get("code")
590+
if code is not None and code != 200:
591+
msg = extract_error_message(data)
592+
raise_for_code(
593+
f"Task status API Error: {msg}",
594+
code=code,
595+
payload=data,
596+
)
597+
598+
items = data.get("data") or []
599+
for item in items:
584600
if str(item.get("task_id")) == str(task_id):
585601
return item.get("status", "unknown")
586602

587-
return "unknown"
603+
return "unknown"
604+
605+
raise ThordataNetworkError(
606+
f"Unexpected task status response type: {type(data).__name__}",
607+
original_error=None,
608+
)
588609

589-
except Exception as e:
590-
logger.error(f"Async status check failed: {e}")
610+
except asyncio.TimeoutError as e:
611+
raise ThordataTimeoutError(
612+
f"Async status check timed out: {e}", original_error=e
613+
)
614+
except aiohttp.ClientError as e:
615+
raise ThordataNetworkError(
616+
f"Async status check failed: {e}", original_error=e
617+
)
618+
619+
async def safe_get_task_status(self, task_id: str) -> str:
620+
"""
621+
Backward-compatible status check.
622+
623+
Returns:
624+
Status string, or "error" on any exception.
625+
"""
626+
try:
627+
return await self.get_task_status(task_id)
628+
except Exception:
591629
return "error"
592630

593631
async def get_task_result(self, task_id: str, file_type: str = "json") -> str:
@@ -632,9 +670,12 @@ async def wait_for_task(
632670
"""
633671
Wait for a task to complete.
634672
"""
635-
elapsed = 0.0
636673

637-
while elapsed < max_wait:
674+
import time
675+
676+
start = time.monotonic()
677+
678+
while (time.monotonic() - start) < max_wait:
638679
status = await self.get_task_status(task_id)
639680

640681
logger.debug(f"Task {task_id} status: {status}")
@@ -652,7 +693,6 @@ async def wait_for_task(
652693
return status
653694

654695
await asyncio.sleep(poll_interval)
655-
elapsed += poll_interval
656696

657697
raise TimeoutError(f"Task {task_id} did not complete within {max_wait} seconds")
658698

src/thordata/client.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,13 @@ def get_task_status(self, task_id: str) -> str:
697697
"""
698698
Check the status of an asynchronous scraping task.
699699
700-
Args:
701-
task_id: The task ID from create_scraper_task.
702-
703700
Returns:
704701
Status string (e.g., "running", "ready", "failed").
702+
703+
Raises:
704+
ThordataConfigError: If public credentials are missing.
705+
ThordataAPIError: If API returns a non-200 code in JSON payload.
706+
ThordataNetworkError: If network/HTTP request fails.
705707
"""
706708
self._require_public_credentials()
707709

@@ -718,18 +720,46 @@ def get_task_status(self, task_id: str) -> str:
718720
timeout=30,
719721
)
720722
response.raise_for_status()
721-
722723
data = response.json()
723724

724-
if data.get("code") == 200 and data.get("data"):
725-
for item in data["data"]:
725+
if isinstance(data, dict):
726+
code = data.get("code")
727+
if code is not None and code != 200:
728+
msg = extract_error_message(data)
729+
raise_for_code(
730+
f"Task status API Error: {msg}",
731+
code=code,
732+
payload=data,
733+
)
734+
735+
items = data.get("data") or []
736+
for item in items:
726737
if str(item.get("task_id")) == str(task_id):
727738
return item.get("status", "unknown")
728739

729-
return "unknown"
740+
return "unknown"
741+
742+
# Unexpected payload type
743+
raise ThordataNetworkError(
744+
f"Unexpected task status response type: {type(data).__name__}",
745+
original_error=None,
746+
)
747+
748+
except requests.Timeout as e:
749+
raise ThordataTimeoutError(f"Status check timed out: {e}", original_error=e)
750+
except requests.RequestException as e:
751+
raise ThordataNetworkError(f"Status check failed: {e}", original_error=e)
752+
753+
def safe_get_task_status(self, task_id: str) -> str:
754+
"""
755+
Backward-compatible status check.
730756
731-
except Exception as e:
732-
logger.error(f"Status check failed: {e}")
757+
Returns:
758+
Status string, or "error" on any exception.
759+
"""
760+
try:
761+
return self.get_task_status(task_id)
762+
except Exception:
733763
return "error"
734764

735765
def get_task_result(self, task_id: str, file_type: str = "json") -> str:
@@ -797,9 +827,9 @@ def wait_for_task(
797827
"""
798828
import time
799829

800-
elapsed = 0.0
830+
start = time.monotonic()
801831

802-
while elapsed < max_wait:
832+
while (time.monotonic() - start) < max_wait:
803833
status = self.get_task_status(task_id)
804834

805835
logger.debug(f"Task {task_id} status: {status}")
@@ -817,7 +847,6 @@ def wait_for_task(
817847
return status
818848

819849
time.sleep(poll_interval)
820-
elapsed += poll_interval
821850

822851
raise TimeoutError(f"Task {task_id} did not complete within {max_wait} seconds")
823852

tests/test_task_status_and_wait.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
from pytest_httpserver import HTTPServer
3+
4+
from thordata import (
5+
AsyncThordataClient,
6+
ThordataAuthError,
7+
ThordataClient,
8+
)
9+
10+
11+
def test_wait_for_task_timeout_uses_monotonic(monkeypatch) -> None:
12+
client = ThordataClient(scraper_token="dummy", public_token="p", public_key="k")
13+
14+
# Always "running" so it must time out quickly.
15+
monkeypatch.setattr(client, "get_task_status", lambda task_id: "running")
16+
17+
with pytest.raises(TimeoutError):
18+
client.wait_for_task("t1", poll_interval=0.01, max_wait=0.05)
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_async_wait_for_task_timeout_uses_monotonic(monkeypatch) -> None:
23+
async with AsyncThordataClient(
24+
scraper_token="dummy", public_token="p", public_key="k"
25+
) as client:
26+
27+
async def _always_running(task_id: str) -> str:
28+
return "running"
29+
30+
monkeypatch.setattr(client, "get_task_status", _always_running)
31+
32+
with pytest.raises(TimeoutError):
33+
await client.wait_for_task("t1", poll_interval=0.01, max_wait=0.05)
34+
35+
36+
def test_get_task_status_raises_on_non_200_code(httpserver: HTTPServer) -> None:
37+
httpserver.expect_request("/tasks-status", method="POST").respond_with_json(
38+
{"code": 401, "msg": "Unauthorized"}
39+
)
40+
41+
base_url = httpserver.url_for("/").rstrip("/").replace("localhost", "127.0.0.1")
42+
43+
client = ThordataClient(
44+
scraper_token="dummy",
45+
public_token="p",
46+
public_key="k",
47+
web_scraper_api_base_url=base_url,
48+
)
49+
50+
with pytest.raises(ThordataAuthError):
51+
client.get_task_status("t1")
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_async_get_task_status_raises_on_non_200_code(
56+
httpserver: HTTPServer,
57+
) -> None:
58+
httpserver.expect_request("/tasks-status", method="POST").respond_with_json(
59+
{"code": 401, "msg": "Unauthorized"}
60+
)
61+
62+
base_url = httpserver.url_for("/").rstrip("/").replace("localhost", "127.0.0.1")
63+
64+
async with AsyncThordataClient(
65+
scraper_token="dummy",
66+
public_token="p",
67+
public_key="k",
68+
web_scraper_api_base_url=base_url,
69+
) as client:
70+
with pytest.raises(ThordataAuthError):
71+
await client.get_task_status("t1")

0 commit comments

Comments
 (0)