From 2fd8ca098677e0b430f4246940821ddc6f0de798 Mon Sep 17 00:00:00 2001 From: Christian Leopoldseder Date: Fri, 24 Apr 2026 09:20:21 -0700 Subject: [PATCH] feat: GenAI SDK client(multimodal) - Allow passing dataset ID in addition to full resource name in dataset methods. PiperOrigin-RevId: 905067692 --- .../replays/test_get_multimodal_datasets.py | 19 +++ vertexai/_genai/_datasets_utils.py | 7 ++ vertexai/_genai/datasets.py | 114 +++++++++++++----- 3 files changed, 111 insertions(+), 29 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py index dbc9da776e..09769040e9 100644 --- a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py @@ -41,6 +41,15 @@ def test_get_dataset_from_public_method(client): assert dataset.display_name == "test-display-name" +def test_get_dataset_by_id(client): + dataset = client.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -67,3 +76,13 @@ async def test_get_dataset_from_public_method_async(client): assert isinstance(dataset, types.MultimodalDataset) assert dataset.name == DATASET assert dataset.display_name == "test-display-name" + + +@pytest.mark.asyncio +async def test_get_dataset_by_id_async(client): + dataset = await client.aio.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index a913523e7e..e063e6802a 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -262,3 +262,10 @@ async def save_dataframe_to_bigquery_async( ) await asyncio.to_thread(copy_job.result) await asyncio.to_thread(bq_client.delete_table, temp_table_id) + + +def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str: + """Resolves a dataset name or ID to a full resource name.""" + if "/" not in resource_name_or_id: + return f"projects/{project}/locations/{location}/datasets/{resource_name_or_id}" + return resource_name_or_id diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index c8a25b9501..046803edf0 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1146,8 +1146,8 @@ def get_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for getting the multimodal dataset. If not provided, the default configuration will be used. @@ -1161,6 +1161,10 @@ def get_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return self._get_multimodal_dataset(config=config, name=name) def delete_multimodal_dataset( @@ -1173,8 +1177,8 @@ def delete_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for deleting the multimodal dataset. If not provided, the default configuration will be used. @@ -1188,6 +1192,10 @@ def delete_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return self._delete_multimodal_dataset(config=config, name=name) def assemble( @@ -1205,8 +1213,8 @@ def assemble( Args: name: - Required. The name of the dataset to assemble. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". gemini_request_read_config: Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be @@ -1223,6 +1231,10 @@ def assemble( elif not config: config = types.AssembleDatasetConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + operation = self._assemble_multimodal_dataset( name=name, gemini_request_read_config=gemini_request_read_config, @@ -1248,8 +1260,8 @@ def assess_tuning_resources( Args: dataset_name: - Required. The name of the dataset to assess the tuning resources - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning resources for. @@ -1271,6 +1283,10 @@ def assess_tuning_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( @@ -1304,8 +1320,8 @@ def assess_tuning_validity( Args: dataset_name: - Required. The name of the dataset to assess the tuning validity - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning validity for. @@ -1332,6 +1348,10 @@ def assess_tuning_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( @@ -1364,8 +1384,8 @@ def assess_batch_prediction_resources( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction resources. @@ -1392,6 +1412,10 @@ def assess_batch_prediction_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( @@ -1425,8 +1449,8 @@ def assess_batch_prediction_validity( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction validity for. @@ -1451,6 +1475,10 @@ def assess_batch_prediction_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( @@ -2384,14 +2412,14 @@ async def get_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for getting the multimodal dataset. If not provided, the default configuration will be used. Returns: - A types.MultimodalDataset object representing the updated multimodal + A types.MultimodalDataset object representing the retrieved multimodal dataset. """ if isinstance(config, dict): @@ -2399,6 +2427,10 @@ async def get_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return await self._get_multimodal_dataset(config=config, name=name) async def delete_multimodal_dataset( @@ -2411,8 +2443,8 @@ async def delete_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for deleting the multimodal dataset. If not provided, the default configuration will be used. @@ -2426,6 +2458,10 @@ async def delete_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return await self._delete_multimodal_dataset(config=config, name=name) async def assemble( @@ -2443,8 +2479,8 @@ async def assemble( Args: name: - Required. The name of the dataset to assemble. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". gemini_request_read_config: Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be @@ -2461,6 +2497,10 @@ async def assemble( elif not config: config = types.AssembleDatasetConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + operation = await self._assemble_multimodal_dataset( name=name, gemini_request_read_config=gemini_request_read_config, @@ -2486,8 +2526,8 @@ async def assess_tuning_resources( Args: dataset_name: - Required. The name of the dataset to assess the tuning resources - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning resources for. @@ -2509,6 +2549,10 @@ async def assess_tuning_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( @@ -2542,8 +2586,8 @@ async def assess_tuning_validity( Args: dataset_name: - Required. The name of the dataset to assess the tuning validity - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning validity for. @@ -2570,6 +2614,10 @@ async def assess_tuning_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( @@ -2602,8 +2650,8 @@ async def assess_batch_prediction_resources( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction resources. @@ -2630,6 +2678,10 @@ async def assess_batch_prediction_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( @@ -2663,8 +2715,8 @@ async def assess_batch_prediction_validity( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction validity for. @@ -2689,6 +2741,10 @@ async def assess_batch_prediction_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(