Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clients/python/src/examples/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def task_that_produces(
production_count: int = 1,
random_count: bool = False,
) -> None:
producer = TaskProducer(KafkaProducer({"bootstrap.servers": bootstrap_servers}))
def producer_factory() -> KafkaProducer:
return KafkaProducer({"bootstrap.servers": bootstrap_servers})

producer = TaskProducer(producer_factory)
production_count = random.randint(1, 50) if random_count else production_count
for i in range(production_count):
logger.debug(f"Producing message {i} onto topic {destination_topic}...")
Expand Down
20 changes: 13 additions & 7 deletions clients/python/src/taskbroker_client/worker/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
class TaskProducer:
"""
TaskProducer is a producer abstraction that should be used by tasks
when producing to Kafka is a side effect of a task function.
After a TaskWorker child process executes a task function, it will collect all
producer futures tracked by TaskProducer, and will only register the activation as
a success if all producer futures from that task were successful.
that produce to Kafka as a side effect of their task function.
After a TaskWorker child process executes a task activation, it will collect all
producer futures tracked by TaskProducer, and will only register the task activation as
a success if all producer futures from that activation were successful.
Otherwise, the activation will be retried.
"""

def __init__(self, producer: ProducerProtocol) -> None:
self._inner_producer = producer
def __init__(self, producer_factory: Callable[[], ProducerProtocol]) -> None:
Comment thread
bmckerry marked this conversation as resolved.
self._producer_factory = producer_factory
self._inner_producer: ProducerProtocol | None = None

def _get(self) -> ProducerProtocol:
if self._inner_producer is None:
self._inner_producer = self._producer_factory()
return self._inner_producer

def track_future(self, future: ProducerFuture[BrokerValue[KafkaPayload]]) -> None:
_pending_futures.add(future)
Expand Down Expand Up @@ -53,7 +59,7 @@ def produce(
callbacks: List of Callables to add to the future as done callbacks. The future itself
is the only arg passed to the callback.
"""
future = self._inner_producer.produce(topic, payload)
future = self._get().produce(topic, payload)
self.track_future(future)
if callbacks:
# Arroyo producers can return a SimpleProducerFuture,
Expand Down
11 changes: 8 additions & 3 deletions clients/python/tests/worker/test_producer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Iterator
from concurrent.futures import Future
from datetime import datetime
from functools import partial

import pytest
from arroyo.backends.abstract import ProducerFuture, SimpleProducerFuture
Expand Down Expand Up @@ -36,6 +37,10 @@ def produce(
return future


def get_dummy_producer(use_simple_futures: bool) -> DummyProducer:
return DummyProducer(use_simple_futures=use_simple_futures)


@pytest.fixture(autouse=True)
def clear_pending_futures() -> Iterator[None]:
_pending_futures.clear()
Expand All @@ -44,7 +49,7 @@ def clear_pending_futures() -> Iterator[None]:


def test_producer_tracks_futures() -> None:
producer = TaskProducer(DummyProducer(use_simple_futures=True))
producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=True))
producer.produce(Topic("test"), make_kafka_payload())
assert len(_pending_futures) == 1
future = next(iter(TaskProducer.collect_futures()))
Expand All @@ -53,7 +58,7 @@ def test_producer_tracks_futures() -> None:


def test_producer_executes_callbacks() -> None:
producer = TaskProducer(DummyProducer(use_simple_futures=False))
producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=False))
received: list[Future[BrokerValue[KafkaPayload]]] = []

def callback(future: Future[BrokerValue[KafkaPayload]]) -> None:
Expand All @@ -68,7 +73,7 @@ def callback(future: Future[BrokerValue[KafkaPayload]]) -> None:


def test_producer_rejects_callbacks_for_simple_futures() -> None:
producer = TaskProducer(DummyProducer(use_simple_futures=True))
producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=True))

def callback(future: Future[BrokerValue[KafkaPayload]]) -> None:
pass
Expand Down
Loading