-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathcosmos_workflow_checkpointing.py
More file actions
201 lines (159 loc) · 6.64 KB
/
cosmos_workflow_checkpointing.py
File metadata and controls
201 lines (159 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright (c) Microsoft. All rights reserved.
# ruff: noqa: T201
"""Sample: Workflow Checkpointing with Cosmos DB NoSQL.
Purpose:
This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint
storage backend for workflows, enabling durable pause-and-resume across
process restarts.
What you learn:
- How to configure CosmosCheckpointStorage for workflow checkpointing
- How to run a workflow that automatically persists checkpoints to Cosmos DB
- How to resume a workflow from a Cosmos DB checkpoint
- How to list and inspect available checkpoints
Prerequisites:
- An Azure Cosmos DB account (or local emulator)
- Environment variables set (see below)
Environment variables:
AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint
AZURE_COSMOS_DATABASE_NAME - Database name
AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints
Optional:
AZURE_COSMOS_KEY - Account key (if not using Azure credentials)
"""
import asyncio
import os
import sys
from dataclasses import dataclass
from typing import Any
from agent_framework import (
Executor,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowContext,
handler,
)
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from agent_framework_azure_cosmos import CosmosCheckpointStorage
@dataclass
class ComputeTask:
"""Task containing the list of numbers remaining to be processed."""
remaining_numbers: list[int]
class StartExecutor(Executor):
"""Initiates the workflow by providing the upper limit."""
@handler
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
"""Start the workflow with numbers up to the given limit."""
print(f"StartExecutor: Starting computation up to {upper_limit}")
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
class WorkerExecutor(Executor):
"""Processes numbers and manages executor state for checkpointing."""
def __init__(self, id: str) -> None:
"""Initialize the worker executor."""
super().__init__(id=id)
self._results: dict[int, list[tuple[int, int]]] = {}
@handler
async def compute(
self,
task: ComputeTask,
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
) -> None:
"""Process the next number, computing its factor pairs."""
next_number = task.remaining_numbers.pop(0)
print(f"WorkerExecutor: Processing {next_number}")
pairs: list[tuple[int, int]] = []
for i in range(1, next_number):
if next_number % i == 0:
pairs.append((i, next_number // i))
self._results[next_number] = pairs
if not task.remaining_numbers:
await ctx.yield_output(self._results)
else:
await ctx.send_message(task)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
return {"results": self._results}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self._results = state.get("results", {})
async def main() -> None:
"""Run the workflow checkpointing sample with Cosmos DB."""
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name:
print("Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME.")
return
# Authentication: supports both managed identity/RBAC and key-based auth.
# When AZURE_COSMOS_KEY is set, key-based auth is used.
# Otherwise, falls back to DefaultAzureCredential (properly closed via async with).
if cosmos_key:
async with CosmosCheckpointStorage(
endpoint=cosmos_endpoint,
credential=cosmos_key,
database_name=cosmos_database_name,
container_name=cosmos_container_name,
) as checkpoint_storage:
await _run_workflow(checkpoint_storage)
else:
from azure.identity.aio import DefaultAzureCredential
async with (
DefaultAzureCredential() as credential,
CosmosCheckpointStorage(
endpoint=cosmos_endpoint,
credential=credential,
database_name=cosmos_database_name,
container_name=cosmos_container_name,
) as checkpoint_storage,
):
await _run_workflow(checkpoint_storage)
async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None:
"""Build and run the workflow with Cosmos DB checkpointing."""
start = StartExecutor(id="start")
worker = WorkerExecutor(id="worker")
workflow_builder = (
WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage)
.add_edge(start, worker)
.add_edge(worker, worker)
)
# --- First run: execute the workflow ---
print("\n=== First Run ===\n")
workflow = workflow_builder.build()
output = None
async for event in workflow.run(message=8, stream=True):
if event.type == "output":
output = event.data
print(f"Factor pairs computed: {output}")
# List checkpoints saved in Cosmos DB
checkpoint_ids = await checkpoint_storage.list_checkpoint_ids(
workflow_name=workflow.name,
)
print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}")
for cid in checkpoint_ids:
print(f" - {cid}")
# Get the latest checkpoint
latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest(
workflow_name=workflow.name,
)
if latest is None:
print("No checkpoint found to resume from.")
return
print(f"\nLatest checkpoint: {latest.checkpoint_id}")
print(f" iteration_count: {latest.iteration_count}")
print(f" timestamp: {latest.timestamp}")
# --- Second run: resume from the latest checkpoint ---
print("\n=== Resuming from Checkpoint ===\n")
workflow2 = workflow_builder.build()
output2 = None
async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True):
if event.type == "output":
output2 = event.data
if output2:
print(f"Resumed workflow produced: {output2}")
else:
print("Resumed workflow completed (no remaining work — already finished).")
if __name__ == "__main__":
asyncio.run(main())