Skip to content

Provide client in activity context #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import temporalio.common
import temporalio.converter
from temporalio.client import Client

from .types import CallableType

Expand Down Expand Up @@ -147,6 +148,7 @@ class _Context:
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
client: Optional[Client]
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
_metric_meter: Optional[temporalio.common.MetricMeter] = None
Expand Down Expand Up @@ -238,13 +240,30 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
self.thread_event.wait(timeout)


def client() -> Client:
"""Return a Temporal Client for use in the current activity.

Returns:
:py:class:`temporalio.client.Client` for use in the current activity.

Raises:
RuntimeError: When not in an activity.
"""
client = _Context.current().client
if not client:
raise RuntimeError(
"No client available. In tests you can pass a client when creating ActivityEnvironment."
)
return client


def in_activity() -> bool:
"""Whether the current code is inside an activity.

Returns:
True if in an activity, False otherwise.
"""
return not _current_context.get(None) is None
return _current_context.get(None) is not None


def info() -> Info:
Expand Down
16 changes: 11 additions & 5 deletions temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import temporalio.converter
import temporalio.exceptions
import temporalio.worker._activity
from temporalio.client import Client

_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
Expand Down Expand Up @@ -62,7 +63,7 @@ class ActivityEnvironment:
take effect. Default is noop.
"""

def __init__(self) -> None:
def __init__(self, client: Optional[Client] = None) -> None:
"""Create an ActivityEnvironment for running activity code."""
self.info = _default_info
self.on_heartbeat: Callable[..., None] = lambda *args: None
Expand All @@ -73,6 +74,7 @@ def __init__(self) -> None:
self._cancelled = False
self._worker_shutdown = False
self._activities: Set[_Activity] = set()
self._client = client

def cancel(self) -> None:
"""Cancel the activity.
Expand Down Expand Up @@ -113,14 +115,15 @@ def run(
The callable's result.
"""
# Create an activity and run it
return _Activity(self, fn).run(*args, **kwargs)
return _Activity(self, fn, self._client).run(*args, **kwargs)


class _Activity:
def __init__(
self,
env: ActivityEnvironment,
fn: Callable,
client: Optional[Client],
) -> None:
self.env = env
self.fn = fn
Expand Down Expand Up @@ -148,11 +151,14 @@ def __init__(
thread_event=threading.Event(),
async_event=asyncio.Event() if self.is_async else None,
),
shield_thread_cancel_exception=None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=env.payload_converter,
runtime_metric_meter=env.metric_meter,
client=client,
)
self.task: Optional[asyncio.Task] = None

Expand Down
20 changes: 14 additions & 6 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
metric_meter: temporalio.common.MetricMeter,
client: temporalio.client.Client,
) -> None:
self._bridge_worker = bridge_worker
self._task_queue = task_queue
Expand All @@ -84,6 +85,7 @@ def __init__(
None
)
self._seen_sync_activity = False
self._client = client

# Validate and build activity dict
self._activities: Dict[str, temporalio.activity._Definition] = {}
Expand Down Expand Up @@ -428,13 +430,16 @@ async def _run_activity(
heartbeat=None,
cancelled_event=running_activity.cancelled_event,
worker_shutdown_event=self._worker_shutdown_event,
shield_thread_cancel_exception=None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=self._data_converter.payload_converter,
runtime_metric_meter=None
if sync_non_threaded
else self._metric_meter,
runtime_metric_meter=(
None if sync_non_threaded else self._metric_meter
),
client=self._client,
)
)
temporalio.activity.logger.debug("Starting activity")
Expand Down Expand Up @@ -692,6 +697,7 @@ async def heartbeat_with_context(*details: Any) -> None:
worker_shutdown_event.thread_event,
payload_converter_class_or_instance,
ctx.runtime_metric_meter,
ctx.client,
input.fn,
*input.args,
]
Expand Down Expand Up @@ -739,6 +745,7 @@ def _execute_sync_activity(
temporalio.converter.PayloadConverter,
],
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
client: temporalio.client.Client,
fn: Callable[..., Any],
*args: Any,
) -> Any:
Expand Down Expand Up @@ -770,6 +777,7 @@ def _execute_sync_activity(
else cancel_thread_raiser.shielded,
payload_converter_class_or_instance=payload_converter_class_or_instance,
runtime_metric_meter=runtime_metric_meter,
client=client,
)
)
return fn(*args)
Expand Down
3 changes: 2 additions & 1 deletion temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from ._activity import SharedStateManager, _ActivityWorker
from ._interceptor import Interceptor
from ._tuning import WorkerTuner, _to_bridge_slot_supplier
from ._tuning import WorkerTuner
from ._workflow import _WorkflowWorker
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
from .workflow_sandbox import SandboxedWorkflowRunner
Expand Down Expand Up @@ -303,6 +303,7 @@ def __init__(
data_converter=client_config["data_converter"],
interceptors=interceptors,
metric_meter=self._runtime.metric_meter,
client=client,
)
self._workflow_worker: Optional[_WorkflowWorker] = None
if workflows:
Expand Down
31 changes: 31 additions & 0 deletions tests/testing/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import threading
import time
from contextvars import copy_context
from unittest.mock import Mock

import pytest

from temporalio import activity
from temporalio.client import Client
from temporalio.exceptions import CancelledError
from temporalio.testing import ActivityEnvironment

Expand Down Expand Up @@ -110,3 +114,30 @@ async def assert_equals(a: str, b: str) -> None:

assert type(expected_err) == type(actual_err)
assert str(expected_err) == str(actual_err)


async def test_activity_env_without_client():
saw_error: bool = False

def my_activity() -> None:
with pytest.raises(RuntimeError):
activity.client()
nonlocal saw_error
saw_error = True

env = ActivityEnvironment()
env.run(my_activity)
assert saw_error


async def test_activity_env_with_client():
got_client: bool = False

def my_activity() -> None:
nonlocal got_client
if activity.client():
got_client = True

env = ActivityEnvironment(client=Mock(spec=Client))
env.run(my_activity)
assert got_client
16 changes: 16 additions & 0 deletions tests/worker/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ async def get_name(name: str) -> str:
assert result.result == "Name: my custom activity name!"


async def test_activity_client(client: Client, worker: ExternalWorker):
with pytest.raises(RuntimeError) as err:
activity.client()
assert str(err.value) == "Not in activity context"

captured_client: Optional[Client] = None

@activity.defn
async def capture_client() -> None:
nonlocal captured_client
captured_client = activity.client()

await _execute_workflow_with_activity(client, worker, capture_client)
assert captured_client is client


async def test_activity_info(
client: Client, worker: ExternalWorker, env: WorkflowEnvironment
):
Expand Down