Skip to content
2 changes: 1 addition & 1 deletion clients/python/src/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ With all of those pre-requisites complete, you can run the example application:

```bash
# Generate 5 tasks
python src/examples/cli.py generate --count 5
python src/examples/cli.py spawn --count 5
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated docs fix, spawn is the correct command


# Run the scheduler which emits a task every 1m
python src/examples/cli.py scheduler
Expand Down
6 changes: 6 additions & 0 deletions clients/python/src/taskbroker_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
This can be mutated by application test harnesses to run tasks without Kafka.
"""

WORKER_CHILD_JOIN_TIMEOUT_SEC = 5
"""
How long the parent worker process should allow child processes
to drain pending produce futures on shutdown before sending SIGKILL.
"""


class CompressionType(Enum):
"""
Expand Down
2 changes: 0 additions & 2 deletions clients/python/src/taskbroker_client/worker/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class TaskProducer:
producer futures tracked by TaskProducer, and will only register the activation as
a success if all producer futures from that task were successful.
Otherwise, the activation will be retried.

TODO: actually have the TaskWorker child check TaskProducer futures.
"""

def __init__(self, producer: ProducerProtocol) -> None:
Expand Down
6 changes: 5 additions & 1 deletion clients/python/src/taskbroker_client/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH,
DEFAULT_WORKER_QUEUE_SIZE,
MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE,
WORKER_CHILD_JOIN_TIMEOUT_SEC,
)
from taskbroker_client.types import InflightTaskActivation, ProcessingResult
from taskbroker_client.worker.client import (
Expand Down Expand Up @@ -735,7 +736,10 @@ def shutdown(self) -> None:
for child in self._children:
child.terminate()
for child in self._children:
child.join()
child.join(WORKER_CHILD_JOIN_TIMEOUT_SEC)
if child.is_alive():
child.kill()
child.join()
Comment thread
cursor[bot] marked this conversation as resolved.

logger.info("taskworker.worker.shutdown.result")
if self._result_thread:
Expand Down
213 changes: 176 additions & 37 deletions clients/python/src/taskbroker_client/worker/workerchild.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import logging
import queue
import signal
import threading
import time
from collections.abc import Callable, Generator, Sequence
from dataclasses import dataclass
from multiprocessing.synchronize import Event
from types import FrameType
from typing import Any
Expand All @@ -16,6 +18,9 @@
import orjson
import sentry_sdk
import zstandard as zstd
from arroyo.backends.abstract import ProducerFuture
from arroyo.backends.kafka import KafkaPayload
from arroyo.types import BrokerValue
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
Expand All @@ -32,6 +37,7 @@
from taskbroker_client.state import clear_current_task, current_task, set_current_task
from taskbroker_client.task import Task
from taskbroker_client.types import ContextHook, InflightTaskActivation, ProcessingResult
from taskbroker_client.worker.producer import TaskProducer

logger = logging.getLogger(__name__)

Expand All @@ -40,6 +46,30 @@ class ProcessingDeadlineExceeded(BaseException):
pass


@dataclass
class ActivationWithPendingFutures:
"""
Represents an executed inflight activation that
produced messages, and is pending results from the producer
futures.

Args:
inflight: The inflight activation.
status: The status of the activation after execution.
execution_start_time: Timestamp of task execution start.
futures_start_time: Timestamp of when the pending futures were enqueued.
pending_futures: Set of pending futures generated by executing this activation.
task_func: The Task object related to this activation.
"""

inflight: InflightTaskActivation
status: TaskActivationStatus.ValueType
execution_start_time: float
futures_start_time: float
pending_futures: set[ProducerFuture[BrokerValue[KafkaPayload]]]
task_func: Task[Any, Any]


@contextlib.contextmanager
def timeout_alarm(
seconds: int, handler: Callable[[int, FrameType | None], None]
Expand Down Expand Up @@ -167,6 +197,18 @@ def child_process(
app = import_app(app_module)
app.load_modules()
metrics = app.metrics
# Signals when the parent worker pool terminates the child
local_shutdown = threading.Event()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference to shutdown_event?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding this right, shutdown_event is shared across all worker child processes in the workerpool: https://github.com/getsentry/taskbroker/blob/main/clients/python/src/taskbroker_client/worker/worker.py#L633

I didn't want to trigger shutdown on all child processes at once by setting shutdown_event in a single child process, so each worker child uses local_shutdown when getting SIGTERM.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah shutdown_event is a multiprocessing event, nevermind


def handle_sigterm(*args: Any) -> None:
logger.info(
"taskworker.worker.sigterm_received",
extra={"processing_pool": processing_pool_name},
)
local_shutdown.set()

signal.signal(signal.SIGTERM, handle_sigterm)
signal.signal(signal.SIGINT, signal.SIG_IGN)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SIGINT is already handled in the worker parent process, which uses the standard shutdown path for child processes, ensuring they drain pending futures. This means we can ignore SIGINT in the worker child processes.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow the rationale. If the parent process is already managing the shutdown triggering the shutdown on the children why do we need to intercept sigterm here ?
It should be up to the main process to tell the children when to shut down. That would make the process a lot easier to follow and more deterministic.

Copy link
Copy Markdown
Member Author

@bmckerry bmckerry May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be up to the main process to tell the children when to shut down

I agree, my understanding is that when the main process receives SIGINT it propagates the signal to all child processes as well (if I'm understanding this thread correctly). By ignoring the signal in child processes, we allow the main process to shut down all child processes when getting a SIGINT (via the standard shutdown process).

I'm not sure when SIGINT is sent other than when using ctrl + c in a dev environment, so this probably isn't a huge deal and I can remove this if you'd prefer.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when SIGINT is sent other than when using ctrl + c in a dev environment, so this probably isn't a huge deal and I can remove this if you'd prefer.

I think that is the main scenario it gets used. K8s uses SIGTERM.


def _get_known_task(activation: TaskActivation) -> Task[Any, Any] | None:
try:
Expand Down Expand Up @@ -195,6 +237,7 @@ def run_worker(
process_type: str,
) -> None:
processed_task_count = 0
pending_task_futures: list[ActivationWithPendingFutures] = []

def handle_alarm(signum: int, frame: FrameType | None) -> None:
"""
Expand All @@ -213,7 +256,40 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None:
f"execution deadline of {deadline} seconds exceeded by {taskname}"
)

while not shutdown_event.is_set():
def await_task_futures(task: ActivationWithPendingFutures) -> None:
"""
Blocks on the result of each producer futures in the given task.
If any futures raise an exception, the task is considered failed
and submitted for retry (if the policy allows).
"""
RESULT_TIMEOUT_SEC = 1
Comment thread
bmckerry marked this conversation as resolved.
try:
# We don't care about the actual result value,
# we just care if result() raises or not
[f.result(RESULT_TIMEOUT_SEC) for f in task.pending_futures]
Comment thread
bmckerry marked this conversation as resolved.
Comment thread
sentry[bot] marked this conversation as resolved.
# If any pending producer futures failed, retry the task
except Exception:
task.status = TASK_ACTIVATION_STATUS_FAILURE
if task.task_func.retry:
retry_state = task.inflight.activation.retry_state
if not task.task_func.retry.max_attempts_reached(retry_state):
task.status = TASK_ACTIVATION_STATUS_RETRY
Comment thread
cursor[bot] marked this conversation as resolved.
pending_task_futures.remove(task)
_task_execution_complete(
inflight=task.inflight,
next_state=task.status,
execution_start_time=task.execution_start_time,
task_func=task.task_func,
futures_start_time=task.futures_start_time,
)

def check_task_future_completion() -> None:
if len(pending_task_futures) > 0:
for task in pending_task_futures.copy():
if all([f.done() for f in task.pending_futures]):
Comment thread
bmckerry marked this conversation as resolved.
await_task_futures(task)
Comment thread
cursor[bot] marked this conversation as resolved.

while not shutdown_event.is_set() and not local_shutdown.is_set():
Comment thread
cursor[bot] marked this conversation as resolved.
if max_task_count and processed_task_count >= max_task_count:
metrics.incr(
"taskworker.worker.max_task_count_reached",
Comment thread
bmckerry marked this conversation as resolved.
Expand All @@ -231,6 +307,7 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None:
"taskworker.worker.child_task_queue_empty",
tags={"processing_pool": processing_pool_name},
)
check_task_future_completion()
continue

task_func = _get_known_task(inflight.activation)
Expand Down Expand Up @@ -364,44 +441,35 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None:
clear_current_task()
processed_task_count += 1

# Get completion time before pushing to queue, so we can measure queue append time
execution_complete_time = time.time()
with metrics.timer(
"taskworker.worker.processed_tasks.put.duration",
tags={
"processing_pool": processing_pool_name,
},
):
processed_tasks.put(
ProcessingResult(
task_id=inflight.activation.id,
status=next_state,
host=inflight.host,
receive_timestamp=inflight.receive_timestamp,
# Send max_attempts and delay_on_retry if this is a retry.
# Don't send it on every task as this codepath is relatively
# unoptimized on the broker side.
max_attempts=(
task_func.retry._times + 1
if task_func.retry and next_state == TASK_ACTIVATION_STATUS_RETRY
else None
),
delay_on_retry=(
task_func.retry._delay
if task_func.retry and next_state == TASK_ACTIVATION_STATUS_RETRY
else None
),
)
task_produced_futures = TaskProducer.collect_futures()
# If the task function itself failed, we don't need to await any
# producer futures since it'll be retried anyways
if next_state != TASK_ACTIVATION_STATUS_COMPLETE:
task_produced_futures = set()

if len(task_produced_futures) == 0:
_task_execution_complete(
inflight,
next_state,
execution_start_time,
Comment thread
sentry[bot] marked this conversation as resolved.
task_func,
)
Comment thread
cursor[bot] marked this conversation as resolved.
else:
pending_task = ActivationWithPendingFutures(
inflight=inflight,
status=next_state,
execution_start_time=execution_start_time,
futures_start_time=time.time(),
pending_futures=task_produced_futures,
task_func=task_func,
)
pending_task_futures.append(pending_task)

record_task_execution(
inflight.activation,
next_state,
execution_start_time,
execution_complete_time,
processing_pool_name,
inflight.host,
)
check_task_future_completion()
Comment thread
bmckerry marked this conversation as resolved.

# Once we get the shutdown signal, drain any pending futures
for task in pending_task_futures.copy():
await_task_futures(task)
Comment thread
bmckerry marked this conversation as resolved.
Comment thread
bmckerry marked this conversation as resolved.

def _execute_activation(
task_func: Task[Any, Any],
Expand Down Expand Up @@ -493,10 +561,14 @@ def record_task_execution(
completion_time: float,
processing_pool_name: str,
taskbroker_host: str,
futures_enqueued_time: float | None = None,
) -> None:
task_added_time = activation.received_at.ToDatetime().timestamp()
execution_duration = completion_time - start_time
execution_latency = completion_time - task_added_time
futures_duration = (
completion_time - futures_enqueued_time if futures_enqueued_time else None
)

logger.debug(
"taskworker.task_execution",
Expand Down Expand Up @@ -537,6 +609,17 @@ def record_task_execution(
"taskbroker_host": taskbroker_host,
},
)
if futures_duration:
metrics.distribution(
"taskworker.worker.future_completion_duration",
futures_duration,
Comment thread
sentry[bot] marked this conversation as resolved.
tags={
"namespace": activation.namespace,
"taskname": activation.taskname,
"processing_pool": processing_pool_name,
"taskbroker_host": taskbroker_host,
},
)

namespace = app.get_namespace(activation.namespace)
metrics.incr(
Expand All @@ -560,6 +643,62 @@ def record_task_execution(
status=monitor_status,
)

def _task_execution_complete(
inflight: InflightTaskActivation,
next_state: TaskActivationStatus.ValueType,
execution_start_time: float,
task_func: Task[Any, Any] | None,
futures_start_time: float | None = None,
) -> None:
# Get completion time before pushing to queue, so we can measure queue append time
execution_complete_time = time.time()
with metrics.timer(
"taskworker.worker.processed_tasks.put.duration",
tags={
"processing_pool": processing_pool_name,
},
):
if task_func and task_func.retry and next_state == TASK_ACTIVATION_STATUS_RETRY:
processed_tasks.put(
ProcessingResult(
task_id=inflight.activation.id,
status=next_state,
host=inflight.host,
receive_timestamp=inflight.receive_timestamp,
# Send max_attempts and delay_on_retry if this is a retry.
# Don't send it on every task as this codepath is relatively
# unoptimized on the broker side.
max_attempts=(
task_func.retry._times + 1
if task_func.retry and next_state == TASK_ACTIVATION_STATUS_RETRY
else None
),
delay_on_retry=(
task_func.retry._delay
if task_func.retry and next_state == TASK_ACTIVATION_STATUS_RETRY
else None
),
)
)
else:
processed_tasks.put(
ProcessingResult(
task_id=inflight.activation.id,
status=next_state,
host=inflight.host,
receive_timestamp=inflight.receive_timestamp,
)
)
record_task_execution(
inflight.activation,
next_state,
execution_start_time,
execution_complete_time,
processing_pool_name,
inflight.host,
futures_start_time,
)

# Run the worker loop
run_worker(
child_tasks,
Expand Down
Loading
Loading