Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions agentplatform/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,109 @@ def create_from_bigframes(
multimodal_dataset=multimodal_dataset, config=config
)

def create_from_gemini_request_jsonl(
self,
*,
gcs_uri: str,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
target_table_id: Optional[str] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a JSONL file stored on GCS.

The JSONL file should contain instances of Gemini
`GenerateContentRequest` on each line. The data will be stored in a
BigQuery table with a single column called "requests". The
request_column_name in the dataset metadata will be set to "requests".

Args:
gcs_uri (str):
The Google Cloud Storage URI of the JSONL file to import.
For example, 'gs://my-bucket/path/to/data.jsonl'
multimodal_dataset:
Optional. A representation of a multimodal dataset.
target_table_id (str):
Optional. The BigQuery table id where the dataframe will be
uploaded. The table id can be in the format of "dataset.table"
or "project.dataset.table". Note that the BigQuery
dataset must already exist and be in the same location as the
multimodal dataset. If not provided, a generated table id will
be created in the `vertex_datasets` dataset (e.g.
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
The created multimodal dataset.
"""
bigquery = _datasets_utils._try_import_bigquery()

if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

if not gcs_uri.startswith("gs://"):
raise ValueError(
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
)

project = self._api_client.project
location = self._api_client.location
credentials = self._api_client._credentials

if target_table_id:
target_table_id = _datasets_utils._normalize_and_validate_table_id(
table_id=target_table_id,
project=project,
location=location,
credentials=credentials,
)
else:
dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists(
project=project, location=location, credentials=credentials
)
target_table_id = _datasets_utils._generate_target_table_id(dataset_id)

request_column_name = "requests"

# Setup LoadJobConfig to load the JSONL file as a CSV directly from GCS.
# We use an unused character (unit separator \x1f) as the field delimiter
# and an empty string as the quote character. This forces BigQuery to
# treat each line (a valid JSON string) as a single CSV row.
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.CSV,
field_delimiter="\x1f",
quote_character="",
schema=[bigquery.SchemaField(request_column_name, "JSON")],
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)

client = bigquery.Client(project=project, credentials=credentials)
load_job = client.load_table_from_uri(
gcs_uri,
target_table_id,
job_config=job_config,
)
load_job.result()

multimodal_dataset = multimodal_dataset.model_copy(deep=True)
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()

read_config = (
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
)
read_config.assembled_request_column_name = request_column_name
metadata.gemini_request_read_config = read_config

multimodal_dataset.metadata = metadata
multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}")

return self.create_from_bigquery(
multimodal_dataset=multimodal_dataset, config=config
)

def update_multimodal_dataset(
self,
*,
Expand Down Expand Up @@ -2400,6 +2503,111 @@ async def create_from_bigframes(
multimodal_dataset=multimodal_dataset, config=config
)

async def create_from_gemini_request_jsonl(
self,
*,
gcs_uri: str,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
target_table_id: Optional[str] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a JSONL file stored on GCS.

The JSONL file should contain instances of Gemini
`GenerateContentRequest` on each line. The data will be stored in a
BigQuery table with a single column called "requests". The
request_column_name in the dataset metadata will be set to "requests".

Args:
gcs_uri (str):
The Google Cloud Storage URI of the JSONL file to import.
For example, 'gs://my-bucket/path/to/data.jsonl'
multimodal_dataset:
Optional. A representation of a multimodal dataset.
target_table_id (str):
Optional. The BigQuery table id where the dataframe will be
uploaded. The table id can be in the format of "dataset.table"
or "project.dataset.table". Note that the BigQuery
dataset must already exist and be in the same location as the
multimodal dataset. If not provided, a generated table id will
be created in the `vertex_datasets` dataset (e.g.
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
The created multimodal dataset.
"""
bigquery = _datasets_utils._try_import_bigquery()

if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

if not gcs_uri.startswith("gs://"):
raise ValueError(
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
)

project = self._api_client.project
location = self._api_client.location
credentials = self._api_client._credentials

if target_table_id:
target_table_id = (
await _datasets_utils._normalize_and_validate_table_id_async(
table_id=target_table_id,
project=project,
location=location,
credentials=credentials,
)
)
else:
dataset_id = await _datasets_utils._create_default_bigquery_dataset_if_not_exists_async(
project=project, location=location, credentials=credentials
)
target_table_id = _datasets_utils._generate_target_table_id(dataset_id)

request_column_name = "requests"

# Setup LoadJobConfig to load the JSONL file as a CSV directly from GCS.
# We use an unused character (unit separator \x1f) as the field delimiter
# and an empty string as the quote character. This forces BigQuery to
# treat each line (a valid JSON string) as a single CSV row.
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.CSV,
field_delimiter="\x1f",
quote_character="",
schema=[bigquery.SchemaField(request_column_name, "JSON")],
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)

client = bigquery.Client(project=project, credentials=credentials)
load_job = client.load_table_from_uri(
gcs_uri,
target_table_id,
job_config=job_config,
)
await asyncio.to_thread(load_job.result)

multimodal_dataset = multimodal_dataset.model_copy(deep=True)
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()

read_config = (
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
)
read_config.assembled_request_column_name = request_column_name
metadata.gemini_request_read_config = read_config

multimodal_dataset.metadata = metadata
multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}")

return await self.create_from_bigquery(
multimodal_dataset=multimodal_dataset, config=config
)

async def update_multimodal_dataset(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,50 @@ def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_repla
)


@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
def test_create_from_gemini_request_jsonl(client, is_replay_mode):
if is_replay_mode:
gcs_uri = "gs://test-bucket/test-blob.jsonl"
else:
gcs_uri = (
"gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl"
)

dataset = client.datasets.create_from_gemini_request_jsonl(
gcs_uri=gcs_uri,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl",
},
)
assert dataset.display_name == "test-from-gemini-jsonl"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
df = rows.to_dataframe()
assert len(df) > 0
assert "requests" in df.columns
assert "contents" in str(df["requests"].iloc[0])


def test_create_from_gemini_request_jsonl_raises_invalid_gcs_uri(client):
with pytest.raises(
ValueError,
match="Invalid GCS URI format. Expected: gs://bucket-name/object-path",
):
client.datasets.create_from_gemini_request_jsonl(gcs_uri="invalid_uri")


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -549,3 +593,51 @@ async def test_create_dataset_from_bigframes_preserves_other_metadata_async(
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
async def test_create_from_gemini_request_jsonl_async(client, is_replay_mode):
if is_replay_mode:
gcs_uri = "gs://test-bucket/test-blob-async.jsonl"
else:
gcs_uri = (
"gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl"
)

dataset = await client.aio.datasets.create_from_gemini_request_jsonl(
gcs_uri=gcs_uri,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-gemini-jsonl-async",
},
)
assert dataset.display_name == "test-from-gemini-jsonl-async"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "requests"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
df = rows.to_dataframe()
assert len(df) > 0
assert "requests" in df.columns
assert "contents" in str(df["requests"].iloc[0])


@pytest.mark.asyncio
async def test_create_from_gemini_request_jsonl_raises_invalid_gcs_uri_async(client):
with pytest.raises(
ValueError,
match="Invalid GCS URI format. Expected: gs://bucket-name/object-path",
):
await client.aio.datasets.create_from_gemini_request_jsonl(
gcs_uri="invalid_uri"
)
Loading