diff --git a/agentplatform/_genai/datasets.py b/agentplatform/_genai/datasets.py index 57528681fb..77cc3277e8 100644 --- a/agentplatform/_genai/datasets.py +++ b/agentplatform/_genai/datasets.py @@ -1112,6 +1112,109 @@ def create_from_bigframes( multimodal_dataset=multimodal_dataset, config=config ) + def create_from_gemini_request_jsonl( + self, + *, + gcs_uri: str, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a JSONL file stored on GCS. + + The JSONL file should contain instances of Gemini + `GenerateContentRequest` on each line. The data will be stored in a + BigQuery table with a single column called "requests". The + request_column_name in the dataset metadata will be set to "requests". + + Args: + gcs_uri (str): + The Google Cloud Storage URI of the JSONL file to import. + For example, 'gs://my-bucket/path/to/data.jsonl' + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + The created multimodal dataset. + """ + bigquery = _datasets_utils._try_import_bigquery() + + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + if not gcs_uri.startswith("gs://"): + raise ValueError( + "Invalid GCS URI format. Expected: gs://bucket-name/object-path" + ) + + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + request_column_name = "requests" + + # Setup LoadJobConfig to load the JSONL file as a CSV directly from GCS. + # We use an unused character (unit separator \x1f) as the field delimiter + # and an empty string as the quote character. This forces BigQuery to + # treat each line (a valid JSON string) as a single CSV row. + job_config = bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.CSV, + field_delimiter="\x1f", + quote_character="", + schema=[bigquery.SchemaField(request_column_name, "JSON")], + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + ) + + client = bigquery.Client(project=project, credentials=credentials) + load_job = client.load_table_from_uri( + gcs_uri, + target_table_id, + job_config=job_config, + ) + load_job.result() + + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata() + + read_config = ( + metadata.gemini_request_read_config or types.GeminiRequestReadConfig() + ) + read_config.assembled_request_column_name = request_column_name + metadata.gemini_request_read_config = read_config + + multimodal_dataset.metadata = metadata + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") + + return self.create_from_bigquery( + multimodal_dataset=multimodal_dataset, config=config + ) + def update_multimodal_dataset( self, *, @@ -2400,6 +2503,111 @@ async def create_from_bigframes( multimodal_dataset=multimodal_dataset, config=config ) + async def create_from_gemini_request_jsonl( + self, + *, + gcs_uri: str, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a JSONL file stored on GCS. + + The JSONL file should contain instances of Gemini + `GenerateContentRequest` on each line. The data will be stored in a + BigQuery table with a single column called "requests". The + request_column_name in the dataset metadata will be set to "requests". + + Args: + gcs_uri (str): + The Google Cloud Storage URI of the JSONL file to import. + For example, 'gs://my-bucket/path/to/data.jsonl' + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + The created multimodal dataset. + """ + bigquery = _datasets_utils._try_import_bigquery() + + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + if not gcs_uri.startswith("gs://"): + raise ValueError( + "Invalid GCS URI format. Expected: gs://bucket-name/object-path" + ) + + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = ( + await _datasets_utils._normalize_and_validate_table_id_async( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + ) + else: + dataset_id = await _datasets_utils._create_default_bigquery_dataset_if_not_exists_async( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + request_column_name = "requests" + + # Setup LoadJobConfig to load the JSONL file as a CSV directly from GCS. + # We use an unused character (unit separator \x1f) as the field delimiter + # and an empty string as the quote character. This forces BigQuery to + # treat each line (a valid JSON string) as a single CSV row. + job_config = bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.CSV, + field_delimiter="\x1f", + quote_character="", + schema=[bigquery.SchemaField(request_column_name, "JSON")], + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + ) + + client = bigquery.Client(project=project, credentials=credentials) + load_job = client.load_table_from_uri( + gcs_uri, + target_table_id, + job_config=job_config, + ) + await asyncio.to_thread(load_job.result) + + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata() + + read_config = ( + metadata.gemini_request_read_config or types.GeminiRequestReadConfig() + ) + read_config.assembled_request_column_name = request_column_name + metadata.gemini_request_read_config = read_config + + multimodal_dataset.metadata = metadata + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") + + return await self.create_from_bigquery( + multimodal_dataset=multimodal_dataset, config=config + ) + async def update_multimodal_dataset( self, *, diff --git a/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py index 3a6001d455..6dbd718c01 100644 --- a/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py @@ -295,6 +295,47 @@ def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_repla ) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_from_gemini_request_jsonl(client, is_replay_mode): + gcs_uri = ( + "gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl" + ) + + dataset = client.datasets.create_from_gemini_request_jsonl( + gcs_uri=gcs_uri, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-gemini-jsonl", + }, + ) + assert dataset.display_name == "test-from-gemini-jsonl" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "requests" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + df = rows.to_dataframe() + assert len(df) > 0 + assert "requests" in df.columns + assert "contents" in str(df["requests"].iloc[0]) + + +def test_create_from_gemini_request_jsonl_raises_invalid_gcs_uri(client): + with pytest.raises( + ValueError, + match="Invalid GCS URI format. Expected: gs://bucket-name/object-path", + ): + client.datasets.create_from_gemini_request_jsonl(gcs_uri="invalid_uri") + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -549,3 +590,48 @@ async def test_create_dataset_from_bigframes_preserves_other_metadata_async( assert dataset.metadata.input_config.bigquery_source.uri == ( f"bq://{BIGQUERY_TABLE_NAME}" ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_from_gemini_request_jsonl_async(client, is_replay_mode): + gcs_uri = ( + "gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl" + ) + + dataset = await client.aio.datasets.create_from_gemini_request_jsonl( + gcs_uri=gcs_uri, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-gemini-jsonl-async", + }, + ) + assert dataset.display_name == "test-from-gemini-jsonl-async" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "requests" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + df = rows.to_dataframe() + assert len(df) > 0 + assert "requests" in df.columns + assert "contents" in str(df["requests"].iloc[0]) + + +@pytest.mark.asyncio +async def test_create_from_gemini_request_jsonl_raises_invalid_gcs_uri_async(client): + with pytest.raises( + ValueError, + match="Invalid GCS URI format. Expected: gs://bucket-name/object-path", + ): + await client.aio.datasets.create_from_gemini_request_jsonl( + gcs_uri="invalid_uri" + )