diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 0a298a032..0b741a927 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1361,9 +1361,9 @@ def process( if self.workflow_info_on_extra: extra.setdefault("temporal_workflow", {}).update(update_details) - kwargs["extra"] = {**extra, **(kwargs.get("extra") or {})} - if msg_extra: - msg = f"{msg} ({msg_extra})" + kwargs["extra"] = {**extra, **(kwargs.get("extra") or {})} + if msg_extra: + msg = f"{msg} ({msg_extra})" return (msg, kwargs) def isEnabledFor(self, level: int) -> bool: diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 665a5393e..86acca464 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1948,8 +1948,29 @@ def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: return None -async def test_workflow_logging(client: Client, env: WorkflowEnvironment): - workflow.logger.full_workflow_info_on_extra = True +@pytest.mark.parametrize( + "with_workflow_info", + [True, False], +) +async def test_workflow_logging( + client: Client, env: WorkflowEnvironment, with_workflow_info: bool +): + orig_on_message = workflow.logger.workflow_info_on_message + orig_on_extra = workflow.logger.workflow_info_on_extra + orig_full_on_extra = workflow.logger.full_workflow_info_on_extra + + try: + workflow.logger.workflow_info_on_message = with_workflow_info + workflow.logger.workflow_info_on_extra = with_workflow_info + workflow.logger.full_workflow_info_on_extra = with_workflow_info + await _do_workflow_logging_test(client, with_workflow_info) + finally: + workflow.logger.workflow_info_on_message = orig_on_message + workflow.logger.workflow_info_on_extra = orig_on_extra + workflow.logger.full_workflow_info_on_extra = orig_full_on_extra + + +async def _do_workflow_logging_test(client: Client, with_workflow_info: bool): with LogCapturer().logs_captured( workflow.logger.base_logger, activity.logger.base_logger ) as capturer: @@ -1976,30 +1997,42 @@ async def test_workflow_logging(client: Client, env: WorkflowEnvironment): assert "signal 2" == await handle.query(LoggingWorkflow.last_signal) # Confirm logs were produced - assert capturer.find_log("Signal: signal 1 ({'attempt':") + assert capturer.find_log("Signal: signal 1") assert capturer.find_log("Signal: signal 2") assert capturer.find_log("Update: update 1") assert capturer.find_log("Update: update 2") assert not capturer.find_log("Signal: signal 3") - # Also make sure it has some workflow info and correct funcName - record = capturer.find_log("Signal: signal 1") - assert ( - record - and record.__dict__["temporal_workflow"]["workflow_type"] - == "LoggingWorkflow" - and record.funcName == "my_signal" - ) - # Since we enabled full info, make sure it's there - assert isinstance(record.__dict__["workflow_info"], workflow.Info) - # Check the log emitted by the update execution. - record = capturer.find_log("Update: update 1") - assert ( - record - and record.__dict__["temporal_workflow"]["update_id"] == "update-1" - and record.__dict__["temporal_workflow"]["update_name"] == "my_update" - and "'update_id': 'update-1'" in record.message - and "'update_name': 'my_update'" in record.message - ) + + if with_workflow_info: + record = capturer.find_log("Signal: signal 1 ({'attempt':") + assert ( + record + and record.__dict__["temporal_workflow"]["workflow_type"] + == "LoggingWorkflow" + and record.funcName == "my_signal" + ) + # Since we enabled full info, make sure it's there + assert isinstance(record.__dict__["workflow_info"], workflow.Info) + + # Check the log emitted by the update execution. + record = capturer.find_log("Update: update 1") + assert ( + record + and record.__dict__["temporal_workflow"]["update_id"] == "update-1" + and record.__dict__["temporal_workflow"]["update_name"] == "my_update" + and "'update_id': 'update-1'" in record.message + and "'update_name': 'my_update'" in record.message + ) + else: + record = capturer.find_log("Signal: signal 1") + assert record and "temporal_workflow" not in record.__dict__ + assert record and "workflow_info" not in record.__dict__ + + record = capturer.find_log("Update: update 1") + assert record and "temporal_workflow" not in record.__dict__ + assert record and "workflow_info" not in record.__dict__ + assert "'update_id': 'update-1'" not in record.message + assert "'update_name': 'my_update'" not in record.message # Clear queue and start a new one with more signals capturer.log_queue.queue.clear()