diff --git a/policyengine_core/reforms/scenario.py b/policyengine_core/reforms/scenario.py new file mode 100644 index 00000000..1e5cfac6 --- /dev/null +++ b/policyengine_core/reforms/scenario.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from policyengine_core.simulations import Simulation + + +class Scenario(BaseModel): + parameter_changes: ( + dict[str, dict[str | int, float | int | bool]] | None + ) = None + """A dictionary mapping parameter names to their time-period-specific changes.""" + modifier_function: Callable[["Simulation"], None] | None = None + """A function that modifies the simulation in some way, e.g., by applying a tax policy change.""" + + def __init__(self): + # Validate parameter changes + pass + # Validate modifier function? + pass diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index a0b1ec85..160e7752 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -6,7 +6,7 @@ from numpy.typing import ArrayLike import logging from pathlib import Path - +from policyengine_core.reforms.scenario import Scenario from policyengine_core import commons, periods from policyengine_core.data.dataset import Dataset from policyengine_core.entities.entity import Entity @@ -91,6 +91,7 @@ def __init__( trace: bool = False, default_input_period: str = None, default_calculation_period: str = None, + scenario: Scenario = None, ): self.default_input_period = ( default_input_period or self.default_input_period @@ -227,6 +228,12 @@ def __init__( self.parent_branch = None + if scenario is not None: + # Apply parameter changes + pass + if scenario.modifier_function is not None: + scenario.modifier_function(self) + def apply_reform(self, reform: Union[tuple, Reform]): if isinstance(reform, tuple): for subreform in reform: