From 2e515a94552559d6ef8acde08a021e9de89b7c9d Mon Sep 17 00:00:00 2001 From: Ben McKerry <110857332+bmckerry@users.noreply.github.com> Date: Wed, 3 Jun 2026 10:33:45 -0400 Subject: [PATCH] feat(TaskProducer): lazy load inner producer --- clients/python/src/examples/tasks.py | 5 ++++- .../src/taskbroker_client/worker/producer.py | 20 ++++++++++++------- clients/python/tests/worker/test_producer.py | 11 +++++++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/clients/python/src/examples/tasks.py b/clients/python/src/examples/tasks.py index 859863bf..190fa0d7 100644 --- a/clients/python/src/examples/tasks.py +++ b/clients/python/src/examples/tasks.py @@ -132,7 +132,10 @@ def task_that_produces( production_count: int = 1, random_count: bool = False, ) -> None: - producer = TaskProducer(KafkaProducer({"bootstrap.servers": bootstrap_servers})) + def producer_factory() -> KafkaProducer: + return KafkaProducer({"bootstrap.servers": bootstrap_servers}) + + producer = TaskProducer(producer_factory) production_count = random.randint(1, 50) if random_count else production_count for i in range(production_count): logger.debug(f"Producing message {i} onto topic {destination_topic}...") diff --git a/clients/python/src/taskbroker_client/worker/producer.py b/clients/python/src/taskbroker_client/worker/producer.py index 6c3f5150..2e741d9b 100644 --- a/clients/python/src/taskbroker_client/worker/producer.py +++ b/clients/python/src/taskbroker_client/worker/producer.py @@ -16,15 +16,21 @@ class TaskProducer: """ TaskProducer is a producer abstraction that should be used by tasks - when producing to Kafka is a side effect of a task function. - After a TaskWorker child process executes a task function, it will collect all - producer futures tracked by TaskProducer, and will only register the activation as - a success if all producer futures from that task were successful. + that produce to Kafka as a side effect of their task function. + After a TaskWorker child process executes a task activation, it will collect all + producer futures tracked by TaskProducer, and will only register the task activation as + a success if all producer futures from that activation were successful. Otherwise, the activation will be retried. """ - def __init__(self, producer: ProducerProtocol) -> None: - self._inner_producer = producer + def __init__(self, producer_factory: Callable[[], ProducerProtocol]) -> None: + self._producer_factory = producer_factory + self._inner_producer: ProducerProtocol | None = None + + def _get(self) -> ProducerProtocol: + if self._inner_producer is None: + self._inner_producer = self._producer_factory() + return self._inner_producer def track_future(self, future: ProducerFuture[BrokerValue[KafkaPayload]]) -> None: _pending_futures.add(future) @@ -53,7 +59,7 @@ def produce( callbacks: List of Callables to add to the future as done callbacks. The future itself is the only arg passed to the callback. """ - future = self._inner_producer.produce(topic, payload) + future = self._get().produce(topic, payload) self.track_future(future) if callbacks: # Arroyo producers can return a SimpleProducerFuture, diff --git a/clients/python/tests/worker/test_producer.py b/clients/python/tests/worker/test_producer.py index 74689b8c..10a16a1a 100644 --- a/clients/python/tests/worker/test_producer.py +++ b/clients/python/tests/worker/test_producer.py @@ -1,6 +1,7 @@ from collections.abc import Iterator from concurrent.futures import Future from datetime import datetime +from functools import partial import pytest from arroyo.backends.abstract import ProducerFuture, SimpleProducerFuture @@ -36,6 +37,10 @@ def produce( return future +def get_dummy_producer(use_simple_futures: bool) -> DummyProducer: + return DummyProducer(use_simple_futures=use_simple_futures) + + @pytest.fixture(autouse=True) def clear_pending_futures() -> Iterator[None]: _pending_futures.clear() @@ -44,7 +49,7 @@ def clear_pending_futures() -> Iterator[None]: def test_producer_tracks_futures() -> None: - producer = TaskProducer(DummyProducer(use_simple_futures=True)) + producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=True)) producer.produce(Topic("test"), make_kafka_payload()) assert len(_pending_futures) == 1 future = next(iter(TaskProducer.collect_futures())) @@ -53,7 +58,7 @@ def test_producer_tracks_futures() -> None: def test_producer_executes_callbacks() -> None: - producer = TaskProducer(DummyProducer(use_simple_futures=False)) + producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=False)) received: list[Future[BrokerValue[KafkaPayload]]] = [] def callback(future: Future[BrokerValue[KafkaPayload]]) -> None: @@ -68,7 +73,7 @@ def callback(future: Future[BrokerValue[KafkaPayload]]) -> None: def test_producer_rejects_callbacks_for_simple_futures() -> None: - producer = TaskProducer(DummyProducer(use_simple_futures=True)) + producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=True)) def callback(future: Future[BrokerValue[KafkaPayload]]) -> None: pass