Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 143 additions & 21 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bigframes import clients, dataframe, dtypes
from bigframes import pandas as bpd
from bigframes import series, session
from bigframes.bigquery._operations import utils as bq_utils
from bigframes.core import convert
from bigframes.core.logging import log_adapter
import bigframes.core.sql.literals
Expand Down Expand Up @@ -391,7 +392,7 @@ def generate_double(

@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_embedding(
model_name: str,
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
output_dimensionality: Optional[int] = None,
Expand All @@ -415,9 +416,8 @@ def generate_embedding(
... ) # doctest: +SKIP

Args:
model_name (str):
The name of a remote model from Vertex AI, such as the
multimodalembedding@001 model.
model (bigframes.ml.base.BaseEstimator or str):
The model to use for text embedding.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate embeddings for. If a Series is provided, it is
treated as the 'content' column. If a DataFrame is provided, it
Expand Down Expand Up @@ -454,20 +454,9 @@ def generate_embedding(
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
for details.
"""
if isinstance(data, (pd.DataFrame, pd.Series)):
data = bpd.read_pandas(data)

if isinstance(data, series.Series):
data = data.copy()
data.name = "content"
data_df = data.to_frame()
elif isinstance(data, dataframe.DataFrame):
data_df = data
else:
raise ValueError(f"Unsupported data type: {type(data)}")

# We need to get the SQL for the input data to pass as a subquery to the TVF
source_sql = data_df.sql
data = _to_dataframe(data, series_rename="content")
model_name, session = bq_utils.get_model_name_and_session(model, data)
table_sql = bq_utils.to_sql(data)

struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
if output_dimensionality is not None:
Expand All @@ -488,12 +477,128 @@ def generate_embedding(
SELECT *
FROM AI.GENERATE_EMBEDDING(
MODEL `{model_name}`,
({source_sql}),
{bigframes.core.sql.literals.struct_literal(struct_fields)})
({table_sql}),
{bigframes.core.sql.literals.struct_literal(struct_fields)}
)
"""

return data_df._session.read_gbq(query)
if session is None:
return bpd.read_gbq_query(query)
else:
return session.read_gbq_query(query)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_text(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
ground_with_google_search: Optional[bool] = None,
request_type: Optional[str] = None,
) -> dataframe.DataFrame:
"""
Generates text using a BigQuery ML model.

See the `BigQuery ML GENERATE_TEXT function syntax
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
for additional reference.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> df = bpd.DataFrame({"prompt": ["write a poem about apples"]})
>>> bbq.ai.generate_text(
... "project.dataset.model_name",
... df
... ) # doctest: +SKIP

Args:
model (bigframes.ml.base.BaseEstimator or str):
The model to use for text generation.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate embeddings for. If a Series is provided, it is
treated as the 'content' column. If a DataFrame is provided, it
must contain a 'content' column, or you must rename the column you
wish to embed to 'content'.
temperature (float, optional):
A FLOAT64 value that is used for sampling promiscuity. The value
must be in the range ``[0.0, 1.0]``. A lower temperature works well
for prompts that expect a more deterministic and less open-ended
or creative response, while a higher temperature can lead to more
diverse or creative results. A temperature of ``0`` is
deterministic, meaning that the highest probability response is
always selected.
max_output_tokens (int, optional):
An INT64 value that sets the maximum number of tokens in the
generated text.
top_k (int, optional):
An INT64 value that changes how the model selects tokens for
output. A ``top_k`` of ``1`` means the next selected token is the
most probable among all tokens in the model's vocabulary. A
``top_k`` of ``3`` means that the next token is selected from
among the three most probable tokens by using temperature. The
default value is ``40``.
top_p (float, optional):
A FLOAT64 value that changes how the model selects tokens for
output. Tokens are selected from most probable to least probable
until the sum of their probabilities equals the ``top_p`` value.
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
select either A or B as the next token by using temperature. The
default value is ``0.95``.
stop_sequences (List[str], optional):
An ARRAY<STRING> value that contains the stop sequences for the model.
ground_with_google_search (bool, optional):
A BOOL value that determines whether to ground the model with Google Search.
request_type (str, optional):
A STRING value that contains the request type for the model.

Returns:
bigframes.pandas.DataFrame:
The generated text.
"""
data = _to_dataframe(data, series_rename="prompt")
model_name, session = bq_utils.get_model_name_and_session(model, data)
table_sql = bq_utils.to_sql(data)

struct_fields: Dict[
str,
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
] = {}
if temperature is not None:
struct_fields["TEMPERATURE"] = temperature
if max_output_tokens is not None:
struct_fields["MAX_OUTPUT_TOKENS"] = max_output_tokens
if top_k is not None:
struct_fields["TOP_K"] = top_k
if top_p is not None:
struct_fields["TOP_P"] = top_p
if stop_sequences is not None:
struct_fields["STEP_SEQUENCES"] = stop_sequences
if ground_with_google_search is not None:
struct_fields["GROUND_WITH_GOOGLE_SEARCH"] = ground_with_google_search
if request_type is not None:
struct_fields["REQUEST_TYPE"] = request_type

query = f"""
SELECT *
FROM AI.GENERATE_TEXT(
MODEL `{model_name}`,
({table_sql}),
{bigframes.core.sql.literals.struct_literal(struct_fields)}
)
"""

if session is None:
return bpd.read_gbq_query(query)
else:
return session.read_gbq_query(query)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
Expand Down Expand Up @@ -811,3 +916,20 @@ def _resolve_connection_id(series: series.Series, connection_id: str | None):
series._session._project,
series._session._location,
)


def _to_dataframe(
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
series_rename: str,
) -> dataframe.DataFrame:
if isinstance(data, (pd.DataFrame, pd.Series)):
data = bpd.read_pandas(data)

if isinstance(data, series.Series):
data = data.copy()
data.name = series_rename
return data.to_frame()
elif isinstance(data, dataframe.DataFrame):
return data

raise ValueError(f"Unsupported data type: {type(data)}")
84 changes: 21 additions & 63 deletions bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,20 @@

from __future__ import annotations

from typing import cast, List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Union

import bigframes_vendored.constants
import google.cloud.bigquery
import pandas as pd

from bigframes.bigquery._operations import utils
import bigframes.core.logging.log_adapter as log_adapter
import bigframes.core.sql.ml
import bigframes.dataframe as dataframe
import bigframes.ml.base
import bigframes.session


# Helper to convert DataFrame to SQL string
def _to_sql(df_or_sql: Union[pd.DataFrame, dataframe.DataFrame, str]) -> str:
import bigframes.pandas as bpd

if isinstance(df_or_sql, str):
return df_or_sql

if isinstance(df_or_sql, pd.DataFrame):
bf_df = bpd.read_pandas(df_or_sql)
else:
bf_df = cast(dataframe.DataFrame, df_or_sql)

# Cache dataframes to make sure base table is not a snapshot.
# Cached dataframe creates a full copy, never uses snapshot.
# This is a workaround for internal issue b/310266666.
bf_df.cache()
sql, _, _ = bf_df._to_sql_query(include_index=False)
return sql


def _get_model_name_and_session(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
# Other dataframe arguments to extract session from
*dataframes: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]],
) -> tuple[str, Optional[bigframes.session.Session]]:
if isinstance(model, pd.Series):
try:
model_ref = model["modelReference"]
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
except KeyError:
raise ValueError("modelReference must be present in the pandas Series.")
elif isinstance(model, str):
model_name = model
else:
if model._bqml_model is None:
raise ValueError("Model must be fitted to be used in ML operations.")
return model._bqml_model.model_name, model._bqml_model.session

session = None
for df in dataframes:
if isinstance(df, dataframe.DataFrame):
session = df._session
break

return model_name, session


def _get_model_metadata(
*,
bqclient: google.cloud.bigquery.Client,
Expand Down Expand Up @@ -143,8 +97,12 @@ def create_model(
"""
import bigframes.pandas as bpd

training_data_sql = _to_sql(training_data) if training_data is not None else None
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None
training_data_sql = (
utils.to_sql(training_data) if training_data is not None else None
)
custom_holiday_sql = (
utils.to_sql(custom_holiday) if custom_holiday is not None else None
)

# Determine session from DataFrames if not provided
if session is None:
Expand Down Expand Up @@ -227,8 +185,8 @@ def evaluate(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_) if input_ is not None else None
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_) if input_ is not None else None

sql = bigframes.core.sql.ml.evaluate(
model_name=model_name,
Expand Down Expand Up @@ -281,8 +239,8 @@ def predict(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_)

sql = bigframes.core.sql.ml.predict(
model_name=model_name,
Expand Down Expand Up @@ -340,8 +298,8 @@ def explain_predict(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_)

sql = bigframes.core.sql.ml.explain_predict(
model_name=model_name,
Expand Down Expand Up @@ -383,7 +341,7 @@ def global_explain(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model)
model_name, session = utils.get_model_name_and_session(model)
sql = bigframes.core.sql.ml.global_explain(
model_name=model_name,
class_level_explain=class_level_explain,
Expand Down Expand Up @@ -419,8 +377,8 @@ def transform(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_)

sql = bigframes.core.sql.ml.transform(
model_name=model_name,
Expand Down Expand Up @@ -500,8 +458,8 @@ def generate_text(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_)

sql = bigframes.core.sql.ml.generate_text(
model_name=model_name,
Expand Down Expand Up @@ -565,8 +523,8 @@ def generate_embedding(
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)
model_name, session = utils.get_model_name_and_session(model, input_)
table_sql = utils.to_sql(input_)

sql = bigframes.core.sql.ml.generate_embedding(
model_name=model_name,
Expand Down
Loading
Loading