Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion nsm/models/confidence/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
# Confidence propagation infrastructure
"""Confidence propagation infrastructure.

The package exposes common semiring implementations together with operator
tooling for reasoning over multi-channel confidence scores. The
``TriFoldSemiring`` models subject/predicate/object log-scores with a shared
center channel and ships with differentiable folding (``Phi``) and unfolding
(``Psi``) helpers for neural modules that need to align these channels during
training.
"""

from .base import BaseSemiring
from .temperature import TemperatureScheduler
from .examples import ProductSemiring, MinMaxSemiring
from .trifold import (
TriFoldSemiring,
as_trifold,
split_trifold,
Phi_min,
Phi_mean,
Phi_logsumexp,
Psi,
)

__all__ = [
'BaseSemiring',
'TemperatureScheduler',
'ProductSemiring',
'MinMaxSemiring',
'TriFoldSemiring',
'as_trifold',
'split_trifold',
'Phi_min',
'Phi_mean',
'Phi_logsumexp',
'Psi',
]
6 changes: 5 additions & 1 deletion nsm/models/confidence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ def verify_semiring_properties(
try:
# Test 2: Identity element (typically 1.0)
a = test_values[0]
identity = torch.tensor(1.0)
identity_fn = getattr(semiring, "get_combine_identity", None)
if callable(identity_fn):
identity = identity_fn(a)
else:
identity = torch.ones_like(a)
combined = semiring.combine(torch.stack([a, identity]))
results['combine_identity'] = torch.allclose(combined, a, atol=atol)

Expand Down
130 changes: 130 additions & 0 deletions nsm/models/confidence/trifold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tri-fold confidence semiring utilities.

This module provides a semiring where each element tracks four correlated
log-scores corresponding to subject, predicate, object, and a shared center
channel. The :class:`TriFoldSemiring` composes multi-hop reasoning chains via
component-wise addition (log-space composition) and aggregates alternative paths
with component-wise maxima. Helper functions convert between structured tuples
and packed tensors and expose differentiable folding operators (``Phi``) that
update the center channel together with an unfolding operator (``Psi``) that
broadcasts the center score back to the leaf channels.
"""

from __future__ import annotations

from typing import Callable, Tuple

import torch
from torch import Tensor

from .base import BaseSemiring


TRIFOLD_CHANNELS = 4


def as_trifold(
subject: Tensor,
predicate: Tensor,
obj: Tensor,
center: Tensor,
) -> Tensor:
"""Stack individual channels into a tri-fold tensor.

All channels are broadcast to a common shape before being stacked along the
last dimension. The resulting tensor always has a final dimension of size 4
representing ``(subject, predicate, object, center)``.
"""

subject, predicate, obj, center = torch.broadcast_tensors(
subject, predicate, obj, center
)
return torch.stack((subject, predicate, obj, center), dim=-1)


def split_trifold(trifold: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Split a tri-fold tensor into ``(subject, predicate, object, center)``."""

if trifold.size(-1) != TRIFOLD_CHANNELS:
raise ValueError(
f"Expected last dimension of size {TRIFOLD_CHANNELS}, got {trifold.size(-1)}"
)
subject = trifold[..., 0]
predicate = trifold[..., 1]
obj = trifold[..., 2]
center = trifold[..., 3]
return subject, predicate, obj, center


def _ensure_trifold(confidences: Tensor) -> Tensor:
if confidences.size(-1) != TRIFOLD_CHANNELS:
raise ValueError(
"TriFoldSemiring expects tensors with last dimension == 4. "
f"Received shape {tuple(confidences.shape)}"
)
return confidences


class TriFoldSemiring(BaseSemiring):
"""Semiring for reasoning over tri-fold log-score tuples."""

def combine(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor:
"""Compose sequential steps by component-wise addition."""

confidences = _ensure_trifold(confidences)
if confidences.ndim < 2:
return confidences
return torch.sum(confidences, dim=dim)

def aggregate(self, confidences: Tensor, dim: int = -2, **_: object) -> Tensor:
"""Aggregate alternative paths with component-wise maxima."""

confidences = _ensure_trifold(confidences)
if confidences.ndim < 2:
return confidences
return torch.max(confidences, dim=dim).values

def get_name(self) -> str:
return "TriFold"

def get_combine_identity(self, reference: Tensor) -> Tensor:
"""Return the additive identity (zero vector) for the semiring."""

return torch.zeros_like(reference)


def _phi(trifold: Tensor, reducer: Callable[[Tensor], Tensor]) -> Tensor:
"""Apply a reducer over subject/predicate/object and update the center."""

trifold = _ensure_trifold(trifold)
subject, predicate, obj, _ = split_trifold(trifold)
stacked = torch.stack((subject, predicate, obj), dim=-1)
new_center = reducer(stacked)
return as_trifold(subject, predicate, obj, new_center)


def Phi_min(trifold: Tensor) -> Tensor:
"""Fold the minimum of subject/predicate/object into the center channel."""

return _phi(trifold, lambda x: torch.min(x, dim=-1).values)


def Phi_mean(trifold: Tensor) -> Tensor:
"""Fold the mean of subject/predicate/object into the center channel."""

return _phi(trifold, lambda x: torch.mean(x, dim=-1))


def Phi_logsumexp(trifold: Tensor) -> Tensor:
"""Fold the log-sum-exp of subject/predicate/object into the center channel."""

return _phi(trifold, lambda x: torch.logsumexp(x, dim=-1))


def Psi(trifold: Tensor) -> Tensor:
"""Broadcast the center channel back to subject/predicate/object channels."""

trifold = _ensure_trifold(trifold)
_, _, _, center = split_trifold(trifold)
return as_trifold(center, center, center, center)

98 changes: 98 additions & 0 deletions tests/models/confidence/test_trifold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

from nsm.models.confidence.base import verify_semiring_properties
from nsm.models.confidence.trifold import (
TriFoldSemiring,
as_trifold,
split_trifold,
Phi_min,
Phi_mean,
Phi_logsumexp,
Psi,
)


def test_trifold_helpers_round_trip():
subject = torch.tensor([0.1, 0.2])
predicate = torch.tensor([0.3, 0.4])
obj = torch.tensor([0.5, 0.6])
center = torch.tensor([0.7, 0.8])

packed = as_trifold(subject, predicate, obj, center)
unpacked = split_trifold(packed)

assert torch.allclose(unpacked[0], subject)
assert torch.allclose(unpacked[1], predicate)
assert torch.allclose(unpacked[2], obj)
assert torch.allclose(unpacked[3], center)


def test_trifold_semiring_combine_adds_componentwise():
semiring = TriFoldSemiring()
confidences = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.5, 0.6, 0.7]]
)

combined = semiring.combine(confidences)
expected = torch.tensor([0.5, 0.7, 0.9, 1.1])

assert torch.allclose(combined, expected)


def test_trifold_semiring_aggregate_max_componentwise():
semiring = TriFoldSemiring()
confidences = torch.tensor(
[[0.1, 0.4, 0.3, 0.2], [0.6, 0.5, 0.7, 0.8], [0.2, 0.3, 0.4, 0.9]]
)

aggregated = semiring.aggregate(confidences)
expected = torch.tensor([0.6, 0.5, 0.7, 0.9])

assert torch.allclose(aggregated, expected)


def test_phi_operators_update_center_channel():
trifold = as_trifold(
torch.tensor([0.0, 0.5]),
torch.tensor([0.1, 0.6]),
torch.tensor([-0.2, 0.2]),
torch.tensor([0.3, 0.4]),
)

phi_min = Phi_min(trifold)
phi_mean = Phi_mean(trifold)
phi_logsumexp = Phi_logsumexp(trifold)

stacked = torch.stack(split_trifold(trifold)[:3], dim=0)
expected_min = stacked.min(dim=0).values
expected_mean = stacked.mean(dim=0)
expected_logsumexp = torch.logsumexp(stacked, dim=0)

assert torch.allclose(split_trifold(phi_min)[3], expected_min)
assert torch.allclose(split_trifold(phi_mean)[3], expected_mean)
assert torch.allclose(split_trifold(phi_logsumexp)[3], expected_logsumexp)


def test_psi_broadcasts_center_channel():
trifold = as_trifold(
torch.tensor([0.1, 0.2]),
torch.tensor([0.3, 0.4]),
torch.tensor([0.5, 0.6]),
torch.tensor([0.7, 0.8]),
)

broadcast = Psi(trifold)
subject, predicate, obj, center = split_trifold(broadcast)

assert torch.allclose(subject, center)
assert torch.allclose(predicate, center)
assert torch.allclose(obj, center)


def test_trifold_semiring_compatible_with_property_checks():
semiring = TriFoldSemiring()
test_values = torch.randn(5, 4)

results = verify_semiring_properties(semiring, test_values=test_values)

assert all(results.values()), f"Failed properties: {results}"
Loading