Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions corrai/base/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fastprogress.fastprogress import progress_bar
from sklearn.metrics import mean_squared_error, mean_absolute_error
from typing import Callable

from corrai.base.utils import check_datetime_index
from corrai.base.metrics import nmbe, cv_rmse
Expand All @@ -22,7 +23,7 @@
def aggregate_time_series(
results: pd.Series,
indicator: str,
method: str = "mean",
method: str | Callable = "mean",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
freq: str | pd.Timedelta | dt.timedelta = None,
Expand Down Expand Up @@ -132,7 +133,9 @@ def aggregate_time_series(
"""

agg_method_kwarg = {} if agg_method_kwarg is None else agg_method_kwarg
method = METHODS[method]

if isinstance(method, str):
method = METHODS[method]

for df in results:
check_datetime_index(df)
Expand Down
42 changes: 17 additions & 25 deletions corrai/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Model(ABC):
Persists the model state or parameters to disk. Optional.
"""

def __init__(self, is_dynamic: bool = True):
def __init__(self, is_dynamic: bool):
self.is_dynamic = is_dynamic

def get_property_from_param(
Expand Down Expand Up @@ -190,7 +190,19 @@ def save(self, file_path: Path):
raise NotImplementedError("No save method was defined for this model")


class IshigamiDynamic(Model):
class PyModel(Model, ABC):
def __init__(self, is_dynamic: bool):
super().__init__(is_dynamic)

def get_property_values(self, property_list: list):
return [getattr(self, name) for name in property_list]

def set_property_values(self, property_dict: dict):
for prop, val in property_dict.items():
setattr(self, prop, val)


class IshigamiDynamic(PyModel):
"""
Example implementation of the Ishigami function.

Expand Down Expand Up @@ -218,13 +230,6 @@ def __init__(self):
self.x2 = 2
self.x3 = 3

def get_property_values(self, property_list: list):
return [getattr(self, name) for name in property_list]

def set_property_values(self, property_dict: dict):
for prop, val in property_dict.items():
setattr(self, prop, val)

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
Expand All @@ -250,7 +255,7 @@ def simulate(
)


class Ishigami(Model):
class Ishigami(PyModel):
"""
Example implementation of the Ishigami function.

Expand Down Expand Up @@ -278,13 +283,6 @@ def __init__(self):
self.x2 = 2
self.x3 = 3

def get_property_values(self, property_list: list):
return [getattr(self, name) for name in property_list]

def set_property_values(self, property_dict: dict):
for prop, val in property_dict.items():
setattr(self, prop, val)

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
Expand All @@ -303,16 +301,13 @@ def simulate(
return pd.Series({"res": res})


class PymodelDynamic(Model):
class PymodelDynamic(PyModel):
def __init__(self):
super().__init__(is_dynamic=True)
self.prop_1 = 1
self.prop_2 = 2
self.prop_3 = 3

def get_property_values(self, property_list: list):
return [getattr(self, name) for name in property_list]

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
Expand All @@ -333,16 +328,13 @@ def simulate(
)


class PymodelStatic(Model):
class PymodelStatic(PyModel):
def __init__(self):
super().__init__(is_dynamic=False)
self.prop_1 = 1
self.prop_2 = 2
self.prop_3 = 3

def get_property_values(self, property_list: list):
return [getattr(self, name) for name in property_list]

def simulate(
self,
property_dict: dict[str, str | int | float] = None,
Expand Down
Loading