diff --git a/README.md b/README.md index 366358efe..f9598afcc 100644 --- a/README.md +++ b/README.md @@ -1257,6 +1257,7 @@ calls in the `temporalio.activity` package make use of it. Specifically: * `in_activity()` - Whether an activity context is present * `info()` - Returns the immutable info of the currently running activity +* `client()` - Returns the Temporal client used by this worker. Only available in `async def` activities. * `heartbeat(*details)` - Record a heartbeat * `is_cancelled()` - Whether a cancellation has been requested on this activity * `wait_for_cancelled()` - `async` call to wait for cancellation request diff --git a/temporalio/activity.py b/temporalio/activity.py index 4a0914bc2..35fae36b0 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( + TYPE_CHECKING, Any, Callable, Iterator, @@ -42,6 +43,9 @@ from .types import CallableType +if TYPE_CHECKING: + from temporalio.client import Client + @overload def defn(fn: CallableType) -> CallableType: ... @@ -179,6 +183,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] + client: Optional[Client] cancellation_details: _ActivityCancellationDetailsHolder _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None @@ -271,13 +276,32 @@ def wait_sync(self, timeout: Optional[float] = None) -> None: self.thread_event.wait(timeout) +def client() -> Client: + """Return a Temporal Client for use in the current activity. + + Returns: + :py:class:`temporalio.client.Client` for use in the current activity. + + Raises: + RuntimeError: When not in an activity. + """ + client = _Context.current().client + if not client: + raise RuntimeError( + "No client available. The client is only available in async " + "(i.e. `async def`) activities; not in sync (i.e. `def`) activities. " + "In tests you can pass a client when creating ActivityEnvironment." + ) + return client + + def in_activity() -> bool: """Whether the current code is inside an activity. Returns: True if in an activity, False otherwise. """ - return not _current_context.get(None) is None + return _current_context.get(None) is not None def info() -> Info: @@ -574,8 +598,10 @@ def _apply_to_callable( fn=fn, # iscoroutinefunction does not return true for async __call__ # TODO(cretz): Why can't MyPy handle this? - is_async=inspect.iscoroutinefunction(fn) - or inspect.iscoroutinefunction(fn.__call__), # type: ignore + is_async=( + inspect.iscoroutinefunction(fn) + or inspect.iscoroutinefunction(fn.__call__) # type: ignore + ), no_thread_cancel_exception=no_thread_cancel_exception, ), ) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 3694dfdc7..d1441b012 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -16,6 +16,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.worker._activity +from temporalio.client import Client _Params = ParamSpec("_Params") _Return = TypeVar("_Return") @@ -63,7 +64,7 @@ class ActivityEnvironment: take effect. Default is noop. """ - def __init__(self) -> None: + def __init__(self, client: Optional[Client] = None) -> None: """Create an ActivityEnvironment for running activity code.""" self.info = _default_info self.on_heartbeat: Callable[..., None] = lambda *args: None @@ -74,6 +75,7 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() + self._client = client self._cancellation_details = ( temporalio.activity._ActivityCancellationDetailsHolder() ) @@ -128,7 +130,7 @@ def run( The callable's result. """ # Create an activity and run it - return _Activity(self, fn).run(*args, **kwargs) + return _Activity(self, fn, self._client).run(*args, **kwargs) class _Activity: @@ -136,10 +138,13 @@ def __init__( self, env: ActivityEnvironment, fn: Callable, + client: Optional[Client], ) -> None: self.env = env self.fn = fn - self.is_async = inspect.iscoroutinefunction(fn) + self.is_async = inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction( + fn.__call__ # type: ignore + ) self.cancel_thread_raiser: Optional[ temporalio.worker._activity._ThreadExceptionRaiser ] = None @@ -163,11 +168,14 @@ def __init__( thread_event=threading.Event(), async_event=asyncio.Event() if self.is_async else None, ), - shield_thread_cancel_exception=None - if not self.cancel_thread_raiser - else self.cancel_thread_raiser.shielded, + shield_thread_cancel_exception=( + None + if not self.cancel_thread_raiser + else self.cancel_thread_raiser.shielded + ), payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, + client=client if self.is_async else None, cancellation_details=env._cancellation_details, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 741dac510..61a072186 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -69,6 +69,7 @@ def __init__( data_converter: temporalio.converter.DataConverter, interceptors: Sequence[Interceptor], metric_meter: temporalio.common.MetricMeter, + client: temporalio.client.Client, encode_headers: bool, ) -> None: self._bridge_worker = bridge_worker @@ -86,6 +87,7 @@ def __init__( None ) self._seen_sync_activity = False + self._client = client # Validate and build activity dict self._activities: Dict[str, temporalio.activity._Definition] = {} @@ -569,11 +571,14 @@ async def _execute_activity( heartbeat=None, cancelled_event=running_activity.cancelled_event, worker_shutdown_event=self._worker_shutdown_event, - shield_thread_cancel_exception=None - if not running_activity.cancel_thread_raiser - else running_activity.cancel_thread_raiser.shielded, + shield_thread_cancel_exception=( + None + if not running_activity.cancel_thread_raiser + else running_activity.cancel_thread_raiser.shielded + ), payload_converter_class_or_instance=self._data_converter.payload_converter, runtime_metric_meter=None if sync_non_threaded else self._metric_meter, + client=self._client if not running_activity.sync else None, cancellation_details=running_activity.cancellation_details, ) ) @@ -679,7 +684,7 @@ def _raise_in_thread_if_pending_unlocked(self) -> None: class _ActivityInboundImpl(ActivityInboundInterceptor): - def __init__( + def __init__( # type: ignore[reportMissingSuperCall] self, worker: _ActivityWorker, running_activity: _RunningActivity ) -> None: # We are intentionally not calling the base class's __init__ here @@ -786,7 +791,7 @@ async def heartbeat_with_context(*details: Any) -> None: class _ActivityOutboundImpl(ActivityOutboundInterceptor): - def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None: + def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None: # type: ignore[reportMissingSuperCall] # We are intentionally not calling the base class's __init__ here self._worker = worker self._info = info @@ -838,11 +843,12 @@ def _execute_sync_activity( worker_shutdown_event=temporalio.activity._CompositeEvent( thread_event=worker_shutdown_event, async_event=None ), - shield_thread_cancel_exception=None - if not cancel_thread_raiser - else cancel_thread_raiser.shielded, + shield_thread_cancel_exception=( + None if not cancel_thread_raiser else cancel_thread_raiser.shielded + ), payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, + client=None, cancellation_details=cancellation_details, ) ) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index ca35b9a88..8520a1656 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -411,8 +411,10 @@ def __init__( data_converter=client_config["data_converter"], interceptors=interceptors, metric_meter=self._runtime.metric_meter, - encode_headers=client_config["header_codec_behavior"] - == HeaderCodecBehavior.CODEC, + client=client, + encode_headers=( + client_config["header_codec_behavior"] == HeaderCodecBehavior.CODEC + ), ) self._nexus_worker: Optional[_NexusWorker] = None if nexus_service_handlers: @@ -577,12 +579,12 @@ def config(self) -> WorkerConfig: @property def task_queue(self) -> str: """Task queue this worker is on.""" - return self._config["task_queue"] + return self._config["task_queue"] # type: ignore[reportTypedDictNotRequiredAccess] @property def client(self) -> temporalio.client.Client: """Client currently set on the worker.""" - return self._config["client"] + return self._config["client"] # type: ignore[reportTypedDictNotRequiredAccess] @client.setter def client(self, value: temporalio.client.Client) -> None: @@ -679,9 +681,9 @@ async def raise_on_shutdown(): ) if exception: logger.error("Worker failed, shutting down", exc_info=exception) - if self._config["on_fatal_error"]: + if self._config["on_fatal_error"]: # type: ignore[reportTypedDictNotRequiredAccess] try: - await self._config["on_fatal_error"](exception) + await self._config["on_fatal_error"](exception) # type: ignore[reportTypedDictNotRequiredAccess] except: logger.warning("Fatal error handler failed") @@ -692,7 +694,7 @@ async def raise_on_shutdown(): # Cancel the shutdown task (safe if already done) tasks[None].cancel() - graceful_timeout = self._config["graceful_shutdown_timeout"] + graceful_timeout = self._config["graceful_shutdown_timeout"] # type: ignore[reportTypedDictNotRequiredAccess] logger.info( f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities" ) diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index ff281d722..2acf93639 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -2,8 +2,12 @@ import threading import time from contextvars import copy_context +from unittest.mock import Mock + +import pytest from temporalio import activity +from temporalio.client import Client from temporalio.exceptions import CancelledError from temporalio.testing import ActivityEnvironment @@ -122,3 +126,44 @@ async def assert_equals(a: str, b: str) -> None: assert type(expected_err) == type(actual_err) assert str(expected_err) == str(actual_err) + + +async def test_error_on_access_client_in_activity_environment_without_client(): + saw_error: bool = False + + async def my_activity() -> None: + with pytest.raises(RuntimeError, match="No client available"): + activity.client() + nonlocal saw_error + saw_error = True + + env = ActivityEnvironment() + await env.run(my_activity) + assert saw_error + + +async def test_access_client_in_activity_environment_with_client(): + got_client: bool = False + + async def my_activity() -> None: + nonlocal got_client + if activity.client(): + got_client = True + + env = ActivityEnvironment(client=Mock(spec=Client)) + await env.run(my_activity) + assert got_client + + +async def test_error_on_access_client_in_sync_activity_in_environment_with_client(): + saw_error: bool = False + + def my_activity() -> None: + with pytest.raises(RuntimeError, match="No client available"): + activity.client() + nonlocal saw_error + saw_error = True + + env = ActivityEnvironment(client=Mock(spec=Client)) + env.run(my_activity) + assert saw_error diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index 66380846a..20af57504 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -94,6 +94,50 @@ async def get_name(name: str) -> str: assert result.result == "Name: my custom activity name!" +async def test_client_available_in_async_activities( + client: Client, worker: ExternalWorker +): + with pytest.raises(RuntimeError) as err: + activity.client() + assert str(err.value) == "Not in activity context" + + captured_client: Optional[Client] = None + + @activity.defn + async def capture_client() -> None: + nonlocal captured_client + captured_client = activity.client() + + await _execute_workflow_with_activity(client, worker, capture_client) + assert captured_client is client + + +async def test_client_not_available_in_sync_activities( + client: Client, worker: ExternalWorker +): + saw_error = False + + @activity.defn + def some_activity() -> None: + with pytest.raises( + RuntimeError, match="The client is only available in async" + ) as err: + activity.client() + nonlocal saw_error + saw_error = True + + await _execute_workflow_with_activity( + client, + worker, + some_activity, + worker_config={ + "activity_executor": concurrent.futures.ThreadPoolExecutor(1), + "max_concurrent_activities": 1, + }, + ) + assert saw_error + + async def test_activity_info( client: Client, worker: ExternalWorker, env: WorkflowEnvironment ): @@ -612,7 +656,7 @@ async def some_activity(param1: SomeClass2, param2: str) -> str: result.result == "param1: , param2: " ) - assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123)) + assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123)) # type: ignore[reportUnboundVariable] # noqa async def test_activity_heartbeat_details(