From df0f2296932cc5ab3d7711c93675c7d9edff0910 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:33:36 +0100 Subject: [PATCH 1/4] fix parallel bootstrapping --- pySEQ/SEQoutput.py | 12 ++++++------ pySEQ/SEQuential.py | 2 +- pySEQ/helpers/_bootstrap.py | 31 +++++++++++++++++++++++++------ tests/test_accessor.py | 25 +++++++++++++++++++++++++ tests/test_bootstrap.py | 0 tests/test_parallel.py | 27 +++++++++++++++++++++++++++ 6 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 tests/test_accessor.py delete mode 100644 tests/test_bootstrap.py diff --git a/pySEQ/SEQoutput.py b/pySEQ/SEQoutput.py index 8819be1..1cb4d32 100644 --- a/pySEQ/SEQoutput.py +++ b/pySEQ/SEQoutput.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import List, Optional, Literal from .SEQopts import SEQopts -import statsmodels.formula.api as smf +from statsmodels.base.wrapper import ResultsWrapper import polars as pl import matplotlib.figure @@ -9,10 +9,10 @@ class SEQoutput: options: SEQopts = None method: str = None - numerator_models: List[smf.MNLogit] = None - denominator_models: List[smf.MNLogit] = None - outcome_models: List[List[smf.glm]] = None - compevent_models: List[List[smf.glm]] = None + numerator_models: List[ResultsWrapper] = None + denominator_models: List[ResultsWrapper] = None + outcome_models: List[List[ResultsWrapper]] = None + compevent_models: List[List[ResultsWrapper]] = None weight_statistics: dict = None hazard: pl.DataFrame = None km_data: pl.DataFrame = None @@ -78,7 +78,7 @@ def retrieve_data(self, case _: data = self.km_data if data is None: - ValueError("Data {type} was not created in the SEQuential process") + raise ValueError("Data {type} was not created in the SEQuential process") return data \ No newline at end of file diff --git a/pySEQ/SEQuential.py b/pySEQ/SEQuential.py index c07ed45..cfa1977 100644 --- a/pySEQ/SEQuential.py +++ b/pySEQ/SEQuential.py @@ -216,7 +216,7 @@ def collect(self): "numerator_model", "denominator_model", "outcome_model", "hazard_ratio", "risk_estimates", - "km_data", "diagnostics", + "km_data", "km_graph", "diagnostics", "_survival_time", "_hazard_time", "_model_time", "_expansion_time", "weight_stats" diff --git a/pySEQ/helpers/_bootstrap.py b/pySEQ/helpers/_bootstrap.py index 8dee0df..5e2987e 100644 --- a/pySEQ/helpers/_bootstrap.py +++ b/pySEQ/helpers/_bootstrap.py @@ -1,6 +1,7 @@ from functools import wraps from concurrent.futures import ProcessPoolExecutor, as_completed import polars as pl +import numpy as np from tqdm import tqdm import copy import time @@ -23,6 +24,19 @@ def _prepare_boot_data(self, data, boot_id): return bootstrapped +def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): + obj = copy.deepcopy(obj) + obj._rng = np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() + obj.DT = _prepare_boot_data(obj, original_DT, i) + + # Disable bootstrapping to prevent recursion + obj.bootstrap_nboot = 0 + + method = getattr(obj, method_name) + result = method(*args, **kwargs) + obj._rng = None + return result + def bootstrap_loop(method): @wraps(method) def wrapper(self, *args, **kwargs): @@ -38,17 +52,22 @@ def wrapper(self, *args, **kwargs): original_DT = self.DT nboot = self.bootstrap_nboot ncores = self.ncores + seed = getattr(self, "seed", None) + method_name = method.__name__ - def _worker(i): - obj = copy.deepcopy(self) - obj.DT = _prepare_boot_data(obj, original_DT, i) - return method(obj, *args, **kwargs) - if getattr(self, "parallel", False): + original_rng = getattr(self, "_rng", None) + self._rng = None + with ProcessPoolExecutor(max_workers=ncores) as executor: - futures = [executor.submit(_worker, i) for i in range(nboot)] + futures = [ + executor.submit(_bootstrap_worker, self, method_name, original_DT, i, seed, args, kwargs) + for i in range(nboot) + ] for j in tqdm(as_completed(futures), total=nboot, desc="Bootstrapping..."): results.append(j.result()) + + self._rng = original_rng else: for i in tqdm(range(nboot), desc="Bootstrapping..."): self.DT = _prepare_boot_data(self, original_DT, i) diff --git a/tests/test_accessor.py b/tests/test_accessor.py new file mode 100644 index 0000000..317ef5c --- /dev/null +++ b/tests/test_accessor.py @@ -0,0 +1,25 @@ +from pySEQ import SEQuential, SEQopts +from pySEQ.data import load_data +import pytest + +def test_ITT_collector(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method = "ITT", + parameters=SEQopts() + ) + s.expand() + s.fit() + collector = s.collect() + outcomes = collector.retrieve_data("unique_outcomes") + with pytest.raises(ValueError): + collector.retrieve_data("km_data") \ No newline at end of file diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_parallel.py b/tests/test_parallel.py index e69de29..18c9c0b 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -0,0 +1,27 @@ +from pySEQ import SEQuential, SEQopts +from pySEQ.data import load_data + +def test_parallel_ITT(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method = "ITT", + parameters=SEQopts(parallel=True, + bootstrap_nboot=2) + ) + s.expand() + s.bootstrap() + s.fit() + matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() + assert matrix == [-6.828506035553407, 0.18935003090041902, 0.12717241010542563, + 0.033715156987629266, -0.00014691202235029346, 0.044566165558944326, + 0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771, + 0.20072409918428052] \ No newline at end of file From 57ffb8826c6a22dd157284bae5afe5eebd46b8ca Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:50:30 +0100 Subject: [PATCH 2/4] overhaul pyproject --- pyproject.toml | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2eca252..956a738 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,30 @@ build-backend = "setuptools.build_meta" name = "pySEQ" version = "0.9.0" description = "Sequentially Nested Target Trial Emulation" -authors = [{name = "Ryan ODea", email = "ryan.odea@psi.ch"}] +readme = "README.md" +license = {text = "MIT"} +keywords = ["causal inference", "sequential trial emulation", "target trial", "observational studies"] requires-python = ">=3.10" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] + +authors = [ + {name = "Ryan O'Dea", email = "ryan.odea@psi.ch"}, + {name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"}, + {name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"}, + {name = "Miguel Hernan", email = "mhernan@hsph.harvard.edu"}, +] + +maintainers = [ + {name = "Ryan O'Dea", email = "ryan.odea@psi.ch"}, +] + dependencies = [ "numpy", "polars", @@ -18,10 +40,22 @@ dependencies = [ "lifelines" ] -[tools.setuptools] +[project.urls] +Homepage = "https://github.com/CausalInference/pySEQ" +Repository = "https://github.com/CausalInference/pySEQ" +"Bug Tracker" = "https://github.com/CausalInference/pySEQ/issues" + +"Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546" +"Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X" +"Tom Palmer (ORCID)" = "https://orcid.org/0000-0003-4655-4511" +"Miguel Hernan (ORCID)" = "https://orcid.org/0000-0003-1619-8456" +"University of Bristol (ROR)" = "https://ror.org/0524sp257" +"Harvard University (ROR)" = "https://ror.org/03vek6s52" + +[tool.setuptools] packages = ["pySEQ", "pySEQ.data"] -[tools.setuptools.package-data] +[tool.setuptools.package-data] SEQdata = ["data/*.csv"] [tool.pytest.ini_options] From e7678cfce181db2d1ebaeaf72744d8e6670f420e Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:05:12 +0100 Subject: [PATCH 3/4] added a datachecker --- pySEQ/SEQuential.py | 3 ++- pySEQ/error/__init__.py | 3 ++- pySEQ/error/_datachecker.py | 29 +++++++++++++++++++++++++++++ tests/test_covariates.py | 4 ++-- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pySEQ/SEQuential.py b/pySEQ/SEQuential.py index cfa1977..3240e9e 100644 --- a/pySEQ/SEQuential.py +++ b/pySEQ/SEQuential.py @@ -8,7 +8,7 @@ from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .error import _param_checker +from .error import _param_checker, _datachecker from .helpers import _col_string, bootstrap_loop, _format_time from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator from .expansion import _binder, _dynamic, _random_selection, _diagnostics @@ -69,6 +69,7 @@ def __init__( self.cense_denominator = _cense_denominator(self) _param_checker(self) + _datachecker(self) def expand(self): start = time.perf_counter() diff --git a/pySEQ/error/__init__.py b/pySEQ/error/__init__.py index 4ff4313..c9ee5ee 100644 --- a/pySEQ/error/__init__.py +++ b/pySEQ/error/__init__.py @@ -1 +1,2 @@ -from ._param_checker import _param_checker \ No newline at end of file +from ._param_checker import _param_checker +from ._datachecker import _datachecker \ No newline at end of file diff --git a/pySEQ/error/_datachecker.py b/pySEQ/error/_datachecker.py index e69de29..da8277f 100644 --- a/pySEQ/error/_datachecker.py +++ b/pySEQ/error/_datachecker.py @@ -0,0 +1,29 @@ +import polars as pl + +def _datachecker(self): + check = self.data.group_by(self.id_col).agg([ + pl.len().alias("row_count"), + pl.col(self.time_col).max().alias("max_time") + ]) + + invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1) + if len(invalid) > 0: + raise ValueError( + f"Data validation failed: {len(invalid)} ID(s) have mismatched " + f"This suggests invalid times" + f"Invalid IDs:\n{invalid}" + ) + + for col in self.excused_colnames: + violations = self.data.sort([self.id_col, self.time_col]).group_by(self.id_col).agg([ + ((pl.col(col).cum_sum().shift(1, fill_value=0) > 0) & (pl.col(col) == 0)) + .any() + .alias("has_violation") + ]).filter(pl.col("has_violation")) + + if len(violations) > 0: + raise ValueError( + f"Column '{col}' violates 'once one, always one' rule for excusing treatment " + f"{len(violations)} ID(s) have zeros after ones." + ) + \ No newline at end of file diff --git a/tests/test_covariates.py b/tests/test_covariates.py index 40ee6fe..f07b2c4 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -122,7 +122,7 @@ def test_PreE_censoring_excused_covariates(): parameters=SEQopts(weighted=True, weight_preexpansion=True, excused=True, - excused_colnames=["ExcusedZero", "ExcusedOne"]) + excused_colnames=["excusedZero", "excusedOne"]) ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq" assert s.numerator is None @@ -144,7 +144,7 @@ def test_PostE_censoring_excused_covariates(): method = "censoring", parameters=SEQopts(weighted=True, excused=True, - excused_colnames=["ExcusedZero", "ExcusedOne"]) + excused_colnames=["excusedZero", "excusedOne"]) ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" From f63abf1917ee62c81ba70c14d662305a292dea86 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:12:10 +0100 Subject: [PATCH 4/4] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 128b0c7..0a48102 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples model.fit() # Fit the model model.survival() # Create survival curves model.plot() # Create and show a plot of the survival curves +model.collect() # Collection of important information ```