diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 9c3870b6e3..6c74c53572 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -210,6 +210,12 @@ class FileArtifactVersion(ArtifactVersion): file_name: str = Field( description="Original filename supplied by the caller." ) + display_name: Optional[str] = Field( + default=None, + description=( + "User-facing filename from inline_data.display_name when persisted." + ), + ) class FileArtifactService(BaseArtifactService): @@ -391,6 +397,7 @@ def _save_artifact_sync( stored_filename = artifact_dir.name content_path = version_dir / stored_filename + display_name: Optional[str] = None if artifact.inline_data: content_path.write_bytes(artifact.inline_data.data) mime_type = ( @@ -398,6 +405,7 @@ def _save_artifact_sync( if artifact.inline_data.mime_type else "application/octet-stream" ) + display_name = artifact.inline_data.display_name elif artifact.text is not None: content_path.write_text(artifact.text, encoding="utf-8") mime_type = None @@ -419,6 +427,7 @@ def _save_artifact_sync( version=next_version, canonical_uri=canonical_uri, custom_metadata=custom_metadata, + display_name=display_name, ) logger.debug( @@ -491,7 +500,13 @@ def _load_artifact_sync( ) return None data = content_path.read_bytes() - return types.Part(inline_data=types.Blob(mime_type=mime_type, data=data)) + return types.Part( + inline_data=types.Blob( + mime_type=mime_type, + data=data, + display_name=metadata.display_name if metadata else None, + ) + ) if not content_path.exists(): logger.warning("Text artifact %s missing at %s", filename, content_path) @@ -719,6 +734,7 @@ def _write_metadata( version: int, canonical_uri: str, custom_metadata: Optional[dict[str, Any]], + display_name: Optional[str] = None, ) -> None: """Persists metadata describing an artifact version.""" metadata = FileArtifactVersion( @@ -726,6 +742,7 @@ def _write_metadata( mime_type=mime_type, canonical_uri=canonical_uri, version=version, + display_name=display_name, # Persist caller supplied metadata for feature parity with other # artifact services (e.g. GCS). custom_metadata=dict(custom_metadata or {}), diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index f8706dedbd..d3dda75734 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -39,6 +39,8 @@ logger = logging.getLogger("google_adk." + __name__) +_GCS_DISPLAY_NAME_METADATA_KEY = "adkDisplayName" + class GcsArtifactService(BaseArtifactService): """An artifact service implementation using Google Cloud Storage (GCS).""" @@ -216,8 +218,13 @@ def _save_artifact( app_name, user_id, filename, version, session_id ) blob = self.bucket.blob(blob_name) - if custom_metadata: - blob.metadata = {k: str(v) for k, v in custom_metadata.items()} + blob_metadata = {k: str(v) for k, v in (custom_metadata or {}).items()} + if artifact.inline_data and artifact.inline_data.display_name: + blob_metadata[_GCS_DISPLAY_NAME_METADATA_KEY] = ( + artifact.inline_data.display_name + ) + if blob_metadata: + blob.metadata = blob_metadata if artifact.inline_data: blob.upload_from_string( @@ -268,10 +275,20 @@ def _load_artifact( artifact_bytes = blob.download_as_bytes() if not artifact_bytes: return None - artifact = types.Part.from_bytes( + display_name = None + if blob.metadata: + display_name = blob.metadata.get(_GCS_DISPLAY_NAME_METADATA_KEY) + if display_name: + return types.Part( + inline_data=types.Blob( + mime_type=blob.content_type, + data=artifact_bytes, + display_name=display_name, + ) + ) + return types.Part.from_bytes( data=artifact_bytes, mime_type=blob.content_type ) - return artifact def _list_artifact_keys( self, app_name: str, user_id: str, session_id: Optional[str] diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 8b82397097..6f4b681321 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -959,3 +959,49 @@ async def test_save_artifact_with_snake_case_dict( assert loaded is not None assert loaded.inline_data is not None assert loaded.inline_data.mime_type == "text/plain" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_load_artifact_preserves_inline_data_display_name( + service_type, artifact_service_factory +): + """Binary artifact load restores inline_data.display_name after save.""" + artifact_service = artifact_service_factory(service_type) + app_name = "app0" + user_id = "user0" + session_id = "sess0" + filename = "artifact.bin" + display_name = "My Report (final).png" + artifact = types.Part( + inline_data=types.Blob( + mime_type="image/png", + data=b"\x89PNG\r\n\x1a\n", + display_name=display_name, + ) + ) + + await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.display_name == display_name