diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index d32e764fa89c..1070671f5a1f 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -23,7 +23,7 @@ import pandas as pd -from bigframes import clients, dataframe, dtypes, series, session +from bigframes import dataframe, dtypes, series, session from bigframes import pandas as bpd from bigframes.bigquery._operations import utils as bq_utils from bigframes.core import convert @@ -885,7 +885,11 @@ def classify( input: PROMPT_TYPE, categories: tuple[str, ...] | list[str], *, + examples: list[tuple[str, str]] | None = None, connection_id: str | None = None, + endpoint: str | None = None, + optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None, + max_error_ratio: float | None = None, ) -> series.Series: """ Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input. @@ -903,22 +907,30 @@ def classify( [2 rows x 2 columns] - .. note:: - - This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the - Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" - and might have limited support. For more information, see the launch stage descriptions - (https://cloud.google.com/products#product-launch-stages). - Args: input (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series or pandas Series. categories (tuple[str, ...] | list[str]): Categories to classify the input into. + examples (list[tuple[str, str]], optional): + An array that contains representative examples of input strings and the output category + that you expect. You can provide examples to help the model understand your + intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. If not provided, the query uses your end-user credential. + endpoint (str, optional): + A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically + identifies and uses the full endpoint of the model. + optimization_mode (Literal["minimize_cost", "maximize_quality"], optional): + A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost`` + and ``maximize_quality``. + max_error_ratio (float, optional): + A value between ``0.0`` and ``1.0`` that contains the maximum acceptable ratio of row-level + inference failures to rows processed on this function. The default value is 1.0. + This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``. Returns: bigframes.series.Series: A new series of strings. @@ -927,10 +939,16 @@ def classify( prompt_context, series_list = _separate_context_and_series(input) assert len(series_list) > 0 + example_tuples = tuple(examples) if examples is not None else None + operator = ai_ops.AIClassify( prompt_context=tuple(prompt_context), categories=tuple(categories), + examples=example_tuples, connection_id=connection_id, + endpoint=endpoint, + optimization_mode=optimization_mode, + max_error_ratio=max_error_ratio, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) @@ -1249,14 +1267,6 @@ def _convert_series( return result -def _resolve_connection_id(series: series.Series, connection_id: str | None): - return clients.get_canonical_bq_connection_id( - connection_id or series._session.bq_connection, - series._session._project, - series._session._location, - ) - - def _to_dataframe( data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], series_rename: str, diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 39c0ffb8d037..03ec72f4e44b 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1996,7 +1996,11 @@ def ai_classify( return ai_ops.AIClassify( _construct_prompt(values, op.prompt_context), # type: ignore op.categories, # type: ignore + _construct_examples(op.examples), # type: ignore op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore + op.max_error_ratio, # type: ignore ).to_expr() @@ -2040,6 +2044,26 @@ def _construct_prompt( return ibis.struct(prompt) +def _construct_examples( + examples: tuple[tuple[str, str]] | None, +) -> ibis_types.ArrayValue | None: + if examples is None: + return None + + results: list[ibis_types.StructValue] = [] + + for example in examples: + ibis_example = ibis.struct( + { + "_field_1": example[0], + "_field_2": example[1], + } + ) + results.append(ibis_example) + + return ibis.array(results) + + @scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True) def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value: return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_key(values) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 2860c9d50d20..abb979ddb43a 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -149,6 +149,16 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: args.append( sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) ) + elif field == "examples": + example_expressions = [ + sge.Tuple( + expressions=[sge.Literal.string(key), sge.Literal.string(val)] + ) + for key, val in value + ] + args.append( + sge.Kwarg(this=field, expression=sge.array(*example_expressions)) + ) else: args.append( sge.Kwarg(this=field, expression=sge.Literal.string(str(value))) diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index 968591b2077d..0d9438741f46 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -160,7 +160,11 @@ class AIClassify(base_ops.NaryOp): prompt_context: Tuple[str | None, ...] categories: tuple[str, ...] + examples: tuple[tuple[str, str], ...] | None connection_id: str | None + endpoint: str | None + optimization_mode: str | None + max_error_ratio: float | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.STRING_DTYPE diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 79179a57351e..421d83db8e08 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -355,6 +355,15 @@ def test_ai_classify(session): assert result.dtype == dtypes.STRING_DTYPE +def test_ai_classify_with_examples(session): + s = bpd.Series(["cat", "orchid"], session=session) + + result = bbq.ai.classify(s, ["animal", "plant"], examples=[("dog", "animal")]) + + assert len(result) == len(s) + assert result.dtype == dtypes.STRING_DTYPE + + def test_ai_classify_multi_model(session, bq_connection): df = session.from_glob_path( "gs://bigframes-dev-testing/a_multimodel/images/*", diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql new file mode 100644 index 000000000000..982b747f8927 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify_with_params/out.sql @@ -0,0 +1,9 @@ +SELECT + AI.CLASSIFY( + input => (`string_col`), + categories => ['greeting', 'rejection'], + examples => [('hi', 'greeting'), ('bye', 'rejection')], + endpoint => 'gemini-2.5-flash', + max_error_ratio => 0.1 + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 76716ca4db24..f09436dc8509 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -392,7 +392,29 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_ op = ops.AIClassify( prompt_context=(None,), categories=("greeting", "rejection"), + examples=None, connection_id=connection_id, + endpoint=None, + optimization_mode=None, + max_error_ratio=None, + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIClassify( + prompt_context=(None,), + categories=("greeting", "rejection"), + examples=(("hi", "greeting"), ("bye", "rejection")), + connection_id=None, + endpoint="gemini-2.5-flash", + optimization_mode=None, + max_error_ratio=0.1, ) sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 3368671b9004..7a6a31c4b72a 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -155,12 +155,16 @@ class AIClassify(Value): input: Value categories: Value[dt.Array[dt.String]] + examples: Optional[Value] connection_id: Optional[Value[dt.String]] + endpoint: Optional[Value[dt.String]] + optimization_mode: Optional[Value[dt.String]] + max_error_ratio: Optional[Value[dt.Float64]] shape = rlz.shape_like("input") @attribute - def dtype(self) -> dt.Struct: + def dtype(self) -> dt.DataType: return dt.string