diff --git a/libs/opsqueue_python/python/opsqueue/producer.py b/libs/opsqueue_python/python/opsqueue/producer.py index 5c6a7f0..307a264 100644 --- a/libs/opsqueue_python/python/opsqueue/producer.py +++ b/libs/opsqueue_python/python/opsqueue/producer.py @@ -42,6 +42,10 @@ ] +class LookupIdsWithEmptyStrategicMetadataError(Exception): + pass + + class ProducerClient: """ Opsqueue producer client. Allows sending of large collections of operations ('submissions') @@ -367,6 +371,25 @@ def lookup_submission_id_by_prefix(self, prefix: str) -> SubmissionId | None: """ return self.inner.lookup_submission_id_by_prefix(prefix) + def lookup_submission_ids_by_strategic_metadata( + self, strategic_metadata: dict[str, int] + ) -> list[SubmissionId]: + """Attempts to find in-progress submissions where the strategic metadata + of that submission includes all of the key-value pairs of the given + 'strategic_metadata'. A matching submission must include all of the + given key-value pairs, but it may also contain other key-value pairs. + + Raises: + - `LookupIdsWithEmptyStrategicMetadataError` if the provided + 'strategic_metadata' contained no key-value pairs to look for. + + """ + if len(strategic_metadata) == 0: + raise LookupIdsWithEmptyStrategicMetadataError() + return self.inner.lookup_submission_ids_by_strategic_metadata( # type: ignore[no-any-return] + strategic_metadata + ) + def is_completed(self, submission_id: SubmissionId) -> bool: raise NotImplementedError diff --git a/libs/opsqueue_python/src/producer.rs b/libs/opsqueue_python/src/producer.rs index 7d4266c..a0a70ed 100644 --- a/libs/opsqueue_python/src/producer.rs +++ b/libs/opsqueue_python/src/producer.rs @@ -189,6 +189,24 @@ impl ProducerClient { }) } + /// Attempts to find the IDs of submission matching ALL key-values pairs of + /// the given strategic metadata. + pub fn lookup_submission_ids_by_strategic_metadata( + &self, + py: Python<'_>, + strategic_metadata: StrategicMetadataMap, + ) -> CPyResult, E> { + py.allow_threads(|| { + self.block_unless_interrupted(async { + self.producer_client + .lookup_submission_ids_by_strategic_metadata(&strategic_metadata) + .await + .map(|res| res.into_iter().map(Into::into).collect()) + .map_err(|e| CError(R(e))) + }) + }) + } + /// Directly inserts a submission without sending the chunks to GCS /// (but immediately embedding them in the DB). /// NOTE: This does not support StrategicMetadata currently diff --git a/libs/opsqueue_python/tests/test_roundtrip.py b/libs/opsqueue_python/tests/test_roundtrip.py index ad10541..ddb97d3 100644 --- a/libs/opsqueue_python/tests/test_roundtrip.py +++ b/libs/opsqueue_python/tests/test_roundtrip.py @@ -4,6 +4,7 @@ from collections.abc import Iterator, Sequence from opsqueue.producer import ( + LookupIdsWithEmptyStrategicMetadataError, SubmissionId, ProducerClient, SubmissionCompleted, @@ -508,3 +509,54 @@ def consume(x: int) -> int | None: with pytest.raises(SubmissionFailedError) as exc_info: producer_client.blocking_stream_completed_submission(submission_id) assert exc_info.value.submission.chunks_done == len(chunks) - 1 + + +def test_lookup_submission_ids_by_strategic_metadata(opsqueue: OpsqueueProcess) -> None: + """Lookup of submission IDs should only match in progress submissions with + all pieces of strategic metadata. + + """ + url = "file:///tmp/opsqueue/test_lookup_submission_ids_by_strategic_metadata" + producer_client = ProducerClient(f"localhost:{opsqueue.port}", url) + id_1 = producer_client.insert_submission( + [1], chunk_size=1, strategic_metadata={"foo": 1, "bar": 2, "wow": 3} + ) + id_2 = producer_client.insert_submission( + [1], chunk_size=1, strategic_metadata={"foo": 1, "bar": 2, "moo": 3} + ) + # Inserting some similar data to that above, which shouldn't get matched. + producer_client.insert_submission( + [1], chunk_size=1, strategic_metadata={"foo": 2, "bar": 1} + ) + + def test_lookup( + strategic_metadata: dict[str, int], expected_ids: list[int] + ) -> None: + found_ids = producer_client.lookup_submission_ids_by_strategic_metadata( + strategic_metadata + ) + assert isinstance(found_ids, list) + assert all(map(lambda x: isinstance(x, SubmissionId), found_ids)) + assert found_ids == expected_ids + + test_lookup({"foo": 1}, [id_1, id_2]) + test_lookup({"foo": 1, "bar": 2}, [id_1, id_2]) + test_lookup({"foo": 1, "MISS": 2}, []) + test_lookup({"wow": 3}, [id_1]) + + # Should only match in-progress submission. + producer_client.cancel_submission(id_1) + test_lookup({"foo": 1}, [id_2]) + + +def test_lookup_submission_ids_by_empty_strategic_metadata( + opsqueue: OpsqueueProcess, +) -> None: + """Lookup of submission IDs with empty strategic_metadata should throw a + LookupIdsWithEmptyStrategicMetadataError. + + """ + url = "file:///tmp/opsqueue/test_lookup_submission_ids_by_empty_strategic_metadata" + producer_client = ProducerClient(f"localhost:{opsqueue.port}", url) + with pytest.raises(LookupIdsWithEmptyStrategicMetadataError): + producer_client.lookup_submission_ids_by_strategic_metadata({}) diff --git a/opsqueue/src/common/submission.rs b/opsqueue/src/common/submission.rs index 69040f1..2be455c 100644 --- a/opsqueue/src/common/submission.rs +++ b/opsqueue/src/common/submission.rs @@ -268,7 +268,7 @@ pub mod db { db::{Connection, True, WriterConnection, WriterPool}, }; use chunk::ChunkSize; - use sqlx::{query, query_as, Sqlite}; + use sqlx::{query, query_as, QueryBuilder, Row, Sqlite}; use axum_prometheus::metrics::{counter, histogram}; @@ -527,6 +527,35 @@ pub mod db { Ok(row.map(|row| row.id)) } + pub async fn lookup_ids_by_strategic_metadata( + strategic_metadata: StrategicMetadataMap, + mut conn: impl Connection, + ) -> Result, DatabaseError> { + // SQLx currently only supports "WHERE X IN (a, ...)" queries for postgres: + // https://github.com/transact-rs/sqlx/blob/main/FAQ.md#how-can-i-do-a-select--where-foo-in--query + // So we workaround this by manually building the query, foregoing + // sqlx's nice type-checking. + let mut query_builder: QueryBuilder = QueryBuilder::new( + " + SELECT submission_id + FROM submissions_metadata + INNER JOIN submissions on submissions.id = submission_id + WHERE (metadata_key, metadata_value) IN ( + ", + ); + query_builder.push_values(strategic_metadata.iter(), |mut b, sm| { + b.push_bind(sm.0).push_bind(sm.1); + }); + query_builder.push(") GROUP BY submission_id HAVING count(*) = "); + query_builder.push_bind(strategic_metadata.len() as i64); + query_builder.push(" ORDER BY submission_id"); + let rows = query_builder.build().fetch_all(conn.get_inner()).await?; + Ok(rows + .into_iter() + .map(|row| row.get("submission_id")) + .collect()) + } + #[tracing::instrument(skip(conn))] pub async fn submission_status( id: SubmissionId, diff --git a/opsqueue/src/producer/client.rs b/opsqueue/src/producer/client.rs index b09e20b..2bfa905 100644 --- a/opsqueue/src/producer/client.rs +++ b/opsqueue/src/producer/client.rs @@ -10,6 +10,7 @@ use crate::{ errors::E::{L, R}, errors::{SubmissionNotCancellable, SubmissionNotFound}, submission::{SubmissionId, SubmissionStatus}, + StrategicMetadataMap, }, tracing::CarrierMap, E, @@ -226,6 +227,33 @@ impl Client { .await } + pub async fn lookup_submission_ids_by_strategic_metadata( + &self, + strategic_metadata: &StrategicMetadataMap, + ) -> Result, InternalProducerClientError> { + (|| async { + let base_url = &self.base_url; + let resp = self + .http_client + .post(format!( + "{base_url}/submissions/lookup_ids_by_strategic_metadata" + )) + .json(strategic_metadata) + .send() + .await? + .error_for_status()?; + let bytes = resp.bytes().await?; + let body = serde_json::from_slice(&bytes)?; + Ok(body) + }) + .retry(retry_policy()) + .when(InternalProducerClientError::is_ephemeral) + .notify(|err, dur| { + tracing::debug!("retrying error {err:?} with sleeping {dur:?}"); + }) + .await + } + /// Get the server's version from the `/version` endpoint. /// /// A successful result will be the value of [`VERSION_CARGO_SEMVER`][crate::VERSION_CARGO_SEMVER] diff --git a/opsqueue/src/producer/server.rs b/opsqueue/src/producer/server.rs index 0298a2b..4986505 100644 --- a/opsqueue/src/producer/server.rs +++ b/opsqueue/src/producer/server.rs @@ -2,7 +2,9 @@ use std::sync::Arc; use crate::common::errors::E::{L, R}; use crate::common::submission::{self, SubmissionId}; +use crate::common::StrategicMetadataMap; use crate::db::{self, DBPools}; +use axum::extract; use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; @@ -60,6 +62,10 @@ impl ServerState { "/submissions/lookup_id_by_prefix/{prefix}", get(lookup_submission_id_by_prefix), ) + .route( + "/submissions/lookup_ids_by_strategic_metadata", + post(lookup_submission_ids_by_strategic_metadata), + ) .route("/submissions/{submission_id}", get(submission_status)) .route("/version", get(crate::server::version_endpoint)) // We're also exposing it here so the producer client can view it .with_state(self) @@ -133,6 +139,16 @@ async fn lookup_submission_id_by_prefix( Ok(Json(submission_id)) } +async fn lookup_submission_ids_by_strategic_metadata( + State(state): State, + extract::Json(strategic_metadata): extract::Json, +) -> Result>, ServerError> { + let mut conn = state.pool.reader_conn().await?; + let submission_ids = + submission::db::lookup_ids_by_strategic_metadata(strategic_metadata, &mut conn).await?; + Ok(Json(submission_ids)) +} + #[tracing::instrument(level = "debug", skip(state))] async fn insert_submission( State(state): State,