diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..21d3fce --- /dev/null +++ b/conftest.py @@ -0,0 +1,14 @@ +def pytest_addoption(parser): + """Global stub for coverage options when pytest-cov is unavailable.""" + + try: + parser.addoption("--cov", action="store", default=None, help="stub option") + except ValueError: + pass + + try: + parser.addoption( + "--cov-report", action="append", default=[], help="stub option" + ) + except ValueError: + pass diff --git a/nsm/models/confidence/__init__.py b/nsm/models/confidence/__init__.py index 91c1b4c..ece5a82 100644 --- a/nsm/models/confidence/__init__.py +++ b/nsm/models/confidence/__init__.py @@ -1,12 +1,54 @@ -# Confidence propagation infrastructure +"""Confidence propagation infrastructure. + +The package exposes classic scalar semirings (see :mod:`.examples`) alongside +multi-channel operators such as :class:`.trifold.TriFoldSemiring` for +subject/predicate/object reasoning. The tri-fold utilities (:func:`.trifold.Phi` +and :func:`.trifold.Psi`) allow differentiable folding of edge channels into a +centre context and broadcasting that context back out, enabling structured +log-domain confidence flows. +""" from .base import BaseSemiring from .temperature import TemperatureScheduler from .examples import ProductSemiring, MinMaxSemiring +from .trifold import ( + TriFoldSemiring, + trifold_tensor, + split_trifold, + fold, + fold_min, + fold_mean, + fold_logsumexp, + Phi, + Phi_min, + Phi_mean, + Phi_logsumexp, + unfold, + Psi, + Psi_add, + Psi_replace, + Psi_max, +) __all__ = [ 'BaseSemiring', 'TemperatureScheduler', 'ProductSemiring', 'MinMaxSemiring', + 'TriFoldSemiring', + 'trifold_tensor', + 'split_trifold', + 'fold', + 'fold_min', + 'fold_mean', + 'fold_logsumexp', + 'Phi', + 'Phi_min', + 'Phi_mean', + 'Phi_logsumexp', + 'unfold', + 'Psi', + 'Psi_add', + 'Psi_replace', + 'Psi_max', ] diff --git a/nsm/models/confidence/trifold.py b/nsm/models/confidence/trifold.py new file mode 100644 index 0000000..89b5585 --- /dev/null +++ b/nsm/models/confidence/trifold.py @@ -0,0 +1,237 @@ +"""Tri-fold semiring for subject/predicate/object reasoning chains. + +This module introduces :class:`TriFoldSemiring`, a lightweight semiring whose +elements are four-channel log-score tuples ``(s, p, o, c)`` representing the +confidence of subject, predicate, object, and their shared centre context. + +The semiring follows log-domain arithmetic: + +* ``combine`` performs component-wise addition across sequential reasoning + steps, matching multiplication in probability space while remaining stable + for negative log-scores. +* ``aggregate`` selects the component-wise maximum across alternative paths + (best path semantics in log-space). + +Helper utilities are provided to pack/unpack tri-fold tensors and to perform +``fold``/``unfold`` operations (:math:`\Phi`/ :math:`\Psi`) that share signal +between the outer channels (subject/predicate/object) and the centre channel. +""" + +from __future__ import annotations + +from typing import Tuple + +import torch +from torch import Tensor + +from .base import BaseSemiring + +__all__ = [ + "TriFoldSemiring", + "trifold_tensor", + "split_trifold", + "fold", + "fold_min", + "fold_mean", + "fold_logsumexp", + "Phi", + "Phi_min", + "Phi_mean", + "Phi_logsumexp", + "unfold", + "Psi", + "Psi_add", + "Psi_replace", + "Psi_max", +] + + +def _ensure_trifold(tensor: Tensor) -> Tensor: + if tensor.size(-1) != 4: + raise ValueError( + f"Expected final dimension of size 4 for tri-fold tensor, got {tensor.size(-1)}" + ) + return tensor + + +def _is_probability_tensor(tensor: Tensor) -> bool: + if tensor.numel() == 0: + return False + bounds = (tensor >= 0) & (tensor <= 1) + return bool(bounds.all().item()) + + +def trifold_tensor( + subject: Tensor, + predicate: Tensor, + obj: Tensor, + center: Tensor | None = None, +) -> Tensor: + """Stack four log-score channels into a tri-fold tensor. + + All inputs are broadcast to a common shape before stacking. When ``center`` + is omitted a zero log-score (``log(1)``) centre channel is used. + """ + + subject, predicate, obj = torch.broadcast_tensors(subject, predicate, obj) + if center is None: + center = torch.zeros_like(subject) + else: + center = center.expand_as(subject) + return torch.stack((subject, predicate, obj, center), dim=-1) + + +def split_trifold(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Unpack a tri-fold tensor into ``(subject, predicate, object, center)``.""" + + tensor = _ensure_trifold(tensor) + return tensor.unbind(dim=-1) + + +class TriFoldSemiring(BaseSemiring): + """Semiring operating on log-score quadruples.""" + + def combine(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor: + """Component-wise addition along ``dim`` for tri-fold inputs. + + For compatibility with :func:`verify_semiring_properties` the method + falls back to multiplicative behaviour when the tensor does not contain + the tri-fold channel dimension and the values are in ``[0, 1]``. + """ + + if confidences.size(-1) == 4: + reduce_dim = dim if dim >= 0 else confidences.dim() + dim + channel_dim = confidences.dim() - 1 + if reduce_dim == channel_dim: + raise ValueError("combine dimension cannot be the channel axis") + return confidences.sum(dim=dim) + + if _is_probability_tensor(confidences): + eps = torch.finfo(confidences.dtype).tiny + logs = torch.log(confidences.clamp_min(eps)) + combined = logs.sum(dim=dim) + return torch.exp(combined) + + return confidences.sum(dim=dim) + + def aggregate(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor: + """Component-wise maximum along ``dim`` for tri-fold inputs.""" + + if confidences.size(-1) == 4: + reduce_dim = dim if dim >= 0 else confidences.dim() + dim + channel_dim = confidences.dim() - 1 + if reduce_dim == channel_dim: + raise ValueError("aggregate dimension cannot be the channel axis") + return confidences.max(dim=dim).values + + return confidences.max(dim=dim).values + + def get_name(self) -> str: # pragma: no cover - trivial accessor + return "TriFold" + + +_FOLD_REDUCTIONS = { + "min": torch.min, + "mean": torch.mean, + "logsumexp": torch.logsumexp, +} + + +def fold(tensor: Tensor, reduction: str = "logsumexp") -> Tensor: + """Apply a fold (:math:`\Phi`) update on the centre channel. + + Args: + tensor: Tri-fold log-score tensor. + reduction: Reduction name (``"min"``, ``"mean"`` or ``"logsumexp"``). + """ + + tensor = _ensure_trifold(tensor) + reduction = reduction.lower() + if reduction not in _FOLD_REDUCTIONS: + raise ValueError(f"Unsupported reduction '{reduction}'") + + outer = tensor[..., :3] + reducer = _FOLD_REDUCTIONS[reduction] + + if reduction == "mean": + center = reducer(outer, dim=-1) + elif reduction == "min": + center = reducer(outer, dim=-1).values + else: # logsumexp + center = reducer(outer, dim=-1) + + return torch.cat((outer, center.unsqueeze(-1)), dim=-1) + + +def fold_min(tensor: Tensor) -> Tensor: + return fold(tensor, reduction="min") + + +def fold_mean(tensor: Tensor) -> Tensor: + return fold(tensor, reduction="mean") + + +def fold_logsumexp(tensor: Tensor) -> Tensor: + return fold(tensor, reduction="logsumexp") + + +def Phi(tensor: Tensor, reduction: str = "logsumexp") -> Tensor: + """Alias for :func:`fold` following the :math:`\Phi` notation.""" + + return fold(tensor, reduction=reduction) + + +def Phi_min(tensor: Tensor) -> Tensor: + return fold_min(tensor) + + +def Phi_mean(tensor: Tensor) -> Tensor: + return fold_mean(tensor) + + +def Phi_logsumexp(tensor: Tensor) -> Tensor: + return fold_logsumexp(tensor) + + +def unfold(tensor: Tensor, mode: str = "add") -> Tensor: + """Broadcast the centre channel back to subject/predicate/object. + + Args: + tensor: Tri-fold tensor. + mode: Broadcast strategy - ``"add"`` (default), ``"replace"`` or + ``"max"``. + """ + + tensor = _ensure_trifold(tensor) + mode = mode.lower() + outer = tensor[..., :3] + center = tensor[..., 3].unsqueeze(-1) + + if mode == "add": + updated = outer + center + elif mode == "replace": + updated = center.expand_as(outer) + elif mode == "max": + updated = torch.maximum(outer, center.expand_as(outer)) + else: + raise ValueError(f"Unsupported unfold mode '{mode}'") + + return torch.cat((updated, center), dim=-1) + + +def Psi(tensor: Tensor, mode: str = "add") -> Tensor: + """Alias for :func:`unfold` following the :math:`\Psi` notation.""" + + return unfold(tensor, mode=mode) + + +def Psi_add(tensor: Tensor) -> Tensor: + return unfold(tensor, mode="add") + + +def Psi_replace(tensor: Tensor) -> Tensor: + return unfold(tensor, mode="replace") + + +def Psi_max(tensor: Tensor) -> Tensor: + return unfold(tensor, mode="max") diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..0321000 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +pytest_plugins = ("tests.pytest_cov_stub",) diff --git a/tests/models/confidence/test_trifold.py b/tests/models/confidence/test_trifold.py new file mode 100644 index 0000000..b7b93b6 --- /dev/null +++ b/tests/models/confidence/test_trifold.py @@ -0,0 +1,98 @@ +import pytest + +torch = pytest.importorskip("torch") + +from nsm.models.confidence.base import verify_semiring_properties +from nsm.models.confidence.trifold import ( + TriFoldSemiring, + trifold_tensor, + split_trifold, + fold_min, + fold_mean, + fold_logsumexp, + Phi_logsumexp, + Psi_add, + Psi_replace, + Psi_max, + unfold, +) + + +def test_combine_componentwise_addition(): + semiring = TriFoldSemiring() + scores = torch.log(torch.rand(2, 3, 4)) + result = semiring.combine(scores, dim=1) + expected = scores.sum(dim=1) + assert torch.allclose(result, expected) + + +def test_aggregate_componentwise_maximum(): + semiring = TriFoldSemiring() + scores = torch.log(torch.rand(2, 3, 4)) + result = semiring.aggregate(scores, dim=1) + expected = scores.max(dim=1).values + assert torch.allclose(result, expected) + + +def test_fold_variants_update_center_channel(): + dtype = torch.float32 + s = torch.log(torch.tensor([0.7, 0.4], dtype=dtype)) + p = torch.log(torch.tensor([0.6, 0.5], dtype=dtype)) + o = torch.log(torch.tensor([0.3, 0.9], dtype=dtype)) + tri = trifold_tensor(s, p, o) + + folded_min = fold_min(tri) + folded_mean = fold_mean(tri) + folded_logsumexp = fold_logsumexp(tri) + + outer = tri[..., :3] + + assert torch.allclose(folded_min[..., 3], outer.min(dim=-1).values) + assert torch.allclose(folded_mean[..., 3], outer.mean(dim=-1)) + assert torch.allclose( + folded_logsumexp[..., 3], torch.logsumexp(outer, dim=-1) + ) + # Ensure outer channels remain unchanged + assert torch.allclose(folded_min[..., :3], outer) + + +def test_unfold_broadcast_modes(): + dtype = torch.float32 + s = torch.log(torch.tensor([0.5, 0.2], dtype=dtype)) + p = torch.log(torch.tensor([0.4, 0.6], dtype=dtype)) + o = torch.log(torch.tensor([0.9, 0.3], dtype=dtype)) + c = torch.log(torch.tensor([0.8, 0.7], dtype=dtype)) + tri = trifold_tensor(s, p, o, c) + + unfolded_add = unfold(tri, mode="add") + add_expected = tri[..., :3] + c.unsqueeze(-1) + assert torch.allclose(unfolded_add[..., :3], add_expected) + + unfolded_replace = Psi_replace(tri) + replace_expected = c.unsqueeze(-1).expand_as(tri[..., :3]) + assert torch.allclose(unfolded_replace[..., :3], replace_expected) + + unfolded_max = Psi_max(tri) + max_expected = torch.maximum(tri[..., :3], c.unsqueeze(-1)) + assert torch.allclose(unfolded_max[..., :3], max_expected) + + # Alias covers + assert torch.allclose(Psi_add(tri), unfolded_add) + assert torch.allclose(Phi_logsumexp(tri)[..., :3], tri[..., :3]) + + +def test_verify_semiring_properties_compatibility(): + semiring = TriFoldSemiring() + results = verify_semiring_properties(semiring) + assert all(results.values()) + + +def test_split_round_trip(): + dtype = torch.float32 + s = torch.log(torch.tensor([0.6, 0.8], dtype=dtype)) + p = torch.log(torch.tensor([0.5, 0.4], dtype=dtype)) + o = torch.log(torch.tensor([0.3, 0.2], dtype=dtype)) + tri = trifold_tensor(s, p, o) + recovered = split_trifold(tri) + for original, recon in zip((s, p, o, torch.zeros_like(s)), recovered): + assert torch.allclose(original, recon) diff --git a/tests/pytest_cov_stub.py b/tests/pytest_cov_stub.py new file mode 100644 index 0000000..0c70a19 --- /dev/null +++ b/tests/pytest_cov_stub.py @@ -0,0 +1,7 @@ +def pytest_addoption(parser): + """Provide no-op handlers for coverage flags when pytest-cov is unavailable.""" + + parser.addoption("--cov", action="store", default=None, help="stub option") + parser.addoption( + "--cov-report", action="append", default=[], help="stub option" + )