Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 131 additions & 64 deletions torch_concepts/nn/modules/mid/inference/deterministic.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
"""Deterministic inference for probabilistic graphical models."""

import warnings

import torch

from .....distributions import Delta
from .forward import ForwardInference
from ..models.variable import Variable

from ..models.variable import _DEFAULT_ACTIVATIONS

class DeterministicInference(ForwardInference):
"""
Deterministic forward inference for probabilistic graphical models.

This inference engine performs deterministic (maximum likelihood) inference by
returning raw logits/outputs from CPDs without sampling. It's useful for
prediction tasks where you want the most likely values rather than samples
from the distribution.

Inherits all functionality from ForwardInference but implements activate()
to map logits to probabilities without stochastic sampling.
This inference engine propagates raw CPD outputs forward without sampling or
using distribution-specific semantics. Because of that, it treats every
variable as deterministic: if the provided probabilistic model contains
variables whose distribution is not :class:`Delta`, the variables are
converted to :class:`Delta` with identity activations during initialization.

Example:
>>> import torch
>>> from torch.distributions import Bernoulli
>>> from torch_concepts import LatentVariable, ConceptVariable
>>> from torch_concepts.distributions import Delta
>>> from torch_concepts.nn import DeterministicInference, ParametricCPD, ProbabilisticModel, LinearConceptToConcept
>>>
>>> # Create a simple PGM: latent -> A -> B
>>> input_var = LatentVariable('input', parents=[], distribution=Delta, size=10)
>>> var_A = ConceptVariable('A', parents=['input'], distribution=Bernoulli, size=1)
>>> var_B = ConceptVariable('B', parents=['A'], distribution=Bernoulli, size=1)
>>> input_var = LatentVariable('input', distribution=Delta, size=10)
>>> var_A = ConceptVariable('A', distribution=Delta, size=1)
>>> var_B = ConceptVariable('B', distribution=Delta, size=1)
>>>
>>> # Define CPDs
>>> from torch.nn import Identity, Linear
>>> cpd_emb = ParametricCPD('input', parametrization=Identity())
>>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1))
>>> cpd_B = ParametricCPD('B', parametrization=LinearConceptToConcept(1, 1))
>>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1), parents=['input'])
>>> cpd_B = ParametricCPD('B', parametrization=LinearConceptToConcept(1, 1), parents=['A'])
>>>
>>> # Create probabilistic model
>>> pgm = ProbabilisticModel(
Expand All @@ -47,93 +48,159 @@ class DeterministicInference(ForwardInference):
>>>
>>> # Perform inference - returns logits, not samples
>>> x = torch.randn(4, 10) # batch_size=4, latent_size=10
>>> results = inference.predict({'input': x})
>>>
>>> # Results contain raw logits for Bernoulli variables
>>> print(results['A'].shape) # torch.Size([4, 1]) - logits, not {0,1}
>>> print(results['B'].shape) # torch.Size([4, 1]) - logits, not {0,1}
>>>
>>> # Query specific concepts - returns concatenated logits
>>> output = inference.query(['B', 'A'], evidence={'input': x})
>>> print(output.shape) # torch.Size([4, 2])
>>> # output contains [logit_B, logit_A] for each sample
>>>
>>> # Convert logits to probabilities if needed
>>> prob_A = torch.sigmoid(results['A'])
>>> print(prob_A.shape) # torch.Size([4, 1])
>>>
>>> # Get hard predictions (0 or 1)
>>> pred_A = (prob_A > 0.5).float()
>>> print(pred_A) # Binary predictions
>>> print(output.probs.shape) # torch.Size([4, 2])
>>> # output.probs contains [logit_B, logit_A] for each sample
"""
def __init__(
self,
probabilistic_model,
graph_learner=None,
detach: bool = False,
lazy: bool = False,
p: float = 0.0,
propagate: str = 'probs',
*args,
**kwargs,
):
if not 0.0 <= p <= 1.0:
raise ValueError(f"p must be in [0, 1], got {p}")
if propagate not in ('logits', 'probs'):
raise ValueError(f"propagate must be 'logits' or 'probs', got {propagate!r}")
self.propagate = propagate
self._coerce_variables_to_delta(probabilistic_model)
super().__init__(
probabilistic_model,
graph_learner,
detach,
lazy,
p,
*args,
**kwargs,
)

def _coerce_variables_to_delta(self, probabilistic_model) -> None:
var_to_change = [
var for var in probabilistic_model.variables
if var.distribution is not Delta
]

if var_to_change:
non_delta_summary = ", ".join(
f"{var.concept} ({getattr(var.distribution, '__name__', var.distribution)})"
for var in var_to_change
)
with warnings.catch_warnings():
warnings.simplefilter("always", UserWarning)
warnings.warn(
"DeterministicInference assumes all variables are Delta() variables. " \
"All Variables will be changed to Delta(), with activations set to identity. "
f"Non-Delta variables: {non_delta_summary}.",
UserWarning,
stacklevel=3,
)

for var in probabilistic_model.variables:
if var in var_to_change:
if self.propagate == 'logits':
var.activation = lambda x: x
else:
var.activation = _DEFAULT_ACTIVATIONS.get(var.distribution, lambda x: x)
var.distribution = Delta
var.dist_kwargs = {}

def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor:
"""
Map logits to probabilities using the variable's activation.

The activation function is stored on the :class:`Variable` instance
(defaulting to sigmoid for Bernoulli, softmax for Categorical,
identity for Delta, etc.). Custom activations can be provided when
constructing the variable.
Apply activation function to raw CPD outputs.

Args:
pred: Prediction tensor (logits).
variable: The Variable whose prediction is being propagated.
variable: The Variable whose prediction is being activated.

Returns:
torch.Tensor: Probability tensor.
torch.Tensor: Activated prediction tensor.
"""
return variable.activation(pred)

def ground_truth_to_evidence(self, value: torch.Tensor, cardinality: int) -> torch.Tensor:
def ground_truth_to_evidence(self, value: torch.Tensor, size: int, type: str) -> torch.Tensor:
"""
Convert discrete ground truth to activated probabilities for propagation.
Convert ground truth to tensors used for propagation.

Since the inference engine now owns the activation, propagation values
must be in the same representation as ``activate`` produces:
probabilities for Bernoulli (0.0/1.0) and one-hot for Categorical.
Ground-truth propagation keeps the existing discrete encoding: 0.0/1.0
for binary variables and one-hot vectors for categorical variables.
Dense tensors with shape ``(batch_size, size)`` are passed
through directly, which supports already-encoded categorical evidence
and continuous variables.

Supports both binary (cardinality=1) and categorical (cardinality>1)
variables. DOES NOT SUPPORT CONTINUOUS VARIABLES.
Supports binary (size=1), categorical (size>1), and
dense continuous variables.

Parameters
----------
value : torch.Tensor
Ground truth tensor. Shape: (batch_size,) or (batch_size, 1).
- Binary (cardinality=1): values in {0, 1}
- Categorical (cardinality>1): class indices
cardinality : int
- Binary (size=1): binary values with shape (batch_size, )
- Categorical (size>1): class indices with shape (batch_size,) or one-hot vectors with shape (batch_size, size)
- Continuous: values with shape (batch_size, size)
size : int
Number of features/classes for this variable.
type : str
Type of the variable ('binary', 'categorical', 'continuous' or 'delta').

Returns
-------
torch.Tensor
Probability / one-hot tensor. Shape: (batch_size, cardinality).
Value tensor. Shape: (batch_size, size).
"""

# TODO: add support for continuous variables

# Allow (batch,) and unsqueeze to (batch, 1)
if value.dim() == 1:
value = value.unsqueeze(-1)

if value.dim() != 2 or value.shape[-1] != 1:
if value.dim() != 2:
raise ValueError(
f"Expected shape (batch,) or (batch, 1), got {tuple(value.shape)}."
f"Expected shape (batch,), (batch, 1), or "
f"(batch, {size}), got {tuple(value.shape)}."
)

if cardinality == 1:
# Binary: validate values are in {0, 1}
unique_vals = value.unique()
if not all(v in (0, 1) for v in unique_vals.tolist()):
import warnings
width = value.shape[-1]

if type == 'binary':
if width != 1:
raise ValueError(
f"Expected shape (batch,) or (batch, 1) for binary variable, "
f"got {tuple(value.shape)}."
)
if not torch.all((value == 0) | (value == 1)):
unique_vals = value.unique()
warnings.warn(
f"Binary ground truth contains values outside {{0, 1}}: "
f"{unique_vals.tolist()}. Values will be used as-is.",
stacklevel=2,
)
return value.float()
else:
# Categorical: return one-hot probabilities
return torch.nn.functional.one_hot(
value.squeeze(-1).long(), num_classes=cardinality
).float()
probs = value.float()
return torch.logit(probs, eps=1e-7) if self.propagate == 'logits' else probs

elif type == 'categorical':
if width == size:
# Already one-hot encoded
one_hot = value.float()
elif width != 1:
raise ValueError(
f"Expected shape (batch,), (batch, 1), or "
f"(batch, {size}) for categorical variable, got {tuple(value.shape)}."
)
else:
# Class indices → one-hot
one_hot = torch.nn.functional.one_hot(
value.squeeze(-1).long(), num_classes=size
).float()
return torch.logit(one_hot, eps=1e-7) if self.propagate == 'logits' else one_hot

else: # 'continuous' or 'delta'
if width == size:
return value.float()
raise ValueError(
f"Expected shape (batch, {size}) for {type} variable, "
f"got {tuple(value.shape)}."
)
6 changes: 3 additions & 3 deletions torch_concepts/nn/modules/mid/inference/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def activate(self, pred: torch.Tensor, variable: Variable) -> torch.Tensor:

Subclass contracts:

* ``DeterministicInference`` — Bernoulli → sigmoid, Categorical → softmax
* ``DeterministicInference`` — apply activation function before propagation
* ``AncestralSamplingInference`` — sample from the distribution

Args:
Expand Down Expand Up @@ -783,7 +783,8 @@ def query(
idx = index_map[name]
gt_value = self.ground_truth_to_evidence(
value=ground_truth[:, idx:idx+1],
cardinality=variable.size,
size=variable.size,
type=variable.type
)
if self.p >= 1.0:
propagation[name] = gt_value
Expand Down Expand Up @@ -1006,4 +1007,3 @@ def unrolled_probabilistic_model(self) -> ProbabilisticModel:
self._unrolled_query_vars = set(v.concept for v in new_variables)

return ProbabilisticModel(new_variables, new_parametric_cpds)

16 changes: 13 additions & 3 deletions torch_concepts/nn/modules/mid/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
RelaxedBernoulli,
OneHotCategorical,
RelaxedOneHotCategorical,
# TODO: add support for continuous distributions
# Normal,
# MultivariateNormal,
Normal,
MultivariateNormal,
Delta
]

Expand Down Expand Up @@ -60,6 +59,16 @@
Delta: lambda x: x,
}

_DEFAULT_TYPES: Dict[Type[Distribution], str] = {
Bernoulli: 'binary',
RelaxedBernoulli: 'binary',
OneHotCategorical: 'categorical',
RelaxedOneHotCategorical: 'categorical',
Normal: 'continuous',
MultivariateNormal: 'continuous',
Delta: 'delta'
}

# Number of raw parameters needed to parameterise each supported distribution
# given a variable of a certain *size* (event dimension).
_PARAM_DIMS: Dict[Type[Distribution], Dict[str, Callable[[int], int]]] = {
Expand Down Expand Up @@ -300,6 +309,7 @@ def __init__(self, concepts: Union[str, List[str]],
self.size = size
self.dist_kwargs = dist_kwargs if dist_kwargs is not None else {}
self.metadata = metadata if metadata is not None else {}
self.type = _DEFAULT_TYPES[distribution]
if activation is not None:
self.activation = activation
else:
Expand Down
Loading