diff --git a/conceptarium/conf/loss/joint_nll.yaml b/conceptarium/conf/loss/joint_nll.yaml new file mode 100644 index 00000000..c4df0879 --- /dev/null +++ b/conceptarium/conf/loss/joint_nll.yaml @@ -0,0 +1 @@ +_target_: "torch_concepts.nn.JointNLLLoss" diff --git a/examples/utilization/1_pgm/4_variable_elimination_inference.py b/examples/utilization/1_pgm/4_variable_elimination_inference.py new file mode 100644 index 00000000..e4e63749 --- /dev/null +++ b/examples/utilization/1_pgm/4_variable_elimination_inference.py @@ -0,0 +1,253 @@ +""" +Example: Differentiable Variable Elimination for Exact Inference +================================================================ + +Demonstrates training a concept-based Bayesian Network using differentiable +Variable Elimination (VE) and subsequently using VE for exact conditional +queries at test time. + +Scenario — Job Offer Model +--------------------------- +:: + + [Economy] [Talent] + \\ / + [Studies] + | + [JobOffer] + +All variables are binary (Bernoulli). A ``ToyDAGDataset`` generates +samples from the ground-truth BN and produces autoencoder embeddings. +The model takes each sample's embedding as input and predicts the +concept values through input-conditioned CPDs, trained by maximising +the log-likelihood via differentiable VE. + +Training +-------- +1. Each CPD's neural network takes the input embedding (concatenated + with parent-state features for child nodes) and outputs logits. +2. VE multiplies the per-sample factors to compute the per-sample + joint distribution P(economy, talent, studies, job_offer | x). +3. NLL loss is the negative log of the joint entry corresponding to + the observed concept values. +4. Gradients flow through VE back to the CPD network weights. + +Test-Time Queries +----------------- +Use VE (without input conditioning) to compute exact conditional +distributions such as: + +- P(studies) — marginal probability +- P(studies | economy=1) — forward query +- P(economy | job_offer=1) — explaining away +""" + +import torch +import numpy as np +from torch.distributions import Bernoulli + +from torch_concepts import ConceptVariable +from torch_concepts.nn import ParametricCPD, ProbabilisticModel, VariableEliminationInference +from torch_concepts.data.datasets.categorical_toy_dag import ToyDAGDataset + +# ── Ground truth CPTs ──────────────────────────────────────────────── +GT_P_ECONOMY = 0.7 +GT_P_TALENT = 0.6 +GT_P_STUDIES = {(0, 0): 0.1, (0, 1): 0.4, (1, 0): 0.5, (1, 1): 0.9} +GT_P_JOB = {0: 0.2, 1: 0.8} + +NODE_NAMES = ["economy", "talent", "studies", "job_offer"] +COL = {name: i for i, name in enumerate(NODE_NAMES)} + +N_SAMPLES = 5000 +N_EPOCHS = 2000 +LR = 0.05 +EMB_DIM = 8 + + +def main(): + # ── 1. Generate dataset via ToyDAGDataset ──────────────────────── + print("Generating data via ToyDAGDataset ...") + + cpt_studies = np.zeros((2, 2, 2)) + for (e, t), p in GT_P_STUDIES.items(): + cpt_studies[1, e, t] = p + cpt_studies[0, e, t] = 1.0 - p + + cpt_job = np.zeros((2, 2)) + for s, p in GT_P_JOB.items(): + cpt_job[1, s] = p + cpt_job[0, s] = 1.0 - p + + dataset = ToyDAGDataset( + variables=NODE_NAMES, + cardinalities={n: 2 for n in NODE_NAMES}, + dag=[("economy", "studies"), ("talent", "studies"), + ("studies", "job_offer")], + conditional_probs={("studies",): cpt_studies, + ("studies", "job_offer"): cpt_job}, + root_priors={"economy": np.array([1 - GT_P_ECONOMY, GT_P_ECONOMY]), + "talent": np.array([1 - GT_P_TALENT, GT_P_TALENT])}, + seed=42, + n_gen=N_SAMPLES, + autoencoder_kwargs={"latent_dim": EMB_DIM, "epochs": 200}, + ) + + # Extract embeddings and concept labels + embeddings = dataset.input_data # (N, EMB_DIM) + concepts = dataset.concepts # (N, n_concepts) + data = concepts.float() # concept labels as float + + print(f"\nDataset: {N_SAMPLES} samples, embedding dim = {EMB_DIM}") + print(f"Empirical frequencies: " + f"economy={data[:, 0].mean():.3f} " + f"talent={data[:, 1].mean():.3f} " + f"studies={data[:, 2].mean():.3f} " + f"job_offer={data[:, 3].mean():.3f}") + + # ── 2. Build model (input-conditioned CPDs) ────────────────────── + print("\nBuilding model ...") + model = ProbabilisticModel( + variables=[ConceptVariable("economy", distribution=Bernoulli), + ConceptVariable("talent", distribution=Bernoulli), + ConceptVariable("studies", distribution=Bernoulli), + ConceptVariable("job_offer", distribution=Bernoulli)], + factors=[ + ParametricCPD("economy", + parametrization=torch.nn.Linear(EMB_DIM, 1)), + ParametricCPD("talent", + parametrization=torch.nn.Linear(EMB_DIM, 1)), + ParametricCPD("studies", + parametrization=torch.nn.Linear(2, 1), + parents=["economy", "talent"]), + ParametricCPD("job_offer", + parametrization=torch.nn.Linear(1, 1), + parents=["studies"]), + ], + ) + + # ── 3. Train via VE with input embeddings ──────────────────────── + print(f"\nTraining via differentiable VE ({N_EPOCHS} epochs) ...") + model.train() + ve = VariableEliminationInference(model) + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + idx = data.long() + + for epoch in range(N_EPOCHS): + optimizer.zero_grad() + out = ve.query(query=NODE_NAMES, evidence={'input': embeddings}, + return_log_joint=True) + log_joint = out['log_joint'] # (N, 2, 2, 2, 2) + # Index each sample's observed state + sample_idx = torch.arange(idx.size(0)) + loss = -log_joint[sample_idx, idx[:, 0], idx[:, 1], + idx[:, 2], idx[:, 3]].mean() + loss.backward() + optimizer.step() + if epoch % 200 == 0 or epoch == N_EPOCHS - 1: + print(f" Epoch {epoch:4d} NLL = {loss.item():.4f}") + + # ── 4. VE queries (averaged over embeddings) vs empirical ───────── + model.eval() + + def empirical_cond(query_col, query_val, evidence): + mask = torch.ones(data.size(0), dtype=torch.bool) + for col, val in evidence.items(): + mask &= data[:, col] == val + subset = data[mask] + if subset.size(0) == 0: + return float('nan') + return (subset[:, query_col] == query_val).float().mean().item() + + def empirical_joint_cond(query_cols, query_vals, evidence): + mask = torch.ones(data.size(0), dtype=torch.bool) + for col, val in evidence.items(): + mask &= data[:, col] == val + subset = data[mask] + if subset.size(0) == 0: + return float('nan') + match = torch.ones(subset.size(0), dtype=torch.bool) + for c, v in zip(query_cols, query_vals): + match &= subset[:, c] == v + return match.float().mean().item() + + print("\n" + "=" * 60) + print("VE Queries (averaged over embeddings) vs Empirical") + print("=" * 60) + + def ve_query_avg(query_vars, evidence): + """Run batched VE and average P(query|x) over matching embeddings.""" + if evidence: + mask = torch.ones(data.size(0), dtype=torch.bool) + for k, v in evidence.items(): + mask &= data[:, COL[k]] == v + embs = embeddings[mask] + else: + embs = embeddings + ev = dict(evidence) + ev['input'] = embs + # Get per-concept probabilities, average over batch + probs = ve.query(query=query_vars, evidence=ev) # (N, n_features) + return probs.mean(dim=0) # average over batch + + with torch.no_grad(): + # Marginals: E_x[ P(var | x) ] + print("\n--- Marginal probabilities ---") + print(f" {'query':<45s} {'VE':>8s} {'Empirical':>9s}") + for var in NODE_NAMES: + avg = ve_query_avg([var], {}) + ve_p = avg.item() + emp_p = empirical_cond(COL[var], 1, {}) + print(f" P({var}=1){'':<35s} {ve_p:8.4f} {emp_p:9.4f}") + + # Forward queries + print("\n--- Forward queries ---") + print(f" {'query':<45s} {'VE':>8s} {'Empirical':>9s}") + for qvar, ev in [("studies", {"economy": 1}), + ("studies", {"economy": 1, "talent": 1}), + ("job_offer", {"studies": 1})]: + avg = ve_query_avg([qvar], ev) + ve_p = avg.item() + emp_p = empirical_cond(COL[qvar], 1, + {COL[k]: v for k, v in ev.items()}) + ev_str = ", ".join(f"{k}={v}" for k, v in ev.items()) + label = f"P({qvar}=1 | {ev_str})" + print(f" {label:<45s} {ve_p:8.4f} {emp_p:9.4f}") + + # Explaining away + print("\n--- Explaining-away queries ---") + print(f" {'query':<45s} {'VE':>8s} {'Empirical':>9s}") + for qvar, ev in [("economy", {"job_offer": 1}), + ("talent", {"job_offer": 1}), + ("economy", {"job_offer": 1, "talent": 1})]: + avg = ve_query_avg([qvar], ev) + ve_p = avg.item() + emp_p = empirical_cond(COL[qvar], 1, + {COL[k]: v for k, v in ev.items()}) + ev_str = ", ".join(f"{k}={v}" for k, v in ev.items()) + label = f"P({qvar}=1 | {ev_str})" + print(f" {label:<45s} {ve_p:8.4f} {emp_p:9.4f}") + + # Joint conditional — use return_log_joint for multi-variable joint + print("\n--- Joint conditional queries ---") + print(f" {'query':<45s} {'VE':>8s} {'Empirical':>9s}") + ev = {"job_offer": 1} + mask = data[:, COL["job_offer"]] == 1 + embs = embeddings[mask] + out = ve.query(query=["economy", "talent"], + evidence={'input': embs, "job_offer": 1}, + return_log_joint=True) + # Exponentiate log-joint to get P(economy, talent | job_offer=1) + avg = torch.exp(out['log_joint']).mean(dim=0) # (2, 2) + emp_ev = {COL["job_offer"]: 1} + for e in range(2): + for t in range(2): + ve_p = avg[e, t].item() + emp_p = empirical_joint_cond( + [COL["economy"], COL["talent"]], [e, t], emp_ev) + label = f"P(economy={e}, talent={t} | job_offer=1)" + print(f" {label:<45s} {ve_p:8.4f} {emp_p:9.4f}") + + +if __name__ == "__main__": + main() diff --git a/tests/nn/modules/low/inference/test_intervention.py b/tests/nn/modules/low/inference/test_intervention.py index d398dfa3..d5a8c31a 100644 --- a/tests/nn/modules/low/inference/test_intervention.py +++ b/tests/nn/modules/low/inference/test_intervention.py @@ -5,7 +5,7 @@ _set_submodule, _as_list, ) -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.low.inference.intervention import ( _GlobalPolicyState, ) diff --git a/tests/nn/modules/mid/inference/test_detach_concepts.py b/tests/nn/modules/mid/inference/test_detach_concepts.py index ae0fa835..dec6e742 100644 --- a/tests/nn/modules/mid/inference/test_detach_concepts.py +++ b/tests/nn/modules/mid/inference/test_detach_concepts.py @@ -17,7 +17,7 @@ from torch_concepts import LatentVariable, ConceptVariable, ExogenousVariable from torch_concepts.distributions import Delta from torch_concepts.nn import DeterministicInference -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel from torch_concepts.nn.modules.low.predictors.linear import LinearConceptToConcept diff --git a/tests/nn/modules/mid/inference/test_deterministic.py b/tests/nn/modules/mid/inference/test_deterministic.py index b8588649..a58129b2 100644 --- a/tests/nn/modules/mid/inference/test_deterministic.py +++ b/tests/nn/modules/mid/inference/test_deterministic.py @@ -22,7 +22,7 @@ from torch_concepts.nn.modules.mid.models.variable import ( Variable, ConceptVariable, LatentVariable, ) -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel from torch_concepts.nn.modules.low.predictors.linear import LinearConceptToConcept diff --git a/tests/nn/modules/mid/inference/test_forward.py b/tests/nn/modules/mid/inference/test_forward.py index 738bed3d..29fd5eff 100644 --- a/tests/nn/modules/mid/inference/test_forward.py +++ b/tests/nn/modules/mid/inference/test_forward.py @@ -15,7 +15,7 @@ from torch_concepts.nn import AncestralSamplingInference, DeterministicInference, WANDAGraphLearner, GraphModel, LazyConstructor, LinearLatentToExogenous, \ LinearExogenousToConcept, HyperlinearConceptExogenousToConcept from torch_concepts.nn.modules.mid.models.variable import Variable -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel from torch_concepts.nn.modules.mid.inference.forward import ForwardInference from torch_concepts.distributions import Delta diff --git a/tests/nn/modules/mid/inference/test_independent.py b/tests/nn/modules/mid/inference/test_independent.py index 9b22e0ba..3faf3e47 100644 --- a/tests/nn/modules/mid/inference/test_independent.py +++ b/tests/nn/modules/mid/inference/test_independent.py @@ -17,7 +17,7 @@ from torch_concepts import InputVariable, EndogenousVariable, ExogenousVariable from torch_concepts.nn.modules.mid.models.variable import Variable -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel from torch_concepts.nn.modules.mid.inference.independent import IndependentInference from torch_concepts.nn.modules.mid.inference.deterministic import DeterministicInference diff --git a/tests/nn/modules/mid/inference/test_variable_elimination.py b/tests/nn/modules/mid/inference/test_variable_elimination.py new file mode 100644 index 00000000..9a6854f8 --- /dev/null +++ b/tests/nn/modules/mid/inference/test_variable_elimination.py @@ -0,0 +1,474 @@ +""" +Comprehensive tests for VariableEliminationInference. + +Tests cover: +- Unbatched marginal queries (single + all variables) +- Batched queries (input-conditioned CPDs) +- Evidence conditioning (set_evidence path) +- return_logits mode +- return_log_joint mode +- Elimination ordering: min-degree heuristic and user-provided +- Order caching +- ground_truth_to_evidence pass-through +- _factor_to_tensor for binary and categorical variables +- Gradient flow through the full VE pipeline +- Numerical correctness against brute-force enumeration +""" +import pytest +import torch +import torch.nn as nn +from torch.distributions import Bernoulli, Categorical + +from torch_concepts.distributions import Delta +from torch_concepts.nn.modules.mid.inference.variable_elimination import ( + VariableEliminationInference, + _min_degree_order, +) +from torch_concepts.nn.modules.mid.models.variable import ( + ConceptVariable, + LatentVariable, +) +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel +from torch_concepts.nn.modules.mid.models.factor import Factor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_binary_chain(latent_dim=4): + """Build input -> A -> B, both Bernoulli.""" + input_var = LatentVariable('input', distribution=Delta, size=latent_dim) + var_A = ConceptVariable('A', distribution=Bernoulli, size=1) + var_B = ConceptVariable('B', distribution=Bernoulli, size=1) + + cpd_input = ParametricCPD('input', parametrization=nn.Identity()) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(latent_dim, 1), + parents=['input']) + cpd_B = ParametricCPD('B', parametrization=nn.Linear(1, 1), + parents=['A']) + + pgm = ProbabilisticModel( + variables=[input_var, var_A, var_B], + factors=[cpd_input, cpd_A, cpd_B], + ) + return pgm + + +def _make_diamond(latent_dim=4): + """Build input -> A -> C, input -> B -> C (diamond / v-structure).""" + input_var = LatentVariable('input', distribution=Delta, size=latent_dim) + var_A = ConceptVariable('A', distribution=Bernoulli, size=1) + var_B = ConceptVariable('B', distribution=Bernoulli, size=1) + var_C = ConceptVariable('C', distribution=Bernoulli, size=1) + + cpd_input = ParametricCPD('input', parametrization=nn.Identity()) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(latent_dim, 1), + parents=['input']) + cpd_B = ParametricCPD('B', parametrization=nn.Linear(latent_dim, 1), + parents=['input']) + cpd_C = ParametricCPD('C', parametrization=nn.Linear(2, 1), + parents=['A', 'B']) + + pgm = ProbabilisticModel( + variables=[input_var, var_A, var_B, var_C], + factors=[cpd_input, cpd_A, cpd_B, cpd_C], + ) + return pgm + + +def _make_categorical_chain(latent_dim=4, k=3): + """Build input -> A (Categorical K) -> B (Bernoulli).""" + input_var = LatentVariable('input', distribution=Delta, size=latent_dim) + var_A = ConceptVariable('A', distribution=Categorical, size=k) + var_B = ConceptVariable('B', distribution=Bernoulli, size=1) + + cpd_input = ParametricCPD('input', parametrization=nn.Identity()) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(latent_dim, k), + parents=['input']) + cpd_B = ParametricCPD('B', parametrization=nn.Linear(k, 1), + parents=['A']) + + pgm = ProbabilisticModel( + variables=[input_var, var_A, var_B], + factors=[cpd_input, cpd_A, cpd_B], + ) + return pgm + + +# =========================================================================== +# _min_degree_order +# =========================================================================== + +class TestMinDegreeOrder: + """Tests for the min-degree elimination ordering heuristic.""" + + def test_returns_all_variables(self): + cards = {'A': 2, 'B': 2, 'C': 2} + fa = Factor(torch.ones(2, 2), ['A', 'B'], cards) + fb = Factor(torch.ones(2, 2), ['B', 'C'], cards) + order = _min_degree_order([fa, fb], ['A', 'B', 'C']) + assert set(order) == {'A', 'B', 'C'} + assert len(order) == 3 + + def test_leaf_eliminated_before_hub(self): + """In a chain A-B-C, the leaves A and C have degree 1 + while B has degree 2 — B should not be eliminated first.""" + cards = {'A': 2, 'B': 2, 'C': 2} + fa = Factor(torch.ones(2, 2), ['A', 'B'], cards) + fb = Factor(torch.ones(2, 2), ['B', 'C'], cards) + order = _min_degree_order([fa, fb], ['A', 'B', 'C']) + # B has the highest degree, so it should not be first + assert order[0] != 'B' + + def test_single_variable(self): + f = Factor(torch.ones(2), ['A'], {'A': 2}) + order = _min_degree_order([f], ['A']) + assert order == ['A'] + + def test_empty_elimination(self): + f = Factor(torch.ones(2), ['A'], {'A': 2}) + order = _min_degree_order([f], []) + assert order == [] + + +# =========================================================================== +# query — basic +# =========================================================================== + +class TestVEQueryUnbatched: + """Unbatched queries (no input tensor).""" + + def test_marginal_single_var_no_input(self): + """Query P(A) from a chain A -> B without input.""" + var_A = ConceptVariable('A', distribution=Bernoulli, size=1) + var_B = ConceptVariable('B', distribution=Bernoulli, size=1) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(1, 1)) + cpd_B = ParametricCPD('B', parametrization=nn.Linear(1, 1), + parents=['A']) + pgm = ProbabilisticModel( + variables=[var_A, var_B], + factors=[cpd_A, cpd_B], + ) + ve = VariableEliminationInference(pgm) + result = ve.query(['A']) + assert result.ndim == 1 # (n_features,) + assert result.shape[-1] == 1 # binary → 1 column + + def test_marginal_probs_sum_to_one(self): + """P(A=0) + P(A=1) = 1 for Bernoulli.""" + var_A = ConceptVariable('A', distribution=Bernoulli, size=1) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(1, 1)) + pgm = ProbabilisticModel(variables=[var_A], factors=[cpd_A]) + ve = VariableEliminationInference(pgm) + result = ve.query(['A']) # P(A=1) + p1 = result.item() + assert 0.0 <= p1 <= 1.0 + + def test_conditional_with_evidence(self): + """Query P(B | A=1) from A -> B.""" + var_A = ConceptVariable('A', distribution=Bernoulli, size=1) + var_B = ConceptVariable('B', distribution=Bernoulli, size=1) + cpd_A = ParametricCPD('A', parametrization=nn.Linear(1, 1)) + cpd_B = ParametricCPD('B', parametrization=nn.Linear(1, 1), + parents=['A']) + pgm = ProbabilisticModel( + variables=[var_A, var_B], factors=[cpd_A, cpd_B]) + ve = VariableEliminationInference(pgm) + result = ve.query(['B'], evidence={'A': 1}) + p1 = result.item() + assert 0.0 <= p1 <= 1.0 + + +class TestVEQueryBatched: + """Batched queries with input-conditioned CPDs.""" + + def test_batched_output_shape(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(8, 4) + result = ve.query(['A', 'B'], evidence={'input': x}) + assert result.shape == (8, 2) # 2 binary → 2 columns + + def test_batched_probs_valid(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + result = ve.query(['A', 'B'], evidence={'input': x}) + assert (result >= 0).all() and (result <= 1).all() + + def test_batched_with_evidence(self): + pgm = _make_diamond() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + result = ve.query(['C'], evidence={'input': x, 'A': 1}) + assert result.shape == (4, 1) + assert (result >= 0).all() and (result <= 1).all() + + +# =========================================================================== +# query — return modes +# =========================================================================== + +class TestVEReturnModes: + """Test return_logits and return_log_joint.""" + + def test_return_logits_binary(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + logits = ve.query(['A', 'B'], evidence={'input': x}, + return_logits=True) + # Logits can be any real number + assert logits.shape == (4, 2) + # Applying sigmoid should recover probabilities + probs = torch.sigmoid(logits) + assert (probs >= 0).all() and (probs <= 1).all() + + def test_return_logits_matches_probs(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + probs = ve.query(['A', 'B'], evidence={'input': x}) + logits = ve.query(['A', 'B'], evidence={'input': x}, + return_logits=True) + recovered = torch.sigmoid(logits) + torch.testing.assert_close(recovered, probs, atol=1e-5, rtol=1e-5) + + def test_return_log_joint_keys(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + assert isinstance(out, dict) + assert 'log_joint' in out + assert 'logits' in out + + def test_return_log_joint_shape(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + assert out['log_joint'].shape == (4, 2, 2) # (batch, card_A, card_B) + assert out['logits'].shape == (4, 2) + + def test_log_joint_sums_to_one(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + joint = torch.exp(out['log_joint']) + sums = joint.sum(dim=(1, 2)) + torch.testing.assert_close(sums, torch.ones(4), atol=1e-5, rtol=1e-5) + + def test_return_log_joint_takes_precedence(self): + """return_log_joint=True should return dict even if return_logits=True.""" + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_logits=True, return_log_joint=True) + assert isinstance(out, dict) + + +# =========================================================================== +# Categorical variables +# =========================================================================== + +class TestVECategorical: + """Tests with categorical variables.""" + + def test_categorical_prob_shape(self): + pgm = _make_categorical_chain(k=3) + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + result = ve.query(['A', 'B'], evidence={'input': x}) + # Categorical(3) → 3 columns + Bernoulli → 1 column = 4 + assert result.shape == (4, 4) + + def test_categorical_probs_sum_to_one(self): + pgm = _make_categorical_chain(k=3) + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + result = ve.query(['A', 'B'], evidence={'input': x}) + cat_probs = result[:, :3] # first 3 columns for A + sums = cat_probs.sum(dim=1) + torch.testing.assert_close(sums, torch.ones(4), atol=1e-5, rtol=1e-5) + + def test_categorical_logits(self): + pgm = _make_categorical_chain(k=3) + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + logits = ve.query(['A', 'B'], evidence={'input': x}, + return_logits=True) + assert logits.shape == (4, 4) + + def test_log_joint_shape_categorical(self): + pgm = _make_categorical_chain(k=3) + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + # A has 3 states, B has 2 states → (batch, 3, 2) + assert out['log_joint'].shape == (4, 3, 2) + + +# =========================================================================== +# Elimination ordering +# =========================================================================== + +class TestEliminationOrdering: + """Test user-provided and cached orderings.""" + + def test_user_provided_order(self): + pgm = _make_diamond() + ve = VariableEliminationInference(pgm, + elimination_order=['A', 'B', 'C']) + x = torch.randn(4, 4) + result = ve.query(['C'], evidence={'input': x}) + assert result.shape == (4, 1) + + def test_order_caching(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + # First call fills cache + ve.query(['A', 'B'], evidence={'input': x}) + assert len(ve._order_cache) == 1 + # Second call with same query/evidence pattern uses cache + ve.query(['A', 'B'], evidence={'input': x}) + assert len(ve._order_cache) == 1 + + def test_different_queries_separate_cache(self): + pgm = _make_diamond() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + ve.query(['A'], evidence={'input': x}) + ve.query(['C'], evidence={'input': x}) + assert len(ve._order_cache) == 2 + + +# =========================================================================== +# ground_truth_to_evidence +# =========================================================================== + +class TestGroundTruthToEvidence: + """VE's ground_truth_to_evidence is identity.""" + + def test_passthrough(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + t = torch.tensor([0, 1, 0, 1]) + result = ve.ground_truth_to_evidence(t, cardinality=2) + assert result is t + + +# =========================================================================== +# Gradient flow +# =========================================================================== + +class TestVEGradientFlow: + """Ensure VE is fully differentiable.""" + + def test_grad_through_query(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + result = ve.query(['A', 'B'], evidence={'input': x}) + result.sum().backward() + # CPD parameters should have gradients + for p in pgm.parameters(): + if p.requires_grad: + assert p.grad is not None + + def test_grad_through_log_joint(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + out['log_joint'].sum().backward() + for p in pgm.parameters(): + if p.requires_grad: + assert p.grad is not None + + def test_grad_through_logits(self): + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + logits = ve.query(['A', 'B'], evidence={'input': x}, + return_logits=True) + logits.sum().backward() + for p in pgm.parameters(): + if p.requires_grad: + assert p.grad is not None + + +# =========================================================================== +# Numerical correctness — brute-force comparison +# =========================================================================== + +class TestVENumericalCorrectness: + """Compare VE results against brute-force enumeration.""" + + def _brute_force_joint(self, pgm, x): + """Compute joint by building all factors and multiplying them.""" + factors = pgm.build_factors(input=x) + result = factors[0] + for f in factors[1:]: + result = result.product(f) + _, normalised = result.normalize() + return normalised + + def test_marginal_matches_brute_force(self): + """P(A) from VE should match marginalising B from joint.""" + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + + # VE result + p_a_ve = ve.query(['A'], evidence={'input': x}) + + # Brute-force: build joint P(A,B), marginalise B + joint = self._brute_force_joint(pgm, x) + # joint has variables ['A', 'B'], shape (batch, 2, 2) + p_a_bf = joint.values.sum(dim=2)[:, 1] # P(A=1) + + torch.testing.assert_close(p_a_ve.squeeze(), p_a_bf, + atol=1e-5, rtol=1e-5) + + def test_conditional_matches_brute_force(self): + """P(B|A=1) from VE should match slicing and normalising joint.""" + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + + # VE + p_b_given_a1 = ve.query(['B'], evidence={'input': x, 'A': 1}) + + # Brute-force: P(A=1, B) / P(A=1) + joint = self._brute_force_joint(pgm, x) + p_a1_b = joint.values[:, 1, :] # (batch, 2) — P(A=1, B=0) and P(A=1, B=1) + p_a1 = p_a1_b.sum(dim=1, keepdim=True) + p_b_given_a1_bf = (p_a1_b / p_a1)[:, 1] # P(B=1 | A=1) + + torch.testing.assert_close(p_b_given_a1.squeeze(), p_b_given_a1_bf, + atol=1e-5, rtol=1e-5) + + def test_log_joint_matches_brute_force(self): + """log P(A,B) from return_log_joint should match log of brute-force joint.""" + pgm = _make_binary_chain() + ve = VariableEliminationInference(pgm) + x = torch.randn(4, 4) + + out = ve.query(['A', 'B'], evidence={'input': x}, + return_log_joint=True) + bf = self._brute_force_joint(pgm, x) + expected_log = torch.log(bf.values.clamp(min=1e-10)) + + torch.testing.assert_close(out['log_joint'], expected_log, + atol=1e-5, rtol=1e-5) diff --git a/tests/nn/modules/mid/models/test_cpd.py b/tests/nn/modules/mid/models/test_cpd.py index f4c0f978..86991562 100644 --- a/tests/nn/modules/mid/models/test_cpd.py +++ b/tests/nn/modules/mid/models/test_cpd.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.distributions import Bernoulli, OneHotCategorical -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.variable import Variable, ConceptVariable from torch_concepts.distributions import Delta @@ -414,14 +414,6 @@ def test_get_parent_combinations_delta_parent(self): inputs, states = cpd._get_parent_combinations() self.assertIsNotNone(inputs) - def test_build_cpt_without_variable(self): - """Test build_cpt raises error when variable not linked.""" - module = nn.Linear(10, 1) - cpd = ParametricCPD(concepts='concept', parametrization=module) - - with self.assertRaises(RuntimeError): - cpd.build_cpt() - class TestParametricCPDParentCap(unittest.TestCase): """Test _get_parent_combinations caps exponential blowup.""" @@ -446,21 +438,5 @@ def test_within_cap_succeeds(self): self.assertEqual(inputs.shape[0], 8) # 2^3 = 8 -class TestParametricCPDSharedGuard(unittest.TestCase): - """Test shared CPD guard in build_cpt / build_potential.""" - - def test_shared_cpd_build_cpt_raises(self): - """build_cpt should raise NotImplementedError for shared CPDs.""" - cpd = ParametricCPD(concepts='A', parametrization=nn.Linear(5, 2), shared=True) - with self.assertRaises(NotImplementedError): - cpd.build_cpt() - - def test_shared_cpd_build_potential_raises(self): - """build_potential should raise NotImplementedError for shared CPDs.""" - cpd = ParametricCPD(concepts='A', parametrization=nn.Linear(5, 2), shared=True) - with self.assertRaises(NotImplementedError): - cpd.build_potential() - - if __name__ == '__main__': unittest.main() diff --git a/tests/nn/modules/mid/models/test_cpd_parent_preservation.py b/tests/nn/modules/mid/models/test_cpd_parent_preservation.py index f06a63fe..ba68faaa 100644 --- a/tests/nn/modules/mid/models/test_cpd_parent_preservation.py +++ b/tests/nn/modules/mid/models/test_cpd_parent_preservation.py @@ -20,7 +20,7 @@ import torch.nn as nn from torch.distributions import Bernoulli, OneHotCategorical from torch_concepts.nn.modules.mid.models.variable import Variable -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel from torch_concepts.nn.modules.low.lazy import LazyConstructor diff --git a/tests/nn/modules/mid/models/test_factor.py b/tests/nn/modules/mid/models/test_factor.py new file mode 100644 index 00000000..2bd81ce0 --- /dev/null +++ b/tests/nn/modules/mid/models/test_factor.py @@ -0,0 +1,323 @@ +""" +Comprehensive tests for the Factor class. + +Tests cover: +- Construction: shape validation, batched vs unbatched, cardinality checks +- product(): scope union, shared variables, einsum, batched × unbatched +- marginalize(): sum-out correctness, error on missing variable +- set_evidence(): slicing, out-of-range errors +- normalize(): partition function, batched normalisation +- _align(): permutation and unsqueezing for broadcast +- __mul__: shorthand for product +- __repr__: string representation +""" +import pytest +import torch + +from torch_concepts.nn.modules.mid.models.factor import Factor, _EINSUM_SUBSCRIPTS + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _cards(*names_and_sizes): + """Build a cardinality dict from (name, size) pairs.""" + return {n: s for n, s in names_and_sizes} + + +# =========================================================================== +# Construction +# =========================================================================== + +class TestFactorConstruction: + """Tests for Factor.__init__.""" + + def test_unbatched_valid(self): + f = Factor(torch.ones(2, 3), ['A', 'B'], _cards(('A', 2), ('B', 3))) + assert f.variables == ['A', 'B'] + assert not f.batched + + def test_batched_valid(self): + f = Factor(torch.ones(4, 2, 3), ['A', 'B'], + _cards(('A', 2), ('B', 3)), batched=True) + assert f.batched + assert f.values.shape == (4, 2, 3) + + def test_wrong_ndim_raises(self): + with pytest.raises(ValueError, match="dimensions"): + Factor(torch.ones(2, 3), ['A'], _cards(('A', 2))) + + def test_wrong_card_raises(self): + with pytest.raises(ValueError, match="cardinality"): + Factor(torch.ones(2, 3), ['A', 'B'], + _cards(('A', 2), ('B', 5))) + + def test_batched_wrong_ndim(self): + with pytest.raises(ValueError, match="dimensions"): + Factor(torch.ones(2, 3), ['A', 'B'], + _cards(('A', 2), ('B', 3)), batched=True) + + def test_extra_cardinalities_ignored(self): + cards = _cards(('A', 2), ('B', 3), ('C', 4)) + f = Factor(torch.ones(2, 3), ['A', 'B'], cards) + assert f.variables == ['A', 'B'] + + +# =========================================================================== +# Product +# =========================================================================== + +class TestFactorProduct: + """Tests for Factor.product (and __mul__).""" + + def test_disjoint_scopes(self): + fa = Factor(torch.tensor([0.6, 0.4]), ['A'], _cards(('A', 2), ('B', 3))) + fb = Factor(torch.tensor([0.1, 0.3, 0.6]), ['B'], _cards(('A', 2), ('B', 3))) + p = fa.product(fb) + assert set(p.variables) == {'A', 'B'} + # Values should be the outer product + expected = torch.tensor([0.6, 0.4]).unsqueeze(1) * torch.tensor([0.1, 0.3, 0.6]).unsqueeze(0) + torch.testing.assert_close(p.values, expected) + + def test_shared_scope(self): + cards = _cards(('A', 2), ('B', 2)) + fab = Factor(torch.tensor([[0.3, 0.7], [0.9, 0.1]]), ['A', 'B'], cards) + fb = Factor(torch.tensor([0.4, 0.6]), ['B'], cards) + p = fab.product(fb) + assert p.variables == ['A', 'B'] + expected = torch.tensor([[0.3 * 0.4, 0.7 * 0.6], + [0.9 * 0.4, 0.1 * 0.6]]) + torch.testing.assert_close(p.values, expected) + + def test_mul_shorthand(self): + fa = Factor(torch.tensor([0.5, 0.5]), ['A'], _cards(('A', 2))) + fb = Factor(torch.tensor([0.3, 0.7]), ['A'], _cards(('A', 2))) + p1 = fa.product(fb) + p2 = fa * fb + torch.testing.assert_close(p1.values, p2.values) + + def test_batched_times_unbatched(self): + cards = _cards(('A', 2)) + fa = Factor(torch.rand(8, 2), ['A'], cards, batched=True) + fb = Factor(torch.tensor([0.3, 0.7]), ['A'], cards, batched=False) + p = fa.product(fb) + assert p.batched + assert p.values.shape == (8, 2) + torch.testing.assert_close(p.values, fa.values * fb.values.unsqueeze(0)) + + def test_unbatched_times_batched(self): + cards = _cards(('A', 2)) + fa = Factor(torch.tensor([0.3, 0.7]), ['A'], cards) + fb = Factor(torch.rand(8, 2), ['A'], cards, batched=True) + p = fa.product(fb) + assert p.batched + assert p.values.shape == (8, 2) + + def test_batched_times_batched(self): + cards = _cards(('A', 2), ('B', 3)) + fa = Factor(torch.rand(4, 2), ['A'], cards, batched=True) + fb = Factor(torch.rand(4, 3), ['B'], cards, batched=True) + p = fa.product(fb) + assert p.batched + assert p.values.shape == (4, 2, 3) + + def test_scope_too_large_raises(self): + """Product whose union scope exceeds the einsum subscript limit.""" + n = len(_EINSUM_SUBSCRIPTS) + # Build two factors whose union scope has n+1 vars (exceeds limit) + vars_a = [f'v{i}' for i in range(n)] + vars_b = [f'v{n}'] # one new var beyond the limit + # Don't put cardinalities so the constructor won't validate sizes + fa = Factor(torch.ones(*([1] * n)), vars_a, {}) + fb = Factor(torch.ones(1), vars_b, {}) + with pytest.raises(ValueError, match="einsum limit"): + fa.product(fb) + + def test_product_shares_cardinalities(self): + cards = _cards(('A', 2), ('B', 3)) + fa = Factor(torch.ones(2), ['A'], cards) + fb = Factor(torch.ones(3), ['B'], cards) + p = fa.product(fb) + assert p.cardinalities is cards + + +# =========================================================================== +# Marginalize +# =========================================================================== + +class TestFactorMarginalize: + """Tests for Factor.marginalize.""" + + def test_sum_out_one_var(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]) + f = Factor(vals, ['A', 'B'], cards) + m = f.marginalize('B') + assert m.variables == ['A'] + torch.testing.assert_close(m.values, torch.tensor([6.0, 15.0])) + + def test_sum_out_first_var(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]) + f = Factor(vals, ['A', 'B'], cards) + m = f.marginalize('A') + assert m.variables == ['B'] + torch.testing.assert_close(m.values, torch.tensor([5.0, 7.0, 9.0])) + + def test_batched_marginalize(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.rand(8, 2, 3) + f = Factor(vals, ['A', 'B'], cards, batched=True) + m = f.marginalize('A') + assert m.batched + assert m.values.shape == (8, 3) + torch.testing.assert_close(m.values, vals.sum(dim=1)) + + def test_missing_variable_raises(self): + f = Factor(torch.ones(2), ['A'], _cards(('A', 2))) + with pytest.raises(ValueError, match="not in the factor scope"): + f.marginalize('X') + + +# =========================================================================== +# set_evidence +# =========================================================================== + +class TestFactorSetEvidence: + """Tests for Factor.set_evidence.""" + + def test_slice_first_var(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]) + f = Factor(vals, ['A', 'B'], cards) + e = f.set_evidence('A', 0) + assert e.variables == ['B'] + torch.testing.assert_close(e.values, torch.tensor([1.0, 2.0, 3.0])) + + def test_slice_second_var(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]) + f = Factor(vals, ['A', 'B'], cards) + e = f.set_evidence('B', 2) + assert e.variables == ['A'] + torch.testing.assert_close(e.values, torch.tensor([3.0, 6.0])) + + def test_batched_evidence(self): + cards = _cards(('A', 2), ('B', 3)) + vals = torch.rand(8, 2, 3) + f = Factor(vals, ['A', 'B'], cards, batched=True) + e = f.set_evidence('A', 1) + assert e.batched + assert e.variables == ['B'] + torch.testing.assert_close(e.values, vals[:, 1, :]) + + def test_missing_variable_raises(self): + f = Factor(torch.ones(2), ['A'], _cards(('A', 2))) + with pytest.raises(ValueError, match="not in the factor scope"): + f.set_evidence('X', 0) + + def test_out_of_range_raises(self): + f = Factor(torch.ones(2), ['A'], _cards(('A', 2))) + with pytest.raises(ValueError, match="out of range"): + f.set_evidence('A', 5) + + def test_negative_state_raises(self): + f = Factor(torch.ones(2), ['A'], _cards(('A', 2))) + with pytest.raises(ValueError, match="out of range"): + f.set_evidence('A', -1) + + +# =========================================================================== +# Normalize +# =========================================================================== + +class TestFactorNormalize: + """Tests for Factor.normalize.""" + + def test_unbatched(self): + f = Factor(torch.tensor([2.0, 3.0, 5.0]), ['A'], _cards(('A', 3))) + Z, normed = f.normalize() + assert Z.item() == pytest.approx(10.0) + torch.testing.assert_close(normed.values, torch.tensor([0.2, 0.3, 0.5])) + + def test_batched(self): + vals = torch.tensor([[1.0, 3.0], [2.0, 2.0]]) + f = Factor(vals, ['A'], _cards(('A', 2)), batched=True) + Z, normed = f.normalize() + assert Z.shape == (2,) + torch.testing.assert_close(Z, torch.tensor([4.0, 4.0])) + torch.testing.assert_close(normed.values.sum(dim=1), + torch.tensor([1.0, 1.0])) + + def test_multi_dim_unbatched(self): + vals = torch.ones(2, 3) + f = Factor(vals, ['A', 'B'], _cards(('A', 2), ('B', 3))) + Z, normed = f.normalize() + assert Z.item() == pytest.approx(6.0) + assert normed.values.sum().item() == pytest.approx(1.0) + + def test_multi_dim_batched(self): + vals = torch.ones(4, 2, 3) + f = Factor(vals, ['A', 'B'], _cards(('A', 2), ('B', 3)), + batched=True) + Z, normed = f.normalize() + assert Z.shape == (4,) + for i in range(4): + assert normed.values[i].sum().item() == pytest.approx(1.0) + + +# =========================================================================== +# __repr__ +# =========================================================================== + +class TestFactorRepr: + def test_repr_contains_variables(self): + f = Factor(torch.ones(2, 3), ['A', 'B'], _cards(('A', 2), ('B', 3))) + r = repr(f) + assert 'A' in r and 'B' in r + assert '[2, 3]' in r + + +# =========================================================================== +# Gradient flow +# =========================================================================== + +class TestFactorGradientFlow: + """Make sure all operations are differentiable.""" + + def test_product_grad(self): + a = torch.tensor([0.3, 0.7], requires_grad=True) + b = torch.tensor([0.4, 0.6], requires_grad=True) + cards = _cards(('A', 2), ('B', 2)) + fa = Factor(a, ['A'], cards) + fb = Factor(b, ['B'], cards) + p = fa.product(fb) + p.values.sum().backward() + assert a.grad is not None + assert b.grad is not None + + def test_marginalize_grad(self): + v = torch.rand(2, 3, requires_grad=True) + f = Factor(v, ['A', 'B'], _cards(('A', 2), ('B', 3))) + m = f.marginalize('B') + m.values.sum().backward() + assert v.grad is not None + + def test_normalize_grad(self): + v = torch.rand(2, 3, requires_grad=True) + f = Factor(v, ['A', 'B'], _cards(('A', 2), ('B', 3))) + _, normed = f.normalize() + normed.values.sum().backward() + assert v.grad is not None + + def test_set_evidence_grad(self): + v = torch.rand(2, 3, requires_grad=True) + f = Factor(v, ['A', 'B'], _cards(('A', 2), ('B', 3))) + e = f.set_evidence('A', 0) + e.values.sum().backward() + assert v.grad is not None diff --git a/tests/nn/modules/mid/models/test_initialization.py b/tests/nn/modules/mid/models/test_initialization.py index 4191bf2f..776ab791 100644 --- a/tests/nn/modules/mid/models/test_initialization.py +++ b/tests/nn/modules/mid/models/test_initialization.py @@ -16,7 +16,7 @@ ExogenousVariable, LatentVariable, ) -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.distributions import Delta diff --git a/tests/nn/modules/mid/models/test_probabilistic_model.py b/tests/nn/modules/mid/models/test_probabilistic_model.py index 363fca1b..26481433 100644 --- a/tests/nn/modules/mid/models/test_probabilistic_model.py +++ b/tests/nn/modules/mid/models/test_probabilistic_model.py @@ -10,8 +10,8 @@ import torch.nn as nn from torch.distributions import Bernoulli, OneHotCategorical from torch_concepts.nn.modules.mid.models.variable import Variable -from torch_concepts.nn.modules.mid.models.factor import ParametricFactor -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_factor import ParametricFactor +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.distributions import Delta from torch_concepts.nn.modules.mid.models.probabilistic_model import ( ProbabilisticModel, @@ -303,49 +303,6 @@ def test_get_variable_parents_nonexistent(self): model = ProbabilisticModel(variables=[var], factors=[cpd]) self.assertEqual(model.get_variable_parents('Z'), []) - def test_build_cpts_no_parents_delta(self): - """Test build_cpts for Delta variable with no parents.""" - var = Variable(concepts='x', distribution=Delta, size=1) - module = nn.Linear(in_features=2, out_features=1) - cpd = ParametricCPD(concepts='x', parametrization=module) - - model = ProbabilisticModel(variables=[var], factors=[cpd]) - cpts = model.build_cpts() - - self.assertIn('x', cpts) - self.assertIsInstance(cpts['x'], torch.Tensor) - self.assertGreaterEqual(cpts['x'].shape[-1], 1) - - def test_build_potentials_no_parents_delta(self): - """Test build_potentials for Delta variable with no parents.""" - var = Variable(concepts='x', distribution=Delta, size=1) - module = nn.Linear(in_features=2, out_features=1) - cpd = ParametricCPD(concepts='x', parametrization=module) - - model = ProbabilisticModel(variables=[var], factors=[cpd]) - pots = model.build_potentials() - - self.assertIn('x', pots) - self.assertIsInstance(pots['x'], torch.Tensor) - - def test_build_cpts_with_parent_bernoulli(self): - """Test build_cpts with parent-child Bernoulli structure.""" - parent = Variable(concepts='p', distribution=Bernoulli, size=1) - child = Variable(concepts='c', distribution=Bernoulli, size=1) - - parent_cpd = ParametricCPD(concepts='p', parametrization=nn.Linear(1, 1)) - child_cpd = ParametricCPD(concepts='c', parametrization=nn.Linear(1, 1), parents=['p']) - - model = ProbabilisticModel( - variables=[parent, child], - factors=[parent_cpd, child_cpd] - ) - - cpts = model.build_cpts() - self.assertIn('c', cpts) - cpt_c = cpts['c'] - self.assertGreaterEqual(cpt_c.shape[1], 1) - def test_get_by_distribution(self): """Test that get_by_distribution works on ProbabilisticModel.""" parent = Variable(concepts='p', distribution=Bernoulli, size=1) @@ -548,44 +505,6 @@ def test_resolve_invalid_parent_type_raises(self): with self.assertRaises(TypeError): ProbabilisticModel(variables=[var], factors=[cpd]) - # --- Line 214: _make_temp_parametric_cpd with non-CPD module --- - - def test_make_temp_cpd_with_plain_module(self): - """_make_temp_parametric_cpd works when passed a plain nn.Module.""" - var = Variable(concepts='A', distribution=Bernoulli, size=1) - cpd = ParametricCPD(concepts='A', parametrization=nn.Linear(10, 1)) - model = ProbabilisticModel(variables=[var], factors=[cpd]) - # Call with a plain nn.Module (not a ParametricCPD) to hit else branch - plain_module = nn.Linear(10, 1) - temp_cpd = model._make_temp_parametric_cpd('A', plain_module) - self.assertIsInstance(temp_cpd, ParametricCPD) - self.assertIs(temp_cpd.variable, var) - self.assertIs(temp_cpd.parametrization, plain_module) - - # --- build_cpts / build_potentials reject shared CPDs --- - - def test_build_cpts_rejects_shared_cpds(self): - """build_cpts raises NotImplementedError for models with shared CPDs.""" - var_a = Variable(concepts='A', distribution=Bernoulli, size=1) - var_b = Variable(concepts='B', distribution=Bernoulli, size=1) - shared_cpd = ParametricCPD( - concepts=['A', 'B'], parametrization=nn.Linear(10, 2), shared=True) - model = ProbabilisticModel( - variables=[var_a, var_b], factors=[shared_cpd]) - with self.assertRaises(NotImplementedError): - model.build_cpts() - - def test_build_potentials_rejects_shared_cpds(self): - """build_potentials raises NotImplementedError for models with shared CPDs.""" - var_a = Variable(concepts='A', distribution=Bernoulli, size=1) - var_b = Variable(concepts='B', distribution=Bernoulli, size=1) - shared_cpd = ParametricCPD( - concepts=['A', 'B'], parametrization=nn.Linear(10, 2), shared=True) - model = ProbabilisticModel( - variables=[var_a, var_b], factors=[shared_cpd]) - with self.assertRaises(NotImplementedError): - model.build_potentials() - if __name__ == '__main__': unittest.main() diff --git a/tests/nn/modules/mid/test_shared_cpd.py b/tests/nn/modules/mid/test_shared_cpd.py index 4bcb2dc4..7bb85ed2 100644 --- a/tests/nn/modules/mid/test_shared_cpd.py +++ b/tests/nn/modules/mid/test_shared_cpd.py @@ -17,7 +17,7 @@ LinearConceptToConcept, LinearLatentToConcept, ) -from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD +from torch_concepts.nn.modules.mid.models.parametric_cpd import ParametricCPD from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel diff --git a/tests/nn/modules/test_joint_nll_loss.py b/tests/nn/modules/test_joint_nll_loss.py new file mode 100644 index 00000000..e73c4f3f --- /dev/null +++ b/tests/nn/modules/test_joint_nll_loss.py @@ -0,0 +1,90 @@ +""" +Tests for JointNLLLoss. + +Tests cover: +- Correct NLL indexing into multi-dimensional log-joint +- BCE fallback for 2D (flat logit) input +- Gradient flow through log-joint path +- Gradient flow through BCE fallback path +- Known-value computation check +- Single-concept edge case +""" +import pytest +import torch +from torch import nn + +from torch_concepts.nn.modules.loss import JointNLLLoss + + +class TestJointNLLLossLogJoint: + """Tests for the multi-dim log-joint path (ndim > 2).""" + + def test_known_value(self): + """Hand-computed NLL for a (2, 2, 2) log-joint.""" + # Two binary concepts, batch=2 + joint = torch.tensor([ + [[0.1, 0.2], [0.3, 0.4]], # sample 0 + [[0.05, 0.15], [0.35, 0.45]], # sample 1 + ]) + log_joint = torch.log(joint) + target = torch.tensor([[0, 1], [1, 0]]) # sample0: (0,1), sample1: (1,0) + + loss_fn = JointNLLLoss() + loss = loss_fn(input=log_joint, target=target) + + # sample 0: joint[0, 0, 1] = 0.2 → -log(0.2) + # sample 1: joint[1, 1, 0] = 0.35 → -log(0.35) + expected = -(torch.log(torch.tensor(0.2)) + torch.log(torch.tensor(0.35))) / 2 + torch.testing.assert_close(loss, expected) + + def test_output_is_scalar(self): + log_joint = torch.randn(16, 2, 3) + target = torch.randint(0, 2, (16, 1)).long() + target = torch.cat([target, torch.randint(0, 3, (16, 1)).long()], dim=1) + loss = JointNLLLoss()(input=log_joint, target=target) + assert loss.ndim == 0 + + def test_gradient_flows(self): + log_joint = torch.randn(8, 2, 2, requires_grad=True) + target = torch.randint(0, 2, (8, 2)) + loss = JointNLLLoss()(input=log_joint, target=target) + loss.backward() + assert log_joint.grad is not None + assert log_joint.grad.shape == log_joint.shape + + def test_single_concept(self): + """Edge case: one concept with cardinality 3 → (batch, 3) log-joint + is 2D so falls through to BCE. Use 3D (batch, 1, 3) for log-joint.""" + # Reshape to (batch, 1, 3) so it's treated as log-joint + log_joint = torch.log_softmax(torch.randn(4, 1, 3), dim=-1) + target = torch.randint(0, 1, (4, 1)).long() # index into dim of size 1 + # This should still work for the log-joint path + loss = JointNLLLoss()(input=log_joint, target=target) + assert loss.ndim == 0 + + def test_many_concepts(self): + """Batch with 4 binary concepts → shape (batch, 2, 2, 2, 2).""" + log_joint = torch.log_softmax(torch.randn(8, 2, 2, 2, 2).view(8, -1), + dim=-1).view(8, 2, 2, 2, 2) + target = torch.randint(0, 2, (8, 4)) + loss = JointNLLLoss()(input=log_joint, target=target) + assert loss.ndim == 0 + assert torch.isfinite(loss) + + +class TestJointNLLLossBCEFallback: + """Tests for the 2D BCE fallback path.""" + + def test_matches_bce(self): + logits = torch.randn(16, 4) + target = torch.rand(16, 4) + loss_joint = JointNLLLoss()(input=logits, target=target) + loss_bce = nn.BCEWithLogitsLoss()(logits, target) + torch.testing.assert_close(loss_joint, loss_bce) + + def test_gradient_flows_bce(self): + logits = torch.randn(8, 3, requires_grad=True) + target = torch.rand(8, 3) + loss = JointNLLLoss()(input=logits, target=target) + loss.backward() + assert logits.grad is not None diff --git a/torch_concepts/data/datasets/categorical_toy_dag.py b/torch_concepts/data/datasets/categorical_toy_dag.py index 78d831e8..334b4d14 100644 --- a/torch_concepts/data/datasets/categorical_toy_dag.py +++ b/torch_concepts/data/datasets/categorical_toy_dag.py @@ -33,6 +33,7 @@ def __init__( cardinalities: Dict[str, int], dag: List[Tuple[str, str]], conditional_probs: Dict[Union[Tuple[str, str], Tuple[str]], np.ndarray], + root_priors: Optional[Dict[str, np.ndarray]] = None, seed: int = 42 ): """ @@ -49,12 +50,18 @@ def __init__( For a child with multiple parents, use key (child,) with shape (child_cardinality, parent1_cardinality, parent2_cardinality, ...). Each CPT should sum to 1.0 along the first (child) dimension. + root_priors: Optional dictionary mapping root variable names to their prior + probability arrays. Each array has length equal to the + variable's cardinality and must sum to 1. + E.g. ``{'v1': np.array([0.3, 0.7])}`` for P(v1=0)=0.3, P(v1=1)=0.7. + Root variables without an entry are sampled uniformly. seed: Random seed for reproducibility """ self.variables = variables self.cardinalities = cardinalities self.dag = dag self.conditional_probs = conditional_probs + self.root_priors = root_priors if root_priors is not None else {} self.seed = seed # Build adjacency structure @@ -102,8 +109,12 @@ def generate_sample(self) -> Dict[str, np.ndarray]: cardinality = self.cardinalities[var] if not self.parents[var]: - # Root node: sample uniformly - value = np.random.randint(0, cardinality) + # Root node: use prior if provided, otherwise uniform + if var in self.root_priors: + probs = np.asarray(self.root_priors[var], dtype=np.float64) + value = np.random.choice(cardinality, p=probs) + else: + value = np.random.randint(0, cardinality) else: # Non-root: sample based on conditional probability parents = self.parents[var] @@ -181,6 +192,9 @@ class ToyDAGDataset(ConceptDataset): dag: List of edges representing the DAG structure as (parent, child) tuples. conditional_probs: Dictionary mapping variables to their conditional probability tables. Format: {(parent, child): array} or {(child,): array for multi-parent} + root_priors: Optional dictionary mapping root variable names to their prior + probability arrays (length = cardinality, must sum to 1). + Root variables without an entry are sampled uniformly. root: Root directory to store/load the dataset. If None, creates local folder. seed: Random seed for data generation and reproducibility. n_gen: Total number of samples to generate. @@ -197,6 +211,7 @@ def __init__( cardinalities: Dict[str, int], dag: List[Tuple[str, str]], conditional_probs: Dict[Union[Tuple[str, str], Tuple[str]], Union[np.ndarray, list]], + root_priors: Optional[Dict[str, Union[np.ndarray, list]]] = None, root: str = None, seed: int = 42, n_gen: int = 10000, @@ -210,6 +225,10 @@ def __init__( self.variables = variables self.cardinalities = cardinalities self.dag = dag + self.root_priors = { + k: np.asarray(v, dtype=np.float64) + for k, v in root_priors.items() + } if root_priors is not None else {} self.seed = seed self.n_gen = n_gen self.target_variable = target_variable @@ -382,6 +401,7 @@ def build(self): cardinalities=self.cardinalities, dag=self.dag, conditional_probs=self.conditional_probs, + root_priors=self.root_priors, seed=self.seed ) diff --git a/torch_concepts/nn/__init__.py b/torch_concepts/nn/__init__.py index 59a247c1..bde1f671 100644 --- a/torch_concepts/nn/__init__.py +++ b/torch_concepts/nn/__init__.py @@ -38,7 +38,7 @@ # Loss functions from .modules.loss import ConceptLoss, WeightedConceptLoss, DepthWeightedConceptLoss, \ - L1LogitRegularizer + L1LogitRegularizer, JointNLLLoss # Metrics from .modules.metrics import ConceptMetrics, compute_cace @@ -55,8 +55,8 @@ # Models (mid-level) -from .modules.mid.models.factor import ParametricFactor -from .modules.mid.models.cpd import ParametricCPD +from .modules.mid.models.parametric_factor import ParametricFactor +from .modules.mid.models.parametric_cpd import ParametricCPD from .modules.mid.models.probabilistic_model import ProbabilisticModel from .modules.mid.constructors.bipartite import BipartiteModel from .modules.mid.constructors.graph import GraphModel @@ -67,6 +67,7 @@ DeterministicInference, AncestralSamplingInference, IndependentInference, + VariableEliminationInference, ) # Interventions (low-level) @@ -126,6 +127,7 @@ "WeightedConceptLoss", "DepthWeightedConceptLoss", "L1LogitRegularizer", + "JointNLLLoss", # Metrics "ConceptMetrics", @@ -153,7 +155,7 @@ "DeterministicInference", "AncestralSamplingInference", "IndependentInference", - + "VariableEliminationInference", # Interventions "RewiringIntervention", "GroundTruthIntervention", diff --git a/torch_concepts/nn/modules/loss.py b/torch_concepts/nn/modules/loss.py index 07b79133..1fbb89c7 100644 --- a/torch_concepts/nn/modules/loss.py +++ b/torch_concepts/nn/modules/loss.py @@ -623,4 +623,51 @@ def forward( mask = torch.isfinite(input) if mask.any(): return self.scale * input[mask].abs().mean() - return torch.tensor(0.0, device=input.device) \ No newline at end of file + return torch.tensor(0.0, device=input.device) + + +class JointNLLLoss(nn.Module): + """Negative log-likelihood loss on the joint distribution. + + Designed for inference engines (e.g., variable elimination) that + return a dict with key ``'log_joint'`` containing the log of the + normalised joint distribution over all query variables, shaped + ``(batch, card_1, card_2, ..., card_n)``. + + This loss indexes into the log-joint using the ground-truth state + indices and returns the mean NLL: + + .. math:: + + \\mathcal{L} = -\\frac{1}{B}\\sum_{i=1}^{B} + \\log P(c_1^{(i)}, c_2^{(i)}, \\ldots, c_n^{(i)} \\mid x^{(i)}) + + When the model's eval inference returns a flat ``(batch, n_features)`` + tensor instead of a log-joint, the loss falls back to + ``BCEWithLogitsLoss`` so that validation / test metrics remain valid. + + Args: + **kwargs: Must include ``input`` (log-joint tensor of shape + ``(batch, *cardinalities)`` or flat logits ``(batch, n_features)``) + and ``target`` (integer state indices or float labels of shape + ``(batch, n_concepts)``). + + Returns: + torch.Tensor: Scalar NLL loss. + """ + + def __init__(self): + super().__init__() + self._bce_fallback = nn.BCEWithLogitsLoss() + + def forward(self, **kwargs) -> torch.Tensor: + input = kwargs['input'] + target = kwargs['target'] + + # Flat 2D tensor → fallback to BCE (eval inference path) + if input.ndim == 2: + return self._bce_fallback(input, target.float()) + + # Multi-dim log-joint from inference engine + idx = (torch.arange(input.size(0), device=input.device), *target.long().unbind(1)) + return -input[idx].mean() \ No newline at end of file diff --git a/torch_concepts/nn/modules/low/inference/intervention.py b/torch_concepts/nn/modules/low/inference/intervention.py index d24d3031..07a66deb 100644 --- a/torch_concepts/nn/modules/low/inference/intervention.py +++ b/torch_concepts/nn/modules/low/inference/intervention.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn -from ...mid.models.cpd import ParametricCPD +from ...mid.models.parametric_cpd import ParametricCPD from ..base.inference import BaseIntervention # ---------------- core helpers ---------------- diff --git a/torch_concepts/nn/modules/mid/constructors/graph.py b/torch_concepts/nn/modules/mid/constructors/graph.py index c45672a9..db34e036 100644 --- a/torch_concepts/nn/modules/mid/constructors/graph.py +++ b/torch_concepts/nn/modules/mid/constructors/graph.py @@ -4,7 +4,7 @@ from .....annotations import Annotations from ..models.variable import Variable, LatentVariable, ExogenousVariable, ConceptVariable from .concept_graph import ConceptGraph -from ..models.cpd import ParametricCPD +from ..models.parametric_cpd import ParametricCPD from ..models.probabilistic_model import ProbabilisticModel from .....distributions import Delta from ..base.model import BaseConstructor diff --git a/torch_concepts/nn/modules/mid/inference/__init__.py b/torch_concepts/nn/modules/mid/inference/__init__.py index d3b1eddd..1c334724 100644 --- a/torch_concepts/nn/modules/mid/inference/__init__.py +++ b/torch_concepts/nn/modules/mid/inference/__init__.py @@ -2,10 +2,12 @@ from .deterministic import DeterministicInference from .independent import IndependentInference from .ancestral import AncestralSamplingInference +from .variable_elimination import VariableEliminationInference __all__: list[str] = [ "ForwardInference", "DeterministicInference", "AncestralSamplingInference", "IndependentInference", + "VariableEliminationInference", ] diff --git a/torch_concepts/nn/modules/mid/inference/variable_elimination.py b/torch_concepts/nn/modules/mid/inference/variable_elimination.py new file mode 100644 index 00000000..6babcbfe --- /dev/null +++ b/torch_concepts/nn/modules/mid/inference/variable_elimination.py @@ -0,0 +1,417 @@ +""" +Variable Elimination inference for probabilistic graphical models. + +This module implements the Sum-Product Variable Elimination algorithm for +computing exact conditional probabilities in both Bayesian Networks and +Markov Random Fields. All operations are differentiable so that gradients +flow back through the neural-network parameters that produced the factor +potentials — enabling end-to-end training through inference. +""" + +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from ..models.factor import Factor +from ..models.probabilistic_model import ProbabilisticModel +from ..models.variable import _BINARY_DISTRIBUTIONS, _CATEGORICAL_DISTRIBUTIONS +from ...low.base.inference import BaseInference + + +# ────────────────────────────────────────────────────────────────────── +# Elimination-ordering heuristics +# ────────────────────────────────────────────────────────────────────── + +def _min_degree_order( + factors: List[Factor], + variables_to_eliminate: List[str], +) -> List[str]: + """ + Compute a greedy min-degree elimination ordering. + + At each step the variable whose current factor-neighbourhood is + smallest (fewest other variables sharing a factor) is chosen. + + Parameters + ---------- + factors : List[Factor] + The current set of factors. + variables_to_eliminate : List[str] + Variables that must be eliminated. + + Returns + ------- + List[str] + The variables in elimination order. + """ + remaining = set(variables_to_eliminate) + order: List[str] = [] + + # Pre-build adjacency: for each variable, the set of other + # variables that share at least one factor with it. + adj: dict = {v: set() for v in remaining} + for f in factors: + scope = [v for v in f.variables if v in remaining] + for v in scope: + for u in scope: + if u != v: + adj[v].add(u) + + for _ in range(len(variables_to_eliminate)): + # Pick the remaining variable with the fewest active neighbours + best = min(remaining, + key=lambda v: sum(1 for u in adj[v] if u in remaining)) + order.append(best) + remaining.remove(best) + + return order + + +# ────────────────────────────────────────────────────────────────────── +# Inference class +# ────────────────────────────────────────────────────────────────────── + +class VariableEliminationInference(BaseInference): + """ + Exact inference via Sum-Product Variable Elimination. + + Supports both Bayesian Networks (all factors are :class:`ParametricCPD`) + and Markov Random Fields (all factors are :class:`ParametricFactor`). + + The ``query`` method: + + 1. Builds discrete :class:`Factor` instances from the neural-network + parametrised CPDs / potentials (differentiable). + 2. Conditions on the evidence by slicing factors. + 3. Eliminates hidden variables in a chosen order. + 4. Normalises the remaining factor to obtain ``P(query | evidence)``. + + All tensor operations (product, marginalisation, slicing, + normalisation) are standard PyTorch ops and therefore + **differentiable** — gradients propagate back through the network + weights that produced the factor values. + + Parameters + ---------- + probabilistic_model : ProbabilisticModel + The graphical model whose factors are queried. + elimination_order : List[str], optional + A fixed elimination ordering for the hidden variables. If + ``None`` a greedy min-degree heuristic is used. + + Example + ------- + >>> import torch + >>> from torch.distributions import Bernoulli + >>> from torch_concepts import ConceptVariable + >>> from torch_concepts.nn import ParametricCPD, ProbabilisticModel + >>> from torch_concepts.nn.modules.mid.inference.variable_elimination import ( + ... VariableEliminationInference, + ... ) + >>> + >>> A = ConceptVariable('A', distribution=Bernoulli) + >>> B = ConceptVariable('B', distribution=Bernoulli) + >>> cpd_A = ParametricCPD('A', parametrization=torch.nn.Linear(1, 1)) + >>> cpd_B = ParametricCPD('B', parametrization=torch.nn.Linear(1, 1), + ... parents=['A']) + >>> model = ProbabilisticModel(variables=[A, B], + ... factors=[cpd_A, cpd_B]) + >>> ve = VariableEliminationInference(model) + >>> result = ve.query(query=['B'], evidence={'A': 1}) + >>> result.values # normalised P(B | A=1) + >>> + >>> # With input embedding: + >>> result = ve.query(query=['B'], evidence={'input': x, 'A': 1}) + """ + + def __init__( + self, + probabilistic_model: ProbabilisticModel, + elimination_order: Optional[List[str]] = None, + ): + super().__init__() + self.probabilistic_model = probabilistic_model + self.elimination_order = elimination_order + self._order_cache: Dict[Tuple[Tuple[str, ...], Tuple[str, ...]], List[str]] = {} + + # ------------------------------------------------------------------ + # BaseInference interface + # ------------------------------------------------------------------ + + def query( + self, + query: List[str], + evidence: Optional[Dict[str, int]] = None, + return_logits: bool = False, + return_log_joint: bool = False, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute the conditional distribution ``P(query | evidence)``. + + Parameters + ---------- + query : List[str] + Names of the query variables. + evidence : dict, optional + Mapping from observed variable names to their observed state + index (0-based). For example ``{'A': 1}`` means A is observed + in state 1. The special key ``'input'`` may hold the input + embedding tensor of shape ``(batch, emb_dim)``. When present, + each factor is conditioned on the input, yielding per-sample + distributions. + return_logits : bool, optional + If ``True``, return unnormalised log-potentials instead + of normalised probabilities. Useful during training when the + loss expects log-scale values. Default: ``False``. + return_log_joint : bool, optional + If ``True``, return a dict with keys ``'log_joint'`` (the + log of the normalised joint distribution, shape + ``(batch, *cardinalities)``) and ``'logits'`` (per-concept + marginal logits, shape ``(batch, n_features)``). Intended + for use with :class:`JointNLLLoss`. Takes precedence over + ``return_logits``. + + Returns + ------- + torch.Tensor or dict + When ``return_log_joint=True``: a dict with ``'log_joint'`` + and ``'logits'``. Otherwise a ``(batch, n_features)`` tensor + (or ``(n_features,)`` when unbatched). + """ + if evidence is None: + evidence = {} + + # Extract input embedding from evidence (if present) + input_tensor = evidence.pop('input', None) if isinstance( + evidence, dict) else None + + # 1. Build factors from the parametric model + # If input_tensor is not None, each factor will be conditioned on it. + factors: List[Factor] = self.probabilistic_model.build_factors( + input=input_tensor) + + # 2. Set evidence: replace each factor by its slice + for var_name, state in evidence.items(): + factors = [ + f.set_evidence(var_name, state) if var_name in f.variables else f + for f in factors + ] + + # 3. Determine hidden variables + all_vars: set = set() + for f in factors: + all_vars.update(f.variables) + query_set = set(query) + evidence_set = set(evidence.keys()) + hidden = all_vars - query_set - evidence_set + + # 4. Elimination ordering (cached by query/evidence pattern) + cache_key = (tuple(sorted(query)), tuple(sorted(evidence.keys()))) + if cache_key in self._order_cache: + elim_order = self._order_cache[cache_key] + elif self.elimination_order is not None: + # Filter user-provided order to only include actual hidden vars + elim_order = [v for v in self.elimination_order if v in hidden] + self._order_cache[cache_key] = elim_order + else: + elim_order = _min_degree_order(factors, list(hidden)) + self._order_cache[cache_key] = elim_order + + # 5. Sum-Product VE + phi_star = self._sum_product_ve(factors, elim_order) + + # 6. Normalise + _Z, normalised = phi_star.normalize() + + # 7. Return as log-joint dict or flat Tensor + if return_log_joint: + log_joint = torch.log(normalised.values.clamp(min=1e-10)) + logits = self._factor_to_tensor(normalised, query, + return_logits=True) + return {'log_joint': log_joint, 'logits': logits} + + return self._factor_to_tensor(normalised, query, return_logits) + + def ground_truth_to_evidence( + self, + value: torch.Tensor, + cardinality: int, + ) -> torch.Tensor: + """ + Convert ground-truth labels to state indices. + + For Variable Elimination, evidence is simply integer state indices. + This method returns the input unchanged (already in the correct + format). + + Parameters + ---------- + value : torch.Tensor + Ground truth tensor with integer state indices. + cardinality : int + Number of states for the variable (unused, kept for API + compatibility). + + Returns + ------- + torch.Tensor + The same integer indices. + """ + return value + + # ------------------------------------------------------------------ + # Factor → Tensor conversion + # ------------------------------------------------------------------ + + def _factor_to_tensor( + self, + joint: Factor, + query: List[str], + return_logits: bool, + ) -> torch.Tensor: + """ + Convert a normalised joint factor into a flat tensor. + + For each query variable the joint is marginalised to a univariate + distribution and then converted to either a logit or a + probability column: + + * **Binary** (cardinality 2): one column - the logit + ``log P(v=1) - log P(v=0)`` when *return_logits* is ``True``, + or ``P(v=1)`` otherwise. + * **Categorical** (cardinality K): *K* columns - the log-probs + when *return_logits* is ``True``, or the probabilities otherwise. + + All marginals are computed directly on the raw tensor with + :func:`torch.sum` - no intermediate :class:`Factor` objects are + created. + + Parameters + ---------- + joint : Factor + Normalised factor over the query variables (possibly batched). + query : List[str] + Ordered concept names. + return_logits : bool + Whether to return logits or probabilities. + + Returns + ------- + torch.Tensor + Shape ``(batch, n_features)`` when batched, else + ``(n_features,)``. + """ + concept_to_var = self.probabilistic_model.concept_to_variable + values = joint.values # (batch?, *var_cards) + offset = 1 if joint.batched else 0 + eps = 1e-10 + columns: List[torch.Tensor] = [] + + for var_name in query: + # Compute marginal by summing over all dims except batch + this var + var_dim = joint.variables.index(var_name) + offset + sum_dims = [d for d in range(offset, values.ndim) if d != var_dim] + probs = values.sum(dim=sum_dims) if sum_dims else values + # Re-normalise (floating-point drift) + probs = probs / probs.sum(dim=-1, keepdim=True) + + var = concept_to_var[var_name] + if var.distribution in _BINARY_DISTRIBUTIONS: + if return_logits: + col = (torch.log(probs[..., 1].clamp(min=eps)) + - torch.log(probs[..., 0].clamp(min=eps))) + else: + col = probs[..., 1] + columns.append(col.unsqueeze(-1)) + elif var.distribution in _CATEGORICAL_DISTRIBUTIONS: + if return_logits: + columns.append(torch.log(probs.clamp(min=eps))) + else: + columns.append(probs) + else: + raise NotImplementedError( + f"_factor_to_tensor: unsupported distribution " + f"{var.distribution.__name__} for variable '{var_name}'." + ) + + return torch.cat(columns, dim=-1) + + + # ────────────────────────────────────────────────────────────────────── + # Core VE routines + # ────────────────────────────────────────────────────────────────────── + + def _sum_product_ve( + self, + factors: List[Factor], + elimination_order: List[str], + ) -> Factor: + """ + Run Sum-Product Variable Elimination. + + Uses a variable-to-factor index for O(degree) factor lookup + per elimination step instead of scanning all factors. + + Parameters + ---------- + factors : List[Factor] + Initial factor set Φ. + elimination_order : List[str] + Ordered list of hidden variables to eliminate. + + Returns + ------- + Factor + The product of all remaining factors after elimination (φ*). + """ + # Build variable → factor-id index + var_to_fids: Dict[str, set] = defaultdict(set) + fid_to_factor: Dict[int, Factor] = {} + for i, f in enumerate(factors): + fid_to_factor[i] = f + for v in f.variables: + var_to_fids[v].add(i) + next_fid = len(factors) + + for var in elimination_order: + # Collect factors mentioning var — O(degree) via index + fids = var_to_fids.pop(var, set()) + phi_prime: List[Factor] = [] + for fid in fids: + f = fid_to_factor.pop(fid, None) + if f is not None: + phi_prime.append(f) + # Remove this factor from index entries of other variables + for v in f.variables: + if v != var and v in var_to_fids: + var_to_fids[v].discard(fid) + + if not phi_prime: + continue + + # Multiply all factors that contain the variable + psi = phi_prime[0] + for f in phi_prime[1:]: + psi = psi.product(f) + + # Marginalise out the variable + tau = psi.marginalize(var) + + # Register the new factor in the index + fid_to_factor[next_fid] = tau + for v in tau.variables: + var_to_fids[v].add(next_fid) + next_fid += 1 + + # Multiply remaining factors + remaining = list(fid_to_factor.values()) + if not remaining: + raise RuntimeError("No factors remain after variable elimination.") + + result = remaining[0] + for f in remaining[1:]: + result = result.product(f) + + return result diff --git a/torch_concepts/nn/modules/mid/models/factor.py b/torch_concepts/nn/modules/mid/models/factor.py index 6ba765e0..6d29797d 100644 --- a/torch_concepts/nn/modules/mid/models/factor.py +++ b/torch_concepts/nn/modules/mid/models/factor.py @@ -1,101 +1,333 @@ """ -ParametricFactor base class for probabilistic graphical models. +Factor operations for probabilistic graphical models. -This module defines the ParametricFactor base class, which represents a factor -in a factor graph. Factors associate concept-variables with neural network -parametrizations and form the building blocks for both directed (Bayesian -Networks) and undirected (Markov Random Fields) graphical models. +This module defines the :class:`Factor` class — a lightweight wrapper around a +multi-dimensional :class:`torch.Tensor` that represents a factor (potential) +in a factor graph. Each tensor axis corresponds to a random variable in the +factor's scope, and all operations (product, marginalisation, evidence +conditioning) are implemented as standard differentiable PyTorch operations +so that gradients flow back through the neural-network parameters that +produced the factor values. + +The class is intentionally agnostic to whether the underlying model is +directed (Bayesian Network) or undirected (Markov Random Field): both +produce :class:`Factor` instances that the inference algorithms manipulate +identically. """ -import torch.nn as nn -from typing import List, Optional, Union -from .variable import Variable +import torch +from typing import Dict, List, Tuple +# Subscript pool for einsum: a-z + A-Y = 51 chars (Z reserved for batch dim). +_EINSUM_SUBSCRIPTS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXY' -class ParametricFactor(nn.Module): - """ - Base class for factors in a probabilistic graphical model. - A ParametricFactor associates a set of named concepts (its *scope*) with a - single neural-network parametrization. The factor produces one potential - over all the concepts in its scope. +class Factor: + """ + A factor (potential function) in a probabilistic graphical model. - This base class is agnostic to the directionality of the graphical model. - Subclasses specialise the semantics: + Wraps a multi-dimensional :class:`torch.Tensor` whose axes are named + random variables. Provides differentiable operations commonly used by + exact and approximate inference algorithms: - * :class:`ParametricCPD` — directed factor with explicit ``parents``, - used in Bayesian Networks. - * (future) undirected factors for Markov Random Fields. + * :meth:`product` — factor product via broadcasting. + * :meth:`marginalize` — sum out one or more variables. + * :meth:`set_evidence` — condition on observed variable values. + * :meth:`normalize` — compute the partition function and return a + normalised copy. Parameters ---------- - concepts : Union[str, List[str]] - A single concept name or a list of concept names defining the scope - of this factor. A single ``ParametricFactor`` instance is always - created (never a list). - parametrization : nn.Module - A single neural-network module that computes the factor potential. - - Attributes - ---------- - concept : str - Primary concept name (first element of the scope). - concepts : List[str] - Full scope of concept names. - parametrization : nn.Module - The neural network module used to compute factor values. - variable : Optional[Variable] - The :class:`Variable` instance this factor is linked to - (set by :class:`ProbabilisticModel` during initialisation). - - See Also - -------- - ParametricCPD : Directed factor for conditional probability distributions. - Variable : Represents a random variable (concept) in the model. - ProbabilisticModel : Generic container that manages factors and variables. + values : torch.Tensor + The factor tensor. Its shape must match the cardinalities of the + variables listed in *variables* (in the same order). When + *batched* is ``True`` the first dimension is a batch dimension + and the variable axes start at dimension 1. + variables : List[str] + Ordered list of variable names — ``variables[i]`` labels axis *i* of + *values* (or axis *i + 1* when *batched*). + cardinalities : Dict[str, int] + Mapping from every variable name that may appear in the model to its + number of states. The dict may contain entries for variables not + currently in this factor's scope (they are simply ignored). + batched : bool, optional + If ``True`` the leading dimension of *values* is treated as a + batch dimension and all factor operations preserve it. + + Raises + ------ + ValueError + If the length of *variables* does not match the number of dimensions + of *values*, or if any axis size disagrees with *cardinalities*. """ - def __init__(self, concepts: Union[str, List[str]], - parametrization: nn.Module, - **kwargs): + def __init__( + self, + values: torch.Tensor, + variables: List[str], + cardinalities: Dict[str, int], + batched: bool = False, + ): + expected_ndim = len(variables) + (1 if batched else 0) + if values.ndim != expected_ndim: + raise ValueError( + f"Tensor has {values.ndim} dimensions but expected " + f"{expected_ndim} ({len(variables)} variables" + f"{' + 1 batch dim' if batched else ''})." + ) + offset = 1 if batched else 0 + for i, var in enumerate(variables): + if var in cardinalities and values.shape[i + offset] != cardinalities[var]: + raise ValueError( + f"Axis {i + offset} ('{var}') has size " + f"{values.shape[i + offset]} but " + f"cardinality is {cardinalities[var]}." + ) + + self.values = values + self.variables: List[str] = list(variables) + self.cardinalities: Dict[str, int] = cardinalities + self.batched: bool = batched + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def product(self, other: "Factor") -> "Factor": + """ + Compute the factor product of ``self`` and *other*. + + The resulting factor has scope + ``self.variables U other.variables``. Shared variables are + aligned and element-wise multiplied; non-shared variables are + broadcast. + + Parameters + ---------- + other : Factor + The factor to multiply with. + + Returns + ------- + Factor + A new factor whose values are the element-wise product over + the union of scopes. + """ + # Build the union scope: self's variables first, then new ones + # from other. + new_vars = list(self.variables) + for v in other.variables: + if v not in new_vars: + new_vars.append(v) + + new_batched = self.batched or other.batched + + # --- einsum-based product --- + # Assign each variable a unique subscript letter. + if len(new_vars) > len(_EINSUM_SUBSCRIPTS): + raise ValueError( + f"Factor product scope has {len(new_vars)} variables, " + f"exceeding the einsum limit of {len(_EINSUM_SUBSCRIPTS)}." + ) + var_to_sub = {v: _EINSUM_SUBSCRIPTS[i] for i, v in enumerate(new_vars)} + batch_sub = 'Z' + + lhs = ''.join(var_to_sub[v] for v in self.variables) + rhs = ''.join(var_to_sub[v] for v in other.variables) + out = ''.join(var_to_sub[v] for v in new_vars) + + a, b = self.values, other.values + + if new_batched: + if self.batched: + lhs = batch_sub + lhs + if other.batched: + rhs = batch_sub + rhs + out = batch_sub + out + # Promote unbatched operand so einsum sees a matching batch dim + if self.batched and not other.batched: + b = b.unsqueeze(0).expand(a.shape[0], *b.shape) + rhs = batch_sub + rhs + elif other.batched and not self.batched: + a = a.unsqueeze(0).expand(b.shape[0], *a.shape) + lhs = batch_sub + lhs + + result = torch.einsum(f'{lhs},{rhs}->{out}', a, b) + + new_cardinalities = self.cardinalities if self.cardinalities is other.cardinalities else {**self.cardinalities, **other.cardinalities} + return Factor(result, new_vars, new_cardinalities, batched=new_batched) + + def marginalize(self, variable: str) -> "Factor": + """ + Sum out *variable* from this factor. + + Parameters + ---------- + variable : str + Name of the variable to eliminate. + + Returns + ------- + Factor + A new factor whose scope no longer contains *variable*. + + Raises + ------ + ValueError + If *variable* is not in this factor's scope. + """ + if variable not in self.variables: + raise ValueError( + f"Variable '{variable}' is not in the factor scope " + f"{self.variables}." + ) + offset = 1 if self.batched else 0 + axis = self.variables.index(variable) + offset + new_values = self.values.sum(dim=axis) + new_vars = [v for v in self.variables if v != variable] + return Factor(new_values, new_vars, self.cardinalities, + batched=self.batched) + + def set_evidence(self, variable: str, state: int) -> "Factor": """ - Initialize a ParametricFactor instance. + Condition on ``variable = state`` by slicing the tensor. + + The returned factor no longer contains *variable* in its scope + (the axis is removed via indexing). Parameters ---------- - concepts : Union[str, List[str]] - Single concept name (stored as ``self.concept``). - parametrization : Union[nn.Module, List[nn.Module]] - Neural network module for computing factor values. - **kwargs - Ignored at this level; accepted so that subclass keyword - arguments (e.g. ``parents``) pass through ``__new__`` without error. + variable : str + Name of the observed variable. + state : int + Observed state index (0-based). + + Returns + ------- + Factor + A new factor with *variable* fixed to *state*. + + Raises + ------ + ValueError + If *variable* is not in the scope or *state* is out of range. + """ + if variable not in self.variables: + raise ValueError( + f"Variable '{variable}' is not in the factor scope " + f"{self.variables}." + ) + offset = 1 if self.batched else 0 + axis = self.variables.index(variable) + offset + if state < 0 or state >= self.values.shape[axis]: + raise ValueError( + f"State {state} is out of range for variable '{variable}' " + f"with {self.values.shape[axis]} states." + ) + # torch.select removes the dimension (like numpy basic indexing). + new_values = self.values.select(axis, state) + new_vars = [v for v in self.variables if v != variable] + return Factor(new_values, new_vars, self.cardinalities, + batched=self.batched) + + def normalize(self) -> Tuple[torch.Tensor, "Factor"]: """ - super().__init__() - if isinstance(concepts, str): - self.concepts: List[str] = [concepts] + Normalise the factor so that its values sum to one. + + When *batched* is ``True`` normalisation is performed + independently for each sample in the batch. + + Returns + ------- + Z : torch.Tensor + The partition function. Scalar when unbatched, shape + ``(batch,)`` when batched. + normalized : Factor + A new factor with the same scope whose values sum to 1. + """ + if self.batched: + var_dims = list(range(1, self.values.ndim)) + Z = self.values.sum(dim=var_dims, keepdim=True) + normalized_values = self.values / Z + Z_out = Z.reshape(self.values.size(0)) else: - self.concepts: List[str] = list(concepts) - self.concept: str = self.concepts[0] - self.parametrization = parametrization - self.variable: Optional[Variable] = None + Z_out = self.values.sum() + normalized_values = self.values / Z_out + return Z_out, Factor(normalized_values, self.variables, + self.cardinalities, batched=self.batched) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ - def forward(self, **kwargs): + def _align(self, target_vars: List[str]) -> torch.Tensor: """ - Compute the factor output by running the parametrization module. + Reshape ``self.values`` so that it is broadcastable to a tensor + whose axes correspond to *target_vars*. + + For every variable in *target_vars* that is **not** in + ``self.variables`` a size-1 dimension is inserted. For variables + that *are* in ``self.variables`` the original size is kept, and + axes are permuted to match the order in *target_vars*. Parameters ---------- - **kwargs - Keyword arguments passed to the parametrization module. + target_vars : List[str] + The desired axis ordering (superset of ``self.variables``). Returns ------- torch.Tensor - Output of the parametrization module. + View / expansion of ``self.values`` with ``len(target_vars)`` + dimensions. """ - return self.parametrization(**kwargs) + # Step 1: build a permutation + insertion plan. + # For each position in target_vars, record either the source axis + # index (if the variable exists in self) or None (needs unsqueeze). + source_index = {v: i for i, v in enumerate(self.variables)} + + # Permute existing axes to the right relative order, then + # unsqueeze for missing variables. + # + # Strategy: first permute self.values so that the variables that + # *do* appear in target_vars are in the correct relative order, + # then unsqueeze at positions where variables are missing. + + # Which of target_vars are in self? + present = [v for v in target_vars if v in source_index] + perm = [source_index[v] for v in present] + + if self.batched: + # Batch dim stays at position 0; variable dims are offset by 1. + perm_full = [0] + [p + 1 for p in perm] + t = (self.values.permute(*perm_full) + if perm_full != list(range(len(perm_full))) + else self.values) + result = t + for i, v in enumerate(target_vars): + if v not in source_index: + result = result.unsqueeze(i + 1) + else: + t = (self.values.permute(*perm) + if perm != list(range(len(perm))) + else self.values) + result = t + for i, v in enumerate(target_vars): + if v not in source_index: + result = result.unsqueeze(i) + + return result + + # ------------------------------------------------------------------ + # Dunder helpers + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + return ( + f"Factor(variables={self.variables}, " + f"shape={list(self.values.shape)})" + ) - def __repr__(self): - scope = self.concepts if len(self.concepts) > 1 else self.concept - return f"{self.__class__.__name__}(concepts={scope!r}, parametrization={self.parametrization.__class__.__name__})" + def __mul__(self, other: "Factor") -> "Factor": + """Allow ``f1 * f2`` as shorthand for ``f1.product(f2)``.""" + return self.product(other) diff --git a/torch_concepts/nn/modules/mid/models/cpd.py b/torch_concepts/nn/modules/mid/models/parametric_cpd.py similarity index 64% rename from torch_concepts/nn/modules/mid/models/cpd.py rename to torch_concepts/nn/modules/mid/models/parametric_cpd.py index 7ac29564..9adb56af 100644 --- a/torch_concepts/nn/modules/mid/models/cpd.py +++ b/torch_concepts/nn/modules/mid/models/parametric_cpd.py @@ -14,9 +14,14 @@ import torch.nn as nn -from .factor import ParametricFactor -from .variable import Variable -from .....distributions import Delta +from .factor import Factor +from .parametric_factor import ParametricFactor +from .variable import ( + Variable, + _BINARY_DISTRIBUTIONS, + _CATEGORICAL_DISTRIBUTIONS, + _CONTINUOUS_DISTRIBUTIONS, +) class ParametricCPD(ParametricFactor): @@ -160,7 +165,7 @@ def _get_parent_combinations(self) -> Tuple[torch.Tensor, torch.Tensor]: # --- guard against combinatorial explosion --- total_bits = 0 for p in self.parents: - if p.distribution in [Bernoulli, RelaxedBernoulli]: + if p.distribution in _BINARY_DISTRIBUTIONS: total_bits += p.size elif p.distribution in [Categorical, OneHotCategorical, RelaxedOneHotCategorical]: total_bits += p.size # one-hot dims @@ -182,7 +187,7 @@ def _get_parent_combinations(self) -> Tuple[torch.Tensor, torch.Tensor]: input_combinations = [] state_combinations = [] - if parent_var.distribution in [Bernoulli, RelaxedBernoulli]: + if parent_var.distribution in _BINARY_DISTRIBUTIONS: input_combinations = list(product([0.0, 1.0], repeat=out_dim)) state_combinations = input_combinations elif parent_var.distribution in [Categorical, OneHotCategorical, RelaxedOneHotCategorical]: @@ -197,7 +202,7 @@ def _get_parent_combinations(self) -> Tuple[torch.Tensor, torch.Tensor]: discrete_state_vectors_list.append( [torch.tensor(s, dtype=torch.float32).unsqueeze(0) for s in state_combinations]) - elif parent_var.distribution is Delta or parent_var.distribution is torch.distributions.Normal: + elif parent_var.distribution in _CONTINUOUS_DISTRIBUTIONS: fixed_value = torch.zeros(parent_var.size).unsqueeze(0) continuous_tensors.append(fixed_value) else: @@ -227,110 +232,128 @@ def _get_parent_combinations(self) -> Tuple[torch.Tensor, torch.Tensor]: return torch.cat(all_full_inputs, dim=0), torch.cat(all_discrete_state_vectors, dim=0) # ------------------------------------------------------------------ - # CPT / potential-table construction + # Factor construction (for inference algorithms) # ------------------------------------------------------------------ - def build_cpt(self) -> torch.Tensor: - if self.shared: + @staticmethod + def _variable_cardinality(var: Variable) -> int: + """Return the number of discrete states for a variable.""" + if var.distribution in _BINARY_DISTRIBUTIONS: + return 2 + elif var.distribution in _CATEGORICAL_DISTRIBUTIONS: + return var.size + elif var.distribution in _CONTINUOUS_DISTRIBUTIONS: + raise ValueError( + f"Continuous variable '{var.concept}' " + f"(distribution={var.distribution.__name__}, size={var.size}) " + f"cannot be discretized into a factor." + ) + else: raise NotImplementedError( - "build_cpt() is not supported for shared CPDs. " - "Shared CPDs output concatenated logits for multiple concepts " - "and cannot be decomposed into per-variable CPTs." + f"Cannot determine cardinality for distribution " + f"{var.distribution.__name__}." ) - if not self.variable: - raise RuntimeError("ParametricCPD not linked to a Variable in ProbabilisticModel.") - - all_full_inputs, discrete_state_vectors = self._get_parent_combinations() - input_batch = all_full_inputs + def build_factor(self, cardinalities: dict = None, + input: torch.Tensor = None) -> "Factor": + """ + Build a :class:`Factor` representing this CPD as a multi-dimensional + tensor ``P(child | parents)``. + + The tensor has one axis per variable in ``{parents} ∪ {child}``, + and is filled by evaluating the parametrization over every + parent-state combination followed by sigmoid (Bernoulli) or + softmax (Categorical). + + Parameters + ---------- + cardinalities : dict, optional + Pre-computed ``{variable_name: num_states}`` mapping. If + ``None`` the cardinalities are inferred from the + :class:`Variable` objects. + input : torch.Tensor, optional + Input embedding of shape ``(batch, emb_dim)``. Only used for + root nodes (CPDs with no parents). When provided, the + embedding is fed to the parametrization to produce per-sample + factor values (returned :class:`Factor` has ``batched=True``). + Child nodes ignore this parameter — their factors are + determined entirely by parent-state combinations. - if input_batch.shape[-1] != self.parametrization.in_features: - raise RuntimeError( - f"Input tensor dimension mismatch for CPT building. " - f"ParametricCPD module expects {self.parametrization.in_features} features, " - f"but parent combinations resulted in {input_batch.shape[-1]} features. " - f"Check Variable definition and ProbabilisticModel resolution." + Returns + ------- + Factor + A factor over ``[*parent_names, child_name]``. + """ + if cardinalities is None: + cardinalities = {} + + # --- determine variable names and cardinalities ----------------- + child_name = self.concept + child_var = self.variable + child_card = cardinalities.get( + child_name, self._variable_cardinality(child_var) + ) + cardinalities[child_name] = child_card + + parent_names = [] + parent_cards = [] + for p in self.parents: + # Continuous parents (Delta, Normal, …) are held at fixed values + # during table construction and do not contribute discrete axes. + if p.distribution in _CONTINUOUS_DISTRIBUTIONS: + continue + pname = p.concept + pcard = cardinalities.get( + pname, self._variable_cardinality(p) ) - - endogenous = self.parametrization(input=input_batch) - probabilities = None - - if self.variable.distribution is Bernoulli: - # Traditional P(X=1) output - p_c1 = torch.sigmoid(endogenous) - - # ACHIEVE THE REQUESTED 4x3 STRUCTURE: [Parent States | P(X=1)] - probabilities = torch.cat([discrete_state_vectors, p_c1], dim=-1) - - elif self.variable.distribution in (Categorical, OneHotCategorical, RelaxedOneHotCategorical): - probabilities = torch.softmax(endogenous, dim=-1) - - elif self.variable.distribution is Delta: - probabilities = endogenous - + cardinalities[pname] = pcard + parent_names.append(pname) + parent_cards.append(pcard) + + # --- evaluate the neural network over all parent combinations --- + if input is not None and not parent_names: + # Input-conditioned mode for root nodes (no discrete parents): + # produce per-sample factors. + B = input.size(0) + logits = self.parametrization(input).unsqueeze(1) # (B, 1, out) + + if child_var.distribution in _BINARY_DISTRIBUTIONS | _CONTINUOUS_DISTRIBUTIONS: + p1 = torch.sigmoid(logits) + probs = torch.cat([1.0 - p1, p1], dim=-1) # (B, 1, 2) + elif child_var.distribution in _CATEGORICAL_DISTRIBUTIONS: + probs = torch.softmax(logits, dim=-1) + else: + raise NotImplementedError( + f"build_factor() not supported for " + f"{child_var.distribution.__name__}." + ) + + values = probs.reshape([B, child_card]) + variables = parent_names + [child_name] + return Factor(values, variables, cardinalities, batched=True) + + # --- non-batched path (for child nodes, or when input is None) -- + all_inputs, _ = self._get_parent_combinations() + logits = self.parametrization(all_inputs) # (n_combos, out) + + if child_var.distribution in _BINARY_DISTRIBUTIONS | _CONTINUOUS_DISTRIBUTIONS: + p1 = torch.sigmoid(logits) # P(child=1 | parents) + probs = torch.cat([1.0 - p1, p1], dim=-1) # (n_combos, 2) + elif child_var.distribution in _CATEGORICAL_DISTRIBUTIONS: + probs = torch.softmax(logits, dim=-1) # (n_combos, K) else: - raise NotImplementedError(f"CPT for {self.variable.distribution.__name__} not supported.") - - return probabilities - - def build_potential(self) -> torch.Tensor: - if self.shared: raise NotImplementedError( - "build_potential() is not supported for shared CPDs. " - "Shared CPDs output concatenated logits for multiple concepts " - "and cannot be decomposed into per-variable potential tables." + f"build_factor() not supported for " + f"{child_var.distribution.__name__}." ) - if not self.variable: - raise RuntimeError("ParametricCPD not linked to a Variable in ProbabilisticModel.") - - # We need the core probability part for potential calculation - all_full_inputs, discrete_state_vectors = self._get_parent_combinations() - endogenous = self.parametrization(input=all_full_inputs) - - if self.variable.distribution is Bernoulli: - cpt_core = torch.sigmoid(endogenous) - elif self.variable.distribution in (Categorical, OneHotCategorical, RelaxedOneHotCategorical): - cpt_core = torch.softmax(endogenous, dim=-1) - elif self.variable.distribution is Delta: - cpt_core = endogenous - else: - raise NotImplementedError("Potential table construction not supported for this distribution.") - - # --- Potential Table Construction --- - - if self.variable.distribution is Bernoulli: - p_c1 = cpt_core - p_c0 = 1.0 - cpt_core - - child_states_c0 = torch.zeros_like(p_c0) - child_states_c1 = torch.ones_like(p_c1) - - # Rows for X=1: [Parent States | Child State (1) | P(X=1)] - rows_c1 = torch.cat([discrete_state_vectors, child_states_c1, p_c1], dim=-1) - # Rows for X=0: [Parent States | Child State (0) | P(X=0)] - rows_c0 = torch.cat([discrete_state_vectors, child_states_c0, p_c0], dim=-1) - - potential_table = torch.cat([rows_c1, rows_c0], dim=0) - - elif self.variable.distribution in (Categorical, OneHotCategorical, RelaxedOneHotCategorical): - n_classes = self.variable.size - all_rows = [] - for i in range(n_classes): - child_state_col = torch.full((cpt_core.shape[0], 1), float(i), dtype=torch.float32) - prob_col = cpt_core[:, i].unsqueeze(-1) - - # [Parent States | Child State (i) | P(X=i)] - rows_ci = torch.cat([discrete_state_vectors, child_state_col, prob_col], dim=-1) - all_rows.append(rows_ci) - - potential_table = torch.cat(all_rows, dim=0) - - elif self.variable.distribution is Delta: - # [Parent States | Child Value] - child_value = cpt_core - potential_table = torch.cat([discrete_state_vectors, child_value], dim=-1) + # --- reshape into multi-dimensional tensor ---------------------- + # Shape: (*parent_cards, child_card) + if parent_cards: + shape = parent_cards + [child_card] else: - raise NotImplementedError("Potential table construction not supported for this distribution.") + shape = [child_card] + values = probs.reshape(shape) - return potential_table + variables = parent_names + [child_name] + return Factor(values, variables, cardinalities) diff --git a/torch_concepts/nn/modules/mid/models/parametric_factor.py b/torch_concepts/nn/modules/mid/models/parametric_factor.py new file mode 100644 index 00000000..6ba765e0 --- /dev/null +++ b/torch_concepts/nn/modules/mid/models/parametric_factor.py @@ -0,0 +1,101 @@ +""" +ParametricFactor base class for probabilistic graphical models. + +This module defines the ParametricFactor base class, which represents a factor +in a factor graph. Factors associate concept-variables with neural network +parametrizations and form the building blocks for both directed (Bayesian +Networks) and undirected (Markov Random Fields) graphical models. +""" +import torch.nn as nn +from typing import List, Optional, Union + +from .variable import Variable + + +class ParametricFactor(nn.Module): + """ + Base class for factors in a probabilistic graphical model. + + A ParametricFactor associates a set of named concepts (its *scope*) with a + single neural-network parametrization. The factor produces one potential + over all the concepts in its scope. + + This base class is agnostic to the directionality of the graphical model. + Subclasses specialise the semantics: + + * :class:`ParametricCPD` — directed factor with explicit ``parents``, + used in Bayesian Networks. + * (future) undirected factors for Markov Random Fields. + + Parameters + ---------- + concepts : Union[str, List[str]] + A single concept name or a list of concept names defining the scope + of this factor. A single ``ParametricFactor`` instance is always + created (never a list). + parametrization : nn.Module + A single neural-network module that computes the factor potential. + + Attributes + ---------- + concept : str + Primary concept name (first element of the scope). + concepts : List[str] + Full scope of concept names. + parametrization : nn.Module + The neural network module used to compute factor values. + variable : Optional[Variable] + The :class:`Variable` instance this factor is linked to + (set by :class:`ProbabilisticModel` during initialisation). + + See Also + -------- + ParametricCPD : Directed factor for conditional probability distributions. + Variable : Represents a random variable (concept) in the model. + ProbabilisticModel : Generic container that manages factors and variables. + """ + + def __init__(self, concepts: Union[str, List[str]], + parametrization: nn.Module, + **kwargs): + """ + Initialize a ParametricFactor instance. + + Parameters + ---------- + concepts : Union[str, List[str]] + Single concept name (stored as ``self.concept``). + parametrization : Union[nn.Module, List[nn.Module]] + Neural network module for computing factor values. + **kwargs + Ignored at this level; accepted so that subclass keyword + arguments (e.g. ``parents``) pass through ``__new__`` without error. + """ + super().__init__() + if isinstance(concepts, str): + self.concepts: List[str] = [concepts] + else: + self.concepts: List[str] = list(concepts) + self.concept: str = self.concepts[0] + self.parametrization = parametrization + self.variable: Optional[Variable] = None + + def forward(self, **kwargs): + """ + Compute the factor output by running the parametrization module. + + Parameters + ---------- + **kwargs + Keyword arguments passed to the parametrization module. + + Returns + ------- + torch.Tensor + Output of the parametrization module. + """ + return self.parametrization(**kwargs) + + def __repr__(self): + scope = self.concepts if len(self.concepts) > 1 else self.concept + return f"{self.__class__.__name__}(concepts={scope!r}, parametrization={self.parametrization.__class__.__name__})" diff --git a/torch_concepts/nn/modules/mid/models/probabilistic_model.py b/torch_concepts/nn/modules/mid/models/probabilistic_model.py index 4c1df738..297054e5 100644 --- a/torch_concepts/nn/modules/mid/models/probabilistic_model.py +++ b/torch_concepts/nn/modules/mid/models/probabilistic_model.py @@ -9,13 +9,15 @@ them without directed-graph semantics. """ +import torch from torch import nn from torch.distributions import Distribution from typing import List, Dict, Optional, Type, Union -from .variable import Variable, ExogenousVariable, ConceptVariable -from .factor import ParametricFactor -from .cpd import ParametricCPD +from .variable import Variable, ExogenousVariable, ConceptVariable, _CONTINUOUS_DISTRIBUTIONS +from .parametric_factor import ParametricFactor +from .parametric_cpd import ParametricCPD +from .factor import Factor # --------------------------------------------------------------------------- @@ -206,49 +208,52 @@ def get_variable_parents(self, concept_name: str) -> List[Variable]: cpd = self.get_module_of_concept(concept_name) return cpd.parents if cpd is not None and hasattr(cpd, 'parents') else [] - # ---- CPT / potential-table helpers (directed models) ------------------- + # ---- Factor construction (for inference algorithms) ---------------- - def _make_temp_parametric_cpd(self, concept: str, module: nn.Module) -> ParametricCPD: - """Create a temporary ParametricCPD for table-building helpers.""" - if isinstance(module, ParametricCPD): - parametrization = module.parametrization - else: - parametrization = module - f = ParametricCPD(concepts=concept, parametrization=parametrization) - f.variable = self.concept_to_variable[concept] - stored = self.factors[str(concept)] if str(concept) in self.factors else None - f.parents = stored.parents if stored is not None else [] - return f - - def build_potentials(self): - """Build potential tables for all concepts. - - Raises: - NotImplementedError: If the model contains shared CPDs. + def build_factors(self, cardinalities: dict = None, + input: "torch.Tensor | None" = None) -> List[Factor]: """ - self._reject_shared_cpds("build_potentials") - return { - concept: self._make_temp_parametric_cpd(concept, module).build_potential() - for concept, module in self.factors.items() - } - - def build_cpts(self): - """Build Conditional Probability Tables for all concepts. - - Raises: - NotImplementedError: If the model contains shared CPDs. + Build :class:`Factor` instances for every factor in the model. + + For directed models each :class:`ParametricCPD` produces a factor + over ``{parents ∪ child}``. For undirected models each + :class:`ParametricFactor` produces a factor over its scope. + + Parameters + ---------- + cardinalities : dict, optional + Pre-computed ``{variable_name: num_states}`` mapping. If + ``None`` the cardinalities are inferred from the + :class:`Variable` objects. + input : torch.Tensor, optional + Input embedding of shape ``(batch, emb_dim)``. Passed + through to each :class:`ParametricCPD`'s ``build_factor`` + so that factor values are conditioned on the input. + + Returns + ------- + List[Factor] + One :class:`Factor` per registered factor in the model. """ - self._reject_shared_cpds("build_cpts") - return { - concept: self._make_temp_parametric_cpd(concept, module).build_cpt() - for concept, module in self.factors.items() - } - - def _reject_shared_cpds(self, method_name: str) -> None: - """Raise if any factor is a shared CPD.""" - if self._shared_cpd_map: - raise NotImplementedError( - f"{method_name}() does not support shared CPDs. " - f"Secondary concepts {list(self._shared_cpd_map.keys())} " - f"would be silently omitted." - ) + if cardinalities is None: + cardinalities = {} + + factors: List[Factor] = [] + for concept, module in self.factors.items(): + if isinstance(module, ParametricCPD): + # Skip continuous/latent variables — they are embeddings, + # not discrete states that can be represented as factors. + if module.variable.distribution in _CONTINUOUS_DISTRIBUTIONS: + continue + factors.append(module.build_factor(cardinalities, + input=input)) + else: + scope_vars = [ + self.concept_to_variable[c] + for c in module.concepts + if c in self.concept_to_variable + ] + factors.append( + module.build_factor(scope_vars, cardinalities) + ) + return factors diff --git a/torch_concepts/nn/modules/mid/models/variable.py b/torch_concepts/nn/modules/mid/models/variable.py index 28d724ff..aab833f4 100644 --- a/torch_concepts/nn/modules/mid/models/variable.py +++ b/torch_concepts/nn/modules/mid/models/variable.py @@ -14,6 +14,11 @@ from .....distributions import Delta +# Distribution type groups. +_BINARY_DISTRIBUTIONS = {Bernoulli, RelaxedBernoulli} +_CATEGORICAL_DISTRIBUTIONS = {Categorical, OneHotCategorical, RelaxedOneHotCategorical} +_CONTINUOUS_DISTRIBUTIONS = {Normal, MultivariateNormal, Delta} + # Default distributions per concept type group (binary / categorical / continuous). _DEFAULT_DISTRIBUTIONS: Dict[str, Type[Distribution]] = { 'binary': Bernoulli,