Skip to content

Commit 05586f0

Browse files
Merge pull request #186 from PolicyEngine/dev
Standardise save() and load() functions for simulations
2 parents de2d3ce + ceaee0a commit 05586f0

File tree

9 files changed

+176
-224
lines changed

9 files changed

+176
-224
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Standardised saving and loading of simulations.

examples/employment_income_variation_uk.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -162,42 +162,10 @@ def create_dataset_with_varied_employment_income(
162162

163163
def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation:
164164
"""Run a single simulation for all employment income variations."""
165-
# Specify additional variables to calculate beyond defaults
166-
variables = {
167-
"household": [
168-
# Default variables
169-
"household_id",
170-
"household_weight",
171-
"household_net_income",
172-
"hbai_household_net_income",
173-
"household_benefits",
174-
"household_tax",
175-
],
176-
"person": [
177-
"person_id",
178-
"benunit_id",
179-
"household_id",
180-
"person_weight",
181-
"employment_income",
182-
"age",
183-
],
184-
"benunit": [
185-
"benunit_id",
186-
"benunit_weight",
187-
# Individual benefits (at benunit level)
188-
"universal_credit",
189-
"child_benefit",
190-
"working_tax_credit",
191-
"child_tax_credit",
192-
"pension_credit",
193-
"income_support",
194-
],
195-
}
196165

197166
simulation = Simulation(
198167
dataset=dataset,
199168
tax_benefit_model_version=uk_latest,
200-
variables=variables,
201169
)
202170
simulation.run()
203171
return simulation

examples/employment_income_variation_us.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -171,56 +171,10 @@ def create_dataset_with_varied_employment_income(
171171

172172
def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation:
173173
"""Run a single simulation for all employment income variations."""
174-
# Specify variables to calculate
175-
variables = {
176-
"household": [
177-
"household_id",
178-
"household_weight",
179-
"household_net_income",
180-
"household_benefits",
181-
"household_tax",
182-
"household_market_income",
183-
],
184-
"person": [
185-
"person_id",
186-
"household_id",
187-
"marital_unit_id",
188-
"family_id",
189-
"spm_unit_id",
190-
"tax_unit_id",
191-
"person_weight",
192-
"employment_income",
193-
"age",
194-
],
195-
"spm_unit": [
196-
"spm_unit_id",
197-
"spm_unit_weight",
198-
"snap",
199-
"tanf",
200-
"spm_unit_net_income",
201-
],
202-
"tax_unit": [
203-
"tax_unit_id",
204-
"tax_unit_weight",
205-
"income_tax",
206-
"employee_payroll_tax",
207-
"eitc",
208-
"ctc",
209-
],
210-
"marital_unit": [
211-
"marital_unit_id",
212-
"marital_unit_weight",
213-
],
214-
"family": [
215-
"family_id",
216-
"family_weight",
217-
],
218-
}
219174

220175
simulation = Simulation(
221176
dataset=dataset,
222177
tax_benefit_model_version=us_latest,
223-
variables=variables,
224178
)
225179
simulation.run()
226180
return simulation

src/policyengine/core/parameter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class Parameter(BaseModel):
99
id: str = Field(default_factory=lambda: str(uuid4()))
1010
name: str
11+
label: str | None = None
1112
description: str | None = None
1213
data_type: type | None = None
1314
tax_benefit_model_version: TaxBenefitModelVersion

src/policyengine/core/simulation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ class Simulation(BaseModel):
2121
tax_benefit_model_version: TaxBenefitModelVersion = None
2222
output_dataset: Dataset | None = None
2323

24-
variables: dict[str, list[str]] | None = Field(
25-
default=None,
26-
description="Optional dictionary mapping entity names to lists of variable names to calculate. If None, uses model defaults.",
27-
)
28-
2924
def run(self):
3025
self.tax_benefit_model_version.run(self)
26+
27+
def save(self):
28+
"""Save the simulation's output dataset."""
29+
self.tax_benefit_model_version.save(self)
30+
31+
def load(self):
32+
"""Load the simulation's output dataset."""
33+
self.tax_benefit_model_version.load(self)

src/policyengine/core/tax_benefit_model_version.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def run(self, simulation: "Simulation") -> "Simulation":
2929
"The TaxBenefitModel class must define a method to execute simulations."
3030
)
3131

32+
def save(self, simulation: "Simulation"):
33+
raise NotImplementedError(
34+
"The TaxBenefitModel class must define a method to save simulations."
35+
)
36+
37+
def load(self, simulation: "Simulation"):
38+
raise NotImplementedError(
39+
"The TaxBenefitModel class must define a method to load simulations."
40+
)
41+
3242
def get_parameter(self, name: str) -> "Parameter":
3343
"""Get a parameter by name.
3444

src/policyengine/tax_benefit_models/uk/model.py

Lines changed: 84 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self, **kwargs: dict):
8383
parameter = Parameter(
8484
id=self.id + "-" + param_node.name,
8585
name=param_node.name,
86+
label=param_node.metadata.get("label", param_node.name),
8687
tax_benefit_model_version=self,
8788
description=param_node.description,
8889
data_type=type(
@@ -152,77 +153,72 @@ def run(self, simulation: "Simulation") -> "Simulation":
152153
)
153154
modifier(microsim)
154155

155-
# Allow custom variable selection, or use defaults
156-
if simulation.variables is not None:
157-
entity_variables = simulation.variables
158-
else:
159-
# Default comprehensive variable set
160-
entity_variables = {
161-
"person": [
162-
# IDs and weights
163-
"person_id",
164-
"benunit_id",
165-
"household_id",
166-
"person_weight",
167-
# Demographics
168-
"age",
169-
"gender",
170-
"is_adult",
171-
"is_SP_age",
172-
"is_child",
173-
# Income
174-
"employment_income",
175-
"self_employment_income",
176-
"pension_income",
177-
"private_pension_income",
178-
"savings_interest_income",
179-
"dividend_income",
180-
"property_income",
181-
"total_income",
182-
"earned_income",
183-
# Benefits
184-
"universal_credit",
185-
"child_benefit",
186-
"pension_credit",
187-
"income_support",
188-
"working_tax_credit",
189-
"child_tax_credit",
190-
# Tax
191-
"income_tax",
192-
"national_insurance",
193-
],
194-
"benunit": [
195-
# IDs and weights
196-
"benunit_id",
197-
"benunit_weight",
198-
# Structure
199-
"family_type",
200-
# Income and benefits
201-
"universal_credit",
202-
"child_benefit",
203-
"working_tax_credit",
204-
"child_tax_credit",
205-
],
206-
"household": [
207-
# IDs and weights
208-
"household_id",
209-
"household_weight",
210-
# Income measures
211-
"household_net_income",
212-
"hbai_household_net_income",
213-
"equiv_hbai_household_net_income",
214-
"household_market_income",
215-
"household_gross_income",
216-
# Benefits and tax
217-
"household_benefits",
218-
"household_tax",
219-
"vat",
220-
# Housing
221-
"rent",
222-
"council_tax",
223-
"tenure_type",
224-
],
225-
}
156+
entity_variables = {
157+
"person": [
158+
# IDs and weights
159+
"person_id",
160+
"benunit_id",
161+
"household_id",
162+
"person_weight",
163+
# Demographics
164+
"age",
165+
"gender",
166+
"is_adult",
167+
"is_SP_age",
168+
"is_child",
169+
# Income
170+
"employment_income",
171+
"self_employment_income",
172+
"pension_income",
173+
"private_pension_income",
174+
"savings_interest_income",
175+
"dividend_income",
176+
"property_income",
177+
"total_income",
178+
"earned_income",
179+
# Benefits
180+
"universal_credit",
181+
"child_benefit",
182+
"pension_credit",
183+
"income_support",
184+
"working_tax_credit",
185+
"child_tax_credit",
186+
# Tax
187+
"income_tax",
188+
"national_insurance",
189+
],
190+
"benunit": [
191+
# IDs and weights
192+
"benunit_id",
193+
"benunit_weight",
194+
# Structure
195+
"family_type",
196+
# Income and benefits
197+
"universal_credit",
198+
"child_benefit",
199+
"working_tax_credit",
200+
"child_tax_credit",
201+
],
202+
"household": [
203+
# IDs and weights
204+
"household_id",
205+
"household_weight",
206+
# Income measures
207+
"household_net_income",
208+
"hbai_household_net_income",
209+
"equiv_hbai_household_net_income",
210+
"household_market_income",
211+
"household_gross_income",
212+
# Benefits and tax
213+
"household_benefits",
214+
"household_tax",
215+
"vat",
216+
# Housing
217+
"rent",
218+
"council_tax",
219+
"tenure_type",
220+
],
221+
}
226222

227223
data = {
228224
"person": pd.DataFrame(),
@@ -247,6 +243,7 @@ def run(self, simulation: "Simulation") -> "Simulation":
247243
)
248244

249245
simulation.output_dataset = PolicyEngineUKDataset(
246+
id=simulation.id,
250247
name=dataset.name,
251248
description=dataset.description,
252249
filepath=str(
@@ -262,7 +259,23 @@ def run(self, simulation: "Simulation") -> "Simulation":
262259
),
263260
)
264261

262+
def save(self, simulation: "Simulation"):
263+
"""Save the simulation's output dataset."""
265264
simulation.output_dataset.save()
266265

266+
def load(self, simulation: "Simulation"):
267+
"""Load the simulation's output dataset."""
268+
simulation.output_dataset = PolicyEngineUKDataset(
269+
id=simulation.id,
270+
name=simulation.dataset.name,
271+
description=simulation.dataset.description,
272+
filepath=str(
273+
Path(simulation.dataset.filepath).parent
274+
/ (simulation.id + ".h5")
275+
),
276+
year=simulation.dataset.year,
277+
is_output_dataset=True,
278+
)
279+
267280

268281
uk_latest = PolicyEngineUKLatest()

0 commit comments

Comments
 (0)