Skip to content

Commit 962135c

Browse files
committed
Allow client to be set in test env
1 parent 5c18ee1 commit 962135c

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

temporalio/activity.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class _Context:
148148
temporalio.converter.PayloadConverter,
149149
]
150150
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
151-
client: Client
151+
client: Optional[Client]
152152
_logger_details: Optional[Mapping[str, Any]] = None
153153
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
154154
_metric_meter: Optional[temporalio.common.MetricMeter] = None
@@ -249,7 +249,12 @@ def client() -> Client:
249249
Raises:
250250
RuntimeError: When not in an activity.
251251
"""
252-
return _Context.current().client
252+
client = _Context.current().client
253+
if not client:
254+
raise RuntimeError(
255+
"No client available. In tests you can pass a client when creating ActivityEnvironment."
256+
)
257+
return client
253258

254259

255260
def in_activity() -> bool:

temporalio/testing/_activity.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import temporalio.converter
1717
import temporalio.exceptions
1818
import temporalio.worker._activity
19+
from temporalio.client import Client
1920

2021
_Params = ParamSpec("_Params")
2122
_Return = TypeVar("_Return")
@@ -62,7 +63,7 @@ class ActivityEnvironment:
6263
take effect. Default is noop.
6364
"""
6465

65-
def __init__(self) -> None:
66+
def __init__(self, client: Optional[Client] = None) -> None:
6667
"""Create an ActivityEnvironment for running activity code."""
6768
self.info = _default_info
6869
self.on_heartbeat: Callable[..., None] = lambda *args: None
@@ -73,6 +74,7 @@ def __init__(self) -> None:
7374
self._cancelled = False
7475
self._worker_shutdown = False
7576
self._activities: Set[_Activity] = set()
77+
self._client = client
7678

7779
def cancel(self) -> None:
7880
"""Cancel the activity.
@@ -113,14 +115,15 @@ def run(
113115
The callable's result.
114116
"""
115117
# Create an activity and run it
116-
return _Activity(self, fn).run(*args, **kwargs)
118+
return _Activity(self, fn, self._client).run(*args, **kwargs)
117119

118120

119121
class _Activity:
120122
def __init__(
121123
self,
122124
env: ActivityEnvironment,
123125
fn: Callable,
126+
client: Optional[Client],
124127
) -> None:
125128
self.env = env
126129
self.fn = fn
@@ -148,11 +151,14 @@ def __init__(
148151
thread_event=threading.Event(),
149152
async_event=asyncio.Event() if self.is_async else None,
150153
),
151-
shield_thread_cancel_exception=None
152-
if not self.cancel_thread_raiser
153-
else self.cancel_thread_raiser.shielded,
154+
shield_thread_cancel_exception=(
155+
None
156+
if not self.cancel_thread_raiser
157+
else self.cancel_thread_raiser.shielded
158+
),
154159
payload_converter_class_or_instance=env.payload_converter,
155160
runtime_metric_meter=env.metric_meter,
161+
client=client,
156162
)
157163
self.task: Optional[asyncio.Task] = None
158164

tests/testing/test_activity.py

+31
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import threading
33
import time
44
from contextvars import copy_context
5+
from unittest.mock import Mock
6+
7+
import pytest
58

69
from temporalio import activity
10+
from temporalio.client import Client
711
from temporalio.exceptions import CancelledError
812
from temporalio.testing import ActivityEnvironment
913

@@ -110,3 +114,30 @@ async def assert_equals(a: str, b: str) -> None:
110114

111115
assert type(expected_err) == type(actual_err)
112116
assert str(expected_err) == str(actual_err)
117+
118+
119+
async def test_activity_env_without_client():
120+
saw_error: bool = False
121+
122+
def my_activity() -> None:
123+
with pytest.raises(RuntimeError):
124+
activity.client()
125+
nonlocal saw_error
126+
saw_error = True
127+
128+
env = ActivityEnvironment()
129+
env.run(my_activity)
130+
assert saw_error
131+
132+
133+
async def test_activity_env_with_client():
134+
got_client: bool = False
135+
136+
def my_activity() -> None:
137+
nonlocal got_client
138+
if activity.client():
139+
got_client = True
140+
141+
env = ActivityEnvironment(client=Mock(spec=Client))
142+
env.run(my_activity)
143+
assert got_client

0 commit comments

Comments
 (0)