From 037ba2078974e760f36eb4d1d2e8e9307df150fe Mon Sep 17 00:00:00 2001 From: research-developer <115124732+research-developer@users.noreply.github.com> Date: Tue, 21 Oct 2025 05:46:54 -0600 Subject: [PATCH] Add tri-fold confidence semiring and utilities --- nsm/models/confidence/__init__.py | 26 ++++- nsm/models/confidence/base.py | 6 +- nsm/models/confidence/trifold.py | 130 ++++++++++++++++++++++++ tests/models/confidence/test_trifold.py | 98 ++++++++++++++++++ 4 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 nsm/models/confidence/trifold.py create mode 100644 tests/models/confidence/test_trifold.py diff --git a/nsm/models/confidence/__init__.py b/nsm/models/confidence/__init__.py index 91c1b4c..8b86f84 100644 --- a/nsm/models/confidence/__init__.py +++ b/nsm/models/confidence/__init__.py @@ -1,12 +1,36 @@ -# Confidence propagation infrastructure +"""Confidence propagation infrastructure. + +The package exposes common semiring implementations together with operator +tooling for reasoning over multi-channel confidence scores. The +``TriFoldSemiring`` models subject/predicate/object log-scores with a shared +center channel and ships with differentiable folding (``Phi``) and unfolding +(``Psi``) helpers for neural modules that need to align these channels during +training. +""" from .base import BaseSemiring from .temperature import TemperatureScheduler from .examples import ProductSemiring, MinMaxSemiring +from .trifold import ( + TriFoldSemiring, + as_trifold, + split_trifold, + Phi_min, + Phi_mean, + Phi_logsumexp, + Psi, +) __all__ = [ 'BaseSemiring', 'TemperatureScheduler', 'ProductSemiring', 'MinMaxSemiring', + 'TriFoldSemiring', + 'as_trifold', + 'split_trifold', + 'Phi_min', + 'Phi_mean', + 'Phi_logsumexp', + 'Psi', ] diff --git a/nsm/models/confidence/base.py b/nsm/models/confidence/base.py index e5ba3d5..6a61d63 100644 --- a/nsm/models/confidence/base.py +++ b/nsm/models/confidence/base.py @@ -237,7 +237,11 @@ def verify_semiring_properties( try: # Test 2: Identity element (typically 1.0) a = test_values[0] - identity = torch.tensor(1.0) + identity_fn = getattr(semiring, "get_combine_identity", None) + if callable(identity_fn): + identity = identity_fn(a) + else: + identity = torch.ones_like(a) combined = semiring.combine(torch.stack([a, identity])) results['combine_identity'] = torch.allclose(combined, a, atol=atol) diff --git a/nsm/models/confidence/trifold.py b/nsm/models/confidence/trifold.py new file mode 100644 index 0000000..ec3423c --- /dev/null +++ b/nsm/models/confidence/trifold.py @@ -0,0 +1,130 @@ +"""Tri-fold confidence semiring utilities. + +This module provides a semiring where each element tracks four correlated +log-scores corresponding to subject, predicate, object, and a shared center +channel. The :class:`TriFoldSemiring` composes multi-hop reasoning chains via +component-wise addition (log-space composition) and aggregates alternative paths +with component-wise maxima. Helper functions convert between structured tuples +and packed tensors and expose differentiable folding operators (``Phi``) that +update the center channel together with an unfolding operator (``Psi``) that +broadcasts the center score back to the leaf channels. +""" + +from __future__ import annotations + +from typing import Callable, Tuple + +import torch +from torch import Tensor + +from .base import BaseSemiring + + +TRIFOLD_CHANNELS = 4 + + +def as_trifold( + subject: Tensor, + predicate: Tensor, + obj: Tensor, + center: Tensor, +) -> Tensor: + """Stack individual channels into a tri-fold tensor. + + All channels are broadcast to a common shape before being stacked along the + last dimension. The resulting tensor always has a final dimension of size 4 + representing ``(subject, predicate, object, center)``. + """ + + subject, predicate, obj, center = torch.broadcast_tensors( + subject, predicate, obj, center + ) + return torch.stack((subject, predicate, obj, center), dim=-1) + + +def split_trifold(trifold: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Split a tri-fold tensor into ``(subject, predicate, object, center)``.""" + + if trifold.size(-1) != TRIFOLD_CHANNELS: + raise ValueError( + f"Expected last dimension of size {TRIFOLD_CHANNELS}, got {trifold.size(-1)}" + ) + subject = trifold[..., 0] + predicate = trifold[..., 1] + obj = trifold[..., 2] + center = trifold[..., 3] + return subject, predicate, obj, center + + +def _ensure_trifold(confidences: Tensor) -> Tensor: + if confidences.size(-1) != TRIFOLD_CHANNELS: + raise ValueError( + "TriFoldSemiring expects tensors with last dimension == 4. " + f"Received shape {tuple(confidences.shape)}" + ) + return confidences + + +class TriFoldSemiring(BaseSemiring): + """Semiring for reasoning over tri-fold log-score tuples.""" + + def combine(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor: + """Compose sequential steps by component-wise addition.""" + + confidences = _ensure_trifold(confidences) + if confidences.ndim < 2: + return confidences + return torch.sum(confidences, dim=dim) + + def aggregate(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor: + """Aggregate alternative paths with component-wise maxima.""" + + confidences = _ensure_trifold(confidences) + if confidences.ndim < 2: + return confidences + return torch.max(confidences, dim=dim).values + + def get_name(self) -> str: + return "TriFold" + + def get_combine_identity(self, reference: Tensor) -> Tensor: + """Return the additive identity (zero vector) for the semiring.""" + + return torch.zeros_like(reference) + + +def _phi(trifold: Tensor, reducer: Callable[[Tensor], Tensor]) -> Tensor: + """Apply a reducer over subject/predicate/object and update the center.""" + + trifold = _ensure_trifold(trifold) + subject, predicate, obj, _ = split_trifold(trifold) + stacked = torch.stack((subject, predicate, obj), dim=-1) + new_center = reducer(stacked) + return as_trifold(subject, predicate, obj, new_center) + + +def Phi_min(trifold: Tensor) -> Tensor: + """Fold the minimum of subject/predicate/object into the center channel.""" + + return _phi(trifold, lambda x: torch.min(x, dim=-1).values) + + +def Phi_mean(trifold: Tensor) -> Tensor: + """Fold the mean of subject/predicate/object into the center channel.""" + + return _phi(trifold, lambda x: torch.mean(x, dim=-1)) + + +def Phi_logsumexp(trifold: Tensor) -> Tensor: + """Fold the log-sum-exp of subject/predicate/object into the center channel.""" + + return _phi(trifold, lambda x: torch.logsumexp(x, dim=-1)) + + +def Psi(trifold: Tensor) -> Tensor: + """Broadcast the center channel back to subject/predicate/object channels.""" + + trifold = _ensure_trifold(trifold) + _, _, _, center = split_trifold(trifold) + return as_trifold(center, center, center, center) + diff --git a/tests/models/confidence/test_trifold.py b/tests/models/confidence/test_trifold.py new file mode 100644 index 0000000..1589cb7 --- /dev/null +++ b/tests/models/confidence/test_trifold.py @@ -0,0 +1,98 @@ +import torch + +from nsm.models.confidence.base import verify_semiring_properties +from nsm.models.confidence.trifold import ( + TriFoldSemiring, + as_trifold, + split_trifold, + Phi_min, + Phi_mean, + Phi_logsumexp, + Psi, +) + + +def test_trifold_helpers_round_trip(): + subject = torch.tensor([0.1, 0.2]) + predicate = torch.tensor([0.3, 0.4]) + obj = torch.tensor([0.5, 0.6]) + center = torch.tensor([0.7, 0.8]) + + packed = as_trifold(subject, predicate, obj, center) + unpacked = split_trifold(packed) + + assert torch.allclose(unpacked[0], subject) + assert torch.allclose(unpacked[1], predicate) + assert torch.allclose(unpacked[2], obj) + assert torch.allclose(unpacked[3], center) + + +def test_trifold_semiring_combine_adds_componentwise(): + semiring = TriFoldSemiring() + confidences = torch.tensor( + [[0.1, 0.2, 0.3, 0.4], [0.4, 0.5, 0.6, 0.7]] + ) + + combined = semiring.combine(confidences) + expected = torch.tensor([0.5, 0.7, 0.9, 1.1]) + + assert torch.allclose(combined, expected) + + +def test_trifold_semiring_aggregate_max_componentwise(): + semiring = TriFoldSemiring() + confidences = torch.tensor( + [[0.1, 0.4, 0.3, 0.2], [0.6, 0.5, 0.7, 0.8], [0.2, 0.3, 0.4, 0.9]] + ) + + aggregated = semiring.aggregate(confidences) + expected = torch.tensor([0.6, 0.5, 0.7, 0.9]) + + assert torch.allclose(aggregated, expected) + + +def test_phi_operators_update_center_channel(): + trifold = as_trifold( + torch.tensor([0.0, 0.5]), + torch.tensor([0.1, 0.6]), + torch.tensor([-0.2, 0.2]), + torch.tensor([0.3, 0.4]), + ) + + phi_min = Phi_min(trifold) + phi_mean = Phi_mean(trifold) + phi_logsumexp = Phi_logsumexp(trifold) + + stacked = torch.stack(split_trifold(trifold)[:3], dim=0) + expected_min = stacked.min(dim=0).values + expected_mean = stacked.mean(dim=0) + expected_logsumexp = torch.logsumexp(stacked, dim=0) + + assert torch.allclose(split_trifold(phi_min)[3], expected_min) + assert torch.allclose(split_trifold(phi_mean)[3], expected_mean) + assert torch.allclose(split_trifold(phi_logsumexp)[3], expected_logsumexp) + + +def test_psi_broadcasts_center_channel(): + trifold = as_trifold( + torch.tensor([0.1, 0.2]), + torch.tensor([0.3, 0.4]), + torch.tensor([0.5, 0.6]), + torch.tensor([0.7, 0.8]), + ) + + broadcast = Psi(trifold) + subject, predicate, obj, center = split_trifold(broadcast) + + assert torch.allclose(subject, center) + assert torch.allclose(predicate, center) + assert torch.allclose(obj, center) + + +def test_trifold_semiring_compatible_with_property_checks(): + semiring = TriFoldSemiring() + test_values = torch.randn(5, 4) + + results = verify_semiring_properties(semiring, test_values=test_values) + + assert all(results.values()), f"Failed properties: {results}"