diff --git a/README_HPC.md b/README_HPC.md new file mode 100644 index 00000000..76bff00a --- /dev/null +++ b/README_HPC.md @@ -0,0 +1,147 @@ +# HPC and DDP Usage Guide + +This package supports single GPU, multi-GPU, and HPC clusters using PyTorch Lightning. The primary goal is consistent behavior across CPU, GPU, and DDP environments. The secondary goal is correct ordering of predictions under DDP. + +## Core Principles + +**Map-style datasets**: Lightning can shard data with DistributedSampler when using `torch.utils.data.Dataset`. + +**LightningDataModule pattern**: Builds datasets from user arrays and manages train/val/test splits consistently. + +**Stable prediction ordering**: Return prediction payloads that include stable indices, then gather and reorder on rank 0. + +## Dataset and Batch Structure + +Datasets are map-style (`torch.utils.data.Dataset`). Each `__getitem__` returns a dict with standard keys: `contexts`, `predictors`, `outcomes`, plus indexing keys for DDP-safe prediction assembly: `idx` (dataset local index), `orig_idx` (stable original row ID). For multitask variants, additional keys include `sample_idx`, `outcome_idx`, `predictor_idx`. DDP sharding can change sample order and pad the last batch, so these indices enable correct reconstruction. + +## DataModule Usage + +The DataModule converts numpy or pandas arrays into tensors and slices by split indices. It passes `orig_idx` into each dataset so every sample reports its original row ID. + +**Split configuration**: Provide `train_idx`, `val_idx`, `test_idx`, `predict_idx` directly, or provide a `splitter(C, X, Y)` callable that returns `(train_idx, val_idx, test_idx)`. If `predict_idx` is not provided, it defaults to `test_idx` when present, otherwise defaults to the full range. + +**Example instantiation**: +```python +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + +dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, + task_type="singletask_multivariate", + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=predict_idx, + train_batch_size=32, + val_batch_size=64, + test_batch_size=64, + predict_batch_size=64, + num_workers=4, + pin_memory=True, + persistent_workers=True, + drop_last=False, + shuffle_train=True, + shuffle_eval=False, +) + +trainer.fit(model, datamodule=dm) +preds = trainer.predict(model, datamodule=dm) +``` + +If calling loaders manually, call `dm.setup(stage="predict")` before retrieving the dataloader. + +## DDP Prediction Mechanics + +Prediction assembly occurs only on rank 0. Non-rank-0 processes return `None`. This prevents duplicated outputs and keeps the API stable under DDP. + +**Predict step payload**: Each `LightningModule.predict_step` returns a dict containing indices (`idx`, `orig_idx`, and optional task indices), batch content when needed (`contexts`, `predictors`), and model outputs (`betas`, `mus`, and sometimes `correlations`). Tensors are detached and moved to CPU inside `predict_step` to keep GPU memory stable. + +**Gather and reorder process**: The trainer helper packs requested keys from local predict outputs, gathers to rank 0, merges payloads by concatenating on axis 0, stable sorts by `idx` when present (otherwise `orig_idx`), and deduplicates padded samples from `DistributedSampler`. If using `dist.gather_object`, ensure the collective backend supports object gathers (commonly Gloo). If your default process group is NCCL, use a Gloo group for object gather or switch to tensor all-gather. + +## Training Configuration + +**Single node, single GPU**: Use `Trainer(accelerator="gpu", devices=1, strategy="auto")`. Recommended data settings: `pin_memory=True`, `num_workers` tuned to CPU count and batch size, `persistent_workers=True` if `num_workers > 0`. + +**Single node, multi-GPU**: Two supported patterns exist. + +Pattern A (Lightning spawns processes): Use a normal Python launch and let Lightning spawn processes. Set `devices` to GPU count and `strategy="ddp"` or `DDPStrategy(...)`. Example for 2 GPUs: `Trainer(accelerator="gpu", devices=2, strategy="ddp")`. + +Pattern B (torchrun launch, recommended for clusters): Use torchrun to launch one process per GPU. Each process should use `devices=1` (or omit devices) and set `strategy="ddp"` or `DDPStrategy(...)`. Example for 2 GPUs on one node: +```bash +export CUDA_VISIBLE_DEVICES=0,1 +torchrun --standalone --nproc_per_node=2 your_script.py +``` + +Recommended environment variables for NCCL: +```bash +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=eno1 +``` + +Set `NCCL_SOCKET_IFNAME` to your actual NIC. If you see hangs, confirm `WORLD_SIZE`, `RANK`, `LOCAL_RANK` are correct. + +## Sklearn-Style Wrappers + +The sklearn-style API uses the same DataModule path when available, keeping the public surface area stable. +```python +from contextualized.easy.wrappers import ContextualizedRegressor + +m = ContextualizedRegressor( + num_archetypes=8, + encoder_type="mlp", + encoder_kwargs={"width": 256, "layers": 2, "link_fn": "identity"}, + trainer_kwargs={ + "accelerator": "gpu", + "devices": 1, + "max_epochs": 50, + }, +) + +m.fit(C_train, X_train, Y_train, val_split=0.2) +yhat = m.predict(C_test, X_test) +betas, mus = m.predict_params(C_test) +``` + +**DDP behavior for wrappers**: Under DDP, rank 0 returns arrays and non-rank-0 returns `None`. If running under torchrun, only rank 0 should consume outputs. Do not stack outputs across ranks. + +## Batch Sizing for Strong Scaling + +Strong scaling means fixed global batch size. Per-GPU batch is `global_batch_size / world_size`. Set one global batch size and keep it fixed. As you add GPUs, let the per-GPU batch shrink. This reduces OOM risk when scaling up GPU count. + +If you hit OOM: reduce global batch first, then reduce model width or layers, then reduce DataLoader workers if CPU memory becomes the bottleneck. + +## Data Movement and Pinned Memory + +Pinned CPU buffers improve host-to-device transfer speed. Recommended for GPU training: `pin_memory=True` in DataLoader, use pinned host tensors in synthetic or streaming benchmarks. For real datasets, use normal CPU tensors and let DataLoader pin memory. + +## Common Pitfalls + +**Wrong batch dict keys**: Models expect batch dict keys `contexts`, `predictors`, `outcomes`. If using custom datasets, match these names. + +**Device mismatch under torchrun**: Under torchrun, processes already map to devices. Use `devices=1` per process (or omit devices), and each process uses `LOCAL_RANK` as its CUDA device. + +**Dropping samples during eval**: Do not drop samples for validation, test, or predict. It breaks ordering assumptions. Use `drop_last=False` for eval loaders. + +**Expecting prediction output on every rank**: Prediction helpers are rank-0-only by design. Non-rank-0 returns `None`. + +## Minimal DDP Launch Recipe + +Single node, 4 GPUs: +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +torchrun --standalone --nproc_per_node=4 train_script.py +``` + +Trainer settings that work well: `strategy=DDPStrategy(find_unused_parameters=False, broadcast_buffers=False)`, mixed precision on GPU if stable for your model, logging sync only on epoch metrics (not per step). + +## Benchmark Pattern for Scaling + +A good scaling benchmark uses fixed global batch size, uses a warmup window then measures steady state, uses already batched inputs to reduce DataLoader overhead, uses pinned CPU memory when measuring host-to-device transfer. Loss computation should be simple and shape-safe. Keep it in the benchmark harness. Avoid clever reshapes that depend on internal model conventions. + +## Summary + +HPC readiness comes from three components: map-style datasets and a DataModule so Lightning can shard correctly, prediction payloads that include stable indices and are CPU-friendly, rank-0 gather and reorder so predictions match user-expected order. Following these patterns ensures multi-GPU training and prediction are stable and repeatable. \ No newline at end of file diff --git a/checkpoints/epoch=0-step=2-v1.ckpt b/checkpoints/epoch=0-step=2-v1.ckpt new file mode 100644 index 00000000..7979456c Binary files /dev/null and b/checkpoints/epoch=0-step=2-v1.ckpt differ diff --git a/checkpoints/epoch=0-step=2-v2.ckpt b/checkpoints/epoch=0-step=2-v2.ckpt new file mode 100644 index 00000000..988fd9a9 Binary files /dev/null and b/checkpoints/epoch=0-step=2-v2.ckpt differ diff --git a/checkpoints/epoch=0-step=2-v3.ckpt b/checkpoints/epoch=0-step=2-v3.ckpt new file mode 100644 index 00000000..3bf8a039 Binary files /dev/null and b/checkpoints/epoch=0-step=2-v3.ckpt differ diff --git a/checkpoints/epoch=0-step=2.ckpt b/checkpoints/epoch=0-step=2.ckpt new file mode 100644 index 00000000..9a5597bb Binary files /dev/null and b/checkpoints/epoch=0-step=2.ckpt differ diff --git a/contextualized/__init__.py b/contextualized/__init__.py index 2f24b87b..3c5e4932 100644 --- a/contextualized/__init__.py +++ b/contextualized/__init__.py @@ -2,7 +2,13 @@ models, distributions, and functions with context-specific parameters. For more details, please refer to contextualized.ml. """ +import torch +if torch.cuda.is_available(): + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass from contextualized import analysis from contextualized import dags from contextualized import easy @@ -10,3 +16,5 @@ from contextualized import baselines from contextualized import utils from contextualized.utils import * + + diff --git a/contextualized/callbacks.py b/contextualized/callbacks.py index b41d4d28..5c81d425 100644 --- a/contextualized/callbacks.py +++ b/contextualized/callbacks.py @@ -87,4 +87,3 @@ def write_on_batch_end( self.arr[n, yi, xi, 0] = beta self.arr[n, yi, xi, 1] = mu - diff --git a/contextualized/data.py b/contextualized/data.py index 9c46291f..a43bc03c 100644 --- a/contextualized/data.py +++ b/contextualized/data.py @@ -1,5 +1,5 @@ import torch -from lightning import LightningDataModule +from lightning.pytorch import LightningDataModule from contextualized.regression.datasets import MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, MultitaskUnivariateDataset from sklearn.model_selection import train_test_split diff --git a/contextualized/easy/ContextualGAM.py b/contextualized/easy/ContextualGAM.py index 5ea6cda5..e09ce295 100644 --- a/contextualized/easy/ContextualGAM.py +++ b/contextualized/easy/ContextualGAM.py @@ -46,4 +46,4 @@ class ContextualGAMRegressor(ContextualizedRegressor): def __init__(self, **kwargs): kwargs["encoder_type"] = "ngam" - super().__init__(**kwargs) + super().__init__(**kwargs) \ No newline at end of file diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 30a9d980..11bb17ef 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -55,4 +55,4 @@ def predict_proba(self, C, X, **kwargs): """ # Returns a np array of shape N samples, K outcomes, 2. probs = super().predict(C, X, **kwargs) - return np.array([1 - probs, probs]).T.swapaxes(0, 1) + return np.array([1 - probs, probs]).T.swapaxes(0, 1) \ No newline at end of file diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 1c4a8f26..ec3e2155 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -5,6 +5,8 @@ from typing import * import numpy as np +import torch +import torch.distributed as dist from contextualized.easy.wrappers import SKLearnWrapper from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer @@ -21,24 +23,65 @@ from contextualized.dags.graph_utils import dag_pred_np +def _is_distributed() -> bool: + """Returns True if torch.distributed is available and initialized.""" + return dist.is_available() and dist.is_initialized() + + +def _rank() -> int: + """Returns the current distributed rank, defaulting to 0 when not distributed.""" + if _is_distributed(): + return dist.get_rank() + return 0 + + class ContextualizedNetworks(SKLearnWrapper): """ sklearn-like interface to Contextualized Networks. """ def _split_train_data( - self, C: np.ndarray, X: np.ndarray, **kwargs - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray] = None, + *, + Y_required: bool = False, + val_split: Optional[float] = None, + random_state: Optional[int] = None, + shuffle: bool = True, + **kwargs, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """Splits data into train and test sets. + Notes: + This override exists to set the default behavior for networks (Y is not required), + while preserving compatibility with SKLearnWrapper._split_train_data. + Args: C (np.ndarray): Contextual features for each sample. X (np.ndarray): The data matrix. + Y (Optional[np.ndarray], optional): Optional targets. Defaults to None. + Y_required (bool, optional): Whether Y is required. Defaults to False. + val_split (Optional[float], optional): Validation split fraction. Defaults to None. + random_state (Optional[int], optional): Random state for splitting. Defaults to None. + shuffle (bool, optional): Whether to shuffle before splitting. Defaults to True. + **kwargs: Additional keyword arguments forwarded to the base implementation. Returns: - Tuple[List[np.ndarray], List[np.ndarray]]: The train and test sets for C and X as ([C_train, X_train], [C_test, X_test]). + Tuple[np.ndarray, Optional[np.ndarray]]: The train/test split outputs as returned by + SKLearnWrapper._split_train_data. """ - return super()._split_train_data(C, X, Y_required=False, **kwargs) + return super()._split_train_data( + C, + X, + Y, + Y_required=Y_required, + val_split=val_split, + random_state=random_state, + shuffle=shuffle, + **kwargs, + ) def predict_networks( self, @@ -51,20 +94,39 @@ def predict_networks( List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]], + None, ]: """Predicts context-specific networks given contextual features. + Notes: + Under DDP, prediction helpers are rank-0 only (by design in the trainers/wrapper). + In such cases, this method returns None on non-rank-0 processes. + Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - with_offsets (bool, optional): If True, returns both the network parameters and offsets. Defaults to False. - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + with_offsets (bool, optional): If True, returns both the network parameters and + offsets (when available). Defaults to False. + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to False. + **kwargs: Keyword arguments forwarded to predict_params. Returns: - Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True. + Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], + Tuple[List[np.ndarray], List[np.ndarray]], None]: + The predicted network parameters (and offsets if with_offsets is True). + Returned as lists of individual bootstraps if individual_preds is True. + Returns None on non-rank-0 under DDP. """ - betas, mus = self.predict_params( + out = self.predict_params( C, individual_preds=individual_preds, uses_y=False, **kwargs ) + if out is None: + return None + + betas, mus = out + if betas is None: + return None + if with_offsets: return betas, mus return betas @@ -72,33 +134,38 @@ def predict_networks( def predict_X( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """Reconstructs the data matrix based on predicted contextualized networks and the true data matrix. + """Reconstructs the data matrix based on predicted contextualized networks and + the true data matrix. + Useful for measuring reconstruction error or for imputation. Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - **kwargs: Keyword arguments for the Lightning trainer's predict_y method. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to False. + **kwargs: Keyword arguments for the Lightning trainer's prediction method. Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for each bootstrap if individual_preds is True (n_samples, n_features). + Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for + each bootstrap if individual_preds is True (n_samples, n_features). """ return self.predict(C, X, individual_preds=individual_preds, **kwargs) class ContextualizedCorrelationNetworks(ContextualizedNetworks): """ - Contextualized Correlation Networks reveal context-varying feature correlations, interaction strengths, dependencies in feature groups. - Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextualized Correlation Networks reveal context-varying feature correlations, + interaction strengths, and dependencies in feature groups. + + Uses the Contextualized Networks model. + + Notes: + This implementation includes CPU/DDP-safe prediction behavior: + - When using a LightningDataModule outside Trainer.fit/predict, setup(stage="predict") + is called before predict_dataloader(). + - Under DDP, only rank-0 returns numpy outputs; non-rank-0 returns None, while still + executing the per-model predict loop to avoid collective mismatches/hangs. """ def __init__(self, **kwargs): @@ -108,74 +175,172 @@ def __init__(self, **kwargs): def predict_correlation( self, C: np.ndarray, individual_preds: bool = True, squared: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Predicts context-specific correlations between features. + Notes: + Under DDP, only rank-0 returns numpy outputs. If any per-model prediction returns + None (rank-0-only behavior), this method returns None. + Args: - C (Numpy ndarray): Contextual features for each sample (n_samples, n_context_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to True. squared (bool, optional): If True, returns the squared correlations. Defaults to True. Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted context-specific correlation matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). + Union[np.ndarray, List[np.ndarray], None]: + The predicted context-specific correlation matrices, or matrices for each + bootstrap if individual_preds is True (n_samples, n_features, n_features). + Returns None on non-rank-0 under DDP. """ - get_dataloader = lambda i: self.models[i].dataloader( - C, np.zeros((len(C), self.x_dim)) - ) - rhos = np.array( - [ - self.trainers[i].predict_params(self.models[i], get_dataloader(i))[0] - for i in range(len(self.models)) - ] + C_scaled = self._maybe_scale_C(C) + Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + + dm = self._build_datamodule( + C=C_scaled, + X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), + Y=Y_zero, + predict_idx=np.arange(len(C_scaled)), + data_kwargs=dict( + train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", 16 + ), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", (self.accelerator in ("cuda", "gpu")) + ), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), + drop_last=False, + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate", ) + + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + saw_none = False + rhos_list: List[np.ndarray] = [] + + for i in range(len(self.models)): + rho_i = self.trainers[i].predict_correlation(self.models[i], pred_loader) + if rho_i is None: + saw_none = True + continue + rhos_list.append(rho_i) + + if saw_none: + return None + + rhos = np.array(rhos_list) + if individual_preds: if squared: return np.square(rhos) return rhos - else: - if squared: - return np.square(np.mean(rhos, axis=0)) - return np.mean(rhos, axis=0) + + mean_rhos = np.mean(rhos, axis=0) + if squared: + return np.square(mean_rhos) + return mean_rhos def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Measures mean-squared errors. + Notes: + This method computes MSEs from reconstructions returned by predict_X, including + handling potential (bootstrap, sample, feature) or (bootstrap, sample, feature, feature) + tensor shapes, and handling N_hat != N_true by truncation to min(N_hat, N_true). + Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ - betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) - mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for i in range(X.shape[-1]): - for j in range(X.shape[-1]): - tiled_xi = np.array([X[:, i] for _ in range(len(betas))]) - tiled_xj = np.array([X[:, j] for _ in range(len(betas))]) - residuals = tiled_xi - betas[:, :, i, j] * tiled_xj - mus[:, :, i, j] - mses += residuals**2 / (X.shape[-1] ** 2) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + X_hat = self.predict_X(C, X, individual_preds=True) + if X_hat is None: + return None + + X_hat = np.array(X_hat) + + if X_hat.ndim not in (3, 4): + raise ValueError( + f"Unexpected X_hat ndim={X_hat.ndim} with shape {X_hat.shape} in " + "ContextualizedCorrelationNetworks.measure_mses" + ) + + N_true, F = X.shape + + if X_hat.ndim == 3: + B, N_hat, F_hat = X_hat.shape + if F_hat != F: + raise ValueError( + f"Feature dimension mismatch between X_hat (F={F_hat}) and X (F={F}) " + "in ContextualizedCorrelationNetworks.measure_mses" + ) + + N_eff = min(N_hat, N_true) + if N_hat != N_true: + X_hat = X_hat[:, :N_eff, :] + X_eff = X[:N_eff, :] + else: + X_eff = X + + X_true = X_eff[None, :, :] + residuals = X_hat - X_true + mses = (residuals**2).mean(axis=-1) + + else: + B, N_hat, F1, F2 = X_hat.shape + if F1 != F: + raise ValueError( + f"Feature dimension mismatch between X_hat (F1={F1}) and X (F={F}) " + "in ContextualizedCorrelationNetworks.measure_mses" + ) + + N_eff = min(N_hat, N_true) + if N_hat != N_true: + X_hat = X_hat[:, :N_eff, :, :] + X_eff = X[:N_eff, :] + else: + X_eff = X + + X_true = X_eff[None, :, :, None] + residuals = X_hat - X_true + mses = (residuals**2).mean(axis=(-1, -2)) + + if individual_preds: + return mses + return mses.mean(axis=0) class ContextualizedMarkovNetworks(ContextualizedNetworks): """ - Contextualized Markov Networks reveal context-varying feature dependencies, cliques, and modules. - Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as Markov Networks. - Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextualized Markov Networks reveal context-varying feature dependencies, cliques, + and modules. + + Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as + Markov Networks. + + Notes: + This implementation includes CPU/DDP-safe prediction behavior analogous to + ContextualizedCorrelationNetworks.predict_correlation. """ def __init__(self, **kwargs): @@ -183,98 +348,135 @@ def __init__(self, **kwargs): def predict_precisions( self, C: np.ndarray, individual_preds: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Predicts context-specific precision matrices. - Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1. - Can be converted to context-specific covariance matrices by taking the inverse. + + Notes: + Under DDP, only rank-0 returns numpy outputs. If any per-model prediction returns + None (rank-0-only behavior), this method returns None. Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to True. Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted context-specific Markov networks as precision matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). + Union[np.ndarray, List[np.ndarray], None]: + The predicted context-specific precision matrices, or matrices for each + bootstrap if individual_preds is True (n_samples, n_features, n_features). + Returns None on non-rank-0 under DDP. """ - get_dataloader = lambda i: self.models[i].dataloader( - C, np.zeros((len(C), self.x_dim)) - ) - precisions = np.array( - [ - self.trainers[i].predict_precision(self.models[i], get_dataloader(i)) - for i in range(len(self.models)) - ] + C_scaled = self._maybe_scale_C(C) + Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + + dm = self._build_datamodule( + C=C_scaled, + X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), + Y=Y_zero, + predict_idx=np.arange(len(C_scaled)), + data_kwargs=dict( + train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", 16 + ), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", (self.accelerator in ("cuda", "gpu")) + ), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), + drop_last=False, + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate", ) + + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + saw_none = False + prec_list: List[np.ndarray] = [] + + for i in range(len(self.models)): + p_i = self.trainers[i].predict_precision(self.models[i], pred_loader) + if p_i is None: + saw_none = True + continue + prec_list.append(p_i) + + if saw_none: + return None + + precisions = np.array(prec_list) if individual_preds: return precisions return np.mean(precisions, axis=0) def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Measures mean-squared errors. Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ - betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) + out = self.predict_networks(C, individual_preds=True, with_offsets=True) + if out is None: + return None + betas, mus = out + mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for bootstrap in range(len(betas)): - for i in range(X.shape[-1]): - # betas are n_boostraps x n_samples x n_features x n_features - # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :]) + F = X.shape[-1] + for b in range(len(betas)): + for i in range(F): preds = np.array( [ - X[j].dot(betas[bootstrap, j, i, :]) + mus[bootstrap, j, i] + X[j].dot(betas[b, j, i, :]) + mus[b, j, i] for j in range(len(X)) ] ) residuals = X[:, i] - preds - mses[bootstrap, :] += residuals**2 / (X.shape[-1]) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + mses[b, :] += residuals**2 / F + + if individual_preds: + return mses + return np.mean(mses, axis=0) class ContextualizedBayesianNetworks(ContextualizedNetworks): """ - Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal context-dependent causal relationships, effect sizes, and variable ordering. - Uses the NOTMAD model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 16. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - archetype_dag_loss_type (str, optional): The type of loss to use for the archetype loss. Defaults to "l1". - archetype_l1 (float, optional): The strength of the l1 regularization for the archetype loss. Defaults to 0.0. - archetype_dag_params (dict, optional): Parameters for the archetype loss. Defaults to {"loss_type": "l1", "params": {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}}. - archetype_dag_loss_params (dict, optional): Parameters for the archetype loss. Defaults to {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}. - archetype_alpha (float, optional): The strength of the alpha regularization for the archetype loss. Defaults to 0.0. - archetype_rho (float, optional): The strength of the rho regularization for the archetype loss. Defaults to 0.0. - archetype_s (float, optional): The strength of the s regularization for the archetype loss. Defaults to 0.0. - archetype_tol (float, optional): The tolerance for the archetype loss. Defaults to 1e-4. - archetype_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the archetype loss. Defaults to False. - init_mat (np.ndarray, optional): The initial adjacency matrix for the archetype loss. Defaults to None. - num_factors (int, optional): The number of factors for the archetype loss. Defaults to 0. - factor_mat_l1 (float, optional): The strength of the l1 regularization for the factor matrix for the archetype loss. Defaults to 0. - sample_specific_dag_loss_type (str, optional): The type of loss to use for the sample-specific loss. Defaults to "l1". - sample_specific_alpha (float, optional): The strength of the alpha regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_rho (float, optional): The strength of the rho regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_s (float, optional): The strength of the s regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_tol (float, optional): The tolerance for the sample-specific loss. Defaults to 1e-4. - sample_specific_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the sample-specific loss. Defaults to False. + Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal + context-dependent causal relationships, effect sizes, and variable ordering. + + Uses the NOTMAD model. + + Notes: + This wrapper preserves the HPC/DDP behavior: rank-0 produces arrays, non-rank-0 + returns None where applicable. """ def _parse_private_init_kwargs(self, **kwargs): - """ - Parses the kwargs for the NOTMAD model. + """Parses the kwargs for the NOTMAD model. Args: - **kwargs: Keyword arguments for the NOTMAD model, including the encoder, archetype loss, sample-specific loss, and optimization parameters. + **kwargs: Keyword arguments for the NOTMAD model, including the encoder, + archetype loss, sample-specific loss, and optimization parameters. + + Returns: + List[str]: Names of kwargs consumed/handled by this parser. """ # Encoder Parameters self._init_kwargs["model"]["encoder_kwargs"] = { @@ -323,11 +525,11 @@ def _parse_private_init_kwargs(self, **kwargs): self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][ param ] = kwargs.pop(f"archetype_{param}", value) + + # Sample-specific parameters sample_specific_dag_loss_type = kwargs.pop( "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) - - # Sample-specific parameters self._init_kwargs["model"]["sample_specific_loss_params"] = { "l1": kwargs.pop("sample_specific_l1", 0.0), "dag": kwargs.pop( @@ -336,7 +538,9 @@ def _parse_private_init_kwargs(self, **kwargs): "loss_type": sample_specific_dag_loss_type, "params": kwargs.pop( "sample_specific_dag_loss_params", - DEFAULT_DAG_LOSS_PARAMS[sample_specific_dag_loss_type].copy(), + DEFAULT_DAG_LOSS_PARAMS[ + sample_specific_dag_loss_type + ].copy(), ), }, ), @@ -401,31 +605,41 @@ def __init__(self, **kwargs): def predict_params( self, C: np.ndarray, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: - """Predicts context-specific Bayesian network parameters as linear coefficients in a linear structural equation model (SEM). + ) -> Union[np.ndarray, List[np.ndarray], None]: + """Predicts context-specific Bayesian network parameters as linear coefficients + in a linear structural equation model (SEM). Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. Returns: - Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. + Union[np.ndarray, List[np.ndarray], None]: + The linear coefficients of the predicted context-specific Bayesian network + parameters (n_samples, n_features, n_features). Returned as lists of + individual bootstraps if individual_preds is True. Returns None on + non-rank-0 under DDP. """ # No mus for NOTMAD at present. return super().predict_params(C, model_includes_mus=False, **kwargs) def predict_networks( self, C: np.ndarray, project_to_dag: bool = True, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Predicts context-specific Bayesian networks. Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by trimming edges until acyclicity is satisified. Defaults to True. - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by + trimming edges until acyclicity is satisified. Defaults to True. + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. Returns: - Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. + Union[np.ndarray, List[np.ndarray], None]: + The linear coefficients of the predicted context-specific Bayesian network + parameters (n_samples, n_features, n_features). Returned as lists of + individual bootstraps if individual_preds is True. Returns None on + non-rank-0 under DDP. """ if kwargs.pop("with_offsets", False): print("No offsets can be returned by NOTMAD.") @@ -436,23 +650,30 @@ def predict_networks( def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """Measures mean-squared errors. Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ betas = self.predict_networks(C, individual_preds=True, **kwargs) + if betas is None: + return None + mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for bootstrap in range(len(betas)): - X_pred = dag_pred_np(X, betas[bootstrap]) - mses[bootstrap, :] = np.mean((X - X_pred) ** 2, axis=1) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + for b in range(len(betas)): + X_pred = dag_pred_np(X, betas[b]) + mses[b, :] = np.mean((X - X_pred) ** 2, axis=1) + + if individual_preds: + return mses + return np.mean(mses, axis=0) diff --git a/contextualized/easy/ContextualizedRegressor.py b/contextualized/easy/ContextualizedRegressor.py index 275f2ee9..8e1a350f 100644 --- a/contextualized/easy/ContextualizedRegressor.py +++ b/contextualized/easy/ContextualizedRegressor.py @@ -7,10 +7,7 @@ ContextualizedRegression, ) from contextualized.easy.wrappers import SKLearnWrapper -from contextualized.regression import RegressionTrainer - -# TODO: Multitask metamodels -# TODO: Task-specific link functions. +from contextualized.regression.trainers import RegressionTrainer class ContextualizedRegressor(SKLearnWrapper): @@ -35,12 +32,11 @@ def __init__(self, **kwargs): elif self.num_archetypes > 0: constructor = ContextualizedRegression else: - print( - f""" - Was told to construct a ContextualizedRegressor with {self.num_archetypes} - archetypes, but this should be a non-negative integer.""" + raise ValueError( + f"num_archetypes must be a non-negative integer, got {self.num_archetypes}." ) + extra_model_kwargs = ["base_param_predictor", "base_y_predictor", "y_dim"] extra_data_kwargs = ["Y_val"] trainer_constructor = RegressionTrainer @@ -53,4 +49,4 @@ def __init__(self, **kwargs): ) def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): - return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) + return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) \ No newline at end of file diff --git a/contextualized/easy/tests.py b/contextualized/easy/tests.py index 2f368468..5ca6d269 100644 --- a/contextualized/easy/tests.py +++ b/contextualized/easy/tests.py @@ -398,4 +398,4 @@ def test_regressor_normalization(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 101966e8..83968ad4 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -4,17 +4,26 @@ import copy import os -from typing import * +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks import ModelCheckpoint +import torch +import torch.distributed as dist +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.strategies import DDPStrategy from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler -import torch from contextualized.functions import LINK_FUNCTIONS -from contextualized.regression import REGULARIZERS, LOSSES +from contextualized.regression import LOSSES, REGULARIZERS + +# Prefer the new, DDP-safe DataModule path when available. +try: + from contextualized.regression.datamodules import ContextualizedRegressionDataModule +except Exception: + ContextualizedRegressionDataModule = None + DEFAULT_LEARNING_RATE = 1e-3 DEFAULT_N_BOOTSTRAPS = 1 @@ -30,26 +39,50 @@ DEFAULT_NORMALIZE = False +def _dist_initialized() -> bool: + return dist.is_available() and dist.is_initialized() + + +def _rank() -> int: + if _dist_initialized(): + return int(dist.get_rank()) + return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) + + +def _is_main_process() -> bool: + return _rank() == 0 + + +def _world_size_env() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 + + class SKLearnWrapper: """ An sklearn-like wrapper for Contextualized models. Args: - base_constructor (class): The base class to construct the model. - extra_model_kwargs (dict): Extra kwargs to pass to the model constructor. - extra_data_kwargs (dict): Extra kwargs to pass to the dataloader constructor. - trainer_constructor (class): The trainer class to use. + base_constructor (callable/class): LightningModule constructor for the model. + extra_model_kwargs (list[str] or set[str]): Extra kw names allowed in "model". + extra_data_kwargs (list[str] or set[str]): Extra kw names allowed in "data". + trainer_constructor (class): Trainer class (should provide predict_y / predict_params for DDP-safe inference). n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"]. link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"]. alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. - normalize (bool, optional): If True, automatically standardize inputs during training and inverse-transform predictions. Defaults to False. + mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to + context-specific parameters or context-specific offsets. + l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 + vs l2 parameter norms. + normalize (bool, optional): If True, automatically standardize inputs during training and inverse-transform + predictions. Defaults to False. """ - def _set_defaults(self): + def _set_defaults(self) -> None: self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS self.default_es_patience = DEFAULT_ES_PATIENCE @@ -72,27 +105,46 @@ def __init__( **kwargs, ): self._set_defaults() + self.base_constructor = base_constructor - self.n_bootstraps = 1 - self.models = None - self.trainers = None - self.dataloaders = None - self.normalize = kwargs.pop("normalize", self.default_normalize) - self.scalers = {"C": None, "X": None, "Y": None} - self.context_dim = None - self.x_dim = None - self.y_dim = None self.trainer_constructor = trainer_constructor - self.accelerator = "gpu" if torch.cuda.is_available() else "cpu" - self.acceptable_kwargs = { + + self._trainer_init_kwargs = kwargs.pop("trainer_kwargs", None) + + self.n_bootstraps: int = 1 + self.models: Optional[List[Any]] = None + self.trainers: Optional[List[Any]] = None + self.dataloaders: Optional[Dict[str, List[Any]]] = None + + self.normalize: bool = bool(kwargs.pop("normalize", self.default_normalize)) + self.scalers: Dict[str, Optional[StandardScaler]] = {"C": None, "X": None, "Y": None} + + self.context_dim: Optional[int] = None + self.x_dim: Optional[int] = None + self.y_dim: Optional[int] = None + + self.accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + + self.acceptable_kwargs: Dict[str, List[str]] = { "data": [ "train_batch_size", "val_batch_size", "test_batch_size", + "predict_batch_size", "C_val", "X_val", + "Y_val", "val_split", + "random_state", + "num_workers", + "pin_memory", + "persistent_workers", + "drop_last", + "shuffle_train", + "shuffle_eval", + "dtype", ], + "model": [ "loss_fn", "link_fn", @@ -104,6 +156,10 @@ def __init__( "learning_rate", "context_dim", "x_dim", + "y_dim", + "width", + "layers", + "encoder_link_fn", ], "trainer": [ "max_epochs", @@ -112,6 +168,17 @@ def __init__( "callbacks", "callback_constructors", "accelerator", + "devices", + "strategy", + "plugins", + "logger", + "enable_checkpointing", + "num_sanity_val_steps", + "default_root_dir", + "log_every_n_steps", + "precision", + "enable_progress_bar", + "limit_val_batches", ], "fit": [], "wrapper": [ @@ -124,6 +191,7 @@ def __init__( "normalize", ], } + self._update_acceptable_kwargs("model", extra_model_kwargs) self._update_acceptable_kwargs("data", extra_data_kwargs) self._update_acceptable_kwargs( @@ -132,6 +200,7 @@ def __init__( self._update_acceptable_kwargs( "data", kwargs.pop("remove_data_kwargs", []), acceptable=False ) + self.convenience_kwargs = [ "alpha", "l1_ratio", @@ -141,124 +210,71 @@ def __init__( "layers", "encoder_link_fn", ] + self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) - self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( - "width", self.constructor_kwargs["encoder_kwargs"]["width"] - ) - self.constructor_kwargs["encoder_kwargs"]["layers"] = kwargs.pop( - "layers", self.constructor_kwargs["encoder_kwargs"]["layers"] - ) - self.constructor_kwargs["encoder_kwargs"]["link_fn"] = kwargs.pop( - "encoder_link_fn", - self.constructor_kwargs["encoder_kwargs"].get( - "link_fn", self.default_encoder_link_fn - ), - ) + + if "encoder_kwargs" in self.constructor_kwargs: + ek = self.constructor_kwargs["encoder_kwargs"] + ek["width"] = kwargs.pop("width", ek.get("width", self.default_encoder_width)) + ek["layers"] = kwargs.pop("layers", ek.get("layers", self.default_encoder_layers)) + ek["link_fn"] = kwargs.pop( + "encoder_link_fn", ek.get("link_fn", self.default_encoder_link_fn) + ) + else: + self.constructor_kwargs["width"] = kwargs.pop( + "width", self.constructor_kwargs.get("width", self.default_encoder_width) + ) + self.constructor_kwargs["layers"] = kwargs.pop( + "layers", self.constructor_kwargs.get("layers", self.default_encoder_layers) + ) + self.constructor_kwargs["encoder_link_fn"] = kwargs.pop( + "encoder_link_fn", + self.constructor_kwargs.get("encoder_link_fn", self.default_encoder_link_fn), + ) + self.not_constructor_kwargs = { k: v for k, v in kwargs.items() if k not in self.constructor_kwargs and k not in self.convenience_kwargs } - # Some args will not be ignored by wrapper because sub-class will handle them. - # self.private_kwargs = kwargs.pop("private_kwargs", []) - # self.private_kwargs.append("private_kwargs") - # Add Predictor-Specific kwargs for parsing. - self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs( - **self.not_constructor_kwargs - ) - for key, value in self.constructor_kwargs.items(): - self._init_kwargs["model"][key] = value - recognized_private_init_kwargs = self._parse_private_init_kwargs(**kwargs) - for kwarg in set(unrecognized_general_kwargs) - set( - recognized_private_init_kwargs - ): - print(f"Received unknown keyword argument {kwarg}, probably ignoring.") - def _organize_and_expand_fit_kwargs(self, **kwargs): - """ - Private function to organize kwargs passed to constructor or - fit function. - """ - organized_kwargs, unrecognized_general_kwargs = self._organize_kwargs(**kwargs) - recognized_private_kwargs = self._parse_private_fit_kwargs(**kwargs) - for kwarg in set(unrecognized_general_kwargs) - set(recognized_private_kwargs): - print(f"Received unknown keyword argument {kwarg}, probably ignoring.") - # Add kwargs from __init__ to organized_kwargs, keeping more recent kwargs. - for category, category_kwargs in self._init_kwargs.items(): - for key, value in category_kwargs.items(): - if key not in organized_kwargs[category]: - organized_kwargs[category][key] = value - - # Add necessary kwargs. - def maybe_add_kwarg(category, kwarg, default_val): - if kwarg in self.acceptable_kwargs[category]: - organized_kwargs[category][kwarg] = organized_kwargs[category].get( - kwarg, default_val - ) + self._init_kwargs, unrecognized = self._organize_kwargs(**self.not_constructor_kwargs) - # Model - maybe_add_kwarg("model", "learning_rate", self.default_learning_rate) - maybe_add_kwarg("model", "context_dim", self.context_dim) - maybe_add_kwarg("model", "x_dim", self.x_dim) - maybe_add_kwarg("model", "y_dim", self.y_dim) - if ( - "num_archetypes" in organized_kwargs["model"] - and organized_kwargs["model"]["num_archetypes"] == 0 - ): - del organized_kwargs["model"]["num_archetypes"] + for k, v in self.constructor_kwargs.items(): + self._init_kwargs["model"][k] = v - # Data - maybe_add_kwarg("data", "train_batch_size", self.default_train_batch_size) - maybe_add_kwarg("data", "val_batch_size", self.default_val_batch_size) - maybe_add_kwarg("data", "test_batch_size", self.default_test_batch_size) + if isinstance(self._trainer_init_kwargs, dict): + self._init_kwargs["trainer"].update(self._trainer_init_kwargs) - # Wrapper - maybe_add_kwarg("wrapper", "n_bootstraps", self.default_n_bootstraps) - - # Trainer - maybe_add_kwarg( - "trainer", - "callback_constructors", - [ - lambda i: EarlyStopping( - monitor=kwargs.get("es_monitor", "val_loss"), - mode=kwargs.get("es_mode", "min"), - patience=kwargs.get("es_patience", self.default_es_patience), - verbose=kwargs.get("es_verbose", False), - min_delta=kwargs.get("es_min_delta", 0.00), - ) - ], - ) - organized_kwargs["trainer"]["callback_constructors"].append( - lambda i: ModelCheckpoint( - monitor=kwargs.get("es_monitor", "val_loss"), - dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename="{epoch}-{val_loss:.2f}", - ) - ) - maybe_add_kwarg("trainer", "accelerator", self.accelerator) - return organized_kwargs + recognized_private = set(self._parse_private_init_kwargs(**kwargs)) + for kw in unrecognized: + if kw not in recognized_private: + print(f"Received unknown keyword argument {kw}, probably ignoring.") - def _parse_private_fit_kwargs(self, **kwargs): + def _parse_private_fit_kwargs(self, **kwargs) -> List[str]: """ Parse private (model-specific) kwargs passed to fit function. Return the list of parsed kwargs. """ return [] - def _parse_private_init_kwargs(self, **kwargs): + def _parse_private_init_kwargs(self, **kwargs) -> List[str]: """ Parse private (model-specific) kwargs passed to constructor. Return the list of parsed kwargs. """ return [] - def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): + def _update_acceptable_kwargs( + self, category, new_kwargs, acceptable: bool = True + ) -> None: """ Helper function to update the acceptable kwargs. + If acceptable=True, the new kwargs will be added to the list of acceptable kwargs. If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs. """ + new_kwargs = list(new_kwargs) if new_kwargs is not None else [] if acceptable: self.acceptable_kwargs[category] = list( set(self.acceptable_kwargs[category]).union(set(new_kwargs)) @@ -268,139 +284,71 @@ def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): set(self.acceptable_kwargs[category]) - set(new_kwargs) ) - def _organize_kwargs(self, **kwargs): + def _organize_kwargs(self, **kwargs) -> Tuple[Dict[str, Dict[str, Any]], List[str]]: """ - Private helper function to organize kwargs passed to constructor or - fit function. + Private helper function to organize kwargs passed to constructor or fit function. Organizes kwargs into data, model, trainer, fit, and wrapper categories. """ - - # Combine default allowed keywords with subclass-specfic - organized_kwargs = {category: {} for category in self.acceptable_kwargs} - unrecognized_kwargs = [] - for kwarg, value in kwargs.items(): - # if kwarg in self.private_kwargs: - # continue - not_found = True - for category, category_kwargs in self.acceptable_kwargs.items(): - if kwarg in category_kwargs: - organized_kwargs[category][kwarg] = value - not_found = False + out = {cat: {} for cat in self.acceptable_kwargs} + unknown: List[str] = [] + for k, v in kwargs.items(): + placed = False + for cat, allowed in self.acceptable_kwargs.items(): + if k in allowed: + out[cat][k] = v + placed = True break - if not_found: - unrecognized_kwargs.append(kwarg) - - return organized_kwargs, unrecognized_kwargs + if not placed: + unknown.append(k) + return out, unknown - def _organize_constructor_kwargs(self, **kwargs): + def _organize_constructor_kwargs(self, **kwargs) -> Dict[str, Any]: """ - Helper function to set all the default constructor or changes allowed. + Helper function to set all the default constructor kwargs or changes allowed. """ - constructor_kwargs = {} - - def maybe_add_constructor_kwarg(kwarg, default_val): - if kwarg in self.acceptable_kwargs["model"]: - constructor_kwargs[kwarg] = kwargs.get(kwarg, default_val) - - maybe_add_constructor_kwarg("link_fn", LINK_FUNCTIONS["identity"]) - maybe_add_constructor_kwarg("univariate", False) - maybe_add_constructor_kwarg("encoder_type", self.default_encoder_type) - maybe_add_constructor_kwarg("loss_fn", LOSSES["mse"]) - maybe_add_constructor_kwarg( - "encoder_kwargs", - { - "width": kwargs.get("encoder_width", self.default_encoder_width), - "layers": kwargs.get("encoder_layers", self.default_encoder_layers), - "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn), - }, - ) - if kwargs.get("subtype_probabilities", False): - constructor_kwargs["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] + ctor: Dict[str, Any] = {} + + def maybe_add(kw, default_val): + if kw in self.acceptable_kwargs["model"]: + ctor[kw] = kwargs.get(kw, default_val) + + maybe_add("link_fn", LINK_FUNCTIONS["identity"]) + maybe_add("univariate", False) + maybe_add("encoder_type", self.default_encoder_type) + maybe_add("loss_fn", LOSSES["mse"]) + + if "encoder_kwargs" in self.acceptable_kwargs["model"]: + ctor["encoder_kwargs"] = kwargs.get( + "encoder_kwargs", + { + "width": kwargs.get("encoder_width", self.default_encoder_width), + "layers": kwargs.get("encoder_layers", self.default_encoder_layers), + "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn), + }, + ) + if kwargs.get("subtype_probabilities", False): + ctor["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] + else: + maybe_add("width", self.default_encoder_width) + maybe_add("layers", self.default_encoder_layers) + maybe_add("encoder_link_fn", self.default_encoder_link_fn) + if kwargs.get("subtype_probabilities", False): + ctor["encoder_link_fn"] = LINK_FUNCTIONS["softmax"] - # Make regularizer if "model_regularizer" in self.acceptable_kwargs["model"]: - if "alpha" in kwargs and kwargs["alpha"] > 0: - constructor_kwargs["model_regularizer"] = REGULARIZERS["l1_l2"]( - kwargs["alpha"], + alpha = float(kwargs.get("alpha", 0.0) or 0.0) + if alpha > 0: + ctor["model_regularizer"] = REGULARIZERS["l1_l2"]( + alpha, kwargs.get("l1_ratio", 1.0), kwargs.get("mu_ratio", 0.5), ) else: - constructor_kwargs["model_regularizer"] = kwargs.get( + ctor["model_regularizer"] = kwargs.get( "model_regularizer", REGULARIZERS["none"] ) - return constructor_kwargs - - def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): - if "C_val" in kwargs: - if "X_val" in kwargs: - if Y_required and "Y_val" in kwargs: - train_data = [C, X, Y] - val_data = [kwargs["C_val"], X, kwargs["X_val"], Y, kwargs["Y_val"]] - return train_data, val_data - print("Y_val not provided, not using the provided C_val or X_val.") - else: - print("X_val not provided, not using the provided C_val.") - if "val_split" in kwargs: - if 0 <= kwargs["val_split"] < 1: - val_split = kwargs["val_split"] - else: - print( - """val_split={kwargs['val_split']} provided but should be between 0 - and 1 to indicate proportion of data to use as validation.""" - ) - raise ValueError - else: - val_split = self.default_val_split - if Y is None: - if val_split > 0: - C_train, C_val, X_train, X_val = train_test_split( - C, X, test_size=val_split, shuffle=True - ) - else: - C_train, X_train = C, X - C_val, X_val = C, X - train_data = [C_train, X_train] - val_data = [C_val, X_val] - else: - if val_split > 0: - C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split( - C, X, Y, test_size=val_split, shuffle=True - ) - else: - C_train, X_train, Y_train = C, X, Y - C_val, X_val, Y_val = C, X, Y - train_data = [C_train, X_train, Y_train] - val_data = [C_val, X_val, Y_val] - return train_data, val_data - def _build_dataloader(self, model, batch_size, *data): - """ - Helper function to build a single dataloder. - Expects *args to contain whatever data (C,X,Y) is necessary for this model. - """ - return model.dataloader(*data, batch_size=batch_size) - - def _build_dataloaders(self, model, train_data, val_data, **kwargs): - """ - :param model: - :param **kwargs: - """ - train_dataloader = self._build_dataloader( - model, - kwargs.get("train_batch_size", self.default_train_batch_size), - *train_data, - ) - if val_data is None: - val_dataloader = None - else: - val_dataloader = self._build_dataloader( - model, - kwargs.get("val_batch_size", self.default_val_batch_size), - *val_data, - ) - - return train_dataloader, val_dataloader + return ctor def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: if self.normalize and self.scalers["C"] is not None: @@ -412,104 +360,281 @@ def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: return self.scalers["X"].transform(X) return X - def predict( - self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: - """Predict outcomes from context C and predictors X. + def _nanrobust_mean(self, arr: np.ndarray, axis: int = 0) -> np.ndarray: + if not np.isfinite(arr).all(): + arr = np.where(np.isfinite(arr), arr, np.nan) + with np.errstate(invalid="ignore"): + mean = np.nanmean(arr, axis=axis) + if np.isnan(mean).any(): + raise RuntimeError( + "All bootstraps produced non-finite predictions for some items." + ) + return mean - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - X (np.ndarray): Predictor array of shape (N, n_features) - individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. + def _default_num_workers(self, devices: int) -> int: + try: + n_cpu = os.cpu_count() or 0 + except Exception: + n_cpu = 0 + if n_cpu <= 0: + return 0 + if self.accelerator != "gpu": + return min(2, n_cpu) - Returns: - Union[np.ndarray, List[np.ndarray]]: The outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. - """ - if not hasattr(self, "models") or self.models is None: - raise ValueError( - "Trying to predict with a model that hasn't been trained yet." - ) - predictions = np.array( - [ - self.trainers[i].predict_y( - self.models[i], - self.models[i].dataloader( - self._maybe_scale_C(C), - self._maybe_scale_X(X), - np.zeros((len(C), self.y_dim)), - ), - **kwargs, - ) - for i in range(len(self.models)) - ] - ) - if individual_preds: - preds = predictions - else: - preds = np.mean(predictions, axis=0) - if self.normalize and self.scalers["Y"] is not None: - if individual_preds: - preds = np.array([self.scalers["Y"].inverse_transform(p) for p in preds]) + world = max(1, _world_size_env() if _world_size_env() > 1 else devices) + cpu_per_rank = max(1, n_cpu // world) + return int(min(4, max(2, cpu_per_rank // 2))) + + def _safe_val_split(self, n: int, val_split: float) -> float: + vs = float(val_split) + if vs <= 0.0: + return 0.0 + if int(round(n * vs)) < 2: + return 0.0 + return vs + + def _resolve_train_val_arrays( + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray], + *, + C_val: Optional[np.ndarray], + X_val: Optional[np.ndarray], + Y_val: Optional[np.ndarray], + Y_required: bool, + val_split: float, + random_state: Optional[int] = None, + shuffle: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: + if ( + C_val is not None + and X_val is not None + and (not Y_required or Y_val is not None) + ): + n_tr = int(C.shape[0]) + C_all = np.concatenate([C, C_val], axis=0) + X_all = np.concatenate([X, X_val], axis=0) + + if Y is None: + Y_all = None else: - preds = self.scalers["Y"].inverse_transform(preds) - return preds + if Y_val is None and Y_required: + raise ValueError("Y_val is required when Y is provided.") + Y_all = np.concatenate([Y, Y_val], axis=0) if Y_val is not None else Y - def predict_params( + train_idx = np.arange(n_tr) + val_idx = np.arange(n_tr, int(C_all.shape[0])) + return C_all, X_all, Y_all, train_idx, val_idx + + n = int(C.shape[0]) + vs = self._safe_val_split(n, val_split) + if vs <= 0.0: + return C, X, Y, np.arange(n), None + + split_kwargs = dict(test_size=vs, shuffle=shuffle) + if random_state is not None: + split_kwargs["random_state"] = random_state + + tr_idx, va_idx = train_test_split(np.arange(n), **split_kwargs) + return C, X, Y, tr_idx, va_idx + + + def _build_datamodule( self, C: np.ndarray, - individual_preds: bool = False, - model_includes_mus: bool = True, - **kwargs, - ) -> Union[ - np.ndarray, - List[np.ndarray], - Tuple[np.ndarray, np.ndarray], - Tuple[List[np.ndarray], List[np.ndarray]], - ]: - """ - Predict context-specific model parameters from context C. + X: np.ndarray, + Y: Optional[np.ndarray], + *, + train_idx: Optional[np.ndarray], + val_idx: Optional[np.ndarray], + test_idx: Optional[np.ndarray], + predict_idx: Optional[np.ndarray], + data_kwargs: Dict[str, Any], + task_type: str, + ): + if ContextualizedRegressionDataModule is None: + raise RuntimeError( + "ContextualizedRegressionDataModule is not available in this installation." + ) - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. Defaults to False, averaging across bootstraps. - model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). Defaults to True. + dk = { + "train_batch_size": self.default_train_batch_size, + "val_batch_size": self.default_val_batch_size, + "test_batch_size": self.default_test_batch_size, + "predict_batch_size": self.default_val_batch_size, + "num_workers": 0, + "pin_memory": (self.accelerator == "gpu"), + "persistent_workers": False, + "drop_last": False, + "shuffle_train": True, + "shuffle_eval": False, + "dtype": torch.float, + } + dk.update(data_kwargs or {}) - Returns: - Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]: The parameters of the predicted context-specific models. - Returned as lists of individual bootstraps if individual_preds is True, otherwise averages the bootstraps for a better estimate. - If model_includes_mus is True, returns both coefficients and offsets as a tuple of (betas, mus). Otherwise, returns coefficients (betas) only. - For model_includes_mus=True, ([betas], [mus]) if individual_preds is True, otherwise (betas, mus). - For model_includes_mus=False, [betas] if individual_preds is True, otherwise betas. - betas is shape (n_samples, x_dim, y_dim) or (n_samples, x_dim) if y_dim = 1. - mus is shape (n_samples, y_dim) or (n_samples,) if y_dim = 1. + return ContextualizedRegressionDataModule( + C=C, + X=X, + Y=Y, + task_type=task_type, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=predict_idx, + train_batch_size=dk["train_batch_size"], + val_batch_size=dk["val_batch_size"], + test_batch_size=dk["test_batch_size"], + predict_batch_size=dk["predict_batch_size"], + num_workers=dk["num_workers"], + pin_memory=dk["pin_memory"], + persistent_workers=dk["persistent_workers"], + drop_last=dk["drop_last"], + shuffle_train=dk["shuffle_train"], + shuffle_eval=dk["shuffle_eval"], + dtype=dk["dtype"], + ) + + def _use_datamodule_for_model(self, model: Any) -> bool: + if ContextualizedRegressionDataModule is None: + return False + return not callable(getattr(model, "dataloader", None)) + + def _organize_and_expand_fit_kwargs(self, **kwargs) -> Dict[str, Dict[str, Any]]: """ - # Returns betas, mus - if kwargs.pop("uses_y", True): - get_dataloader = lambda i: self.models[i].dataloader( - self._maybe_scale_C(C), - np.zeros((len(C), self.x_dim)), - np.zeros((len(C), self.y_dim)) - ) + Private function to organize kwargs passed to constructor or fit function. + """ + organized, unrecognized = self._organize_kwargs(**kwargs) + recognized_private = set(self._parse_private_fit_kwargs(**kwargs)) + for kw in unrecognized: + if kw not in recognized_private: + print(f"Received unknown keyword argument {kw}, probably ignoring.") + + for category, cat_kwargs in self._init_kwargs.items(): + for k, v in cat_kwargs.items(): + organized[category].setdefault(k, v) + + def maybe_add(cat: str, k: str, default_val: Any) -> None: + if k in self.acceptable_kwargs[cat]: + organized[cat][k] = organized[cat].get(k, default_val) + + maybe_add("model", "learning_rate", self.default_learning_rate) + maybe_add("model", "context_dim", self.context_dim) + maybe_add("model", "x_dim", self.x_dim) + maybe_add("model", "y_dim", self.y_dim) + + if organized["model"].get("num_archetypes", 1) == 0: + organized["model"].pop("num_archetypes", None) + + maybe_add("data", "train_batch_size", self.default_train_batch_size) + maybe_add("data", "val_batch_size", self.default_val_batch_size) + maybe_add("data", "test_batch_size", self.default_test_batch_size) + maybe_add( + "data", + "predict_batch_size", + organized["data"].get("val_batch_size", self.default_val_batch_size), + ) + + maybe_add("trainer", "accelerator", self.accelerator) + organized["trainer"].setdefault("enable_progress_bar", False) + organized["trainer"].setdefault("logger", False) + organized["trainer"].setdefault("num_sanity_val_steps", 0) + + world = _world_size_env() + launched_externally = world > 1 and ( + os.environ.get("LOCAL_RANK") is not None or os.environ.get("RANK") is not None + ) + + if "devices" not in organized["trainer"]: + organized["trainer"]["devices"] = 1 if launched_externally else (world if world > 1 else 1) + + devices_cfg = organized["trainer"].get("devices", 1) + if isinstance(devices_cfg, int): + devices = devices_cfg + elif isinstance(devices_cfg, (list, tuple)): + devices = len(devices_cfg) else: - get_dataloader = lambda i: self.models[i].dataloader( - self._maybe_scale_C(C), - np.zeros((len(C), self.x_dim)) - ) - predictions = [ - self.trainers[i].predict_params(self.models[i], get_dataloader(i), **kwargs) - for i in range(len(self.models)) - ] - if model_includes_mus: - betas = np.array([p[0] for p in predictions]) - mus = np.array([p[1] for p in predictions]) - if individual_preds: - return betas, mus + devices = 1 + + if world > 1 and (not launched_externally) and devices != world: + if _is_main_process(): + print( + f"[WARNING] WORLD_SIZE={world} but devices={devices}; " + f"overriding devices -> {world}." + ) + organized["trainer"]["devices"] = world + devices = world + + + if "strategy" not in organized["trainer"]: + if devices > 1 or world > 1: + organized["trainer"]["strategy"] = DDPStrategy( + find_unused_parameters=False, + broadcast_buffers=False, + process_group_backend="nccl" if torch.cuda.is_available() else "gloo", + ) else: - return np.mean(betas, axis=0), np.mean(mus, axis=0) - betas = np.array(predictions) - if not individual_preds: - return np.mean(betas, axis=0) - return betas + organized["trainer"]["strategy"] = "auto" + + if self.accelerator == "gpu": + organized["trainer"].setdefault("precision", "16-mixed") + else: + organized["trainer"].setdefault("precision", 32) + + maybe_add("data", "num_workers", self._default_num_workers(devices)) + maybe_add("data", "pin_memory", self.accelerator == "gpu") + maybe_add( + "data", + "persistent_workers", + organized["data"].get("num_workers", 0) > 0, + ) + maybe_add("data", "drop_last", (devices > 1 or world > 1)) + maybe_add("data", "shuffle_train", True) + maybe_add("data", "shuffle_eval", False) + maybe_add("data", "dtype", torch.float) + + maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) + + val_split = float(organized["data"].get("val_split", self.default_val_split)) + organized["data"]["val_split"] = val_split + + use_val = self._safe_val_split(10, val_split) > 0.0 + es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) + es_monitor = organized["wrapper"].get( + "es_monitor", "val_loss" if use_val else "train_loss" + ) + es_mode = organized["wrapper"].get("es_mode", "min") + es_verbose = organized["wrapper"].get("es_verbose", False) + es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) + + cb_ctors = organized["trainer"].get("callback_constructors", None) + if cb_ctors is None: + cb_ctors = [] + + organized["trainer"].setdefault("enable_checkpointing", True) + + if es_patience is not None and int(es_patience) > 0: + cb_ctors.append( + lambda i: EarlyStopping( + monitor=es_monitor, + mode=es_mode, + patience=int(es_patience), + verbose=bool(es_verbose), + min_delta=float(es_min_delta), + ) + ) + + if bool(organized["trainer"].get("enable_checkpointing", True)): + cb_ctors.append( + lambda i: ModelCheckpoint( + monitor=es_monitor, + dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", + filename="{epoch}-{val_loss:.4f}", + ) + ) + + organized["trainer"]["callback_constructors"] = cb_ctors + return organized def fit(self, *args, **kwargs) -> None: """ @@ -518,7 +643,7 @@ def fit(self, *args, **kwargs) -> None: Args: C (np.ndarray): Context array of shape (n_samples, n_context_features) X (np.ndarray): Predictor array of shape (N, n_features) - Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None, where X will be used as targets such as in Contextualized Networks. + Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None. max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1. learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3. val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2. @@ -531,72 +656,442 @@ def fit(self, *args, **kwargs) -> None: es_mode (str, optional): Mode for early stopping. Defaults to "min". es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False. """ - self.models = [] - self.trainers = [] + self.models, self.trainers = [], [] self.dataloaders = {"train": [], "val": [], "test": []} - C, X = args[0], args[1] + + if len(args) < 2: + raise ValueError("fit expects at least (C, X) as positional args.") + + C = kwargs.pop("C", None) + X = kwargs.pop("X", None) + Y = kwargs.pop("Y", None) + + if C is None or X is None: + C = args[0] + X = args[1] + if len(args) >= 3: + Y = args[2] + if C is None or X is None: + raise ValueError("fit requires C and X.") + + C = np.asarray(C) + X = np.asarray(X) + if Y is not None: + Y = np.asarray(Y) + if self.normalize: if self.scalers["C"] is None: self.scalers["C"] = StandardScaler().fit(C) C = self.scalers["C"].transform(C) + if self.scalers["X"] is None: self.scalers["X"] = StandardScaler().fit(X) X = self.scalers["X"].transform(X) - self.context_dim = C.shape[-1] - self.x_dim = X.shape[-1] - if len(args) == 3: - Y = args[2] - if kwargs.get("Y", None) is not None: - Y = kwargs.get("Y") - if len(Y.shape) == 1: # add feature dimension to Y if not given. - Y = np.expand_dims(Y, 1) - if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): - if self.scalers["Y"] is None: - self.scalers["Y"] = StandardScaler().fit(Y) - Y = self.scalers["Y"].transform(Y) - self.y_dim = Y.shape[-1] - args = (C, X, Y) + + self.context_dim = int(C.shape[-1]) + self.x_dim = int(X.shape[-1]) + + if Y is None: + Y = X else: - self.y_dim = self.x_dim - args = (C, X) - organized_kwargs = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = organized_kwargs["wrapper"].get( - "n_bootstraps", self.n_bootstraps + if Y.ndim == 1: + Y = np.expand_dims(Y, 1) + + if self.normalize and self.scalers["Y"] is not None: + pass + + if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): + if self.scalers["Y"] is None: + self.scalers["Y"] = StandardScaler().fit(Y) + Y = self.scalers["Y"].transform(Y) + + self.y_dim = int(Y.shape[-1]) + + organized = self._organize_and_expand_fit_kwargs(**kwargs) + self.n_bootstraps = int( + organized["wrapper"].get("n_bootstraps", self.n_bootstraps) ) - for bootstrap in range(self.n_bootstraps): - model = self.base_constructor(**organized_kwargs["model"]) - train_data, val_data = self._split_train_data( - *args, **organized_kwargs["data"] - ) - train_dataloader, val_dataloader = self._build_dataloaders( - model, - train_data, - val_data, - **organized_kwargs["data"], - ) - # Makes a new trainer for each bootstrap fit - bad practice, but necessary here. - my_trainer_kwargs = copy.deepcopy(organized_kwargs["trainer"]) - # Must reconstruct the callbacks because they save state from fitting trajectories. - my_trainer_kwargs["callbacks"] = [ - f(bootstrap) - for f in organized_kwargs["trainer"]["callback_constructors"] - ] - del my_trainer_kwargs["callback_constructors"] - trainer = self.trainer_constructor( - **my_trainer_kwargs, enable_progress_bar=False - ) - checkpoint_callback = my_trainer_kwargs["callbacks"][1] - os.makedirs(checkpoint_callback.dirpath, exist_ok=True) - try: - trainer.fit( - model, train_dataloader, val_dataloader, **organized_kwargs["fit"] + + val_split = float(organized["data"].get("val_split", self.default_val_split)) + val_split = self._safe_val_split(int(C.shape[0]), val_split) + organized["data"]["val_split"] = val_split + use_val = val_split > 0.0 + + if not use_val: + new_ctors = [] + for ctor in organized["trainer"].get("callback_constructors", []): + + def _wrap_ctor(_ctor): + def _inner(i): + cb = _ctor(i) + if ( + isinstance(cb, EarlyStopping) + and isinstance(getattr(cb, "monitor", ""), str) + and cb.monitor.startswith("val_") + ): + return EarlyStopping( + monitor="train_loss", + mode=getattr(cb, "mode", "min"), + patience=getattr(cb, "patience", self.default_es_patience), + verbose=getattr(cb, "verbose", False), + min_delta=getattr(cb, "min_delta", 0.0), + ) + if ( + isinstance(cb, ModelCheckpoint) + and isinstance(getattr(cb, "monitor", ""), str) + and cb.monitor.startswith("val_") + ): + cb.monitor = None + return cb + + return _inner + + new_ctors.append(_wrap_ctor(ctor)) + organized["trainer"]["callback_constructors"] = new_ctors + organized["trainer"].setdefault("limit_val_batches", 0) + + C_val = organized["data"].get("C_val", None) + X_val = organized["data"].get("X_val", None) + Y_val = organized["data"].get("Y_val", None) + + univariate_flag = bool(organized["model"].get("univariate", False)) + task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + + C_all, X_all, Y_all, train_idx, val_idx = self._resolve_train_val_arrays( + C, + X, + Y, + C_val=C_val, + X_val=X_val, + Y_val=Y_val, + Y_required=True, + val_split=val_split, + random_state=organized["data"].get("random_state", None), + ) + + + for b in range(self.n_bootstraps): + model_kwargs = dict(organized["model"]) + model_kwargs.pop("univariate", None) + + model = self.base_constructor(**model_kwargs) + + use_dm = self._use_datamodule_for_model(model) + + trainer_kwargs = copy.deepcopy(organized["trainer"]) + cb_ctors = trainer_kwargs.pop("callback_constructors", []) + callbacks = list(trainer_kwargs.get("callbacks", [])) + callbacks.extend([ctor(b) for ctor in cb_ctors]) + trainer_kwargs["callbacks"] = callbacks + + for cb in callbacks: + if isinstance(cb, ModelCheckpoint): + try: + os.makedirs(cb.dirpath, exist_ok=True) + except Exception: + pass + + from contextualized.regression.trainers import make_trainer_with_env + + trainer = make_trainer_with_env(self.trainer_constructor, **trainer_kwargs) + + if use_dm: + dm = self._build_datamodule( + C=C_all, + X=X_all, + Y=Y_all, + train_idx=train_idx, + val_idx=val_idx if use_val else None, + test_idx=None, + predict_idx=None, + data_kwargs=organized["data"], + task_type=task_type, ) - except: - trainer.fit(model, train_dataloader, **organized_kwargs["fit"]) - if kwargs.get("max_epochs", 1) > 0: - best_checkpoint = torch.load(checkpoint_callback.best_model_path) - model.load_state_dict(best_checkpoint["state_dict"]) - self.dataloaders["train"].append(train_dataloader) - self.dataloaders["val"].append(val_dataloader) + + if _is_main_process(): + print( + f"[RANK {_rank()}] train_idx[:5]={train_idx[:5]}, " + f"val_idx[:5]={val_idx[:5] if val_idx is not None else None}" + ) + + trainer.fit(model, datamodule=dm, **organized["fit"]) + + try: + dm.setup("fit") + self.dataloaders["train"].append(dm.train_dataloader()) + self.dataloaders["val"].append(dm.val_dataloader() if use_val else None) + self.dataloaders["test"].append(None) + except Exception: + self.dataloaders["train"].append(None) + self.dataloaders["val"].append(None) + self.dataloaders["test"].append(None) + + else: + train_data = ( + [C_all[train_idx], X_all[train_idx], Y_all[train_idx]] + if Y_all is not None + else [C_all[train_idx], X_all[train_idx]] + ) + + val_data = None + if use_val and val_idx is not None: + val_data = ( + [C_all[val_idx], X_all[val_idx], Y_all[val_idx]] + if Y_all is not None + else [C_all[val_idx], X_all[val_idx]] + ) + + train_dl = model.dataloader( + *train_data, + batch_size=organized["data"].get( + "train_batch_size", self.default_train_batch_size + ), + ) + + val_dl = None + if val_data is not None: + val_dl = model.dataloader( + *val_data, + batch_size=organized["data"].get( + "val_batch_size", self.default_val_batch_size + ), + ) + + try: + trainer.fit(model, train_dl, val_dl, **organized["fit"]) + except Exception: + trainer.fit(model, train_dl, **organized["fit"]) + + self.dataloaders["train"].append(train_dl) + self.dataloaders["val"].append(val_dl) + self.dataloaders["test"].append(None) + + ckpt_cb = next( + (cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), + None, + ) + if ckpt_cb is not None and getattr(ckpt_cb, "best_model_path", None): + best_path = ckpt_cb.best_model_path + if isinstance(best_path, str) and best_path and os.path.exists(best_path): + try: + best = torch.load(best_path, map_location="cpu") + if isinstance(best, dict) and "state_dict" in best: + model.load_state_dict(best["state_dict"]) + except Exception: + pass + self.models.append(model) self.trainers.append(trainer) + + return None + + def predict( + self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs + ) -> Union[np.ndarray, List[np.ndarray], None]: + """Predict outcomes from context C and predictors X. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + X (np.ndarray): Predictor array of shape (N, n_features) + individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: Predicted outcomes. If individual_preds is True, returns + predictions for each bootstrap. Returns None if any trainer returns None. + """ + if self.models is None or self.trainers is None: + raise ValueError("Trying to predict with a model that hasn't been trained yet.") + + C = np.asarray(C) + X = np.asarray(X) + Cq = self._maybe_scale_C(C) + Xq = self._maybe_scale_X(X) + + preds_all: List[np.ndarray] = [] + saw_none = False + + for model, trainer in zip(self.models, self.trainers): + if not hasattr(trainer, "predict_y"): + raise RuntimeError( + "Trainer does not implement predict_y(). " + "Use contextualized.regression.trainers.RegressionTrainer (or a subclass)." + ) + + use_dm = self._use_datamodule_for_model(model) + + if use_dm: + Yq = np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) + + univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) + task_type = ( + "singletask_univariate" + if univariate_flag + else "singletask_multivariate" + ) + + dm = self._build_datamodule( + C=Cq, + X=Xq, + Y=Yq, + train_idx=None, + val_idx=None, + test_idx=None, + predict_idx=np.arange(len(Cq)), + data_kwargs={**self._init_kwargs.get("data", {}), **kwargs}, + task_type=task_type, + ) + dm.setup("predict") + dl = dm.predict_dataloader() + else: + dl = model.dataloader( + Cq, + Xq, + np.zeros((len(Cq), int(self.y_dim or 1))), + batch_size=kwargs.get( + "predict_batch_size", self.default_val_batch_size + ), + ) + + yhat = trainer.predict_y(model, dl, **kwargs) + if yhat is None: + saw_none = True + continue + + preds_all.append(np.asarray(yhat, dtype=float)) + + if saw_none: + return None + + predictions = np.array(preds_all, dtype=float) + + if individual_preds: + out = predictions + else: + bad = ~np.isfinite(predictions) + if bad.any(): + num_bad_boots = np.unique(np.where(bad)[0]).size + print( + f"Warning: {num_bad_boots}/{len(preds_all)} bootstraps produced " + f"non-finite predictions; excluding them from the ensemble." + ) + out = self._nanrobust_mean(predictions, axis=0) + + if self.normalize and self.scalers["Y"] is not None: + if individual_preds: + out = np.array([self.scalers["Y"].inverse_transform(p) for p in out]) + else: + out = self.scalers["Y"].inverse_transform(out) + + return out + + def predict_params( + self, + C: np.ndarray, + individual_preds: bool = False, + model_includes_mus: bool = True, + **kwargs, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[None, None], None]: + """ + Predict context-specific model parameters from context C. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. + Defaults to False, averaging across bootstraps. + model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). + Defaults to True. + + Returns: + Union[np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[None, None], None]: + If model_includes_mus is True, returns (betas, mus); otherwise returns betas. + If individual_preds is True, returns arrays stacked over bootstraps. + Returns (None, None) or None if any trainer returns None. + """ + if self.models is None or self.trainers is None: + raise ValueError("Trying to predict with a model that hasn't been trained yet.") + + C = np.asarray(C) + Cq = self._maybe_scale_C(C) + + uses_y = bool(kwargs.pop("uses_y", True)) + + betas_list: List[np.ndarray] = [] + mus_list: List[np.ndarray] = [] + saw_none = False + + for model, trainer in zip(self.models, self.trainers): + if not hasattr(trainer, "predict_params"): + raise RuntimeError( + "Trainer does not implement predict_params(). " + "Use contextualized.regression.trainers.RegressionTrainer (or a subclass)." + ) + + use_dm = self._use_datamodule_for_model(model) + + if use_dm: + X_zero = np.zeros((len(Cq), int(self.x_dim or 1)), dtype=np.float32) + Y_zero = ( + np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) + if uses_y + else None + ) + + univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) + task_type = ( + "singletask_univariate" + if univariate_flag + else "singletask_multivariate" + ) + + dm = self._build_datamodule( + C=Cq, + X=X_zero, + Y=Y_zero, + train_idx=None, + val_idx=None, + test_idx=None, + predict_idx=np.arange(len(Cq)), + data_kwargs={**self._init_kwargs.get("data", {}), **kwargs}, + task_type=task_type, + ) + dm.setup("predict") + dl = dm.predict_dataloader() + else: + if uses_y: + dl = model.dataloader( + Cq, + np.zeros((len(Cq), int(self.x_dim or 1))), + np.zeros((len(Cq), int(self.y_dim or 1))), + ) + else: + dl = model.dataloader( + Cq, + np.zeros((len(Cq), int(self.x_dim or 1))), + ) + + out = trainer.predict_params(model, dl, **kwargs) + if out is None or (isinstance(out, tuple) and out[0] is None): + saw_none = True + continue + + if model_includes_mus: + b, m = out + betas_list.append(np.asarray(b)) + mus_list.append(np.asarray(m)) + else: + betas_list.append(np.asarray(out)) + + if saw_none: + return (None, None) if model_includes_mus else None + + betas = np.array(betas_list) + + if model_includes_mus: + mus = np.array(mus_list) + if individual_preds: + return betas, mus + return np.mean(betas, axis=0), np.mean(mus, axis=0) + + if individual_preds: + return betas + return np.mean(betas, axis=0) diff --git a/contextualized/modules.py b/contextualized/modules.py index 96d48b07..65dd45c0 100644 --- a/contextualized/modules.py +++ b/contextualized/modules.py @@ -8,6 +8,25 @@ from contextualized.functions import LINK_FUNCTIONS +def _resolve_link_fn(maybe_link): + """ + Accepts either: + - a string key (looked up in LINK_FUNCTIONS), or + - a callable (returned as-is, including functools.partial) + """ + if isinstance(maybe_link, str): + try: + return LINK_FUNCTIONS[maybe_link] + except KeyError as e: + raise KeyError( + f"Unknown link_fn '{maybe_link}'. " + f"Valid options: {list(LINK_FUNCTIONS.keys())}" + ) from e + if callable(maybe_link): + return maybe_link + raise TypeError(f"link_fn must be str or callable, got {type(maybe_link).__name__}") + + class SoftSelect(nn.Module): """ Parameter sharing for multiple context encoders: @@ -91,7 +110,7 @@ def __init__( else: # Linear encoder mlp_layers = [nn.Linear(input_dim, output_dim)] self.mlp = nn.Sequential(*mlp_layers) - self.link_fn = LINK_FUNCTIONS[link_fn] + self.link_fn = _resolve_link_fn(link_fn) def forward(self, X): """Torch Forward pass.""" @@ -114,8 +133,12 @@ def __init__( link_fn="identity", ): super().__init__() - self.intput_dim = input_dim + self.input_dim = input_dim self.output_dim = output_dim + + # Each feature-wise network uses an identity link; the global link is applied once. + per_feat_link = "identity" + self.nams = nn.ModuleList( [ MLP( @@ -124,17 +147,17 @@ def __init__( width, layers, activation=activation, - link_fn=identity_link, + link_fn=per_feat_link, ) for _ in range(input_dim) ] ) - self.link_fn = LINK_FUNCTIONS[link_fn] + self.link_fn = _resolve_link_fn(link_fn) def forward(self, X): """Torch Forward pass.""" ret = self.nams[0](X[:, 0].unsqueeze(-1)) - for i, nam in enumerate(self.nams[1:]): + for i, nam in enumerate(self.nams[1:], start=1): ret += nam(X[:, i].unsqueeze(-1)) return self.link_fn(ret) diff --git a/contextualized/regression/__init__.py b/contextualized/regression/__init__.py index e4498af7..9e8fd308 100644 --- a/contextualized/regression/__init__.py +++ b/contextualized/regression/__init__.py @@ -19,6 +19,10 @@ TasksplitContextualizedUnivariateRegression, ) from contextualized.regression.trainers import RegressionTrainer +from contextualized.regression.datamodules import ( + ContextualizedRegressionDataModule, + TASK_TO_DATASET, +) DATASETS = { "multivariate": MultivariateDataset, @@ -26,7 +30,46 @@ "multitask_multivariate": MultitaskMultivariateDataset, "multitask_univariate": MultitaskUnivariateDataset, } + LOSSES = {"mse": MSE, "bceloss": BCELoss} + MODELS = ["multivariate", "univariate"] + METAMODELS = ["simple", "subtype", "multitask", "tasksplit"] + TRAINERS = {"regression_trainer": RegressionTrainer} + +# New exports for distributed-ready data handling +DATAMODULES = { + "regression": ContextualizedRegressionDataModule, +} + +__all__ = [ + # datasets + "MultivariateDataset", + "UnivariateDataset", + "MultitaskMultivariateDataset", + "MultitaskUnivariateDataset", + "DATASETS", + # datamodules + "ContextualizedRegressionDataModule", + "TASK_TO_DATASET", + "DATAMODULES", + # losses/regularizers + "MSE", + "BCELoss", + "REGULARIZERS", + "LOSSES", + # models + "NaiveContextualizedRegression", + "ContextualizedRegression", + "MultitaskContextualizedRegression", + "TasksplitContextualizedRegression", + "ContextualizedUnivariateRegression", + "TasksplitContextualizedUnivariateRegression", + "MODELS", + "METAMODELS", + # trainers + "RegressionTrainer", + "TRAINERS", +] \ No newline at end of file diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py new file mode 100644 index 00000000..1e53abf0 --- /dev/null +++ b/contextualized/regression/datamodules.py @@ -0,0 +1,225 @@ +# contextualized/regression/datamodules.py +from __future__ import annotations + +from typing import Callable, Optional, Sequence, Tuple, Union, Dict +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader +import lightning.pytorch as pl + +from .datasets import ( + MultivariateDataset, + UnivariateDataset, + MultitaskMultivariateDataset, + MultitaskUnivariateDataset, +) + +TensorLike = Union[np.ndarray, pd.DataFrame, pd.Series, torch.Tensor] +IndexLike = Optional[Union[Sequence[int], np.ndarray, torch.Tensor]] + +TASK_TO_DATASET = { + "singletask_multivariate": MultivariateDataset, + "singletask_univariate": UnivariateDataset, + "multitask_multivariate": MultitaskMultivariateDataset, + "multitask_univariate": MultitaskUnivariateDataset, +} + + +def _to_tensor(x: TensorLike, dtype: torch.dtype) -> torch.Tensor: + if isinstance(x, torch.Tensor): + return x.to(dtype=dtype, copy=False) + if isinstance(x, (pd.DataFrame, pd.Series)): + x = x.to_numpy(copy=False) + return torch.as_tensor(x, dtype=dtype) + + +def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: + if idx is None: + return x + if isinstance(idx, torch.Tensor): + return x[idx] + if isinstance(idx, np.ndarray): + idx = torch.as_tensor(idx, dtype=torch.long) + return x[idx] + return x[torch.as_tensor(idx, dtype=torch.long)] + + +def _to_index_tensor(idx: IndexLike) -> Optional[torch.Tensor]: + """Normalize an index-like into a 1D CPU LongTensor.""" + if idx is None: + return None + if isinstance(idx, torch.Tensor): + out = idx.to(dtype=torch.long, device="cpu") + elif isinstance(idx, np.ndarray): + out = torch.as_tensor(idx, dtype=torch.long, device="cpu") + else: + out = torch.as_tensor(idx, dtype=torch.long, device="cpu") + return out.view(-1) + + +class ContextualizedRegressionDataModule(pl.LightningDataModule): + """ + DataModule that returns map-style datasets for contextualized regression, + allowing Lightning's Trainer (DDP) to auto-attach DistributedSampler and shard data. + + give ∈ { + "singletask_multivariate", + "singletask_univariate", + "multitask_multivariate", + "multitask_univariate", + } + """ + + def __init__( + self, + C: TensorLike, + X: TensorLike, + Y: Optional[TensorLike], + *, + task_type: str, + train_idx: IndexLike = None, + val_idx: IndexLike = None, + test_idx: IndexLike = None, + predict_idx: IndexLike = None, + splitter: Optional[ + Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + Tuple[IndexLike, IndexLike, IndexLike]] + ] = None, + train_batch_size: int = 32, + val_batch_size: int = 32, + test_batch_size: int = 32, + predict_batch_size: int = 32, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + drop_last: bool = False, + shuffle_train: bool = True, + shuffle_eval: bool = False, + dtype: torch.dtype = torch.float, + ): + super().__init__() + if task_type not in TASK_TO_DATASET: + raise ValueError( + f"Unknown task_type={task_type!r}. " + f"Expected one of {list(TASK_TO_DATASET)}." + ) + self.task_type = task_type + + self._C_raw = C + self._X_raw = X + self._Y_raw = Y + + self.train_idx = train_idx + self.val_idx = val_idx + self.test_idx = test_idx + self.predict_idx = predict_idx + self.splitter = splitter + + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.test_batch_size = test_batch_size + self.predict_batch_size = predict_batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = bool(persistent_workers and num_workers > 0) + self.drop_last = drop_last + self.shuffle_train = shuffle_train + self.shuffle_eval = shuffle_eval + self.dtype = dtype + + self.C: Optional[torch.Tensor] = None + self.X: Optional[torch.Tensor] = None + self.Y: Optional[torch.Tensor] = None + + self.ds_train = None + self.ds_val = None + self.ds_test = None + self.ds_predict = None + + def prepare_data(self) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + C = _to_tensor(self._C_raw, self.dtype) + X = _to_tensor(self._X_raw, self.dtype) + Y = None if self._Y_raw is None else _to_tensor(self._Y_raw, self.dtype) + + if self.train_idx is None and self.val_idx is None and self.test_idx is None: + if self.splitter is not None: + tr, va, te = self.splitter(C, X, Y) + self.train_idx, self.val_idx, self.test_idx = tr, va, te + + if self.predict_idx is None: + if self.test_idx is not None: + self.predict_idx = self.test_idx + else: + self.predict_idx = torch.arange(C.shape[0], dtype=torch.long) + + def _mk_dataset(idx: IndexLike): + if idx is None: + return None + + idx_t = _to_index_tensor(idx) + + C_s = _maybe_index(C, idx_t) + X_s = _maybe_index(X, idx_t) + Y_s = None if (Y is None) else _maybe_index(Y, idx_t) + ds_cls = TASK_TO_DATASET[self.task_type] + + if Y_s is None: + Y_s = X_s + + return ds_cls(C_s, X_s, Y_s, orig_idx=idx_t, dtype=self.dtype) + + self.ds_train = _mk_dataset(self.train_idx) + self.ds_val = _mk_dataset(self.val_idx) + self.ds_test = _mk_dataset(self.test_idx) + self.ds_predict = _mk_dataset(self.predict_idx) + + self.C, self.X, self.Y = C, X, Y + + def _common_dl_kwargs(self, batch_size: int, *, drop_last: Optional[bool] = None) -> Dict: + return { + "batch_size": batch_size, + "num_workers": self.num_workers, + "pin_memory": self.pin_memory, + "persistent_workers": bool(self.num_workers > 0 and self.persistent_workers), + "drop_last": self.drop_last if drop_last is None else bool(drop_last), + } + + def train_dataloader(self) -> DataLoader: + if self.ds_train is None: + raise RuntimeError("train dataset is not set; provide train_idx or splitter.") + return DataLoader( + dataset=self.ds_train, + shuffle=self.shuffle_train, + **self._common_dl_kwargs(self.train_batch_size, drop_last=self.drop_last), + ) + + def val_dataloader(self): + if self.ds_val is None: + return None + return DataLoader( + dataset=self.ds_val, + shuffle=self.shuffle_eval, + **self._common_dl_kwargs(self.val_batch_size, drop_last=False), + ) + + def test_dataloader(self) -> DataLoader: + if self.ds_test is None: + raise RuntimeError("test dataset is not set; provide test_idx or splitter.") + return DataLoader( + dataset=self.ds_test, + shuffle=self.shuffle_eval, + **self._common_dl_kwargs(self.test_batch_size, drop_last=False), + ) + + def predict_dataloader(self) -> DataLoader: + if self.ds_predict is None: + raise RuntimeError("predict dataset is not set; provide predict_idx/test_idx.") + return DataLoader( + dataset=self.ds_predict, + shuffle=False, + **self._common_dl_kwargs(self.predict_batch_size, drop_last=False), + ) diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index 911997cc..30b36d05 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -11,21 +11,28 @@ class MultivariateDataset(Dataset): """ Simple multivariate dataset with context, predictors, and outcomes. """ - def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) + + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) - + def __getitem__(self, idx): return { "idx": idx, + "orig_idx": self.orig_idx[idx], "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1), "outcomes": self.Y[idx].unsqueeze(-1), @@ -36,21 +43,28 @@ class UnivariateDataset(Dataset): """ Simple univariate dataset with context, predictors, and one outcome. """ - def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) + + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) - + def __getitem__(self, idx): return { "idx": idx, + "orig_idx": self.orig_idx[idx], "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1).unsqueeze(-1), "outcomes": self.Y[idx].expand(self.x_dim, -1).T.unsqueeze(-1), @@ -61,27 +75,36 @@ class MultitaskMultivariateDataset(Dataset): """ Multi-task Multivariate Dataset. """ - def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): + self.C = C.to(dtype) if isinstance(C, torch.Tensor) else torch.as_tensor(C, dtype=dtype) + self.X = X.to(dtype) if isinstance(X, torch.Tensor) else torch.as_tensor(X, dtype=dtype) + self.Y = Y.to(dtype) if isinstance(Y, torch.Tensor) else torch.as_tensor(Y, dtype=dtype) + + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) * self.y_dim - + def __getitem__(self, idx): # Get task-split sample indices n_i = idx // self.y_dim y_i = idx % self.y_dim + # Create a one-hot encoding for the task - t = torch.zeros(self.y_dim) + t = torch.zeros(self.y_dim, dtype=self.dtype) t[y_i] = 1 + return { "idx": idx, + "orig_idx": self.orig_idx[n_i], "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i], @@ -90,55 +113,44 @@ def __getitem__(self, idx): "outcome_idx": y_i, } - # def __next__(self): - # if self.y_i >= self.y_dim: - # self.n_i += 1 - # self.y_i = 0 - # if self.n_i >= self.n: - # self.n_i = 0 - # raise StopIteration - # t = torch.zeros(self.y_dim) - # t[self.y_i] = 1 - # ret = ( - # self.C[self.n_i], - # t, - # self.X[self.n_i], - # self.Y[self.n_i, self.y_i].unsqueeze(0), - # self.n_i, - # self.y_i, - # ) - # self.y_i += 1 - # return ret - class MultitaskUnivariateDataset(Dataset): """ Multitask Univariate Dataset. Splits each sample into univariate X and Y feature pairs for univariate regression tasks. - """ - def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + """ + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) + + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) * self.x_dim * self.y_dim - + def __getitem__(self, idx): # Get task-split sample indices n_i = idx // (self.x_dim * self.y_dim) x_i = (idx // self.y_dim) % self.x_dim y_i = idx % self.y_dim + # Create a one-hot encoding for the task - t = torch.zeros(self.x_dim + self.y_dim) + t = torch.zeros(self.x_dim + self.y_dim, dtype=self.dtype) t[x_i] = 1 t[self.x_dim + y_i] = 1 + return { "idx": idx, + "orig_idx": self.orig_idx[n_i], "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i, x_i].unsqueeze(0), @@ -146,4 +158,4 @@ def __getitem__(self, idx): "sample_idx": n_i, "predictor_idx": x_i, "outcome_idx": y_i, - } \ No newline at end of file + } diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index c31398c9..4a696a07 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -11,11 +11,13 @@ Implemented with PyTorch Lightning """ +from .datamodules import ContextualizedRegressionDataModule + from abc import abstractmethod import numpy as np import torch from torch.utils.data import DataLoader -import lightning as pl +import lightning.pytorch as pl from contextualized.regression.regularizers import REGULARIZERS from contextualized.regression.losses import MSE @@ -31,6 +33,70 @@ ) +def _resolve_registry_or_callable(maybe_obj, registry, name: str): + """ + + :param maybe_obj: + :param registry: + :param name: + + """ + if isinstance(maybe_obj, str): + try: + return registry[maybe_obj] + except KeyError as e: + raise KeyError( + f"Unknown {name} '{maybe_obj}'. Valid keys: {list(registry.keys())}" + ) from e + if callable(maybe_obj): + return maybe_obj + raise TypeError( + f"{name} must be a string key or a callable, got {type(maybe_obj).__name__}" + ) + + +def _resolve_loss(maybe_loss): + """ + + :param maybe_loss: + + """ + if isinstance(maybe_loss, str): + if maybe_loss.lower() == "mse": + return MSE + raise KeyError( + f"Unknown loss_fn '{maybe_loss}'. " + "Pass a callable loss or the string 'mse'." + ) + if callable(maybe_loss): + return maybe_loss + raise TypeError( + f"loss_fn must be a string key or a callable, got {type(maybe_loss).__name__}" + ) + + +def _resolve_regularizer(maybe_reg): + """ + + :param maybe_reg: + + """ + if isinstance(maybe_reg, str): + try: + return REGULARIZERS[maybe_reg] + except KeyError as e: + raise KeyError( + f"Unknown model_regularizer '{maybe_reg}'. " + f"Valid keys: {list(REGULARIZERS.keys())}" + ) from e + if callable(maybe_reg): + return maybe_reg + raise TypeError( + "model_regularizer must be a string key or a callable, got " + f"{type(maybe_reg).__name__}" + ) + + class ContextualizedRegressionBase(pl.LightningModule): """ Abstract class for Contextualized Regression. @@ -72,7 +138,7 @@ class ContextualizedRegressionBase(pl.LightningModule): # self.base_y_predictor = base_y_predictor # self.base_param_predictor = base_param_predictor # self._build_metamodel( - # context_dim, + # context_dim, # x_dim, # y_dim, # univariate, @@ -84,8 +150,8 @@ class ContextualizedRegressionBase(pl.LightningModule): # @abstractmethod # def _build_metamodel( - # self, - # context_dim, + # self, + # context_dim, # x_dim, # y_dim, # univariate, @@ -102,7 +168,7 @@ class ContextualizedRegressionBase(pl.LightningModule): # """ # # builds the metamodel # self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type]( - # context_dim, + # context_dim, # x_dim, # y_dim, # univariate, @@ -176,7 +242,9 @@ def forward(self, batch): if not self.fit_intercept: mu = torch.zeros_like(mu) if self.base_param_predictor is not None: - base_beta, base_mu = self.base_param_predictor.predict_params(batch["contexts"]) + base_beta, base_mu = self.base_param_predictor.predict_params( + batch["contexts"] + ) beta = beta + base_beta.to(beta.device) mu = mu + base_mu.to(mu.device) return beta, mu @@ -188,6 +256,48 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer + def _batch_size_from_batch(self, batch: dict) -> int: + """ + + :param batch: + + """ + if ( + isinstance(batch, dict) + and "contexts" in batch + and isinstance(batch["contexts"], torch.Tensor) + ): + return int(batch["contexts"].shape[0]) + return 1 + + def _predict_payload(self, batch: dict, **outputs) -> dict: + """ + + :param batch: + :param **outputs: + + """ + out = {} + for k in ( + "idx", + "orig_idx", + "sample_idx", + "outcome_idx", + "predictor_idx", + "contexts", + "predictors", + ): + if isinstance(batch, dict) and k in batch: + out[k] = batch[k] + + out.update(outputs) + + for k, v in list(out.items()): + if isinstance(v, torch.Tensor): + out[k] = v.detach().cpu() + return out + + def training_step(self, batch, batch_idx): """ @@ -196,7 +306,28 @@ def training_step(self, batch, batch_idx): """ loss = self._batch_loss(batch, batch_idx) - self.log_dict({"train_loss": loss}) + bs = self._batch_size_from_batch(batch) + + self.log( + "train_loss_step", + loss, + on_step=True, + on_epoch=False, + prog_bar=True, + sync_dist=False, + batch_size=bs, + ) + + self.log( + "train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + sync_dist=True, + batch_size=bs, + ) + return loss def validation_step(self, batch, batch_idx): @@ -207,7 +338,16 @@ def validation_step(self, batch, batch_idx): """ loss = self._batch_loss(batch, batch_idx) - self.log_dict({"val_loss": loss}) + bs = self._batch_size_from_batch(batch) + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=bs, + ) return loss def test_step(self, batch, batch_idx): @@ -218,7 +358,16 @@ def test_step(self, batch, batch_idx): """ loss = self._batch_loss(batch, batch_idx) - self.log_dict({"test_loss": loss}) + bs = self._batch_size_from_batch(batch) + self.log( + "test_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=bs, + ) return loss def _predict_from_models(self, X, beta_hat, mu_hat): @@ -229,7 +378,122 @@ def _predict_from_models(self, X, beta_hat, mu_hat): :param mu_hat: """ - return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) + if isinstance(X, torch.Tensor) and X.dim() == 4 and X.shape[-1] == 1: + X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if beta_hat.dim() == 3: + beta_hat = beta_hat.unsqueeze(-1) + if beta_hat.dim() != 4 or beta_hat.shape[-1] != 1: + raise RuntimeError( + f"Univariate expects beta_hat (B,y,x,1); got {beta_hat.shape}" + ) + + if not isinstance(mu_hat, torch.Tensor): + mu_hat = torch.as_tensor( + mu_hat, device=beta_hat.device, dtype=beta_hat.dtype + ) + else: + mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if mu_hat.dim() == 2: + mu_hat = ( + mu_hat.unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + ) + elif mu_hat.dim() == 3: + if mu_hat.shape[-1] == 1: + mu_hat = mu_hat.unsqueeze(-1).expand( + -1, beta_hat.shape[1], beta_hat.shape[2], 1 + ) + else: + mu_hat = mu_hat.unsqueeze(-1) + elif mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: + pass + else: + raise RuntimeError( + f"Unsupported mu_hat shape for univariate: {mu_hat.shape}" + ) + + out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat + return self.link_fn(out) + + if not isinstance(beta_hat, torch.Tensor): + raise RuntimeError(f"beta_hat must be a tensor, got {type(beta_hat)}") + + if beta_hat.dim() == 4 and beta_hat.shape[-1] == 1: + beta_hat = beta_hat.squeeze(-1) + + if beta_hat.dim() != 3: + raise RuntimeError( + f"_predict_from_models expects beta_hat with shape (B, y, x) " + f"or (B, y, x, 1); got {beta_hat.shape}" + ) + + B, y_dim, x_dim = beta_hat.shape + + if not isinstance(X, torch.Tensor): + X = torch.as_tensor(X, device=beta_hat.device, dtype=beta_hat.dtype) + else: + X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if X.dim() == 2: + if X.shape[0] != B: + raise RuntimeError( + f"X batch dim {X.shape[0]} != beta_hat batch dim {B}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + if X.shape[1] != x_dim: + raise RuntimeError( + f"X feature dim {X.shape[1]} != x_dim {x_dim}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + X = X.unsqueeze(1).expand(-1, y_dim, -1) + + elif X.dim() == 3: + if X.shape[0] != B: + raise RuntimeError( + f"X batch dim {X.shape[0]} != beta_hat batch dim {B}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + + if X.shape[1] == y_dim and X.shape[2] == x_dim: + pass + elif X.shape[1] == 1 and X.shape[2] == x_dim: + X = X.expand(-1, y_dim, -1) + elif X.shape[1] == x_dim and X.shape[2] == y_dim and x_dim == y_dim: + X = X.permute(0, 2, 1) + else: + raise RuntimeError( + f"Unexpected X shape {X.shape} for beta_hat {beta_hat.shape}. " + "Cannot safely align dimensions." + ) + else: + raise RuntimeError( + f"Unsupported X.ndim={X.dim()} for _predict_from_models; " + f"expected 2 or 3. X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + + if not isinstance(mu_hat, torch.Tensor): + mu_hat = torch.as_tensor(mu_hat, device=beta_hat.device, dtype=beta_hat.dtype) + else: + mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: + mu_hat = mu_hat.squeeze(-1) + + if mu_hat.dim() == 2: + mu_hat = mu_hat.unsqueeze(-1) + elif mu_hat.dim() == 3: + pass + else: + raise RuntimeError( + f"Unsupported mu_hat.ndim={mu_hat.dim()} in _predict_from_models; " + f"mu_hat.shape={mu_hat.shape}" + ) + + out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat + return self.link_fn(out) def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -261,85 +525,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): # return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) -# class NaiveContextualizedRegression(ContextualizedRegressionBase): -# """See NaiveMetamodel""" - -# def _build_metamodel(self, *args, **kwargs): -# """ - -# :param *args: -# :param **kwargs: - -# """ -# kwargs["univariate"] = False -# self.metamodel = NaiveMetamodel(*args, **kwargs) - -# def _batch_loss(self, batch, batch_idx): -# """ - -# :param batch: -# :param batch_idx: - -# """ -# C, X, Y, _ = batch -# beta_hat, mu_hat = self.predict_step(batch, batch_idx) -# pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) -# reg_loss = self.model_regularizer(beta_hat, mu_hat) -# return pred_loss + reg_loss - -# def predict_step(self, batch, batch_idx): -# """ - -# :param batch: -# :param batch_idx: - -# """ -# C, _, _, _ = batch -# beta_hat, mu_hat = self(C) -# return beta_hat, mu_hat - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) - - class ContextualizedRegression(ContextualizedRegressionBase): """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" def __init__( @@ -364,19 +549,20 @@ def __init__( base_param_predictor=None, ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.base_y_predictor = base_y_predictor self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": self.metamodel = SubtypeMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=False, @@ -388,7 +574,7 @@ def __init__( if num_archetypes is not None: raise ValueError("NaiveMetamodel does not support num_archetypes.") self.metamodel = NaiveMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=False, @@ -406,69 +592,22 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat.squeeze(-1), - }) - return batch - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class NaiveContextualizedRegression(ContextualizedRegression): """Handle for NaiveMetamodel usage of ContextualizedRegression. - Does not use archetypes. + Does not use archetypes. """ def __init__( self, @@ -503,8 +642,9 @@ def __init__( loss_fn=loss_fn, model_regularizer=model_regularizer, base_y_predictor=base_y_predictor, - base_param_predictor=base_param_predictor + base_param_predictor=base_param_predictor, ) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) class MultitaskContextualizedRegression(ContextualizedRegressionBase): @@ -528,14 +668,15 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = MultitaskMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -555,7 +696,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -566,10 +706,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -580,65 +723,12 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat.squeeze(-1), - }) - # Return batch with predictions - return batch - - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, y_idx = data - # betas[n_idx, y_idx] = beta_hats - # mus[n_idx, y_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, y_idx = data - # ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class TasksplitContextualizedRegression(ContextualizedRegressionBase): @@ -671,15 +761,16 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.metamodel_type = metamodel_type self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = TasksplitMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -692,7 +783,7 @@ def __init__( task_encoder_type=task_encoder_type, task_encoder_kwargs=task_encoder_kwargs, ) - + def forward(self, batch): """ @@ -702,7 +793,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -713,10 +803,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -727,90 +820,12 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat.squeeze(-1), - }) - # Return batch with predictions - return batch - - # def _batch_loss(self, batch, batch_idx): - # """ - - # :param batch: - # :param batch_idx: - - # """ - # beta_hat, mu_hat = self(batch) - # pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) - # reg_loss = self.model_regularizer(beta_hat, mu_hat) - # return pred_loss + reg_loss - - # def predict_step(self, batch, batch_idx): - # """ - - # :param batch: - # :param batch_idx: - - # """ - # beta_hat, mu_hat = self(batch) - # batch.update({ - # "betas": beta_hat, - # "mus": mu_hat.squeeze(-1) - # }) - # return batch - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, y_idx = data - # betas[n_idx, y_idx] = beta_hats - # mus[n_idx, y_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, y_idx = data - # ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class ContextualizedUnivariateRegression(ContextualizedRegressionBase): @@ -837,19 +852,20 @@ def __init__( base_param_predictor=None, ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.base_y_predictor = base_y_predictor self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": self.metamodel = SubtypeMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=True, @@ -861,7 +877,7 @@ def __init__( if num_archetypes is not None: raise ValueError("NaiveMetamodel does not support num_archetypes.") self.metamodel = NaiveMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=True, @@ -870,7 +886,7 @@ def __init__( ) else: raise ValueError("Supported metamodel_type's: subtype, naive") - + def forward(self, batch): """ @@ -880,9 +896,8 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu - + def _batch_loss(self, batch, batch_idx): """ @@ -891,64 +906,17 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), - }) - return batch - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats.squeeze(-1) - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, UnivariateDataset, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class MultitaskContextualizedUnivariateRegression(ContextualizedRegressionBase): @@ -973,14 +941,15 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = MultitaskMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -990,7 +959,7 @@ def __init__( encoder_type=encoder_type, encoder_kwargs=encoder_kwargs, ) - + def forward(self, batch): """ @@ -1000,7 +969,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -1011,10 +979,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -1025,22 +996,13 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + class TasksplitContextualizedUnivariateRegression(ContextualizedRegressionBase): """See TasksplitMetamodel""" @@ -1071,14 +1033,15 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = TasksplitMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -1091,7 +1054,7 @@ def __init__( task_encoder_type=task_encoder_type, task_encoder_kwargs=task_encoder_kwargs, ) - + def forward(self, batch): """ @@ -1101,7 +1064,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -1112,10 +1074,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -1126,65 +1091,12 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), - }) - return batch - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = betas.copy() - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, x_idx, y_idx = data - # betas[n_idx, y_idx, x_idx] = beta_hats.squeeze(-1) - # mus[n_idx, y_idx, x_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, x_idx, y_idx = data - # ys[n_idx, y_idx, x_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze( - # -1 - # ) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskUnivariateDataset, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class ContextualizedCorrelation(ContextualizedUnivariateRegression): @@ -1198,26 +1110,21 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) beta_hat = beta_hat.squeeze(-1) + beta_hat_T = beta_hat.transpose(1, 2) signs = torch.sign(beta_hat) signs[signs != signs.transpose(1, 2)] = 0 correlations = signs * torch.sqrt(torch.abs(beta_hat * beta_hat_T)) - batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), - "correlations": correlations, - }) - return batch + + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload( + batch, betas=beta_hat, mus=mu_hat, correlations=correlations + ) class MultitaskContextualizedCorrelation(MultitaskContextualizedUnivariateRegression): @@ -1231,6 +1138,7 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) class TasksplitContextualizedCorrelation(TasksplitContextualizedUnivariateRegression): @@ -1244,6 +1152,7 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) class ContextualizedNeighborhoodSelection(ContextualizedRegression): @@ -1266,35 +1175,16 @@ def __init__( super().__init__( context_dim, x_dim, x_dim, model_regularizer=model_regularizer, **kwargs ) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ - C, _, _, _ = batch - beta_hat, mu_hat = self(C) + beta_hat, mu_hat = self(batch) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - return beta_hat, mu_hat - - def dataloader(self, C, X, Y=None, **kwargs): - """ - - :param C: - :param X: - :param Y: - :param **kwargs: - """ - - if Y is not None: - print( - "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." - ) - return super().dataloader(C, X, X, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) class ContextualizedMarkovGraph(ContextualizedRegression): @@ -1309,35 +1199,14 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ - C, _, _, _ = batch - beta_hat, mu_hat = self(C) - beta_hat = beta_hat + torch.transpose( - beta_hat, 1, 2 - ) # hotfix to enforce symmetry + beta_hat, mu_hat = self(batch) + beta_hat = beta_hat + beta_hat.transpose(1, 2) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - return beta_hat, mu_hat - - def dataloader(self, C, X, Y=None, **kwargs): - """ - :param C: - :param X: - :param Y: - :param **kwargs: - - """ - - if Y is not None: - print( - "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." - ) - return super().dataloader(C, X, X, **kwargs) + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) diff --git a/contextualized/regression/metamodels.py b/contextualized/regression/metamodels.py index feca94af..5df95120 100644 --- a/contextualized/regression/metamodels.py +++ b/contextualized/regression/metamodels.py @@ -274,4 +274,4 @@ def forward(self, C, T): MULTITASK_METAMODELS = { "multitask": MultitaskMetamodel, "tasksplit": TasksplitMetamodel, -} +} \ No newline at end of file diff --git a/contextualized/regression/regularizers.py b/contextualized/regression/regularizers.py index 5ba15648..eee1f904 100644 --- a/contextualized/regression/regularizers.py +++ b/contextualized/regression/regularizers.py @@ -78,4 +78,4 @@ def l1_l2_reg(alpha, l1_ratio=0.5, mu_ratio=0.5): return partial(l1_l2_reg_fn, alpha, l1_ratio, mu_ratio) -REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg} +REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg} \ No newline at end of file diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 759e09ea..107d1f9d 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -2,8 +2,193 @@ PyTorch-Lightning trainers used for Contextualized regression. """ +from typing import Any, Tuple, List, Dict, Optional + import numpy as np -import pytorch_lightning as pl +import torch +import torch.distributed as dist +import lightning.pytorch as pl +from lightning.pytorch.plugins.environments import LightningEnvironment +import os +from lightning.pytorch.strategies import DDPStrategy + + +def _stack_from_preds(preds: List[dict], key: str) -> torch.Tensor: + """ + Concatenate a tensor field from the list of batch dicts returned by predict(). + """ + preds = _flatten_pl_predict_output(preds) + parts = [] + for p in preds: + val = p[key] + if isinstance(val, np.ndarray): + val = torch.from_numpy(val) + parts.append(val.detach().cpu()) + return torch.cat(parts, dim=0) + + +def _is_distributed() -> bool: + return dist.is_available() and dist.is_initialized() + + +def _is_main_process() -> bool: + return (not _is_distributed()) or dist.get_rank() == 0 + + +def _flatten_pl_predict_output(preds): + """ + Lightning can return: + - list[dict] (single dataloader) + - list[list[dict]] (multiple dataloaders) + + Normalize to list[dict]. + """ + if preds is None: + return [] + if len(preds) > 0 and isinstance(preds[0], list): + out = [] + for sub in preds: + out.extend(sub) + return out + return preds + + +def _to_numpy_cpu(x): + if x is None: + return None + if isinstance(x, np.ndarray): + return x + if torch.is_tensor(x): + return x.detach().cpu().numpy() + return np.asarray(x) + + +def _pack_keys_from_preds(preds: list, keys: Tuple[str, ...]) -> Dict[str, np.ndarray]: + """ + Pack only requested keys from list[dict] predictions into numpy arrays. + Concats on axis 0. + """ + preds = _flatten_pl_predict_output(preds) + if not preds: + return {} + + packed: Dict[str, List[np.ndarray]] = {k: [] for k in keys} + for p in preds: + for k in keys: + if k in p: + v = _to_numpy_cpu(p[k]) + if v is not None: + packed[k].append(v) + + out: Dict[str, np.ndarray] = {} + for k, parts in packed.items(): + if not parts: + continue + out[k] = np.concatenate(parts, axis=0) + return out + + +def _gather_object_to_rank0(obj): + """ + Gather arbitrary Python objects to rank 0. + + Returns: + - list[obj] on rank 0 + - None on other ranks + """ + if not _is_distributed(): + return [obj] + + world_size = dist.get_world_size() + if world_size == 1: + return [obj] + + if _is_main_process(): + gathered = [None for _ in range(world_size)] + dist.gather_object(obj, object_gather_list=gathered, dst=0) + return gathered + else: + dist.gather_object(obj, object_gather_list=None, dst=0) + return None + + +def _merge_packed_payloads( + payloads: List[Optional[Dict[str, np.ndarray]]], +) -> Dict[str, np.ndarray]: + """ + Merge list[dict[str, np.ndarray]] -> dict[str, np.ndarray] by concatenation axis 0. + """ + merged: Dict[str, np.ndarray] = {} + payloads = [p for p in payloads if p] + if not payloads: + return merged + + keys = set() + for p in payloads: + keys.update(p.keys()) + + for k in keys: + chunks = [ + p[k] + for p in payloads + if (k in p) and (p[k] is not None) and (len(p[k]) > 0) + ] + if not chunks: + continue + merged[k] = np.concatenate(chunks, axis=0) + return merged + + +def _stable_sort_and_dedupe(payload: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Sort payload arrays by dataset-local 'idx' when present (correct for subsets), + else fall back to 'orig_idx'. Then dedupe (DistributedSampler may pad/duplicate). + """ + if not payload: + return payload + + key = "idx" if "idx" in payload else ("orig_idx" if "orig_idx" in payload else None) + if key is None: + return payload + + k = payload[key].astype(np.int64) + if k.size == 0: + return payload + + order = np.argsort(k, kind="mergesort") + k_sorted = k[order] + _, uniq_pos = np.unique(k_sorted, return_index=True) + keep = order[np.sort(uniq_pos)] + + out: Dict[str, np.ndarray] = {} + for name, v in payload.items(): + if isinstance(v, np.ndarray) and v.shape[0] == k.shape[0]: + out[name] = v[keep] + else: + out[name] = v + return out + + +def _gather_predict_payload( + preds, keys: Tuple[str, ...] +) -> Optional[Dict[str, np.ndarray]]: + """ + Packs requested keys from local preds, gathers to rank0 under DDP, merges, and + stable-sorts/dedupes by orig_idx (if present). + + Returns: + - payload dict on rank 0 + - None on non-rank0 in DDP + """ + local = _pack_keys_from_preds(preds, keys) + + gathered = _gather_object_to_rank0(local) + if gathered is None: + return None + + merged = _merge_packed_payloads(gathered) + merged = _stable_sort_and_dedupe(merged) + return merged class RegressionTrainer(pl.Trainer): @@ -11,22 +196,92 @@ class RegressionTrainer(pl.Trainer): Trains the contextualized.regression lightning_modules """ - def predict_params(self, model, dataloader): + @torch.no_grad() + def predict_params( + self, model: pl.LightningModule, dataloader + ) -> Tuple[np.ndarray, np.ndarray]: """ Returns context-specific regression models - beta (numpy.ndarray): (n, y_dim, x_dim) - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) """ preds = super().predict(model, dataloader) - return model._params_reshape(preds, dataloader) - def predict_y(self, model, dataloader): + payload = _gather_predict_payload(preds, keys=("idx", "orig_idx", "betas", "mus")) + if payload is None: + return None, None + + if "betas" not in payload or "mus" not in payload: + raise RuntimeError( + "predict_params: predict_step must return 'betas' and 'mus' (and ideally 'orig_idx')." + ) + + return payload["betas"], payload["mus"] + + @torch.no_grad() + def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ Returns context-specific predictions of the response Y - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) """ preds = super().predict(model, dataloader) - return model._y_reshape(preds, dataloader) + + payload = _gather_predict_payload( + preds, keys=("idx", "orig_idx", "contexts", "predictors", "betas", "mus") + ) + + if payload is None: + return None + + if "betas" not in payload or "mus" not in payload: + raise RuntimeError("predict_y: predict_step must return 'betas' and 'mus'.") + + betas = torch.as_tensor(payload["betas"]) + mus = torch.as_tensor(payload["mus"]) + + if ("contexts" in payload) and ("predictors" in payload): + C = torch.as_tensor(payload["contexts"]) + X = torch.as_tensor(payload["predictors"]) + else: + ds = getattr(dataloader, "dataset", None) + if ds is None: + raise RuntimeError( + "predict_y: dataloader has no .dataset; cannot reconstruct C/X." + ) + + idx_np = payload["idx"].astype(np.int64) + idx_t = torch.as_tensor(idx_np, dtype=torch.long) + + if hasattr(ds, "dataset") and hasattr(ds, "indices"): + base = ds.dataset + if not (hasattr(base, "C") and hasattr(base, "X")): + raise RuntimeError("predict_y: Subset base dataset must expose .C and .X.") + base_pos = np.asarray(ds.indices, dtype=np.int64)[idx_np] + base_pos_t = torch.as_tensor(base_pos, dtype=torch.long) + C = base.C[base_pos_t] + X = base.X[base_pos_t] + else: + if not (hasattr(ds, "C") and hasattr(ds, "X")): + raise RuntimeError( + "predict_y: dataset must expose .C and .X tensors for Option A prediction." + ) + C = ds.C[idx_t] + X = ds.X[idx_t] + + if torch.is_tensor(C): + C = C.to(dtype=betas.dtype) + else: + C = torch.as_tensor(C, dtype=betas.dtype) + + if torch.is_tensor(X): + X = X.to(dtype=betas.dtype) + else: + X = torch.as_tensor(X, dtype=betas.dtype) + + with torch.no_grad(): + yhat = model._predict_y(C, X, betas, mus).detach().cpu().numpy() + + return yhat class CorrelationTrainer(RegressionTrainer): @@ -34,16 +289,31 @@ class CorrelationTrainer(RegressionTrainer): Trains the contextualized.regression correlation lightning_modules """ - def predict_correlation(self, model, dataloader): + @torch.no_grad() + def predict_correlation(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ Returns context-specific correlation networks containing Pearson's correlation coefficient - correlation (numpy.ndarray): (n, x_dim, x_dim) """ - betas, _ = super().predict_params(model, dataloader) + preds = super().predict(model, dataloader) + preds_flat = _flatten_pl_predict_output(preds) + + if preds_flat and ("correlations" in preds_flat[0]): + payload = _gather_predict_payload(preds, keys=("orig_idx", "correlations")) + if payload is None: + return None + if "correlations" not in payload: + raise RuntimeError( + "predict_correlation: predict_step returned no 'correlations'." + ) + return payload["correlations"] + + betas, _ = self.predict_params(model, dataloader) + if betas is None: + return None + signs = np.sign(betas) - signs[signs != np.transpose(signs, (0, 2, 1))] = ( - 0 # remove asymmetric estimations - ) + signs[signs != np.transpose(signs, (0, 2, 1))] = 0 correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1)))) return correlations @@ -53,13 +323,35 @@ class MarkovTrainer(CorrelationTrainer): Trains the contextualized.regression markov graph lightning_modules """ - def predict_precision(self, model, dataloader): + @torch.no_grad() + def predict_precision(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ Returns context-specific precision matrix under a Gaussian graphical model Assuming all diagonal precisions are equal and constant over context, this is equivalent to the negative of the multivariate regression coefficient. - precision (numpy.ndarray): (n, x_dim, x_dim) """ - # A trick in the markov lightning_module predict_step makes makes the predict_correlation - # output equivalent to negative precision values here. return -super().predict_correlation(model, dataloader) + + +def choose_lightning_environment() -> LightningEnvironment: + """ + Returns the Lightning environment plugin used for single-process runs. + """ + return LightningEnvironment() + + +def make_trainer_with_env(trainer_cls, **trainer_kwargs): + """ + Factory that respects caller-provided `devices` and `strategy`. + Does not inject LightningEnvironment when torchrun is managing processes. + """ + import os + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + if "plugins" not in trainer_kwargs and world_size == 1: + env = choose_lightning_environment() + trainer_kwargs["plugins"] = [env] + + return trainer_cls(**trainer_kwargs) diff --git a/contextualized/tests.py b/contextualized/tests.py index 1e4e8718..0bff28d0 100644 --- a/contextualized/tests.py +++ b/contextualized/tests.py @@ -203,4 +203,4 @@ def test_save_load(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/contextualized/utils/__init__.py b/contextualized/utils/__init__.py new file mode 100644 index 00000000..00e7196b --- /dev/null +++ b/contextualized/utils/__init__.py @@ -0,0 +1,47 @@ +""" +Utility functions and simple helper predictors used across the library, +including saving/loading of contextualized models. +""" + +from __future__ import annotations + +import torch + + +def save(model, path: str) -> None: + """Save a model object to disk.""" + with open(path, "wb") as out_file: + torch.save(model, out_file) + + +def load(path: str): + """Load a model object from disk.""" + with open(path, "rb") as in_file: + # Newer torch supports weights_only; older versions do not. + try: + return torch.load(in_file, weights_only=False) + except TypeError: + return torch.load(in_file) + + +class DummyParamPredictor: + """Predicts parameters as all zeros (for unit tests / baselines).""" + + def __init__(self, beta_dim, mu_dim): + self.beta_dim = beta_dim + self.mu_dim = mu_dim + + def predict_params(self, *args): + n = len(args[0]) + return torch.zeros((n, *self.beta_dim)), torch.zeros((n, *self.mu_dim)) + + +class DummyYPredictor: + """Predicts Y values as all zeros (for unit tests / baselines).""" + + def __init__(self, y_dim): + self.y_dim = y_dim + + def predict_y(self, *args): + n = len(args[0]) + return torch.zeros((n, *self.y_dim))