Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"""

import abc
from typing import cast, Optional, TypeVar, Union
import typing
from typing import Optional, TypeVar, Union
import warnings

import bigframes_vendored.sklearn.base
Expand Down Expand Up @@ -133,7 +134,7 @@ def register(self: _T, vertex_ai_model_id: Optional[str] = None) -> _T:
self._bqml_model = self._create_bqml_model() # type: ignore
except AttributeError:
raise RuntimeError("A model must be trained before register.")
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

self._bqml_model.register(vertex_ai_model_id)
return self
Expand Down Expand Up @@ -286,7 +287,7 @@ def _predict_and_retry(
bpd.concat([df_result, df_succ]) if df_result is not None else df_succ
)

df_result = cast(
df_result = typing.cast(
bpd.DataFrame,
bpd.concat([df_result, df_fail]) if df_result is not None else df_fail,
)
Expand All @@ -306,7 +307,7 @@ def _extract_output_names(self):

output_names = []
for transform_col in self._bqml_model._model._properties["transformColumns"]:
transform_col_dict = cast(dict, transform_col)
transform_col_dict = typing.cast(dict, transform_col)
# pass the columns that are not transformed
if "transformSql" not in transform_col_dict:
continue
Expand Down
6 changes: 3 additions & 3 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import re
import types
import typing
from typing import cast, Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

from bigframes_vendored import constants
import bigframes_vendored.sklearn.compose._column_transformer
Expand Down Expand Up @@ -218,7 +218,7 @@ def camel_to_snake(name):

output_names = []
for transform_col in bq_model._properties["transformColumns"]:
transform_col_dict = cast(dict, transform_col)
transform_col_dict = typing.cast(dict, transform_col)
# pass the columns that are not transformed
if "transformSql" not in transform_col_dict:
continue
Expand Down Expand Up @@ -282,7 +282,7 @@ def _merge(
return self # SQLScalarColumnTransformer only work inside ColumnTransformer
feature_columns_sorted = sorted(
[
cast(str, feature_column.name)
typing.cast(str, feature_column.name)
for feature_column in bq_model.feature_columns
]
)
Expand Down
5 changes: 3 additions & 2 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import dataclasses
import datetime
from typing import Callable, cast, Iterable, Mapping, Optional, Union
import typing
from typing import Callable, Iterable, Mapping, Optional, Union
import uuid

from google.cloud import bigquery
Expand Down Expand Up @@ -376,7 +377,7 @@ def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
if vertex_ai_model_id is None:
# vertex id needs to start with letters. https://cloud.google.com/vertex-ai/docs/general/resource-naming
vertex_ai_model_id = "bigframes_" + cast(str, self._model.model_id)
vertex_ai_model_id = "bigframes_" + typing.cast(str, self._model.model_id)

# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
# The possibility of conflicts should be low.
Expand Down
15 changes: 8 additions & 7 deletions bigframes/ml/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from __future__ import annotations

from typing import cast, Mapping, Optional
import typing
from typing import Mapping, Optional

from google.cloud import bigquery

Expand Down Expand Up @@ -78,7 +79,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

(X,) = utils.batch_convert_to_dataframe(X)

Expand All @@ -99,7 +100,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TensorFlowModel:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
Expand Down Expand Up @@ -157,7 +158,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

Expand All @@ -178,7 +179,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> ONNXModel:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
Expand Down Expand Up @@ -276,7 +277,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

Expand All @@ -297,7 +298,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBoostModel:
if self.model_path is None:
raise ValueError("Model GCS path must be provided.")
self._bqml_model = self._create_bqml_model()
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
20 changes: 12 additions & 8 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from __future__ import annotations

from typing import cast, Iterable, Literal, Mapping, Optional, Union
import typing
from typing import Iterable, Literal, Mapping, Optional, Union
import warnings

import bigframes_vendored.constants as constants
Expand Down Expand Up @@ -252,7 +253,7 @@ def predict(

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
col_label = typing.cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

options: dict = {}
Expand Down Expand Up @@ -391,7 +392,7 @@ def predict(

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
col_label = typing.cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
Expand Down Expand Up @@ -604,7 +605,10 @@ def fit(
options["prompt_col"] = X.columns.tolist()[0]

self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
X, y, options=options, connection_name=cast(str, self.connection_name)
X,
y,
options=options,
connection_name=typing.cast(str, self.connection_name),
)
return self

Expand Down Expand Up @@ -735,7 +739,7 @@ def predict(

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
col_label = typing.cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options: dict = {
Expand Down Expand Up @@ -820,8 +824,8 @@ def score(
)

# BQML identified the column by name
X_col_label = cast(blocks.Label, X.columns[0])
y_col_label = cast(blocks.Label, y.columns[0])
X_col_label = typing.cast(blocks.Label, X.columns[0])
y_col_label = typing.cast(blocks.Label, y.columns[0])
X = X.rename(columns={X_col_label: "input_text"})
y = y.rename(columns={y_col_label: "output_text"})

Expand Down Expand Up @@ -1033,7 +1037,7 @@ def predict(

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
col_label = typing.cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
Expand Down
7 changes: 4 additions & 3 deletions bigframes/ml/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import inspect
from itertools import chain
import time
from typing import cast, Generator, List, Optional, Union
import typing
from typing import Generator, List, Optional, Union

import bigframes_vendored.sklearn.model_selection._split as vendored_model_selection_split
import bigframes_vendored.sklearn.model_selection._validation as vendored_model_selection_validation
Expand Down Expand Up @@ -99,10 +100,10 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra
train_dfs.append(train)
test_dfs.append(test)

train_df = cast(
train_df = typing.cast(
bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col")
)
test_df = cast(
test_df = typing.cast(
bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col")
)
return [train_df, test_df]
Expand Down
4 changes: 2 additions & 2 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import typing
from typing import cast, Iterable, List, Literal, Optional, Union
from typing import Iterable, List, Literal, Optional, Union

import bigframes_vendored.sklearn.preprocessing._data
import bigframes_vendored.sklearn.preprocessing._discretization
Expand Down Expand Up @@ -470,7 +470,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]:
s = sql[sql.find("(") + 1 : sql.find(")")]
col_label, drop_str, top_k, frequency_threshold = s.split(", ")
drop = (
cast(Literal["most_frequent"], "most_frequent")
typing.cast(Literal["most_frequent"], "most_frequent")
if drop_str.lower() == "'most_frequent'"
else None
)
Expand Down
Loading