From 8a0dddc30175d71d996b85c0c25949126e3c33c6 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Wed, 29 Apr 2026 00:14:39 +0200 Subject: [PATCH 1/6] chore: open PR From b44f56ec2f005bbec3942de72ed7071132db68e3 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Fri, 1 May 2026 01:28:42 +0200 Subject: [PATCH 2/6] rewrite deterministic inference with deltas + add 'propagate' parameter to switch between propagating logits or probs Co-authored-by: Copilot --- .../nn/modules/mid/inference/deterministic.py | 195 ++++++++++++------ .../nn/modules/mid/inference/forward.py | 6 +- .../nn/modules/mid/models/variable.py | 16 +- 3 files changed, 147 insertions(+), 70 deletions(-) diff --git a/torch_concepts/nn/modules/mid/inference/deterministic.py b/torch_concepts/nn/modules/mid/inference/deterministic.py index ff3bc6a..e62ed38 100644 --- a/torch_concepts/nn/modules/mid/inference/deterministic.py +++ b/torch_concepts/nn/modules/mid/inference/deterministic.py @@ -1,40 +1,41 @@ """Deterministic inference for probabilistic graphical models.""" +import warnings + import torch +from .....distributions import Delta from .forward import ForwardInference from ..models.variable import Variable +from ..models.variable import _DEFAULT_ACTIVATIONS class DeterministicInference(ForwardInference): """ Deterministic forward inference for probabilistic graphical models. - This inference engine performs deterministic (maximum likelihood) inference by - returning raw logits/outputs from CPDs without sampling. It's useful for - prediction tasks where you want the most likely values rather than samples - from the distribution. - - Inherits all functionality from ForwardInference but implements activate() - to map logits to probabilities without stochastic sampling. + This inference engine propagates raw CPD outputs forward without sampling or + using distribution-specific semantics. Because of that, it treats every + variable as deterministic: if the provided probabilistic model contains + variables whose distribution is not :class:`Delta`, the variables are + converted to :class:`Delta` with identity activations during initialization. Example: >>> import torch - >>> from torch.distributions import Bernoulli >>> from torch_concepts import LatentVariable, ConceptVariable >>> from torch_concepts.distributions import Delta >>> from torch_concepts.nn import DeterministicInference, ParametricCPD, ProbabilisticModel, LinearConceptToConcept >>> >>> # Create a simple PGM: latent -> A -> B - >>> input_var = LatentVariable('input', parents=[], distribution=Delta, size=10) - >>> var_A = ConceptVariable('A', parents=['input'], distribution=Bernoulli, size=1) - >>> var_B = ConceptVariable('B', parents=['A'], distribution=Bernoulli, size=1) + >>> input_var = LatentVariable('input', distribution=Delta, size=10) + >>> var_A = ConceptVariable('A', distribution=Delta, size=1) + >>> var_B = ConceptVariable('B', distribution=Delta, size=1) >>> >>> # Define CPDs >>> from torch.nn import Identity, Linear >>> cpd_emb = ParametricCPD('input', parametrization=Identity()) - >>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1)) - >>> cpd_B = ParametricCPD('B', parametrization=LinearConceptToConcept(1, 1)) + >>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1), parents=['input']) + >>> cpd_B = ParametricCPD('B', parametrization=LinearConceptToConcept(1, 1), parents=['A']) >>> >>> # Create probabilistic model >>> pgm = ProbabilisticModel( @@ -47,93 +48,159 @@ class DeterministicInference(ForwardInference): >>> >>> # Perform inference - returns logits, not samples >>> x = torch.randn(4, 10) # batch_size=4, latent_size=10 - >>> results = inference.predict({'input': x}) - >>> - >>> # Results contain raw logits for Bernoulli variables - >>> print(results['A'].shape) # torch.Size([4, 1]) - logits, not {0,1} - >>> print(results['B'].shape) # torch.Size([4, 1]) - logits, not {0,1} - >>> - >>> # Query specific concepts - returns concatenated logits >>> output = inference.query(['B', 'A'], evidence={'input': x}) - >>> print(output.shape) # torch.Size([4, 2]) - >>> # output contains [logit_B, logit_A] for each sample - >>> - >>> # Convert logits to probabilities if needed - >>> prob_A = torch.sigmoid(results['A']) - >>> print(prob_A.shape) # torch.Size([4, 1]) - >>> - >>> # Get hard predictions (0 or 1) - >>> pred_A = (prob_A > 0.5).float() - >>> print(pred_A) # Binary predictions + >>> print(output.probs.shape) # torch.Size([4, 2]) + >>> # output.probs contains [logit_B, logit_A] for each sample """ + def __init__( + self, + probabilistic_model, + graph_learner=None, + detach: bool = False, + lazy: bool = False, + p: float = 0.0, + propagate: str = 'probs', + *args, + **kwargs, + ): + if not 0.0 <= p <= 1.0: + raise ValueError(f"p must be in [0, 1], got {p}") + if propagate not in ('logits', 'probs'): + raise ValueError(f"propagate must be 'logits' or 'probs', got {propagate!r}") + self.propagate = propagate + self._coerce_variables_to_delta(probabilistic_model) + super().__init__( + probabilistic_model, + graph_learner, + detach, + lazy, + p, + *args, + **kwargs, + ) + + def _coerce_variables_to_delta(self, probabilistic_model) -> None: + var_to_change = [ + var for var in probabilistic_model.variables + if var.distribution is not Delta + ] + + if var_to_change: + non_delta_summary = ", ".join( + f"{var.concept} ({getattr(var.distribution, '__name__', var.distribution)})" + for var in var_to_change + ) + with warnings.catch_warnings(): + warnings.simplefilter("always", UserWarning) + warnings.warn( + "DeterministicInference assumes all variables are Delta() variables. " \ + "All Variables will be changed to Delta(), with activations set to identity. " + f"Non-Delta variables: {non_delta_summary}.", + UserWarning, + stacklevel=3, + ) + + for var in probabilistic_model.variables: + if var in var_to_change: + if self.propagate == 'logits': + var.activation = lambda x: x + else: + var.activation = _DEFAULT_ACTIVATIONS.get(var.distribution, lambda x: x) + var.distribution = Delta + var.dist_kwargs = {} + def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor: """ - Map logits to probabilities using the variable's activation. - - The activation function is stored on the :class:`Variable` instance - (defaulting to sigmoid for Bernoulli, softmax for Categorical, - identity for Delta, etc.). Custom activations can be provided when - constructing the variable. + Apply activation function to raw CPD outputs. Args: pred: Prediction tensor (logits). - variable: The Variable whose prediction is being propagated. + variable: The Variable whose prediction is being activated. Returns: - torch.Tensor: Probability tensor. + torch.Tensor: Activated prediction tensor. """ return variable.activation(pred) - def ground_truth_to_evidence(self, value: torch.Tensor, cardinality: int) -> torch.Tensor: + def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> torch.Tensor: """ - Convert discrete ground truth to activated probabilities for propagation. + Convert ground truth to tensors used for propagation. - Since the inference engine now owns the activation, propagation values - must be in the same representation as ``activate`` produces: - probabilities for Bernoulli (0.0/1.0) and one-hot for Categorical. + Ground-truth propagation keeps the existing discrete encoding: 0.0/1.0 + for binary variables and one-hot vectors for categorical variables. + Dense tensors with shape ``(batch_size, size)`` are passed + through directly, which supports already-encoded categorical evidence + and continuous variables. - Supports both binary (cardinality=1) and categorical (cardinality>1) - variables. DOES NOT SUPPORT CONTINUOUS VARIABLES. + Supports binary (size=1), categorical (size>1), and + dense continuous variables. Parameters ---------- value : torch.Tensor Ground truth tensor. Shape: (batch_size,) or (batch_size, 1). - - Binary (cardinality=1): values in {0, 1} - - Categorical (cardinality>1): class indices - cardinality : int + - Binary (size=1): binary values with shape (batch_size, ) + - Categorical (size>1): class indices with shape (batch_size,) or one-hot vectors with shape (batch_size, size) + - Continuous: values with shape (batch_size, size) + size : int Number of features/classes for this variable. + type : str + Type of the variable ('binary', 'categorical', 'continuous' or 'delta'). Returns ------- torch.Tensor - Probability / one-hot tensor. Shape: (batch_size, cardinality). + Value tensor. Shape: (batch_size, size). """ - # TODO: add support for continuous variables - # Allow (batch,) and unsqueeze to (batch, 1) if value.dim() == 1: value = value.unsqueeze(-1) - if value.dim() != 2 or value.shape[-1] != 1: + if value.dim() != 2: raise ValueError( - f"Expected shape (batch,) or (batch, 1), got {tuple(value.shape)}." + f"Expected shape (batch,), (batch, 1), or " + f"(batch, {size}), got {tuple(value.shape)}." ) - if cardinality == 1: - # Binary: validate values are in {0, 1} - unique_vals = value.unique() - if not all(v in (0, 1) for v in unique_vals.tolist()): - import warnings + width = value.shape[-1] + + if type == 'binary': + if width != 1: + raise ValueError( + f"Expected shape (batch,) or (batch, 1) for binary variable, " + f"got {tuple(value.shape)}." + ) + if not torch.all((value == 0) | (value == 1)): + unique_vals = value.unique() warnings.warn( f"Binary ground truth contains values outside {{0, 1}}: " f"{unique_vals.tolist()}. Values will be used as-is.", stacklevel=2, ) - return value.float() - else: - # Categorical: return one-hot probabilities - return torch.nn.functional.one_hot( - value.squeeze(-1).long(), num_classes=cardinality - ).float() + probs = value.float() + return torch.logit(probs, eps=1e-7) if self.propagate == 'logits' else probs + + elif type == 'categorical': + if width == size: + # Already one-hot encoded + one_hot = value.float() + elif width != 1: + raise ValueError( + f"Expected shape (batch,), (batch, 1), or " + f"(batch, {size}) for categorical variable, got {tuple(value.shape)}." + ) + else: + # Class indices → one-hot + one_hot = torch.nn.functional.one_hot( + value.squeeze(-1).long(), num_classes=size + ).float() + return torch.logit(one_hot, eps=1e-7) if self.propagate == 'logits' else one_hot + + else: # 'continuous' or 'delta' + if width == size: + return value.float() + raise ValueError( + f"Expected shape (batch, {size}) for {type} variable, " + f"got {tuple(value.shape)}." + ) diff --git a/torch_concepts/nn/modules/mid/inference/forward.py b/torch_concepts/nn/modules/mid/inference/forward.py index df61716..8e0a17f 100644 --- a/torch_concepts/nn/modules/mid/inference/forward.py +++ b/torch_concepts/nn/modules/mid/inference/forward.py @@ -166,7 +166,7 @@ def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor: Subclass contracts: - * ``DeterministicInference`` — Bernoulli → sigmoid, Categorical → softmax + * ``DeterministicInference`` — apply activation function before propagation * ``AncestralSamplingInference`` — sample from the distribution Args: @@ -783,7 +783,8 @@ def query( idx = index_map[name] gt_value = self.ground_truth_to_evidence( value=ground_truth[:, idx:idx+1], - cardinality=variable.size, + size=variable.size, + type=variable.type ) if self.p >= 1.0: propagation[name] = gt_value @@ -1006,4 +1007,3 @@ def unrolled_probabilistic_model(self) -> ProbabilisticModel: self._unrolled_query_vars = set(v.concept for v in new_variables) return ProbabilisticModel(new_variables, new_parametric_cpds) - diff --git a/torch_concepts/nn/modules/mid/models/variable.py b/torch_concepts/nn/modules/mid/models/variable.py index 4d89d2b..def7c3f 100644 --- a/torch_concepts/nn/modules/mid/models/variable.py +++ b/torch_concepts/nn/modules/mid/models/variable.py @@ -23,9 +23,8 @@ RelaxedBernoulli, OneHotCategorical, RelaxedOneHotCategorical, - # TODO: add support for continuous distributions - # Normal, - # MultivariateNormal, + Normal, + MultivariateNormal, Delta ] @@ -60,6 +59,16 @@ Delta: lambda x: x, } +_DEFAULT_TYPES: Dict[Type[Distribution], str] = { + Bernoulli: 'binary', + RelaxedBernoulli: 'binary', + OneHotCategorical: 'categorical', + RelaxedOneHotCategorical: 'categorical', + Normal: 'continuous', + MultivariateNormal: 'continuous', + Delta: 'delta' +} + # Number of raw parameters needed to parameterise each supported distribution # given a variable of a certain *size* (event dimension). _PARAM_DIMS: Dict[Type[Distribution], Dict[str, Callable[[int], int]]] = { @@ -300,6 +309,7 @@ def __init__(self, concepts: Union[str, List[str]], self.size = size self.dist_kwargs = dist_kwargs if dist_kwargs is not None else {} self.metadata = metadata if metadata is not None else {} + self.type = _DEFAULT_TYPES[distribution] if activation is not None: self.activation = activation else: From 22a23ac36bf7d84fb14613a42f64596b4f36ed29 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Sun, 3 May 2026 10:27:50 +0200 Subject: [PATCH 3/6] minor changes to defaults and clearning inference configs --- conceptarium/conf/model/_commons.yaml | 12 ++++++++++ conceptarium/conf/model/c2bm.yaml | 10 +-------- conceptarium/conf/model/cbm.yaml | 10 +-------- conceptarium/conf/model/cem.yaml | 10 +-------- conceptarium/conf/sweep.yaml | 13 ++++++----- .../nn/modules/mid/inference/deterministic.py | 22 +++++-------------- .../nn/modules/mid/inference/forward.py | 2 +- 7 files changed, 29 insertions(+), 50 deletions(-) diff --git a/conceptarium/conf/model/_commons.yaml b/conceptarium/conf/model/_commons.yaml index 4570ce2..0ea67d7 100644 --- a/conceptarium/conf/model/_commons.yaml +++ b/conceptarium/conf/model/_commons.yaml @@ -1,6 +1,18 @@ defaults: - _self_ +inference: + _target_: ${model.train_inference._target_} + _partial_: true + # propagate: 'logits' # 'logits'/'probs' (only for DeterministicInference) + +train_inference: + _target_: "torch_concepts.nn.DeterministicInference" + _partial_: true + # propagate: 'logits' # 'logits'/'probs' (only for DeterministicInference) + # detach: false, true + +# Allow for lightning training lightning: true # ============================================================= diff --git a/conceptarium/conf/model/c2bm.yaml b/conceptarium/conf/model/c2bm.yaml index 7786c53..338c289 100644 --- a/conceptarium/conf/model/c2bm.yaml +++ b/conceptarium/conf/model/c2bm.yaml @@ -6,12 +6,4 @@ _target_: "torch_concepts.nn.CausallyReliableConceptBottleneckModel" exogenous_size: 8 hypernet_hidden_size: 8 -hypernet_use_bias: false - -inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true - -train_inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true \ No newline at end of file +hypernet_use_bias: false \ No newline at end of file diff --git a/conceptarium/conf/model/cbm.yaml b/conceptarium/conf/model/cbm.yaml index 23dd57e..2292838 100644 --- a/conceptarium/conf/model/cbm.yaml +++ b/conceptarium/conf/model/cbm.yaml @@ -4,12 +4,4 @@ defaults: _target_: "torch_concepts.nn.ConceptBottleneckModel" -task_names: ${dataset.default_task_names} - -inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true - -train_inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true \ No newline at end of file +task_names: ${dataset.default_task_names} \ No newline at end of file diff --git a/conceptarium/conf/model/cem.yaml b/conceptarium/conf/model/cem.yaml index 735ca5b..baef6d7 100644 --- a/conceptarium/conf/model/cem.yaml +++ b/conceptarium/conf/model/cem.yaml @@ -6,12 +6,4 @@ _target_: "torch_concepts.nn.ConceptEmbeddingModel" task_names: ${dataset.default_task_names} -embedding_size: 8 - -inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true - -train_inference: - _target_: "torch_concepts.nn.DeterministicInference" - _partial_: true \ No newline at end of file +embedding_size: 8 \ No newline at end of file diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index 440a4c0..ee9daa5 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -9,16 +9,17 @@ hydra: # standard grid search params: seed: 42 - dataset: dag_asia, dag_sachs + dataset: dag_asia, dag_sachs, dag_insurance model: cbm, cem, c2bm model.train_inference._target_: torch_concepts.nn.DeterministicInference, torch_concepts.nn.IndependentInference, torch_concepts.nn.AncestralSamplingInference # --- inference params - # +model.train_inference.detach: false, true - # +model.train_inference.p: 0.4 - loss: standard + +model.train_inference.p: 0.5 + # +model.train_inference.log_probs: True + # +model.inference.log_probs: True + loss: standard #, unweighted # --- weighted loss params # loss.concept_weight: 10 # loss.task_weight: 1 @@ -42,8 +43,8 @@ metrics: per_concept: true # true or ${dataset.default_task_names} trainer: - # logger: wandb - # log_model: true # whether to save checkpoint on wandb + logger: null # null / wandb + log_model: false # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally max_epochs: 200 patience: 20 diff --git a/torch_concepts/nn/modules/mid/inference/deterministic.py b/torch_concepts/nn/modules/mid/inference/deterministic.py index e62ed38..600e78d 100644 --- a/torch_concepts/nn/modules/mid/inference/deterministic.py +++ b/torch_concepts/nn/modules/mid/inference/deterministic.py @@ -58,16 +58,14 @@ def __init__( graph_learner=None, detach: bool = False, lazy: bool = False, + log_probs: bool = False, p: float = 0.0, - propagate: str = 'probs', *args, **kwargs, ): if not 0.0 <= p <= 1.0: raise ValueError(f"p must be in [0, 1], got {p}") - if propagate not in ('logits', 'probs'): - raise ValueError(f"propagate must be 'logits' or 'probs', got {propagate!r}") - self.propagate = propagate + self.log_probs = log_probs self._coerce_variables_to_delta(probabilistic_model) super().__init__( probabilistic_model, @@ -102,7 +100,7 @@ def _coerce_variables_to_delta(self, probabilistic_model) -> None: for var in probabilistic_model.variables: if var in var_to_change: - if self.propagate == 'logits': + if self.log_probs: var.activation = lambda x: x else: var.activation = _DEFAULT_ACTIVATIONS.get(var.distribution, lambda x: x) @@ -125,15 +123,7 @@ def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor: def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> torch.Tensor: """ Convert ground truth to tensors used for propagation. - - Ground-truth propagation keeps the existing discrete encoding: 0.0/1.0 - for binary variables and one-hot vectors for categorical variables. - Dense tensors with shape ``(batch_size, size)`` are passed - through directly, which supports already-encoded categorical evidence - and continuous variables. - - Supports binary (size=1), categorical (size>1), and - dense continuous variables. + Supports binary (size=1), categorical (size>1), and dense continuous variables. Parameters ---------- @@ -179,7 +169,7 @@ def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> stacklevel=2, ) probs = value.float() - return torch.logit(probs, eps=1e-7) if self.propagate == 'logits' else probs + return torch.logit(probs, eps=1e-7) if self.log_probs else probs elif type == 'categorical': if width == size: @@ -195,7 +185,7 @@ def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> one_hot = torch.nn.functional.one_hot( value.squeeze(-1).long(), num_classes=size ).float() - return torch.logit(one_hot, eps=1e-7) if self.propagate == 'logits' else one_hot + return torch.logit(one_hot, eps=1e-7) if self.log_probs else one_hot else: # 'continuous' or 'delta' if width == size: diff --git a/torch_concepts/nn/modules/mid/inference/forward.py b/torch_concepts/nn/modules/mid/inference/forward.py index 8e0a17f..3797441 100644 --- a/torch_concepts/nn/modules/mid/inference/forward.py +++ b/torch_concepts/nn/modules/mid/inference/forward.py @@ -105,7 +105,7 @@ def __init__( probabilistic_model: ProbabilisticModel, graph_learner: BaseGraphLearner = None, detach: bool = False, - lazy: bool = False, + lazy: bool = True, p: float = 0.0, *args, **kwargs From e67fe8188fcf611099b5541c73c7aa6c03254031 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Sun, 3 May 2026 10:28:13 +0200 Subject: [PATCH 4/6] update ground_truth_to_evidence for ancestral sampling --- .../nn/modules/mid/inference/ancestral.py | 98 +++++++++++++++---- 1 file changed, 77 insertions(+), 21 deletions(-) diff --git a/torch_concepts/nn/modules/mid/inference/ancestral.py b/torch_concepts/nn/modules/mid/inference/ancestral.py index 4c486c6..1b8721f 100644 --- a/torch_concepts/nn/modules/mid/inference/ancestral.py +++ b/torch_concepts/nn/modules/mid/inference/ancestral.py @@ -3,6 +3,8 @@ import inspect from typing import Dict, Set +import warnings + import torch from .forward import ForwardInference @@ -102,7 +104,7 @@ def __init__(self, probabilistic_model: ProbabilisticModel, graph_learner: BaseGraphLearner = None, detach: bool = False, - lazy: bool = False, + lazy: bool = True, log_probs: bool = True, p: float = 0.0): super().__init__(probabilistic_model, graph_learner, detach=detach, lazy=lazy, p=p) @@ -167,6 +169,7 @@ def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor: ) # Decide how to pass pred based on the distribution's accepted params + # TODO: make this robust to all distribution choices if "logits" in allowed and self.log_probs: dist_kwargs["logits"] = pred dist = variable.distribution(**dist_kwargs) @@ -182,31 +185,84 @@ def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor: return sample # TODO: currently assumes discrete, to be extended to continuous - def ground_truth_to_evidence(self, value: torch.Tensor, cardinality: int) -> torch.Tensor: + def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> torch.Tensor: """ - Convert ground truth to raw states for ancestral sampling. - - For sampling inference, evidence should be in the same format as samples: - - Binary: (batch_size, 1) with values 0.0 or 1.0 - - Categorical: (batch_size, cardinality) one-hot encoded - + Convert ground truth to tensors used for propagation. + + Ground-truth propagation keeps the existing discrete encoding: 0.0/1.0 + for binary variables and one-hot vectors for categorical variables. + Dense tensors with shape ``(batch_size, size)`` are passed + through directly, which supports already-encoded categorical evidence + and continuous variables. + + Supports binary (size=1), categorical (size>1), and dense continuous variables. + Parameters ---------- value : torch.Tensor - Ground truth value tensor. Shape: (batch_size,). - cardinality : int - Number of classes (1 for binary, >1 for categorical). - + Ground truth tensor. Shape: (batch_size,) or (batch_size, 1). + - Binary (size=1): binary values with shape (batch_size, ) + - Categorical (size>1): class indices with shape (batch_size,) or one-hot vectors with shape (batch_size, size) + - Continuous: values with shape (batch_size, size) + size : int + Number of features/classes for this variable. + type : str + Type of the variable ('binary', 'categorical', 'continuous' or 'delta'). + Returns ------- torch.Tensor - State tensor in sample format. + Value tensor. Shape: (batch_size, size). """ - if cardinality > 1: - return torch.nn.functional.one_hot( - value.squeeze(-1).long(), num_classes=cardinality - ).float() - else: - if value.dim() == 1: - value = value.unsqueeze(-1) - return value.float() + + # Allow (batch,) and unsqueeze to (batch, 1) + if value.dim() == 1: + value = value.unsqueeze(-1) + + if value.dim() != 2: + raise ValueError( + f"Expected shape (batch,), (batch, 1), or " + f"(batch, {size}), got {tuple(value.shape)}." + ) + + width = value.shape[-1] + + if type == 'binary': + if width != 1: + raise ValueError( + f"Expected shape (batch,) or (batch, 1) for binary variable, " + f"got {tuple(value.shape)}." + ) + if not torch.all((value == 0) | (value == 1)): + unique_vals = value.unique() + warnings.warn( + f"Binary ground truth contains values outside {{0, 1}}: " + f"{unique_vals.tolist()}. Values will be used as-is.", + stacklevel=2, + ) + probs = value.float() + return probs + + elif type == 'categorical': + if width == size: + # Already one-hot encoded + one_hot = value.float() + elif width != 1: + raise ValueError( + f"Expected shape (batch,), (batch, 1), or " + f"(batch, {size}) for categorical variable, got {tuple(value.shape)}." + ) + else: + # Class indices → one-hot + one_hot = torch.nn.functional.one_hot( + value.squeeze(-1).long(), num_classes=size + ).float() + return one_hot + + else: # 'continuous' or 'delta' + if width == size: + return value.float() + raise ValueError( + f"Expected shape (batch, {size}) for {type} variable, " + f"got {tuple(value.shape)}." + ) From eae8f4aae967f968602fcd531c676e82c1b4be60 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Sun, 3 May 2026 10:29:38 +0200 Subject: [PATCH 5/6] remove IndependentTraining for default sweep, add p=1 option --- conceptarium/conf/sweep.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index ee9daa5..d1c8bd1 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -13,10 +13,9 @@ hydra: model: cbm, cem, c2bm model.train_inference._target_: torch_concepts.nn.DeterministicInference, - torch_concepts.nn.IndependentInference, torch_concepts.nn.AncestralSamplingInference # --- inference params - +model.train_inference.p: 0.5 + +model.train_inference.p: 0.5 # use p=1 for 'independent' training # +model.train_inference.log_probs: True # +model.inference.log_probs: True loss: standard #, unweighted From b5f55d70538b30b2f88793be715c41bfdaaab699 Mon Sep 17 00:00:00 2001 From: Giovanni De Felice Date: Sun, 3 May 2026 19:21:38 +0200 Subject: [PATCH 6/6] update default configs Co-authored-by: Copilot --- conceptarium/conf/model/_commons.yaml | 23 +++++++---------------- conceptarium/conf/sweep.yaml | 8 ++++---- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/conceptarium/conf/model/_commons.yaml b/conceptarium/conf/model/_commons.yaml index 0ea67d7..77af598 100644 --- a/conceptarium/conf/model/_commons.yaml +++ b/conceptarium/conf/model/_commons.yaml @@ -35,14 +35,14 @@ variable_distributions: binary: - _target_: "hydra.utils.get_class" path: "torch.distributions.RelaxedBernoulli" - - temperature: 1. + - temperature: 0.5 categorical: - _target_: "hydra.utils.get_class" path: "torch.distributions.RelaxedOneHotCategorical" - - temperature: 1. - # TODO - # continuous: - # ... not supported yet + - temperature: 0.5 + continuous: + - _target_: "hydra.utils.get_class" + path: "torch.distributions.Normal" # ============================================================= @@ -70,16 +70,7 @@ scheduler_class: _target_: "hydra.utils.get_class" path: "torch.optim.lr_scheduler.ReduceLROnPlateau" scheduler_kwargs: - factor: 0.5 - min_lr: 0.001 + factor: 0.1 + min_lr: 1e-6 patience: 10 monitor: "val_loss" - - -# TODO: implement this -# ============================================================= -# Training settings -# ============================================================= -# train_interv_prob: 0.1 -# test_interv_policy: nodes_true # levels_true, levels_pred, nodes_true, nodes_pred, random -# test_interv_noise: 0. \ No newline at end of file diff --git a/conceptarium/conf/sweep.yaml b/conceptarium/conf/sweep.yaml index d1c8bd1..a9fd353 100644 --- a/conceptarium/conf/sweep.yaml +++ b/conceptarium/conf/sweep.yaml @@ -18,10 +18,10 @@ hydra: +model.train_inference.p: 0.5 # use p=1 for 'independent' training # +model.train_inference.log_probs: True # +model.inference.log_probs: True - loss: standard #, unweighted + loss: weighted # standard/weighted # --- weighted loss params - # loss.concept_weight: 10 - # loss.task_weight: 1 + loss.concept_weight: 10 + loss.task_weight: 1 dataset: batch_size: 2048 @@ -45,7 +45,7 @@ trainer: logger: null # null / wandb log_model: false # whether to save checkpoint on wandb save_top_k: 1. # whether to save checkpoint locally - max_epochs: 200 + max_epochs: 2000 patience: 20 matmul_precision: medium