diff --git a/agentplatform/_genai/model_garden.py b/agentplatform/_genai/model_garden.py index 085c0b1703..ec2ef22603 100644 --- a/agentplatform/_genai/model_garden.py +++ b/agentplatform/_genai/model_garden.py @@ -30,6 +30,56 @@ logger = logging.getLogger("agentplatform_genai.modelgarden") +def _GetPublisherModelConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["hugging_face_token"]) is not None: + setv( + parent_object, + ["_query", "huggingFaceToken"], + getv(from_object, ["hugging_face_token"]), + ) + + if ( + getv(from_object, ["include_equivalent_model_garden_model_deployment_configs"]) + is not None + ): + setv( + parent_object, + ["_query", "includeEquivalentModelGardenModelDeploymentConfigs"], + getv( + from_object, + ["include_equivalent_model_garden_model_deployment_configs"], + ), + ) + + if getv(from_object, ["is_hugging_face_model"]) is not None: + setv( + parent_object, + ["_query", "isHuggingFaceModel"], + getv(from_object, ["is_hugging_face_model"]), + ) + + return to_object + + +def _GetPublisherModelRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _GetPublisherModelConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + def _ListPublisherModelsConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -70,502 +120,1013 @@ def _ListPublisherModelsRequestParameters_to_vertex( class ModelGarden(_api_module.BaseModule): - """Model Garden module.""" - - def _list_publisher_models( - self, - *, - parent: Optional[str] = None, - config: Optional[types.ListPublisherModelsConfigOrDict] = None, - ) -> types.ListPublisherModelsResponse: - """ - Lists publisher models (internal). - """ - - parameter_model = types._ListPublisherModelsRequestParameters( - parent=parent, - config=config, - ) - - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError( - "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." - ) - else: - request_dict = _ListPublisherModelsRequestParameters_to_vertex( - parameter_model - ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{parent}/models".format_map(request_url_dict) - else: - path = "{parent}/models" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( - parameter_model.config is not None - and parameter_model.config.http_options is not None - ): - http_options = parameter_model.config.http_options - - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) - - response = self._api_client.request("get", path, request_dict, http_options) - - response_dict = {} if not response.body else json.loads(response.body) - - return_value = types.ListPublisherModelsResponse._from_response( - response=response_dict, - kwargs=( - { - "config": { - "response_schema": getattr( - parameter_model.config, "response_schema", None - ), - "response_json_schema": getattr( - parameter_model.config, "response_json_schema", None - ), - "include_all_fields": getattr( - parameter_model.config, "include_all_fields", None - ), - } + """Model Garden module.""" + + def _list_publisher_models( + self, + *, + parent: Optional[str] = None, + config: Optional[types.ListPublisherModelsConfigOrDict] = None, + ) -> types.ListPublisherModelsResponse: + """Lists publisher models (internal).""" + + parameter_model = types._ListPublisherModelsRequestParameters( + parent=parent, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform" + " mode, not in Gemini Developer API mode." + ) + else: + request_dict = _ListPublisherModelsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{parent}/models".format_map(request_url_dict) + else: + path = "{parent}/models" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListPublisherModelsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), } - if getattr(parameter_model, "config", None) - else {} - ), + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_publisher_model( + self, + *, + name: str, + config: Optional[types.GetPublisherModelConfigOrDict] = None, + ) -> types.PublisherModel: + """Gets a publisher model (internal).""" + + parameter_model = types._GetPublisherModelRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform" + " mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetPublisherModelRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.PublisherModel._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + @staticmethod + def _build_filter_str( + model_filter: Optional[str], + include_hugging_face_models: bool, + deployable_only: bool, + ) -> str: + """Builds the filter string for the ListPublisherModels API. + + Args: + model_filter: Optional substring to match against model IDs and display + names (case-insensitive). + include_hugging_face_models: Whether to include HuggingFace models. If + True, uses ``is_hf_wildcard(true)``; otherwise + ``is_hf_wildcard(false)``. + deployable_only: Whether to restrict to models with verified deployment + configurations via the ``VERIFIED_DEPLOYMENT_SUCCEED`` label. + + Returns: + A filter string suitable for the ``filter`` parameter of the + ListPublisherModels API. + """ + import re + + if include_hugging_face_models: + filter_str = "is_hf_wildcard(true)" + if deployable_only: + filter_str += ( + " AND labels.VERIFIED_DEPLOYMENT_CONFIG=VERIFIED_DEPLOYMENT_SUCCEED" ) - - self._api_client._verify_response(return_value) - return return_value - - @staticmethod - def _build_filter_str( - model_filter: Optional[str], - include_hugging_face_models: bool, - deployable_only: bool, - ) -> str: - """Builds the filter string for the ListPublisherModels API. - - Args: - model_filter: Optional substring to match against model IDs and display - names (case-insensitive). - include_hugging_face_models: Whether to include HuggingFace models. If - True, uses ``is_hf_wildcard(true)``; otherwise ``is_hf_wildcard(false)``. - deployable_only: Whether to restrict to models with verified deployment - configurations via the ``VERIFIED_DEPLOYMENT_SUCCEED`` label. - - Returns: - A filter string suitable for the ``filter`` parameter of the - ListPublisherModels API. - """ - import re - - if include_hugging_face_models: - filter_str = "is_hf_wildcard(true)" - if deployable_only: - filter_str += ( - " AND labels.VERIFIED_DEPLOYMENT_CONFIG=VERIFIED_DEPLOYMENT_SUCCEED" - ) - else: - filter_str = "is_hf_wildcard(false)" - - if model_filter: - escaped = re.escape(model_filter) - filter_str = ( - f'{filter_str} AND (model_user_id=~"(?i).*{escaped}.*"' - f' OR display_name=~"(?i).*{escaped}.*")' - ) - - return filter_str - - @staticmethod - def _format_model_name( - model: types.PublisherModel, - include_hugging_face_models: bool, - ) -> str: - """Formats a PublisherModel into a human-readable model name string. - - Args: - model: The PublisherModel to format. - include_hugging_face_models: Whether HuggingFace models are included in - the listing. Controls whether the ``@version`` suffix is appended. - - Returns: - A formatted model name string in one of the following formats: - - - ``'{publisher}/{model}@{version}'`` when - ``include_hugging_face_models`` is False. - - ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True. - """ - import re - - name = model.name or "" - formatted = re.sub(r"publishers/(hf-|)|models/", "", name) - if include_hugging_face_models: - return formatted - return formatted + "@" + (model.version_id or "") - - @staticmethod - def _has_deploy_config(model: types.PublisherModel) -> bool: - """Checks whether a model has verified deployment configurations. - - Args: - model: The PublisherModel to check. - - Returns: - True if the model has at least one entry in - ``supported_actions.multi_deploy_vertex``. - """ - return bool( - model.supported_actions - and model.supported_actions.multi_deploy_vertex - and model.supported_actions.multi_deploy_vertex.multi_deploy_vertex + else: + filter_str = "is_hf_wildcard(false)" + + if model_filter: + escaped = re.escape(model_filter) + filter_str = ( + f'{filter_str} AND (model_user_id=~"(?i).*{escaped}.*"' + f' OR display_name=~"(?i).*{escaped}.*")' + ) + + return filter_str + + @staticmethod + def _format_model_name( + model: types.PublisherModel, + include_hugging_face_models: bool, + ) -> str: + """Formats a PublisherModel into a human-readable model name string. + + Args: + model: The PublisherModel to format. + include_hugging_face_models: Whether HuggingFace models are included in + the listing. Controls whether the ``@version`` suffix is appended. + + Returns: + A formatted model name string in one of the following formats: + + - ``'{publisher}/{model}@{version}'`` when + ``include_hugging_face_models`` is False. + - ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True. + """ + import re + + name = model.name or "" + formatted = re.sub(r"publishers/(hf-|)|models/", "", name) + if include_hugging_face_models: + return formatted + return formatted + "@" + (model.version_id or "") + + @staticmethod + def _has_deploy_config(model: types.PublisherModel) -> bool: + """Checks whether a model has verified deployment configurations. + + Args: + model: The PublisherModel to check. + + Returns: + True if the model has at least one entry in + ``supported_actions.multi_deploy_vertex``. + """ + return bool( + model.supported_actions + and model.supported_actions.multi_deploy_vertex + and model.supported_actions.multi_deploy_vertex.multi_deploy_vertex + ) + + @staticmethod + def _reconcile_model_name(model_name: str) -> str: + """Normalizes a model name into a publisher model resource name. + + Args: + model_name: A Model Garden model resource name in the format + ``'publishers/{publisher}/models/{model}@{version}'``, a simplified name + in the format ``'{publisher}/{model}@{version}'`` (or without + ``@{version}``), or a Hugging Face model ID + ``'{organization}/{model}'``. + + Returns: + The resource name in the format + ``'publishers/{publisher}/models/{model}@{version}'``. + + Raises: + ValueError: If ``model_name`` is not a valid publisher model name. + """ + import re + + model_name = model_name.lower() # Hugging Face IDs are lower-case. + # @version is optional so versionless full resource names parse on this + # branch instead of mangling on the simplified branch below. + full_match = re.match( + r"^publishers/(?P[^/]+)/models/(?P[^@]+)(?:@(?P[^@]+))?$", + model_name, + ) + if full_match: + model = full_match.group("model") + if full_match.group("version"): + model = f"{model}@{full_match.group('version')}" + return f"publishers/{full_match.group('publisher')}/models/{model}" + # Reject Model Registry names; they would otherwise match the simplified + # branch and be silently mangled. + if re.match(r"^projects/.+/locations/.+/models/.+$", model_name): + raise ValueError(f"`{model_name}` is not a valid publisher model name") + simplified_match = re.match( + r"^(?P[^/]+)/(?P[^@]+)(?:@(?P.+))?$", + model_name, + ) + if simplified_match: + model = simplified_match.group("model") + if simplified_match.group("version"): + model = f"{model}@{simplified_match.group('version')}" + return f"publishers/{simplified_match.group('publisher')}/models/{model}" + raise ValueError(f"`{model_name}` is not a valid publisher model name") + + @staticmethod + def _is_hugging_face_model(model_name: str) -> bool: + """Returns whether a model name looks like a Hugging Face model ID. + + Matches the bare ``'{organization}/{model}'`` shape (a single slash and no + ``@version``), e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``. + + Args: + model_name: The model name to inspect. + + Returns: + True if ``model_name`` matches the Hugging Face ID shape. + """ + import re + + return bool( + re.match(r"^(?P[^/]+)/(?P[^/@]+)$", model_name) + ) + + @staticmethod + def _matches_filter( + value: Optional[str], + model_filter: Optional[Union[str, list[str]]], + ) -> bool: + """Returns whether ``value`` matches the (optional) keyword filter. + + Mirrors the legacy SDK: the filter may be a single keyword or a list of + keywords, and matching is a case-insensitive substring test where the value + matches if it contains *any* of the keywords. + + Args: + value: The field value to test (e.g. a machine type), or None. + model_filter: A keyword, a list of keywords, or None (no filtering). + + Returns: + True if there is no filter, or if ``value`` contains any of the keywords. + """ + if not model_filter: + return True + if value is None: + return False + keywords = [model_filter] if isinstance(model_filter, str) else model_filter + value_lower = value.lower() + return any(keyword.lower() in value_lower for keyword in keywords) + + @staticmethod + def _extract_and_filter_deploy_options( + publisher_model: types.PublisherModel, + machine_type_filter: Optional[Union[str, list[str]]] = None, + accelerator_type_filter: Optional[Union[str, list[str]]] = None, + serving_container_image_uri_filter: Optional[ + Union[str, list[str]] + ] = None, + ) -> list[types.DeployOption]: + """Extracts and filters deploy options from a publisher model. + + Args: + publisher_model: The publisher model to extract deploy options from. + machine_type_filter: Optional case-insensitive keyword (or list of + keywords) matched against the machine type; an option is kept if its + machine type contains any of them (e.g. ``'g2'`` or ``['n1', 'g2']``). + accelerator_type_filter: Optional case-insensitive keyword (or list of + keywords) matched against the accelerator type (e.g. ``'L4'`` or + ``['T4', 'L4']``). + serving_container_image_uri_filter: Optional case-insensitive keyword (or + list of keywords) matched against the serving container image URI (e.g. + ``'vllm'`` or ``['vllm', 'tgi']``). + + Returns: + A list of ``DeployOption`` objects matching the provided filters. + + Raises: + ValueError: If the model does not support deployment, or if no deploy + options remain after applying the filters. + """ + if not ( + publisher_model.supported_actions + and publisher_model.supported_actions.multi_deploy_vertex + and publisher_model.supported_actions.multi_deploy_vertex.multi_deploy_vertex + ): + raise ValueError( + "Model does not support deployment. " + "Use `list_deployable_models()` to find supported models." + ) + + options = ( + publisher_model.supported_actions.multi_deploy_vertex.multi_deploy_vertex + ) + result = [] + for opt in options: + container = opt.container_spec.image_uri if opt.container_spec else None + machine = ( + opt.dedicated_resources.machine_spec + if opt.dedicated_resources + else None + ) + machine_type = machine.machine_type if machine else None + + # Restore the proto3 defaults the JSON transport drops, so structured + # output matches the gRPC SDK on CPU/TPU machines. + accelerator_enum = machine.accelerator_type if machine else None + accelerator_value = accelerator_enum.value if accelerator_enum else None + has_accelerator = ( + accelerator_value is not None + and accelerator_value != "ACCELERATOR_TYPE_UNSPECIFIED" + ) + if machine: + accelerator_type = ( + accelerator_value + if accelerator_value is not None + else "ACCELERATOR_TYPE_UNSPECIFIED" ) - - def _list_all_publisher_models( - self, - api_config: types.ListPublisherModelsConfig, - ) -> list[types.PublisherModel]: - """Fetches all pages of publisher models from the API. - - Args: - api_config: The configuration for the ListPublisherModels API call, - including filter and version settings. - - Returns: - A list of all ``PublisherModel`` objects across all pages. - """ - all_models = [] - page_token = None - while True: - if page_token: - api_config = types.ListPublisherModelsConfig( - filter=api_config.filter, - list_all_versions=api_config.list_all_versions, - page_token=page_token, - ) - response = self._list_publisher_models( - parent="publishers/*", - config=api_config, - ) - all_models.extend(response.publisher_models or []) - page_token = response.next_page_token - if not page_token: - break - return all_models - - def _list( - self, - model_filter: Optional[str], - include_hugging_face_models: Optional[bool], - deployable_only: bool, - ) -> list[str]: - """Shared implementation for listing models. - - Args: - model_filter: Optional substring to filter models by. - include_hugging_face_models: Whether to include HuggingFace models. - deployable_only: If True, only return models with deployment configs. - - Returns: - A list of formatted model name strings. - """ - include_hf = include_hugging_face_models is True - - filter_str = self._build_filter_str( - model_filter, include_hf, deployable_only=deployable_only + accelerator_count = ( + machine.accelerator_count + if machine.accelerator_count is not None + else 0 ) - + else: + accelerator_type = None + accelerator_count = None + + if not ModelGarden._matches_filter(machine_type, machine_type_filter): + continue + # ACCELERATOR_TYPE_UNSPECIFIED means "no accelerator" and never matches. + if accelerator_type_filter and not has_accelerator: + continue + if not ModelGarden._matches_filter( + accelerator_type, accelerator_type_filter + ): + continue + if not ModelGarden._matches_filter( + container, serving_container_image_uri_filter + ): + continue + + result.append( + types.DeployOption( + option_name=opt.deploy_task_name, + serving_container_image_uri=container, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + ) + ) + + if not result: + raise ValueError("No deploy options found.") + + return result + + @staticmethod + def _format_concise_deploy_options( + options: list[types.DeployOption], + ) -> str: + """Formats deploy options into a human-readable string. + + Mirrors the legacy ``vertexai.model_garden`` SDK output: each option is + rendered as a ``[Option N: ]`` block followed by its non-null + fields (container image, machine type, accelerator type/count). + + Args: + options: The deploy options to format. + + Returns: + A human-readable, multi-line string describing the deploy options. + """ + fields = [ + "serving_container_image_uri", + "machine_type", + "accelerator_type", + "accelerator_count", + ] + blocks = [] + for i, option in enumerate(options): + if option.option_name: + header = f"[Option {i + 1}: {option.option_name}]\n" + else: + header = f"[Option {i + 1}]\n" + lines = [] + for field in fields: + value = getattr(option, field) + if value is None: + continue + if field == "accelerator_count": + lines.append(f" {field}={value},") + else: + lines.append(f' {field}="{value}",') + blocks.append(header + "\n".join(lines)) + return "\n\n".join(blocks) + + def _list_all_publisher_models( + self, + api_config: types.ListPublisherModelsConfig, + ) -> list[types.PublisherModel]: + """Fetches all pages of publisher models from the API. + + Args: + api_config: The configuration for the ListPublisherModels API call, + including filter and version settings. + + Returns: + A list of all ``PublisherModel`` objects across all pages. + """ + all_models = [] + page_token = None + while True: + if page_token: api_config = types.ListPublisherModelsConfig( - filter=filter_str, - list_all_versions=True, - ) - - models = self._list_all_publisher_models(api_config) - - if deployable_only: - # The VERIFIED_DEPLOYMENT_SUCCEED label filter is only applied - # server-side for HF models. For native models, the server returns all - # models and we must filter client-side by checking for - # multi_deploy_vertex configs. - models = [m for m in models if self._has_deploy_config(m)] - - return [self._format_model_name(m, include_hf) for m in models] - - def list_deployable_models( - self, - config: Optional[types.ListDeployableModelsConfigOrDict] = None, - ) -> list[str]: - """Lists models in Model Garden that support deployment. - - Returns models that have at least one verified deployment configuration. - When ``include_hugging_face_models`` is False (the default), - HuggingFace models are excluded from the results. - - Args: - config: Optional configuration for filtering results. Accepts a - ``ListDeployableModelsConfig`` instance or an equivalent dict. - - Returns: - A list of model name strings in the format - ``'{publisher}/{model}@{version}'`` (e.g. ``'google/gemma2@gemma-2-2b-it'``) - or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True - (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). - """ - if config is None: - config = types.ListDeployableModelsConfig() - if isinstance(config, dict): - config = types.ListDeployableModelsConfig.model_validate(config) - - return self._list( - config.model_filter, - config.include_hugging_face_models, - deployable_only=True, - ) - - def list_models( - self, - config: Optional[types.ListModelGardenModelsConfigOrDict] = None, - ) -> list[str]: - """Lists all models available in Model Garden. - - Returns all models regardless of deployment support. When - ``include_hugging_face_models`` is False (the default), HuggingFace - models are excluded from the results. - - Args: - config: Optional configuration for filtering results. Accepts a - ``ListModelGardenModelsConfig`` instance or an equivalent dict. - - Returns: - A list of model name strings in the format - ``'{publisher}/{model}@{version}'`` (e.g. ``'google/gemma2@gemma-2-2b-it'``) - or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True - (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). - """ - if config is None: - config = types.ListModelGardenModelsConfig() - if isinstance(config, dict): - config = types.ListModelGardenModelsConfig.model_validate(config) - - return self._list( - config.model_filter, - config.include_hugging_face_models, - deployable_only=False, + filter=api_config.filter, + list_all_versions=api_config.list_all_versions, + page_token=page_token, ) + response = self._list_publisher_models( + parent="publishers/*", + config=api_config, + ) + all_models.extend(response.publisher_models or []) + page_token = response.next_page_token + if not page_token: + break + return all_models + + def _list( + self, + model_filter: Optional[str], + include_hugging_face_models: Optional[bool], + deployable_only: bool, + ) -> list[str]: + """Shared implementation for listing models. + + Args: + model_filter: Optional substring to filter models by. + include_hugging_face_models: Whether to include HuggingFace models. + deployable_only: If True, only return models with deployment configs. + + Returns: + A list of formatted model name strings. + """ + include_hf = include_hugging_face_models is True + + filter_str = self._build_filter_str( + model_filter, include_hf, deployable_only=deployable_only + ) + + api_config = types.ListPublisherModelsConfig( + filter=filter_str, + list_all_versions=True, + ) + + models = self._list_all_publisher_models(api_config) + + if deployable_only: + # The VERIFIED_DEPLOYMENT_SUCCEED server filter only applies to HF + # models; filter native models client-side via multi_deploy_vertex. + models = [m for m in models if self._has_deploy_config(m)] + + return [self._format_model_name(m, include_hf) for m in models] + + def list_deployable_models( + self, + config: Optional[types.ListDeployableModelsConfigOrDict] = None, + ) -> list[str]: + """Lists models in Model Garden that support deployment. + + Returns models that have at least one verified deployment configuration. + When ``include_hugging_face_models`` is False (the default), + HuggingFace models are excluded from the results. + + Args: + config: Optional configuration for filtering results. Accepts a + ``ListDeployableModelsConfig`` instance or an equivalent dict. + + Returns: + A list of model name strings in the format + ``'{publisher}/{model}@{version}'`` (e.g. + ``'google/gemma2@gemma-2-2b-it'``) + or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True + (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + """ + if config is None: + config = types.ListDeployableModelsConfig() + if isinstance(config, dict): + config = types.ListDeployableModelsConfig.model_validate(config) + + return self._list( + config.model_filter, + config.include_hugging_face_models, + deployable_only=True, + ) + + def list_models( + self, + config: Optional[types.ListModelGardenModelsConfigOrDict] = None, + ) -> list[str]: + """Lists all models available in Model Garden. + + Returns all models regardless of deployment support. When + ``include_hugging_face_models`` is False (the default), HuggingFace + models are excluded from the results. + + Args: + config: Optional configuration for filtering results. Accepts a + ``ListModelGardenModelsConfig`` instance or an equivalent dict. + + Returns: + A list of model name strings in the format + ``'{publisher}/{model}@{version}'`` (e.g. + ``'google/gemma2@gemma-2-2b-it'``) + or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True + (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + """ + if config is None: + config = types.ListModelGardenModelsConfig() + if isinstance(config, dict): + config = types.ListModelGardenModelsConfig.model_validate(config) + + return self._list( + config.model_filter, + config.include_hugging_face_models, + deployable_only=False, + ) + + def list_publisher_model_deploy_options( + self, + model: str, + config: Optional[ + types.ListPublisherModelDeployOptionsConfigOrDict + ] = None, + ) -> Union[str, list[types.DeployOption]]: + """Lists the verified deploy options for a Model Garden publisher model. + + Supports Google open models (e.g. ``'google/gemma3@gemma-3-12b-it'``), + partner publisher models (e.g. + ``'mistralai/mistral-7b@mistral-7b-instruct-v0.2'``), and Hugging Face + model IDs (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + + Args: + model: The publisher model to list deploy options for. Accepts the full + resource name ``'publishers/{publisher}/models/{model}@{version}'``, a + simplified ``'{publisher}/{model}@{version}'`` (or without the + ``@{version}``), or a Hugging Face model ID + ``'{organization}/{model}'``. + config: Optional configuration for filtering the deploy options. Accepts a + ``ListPublisherModelDeployOptionsConfig`` instance or an equivalent + dict. + + Returns: + A list of ``DeployOption`` objects, one per verified deployment + configuration (container image URI, machine type, accelerator type and + count). If ``config.concise`` is True, returns a human-readable string + describing those deploy options instead. + + Raises: + ValueError: If ``model`` is not a valid publisher model name, if the + model does not support deployment, or if no deploy options match the + provided filters. + """ + if config is None: + config = types.ListPublisherModelDeployOptionsConfig() + if isinstance(config, dict): + config = types.ListPublisherModelDeployOptionsConfig.model_validate( + config + ) + + get_publisher_model_config = types.GetPublisherModelConfig( + is_hugging_face_model=self._is_hugging_face_model(model), + include_equivalent_model_garden_model_deployment_configs=True, + ) + publisher_model = self._get_publisher_model( + name=self._reconcile_model_name(model), + config=get_publisher_model_config, + ) + + options = self._extract_and_filter_deploy_options( + publisher_model, + machine_type_filter=config.machine_type_filter, + accelerator_type_filter=config.accelerator_type_filter, + serving_container_image_uri_filter=config.serving_container_image_uri_filter, + ) + + if config.concise is True: + return self._format_concise_deploy_options(options) + + return options class AsyncModelGarden(_api_module.BaseModule): - """Model Garden module.""" - - async def _list_publisher_models( - self, - *, - parent: Optional[str] = None, - config: Optional[types.ListPublisherModelsConfigOrDict] = None, - ) -> types.ListPublisherModelsResponse: - """ - Lists publisher models (internal). - """ - - parameter_model = types._ListPublisherModelsRequestParameters( - parent=parent, - config=config, - ) - - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError( - "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." - ) - else: - request_dict = _ListPublisherModelsRequestParameters_to_vertex( - parameter_model - ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{parent}/models".format_map(request_url_dict) - else: - path = "{parent}/models" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( - parameter_model.config is not None - and parameter_model.config.http_options is not None - ): - http_options = parameter_model.config.http_options - - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) - - response = await self._api_client.async_request( - "get", path, request_dict, http_options - ) - - response_dict = {} if not response.body else json.loads(response.body) - - return_value = types.ListPublisherModelsResponse._from_response( - response=response_dict, - kwargs=( - { - "config": { - "response_schema": getattr( - parameter_model.config, "response_schema", None - ), - "response_json_schema": getattr( - parameter_model.config, "response_json_schema", None - ), - "include_all_fields": getattr( - parameter_model.config, "include_all_fields", None - ), - } + """Model Garden module.""" + + async def _list_publisher_models( + self, + *, + parent: Optional[str] = None, + config: Optional[types.ListPublisherModelsConfigOrDict] = None, + ) -> types.ListPublisherModelsResponse: + """Lists publisher models (internal).""" + + parameter_model = types._ListPublisherModelsRequestParameters( + parent=parent, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform" + " mode, not in Gemini Developer API mode." + ) + else: + request_dict = _ListPublisherModelsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{parent}/models".format_map(request_url_dict) + else: + path = "{parent}/models" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListPublisherModelsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), } - if getattr(parameter_model, "config", None) - else {} - ), - ) - - self._api_client._verify_response(return_value) - return return_value - - async def _list_all_publisher_models( - self, - api_config: types.ListPublisherModelsConfig, - ) -> list[types.PublisherModel]: - """Fetches all pages of publisher models from the API. - - Args: - api_config: The configuration for the ListPublisherModels API call, - including filter and version settings. - - Returns: - A list of all ``PublisherModel`` objects across all pages. - """ - all_models = [] - page_token = None - while True: - if page_token: - api_config = types.ListPublisherModelsConfig( - filter=api_config.filter, - list_all_versions=api_config.list_all_versions, - page_token=page_token, - ) - response = await self._list_publisher_models( - parent="publishers/*", - config=api_config, - ) - all_models.extend(response.publisher_models or []) - page_token = response.next_page_token - if not page_token: - break - return all_models - - async def _list( - self, - model_filter: Optional[str], - include_hugging_face_models: Optional[bool], - deployable_only: bool, - ) -> list[str]: - """Shared implementation for listing models. - - Args: - model_filter: Optional substring to filter models by. - include_hugging_face_models: Whether to include HuggingFace models. - deployable_only: If True, only return models with deployment configs. - - Returns: - A list of formatted model name strings. - """ - include_hf = include_hugging_face_models is True - - filter_str = ModelGarden._build_filter_str( - model_filter, include_hf, deployable_only=deployable_only - ) - + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_publisher_model( + self, + *, + name: str, + config: Optional[types.GetPublisherModelConfigOrDict] = None, + ) -> types.PublisherModel: + """Gets a publisher model (internal).""" + + parameter_model = types._GetPublisherModelRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform" + " mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetPublisherModelRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.PublisherModel._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list_all_publisher_models( + self, + api_config: types.ListPublisherModelsConfig, + ) -> list[types.PublisherModel]: + """Fetches all pages of publisher models from the API. + + Args: + api_config: The configuration for the ListPublisherModels API call, + including filter and version settings. + + Returns: + A list of all ``PublisherModel`` objects across all pages. + """ + all_models = [] + page_token = None + while True: + if page_token: api_config = types.ListPublisherModelsConfig( - filter=filter_str, - list_all_versions=True, - ) - - models = await self._list_all_publisher_models(api_config) - - if deployable_only: - models = [m for m in models if ModelGarden._has_deploy_config(m)] - - return [ModelGarden._format_model_name(m, include_hf) for m in models] - - async def list_deployable_models( - self, - config: Optional[types.ListDeployableModelsConfigOrDict] = None, - ) -> list[str]: - """Lists models in Model Garden that support deployment. - - Returns models that have at least one verified deployment configuration. - When ``include_hugging_face_models`` is False (the default), - HuggingFace models are excluded from the results. - - Args: - config: Optional configuration for filtering results. Accepts a - ``ListDeployableModelsConfig`` instance or an equivalent dict. - - Returns: - A list of model name strings in the format - ``'{publisher}/{model}@{version}'`` (e.g. ``'google/gemma2@gemma-2-2b-it'``) - or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True - (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). - """ - if config is None: - config = types.ListDeployableModelsConfig() - if isinstance(config, dict): - config = types.ListDeployableModelsConfig.model_validate(config) - - return await self._list( - config.model_filter, - config.include_hugging_face_models, - deployable_only=True, - ) - - async def list_models( - self, - config: Optional[types.ListModelGardenModelsConfigOrDict] = None, - ) -> list[str]: - """Lists all models available in Model Garden. - - Returns all models regardless of deployment support. When - ``include_hugging_face_models`` is False (the default), HuggingFace - models are excluded from the results. - - Args: - config: Optional configuration for filtering results. Accepts a - ``ListModelGardenModelsConfig`` instance or an equivalent dict. - - Returns: - A list of model name strings in the format - ``'{publisher}/{model}@{version}'`` (e.g. ``'google/gemma2@gemma-2-2b-it'``) - or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True - (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). - """ - if config is None: - config = types.ListModelGardenModelsConfig() - if isinstance(config, dict): - config = types.ListModelGardenModelsConfig.model_validate(config) - - return await self._list( - config.model_filter, - config.include_hugging_face_models, - deployable_only=False, + filter=api_config.filter, + list_all_versions=api_config.list_all_versions, + page_token=page_token, ) + response = await self._list_publisher_models( + parent="publishers/*", + config=api_config, + ) + all_models.extend(response.publisher_models or []) + page_token = response.next_page_token + if not page_token: + break + return all_models + + async def _list( + self, + model_filter: Optional[str], + include_hugging_face_models: Optional[bool], + deployable_only: bool, + ) -> list[str]: + """Shared implementation for listing models. + + Args: + model_filter: Optional substring to filter models by. + include_hugging_face_models: Whether to include HuggingFace models. + deployable_only: If True, only return models with deployment configs. + + Returns: + A list of formatted model name strings. + """ + include_hf = include_hugging_face_models is True + + filter_str = ModelGarden._build_filter_str( + model_filter, include_hf, deployable_only=deployable_only + ) + + api_config = types.ListPublisherModelsConfig( + filter=filter_str, + list_all_versions=True, + ) + + models = await self._list_all_publisher_models(api_config) + + if deployable_only: + models = [m for m in models if ModelGarden._has_deploy_config(m)] + + return [ModelGarden._format_model_name(m, include_hf) for m in models] + + async def list_deployable_models( + self, + config: Optional[types.ListDeployableModelsConfigOrDict] = None, + ) -> list[str]: + """Lists models in Model Garden that support deployment. + + Returns models that have at least one verified deployment configuration. + When ``include_hugging_face_models`` is False (the default), + HuggingFace models are excluded from the results. + + Args: + config: Optional configuration for filtering results. Accepts a + ``ListDeployableModelsConfig`` instance or an equivalent dict. + + Returns: + A list of model name strings in the format + ``'{publisher}/{model}@{version}'`` (e.g. + ``'google/gemma2@gemma-2-2b-it'``) + or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True + (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + """ + if config is None: + config = types.ListDeployableModelsConfig() + if isinstance(config, dict): + config = types.ListDeployableModelsConfig.model_validate(config) + + return await self._list( + config.model_filter, + config.include_hugging_face_models, + deployable_only=True, + ) + + async def list_models( + self, + config: Optional[types.ListModelGardenModelsConfigOrDict] = None, + ) -> list[str]: + """Lists all models available in Model Garden. + + Returns all models regardless of deployment support. When + ``include_hugging_face_models`` is False (the default), HuggingFace + models are excluded from the results. + + Args: + config: Optional configuration for filtering results. Accepts a + ``ListModelGardenModelsConfig`` instance or an equivalent dict. + + Returns: + A list of model name strings in the format + ``'{publisher}/{model}@{version}'`` (e.g. + ``'google/gemma2@gemma-2-2b-it'``) + or ``'{publisher}/{model}'`` when ``include_hugging_face_models`` is True + (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + """ + if config is None: + config = types.ListModelGardenModelsConfig() + if isinstance(config, dict): + config = types.ListModelGardenModelsConfig.model_validate(config) + + return await self._list( + config.model_filter, + config.include_hugging_face_models, + deployable_only=False, + ) + + async def list_publisher_model_deploy_options( + self, + model: str, + config: Optional[ + types.ListPublisherModelDeployOptionsConfigOrDict + ] = None, + ) -> Union[str, list[types.DeployOption]]: + """Lists the verified deploy options for a Model Garden publisher model. + + Supports Google open models (e.g. ``'google/gemma3@gemma-3-12b-it'``), + partner publisher models (e.g. + ``'mistralai/mistral-7b@mistral-7b-instruct-v0.2'``), and Hugging Face + model IDs (e.g. ``'meta-llama/Llama-3.3-70B-Instruct'``). + + Args: + model: The publisher model to list deploy options for. Accepts the full + resource name ``'publishers/{publisher}/models/{model}@{version}'``, a + simplified ``'{publisher}/{model}@{version}'`` (or without the + ``@{version}``), or a Hugging Face model ID + ``'{organization}/{model}'``. + config: Optional configuration for filtering the deploy options. Accepts a + ``ListPublisherModelDeployOptionsConfig`` instance or an equivalent + dict. + + Returns: + A list of ``DeployOption`` objects, one per verified deployment + configuration (container image URI, machine type, accelerator type and + count). If ``config.concise`` is True, returns a human-readable string + describing those deploy options instead. + + Raises: + ValueError: If ``model`` is not a valid publisher model name, if the + model does not support deployment, or if no deploy options match the + provided filters. + """ + if config is None: + config = types.ListPublisherModelDeployOptionsConfig() + if isinstance(config, dict): + config = types.ListPublisherModelDeployOptionsConfig.model_validate( + config + ) + + api_config = types.GetPublisherModelConfig( + is_hugging_face_model=ModelGarden._is_hugging_face_model(model), + include_equivalent_model_garden_model_deployment_configs=True, + ) + publisher_model = await self._get_publisher_model( + name=ModelGarden._reconcile_model_name(model), config=api_config + ) + + options = ModelGarden._extract_and_filter_deploy_options( + publisher_model, + machine_type_filter=config.machine_type_filter, + accelerator_type_filter=config.accelerator_type_filter, + serving_container_image_uri_filter=config.serving_container_image_uri_filter, + ) + + if config.concise is True: + return ModelGarden._format_concise_deploy_options(options) + + return options diff --git a/agentplatform/_genai/types/__init__.py b/agentplatform/_genai/types/__init__.py index 134f7edd98..6085d22691 100644 --- a/agentplatform/_genai/types/__init__.py +++ b/agentplatform/_genai/types/__init__.py @@ -95,6 +95,7 @@ from .common import _GetImportFilesOperationParameters from .common import _GetMultimodalDatasetOperationParameters from .common import _GetMultimodalDatasetParameters +from .common import _GetPublisherModelRequestParameters from .common import _GetRagConfigOperationParameters from .common import _GetRagConfigRequestParameters from .common import _GetRagCorpusRequestParameters @@ -470,6 +471,9 @@ from .common import DeleteSkillOperation from .common import DeleteSkillOperationDict from .common import DeleteSkillOperationOrDict +from .common import DeployOption +from .common import DeployOptionDict +from .common import DeployOptionOrDict from .common import DirectUploadSource from .common import DirectUploadSourceDict from .common import DirectUploadSourceOrDict @@ -723,6 +727,9 @@ from .common import GetPromptConfig from .common import GetPromptConfigDict from .common import GetPromptConfigOrDict +from .common import GetPublisherModelConfig +from .common import GetPublisherModelConfigDict +from .common import GetPublisherModelConfigOrDict from .common import GetRagConfig from .common import GetRagConfigDict from .common import GetRagConfigOperationConfig @@ -875,6 +882,9 @@ from .common import ListPromptsConfig from .common import ListPromptsConfigDict from .common import ListPromptsConfigOrDict +from .common import ListPublisherModelDeployOptionsConfig +from .common import ListPublisherModelDeployOptionsConfigDict +from .common import ListPublisherModelDeployOptionsConfigOrDict from .common import ListPublisherModelsConfig from .common import ListPublisherModelsConfigDict from .common import ListPublisherModelsConfigOrDict @@ -3454,6 +3464,9 @@ "ListPublisherModelsResponse", "ListPublisherModelsResponseDict", "ListPublisherModelsResponseOrDict", + "GetPublisherModelConfig", + "GetPublisherModelConfigDict", + "GetPublisherModelConfigOrDict", "PromptOptimizerConfig", "PromptOptimizerConfigDict", "PromptOptimizerConfigOrDict", @@ -3553,6 +3566,12 @@ "ListModelGardenModelsConfig", "ListModelGardenModelsConfigDict", "ListModelGardenModelsConfigOrDict", + "ListPublisherModelDeployOptionsConfig", + "ListPublisherModelDeployOptionsConfigDict", + "ListPublisherModelDeployOptionsConfigOrDict", + "DeployOption", + "DeployOptionDict", + "DeployOptionOrDict", "A2aTaskState", "State", "Strategy", @@ -3736,6 +3755,7 @@ "_GetSkillRevisionRequestParameters", "_ListSkillRevisionsRequestParameters", "_ListPublisherModelsRequestParameters", + "_GetPublisherModelRequestParameters", "evals", "agent_engines", "prompts", diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index 2a8b150772..a7fb696770 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -23797,6 +23797,71 @@ class ListPublisherModelsResponseDict(TypedDict, total=False): ] +class GetPublisherModelConfig(_common.BaseModel): + """Config for getting a publisher model.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + hugging_face_token: Optional[str] = Field( + default=None, + description="""Optional. Hugging Face access token for gated models.""", + ) + include_equivalent_model_garden_model_deployment_configs: Optional[bool] = Field( + default=None, + description="""Optional. Whether to include the deploy options of equivalent + Model Garden models.""", + ) + is_hugging_face_model: Optional[bool] = Field( + default=None, + description="""Optional. Whether the requested model is a Hugging Face model.""", + ) + + +class GetPublisherModelConfigDict(TypedDict, total=False): + """Config for getting a publisher model.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + hugging_face_token: Optional[str] + """Optional. Hugging Face access token for gated models.""" + + include_equivalent_model_garden_model_deployment_configs: Optional[bool] + """Optional. Whether to include the deploy options of equivalent + Model Garden models.""" + + is_hugging_face_model: Optional[bool] + """Optional. Whether the requested model is a Hugging Face model.""" + + +GetPublisherModelConfigOrDict = Union[ + GetPublisherModelConfig, GetPublisherModelConfigDict +] + + +class _GetPublisherModelRequestParameters(_common.BaseModel): + """Parameters for getting a publisher model.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetPublisherModelConfig] = Field(default=None, description="""""") + + +class _GetPublisherModelRequestParametersDict(TypedDict, total=False): + """Parameters for getting a publisher model.""" + + name: Optional[str] + """""" + + config: Optional[GetPublisherModelConfigDict] + """""" + + +_GetPublisherModelRequestParametersOrDict = Union[ + _GetPublisherModelRequestParameters, _GetPublisherModelRequestParametersDict +] + + class PromptOptimizerConfig(_common.BaseModel): """VAPO Prompt Optimizer Config.""" @@ -25843,3 +25908,106 @@ class ListModelGardenModelsConfigDict(TypedDict, total=False): ListModelGardenModelsConfigOrDict = Union[ ListModelGardenModelsConfig, ListModelGardenModelsConfigDict ] + + +class ListPublisherModelDeployOptionsConfig(_common.BaseModel): + """Config for listing the deploy options of a publisher model.""" + + machine_type_filter: Optional[Union[str, list[str]]] = Field( + default=None, + description="""Optional. Case-insensitive substring filter on the machine type. + Accepts a single keyword (e.g. ``'g2'``) or a list of keywords (e.g. + ``['n1', 'g2']``); an option matches if it contains any of them.""", + ) + accelerator_type_filter: Optional[Union[str, list[str]]] = Field( + default=None, + description="""Optional. Case-insensitive substring filter on the accelerator + type. Accepts a single keyword (e.g. ``'L4'``) or a list of keywords + (e.g. ``['T4', 'L4']``); an option matches if it contains any of them.""", + ) + serving_container_image_uri_filter: Optional[Union[str, list[str]]] = Field( + default=None, + description="""Optional. Case-insensitive substring filter on the serving + container image URI. Accepts a single keyword (e.g. ``'vllm'``) or a list + of keywords (e.g. ``['vllm', 'tgi']``); an option matches if it contains + any of them.""", + ) + concise: Optional[bool] = Field( + default=None, + description="""Optional. If True, returns a human-readable string describing the + deploy options (container and machine specs) instead of a list of + ``DeployOption`` objects.""", + ) + + +class ListPublisherModelDeployOptionsConfigDict(TypedDict, total=False): + """Config for listing the deploy options of a publisher model.""" + + machine_type_filter: Optional[Union[str, list[str]]] + """Optional. Case-insensitive substring filter on the machine type. + Accepts a single keyword (e.g. ``'g2'``) or a list of keywords (e.g. + ``['n1', 'g2']``); an option matches if it contains any of them.""" + + accelerator_type_filter: Optional[Union[str, list[str]]] + """Optional. Case-insensitive substring filter on the accelerator + type. Accepts a single keyword (e.g. ``'L4'``) or a list of keywords + (e.g. ``['T4', 'L4']``); an option matches if it contains any of them.""" + + serving_container_image_uri_filter: Optional[Union[str, list[str]]] + """Optional. Case-insensitive substring filter on the serving + container image URI. Accepts a single keyword (e.g. ``'vllm'``) or a list + of keywords (e.g. ``['vllm', 'tgi']``); an option matches if it contains + any of them.""" + + concise: Optional[bool] + """Optional. If True, returns a human-readable string describing the + deploy options (container and machine specs) instead of a list of + ``DeployOption`` objects.""" + + +ListPublisherModelDeployOptionsConfigOrDict = Union[ + ListPublisherModelDeployOptionsConfig, + ListPublisherModelDeployOptionsConfigDict, +] + + +class DeployOption(_common.BaseModel): + """A verified deploy option for a model.""" + + option_name: Optional[str] = Field( + default=None, description="""The name of the deploy task.""" + ) + serving_container_image_uri: Optional[str] = Field( + default=None, description="""The URI of the serving container.""" + ) + machine_type: Optional[str] = Field( + default=None, description="""The machine type.""" + ) + accelerator_type: Optional[str] = Field( + default=None, description="""The accelerator type.""" + ) + accelerator_count: Optional[int] = Field( + default=None, description="""The number of accelerators.""" + ) + + +class DeployOptionDict(TypedDict, total=False): + """A verified deploy option for a model.""" + + option_name: Optional[str] + """The name of the deploy task.""" + + serving_container_image_uri: Optional[str] + """The URI of the serving container.""" + + machine_type: Optional[str] + """The machine type.""" + + accelerator_type: Optional[str] + """The accelerator type.""" + + accelerator_count: Optional[int] + """The number of accelerators.""" + + +DeployOptionOrDict = Union[DeployOption, DeployOptionDict] diff --git a/tests/unit/agentplatform/genai/replays/test_genai_model_garden.py b/tests/unit/agentplatform/genai/replays/test_genai_model_garden.py index b0f5b97958..9cebf7d6bb 100644 --- a/tests/unit/agentplatform/genai/replays/test_genai_model_garden.py +++ b/tests/unit/agentplatform/genai/replays/test_genai_model_garden.py @@ -34,16 +34,72 @@ def test_list_deployable_models(client): def test_list_models(client): - """Tests listing all baseline models in Model Garden.""" - models = client.model_garden.list_models( + """Tests listing all baseline models in Model Garden.""" + models = client.model_garden.list_models( config=types.ListModelGardenModelsConfig( include_hugging_face_models=False, model_filter="timesfm", ) ) - assert len(models) > 0 - assert isinstance(models[0], str) - assert "timesfm" in models[0].lower() + assert len(models) > 0 + assert isinstance(models[0], str) + assert "timesfm" in models[0].lower() + + +def test_list_publisher_model_deploy_options(client): + """Tests listing the verified deploy options for an open model.""" + options = client.model_garden.list_publisher_model_deploy_options( + model="google/gemma3@gemma-3-12b-it" + ) + assert len(options) > 0 + assert isinstance(options[0], types.DeployOption) + # Every verified deploy option exposes a serving container image. + assert options[0].serving_container_image_uri + + +def test_list_publisher_model_deploy_options_with_filter(client): + """Tests filtering an open model's deploy options by accelerator type.""" + options = client.model_garden.list_publisher_model_deploy_options( + model="google/gemma3@gemma-3-12b-it", + config=types.ListPublisherModelDeployOptionsConfig( + accelerator_type_filter="NVIDIA" + ), + ) + assert len(options) > 0 + for option in options: + assert "NVIDIA" in (option.accelerator_type or "") + + +def test_list_publisher_model_deploy_options_concise(client): + """Tests the concise (human-readable string) output for an open model.""" + options = client.model_garden.list_publisher_model_deploy_options( + model="google/gemma3@gemma-3-12b-it", + config=types.ListPublisherModelDeployOptionsConfig(concise=True), + ) + assert isinstance(options, str) + assert "[Option 1" in options + + +def test_list_publisher_model_deploy_options_hugging_face(client): + """Tests deploy options for a Hugging Face model. + + Exercises the distinct GetPublisherModel request path where + is_hugging_face_model=True is sent. + """ + options = client.model_garden.list_publisher_model_deploy_options( + model="codellama/codellama-7b-hf" + ) + assert len(options) > 0 + assert isinstance(options[0], types.DeployOption) + assert options[0].serving_container_image_uri + + +def test_list_publisher_model_deploy_options_no_deploy_support(client): + """Tests a model with no verified deployment config raises ValueError.""" + with pytest.raises(ValueError, match="does not support deployment"): + client.model_garden.list_publisher_model_deploy_options( + model="google/gemini-embedding-001@default" + ) pytestmark = pytest_helper.setup( @@ -80,3 +136,14 @@ async def test_list_models_async(client): assert len(models) > 0 assert isinstance(models[0], str) assert "timesfm" in models[0].lower() + + +@pytest.mark.asyncio +async def test_list_publisher_model_deploy_options_async(client): + """Tests listing the deploy options for an open model asynchronously.""" + options = await client.aio.model_garden.list_publisher_model_deploy_options( + model="google/gemma3@gemma-3-12b-it" + ) + assert len(options) > 0 + assert isinstance(options[0], types.DeployOption) + assert options[0].serving_container_image_uri diff --git a/tests/unit/agentplatform/genai/test_genai_model_garden.py b/tests/unit/agentplatform/genai/test_genai_model_garden.py index 100b5edbb8..fe2b054089 100644 --- a/tests/unit/agentplatform/genai/test_genai_model_garden.py +++ b/tests/unit/agentplatform/genai/test_genai_model_garden.py @@ -14,6 +14,7 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring +import asyncio from unittest import mock from agentplatform._genai import model_garden from agentplatform._genai import types @@ -34,6 +35,16 @@ def mock_client(): return model_garden.ModelGarden(mock_api_client) +@pytest.fixture +def mock_async_client(): + mock_api_client = mock.Mock(spec=client.Client) + mock_api_client.project = _TEST_PROJECT + mock_api_client.location = _TEST_LOCATION + mock_api_client.vertexai = True + + return model_garden.AsyncModelGarden(mock_api_client) + + def _make_deployable_model(name, version_id="001"): """Helper to create a PublisherModel with multi_deploy_vertex support.""" return types.PublisherModel( @@ -405,10 +416,738 @@ def test_build_filter_str_with_model_filter(): def test_build_filter_str_escapes_special_chars(): - """Tests that special regex characters in model_filter are escaped.""" - build_filter = model_garden.ModelGarden._build_filter_str - result = build_filter( + """Tests that special regex characters in model_filter are escaped.""" + build_filter = model_garden.ModelGarden._build_filter_str + result = build_filter( model_filter="model.v2+", include_hugging_face_models=False, deployable_only=False ) - # re.escape turns '.' into '\\.' and '+' into '\\+' - assert r"model\.v2\+" in result + # re.escape turns '.' into '\\.' and '+' into '\\+' + assert r"model\.v2\+" in result + + +# ---- list_publisher_model_deploy_options tests ---- + + +def _make_deploy_option( + deploy_task_name="option-1", + machine_type="g2-standard-12", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/vllm", +): + """Builds a single PublisherModelCallToActionDeploy option.""" + return types.PublisherModelCallToActionDeploy( + deploy_task_name=deploy_task_name, + dedicated_resources=types.DedicatedResources( + machine_spec=types.MachineSpec( + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + ) + ), + container_spec=types.ModelContainerSpec(image_uri=image_uri), + ) + + +def _make_model_with_deploy_options( + options, name="publishers/google/models/gemma-2b" +): + """Builds a PublisherModel exposing the given deploy options.""" + return types.PublisherModel( + name=name, + supported_actions=types.PublisherModelCallToAction( + multi_deploy_vertex=types.PublisherModelCallToActionDeployVertex( + multi_deploy_vertex=options, + ) + ), + ) + + +def test_list_publisher_model_deploy_options_basic(mock_client): + """Tests extraction of a single deploy option into a DeployOption.""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ) as mock_get: + # A simplified name is reconciled to the full publisher resource name. + options = mock_client.list_publisher_model_deploy_options( + model="google/gemma3@gemma-3-12b-it" + ) + + # The GetPublisherModel call must use the reconciled name, flag the model + # as non-Hugging-Face (it has an @version), and request equivalent + # deployment configs. + mock_get.assert_called_once_with( + name="publishers/google/models/gemma3@gemma-3-12b-it", + config=types.GetPublisherModelConfig( + is_hugging_face_model=False, + include_equivalent_model_garden_model_deployment_configs=True, + ), + ) + assert len(options) == 1 + assert isinstance(options[0], types.DeployOption) + assert options[0].option_name == "option-1" + assert options[0].machine_type == "g2-standard-12" + # accelerator_type is returned as the enum's string value (legacy parity). + assert options[0].accelerator_type == "NVIDIA_L4" + assert options[0].accelerator_count == 1 + assert "vllm" in options[0].serving_container_image_uri + + +def test_list_publisher_model_deploy_options_multiple(mock_client): + """Tests that all deploy options are returned when no filters are set.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="g2", machine_type="g2-standard-12"), + _make_deploy_option(deploy_task_name="a3", machine_type="a3-highgpu-8g"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001" + ) + + assert [o.option_name for o in options] == ["g2", "a3"] + + +def test_list_publisher_model_deploy_options_machine_type_filter(mock_client): + """Tests machine_type_filter is a case-insensitive substring match.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="g2", machine_type="g2-standard-12"), + _make_deploy_option(deploy_task_name="a3", machine_type="a3-highgpu-8g"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + machine_type_filter="G2" + ), + ) + + assert len(options) == 1 + assert options[0].option_name == "g2" + + +def test_list_publisher_model_deploy_options_accelerator_type_filter( + mock_client, +): + """Tests accelerator_type_filter is a case-insensitive substring match.""" + dummy_model = _make_model_with_deploy_options([ + _make_deploy_option(deploy_task_name="l4", accelerator_type="NVIDIA_L4"), + _make_deploy_option( + deploy_task_name="h100", accelerator_type="NVIDIA_H100_80GB" + ), + ]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + accelerator_type_filter="h100" + ), + ) + + assert len(options) == 1 + assert options[0].option_name == "h100" + + +def test_list_publisher_model_deploy_options_image_uri_filter(mock_client): + """Tests serving_container_image_uri_filter is a case-insensitive match.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="vllm", image_uri="docker/vllm"), + _make_deploy_option(deploy_task_name="tgi", image_uri="docker/tgi"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + serving_container_image_uri_filter="VLLM" + ), + ) + + assert len(options) == 1 + assert options[0].option_name == "vllm" + + +def test_list_publisher_model_deploy_options_machine_type_filter_list( + mock_client, +): + """Tests a list of keywords matches options containing ANY of them (legacy parity).""" + dummy_model = _make_model_with_deploy_options([ + _make_deploy_option(deploy_task_name="g2", machine_type="g2-standard-12"), + _make_deploy_option(deploy_task_name="a3", machine_type="a3-highgpu-8g"), + _make_deploy_option(deploy_task_name="n1", machine_type="n1-standard-8"), + ]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + machine_type_filter=["A3", "n1"] + ), + ) + + assert [o.option_name for o in options] == ["a3", "n1"] + + +def test_list_publisher_model_deploy_options_accelerator_type_filter_list( + mock_client, +): + """Tests a list accelerator filter matches options containing ANY keyword.""" + dummy_model = _make_model_with_deploy_options([ + _make_deploy_option(deploy_task_name="l4", accelerator_type="NVIDIA_L4"), + _make_deploy_option( + deploy_task_name="t4", accelerator_type="NVIDIA_TESLA_T4" + ), + _make_deploy_option( + deploy_task_name="h100", accelerator_type="NVIDIA_H100_80GB" + ), + ]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + accelerator_type_filter=["T4", "L4"] + ), + ) + + assert [o.option_name for o in options] == ["l4", "t4"] + + +def test_list_publisher_model_deploy_options_image_uri_filter_list(mock_client): + """Tests a list image-uri filter matches any keyword.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="vllm", image_uri="docker/vllm"), + _make_deploy_option(deploy_task_name="tgi", image_uri="docker/tgi"), + _make_deploy_option(deploy_task_name="sglang", image_uri="docker/sglang"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + serving_container_image_uri_filter=["vllm", "tgi"] + ), + ) + + assert [o.option_name for o in options] == ["vllm", "tgi"] + + +def test_matches_filter(): + """Tests the keyword-filter helper (single keyword, list, None, misses).""" + matches = model_garden.ModelGarden._matches_filter + # No filter -> always matches. + assert matches("g2-standard-12", None) is True + assert matches(None, None) is True + # Single keyword, case-insensitive substring. + assert matches("g2-standard-12", "G2") is True + assert matches("g2-standard-12", "n1") is False + # List of keywords -> match if ANY is contained. + assert matches("a3-highgpu-8g", ["n1", "a3"]) is True + assert matches("a3-highgpu-8g", ["n1", "g2"]) is False + # A missing (None) value never matches a non-empty filter. + assert matches(None, "g2") is False + # An empty list filter behaves like "no filter". + assert matches("anything", []) is True + + +def test_list_publisher_model_deploy_options_dict_config(mock_client): + """Tests config passed as a dict is validated and applied.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="g2", machine_type="g2-standard-12"), + _make_deploy_option(deploy_task_name="a3", machine_type="a3-highgpu-8g"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config={"machine_type_filter": "a3"}, + ) + + assert len(options) == 1 + assert options[0].option_name == "a3" + + +def test_list_publisher_model_deploy_options_default_config(mock_client): + """Tests config=None returns all options.""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001" + ) + + assert len(options) == 1 + + +def test_list_publisher_model_deploy_options_no_deploy_support_raises( + mock_client, +): + """Tests ValueError when the model does not support deployment (legacy parity).""" + dummy_model = types.PublisherModel(name="publishers/google/models/bert-base") + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + with pytest.raises(ValueError, match="does not support deployment"): + mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/bert-base@001" + ) + + +def test_list_publisher_model_deploy_options_no_match_raises(mock_client): + """Tests ValueError when filters exclude every option (legacy parity).""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + with pytest.raises(ValueError, match="No deploy options found."): + mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + machine_type_filter="does-not-exist" + ), + ) + + +def test_reconcile_model_name_simplified_with_version(): + """Tests a simplified name with a version is expanded to the full name.""" + reconcile = model_garden.ModelGarden._reconcile_model_name + assert ( + reconcile("google/gemma3@gemma-3-12b-it") + == "publishers/google/models/gemma3@gemma-3-12b-it" + ) + + +def test_reconcile_model_name_simplified_without_version(): + """Tests a simplified name without a version is expanded to the full name.""" + reconcile = model_garden.ModelGarden._reconcile_model_name + assert reconcile("google/gemma3") == "publishers/google/models/gemma3" + + +def test_reconcile_model_name_full_resource_name(): + """Tests a full resource name (with version) is returned normalized.""" + reconcile = model_garden.ModelGarden._reconcile_model_name + assert ( + reconcile("publishers/google/models/gemma3@gemma-3-12b-it") + == "publishers/google/models/gemma3@gemma-3-12b-it" + ) + + +def test_reconcile_model_name_full_resource_name_no_version(): + """Tests a full resource name WITHOUT a version is returned unchanged. + + Regression test: previously a versionless full name fell through to the + simplified branch and was mangled into 'publishers/publishers/models/...'. + """ + reconcile = model_garden.ModelGarden._reconcile_model_name + assert ( + reconcile("publishers/google/models/gemma3") + == "publishers/google/models/gemma3" + ) + + +def test_reconcile_model_name_hugging_face_full_resource_name_no_version(): + """Tests a versionless Hugging Face full resource name is preserved.""" + reconcile = model_garden.ModelGarden._reconcile_model_name + assert ( + reconcile("publishers/hf-codellama/models/codellama-7b-hf") + == "publishers/hf-codellama/models/codellama-7b-hf" + ) + + +def test_reconcile_model_name_lowercases(): + """Tests names are lowercased (Hugging Face parity with legacy SDK).""" + reconcile = model_garden.ModelGarden._reconcile_model_name + assert ( + reconcile("Meta-Llama/Llama-3.3-70B-Instruct") + == "publishers/meta-llama/models/llama-3.3-70b-instruct" + ) + + +def test_reconcile_model_name_invalid_raises(): + """Tests an invalid name raises ValueError.""" + reconcile = model_garden.ModelGarden._reconcile_model_name + with pytest.raises(ValueError, match="not a valid publisher model name"): + reconcile("invalid-name-without-slash") + + +def test_reconcile_model_name_model_registry_raises(): + """Tests a Model Registry resource name raises ValueError. + + Without an explicit guard, ``projects/.../locations/.../models/...`` would + match the simplified ``{publisher}/{model}`` regex and be silently mangled + into ``publishers/projects/models/.../locations/.../models/...``. The guard + rejects it loudly with the same ``not a valid publisher model name`` + message used for any other unsupported input. + """ + reconcile = model_garden.ModelGarden._reconcile_model_name + for name in ( + "projects/123/locations/us-central1/models/456", + "projects/my-project/locations/europe-west1/models/9876543210@1", + "projects/p/locations/l/models/m", + ): + with pytest.raises(ValueError, match="not a valid publisher model name"): + reconcile(name) + + +def test_is_hugging_face_model(): + """Tests the Hugging Face model heuristic.""" + is_hf = model_garden.ModelGarden._is_hugging_face_model + # Bare org/model (single slash, no @version) -> Hugging Face. + assert is_hf("meta-llama/Llama-3.3-70B-Instruct") is True + # Simplified native names without @version also match (handled + # correctly by _reconcile_model_name). + assert is_hf("google/gemma3") is True + # Names with @version or a publishers/ prefix are not Hugging Face. + assert is_hf("google/gemma3@gemma-3-12b-it") is False + assert is_hf("publishers/google/models/gemma3@gemma-3-12b-it") is False + assert is_hf("publishers/hf-meta-llama/models/llama-3.3") is False + + +def test_list_publisher_model_deploy_options_hugging_face_model(mock_client): + """Tests an HF model name sends is_hugging_face_model=True (legacy parity).""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ) as mock_get: + mock_client.list_publisher_model_deploy_options( + model="meta-llama/Llama-3.3-70B-Instruct" + ) + + mock_get.assert_called_once_with( + name="publishers/meta-llama/models/llama-3.3-70b-instruct", + config=types.GetPublisherModelConfig( + is_hugging_face_model=True, + include_equivalent_model_garden_model_deployment_configs=True, + ), + ) + + +def test_list_publisher_model_deploy_options_async(mock_async_client): + """Tests the async client returns deploy options.""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_async_client, + "_get_publisher_model", + new=mock.AsyncMock(return_value=dummy_model), + ): + options = asyncio.run( + mock_async_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001" + ) + ) + + assert len(options) == 1 + assert options[0].option_name == "option-1" + assert options[0].machine_type == "g2-standard-12" + + +def test_list_publisher_model_deploy_options_async_no_deploy_support_raises( + mock_async_client, +): + """Tests the async client raises when deployment is unsupported.""" + dummy_model = types.PublisherModel(name="publishers/google/models/bert-base") + + with mock.patch.object( + mock_async_client, + "_get_publisher_model", + new=mock.AsyncMock(return_value=dummy_model), + ): + with pytest.raises(ValueError, match="does not support deployment"): + asyncio.run( + mock_async_client.list_publisher_model_deploy_options( + model="publishers/google/models/bert-base@001" + ) + ) + + +def _make_deploy_option_no_accelerator( + deploy_task_name="tpu", + machine_type="ct5lp-hightpu-1t", + image_uri="docker/hexllm", +): + """Builds a deploy option whose machine has no GPU accelerator (e.g. TPU).""" + return types.PublisherModelCallToActionDeploy( + deploy_task_name=deploy_task_name, + dedicated_resources=types.DedicatedResources( + machine_spec=types.MachineSpec(machine_type=machine_type) + ), + container_spec=types.ModelContainerSpec(image_uri=image_uri), + ) + + +def test_list_publisher_model_deploy_options_no_accelerator_defaults( + mock_client, +): + """Tests no-accelerator machines report UNSPECIFIED/0 (legacy proto parity).""" + dummy_model = _make_model_with_deploy_options( + [_make_deploy_option_no_accelerator()] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001" + ) + + assert options[0].machine_type == "ct5lp-hightpu-1t" + # Over gRPC the legacy SDK surfaces these proto3 defaults; we match them. + assert options[0].accelerator_type == "ACCELERATOR_TYPE_UNSPECIFIED" + assert options[0].accelerator_count == 0 + + +def test_list_publisher_model_deploy_options_no_accelerator_concise( + mock_client, +): + """Tests concise output for a no-accelerator machine matches legacy.""" + dummy_model = _make_model_with_deploy_options( + [_make_deploy_option_no_accelerator()] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + result = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig(concise=True), + ) + + expected = ( + "[Option 1: tpu]\n" + ' serving_container_image_uri="docker/hexllm",\n' + ' machine_type="ct5lp-hightpu-1t",\n' + ' accelerator_type="ACCELERATOR_TYPE_UNSPECIFIED",\n' + " accelerator_count=0," + ) + assert result == expected + + +def test_list_publisher_model_deploy_options_accelerator_filter_excludes_unspecified( + mock_client, +): + """Tests accelerator_type_filter excludes no-accelerator options (legacy parity).""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option_no_accelerator(deploy_task_name="tpu"), + _make_deploy_option(deploy_task_name="gpu", accelerator_type="NVIDIA_L4"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + options = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + accelerator_type_filter="NVIDIA" + ), + ) + + assert [o.option_name for o in options] == ["gpu"] + + +# ---- concise option tests ---- + + +def test_list_publisher_model_deploy_options_not_concise_returns_list( + mock_client, +): + """Tests that without concise the method returns a list of DeployOption.""" + dummy_model = _make_model_with_deploy_options([_make_deploy_option()]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + result = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig(concise=False), + ) + + assert isinstance(result, list) + assert isinstance(result[0], types.DeployOption) + + +def test_list_publisher_model_deploy_options_concise_returns_string( + mock_client, +): + """Tests concise=True returns the legacy-format human-readable string.""" + dummy_model = _make_model_with_deploy_options([ + _make_deploy_option( + deploy_task_name="option-1", + machine_type="g2-standard-12", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + image_uri="docker/vllm", + ) + ]) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + result = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig(concise=True), + ) + + # Matches the legacy SDK's concise formatting exactly. + expected = ( + "[Option 1: option-1]\n" + ' serving_container_image_uri="docker/vllm",\n' + ' machine_type="g2-standard-12",\n' + ' accelerator_type="NVIDIA_L4",\n' + " accelerator_count=1," + ) + assert isinstance(result, str) + assert result == expected + + +def test_list_publisher_model_deploy_options_concise_multiple(mock_client): + """Tests concise formatting of multiple options separated by a blank line.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option( + deploy_task_name="g2", + machine_type="g2-standard-12", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + image_uri="docker/vllm", + ), + _make_deploy_option( + deploy_task_name="a3", + machine_type="a3-highgpu-8g", + accelerator_type="NVIDIA_H100_80GB", + accelerator_count=8, + image_uri="docker/hexllm", + ), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + result = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig(concise=True), + ) + + expected = ( + "[Option 1: g2]\n" + ' serving_container_image_uri="docker/vllm",\n' + ' machine_type="g2-standard-12",\n' + ' accelerator_type="NVIDIA_L4",\n' + " accelerator_count=1," + "\n\n" + "[Option 2: a3]\n" + ' serving_container_image_uri="docker/hexllm",\n' + ' machine_type="a3-highgpu-8g",\n' + ' accelerator_type="NVIDIA_H100_80GB",\n' + " accelerator_count=8," + ) + assert result == expected + + +def test_list_publisher_model_deploy_options_concise_with_filter(mock_client): + """Tests concise output honors filters before formatting.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option(deploy_task_name="g2", machine_type="g2-standard-12"), + _make_deploy_option(deploy_task_name="a3", machine_type="a3-highgpu-8g"), + ] + ) + + with mock.patch.object( + mock_client, "_get_publisher_model", return_value=dummy_model + ): + result = mock_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig( + machine_type_filter="a3", concise=True + ), + ) + + assert result.startswith("[Option 1: a3]") + assert "a3-highgpu-8g" in result + assert "g2-standard-12" not in result + + +def test_format_concise_deploy_options_omits_none_fields(): + """Tests that None fields are omitted and the header has no name if unset.""" + options = [ + types.DeployOption( + machine_type="g2-standard-12", + accelerator_count=1, + ) + ] + result = model_garden.ModelGarden._format_concise_deploy_options(options) + expected = ( + "[Option 1]\n" + ' machine_type="g2-standard-12",\n' + " accelerator_count=1," + ) + assert result == expected + + +def test_list_publisher_model_deploy_options_concise_async(mock_async_client): + """Tests the async client returns a concise string when requested.""" + dummy_model = _make_model_with_deploy_options( + [ + _make_deploy_option( + deploy_task_name="option-1", + machine_type="g2-standard-12", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + image_uri="docker/vllm", + ) + ] + ) + + with mock.patch.object( + mock_async_client, + "_get_publisher_model", + new=mock.AsyncMock(return_value=dummy_model), + ): + result = asyncio.run( + mock_async_client.list_publisher_model_deploy_options( + model="publishers/google/models/gemma-2b@001", + config=types.ListPublisherModelDeployOptionsConfig(concise=True), + ) + ) + + assert isinstance(result, str) + assert result.startswith("[Option 1: option-1]")