diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index a989d46..604da99 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -107,7 +107,7 @@ def __init__(self, **kwargs): ) def predict_correlation( - self, C: np.ndarray, individual_preds: bool = True, squared: bool = True + self, C: np.ndarray, individual_preds: bool = False, squared: bool = True ) -> Union[np.ndarray, List[np.ndarray]]: """Predicts context-specific correlations between features. @@ -182,7 +182,7 @@ def __init__(self, **kwargs): super().__init__(ContextualizedMarkovGraph, [], [], MarkovTrainer, **kwargs) def predict_precisions( - self, C: np.ndarray, individual_preds: bool = True + self, C: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray]]: """Predicts context-specific precision matrices. Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1. @@ -434,6 +434,98 @@ def predict_networks( ) return betas + def _reconstruct_from_betas( + self, betas: np.ndarray, X_arr: np.ndarray + ) -> np.ndarray: + + """Reconstructs features from predicted betas. + + Args: + betas (np.ndarray): Coefficient matrices, shape (F, F) or (N, F, F). + X_arr (np.ndarray): Input data, shape (N, F). + + Returns: + np.ndarray: Reconstructed data, shape (N, F). + """ + + n_samples, n_features = X_arr.shape + + B = np.array(betas, copy=True) + if B.ndim == 2: + B = np.broadcast_to( + B[None, :, :], (n_samples, n_features, n_features) + ).copy() + elif B.ndim != 3: + raise ValueError(f"Expected betas 2D or 3D, got shape {B.shape}") + + # zero diagonal + idx = np.arange(n_features) + B[:, idx, idx] = 0.0 + + X_hat = dag_pred_np(X_arr, B) + return X_hat + + def predict( + self, + C: np.ndarray, + X: np.ndarray, + project_to_dag: bool = True, + individual_preds: bool = False, + **kwargs, + ) -> np.ndarray: + + """Predicts reconstructed data from context and features. + + Args: + C (np.ndarray): Contextual features, shape (N, K). + X (np.ndarray): Input data, shape (N, F). + project_to_dag (bool, optional): If True, enforce DAG structure. Defaults to True. + individual_preds (bool, optional): If True, return per-bootstrap predictions. Defaults to False. + **kwargs: Additional keyword arguments. + + Returns: + np.ndarray: Reconstructed predictions, shape (N, F), or (B, N, F) if individual_preds is True. + """ + X_scaled = self._maybe_scale_X(X) + + betas = self.predict_networks( + C, + project_to_dag=project_to_dag, + individual_preds=individual_preds, + **kwargs, + ) + + # unify iterable over bootstraps + is_bootstrap_stack = isinstance(betas, np.ndarray) and betas.ndim == 4 + if isinstance(betas, list) or is_bootstrap_stack: + if is_bootstrap_stack: + betas_iter = (betas[k] for k in range(betas.shape[0])) + else: + betas_iter = betas + + reconstructions = [ + self._reconstruct_from_betas(b, X_scaled) for b in betas_iter + ] + recon_stack = np.stack(reconstructions, axis=0) # (B, N, F) + + if self.normalize and self.scalers["X"] is not None: + recon_stack = np.stack( + [ + self.scalers["X"].inverse_transform(recon_stack[k]) + for k in range(recon_stack.shape[0]) + ], + axis=0, + ) + + if individual_preds: + return recon_stack # (B, N, F) + return self._nanrobust_mean(recon_stack, axis=0) # (N, F) + + reconstructed_scaled = self._reconstruct_from_betas(betas, X_scaled) + if self.normalize and self.scalers["X"] is not None: + return self.scalers["X"].inverse_transform(reconstructed_scaled) + return reconstructed_scaled + def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: diff --git a/contextualized/easy/tests.py b/contextualized/easy/tests.py index b22339d..f7f54c4 100644 --- a/contextualized/easy/tests.py +++ b/contextualized/easy/tests.py @@ -113,11 +113,13 @@ def test_correlation(self): encoder_type="ngam", num_archetypes=16 ) self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) - rho = model.predict_correlation(self.C, squared=False) + rho = model.predict_correlation(self.C, individual_preds=True, squared=False) assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim) rho = model.predict_correlation(self.C, individual_preds=False, squared=False) assert rho.shape == (self.n_samples, self.x_dim, self.x_dim), rho.shape - rho_squared = model.predict_correlation(self.C, squared=True) + rho_squared = model.predict_correlation( + self.C, individual_preds=True, squared=True + ) assert np.min(rho_squared) >= 0 assert rho_squared.shape == (1, self.n_samples, self.x_dim, self.x_dim)