Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benches/store_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn get_pending_activations(num_activations: u32, num_workers: u32) {
let namespace = generate_unique_namespace();

for chunk in make_activations_with_namespace(namespace.clone(), num_activations).chunks(1024) {
store.store(chunk.to_vec()).await.unwrap();
store.store(chunk).await.unwrap();
}

assert_eq!(
Expand Down Expand Up @@ -106,7 +106,7 @@ async fn set_status(num_activations: u32, num_workers: u32) {
let namespace = generate_unique_namespace();

for chunk in make_activations_with_namespace(namespace, num_activations).chunks(1024) {
store.store(chunk.to_vec()).await.unwrap();
store.store(chunk).await.unwrap();
}

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion src/fetch/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl ActivationStore for MockStore {
unimplemented!()
}

async fn store(&self, _batch: Vec<Activation>) -> Result<u64, Error> {
async fn store(&self, _batch: &[Activation]) -> Result<u64, Error> {
unimplemented!()
}

Expand Down
18 changes: 9 additions & 9 deletions src/grpc/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async fn test_get_task_success(#[case] adapter: &str) {
let config = create_config();

let activations = make_activations(1);
store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store: store.clone(),
Expand Down Expand Up @@ -179,7 +179,7 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) {
activations[1].activation = payload.encode_to_vec();
activations[1].application = "hammers".into();

store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -213,7 +213,7 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str
let activations = make_activations(2);
let namespace = activations[0].namespace.clone();

store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -243,7 +243,7 @@ async fn test_set_task_status_success(#[case] adapter: &str) {
let config = create_config();

let activations = make_activations(2);
store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -297,7 +297,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) {
activations[1].activation = payload.encode_to_vec();
activations[1].application = "hammers".into();

store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -344,7 +344,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) {
activations[1].activation = payload.encode_to_vec();
activations[1].application = "hammers".into();

store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -381,7 +381,7 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte
let activations = make_activations(2);
let namespace = activations[0].namespace.clone();

store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down Expand Up @@ -420,7 +420,7 @@ async fn test_set_task_status_forwards_to_update_channel(#[case] adapter: &str)
let (update_tx, mut update_rx) = mpsc::channel::<StatusUpdate>(8);

let activations = make_activations(2);
store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store: store.clone(),
Expand Down Expand Up @@ -477,7 +477,7 @@ async fn test_set_task_status_update_channel_closed_returns_internal() {
drop(update_rx);

let activations = make_activations(1);
store.store(activations).await.unwrap();
store.store(&activations).await.unwrap();

let service = TaskbrokerServer {
store,
Expand Down
2 changes: 1 addition & 1 deletion src/kafka/activation_batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl Reducer for ActivationBatcher {
"taskname" => task_name.clone(),
)
.increment(1);
self.forward_batch.push(t.activation.clone());
self.forward_batch.push(t.activation);
return Ok(());
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/kafka/activation_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,15 @@ impl Reducer for ActivationWriter {
let insert_id = Utc::now().timestamp_millis();
debug!("Preparing insert {:?}", insert_id);

let batch = self.batch.clone().unwrap();
let write_to_store_start = Instant::now();
let res = self.store.store(batch.clone()).await;
let res = self.store.store(batch).await;

// If every "preparing" has a matching "completed" we are good
debug!("Completed insert {:?}", insert_id);

match res {
Ok(entries) => {
self.batch.take();
let batch = self.batch.take().unwrap();
let lag = Utc::now()
- batch
.iter()
Expand Down Expand Up @@ -478,7 +477,7 @@ mod tests {
.status(ActivationStatus::Processing)
.build(TaskActivationBuilder::new());

store.store(vec![existing_activation]).await.unwrap();
store.store(&[existing_activation]).await.unwrap();

let mut writer = ActivationWriter::new(store.clone(), writer_config);
let batch = vec![
Expand Down Expand Up @@ -538,7 +537,7 @@ mod tests {
write_failure_backoff_ms: 4000,
};
let first_round = make_activations(200);
store.store(first_round).await.unwrap();
store.store(&first_round).await.unwrap();
assert!(store.db_size().await.unwrap() > 50_000);

// Make more activations that won't be stored.
Expand Down
26 changes: 9 additions & 17 deletions src/kafka/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl MessageQueue for StreamPartitionQueue<KafkaContext> {
#[instrument(skip_all)]
pub async fn map<T>(
queue: impl MessageQueue,
transform: impl Fn(Arc<OwnedMessage>) -> Result<T, Error>,
transform: impl Fn(&OwnedMessage) -> Result<T, Error>,
ok: mpsc::Sender<(iter::Once<OwnedMessage>, T)>,
err: mpsc::Sender<OwnedMessage>,
shutdown: CancellationToken,
Expand All @@ -475,16 +475,10 @@ pub async fn map<T>(
let Some(msg) = val else {
break;
};
let msg = Arc::new(msg.detach()?);
match transform(msg.clone()) {
let msg = msg.detach()?;
match transform(&msg) {
Ok(transformed) => {
if ok.send((
iter::once(
Arc::try_unwrap(msg)
.expect("msg should only have a single strong ref"),
),
transformed,
)).await.is_err() {
if ok.send((iter::once(msg), transformed)).await.is_err() {
debug!("Receive half of ok channel is closed, shutting down...");
break;
}
Expand All @@ -497,11 +491,9 @@ pub async fn map<T>(
"Failed to map message: {:?}",
e,
);
err.send(
Arc::try_unwrap(msg).expect("msg should only have a single strong ref"),
)
.await
.expect("reduce_err is not available");
err.send(msg)
.await
.expect("reduce_err is not available");
}
}
}
Expand Down Expand Up @@ -1836,7 +1828,7 @@ mod tests {
),

map:
|_: Arc<OwnedMessage>| Ok(()),
|_: &OwnedMessage| Ok(()),
reduce:
NoopReducer::new(),
NoopReducer::new(),
Expand All @@ -1853,7 +1845,7 @@ mod tests {
),

map:
|_: Arc<OwnedMessage>| Ok(()),
|_: &OwnedMessage| Ok(()),
reduce:
NoopReducer::new(),
});
Expand Down
6 changes: 2 additions & 4 deletions src/kafka/deserialize.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::Arc;

use anyhow::Error;
use rdkafka::Message;
use rdkafka::message::OwnedMessage;
Expand Down Expand Up @@ -40,12 +38,12 @@ impl DeserializeConfig {
/// In raw mode, raw Kafka bytes are wrapped into a TaskActivation.
/// In normal mode, Kafka messages are expected to contain encoded TaskActivation protos.
/// Messages from the retry topic are always deserialized as activations.
pub fn new(config: DeserializeConfig) -> impl Fn(Arc<OwnedMessage>) -> Result<Activation, Error> {
pub fn new(config: DeserializeConfig) -> impl Fn(&OwnedMessage) -> Result<Activation, Error> {
let raw_deserializer = config.raw_config.map(deserialize_raw::new);
let activation_deserializer = deserialize_activation::new(config.activation_config);
let retry_topic = config.retry_topic;

move |msg: Arc<OwnedMessage>| {
move |msg: &OwnedMessage| {
// Messages from the retry topic are always activations
if let Some(ref retry_topic) = retry_topic
&& msg.topic() == retry_topic
Expand Down
20 changes: 7 additions & 13 deletions src/kafka/deserialize_activation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::sync::Arc;
use std::time::Duration;

use anyhow::{Error, anyhow};
Expand Down Expand Up @@ -35,8 +34,8 @@ pub fn bucket_from_id(id: &str) -> i16 {

pub fn new(
config: DeserializeActivationConfig,
) -> impl Fn(Arc<OwnedMessage>) -> Result<Activation, Error> {
move |msg: Arc<OwnedMessage>| {
) -> impl Fn(&OwnedMessage) -> Result<Activation, Error> {
move |msg: &OwnedMessage| {
let Some(payload) = msg.payload() else {
return Err(anyhow!("Message has no payload"));
};
Expand Down Expand Up @@ -169,8 +168,7 @@ mod tests {
0,
None,
);
let arc_message = Arc::new(message);
let inflight_opt = deserializer(arc_message);
let inflight_opt = deserializer(&message);

assert!(inflight_opt.is_ok());
let inflight = inflight_opt.unwrap();
Expand Down Expand Up @@ -215,8 +213,7 @@ mod tests {
0,
None,
);
let arc_message = Arc::new(message);
let inflight_opt = deserializer(arc_message);
let inflight_opt = deserializer(&message);

assert!(inflight_opt.is_ok());
let inflight = inflight_opt.unwrap();
Expand Down Expand Up @@ -262,8 +259,7 @@ mod tests {
0,
None,
);
let arc_message = Arc::new(message);
let inflight_opt = deserializer(arc_message);
let inflight_opt = deserializer(&message);

assert!(inflight_opt.is_ok());
let inflight = inflight_opt.unwrap();
Expand Down Expand Up @@ -309,8 +305,7 @@ mod tests {
0,
None,
);
let arc_message = Arc::new(message);
let inflight_opt = deserializer(arc_message);
let inflight_opt = deserializer(&message);

assert!(inflight_opt.is_ok());
let inflight = inflight_opt.unwrap();
Expand Down Expand Up @@ -357,8 +352,7 @@ mod tests {
0,
None,
);
let arc_message = Arc::new(message);
let inflight_opt = deserializer(arc_message);
let inflight_opt = deserializer(&message);

assert!(inflight_opt.is_ok());
let inflight = inflight_opt.unwrap();
Expand Down
14 changes: 6 additions & 8 deletions src/kafka/deserialize_raw.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::Error;
use chrono::{DateTime, Utc};
Expand Down Expand Up @@ -102,8 +101,8 @@ fn encode_raw_params(raw_bytes: &[u8]) -> Vec<u8> {

/// Create a deserializer closure for raw mode.
/// Wraps raw Kafka message bytes into a TaskActivation with msgpack-encoded parameters_bytes.
pub fn new(config: RawConfig) -> impl Fn(Arc<OwnedMessage>) -> Result<Activation, Error> {
move |msg: Arc<OwnedMessage>| {
pub fn new(config: RawConfig) -> impl Fn(&OwnedMessage) -> Result<Activation, Error> {
move |msg: &OwnedMessage| {
// Whether a message without payload is valid is technically not up to taskbroker, and we
// can't DLQ messages here. It's easier to convert it to an empty bytestring and let the
// task fail. Failed tasks can be DLQed in upkeep.rs
Expand All @@ -130,7 +129,7 @@ pub fn new(config: RawConfig) -> impl Fn(Arc<OwnedMessage>) -> Result<Activation
#[allow(deprecated)]
parameters: String::new(),
parameters_bytes,
headers: extract_headers(&msg),
headers: extract_headers(msg),
received_at: Some(received_at),
retry_state: None,
processing_deadline_duration: config.processing_deadline_duration.into(),
Expand Down Expand Up @@ -175,7 +174,6 @@ pub fn new(config: RawConfig) -> impl Fn(Arc<OwnedMessage>) -> Result<Activation

#[cfg(test)]
mod tests {
use std::sync::Arc;

use rdkafka::Timestamp;
use rdkafka::message::{Header, OwnedHeaders, OwnedMessage};
Expand Down Expand Up @@ -233,7 +231,7 @@ mod tests {
None,
);

let result = deserializer(Arc::new(message));
let result = deserializer(&message);
assert!(result.is_ok());

let inflight = result.unwrap();
Expand Down Expand Up @@ -273,7 +271,7 @@ mod tests {
None,
);

let result = deserializer(Arc::new(message));
let result = deserializer(&message);
assert!(result.is_ok());

let inflight = result.unwrap();
Expand Down Expand Up @@ -321,7 +319,7 @@ mod tests {
Some(headers),
);

let result = deserializer(Arc::new(message));
let result = deserializer(&message);
assert!(result.is_ok());

let inflight = result.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/push/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl MockStore {

#[async_trait]
impl ActivationStore for MockStore {
async fn store(&self, _batch: Vec<Activation>) -> Result<u64> {
async fn store(&self, _batch: &[Activation]) -> Result<u64> {
Ok(0)
}

Expand Down
Loading
Loading