From ee2f914ad8a10539c32d77b9bc5cd586b29643e4 Mon Sep 17 00:00:00 2001 From: research-developer <115124732+research-developer@users.noreply.github.com> Date: Tue, 21 Oct 2025 05:46:13 -0600 Subject: [PATCH] Add TriFold semiring reasoning head --- README.md | 1 + nsm/models/confidence/__init__.py | 12 ++ nsm/models/confidence/trifold.py | 269 ++++++++++++++++++++++++ nsm/models/hierarchical.py | 93 +++++++- tests/models/confidence/test_trifold.py | 66 ++++++ 5 files changed, 433 insertions(+), 8 deletions(-) create mode 100644 nsm/models/confidence/trifold.py create mode 100644 tests/models/confidence/test_trifold.py diff --git a/README.md b/README.md index 6b73ee8..90e8dd9 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Six semantic levels: - **Symbolic Layer**: Semantic graph with typed edges (R-GCN), level tags, confidence values - **Neural Layer**: Learnable confidence tensors (provenance semirings), message passing with gradient flow +- **TriFold Semantic Head**: Tropical semiring over (subject, predicate, object, center) log-scores with fold/unfold operators for semantic triple convergence - **Training**: Cycle consistency loss `||WHY(WHAT(x))-x||²`, information-theoretic pruning (80-85% sparsification) ## Current Phase: Phase 1 Foundation diff --git a/nsm/models/confidence/__init__.py b/nsm/models/confidence/__init__.py index 91c1b4c..72460c1 100644 --- a/nsm/models/confidence/__init__.py +++ b/nsm/models/confidence/__init__.py @@ -3,10 +3,22 @@ from .base import BaseSemiring from .temperature import TemperatureScheduler from .examples import ProductSemiring, MinMaxSemiring +from .trifold import ( + TriFoldSemiring, + TriFoldReasoner, + TriFoldFold, + TriFoldUnfold, + TRIFOLD_CHANNELS, +) __all__ = [ 'BaseSemiring', 'TemperatureScheduler', 'ProductSemiring', 'MinMaxSemiring', + 'TriFoldSemiring', + 'TriFoldReasoner', + 'TriFoldFold', + 'TriFoldUnfold', + 'TRIFOLD_CHANNELS', ] diff --git a/nsm/models/confidence/trifold.py b/nsm/models/confidence/trifold.py new file mode 100644 index 0000000..5cc1b40 --- /dev/null +++ b/nsm/models/confidence/trifold.py @@ -0,0 +1,269 @@ +"""TriFold semiring and operators for neurosymbolic semantic triples. + +Implements the recursive triadic confidence algebra described in the +project discussion: + +- ``TriFoldSemiring`` operates in log-space over 4-tuples ``(s, p, o, c)`` + representing subject, predicate, object, and convergence scores. +- ``TriFoldFold`` implements the folding operator :math:`\Phi` that pushes + loop evidence into the nexus. +- ``TriFoldUnfold`` implements the unfolding operator :math:`\Psi` that + propagates nexus coherence back to each loop. +- ``TriFoldReasoner`` orchestrates iterative fold/unfold message passing and + provides aggregated semantics for each graph in a batch. + +The implementation keeps the semiring operations distributive by +encapsulating folding/unfolding as separate differentiable modules instead of +changing the semiring product. This mirrors the specification from the +"TriFold" design document and allows seamless integration with the existing +neurosymbolic hierarchy. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +import torch +from torch import Tensor +import torch.nn as nn + +from .base import BaseSemiring + + +TRIFOLD_CHANNELS = 4 # subject, predicate, object, center + + +def _validate_trifold_shape(tensor: Tensor) -> None: + """Ensure the final dimension encodes a tri-fold state.""" + + if tensor.size(-1) != TRIFOLD_CHANNELS: + raise ValueError( + f"TriFold tensors must have last dimension of size {TRIFOLD_CHANNELS}, " + f"got {tensor.size(-1)}" + ) + + +class TriFoldSemiring(BaseSemiring): + """Tropical-style semiring over tri-fold log scores.""" + + def __init__(self, zero: float = float("-inf"), one: float = 0.0): + self.zero = zero + self.one = one + + def combine( + self, + confidences: Tensor, + dim: int = -2, + mask: Optional[Tensor] = None, + keepdim: bool = False, + **kwargs, + ) -> Tensor: + """Sequential composition corresponds to addition in log-space.""" + + _validate_trifold_shape(confidences) + + if mask is not None: + mask = mask.to(confidences.dtype) + while mask.dim() < confidences.dim() - 1: + mask = mask.unsqueeze(-1) + confidences = confidences * mask + + combined = torch.sum(confidences, dim=dim, keepdim=keepdim) + return combined + + def aggregate( + self, + confidences: Tensor, + dim: int = -2, + mask: Optional[Tensor] = None, + keepdim: bool = False, + **kwargs, + ) -> Tensor: + """Aggregate competing hypotheses via component-wise maximum.""" + + _validate_trifold_shape(confidences) + + values = confidences + if mask is not None: + mask = mask.to(confidences.device) + while mask.dim() < confidences.dim() - 1: + mask = mask.unsqueeze(-1) + fill_value = torch.full_like(confidences, self.zero) + values = torch.where(mask.bool(), confidences, fill_value) + + aggregated = torch.max(values, dim=dim, keepdim=keepdim).values + return aggregated + + def element( + self, + subject: float, + predicate: float, + obj: float, + center: float, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tensor: + """Create a tri-fold element tensor.""" + + return torch.tensor([subject, predicate, obj, center], device=device, dtype=dtype) + + def get_name(self) -> str: + return "TriFold" + + +@dataclass +class TriFoldOutput: + """Container for tri-fold reasoning outputs.""" + + states: Tensor + aggregated: Tensor + center: Tensor + loops: Tensor + fold_history: Tensor + + +class TriFoldFold(nn.Module): + """Fold operator :math:`\Phi` that accumulates loop evidence.""" + + def __init__( + self, + alpha: float = 1.0, + reduction: str = "min", + ) -> None: + super().__init__() + self.alpha = alpha + self.reduction = reduction + + reducers: Dict[str, Callable[[Tensor], Tensor]] = { + "min": lambda x: torch.min(x, dim=-1).values, + "mean": lambda x: torch.mean(x, dim=-1), + "logsumexp": lambda x: torch.logsumexp(x, dim=-1), + } + + if reduction not in reducers: + raise ValueError( + f"Unknown reduction '{reduction}'. Expected one of {list(reducers)}" + ) + + self._reduce = reducers[reduction] + + def forward( + self, + states: Tensor, + mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + _validate_trifold_shape(states) + + loops, center = states[..., :3], states[..., 3:] + fold_value = self._reduce(loops) + + if mask is not None: + mask = mask.to(states.dtype) + while mask.dim() < fold_value.dim(): + mask = mask.unsqueeze(-1) + fold_value = fold_value * mask.squeeze(-1) + + center = center + self.alpha * fold_value.unsqueeze(-1) + updated = torch.cat([loops, center], dim=-1) + return updated, fold_value + + +class TriFoldUnfold(nn.Module): + """Unfold operator :math:`\Psi` that redistributes nexus coherence.""" + + def __init__(self, beta: float = 0.2) -> None: + super().__init__() + self.beta = beta + + def forward( + self, + states: Tensor, + mask: Optional[Tensor] = None, + ) -> Tensor: + _validate_trifold_shape(states) + + loops, center = states[..., :3], states[..., 3:] + delta = self.beta * center + + if mask is not None: + mask = mask.to(states.dtype) + while mask.dim() < delta.dim(): + mask = mask.unsqueeze(-1) + delta = delta * mask + + loops = loops + delta + return torch.cat([loops, center], dim=-1) + + +class TriFoldReasoner(nn.Module): + """Iterative fold/unfold reasoning over tri-fold states.""" + + def __init__( + self, + semiring: Optional[TriFoldSemiring] = None, + alpha: float = 1.0, + beta: float = 0.2, + reduction: str = "min", + ) -> None: + super().__init__() + self.semiring = semiring or TriFoldSemiring() + self.fold = TriFoldFold(alpha=alpha, reduction=reduction) + self.unfold = TriFoldUnfold(beta=beta) + + def forward( + self, + states: Tensor, + batch: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + iterations: int = 1, + ) -> TriFoldOutput: + _validate_trifold_shape(states) + + updated = states + history = [] + + for _ in range(iterations): + updated, fold_value = self.fold(updated, mask=mask) + history.append(fold_value) + updated = self.unfold(updated, mask=mask) + + if history: + fold_history_tensor = torch.stack(history, dim=0) + else: + fold_history_tensor = torch.zeros( + (0,) + updated.shape[:-1], + device=updated.device, + dtype=updated.dtype, + ) + + aggregated = self._aggregate(updated, batch=batch, mask=mask) + center = aggregated[..., 3] + loops = aggregated[..., :3] + + return TriFoldOutput( + states=updated, + aggregated=aggregated, + center=center, + loops=loops, + fold_history=fold_history_tensor, + ) + + def _aggregate( + self, + states: Tensor, + batch: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: + if batch is None: + return self.semiring.aggregate(states, dim=-2 if states.dim() > 1 else 0, mask=mask) + + unique_batches = torch.unique(batch, sorted=True) + aggregated_states = [] + for idx in unique_batches.tolist(): + batch_mask = batch == idx + aggregated_states.append( + self.semiring.aggregate(states[batch_mask], dim=0, mask=None) + ) + + return torch.stack(aggregated_states, dim=0) diff --git a/nsm/models/hierarchical.py b/nsm/models/hierarchical.py index 0c25921..4915485 100644 --- a/nsm/models/hierarchical.py +++ b/nsm/models/hierarchical.py @@ -32,6 +32,7 @@ from .pooling import SymmetricGraphPooling from .confidence.base import BaseSemiring from .confidence.examples import ProductSemiring +from .confidence.trifold import TriFoldReasoner, TRIFOLD_CHANNELS class SymmetricHierarchicalLayer(nn.Module): @@ -46,6 +47,7 @@ class SymmetricHierarchicalLayer(nn.Module): - Coupling layers for invertible transformations - Graph pooling for hierarchical coarsening - Semiring for confidence propagation + - Tri-fold semantic head for recursive subject-predicate-object reasoning Args: node_features (int): Node feature dimensionality @@ -56,6 +58,12 @@ class SymmetricHierarchicalLayer(nn.Module): hidden_dim (int): Hidden dimension for coupling/R-GCN semiring (BaseSemiring, optional): Confidence propagation semiring dropout (float): Dropout rate for regularization + tri_semantics (bool): Enable tri-fold semantic triple head + tri_hidden_dim (int, optional): Hidden dimension for tri-fold projector + tri_iterations (int): Number of fold/unfold refinement steps + tri_fold_reduction (str): Reduction used by fold operator ('min', 'mean', 'logsumexp') + tri_alpha (float): Scaling factor for fold accumulation + tri_beta (float): Scaling factor for unfold broadcasting Example: >>> layer = SymmetricHierarchicalLayer( @@ -88,7 +96,13 @@ def __init__( coupling_layers: int = 3, hidden_dim: int = 128, semiring: Optional[BaseSemiring] = None, - dropout: float = 0.1 + dropout: float = 0.1, + tri_semantics: bool = True, + tri_hidden_dim: Optional[int] = None, + tri_iterations: int = 2, + tri_fold_reduction: str = 'min', + tri_alpha: float = 1.0, + tri_beta: float = 0.2, ): super().__init__() @@ -149,6 +163,27 @@ def __init__( # Dropout for regularization self.dropout = nn.Dropout(dropout) + # Tri-fold semantics head + self.enable_trifold = tri_semantics + self.trifold_iterations = tri_iterations + if self.enable_trifold: + hidden = tri_hidden_dim or node_features + projector_layers = [ + nn.Linear(node_features, hidden), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden, TRIFOLD_CHANNELS), + ] + self.trifold_projector = nn.Sequential(*projector_layers) + self.trifold_reasoner = TriFoldReasoner( + alpha=tri_alpha, + beta=tri_beta, + reduction=tri_fold_reduction, + ) + else: + self.trifold_projector = None + self.trifold_reasoner = None + # Layer normalization for stability self.norm_l1 = nn.LayerNorm(node_features) self.norm_l2 = nn.LayerNorm(node_features) @@ -160,7 +195,7 @@ def why_operation( edge_type: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor, Tensor]: """WHY operation: Abstract from concrete (L1) to abstract (L2). Steps: @@ -181,6 +216,7 @@ def why_operation( - x_abstract (Tensor): Abstract node features [num_pooled, node_features] - edge_index_abstract (Tensor): Abstract edge index [2, num_pooled_edges] - edge_attr_abstract (Tensor, optional): Abstract edge attributes + - batch_abstract (Tensor, optional): Abstract batch assignments - perm (Tensor): Pooling indices - score (Tensor): Node selection scores """ @@ -216,7 +252,14 @@ def why_operation( x_abstract = self.norm_l2(x_abstract) x_abstract = F.relu(x_abstract) - return x_abstract, edge_index_abstract, edge_attr_abstract, perm, score + return ( + x_abstract, + edge_index_abstract, + edge_attr_abstract, + batch_abstract, + perm, + score, + ) def what_operation( self, @@ -277,23 +320,51 @@ def forward( - x_abstract: Abstract representations - x_reconstructed: Reconstructed concrete features (if return_cycle_loss) - cycle_loss: Reconstruction error (if return_cycle_loss) + - batch_abstract: Batch assignments for abstract nodes - perm: Pooling indices - score: Node selection scores + - tri_fold_states: Per-node tri-fold states (if enabled) + - tri_fold_summary: Aggregated tri-fold summary per graph (if enabled) + - tri_fold_center: Nexus coherence scores (if enabled) + - tri_fold_loops: Aggregated loop scores (if enabled) + - tri_fold_history: Fold operator values per iteration (if enabled) """ original_num_nodes = x.size(0) # WHY operation - x_abstract, edge_index_abstract, edge_attr_abstract, perm, score = \ - self.why_operation(x, edge_index, edge_type, edge_attr, batch) + ( + x_abstract, + edge_index_abstract, + edge_attr_abstract, + batch_abstract, + perm, + score, + ) = self.why_operation(x, edge_index, edge_type, edge_attr, batch) result = { 'x_abstract': x_abstract, 'edge_index_abstract': edge_index_abstract, 'edge_attr_abstract': edge_attr_abstract, + 'batch_abstract': batch_abstract, 'perm': perm, 'score': score } + if self.enable_trifold: + tri_input = self.trifold_projector(x_abstract) + tri_result = self.trifold_reasoner( + tri_input, + batch=batch_abstract, + iterations=self.trifold_iterations, + ) + result.update({ + 'tri_fold_states': tri_result.states, + 'tri_fold_summary': tri_result.aggregated, + 'tri_fold_center': tri_result.center, + 'tri_fold_loops': tri_result.loops, + 'tri_fold_history': tri_result.fold_history, + }) + if return_cycle_loss: # WHAT operation x_reconstructed = self.what_operation( @@ -350,7 +421,9 @@ def __repr__(self) -> str: f' num_relations={self.num_relations},\n' f' num_bases={self.num_bases},\n' f' pool_ratio={self.pool_ratio:.2f},\n' - f' semiring={self.semiring.get_name()}\n' + f' semiring={self.semiring.get_name()},\n' + f' tri_semantics={self.enable_trifold},\n' + f' tri_iterations={self.trifold_iterations}\n' f')') @@ -470,7 +543,9 @@ def forward( if batch is not None: # Batch-wise global pooling from torch_geometric.nn import global_mean_pool - batch_abstract = batch[result['perm']] + batch_abstract = result.get('batch_abstract') + if batch_abstract is None and result.get('perm') is not None: + batch_abstract = batch[result['perm']] x_graph = global_mean_pool(x_abstract, batch_abstract) else: # Single graph: mean pooling @@ -484,7 +559,9 @@ def forward( if batch is not None: # Batch-wise global pooling from torch_geometric.nn import global_mean_pool - batch_abstract = batch[result['perm']] + batch_abstract = result.get('batch_abstract') + if batch_abstract is None and result.get('perm') is not None: + batch_abstract = batch[result['perm']] x_graph = global_mean_pool(x_abstract, batch_abstract) else: # Single graph: mean pooling diff --git a/tests/models/confidence/test_trifold.py b/tests/models/confidence/test_trifold.py new file mode 100644 index 0000000..fce1da4 --- /dev/null +++ b/tests/models/confidence/test_trifold.py @@ -0,0 +1,66 @@ +import torch + +from nsm.models.confidence import TriFoldSemiring, TriFoldReasoner + + +def test_trifold_semiring_combine_adds_channels(): + semiring = TriFoldSemiring() + path = torch.tensor([ + [0.0, -1.0, -2.0, -0.5], + [0.2, -0.3, -0.1, -0.4], + ]) + combined = semiring.combine(path, dim=0) + expected = torch.tensor([0.2, -1.3, -2.1, -0.9]) + assert torch.allclose(combined, expected, atol=1e-5) + + +def test_trifold_semiring_aggregate_max_componentwise(): + semiring = TriFoldSemiring() + candidates = torch.tensor([ + [0.1, -0.2, -0.5, -1.0], + [-0.3, 0.4, -0.1, -0.2], + [0.0, -0.1, 0.2, -0.3], + ]) + aggregated = semiring.aggregate(candidates, dim=0) + expected = torch.tensor([0.1, 0.4, 0.2, -0.2]) + assert torch.allclose(aggregated, expected, atol=1e-5) + + +def test_trifold_reasoner_fold_unfold_cycle(): + states = torch.tensor([ + [0.1, 0.2, 0.3, 0.0], + [0.3, 0.1, 0.0, -0.5], + ]) + reasoner = TriFoldReasoner(alpha=1.0, beta=0.5, reduction='min') + output = reasoner(states, iterations=1) + + expected_states = torch.tensor([ + [0.15, 0.25, 0.35, 0.1], + [0.05, -0.15, -0.25, -0.5], + ]) + assert torch.allclose(output.states, expected_states, atol=1e-5) + assert torch.allclose(output.aggregated, torch.tensor([0.15, 0.25, 0.35, 0.1]), atol=1e-5) + assert torch.isclose(output.center, torch.tensor(0.1), atol=1e-5) + assert torch.allclose(output.loops, torch.tensor([0.15, 0.25, 0.35]), atol=1e-5) + assert torch.allclose(output.fold_history.squeeze(0), torch.tensor([0.1, 0.0]), atol=1e-5) + + +def test_trifold_reasoner_batch_aggregation(): + states = torch.tensor([ + [0.0, 0.0, 0.0, 0.0], + [0.5, 0.2, -0.1, -0.2], + [0.1, 0.3, 0.4, 0.5], + ]) + batch = torch.tensor([0, 0, 1]) + reasoner = TriFoldReasoner(alpha=0.0, beta=0.0) + output = reasoner(states, batch=batch, iterations=0) + + assert output.fold_history.shape[0] == 0 + assert output.aggregated.shape == (2, 4) + expected = torch.tensor([ + [0.5, 0.2, 0.0, 0.0], + [0.1, 0.3, 0.4, 0.5], + ]) + assert torch.allclose(output.aggregated, expected, atol=1e-5) + assert torch.allclose(output.center, torch.tensor([0.0, 0.5]), atol=1e-5) + assert torch.allclose(output.loops, expected[:, :3], atol=1e-5)