diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index fea97564b..43091d100 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -2025,6 +2025,7 @@ async def run_child() -> Any: return handle except asyncio.CancelledError: apply_child_cancel_error() + raise async def _outbound_start_nexus_operation( self, input: StartNexusOperationInput[Any, OutputT] diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 71f48cc63..214e8d3bb 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1216,6 +1216,89 @@ async def test_workflow_cancel_child_unstarted(_client: Client): raise NotImplementedError +@workflow.defn +class CancelDuringChildStartWorkflow: + @workflow.run + async def run(self) -> None: + await workflow.start_child_workflow("child-workflow", id="child-id") + await workflow.wait_condition(lambda: False) + + +async def test_workflow_cancelled_during_child_workflow_start(): + now = datetime.now(timezone.utc) + timestamp = Timestamp() + timestamp.FromDatetime(now) + workflow_id = f"workflow-{uuid.uuid4()}" + run_id = str(uuid.uuid4()) + instance = UnsandboxedWorkflowRunner().create_instance( + WorkflowInstanceDetails( + payload_converter_class=DefaultPayloadConverter, + failure_converter_class=DefaultFailureConverter, + interceptor_classes=[], + defn=workflow._Definition.must_from_class(CancelDuringChildStartWorkflow), + info=workflow.Info( + attempt=1, + continued_run_id=None, + cron_schedule=None, + execution_timeout=None, + first_execution_run_id=run_id, + headers={}, + namespace="default", + parent=None, + root=None, + priority=Priority.default, + raw_memo={}, + retry_policy=None, + run_id=run_id, + run_timeout=None, + search_attributes={}, + start_time=now, + task_queue="task-queue", + task_timeout=timedelta(seconds=10), + typed_search_attributes=TypedSearchAttributes([]), + workflow_id=workflow_id, + workflow_start_time=now, + workflow_type="CancelDuringChildStartWorkflow", + ), + randomness_seed=123, + extern_functions={}, + disable_eager_activity_execution=False, + worker_level_failure_exception_types=[], + last_completion_result=Payloads(), + last_failure=None, + ) + ) + + start_activation = WorkflowActivation(run_id=run_id, timestamp=timestamp) + initialize = start_activation.jobs.add().initialize_workflow + initialize.workflow_type = "CancelDuringChildStartWorkflow" + initialize.workflow_id = workflow_id + initialize.randomness_seed = 123 + start_completion = instance.activate(start_activation) + start_commands = start_completion.successful.commands + + assert not start_completion.HasField("failed") + assert len(start_commands) == 1 + assert start_commands[0].HasField("start_child_workflow_execution") + child_seq = start_commands[0].start_child_workflow_execution.seq + + cancel_activation = WorkflowActivation(run_id=run_id, timestamp=timestamp) + cancel_activation.jobs.add().cancel_workflow.SetInParent() + child_start = cancel_activation.jobs.add().resolve_child_workflow_execution_start + child_start.seq = child_seq + child_start.succeeded.run_id = str(uuid.uuid4()) + cancel_completion = instance.activate(cancel_activation) + cancel_command_variants = [ + command.WhichOneof("variant") + for command in cancel_completion.successful.commands + ] + + assert not cancel_completion.HasField("failed") + assert "cancel_child_workflow_execution" in cancel_command_variants + assert "cancel_workflow_execution" in cancel_command_variants + assert "start_timer" not in cancel_command_variants + + @workflow.defn class ReturnSignalWorkflow: def __init__(self) -> None: