diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 9b38702cce..3f6ccecaa2 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -24,7 +24,8 @@ """ import abc -from typing import cast, Optional, TypeVar, Union +import typing +from typing import Optional, TypeVar, Union import warnings import bigframes_vendored.sklearn.base @@ -133,7 +134,7 @@ def register(self: _T, vertex_ai_model_id: Optional[str] = None) -> _T: self._bqml_model = self._create_bqml_model() # type: ignore except AttributeError: raise RuntimeError("A model must be trained before register.") - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) self._bqml_model.register(vertex_ai_model_id) return self @@ -286,7 +287,7 @@ def _predict_and_retry( bpd.concat([df_result, df_succ]) if df_result is not None else df_succ ) - df_result = cast( + df_result = typing.cast( bpd.DataFrame, bpd.concat([df_result, df_fail]) if df_result is not None else df_fail, ) @@ -306,7 +307,7 @@ def _extract_output_names(self): output_names = [] for transform_col in self._bqml_model._model._properties["transformColumns"]: - transform_col_dict = cast(dict, transform_col) + transform_col_dict = typing.cast(dict, transform_col) # pass the columns that are not transformed if "transformSql" not in transform_col_dict: continue diff --git a/bigframes/ml/compose.py b/bigframes/ml/compose.py index d638e026e4..f8244fb0d8 100644 --- a/bigframes/ml/compose.py +++ b/bigframes/ml/compose.py @@ -21,7 +21,7 @@ import re import types import typing -from typing import cast, Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union from bigframes_vendored import constants import bigframes_vendored.sklearn.compose._column_transformer @@ -218,7 +218,7 @@ def camel_to_snake(name): output_names = [] for transform_col in bq_model._properties["transformColumns"]: - transform_col_dict = cast(dict, transform_col) + transform_col_dict = typing.cast(dict, transform_col) # pass the columns that are not transformed if "transformSql" not in transform_col_dict: continue @@ -282,7 +282,7 @@ def _merge( return self # SQLScalarColumnTransformer only work inside ColumnTransformer feature_columns_sorted = sorted( [ - cast(str, feature_column.name) + typing.cast(str, feature_column.name) for feature_column in bq_model.feature_columns ] ) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 4dbc1a5fa3..620843fb6e 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -18,7 +18,8 @@ import dataclasses import datetime -from typing import Callable, cast, Iterable, Mapping, Optional, Union +import typing +from typing import Callable, Iterable, Mapping, Optional, Union import uuid from google.cloud import bigquery @@ -376,7 +377,7 @@ def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel: if vertex_ai_model_id is None: # vertex id needs to start with letters. https://cloud.google.com/vertex-ai/docs/general/resource-naming - vertex_ai_model_id = "bigframes_" + cast(str, self._model.model_id) + vertex_ai_model_id = "bigframes_" + typing.cast(str, self._model.model_id) # truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models. # The possibility of conflicts should be low. diff --git a/bigframes/ml/imported.py b/bigframes/ml/imported.py index 295649ed7f..56b5d6735c 100644 --- a/bigframes/ml/imported.py +++ b/bigframes/ml/imported.py @@ -16,7 +16,8 @@ from __future__ import annotations -from typing import cast, Mapping, Optional +import typing +from typing import Mapping, Optional from google.cloud import bigquery @@ -78,7 +79,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) (X,) = utils.batch_convert_to_dataframe(X) @@ -99,7 +100,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TensorFlowModel: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) @@ -157,7 +158,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) @@ -178,7 +179,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> ONNXModel: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) @@ -276,7 +277,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) @@ -297,7 +298,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBoostModel: if self.model_path is None: raise ValueError("Model GCS path must be provided.") self._bqml_model = self._create_bqml_model() - self._bqml_model = cast(core.BqmlModel, self._bqml_model) + self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model) new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f4e60f3f9d..585599c9b6 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -16,7 +16,8 @@ from __future__ import annotations -from typing import cast, Iterable, Literal, Mapping, Optional, Union +import typing +from typing import Iterable, Literal, Mapping, Optional, Union import warnings import bigframes_vendored.constants as constants @@ -252,7 +253,7 @@ def predict( if len(X.columns) == 1: # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) + col_label = typing.cast(blocks.Label, X.columns[0]) X = X.rename(columns={col_label: "content"}) options: dict = {} @@ -391,7 +392,7 @@ def predict( if len(X.columns) == 1: # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) + col_label = typing.cast(blocks.Label, X.columns[0]) X = X.rename(columns={col_label: "content"}) # TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input @@ -604,7 +605,10 @@ def fit( options["prompt_col"] = X.columns.tolist()[0] self._bqml_model = self._bqml_model_factory.create_llm_remote_model( - X, y, options=options, connection_name=cast(str, self.connection_name) + X, + y, + options=options, + connection_name=typing.cast(str, self.connection_name), ) return self @@ -735,7 +739,7 @@ def predict( if len(X.columns) == 1: # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) + col_label = typing.cast(blocks.Label, X.columns[0]) X = X.rename(columns={col_label: "prompt"}) options: dict = { @@ -820,8 +824,8 @@ def score( ) # BQML identified the column by name - X_col_label = cast(blocks.Label, X.columns[0]) - y_col_label = cast(blocks.Label, y.columns[0]) + X_col_label = typing.cast(blocks.Label, X.columns[0]) + y_col_label = typing.cast(blocks.Label, y.columns[0]) X = X.rename(columns={X_col_label: "input_text"}) y = y.rename(columns={y_col_label: "output_text"}) @@ -1033,7 +1037,7 @@ def predict( if len(X.columns) == 1: # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) + col_label = typing.cast(blocks.Label, X.columns[0]) X = X.rename(columns={col_label: "prompt"}) options = { diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 5adfb03b7f..3d23fbf568 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -20,7 +20,8 @@ import inspect from itertools import chain import time -from typing import cast, Generator, List, Optional, Union +import typing +from typing import Generator, List, Optional, Union import bigframes_vendored.sklearn.model_selection._split as vendored_model_selection_split import bigframes_vendored.sklearn.model_selection._validation as vendored_model_selection_validation @@ -99,10 +100,10 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra train_dfs.append(train) test_dfs.append(test) - train_df = cast( + train_df = typing.cast( bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col") ) - test_df = cast( + test_df = typing.cast( bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col") ) return [train_df, test_df] diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 8bf89b0838..22a3e7e222 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -18,7 +18,7 @@ from __future__ import annotations import typing -from typing import cast, Iterable, List, Literal, Optional, Union +from typing import Iterable, List, Literal, Optional, Union import bigframes_vendored.sklearn.preprocessing._data import bigframes_vendored.sklearn.preprocessing._discretization @@ -470,7 +470,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]: s = sql[sql.find("(") + 1 : sql.find(")")] col_label, drop_str, top_k, frequency_threshold = s.split(", ") drop = ( - cast(Literal["most_frequent"], "most_frequent") + typing.cast(Literal["most_frequent"], "most_frequent") if drop_str.lower() == "'most_frequent'" else None )