Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pandas as pd

from bigframes import clients, dataframe, dtypes, series, session
from bigframes import dataframe, dtypes, series, session
from bigframes import pandas as bpd
from bigframes.bigquery._operations import utils as bq_utils
from bigframes.core import convert
Expand Down Expand Up @@ -885,7 +885,11 @@ def classify(
input: PROMPT_TYPE,
categories: tuple[str, ...] | list[str],
*,
examples: list[tuple[str, str]] | None = None,
connection_id: str | None = None,
endpoint: str | None = None,
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
max_error_ratio: float | None = None,
Comment thread
sycai marked this conversation as resolved.
) -> series.Series:
"""
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
Expand All @@ -903,22 +907,30 @@ def classify(
<BLANKLINE>
[2 rows x 2 columns]

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).

Args:
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
or pandas Series.
categories (tuple[str, ...] | list[str]):
Categories to classify the input into.
examples (list[tuple[str, str]], optional):
An array that contains representative examples of input strings and the output category
that you expect. You can provide examples to help the model understand your
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the query uses your end-user credential.
endpoint (str, optional):
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
identifies and uses the full endpoint of the model.
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
and ``maximize_quality``.
max_error_ratio (float, optional):
A value between ``0.0`` and ``1.0`` that contains the maximum acceptable ratio of row-level
inference failures to rows processed on this function. The default value is 1.0.
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.

Returns:
bigframes.series.Series: A new series of strings.
Expand All @@ -927,10 +939,16 @@ def classify(
prompt_context, series_list = _separate_context_and_series(input)
assert len(series_list) > 0

example_tuples = tuple(examples) if examples is not None else None

operator = ai_ops.AIClassify(
prompt_context=tuple(prompt_context),
categories=tuple(categories),
examples=example_tuples,
connection_id=connection_id,
endpoint=endpoint,
optimization_mode=optimization_mode,
max_error_ratio=max_error_ratio,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])
Expand Down Expand Up @@ -1249,14 +1267,6 @@ def _convert_series(
return result


def _resolve_connection_id(series: series.Series, connection_id: str | None):
return clients.get_canonical_bq_connection_id(
connection_id or series._session.bq_connection,
series._session._project,
series._session._location,
)


def _to_dataframe(
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
series_rename: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,11 @@ def ai_classify(
return ai_ops.AIClassify(
_construct_prompt(values, op.prompt_context), # type: ignore
op.categories, # type: ignore
_construct_examples(op.examples), # type: ignore
Comment thread
sycai marked this conversation as resolved.
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
op.max_error_ratio, # type: ignore
).to_expr()


Expand Down Expand Up @@ -2040,6 +2044,26 @@ def _construct_prompt(
return ibis.struct(prompt)


def _construct_examples(
examples: tuple[tuple[str, str]] | None,
) -> ibis_types.ArrayValue | None:
if examples is None:
return None

results: list[ibis_types.StructValue] = []

for example in examples:
ibis_example = ibis.struct(
{
"_field_1": example[0],
"_field_2": example[1],
}
)
results.append(ibis_example)

return ibis.array(results)


@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_key(values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
args.append(
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
)
elif field == "examples":
example_expressions = [
sge.Tuple(
expressions=[sge.Literal.string(key), sge.Literal.string(val)]
)
for key, val in value
]
args.append(
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
)
else:
args.append(
sge.Kwarg(this=field, expression=sge.Literal.string(str(value)))
Expand Down
4 changes: 4 additions & 0 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ class AIClassify(base_ops.NaryOp):

prompt_context: Tuple[str | None, ...]
categories: tuple[str, ...]
examples: tuple[tuple[str, str], ...] | None
connection_id: str | None
endpoint: str | None
optimization_mode: str | None
max_error_ratio: float | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.STRING_DTYPE
Expand Down
9 changes: 9 additions & 0 deletions packages/bigframes/tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,15 @@ def test_ai_classify(session):
assert result.dtype == dtypes.STRING_DTYPE


def test_ai_classify_with_examples(session):
s = bpd.Series(["cat", "orchid"], session=session)

result = bbq.ai.classify(s, ["animal", "plant"], examples=[("dog", "animal")])

assert len(result) == len(s)
assert result.dtype == dtypes.STRING_DTYPE


def test_ai_classify_multi_model(session, bq_connection):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT
AI.CLASSIFY(
input => (`string_col`),
categories => ['greeting', 'rejection'],
examples => [('hi', 'greeting'), ('bye', 'rejection')],
endpoint => 'gemini-2.5-flash',
max_error_ratio => 0.1
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,29 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_
op = ops.AIClassify(
prompt_context=(None,),
categories=("greeting", "rejection"),
examples=None,
connection_id=connection_id,
endpoint=None,
optimization_mode=None,
max_error_ratio=None,
)

sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])

snapshot.assert_match(sql, "out.sql")


def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIClassify(
prompt_context=(None,),
categories=("greeting", "rejection"),
examples=(("hi", "greeting"), ("bye", "rejection")),
connection_id=None,
endpoint="gemini-2.5-flash",
optimization_mode=None,
max_error_ratio=0.1,
)

sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,16 @@ class AIClassify(Value):

input: Value
categories: Value[dt.Array[dt.String]]
examples: Optional[Value]
connection_id: Optional[Value[dt.String]]
endpoint: Optional[Value[dt.String]]
optimization_mode: Optional[Value[dt.String]]
max_error_ratio: Optional[Value[dt.Float64]]

shape = rlz.shape_like("input")

@attribute
def dtype(self) -> dt.Struct:
def dtype(self) -> dt.DataType:
return dt.string


Expand Down
Loading