diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 9fa1711ac9..33b84b6ad7 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -115,6 +115,40 @@ def test_create_dataset_from_bigquery(client): ) +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +def test_create_dataset_from_bigquery_with_uri(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +def test_create_dataset_from_bigquery_preserves_other_metadata(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + @pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") def test_create_dataset_from_bigquery_no_display_name(client): dataset = client.datasets.create_from_bigquery( @@ -254,6 +288,42 @@ async def test_create_dataset_from_bigquery_async(client): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +async def test_create_dataset_from_bigquery_with_uri_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_preserves_other_metadata_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") async def test_create_dataset_from_bigquery_no_display_name_async(client): diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index a913523e7e..b5c2dd4445 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -16,7 +16,7 @@ import asyncio import datetime -from typing import Any, Type, TypeVar +from typing import Any, TypeVar import uuid import google.auth.credentials @@ -34,17 +34,6 @@ T = TypeVar("T", bound=BaseModel) -def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: - """Creates a model from a response.""" - model_field_names = model_type.model_fields.keys() - filtered_response = {} - for key, value in response.items(): - snake_key = common.camel_to_snake(key) - if snake_key in model_field_names: - filtered_response[snake_key] = value - return model_type(**filtered_response) - - def validate_multimodal_dataset_bigquery_uri( multimodal_dataset: common.MultimodalDataset, ) -> None: diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index e9febf01b0..b086762523 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -908,14 +908,18 @@ def _wait_for_operation( def create_from_bigquery( self, *, - multimodal_dataset: types.MultimodalDatasetOrDict, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, ) -> types.MultimodalDataset: """Creates a multimodal dataset from a BigQuery table. Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". multimodal_dataset: - Required. A representation of a multimodal dataset. + Optional. A representation of a multimodal dataset. config: Optional. A configuration for creating the multimodal dataset. If not provided, the default configuration will be used. @@ -923,8 +927,15 @@ def create_from_bigquery( Returns: A types.MultimodalDataset object representing a multimodal dataset. """ - if isinstance(multimodal_dataset, dict): + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) if isinstance(config, dict): @@ -947,7 +958,24 @@ def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return types.MultimodalDataset._from_response( + response=response, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), + ) def create_from_pandas( self, @@ -1267,9 +1295,23 @@ def assess_tuning_resources( operation=operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response( - types.TuningResourceUsageAssessmentResult, - response["tuningResourceUsageAssessmentResult"], + return types.TuningResourceUsageAssessmentResult._from_response( + response=response["tuningResourceUsageAssessmentResult"], + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) def assess_tuning_validity( @@ -1329,9 +1371,23 @@ def assess_tuning_validity( operation=operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response( - types.TuningValidationAssessmentResult, - response["tuningValidationAssessmentResult"], + return types.TuningValidationAssessmentResult._from_response( + response=response["tuningValidationAssessmentResult"], + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) def assess_batch_prediction_resources( @@ -1389,8 +1445,23 @@ def assess_batch_prediction_resources( timeout_seconds=config.timeout, ) result = response["batchPredictionResourceUsageAssessmentResult"] - return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + return types.BatchPredictionResourceUsageAssessmentResult._from_response( + response=result, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) def assess_batch_prediction_validity( @@ -1448,8 +1519,23 @@ def assess_batch_prediction_validity( timeout_seconds=config.timeout, ) result = response["batchPredictionValidationAssessmentResult"] - return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + return types.BatchPredictionValidationAssessmentResult._from_response( + response=result, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) @@ -2132,14 +2218,18 @@ async def _wait_for_operation( async def create_from_bigquery( self, *, - multimodal_dataset: types.MultimodalDatasetOrDict, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, ) -> types.MultimodalDataset: """Creates a multimodal dataset from a BigQuery table. Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". multimodal_dataset: - Required. A representation of a multimodal dataset. + Optional. A representation of a multimodal dataset. config: Optional. A configuration for creating the multimodal dataset. If not provided, the default configuration will be used. @@ -2147,8 +2237,15 @@ async def create_from_bigquery( Returns: A types.MultimodalDataset object representing a multimodal dataset. """ - if isinstance(multimodal_dataset, dict): + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) if isinstance(config, dict): @@ -2171,7 +2268,24 @@ async def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return types.MultimodalDataset._from_response( + response=response, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), + ) async def create_from_pandas( self, @@ -2489,9 +2603,23 @@ async def assess_tuning_resources( operation=operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response( - types.TuningResourceUsageAssessmentResult, - response["tuningResourceUsageAssessmentResult"], + return types.TuningResourceUsageAssessmentResult._from_response( + response=response["tuningResourceUsageAssessmentResult"], + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) async def assess_tuning_validity( @@ -2551,9 +2679,23 @@ async def assess_tuning_validity( operation=operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response( - types.TuningValidationAssessmentResult, - response["tuningValidationAssessmentResult"], + return types.TuningValidationAssessmentResult._from_response( + response=response["tuningValidationAssessmentResult"], + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) async def assess_batch_prediction_resources( @@ -2611,8 +2753,23 @@ async def assess_batch_prediction_resources( timeout_seconds=config.timeout, ) result = response["batchPredictionResourceUsageAssessmentResult"] - return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + return types.BatchPredictionResourceUsageAssessmentResult._from_response( + response=result, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), ) async def assess_batch_prediction_validity( @@ -2670,6 +2827,21 @@ async def assess_batch_prediction_validity( timeout_seconds=config.timeout, ) result = response["batchPredictionValidationAssessmentResult"] - return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + return types.BatchPredictionValidationAssessmentResult._from_response( + response=result, + kwargs=( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr( + config, "response_json_schema", None + ), + "include_all_fields": getattr( + config, "include_all_fields", None + ), + } + } + if config + else {} + ), )