diff --git a/temporalio/activity.py b/temporalio/activity.py index 281cfcb8d..07b58dce8 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -36,6 +36,7 @@ import temporalio.common import temporalio.converter +from temporalio.client import Client from .types import CallableType @@ -147,6 +148,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] + client: Optional[Client] _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None _metric_meter: Optional[temporalio.common.MetricMeter] = None @@ -238,13 +240,30 @@ def wait_sync(self, timeout: Optional[float] = None) -> None: self.thread_event.wait(timeout) +def client() -> Client: + """Return a Temporal Client for use in the current activity. + + Returns: + :py:class:`temporalio.client.Client` for use in the current activity. + + Raises: + RuntimeError: When not in an activity. + """ + client = _Context.current().client + if not client: + raise RuntimeError( + "No client available. In tests you can pass a client when creating ActivityEnvironment." + ) + return client + + def in_activity() -> bool: """Whether the current code is inside an activity. Returns: True if in an activity, False otherwise. """ - return not _current_context.get(None) is None + return _current_context.get(None) is not None def info() -> Info: diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 35ebce6ee..4dcd545bb 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -16,6 +16,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.worker._activity +from temporalio.client import Client _Params = ParamSpec("_Params") _Return = TypeVar("_Return") @@ -62,7 +63,7 @@ class ActivityEnvironment: take effect. Default is noop. """ - def __init__(self) -> None: + def __init__(self, client: Optional[Client] = None) -> None: """Create an ActivityEnvironment for running activity code.""" self.info = _default_info self.on_heartbeat: Callable[..., None] = lambda *args: None @@ -73,6 +74,7 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() + self._client = client def cancel(self) -> None: """Cancel the activity. @@ -113,7 +115,7 @@ def run( The callable's result. """ # Create an activity and run it - return _Activity(self, fn).run(*args, **kwargs) + return _Activity(self, fn, self._client).run(*args, **kwargs) class _Activity: @@ -121,6 +123,7 @@ def __init__( self, env: ActivityEnvironment, fn: Callable, + client: Optional[Client], ) -> None: self.env = env self.fn = fn @@ -148,11 +151,14 @@ def __init__( thread_event=threading.Event(), async_event=asyncio.Event() if self.is_async else None, ), - shield_thread_cancel_exception=None - if not self.cancel_thread_raiser - else self.cancel_thread_raiser.shielded, + shield_thread_cancel_exception=( + None + if not self.cancel_thread_raiser + else self.cancel_thread_raiser.shielded + ), payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, + client=client, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 9c4889c94..6bebe466c 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -69,6 +69,7 @@ def __init__( data_converter: temporalio.converter.DataConverter, interceptors: Sequence[Interceptor], metric_meter: temporalio.common.MetricMeter, + client: temporalio.client.Client, ) -> None: self._bridge_worker = bridge_worker self._task_queue = task_queue @@ -84,6 +85,7 @@ def __init__( None ) self._seen_sync_activity = False + self._client = client # Validate and build activity dict self._activities: Dict[str, temporalio.activity._Definition] = {} @@ -428,13 +430,16 @@ async def _run_activity( heartbeat=None, cancelled_event=running_activity.cancelled_event, worker_shutdown_event=self._worker_shutdown_event, - shield_thread_cancel_exception=None - if not running_activity.cancel_thread_raiser - else running_activity.cancel_thread_raiser.shielded, + shield_thread_cancel_exception=( + None + if not running_activity.cancel_thread_raiser + else running_activity.cancel_thread_raiser.shielded + ), payload_converter_class_or_instance=self._data_converter.payload_converter, - runtime_metric_meter=None - if sync_non_threaded - else self._metric_meter, + runtime_metric_meter=( + None if sync_non_threaded else self._metric_meter + ), + client=self._client, ) ) temporalio.activity.logger.debug("Starting activity") @@ -692,6 +697,7 @@ async def heartbeat_with_context(*details: Any) -> None: worker_shutdown_event.thread_event, payload_converter_class_or_instance, ctx.runtime_metric_meter, + ctx.client, input.fn, *input.args, ] @@ -739,6 +745,7 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], + client: temporalio.client.Client, fn: Callable[..., Any], *args: Any, ) -> Any: @@ -770,6 +777,7 @@ def _execute_sync_activity( else cancel_thread_raiser.shielded, payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, + client=client, ) ) return fn(*args) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 4c34a950e..b6c8ad60a 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -29,7 +29,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor -from ._tuning import WorkerTuner, _to_bridge_slot_supplier +from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -303,6 +303,7 @@ def __init__( data_converter=client_config["data_converter"], interceptors=interceptors, metric_meter=self._runtime.metric_meter, + client=client, ) self._workflow_worker: Optional[_WorkflowWorker] = None if workflows: diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index 29b66c772..cb6f588c4 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -2,8 +2,12 @@ import threading import time from contextvars import copy_context +from unittest.mock import Mock + +import pytest from temporalio import activity +from temporalio.client import Client from temporalio.exceptions import CancelledError from temporalio.testing import ActivityEnvironment @@ -110,3 +114,30 @@ async def assert_equals(a: str, b: str) -> None: assert type(expected_err) == type(actual_err) assert str(expected_err) == str(actual_err) + + +async def test_activity_env_without_client(): + saw_error: bool = False + + def my_activity() -> None: + with pytest.raises(RuntimeError): + activity.client() + nonlocal saw_error + saw_error = True + + env = ActivityEnvironment() + env.run(my_activity) + assert saw_error + + +async def test_activity_env_with_client(): + got_client: bool = False + + def my_activity() -> None: + nonlocal got_client + if activity.client(): + got_client = True + + env = ActivityEnvironment(client=Mock(spec=Client)) + env.run(my_activity) + assert got_client diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index b17a0650f..6cdbc708f 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -90,6 +90,22 @@ async def get_name(name: str) -> str: assert result.result == "Name: my custom activity name!" +async def test_activity_client(client: Client, worker: ExternalWorker): + with pytest.raises(RuntimeError) as err: + activity.client() + assert str(err.value) == "Not in activity context" + + captured_client: Optional[Client] = None + + @activity.defn + async def capture_client() -> None: + nonlocal captured_client + captured_client = activity.client() + + await _execute_workflow_with_activity(client, worker, capture_client) + assert captured_client is client + + async def test_activity_info( client: Client, worker: ExternalWorker, env: WorkflowEnvironment ):