Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nsm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def generate_triples(self) -> List[SemanticTriple]:
predicate = f"pred_{torch.randint(0, self.num_predicates, (1,)).item()}"
obj = f"entity_{torch.randint(0, self.num_entities, (1,)).item()}"

# Random confidence and level
confidence = torch.rand(1).item()
# Random confidence log-scores and level
confidence = torch.log_softmax(torch.randn(4), dim=0)
level = torch.randint(1, self.num_levels + 1, (1,)).item()

triple = SemanticTriple(
Expand Down
66 changes: 50 additions & 16 deletions nsm/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
for graph neural network processing.
"""

from typing import List, Optional, Tuple, Dict
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.data import Data
Expand Down Expand Up @@ -36,7 +36,7 @@ class GraphConstructor:
>>> constructor = GraphConstructor()
>>> graph = constructor.construct(triples)
>>> print(graph)
Data(x=[3, 64], edge_index=[2, 2], edge_attr=[2, 1], ...)
Data(x=[3, 64], edge_index=[2, 2], edge_attr=[2, 4], ...)
"""

def __init__(
Expand Down Expand Up @@ -75,7 +75,7 @@ def construct(
- x: Node features [num_nodes, feat_dim]
- edge_index: Edge connectivity [2, num_edges]
- edge_attr: Edge attributes including:
- confidence: [num_edges, 1]
- confidence: [num_edges, 4] log-score tensors
- edge_type: [num_edges] (predicate indices)
- num_nodes: Total number of unique entities
- num_edges: Total number of triples
Expand All @@ -91,14 +91,14 @@ def construct(
return Data(
x=torch.zeros(0, self.node_feature_dim),
edge_index=torch.zeros(2, 0, dtype=torch.long),
edge_attr=torch.zeros(0, 1)
edge_attr=torch.zeros(0, 4)
)

# Build vocabulary from triples
node_indices = {} # entity -> node_idx mapping
edge_list = [] # List of (src, dst) tuples
edge_types = [] # List of predicate indices
confidences = [] # List of confidence scores
confidence_vectors = [] # List of [4] confidence tensors
node_levels = {} # entity -> level mapping

current_node_idx = 0
Expand Down Expand Up @@ -128,7 +128,7 @@ def construct(
# Create edge from subject to object
edge_list.append((node_indices[triple.subject], node_indices[triple.object]))
edge_types.append(pred_idx)
confidences.append(triple.confidence)
confidence_vectors.append(triple.get_confidence_tensor())

# Convert to tensors
num_nodes = len(node_indices)
Expand All @@ -139,8 +139,8 @@ def construct(
# Edge types for R-GCN
edge_type = torch.tensor(edge_types, dtype=torch.long)

# Confidence scores
confidence = torch.tensor(confidences, dtype=torch.float32).unsqueeze(1)
# Confidence score tensors [num_edges, 4]
confidence = torch.stack(confidence_vectors, dim=0)

# Node features
if node_features is None:
Expand Down Expand Up @@ -217,13 +217,18 @@ def construct_hierarchical(

return concrete_graph, abstract_graph

def add_self_loops(self, data: Data, self_loop_weight: float = 1.0) -> Data:
def add_self_loops(
self,
data: Data,
self_loop_weight: Union[float, Sequence[float], Tensor] = 1.0
) -> Data:
"""
Add self-loops to graph (useful for GNN message passing).

Args:
data: PyG Data object
self_loop_weight: Confidence score for self-loops
self_loop_weight: Confidence score for self-loops. Accepts
scalar or length-4 iterable/tensor.

Returns:
Data object with self-loops added
Expand All @@ -237,11 +242,7 @@ def add_self_loops(self, data: Data, self_loop_weight: float = 1.0) -> Data:
data.edge_index = torch.cat([data.edge_index, self_loop_index], dim=1)

# Add self-loop attributes
self_loop_attr = torch.full(
(num_nodes, 1),
self_loop_weight,
dtype=torch.float32
)
self_loop_attr = self._expand_self_loop_weight(num_nodes, self_loop_weight)
data.edge_attr = torch.cat([data.edge_attr, self_loop_attr], dim=0)

# Add self-loop edge types (use special index for self-loops)
Expand All @@ -256,6 +257,39 @@ def add_self_loops(self, data: Data, self_loop_weight: float = 1.0) -> Data:

return data

@staticmethod
def _expand_self_loop_weight(
num_nodes: int,
weight: Union[float, Sequence[float], Tensor]
) -> Tensor:
"""Expand self-loop weight(s) to [num_nodes, 4] tensor."""
if isinstance(weight, Tensor):
tensor = weight.detach().clone().to(dtype=torch.float32)
flat = tensor.reshape(-1)
if flat.numel() == 1:
value = float(flat.item())
return torch.full((num_nodes, 4), value, dtype=torch.float32)
if flat.numel() != 4:
raise ValueError(
"Self-loop weight tensor must have 1 or 4 elements"
)
return flat.reshape(1, 4).repeat(num_nodes, 1)

if isinstance(weight, Sequence):
values = list(weight)
tensor = torch.tensor(values, dtype=torch.float32)
flat = tensor.reshape(-1)
if flat.numel() == 1:
return GraphConstructor._expand_self_loop_weight(num_nodes, flat.item())
if flat.numel() != 4:
raise ValueError(
"Self-loop weight sequence must have length 1 or 4"
)
return flat.reshape(1, 4).repeat(num_nodes, 1)

value = float(weight)
return torch.full((num_nodes, 4), value, dtype=torch.float32)

def batch_construct(
self,
triple_lists: List[List[SemanticTriple]]
Expand Down Expand Up @@ -310,7 +344,7 @@ def visualize_graph_structure(data: Data) -> str:
Nodes: 5
Edges: 8
Node features: [5, 64]
Edge attributes: [8, 1]
Edge attributes: [8, 4]
"""
info = [
"Graph Structure:",
Expand Down
133 changes: 111 additions & 22 deletions nsm/data/triple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor

Expand All @@ -25,7 +25,8 @@ class SemanticTriple:
subject: The subject entity (string identifier or index)
predicate: The relationship type (string identifier or index)
object: The object entity (string identifier or index)
confidence: Learnable confidence score in [0, 1], default 1.0
confidence: Learnable confidence score. Accepts scalar in [0, 1]
or length-4 log-score tensor/list.
level: Hierarchical level (1=concrete, 2=abstract for Phase 1)
metadata: Additional information (provenance, timestamp, etc.)

Expand Down Expand Up @@ -61,48 +62,121 @@ class SemanticTriple:
subject: Union[str, int]
predicate: Union[str, int]
object: Union[str, int]
confidence: float = 1.0
confidence: Union[float, Sequence[float], Tensor] = 1.0
level: int = 1
metadata: Dict[str, Any] = field(default_factory=dict)
_confidence_tensor: Tensor = field(init=False, repr=False, compare=False)
_confidence_is_scalar: bool = field(init=False, repr=False, compare=False)

def __post_init__(self):
"""Validate triple attributes after initialization."""
# Validate confidence in [0, 1]
if not 0.0 <= self.confidence <= 1.0:
raise ValueError(
f"Confidence must be in [0, 1], got {self.confidence}"
)
self._set_confidence(self.confidence)

# Validate level (Phase 1: 1-2, Phase 2+: 1-6)
if not 1 <= self.level <= 6:
raise ValueError(
f"Level must be in [1, 6], got {self.level}"
)

def _set_confidence(
self,
confidence_value: Union[float, Sequence[float], Tensor]
) -> None:
"""Normalize confidence input into scalar + 4-vector tensor."""
tensor, is_scalar, scalar_value = self._coerce_confidence_tensor(
confidence_value
)

self._confidence_tensor = tensor
self._confidence_is_scalar = is_scalar

if is_scalar:
self.confidence = scalar_value
else:
# Aggregate vector confidence into scalar summary for legacy paths.
probs = torch.softmax(tensor, dim=0)
self.confidence = float(probs.max().item())

@staticmethod
def _coerce_confidence_tensor(
confidence_value: Union[float, Sequence[float], Tensor]
) -> Tuple[Tensor, bool, float]:
"""Convert confidence inputs into a standardized tensor."""
if isinstance(confidence_value, Tensor):
tensor = confidence_value.detach().clone().to(dtype=torch.float32)
flat = tensor.reshape(-1)
if flat.numel() == 1:
scalar = float(flat.item())
if not 0.0 <= scalar <= 1.0:
raise ValueError(
f"Confidence must be in [0, 1], got {scalar}"
)
vector = torch.full((4,), scalar, dtype=torch.float32)
return vector, True, scalar
if flat.numel() != 4:
raise ValueError(
"Confidence tensor must have 4 elements, "
f"got shape {tuple(tensor.shape)}"
)
return flat.reshape(4), False, float('nan')

if isinstance(confidence_value, (int, float)):
scalar = float(confidence_value)
if not 0.0 <= scalar <= 1.0:
raise ValueError(
f"Confidence must be in [0, 1], got {scalar}"
)
vector = torch.full((4,), scalar, dtype=torch.float32)
return vector, True, scalar

if isinstance(confidence_value, Sequence):
values = list(confidence_value)
tensor = torch.tensor(values, dtype=torch.float32)
flat = tensor.reshape(-1)
if flat.numel() == 1:
return SemanticTriple._coerce_confidence_tensor(flat.item())
if flat.numel() != 4:
raise ValueError(
"Confidence sequence must contain exactly 4 values, "
f"got {len(values)}"
)
return flat.reshape(4), False, float('nan')

raise TypeError(
"Confidence must be a float, Tensor, or sequence of floats."
)

def uses_confidence_vector(self) -> bool:
"""Return True if the triple stores a 4-channel confidence vector."""
return not self._confidence_is_scalar

def get_confidence_tensor(self) -> Tensor:
"""Return the 4-element confidence tensor for this triple."""
return self._confidence_tensor.clone()

def to_tensor(self) -> Tensor:
"""
Convert confidence to a PyTorch tensor.
Convert confidence representation to a PyTorch tensor.

Returns:
Tensor: Scalar tensor containing confidence value
Tensor: Shape [4] tensor containing confidence values/log-scores
"""
return torch.tensor(self.confidence, dtype=torch.float32)
return self.get_confidence_tensor()

def update_confidence(self, new_confidence: float) -> None:
def update_confidence(
self,
new_confidence: Union[float, Sequence[float], Tensor]
) -> None:
"""
Update the confidence score.

Args:
new_confidence: New confidence value in [0, 1]
new_confidence: New confidence value(s)

Raises:
ValueError: If new_confidence not in [0, 1]
ValueError: If new_confidence has invalid value or shape
"""
if not 0.0 <= new_confidence <= 1.0:
raise ValueError(
f"Confidence must be in [0, 1], got {new_confidence}"
)
self.confidence = new_confidence
self._set_confidence(new_confidence)

def is_concrete(self) -> bool:
"""Check if this triple represents concrete actions/environment."""
Expand Down Expand Up @@ -222,14 +296,29 @@ def get_unique_predicates(self) -> set[Union[str, int]]:
"""Get set of all unique predicate types."""
return {t.predicate for t in self.triples}

def get_confidence_tensor(self) -> Tensor:
def get_confidence_tensor(self, *, as_vector: bool = False) -> Tensor:
"""
Get confidence scores as a tensor.

Args:
as_vector: If True, returns stacked [num_triples, 4] tensors.
If False (default), returns scalar confidences.

Returns:
Tensor: Shape [num_triples] containing all confidence scores
Tensor containing confidence information.
- Shape [num_triples] when as_vector is False
- Shape [num_triples, 4] when as_vector is True
"""
confidences = [t.confidence for t in self.triples]
if not self.triples:
if as_vector:
return torch.zeros(0, 4, dtype=torch.float32)
return torch.zeros(0, dtype=torch.float32)

if as_vector:
tensors = [t.get_confidence_tensor() for t in self.triples]
return torch.stack(tensors, dim=0)

confidences = [float(t.confidence) for t in self.triples]
return torch.tensor(confidences, dtype=torch.float32)

def __len__(self) -> int:
Expand Down
Loading
Loading