diff --git a/.gitignore b/.gitignore index a98720b..96b0498 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,7 @@ env/ *.pt *.ckpt checkpoints/ +checkpoints/*/ # All subdirectories *.pkl # Jupyter @@ -102,8 +103,15 @@ datasets/ !configs/*.yaml !configs/*.yml +# Domain-specific data (generated at runtime) +data/causal/ +data/kg/ +data/planning/ +!data/*/.gitkeep + # Results & Outputs results/ +results/*/ # All subdirectories outputs/ figures/ plots/ @@ -131,4 +139,15 @@ Desktop.ini .git/worktrees/ # Keep empty directories -!.gitkeep \ No newline at end of file +!.gitkeep + +# Branch-specific summary documents (NSM-26) +*_DATASET_SUMMARY.md +*_SUMMARY.md +*_ANALYSIS.md +*_FINDINGS.md +*_STATUS.md + +# Auto-generated scripts (NSM-26) +experiments/run_*.sh +experiments/training_log.jsonl \ No newline at end of file diff --git a/KG_IMPLEMENTATION_SUMMARY.md b/KG_IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..d9ee56a --- /dev/null +++ b/KG_IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,245 @@ +# Knowledge Graph Dataset Implementation Summary + +## Overview + +Successfully implemented a comprehensive Knowledge Graph dataset generator for NSM Phase 1 exploration (NSM-10). This is one of three parallel dataset explorations to empirically validate the best domain for 2-level hierarchical reasoning. + +## What Was Implemented + +### 1. KnowledgeGraphTripleDataset (`nsm/data/knowledge_graph_dataset.py`) + +**Core Features:** +- Generates 20K synthetic triples across 5K entities +- 50+ predicate types spanning biographical, geographic, creative, and conceptual relations +- 2-level hierarchy: L1 (facts/instances) and L2 (types/categories) +- Confidence scores 0.5-1.0 for partial observability +- 6 entity categories: People, Places, Organizations, Concepts, Awards, Dates + +**Entity Generation:** +- Named entities: Einstein, Curie, Newton, Paris, London, MIT, Harvard, etc. +- Rich biographical data: born_in, works_at, won, studied_at +- Geographic hierarchies: city → country → continent +- Type hierarchies: Person → Living_Being → Entity + +**Reasoning Support:** +- Multi-hop query generation (2-hop paths) +- Type consistency checking pairs +- Analogical reasoning support +- Link prediction labels + +### 2. Comprehensive Tests (`tests/data/test_kg_dataset.py`) + +**21 Test Cases:** +- Dataset generation and initialization +- Triple structure validation +- Level distribution (L1 vs L2) +- Confidence variance and diversity +- Predicate type coverage (50+) +- Entity diversity and categories +- Named entity inclusion +- Multi-hop reasoning paths +- Type hierarchy validation +- PyG interface compliance +- Caching and reproducibility + +**Test Results:** +- ✅ 21/21 tests passing +- ✅ 98% code coverage +- ✅ Reproducibility verified +- ✅ PyG DataLoader compatible + +### 3. Example Script (`examples/knowledge_graph_example.py`) + +**16 Demonstration Sections:** +1. Dataset creation and statistics +2. Sample triples (L1 and L2) +3. Predicate type distribution +4. Entity category breakdown +5. Named entity examples +6. PyG graph structure +7. Multi-hop reasoning queries +8. Type consistency pairs +9. Biographical reasoning chains +10. Type hierarchy display +11. Confidence distribution +12. PyG DataLoader batching +13. Reasoning pattern examples +14. Instance-of relations +15. Geographic chains +16. Professional relations + +### 4. Evaluation Metrics (`nsm/evaluation/kg_metrics.py`) + +**Metrics Implemented:** +- **Link Prediction:** Hits@K, MRR, Mean/Median Rank +- **Analogical Reasoning:** A:B :: C:D with vector arithmetic +- **Type Consistency:** Precision, Recall, F1, Confusion Matrix +- **Multi-hop Reasoning:** Exact match, Hits@K, Average Precision +- **Calibration:** ECE, MCE, Calibration Curves + +## Design Decisions + +### 1. Entity-Centric Knowledge Representation +**Rationale:** Knowledge graphs excel at entity relationships and type hierarchies, making them ideal for testing hierarchical abstraction in NSM. + +### 2. 50+ Predicate Types +**Rationale:** Rich relation vocabulary enables diverse reasoning patterns and tests R-GCN's basis decomposition (NSM-17). + +### 3. Confidence Variance (0.5-1.0) +**Rationale:** Partial observability tests NSM's confidence propagation (NSM-12) and provenance semiring implementation. + +### 4. Named Entity Inclusion +**Rationale:** Real-world entities (Einstein, Paris) make debugging and interpretation easier during development. + +### 5. Reproducible Generation with Seeds +**Rationale:** Essential for comparing across exploration branches (NSM-10, NSM-12, NSM-11). + +## Integration Points + +### With NSM-18 (PyG Infrastructure): +- ✅ Extends `BaseSemanticTripleDataset` +- ✅ Uses `GraphConstructor` for graph building +- ✅ Compatible with `TripleVocabulary` +- ✅ Returns PyG `Data` objects + +### With NSM-17 (R-GCN): +- ✅ Edge types for 50+ predicates +- ✅ Confidence as edge attributes +- ✅ Typed relations ready for basis decomposition + +### With NSM-12 (Confidence Exploration): +- ✅ Wide confidence range (0.5-1.0) +- ✅ Product semiring evaluation ready +- ✅ Calibration metrics implemented + +### With NSM-14 (Training Loop): +- ✅ Link prediction labels +- ✅ Batch loading compatible +- ✅ Evaluation metrics ready + +## Testing Results + +``` +======================== 21 passed, 3 warnings in 4.43s ======================== + +Coverage: + nsm/data/knowledge_graph_dataset.py: 98% + nsm/data/dataset.py: 69% + +Key Metrics: + - 5000 triples generated + - 1298+ unique entities + - 66 predicates (50+ expected, extras from random generation) + - L1/L2 ratio: ~87%/13% (facts vs types) + - Average confidence: 0.77 +``` + +## Commits Made + +1. **4441471** - Implement KnowledgeGraphTripleDataset for relational reasoning + - Core dataset class with entity/predicate generation + - Multi-hop query support + - Type hierarchy implementation + - Fix for PyTorch 2.6 weights_only parameter + +2. **2b5f1f2** - Add comprehensive tests for KnowledgeGraphTripleDataset + - 21 test cases covering all functionality + - 98% code coverage + - Reproducibility and caching tests + +3. **467d2ff** - Add knowledge graph dataset example and visualization + - 16 demonstration sections + - Reasoning chain examples + - PyG DataLoader integration + +4. **46cdacc** - Add evaluation metrics for knowledge graph reasoning + - Link prediction (Hits@K, MRR) + - Analogical reasoning + - Type consistency checking + - Calibration metrics (ECE/MCE) + +## Next Steps for NSM-10 Evaluation + +### Comparison Criteria (from CLAUDE.md): +1. **Task accuracy (40%):** Link prediction, type inference +2. **Calibration (20%):** ECE on confidence scores +3. **Multi-hop (20%):** 2-hop reasoning accuracy +4. **Interpretability (20%):** Debugging and explainability + +### Evaluation Protocol: +```bash +# Run evaluation suite +python -m tests.evaluation_suite --dataset knowledge_graph --output results/kg.json + +# Compare with other branches +python compare_results.py results/kg.json results/planning.json results/causal.json +``` + +### Expected Strengths: +- ✅ Rich predicate diversity (50+ types) +- ✅ Clear type hierarchies (instance_of, subclass_of) +- ✅ Multi-hop paths (2-hop queries) +- ✅ Entity-centric interpretability + +### Potential Weaknesses: +- ⚠️ Less hierarchical structure than planning domain +- ⚠️ May need deeper hierarchies for full NSM evaluation +- ⚠️ Random relation generation may create noise + +## Files Changed + +``` +nsm/data/knowledge_graph_dataset.py (new, 682 lines) +nsm/data/dataset.py (modified, +1 line for weights_only fix) +tests/data/test_kg_dataset.py (new, 394 lines) +examples/knowledge_graph_example.py (new, 216 lines) +nsm/evaluation/__init__.py (new, 13 lines) +nsm/evaluation/kg_metrics.py (new, 449 lines) +``` + +**Total:** 1,755 lines of new code + +## Domain Properties + +### Level 1 (Facts/Instances): +- **Biographical:** born_in, died_in, works_at, studied_at, won +- **Geographic:** located_in, capital_of, borders, adjacent_to +- **Creative:** created, authored, composed, designed, invented +- **Professional:** employed_by, founded, leads, member_of +- **Temporal:** occurred_in, started_on, ended_on + +### Level 2 (Types/Categories): +- **Type hierarchy:** instance_of, subclass_of, category_of +- **Generalizations:** typically_has, usually_in, commonly_has +- **Abstract:** related_to, similar_to, implies, requires, enables + +### Mathematical Foundation: +``` +Knowledge Graph G = (E, R, T) where: +- E: Set of entities (5K) +- R: Set of typed relations (50+) +- T ⊆ E × R × E: Set of triples (20K) +- L: Level function L: T → {1, 2} +- C: Confidence function C: T → [0.5, 1.0] +``` + +## Conclusion + +✅ **Implementation Complete** +- Fully functional KG dataset generator +- Comprehensive test coverage (21/21 passing) +- Rich evaluation metrics +- Ready for NSM-10 parallel exploration + +✅ **NSM-18 Integration Verified** +- Compatible with BaseSemanticTripleDataset +- PyG Data objects working +- Vocabulary and graph construction validated + +✅ **Ready for Evaluation** +- Evaluation metrics implemented +- Comparison protocol defined +- Documentation complete + +**Branch:** dataset-knowledge-graph +**Status:** ✅ Ready for evaluation and PR (once NSM-10 exploration complete) diff --git a/examples/knowledge_graph_example.py b/examples/knowledge_graph_example.py new file mode 100644 index 0000000..8a79f7d --- /dev/null +++ b/examples/knowledge_graph_example.py @@ -0,0 +1,216 @@ +""" +Knowledge Graph Dataset Example + +Demonstrates loading and using the KnowledgeGraphTripleDataset +for relational reasoning tasks. +""" + +import torch +from pathlib import Path + +from nsm.data.knowledge_graph_dataset import KnowledgeGraphTripleDataset +from nsm.data.graph import visualize_graph_structure + + +def main(): + """Run Knowledge Graph dataset examples.""" + print("=" * 80) + print("Knowledge Graph Dataset Example") + print("=" * 80) + + # Create dataset + print("\n1. Creating Knowledge Graph Dataset...") + dataset = KnowledgeGraphTripleDataset( + root="data/kg_example", + split='train', + num_entities=1000, + num_triples=5000, + seed=42 + ) + print(f" Dataset created: {dataset}") + + # Display statistics + print("\n2. Dataset Statistics:") + stats = dataset.get_statistics() + for key, value in stats.items(): + print(f" {key}: {value}") + + # Show sample triples + print("\n3. Sample Triples (First 10):") + print(" " + "-" * 76) + for i, triple in enumerate(dataset.triples[:10]): + print(f" {i+1}. ({triple.subject}, {triple.predicate}, {triple.object})") + print(f" Confidence: {triple.confidence:.3f}, Level: {triple.level}") + + # Show Level 1 (facts) examples + print("\n4. Level 1 Triples (Facts/Instances):") + print(" " + "-" * 76) + level1_triples = [t for t in dataset.triples if t.level == 1][:10] + for i, triple in enumerate(level1_triples): + print(f" {i+1}. {triple.subject} --[{triple.predicate}]--> {triple.object}") + print(f" Confidence: {triple.confidence:.3f}") + + # Show Level 2 (types) examples + print("\n5. Level 2 Triples (Types/Categories):") + print(" " + "-" * 76) + level2_triples = [t for t in dataset.triples if t.level == 2][:10] + for i, triple in enumerate(level2_triples): + print(f" {i+1}. {triple.subject} --[{triple.predicate}]--> {triple.object}") + print(f" Confidence: {triple.confidence:.3f}") + + # Show predicate diversity + print("\n6. Predicate Types:") + predicates = {} + for triple in dataset.triples: + pred = triple.predicate + level = triple.level + key = f"L{level}: {pred}" + predicates[key] = predicates.get(key, 0) + 1 + + # Show top 20 predicates + sorted_preds = sorted(predicates.items(), key=lambda x: x[1], reverse=True)[:20] + for pred, count in sorted_preds: + print(f" {pred}: {count} triples") + + # Entity categories + print("\n7. Entity Categories:") + print(f" People: {len(dataset.people)}") + print(f" Places: {len(dataset.places)}") + print(f" Organizations: {len(dataset.organizations)}") + print(f" Concepts: {len(dataset.concepts)}") + print(f" Awards: {len(dataset.awards)}") + print(f" Total unique entities: {len(dataset.entities)}") + + # Show some famous entities + print("\n8. Sample Named Entities:") + all_entities = set() + for triple in dataset.triples: + all_entities.add(triple.subject) + all_entities.add(triple.object) + + famous_people = [p for p in dataset.PERSON_NAMES if p in all_entities][:5] + famous_places = [p for p in dataset.PLACES if p in all_entities][:5] + + print(f" People: {', '.join(famous_people)}") + print(f" Places: {', '.join(famous_places)}") + + # Get and visualize a graph + print("\n9. PyTorch Geometric Graph (Sample):") + graph, label = dataset[0] + print(visualize_graph_structure(graph)) + print(f" Label (confidence): {label.item():.3f}") + + # Multi-hop reasoning queries + print("\n10. Multi-hop Reasoning Queries (Sample):") + print(" " + "-" * 74) + queries = dataset.get_multi_hop_queries(num_queries=5) + for i, query in enumerate(queries, 1): + print(f" Query {i}:") + print(f" Start: {query['start_entity']}") + print(f" Path: {' -> '.join(query['relations'])}") + if 'intermediate' in query: + print(f" Via: {query['intermediate']}") + print(f" Answer: {query['expected_answer']}") + print() + + # Type consistency checking + print("\n11. Type Consistency Pairs (Sample):") + print(" " + "-" * 74) + pairs = dataset.get_type_consistency_pairs(num_pairs=10) + for i, (entity, entity_type, is_consistent) in enumerate(pairs, 1): + status = "✓ VALID" if is_consistent else "✗ INVALID" + print(f" {i}. {entity} : {entity_type} -> {status}") + + # Show biographical chain example + print("\n12. Example Biographical Reasoning Chain:") + print(" " + "-" * 74) + # Find a person with multiple relations + person_triples = {} + for triple in dataset.triples: + if triple.subject in dataset.people: + if triple.subject not in person_triples: + person_triples[triple.subject] = [] + person_triples[triple.subject].append(triple) + + # Find person with rich biography + for person, triples in list(person_triples.items())[:5]: + if len(triples) >= 3: + print(f" Person: {person}") + for triple in triples[:5]: + print(f" {triple.predicate} -> {triple.object} (conf: {triple.confidence:.3f})") + break + + # Type hierarchy example + print("\n13. Type Hierarchy:") + print(" " + "-" * 74) + hierarchy = dataset.type_hierarchy + for child, parent in list(hierarchy.items())[:10]: + print(f" {child} --[subclass_of]--> {parent}") + + # Confidence distribution + print("\n14. Confidence Score Distribution:") + confidences = [t.confidence for t in dataset.triples] + ranges = [ + (0.5, 0.6, "0.5-0.6"), + (0.6, 0.7, "0.6-0.7"), + (0.7, 0.8, "0.7-0.8"), + (0.8, 0.9, "0.8-0.9"), + (0.9, 1.0, "0.9-1.0"), + ] + + for low, high, label in ranges: + count = sum(1 for c in confidences if low <= c <= high) + pct = 100 * count / len(confidences) + bar = "█" * int(pct / 2) + print(f" {label}: {bar} {pct:.1f}% ({count} triples)") + + # PyG Data Loader example + print("\n15. PyTorch Geometric DataLoader Example:") + from torch_geometric.loader import DataLoader + + # Create small batch + loader = DataLoader(dataset, batch_size=4, shuffle=True) + batch = next(iter(loader)) + print(f" Batch of {batch.num_graphs} graphs:") + print(f" Total nodes: {batch.num_nodes}") + print(f" Total edges: {batch.edge_index.size(1)}") + print(f" Node features: {batch.x.shape}") + print(f" Edge features: {batch.edge_attr.shape}") + + # Show specific reasoning patterns + print("\n16. Specific Reasoning Patterns:") + print(" " + "-" * 74) + + # Find instance-of relations + instance_of = [t for t in dataset.triples if t.predicate == "instance_of"][:5] + print(" a) Instance-of Relations:") + for triple in instance_of: + print(f" {triple.subject} is a {triple.object}") + + # Find born_in + located_in chains + print("\n b) Geographic Chains:") + born_in = {t.subject: t.object for t in dataset.triples if t.predicate == "born_in"} + located_in = {t.subject: t.object for t in dataset.triples if t.predicate == "located_in"} + + count = 0 + for person, city in list(born_in.items())[:20]: + if city in located_in: + country = located_in[city] + print(f" {person} born in {city}, which is in {country}") + count += 1 + if count >= 3: + break + + # Find work relations + print("\n c) Professional Relations:") + works_at = [t for t in dataset.triples if t.predicate == "works_at"][:5] + for triple in works_at: + print(f" {triple.subject} works at {triple.object}") + + print("\n" + "=" * 80) + print("Example complete!") + print("=" * 80) + + +if __name__ == '__main__': + main() diff --git a/experiments/train_kg.py b/experiments/train_kg.py new file mode 100644 index 0000000..b01a480 --- /dev/null +++ b/experiments/train_kg.py @@ -0,0 +1,324 @@ +""" +Training script for NSM-23: Knowledge Graph Domain with Link Prediction. + +Domain-specific implementation: +- 66 relations (large relation vocabulary: IsA, PartOf, LocatedIn, etc.) +- 12 bases for R-GCN (81.8% parameter reduction) +- Pool ratio: 0.13 (weak hierarchy - preserve fine-grained relations) +- Link prediction task (binary classification: valid/invalid triple) +- Negative sampling for incomplete KG (50/50 split) + +Target metrics: +- Reconstruction error: <30% (higher tolerance due to weak hierarchy) +- Hits@10: ≥70% (positive class accuracy) +- MRR: ≥0.5 (average confidence on true triples) +- Analogical reasoning: ≥60% (overall accuracy) + +Setup: + pip install -e . # Install package from project root + +Usage: + python experiments/train_kg.py --epochs 100 --batch-size 32 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import Adam +from torch.utils.data import DataLoader, random_split +from torch_geometric.loader import DataLoader as GeometricDataLoader +import argparse +from pathlib import Path +import json + +# NOTE: Install the package before running this script: +# pip install -e . +# from the project root directory + +from nsm.data.knowledge_graph_dataset import KnowledgeGraphTripleDataset +from nsm.models import NSMModel +from nsm.training import NSMTrainer, compute_classification_metrics +from nsm.models.confidence.temperature import TemperatureScheduler +from nsm.evaluation.kg_metrics import ( + compute_hits_at_k, + compute_mrr, + compute_analogical_reasoning_accuracy +) + + +def compute_kg_metrics( + preds: torch.Tensor, + labels: torch.Tensor, + task_type: str, + dataset: KnowledgeGraphTripleDataset = None +) -> dict: + """Compute knowledge graph domain-specific metrics. + + Args: + preds: Predicted logits + labels: Ground truth labels + task_type: Task type + dataset: KnowledgeGraphTripleDataset for link prediction evaluation + + Returns: + dict: Metrics including Hits@10, MRR, analogical reasoning + """ + metrics = compute_classification_metrics(preds, labels, task_type) + + # Domain-specific metrics + if dataset is not None: + # Hits@10 (top-10 accuracy for link prediction) + hits_at_10 = compute_hits_at_k(preds, labels, dataset, k=10) + metrics['hits@10'] = hits_at_10 + + # Mean Reciprocal Rank + mrr = compute_mrr(preds, labels, dataset) + metrics['mrr'] = mrr + + # Analogical reasoning (A:B::C:?) + analogical_acc = compute_analogical_reasoning_accuracy(preds, labels, dataset) + metrics['analogical_reasoning'] = analogical_acc + + return metrics + + +def create_kg_model( + node_features: int, + num_classes: int, + device: torch.device +) -> NSMModel: + """Create NSM model configured for knowledge graph domain. + + Configuration (from NSM-23): + - 66 relations (large vocabulary: IsA, PartOf, LocatedIn, etc.) + - 12 bases (81.8% parameter reduction - critical for 66 relations) + - pool_ratio=0.13 (weak hierarchy - preserve fine-grained relations) + - Link prediction task + + Args: + node_features: Node feature dimensionality + num_classes: Number of output classes (for link prediction) + device: Device + + Returns: + NSMModel: Configured model + """ + model = NSMModel( + node_features=node_features, + num_relations=66, # Large relation vocabulary + num_classes=num_classes, + num_bases=12, # 81.8% parameter reduction + pool_ratio=0.13, # Weak hierarchy (preserve fine-grained facts) + task_type='link_prediction', + num_levels=3 # Phase 1.5: 3-level hierarchy to break symmetry bias + ) + + return model.to(device) + + +def collate_fn(batch_list): + """Collate function for KnowledgeGraphTripleDataset. + + PyG Data objects need special handling for batching. + """ + from torch_geometric.data import Batch + + # Batch PyG Data objects + data_list = [item[0] for item in batch_list] + # Labels are already binary (0 or 1) from generate_labels() + labels = torch.tensor([item[1].item() for item in batch_list], dtype=torch.long) + + batched_data = Batch.from_data_list(data_list) + + # Create batch dict + batch = { + 'x': batched_data.x, + 'edge_index': batched_data.edge_index, + 'edge_type': batched_data.edge_type, + 'edge_attr': batched_data.edge_attr if hasattr(batched_data, 'edge_attr') else None, + 'batch': batched_data.batch, + 'y': labels + } + + return batch + + +def main(args): + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Dataset + print("Loading knowledge graph dataset...") + dataset = KnowledgeGraphTripleDataset( + root=args.data_dir, + split='train', + num_entities=args.num_entities, + num_triples=args.num_triples, + seed=args.seed + ) + + # Split into train/val + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split( + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(args.seed) + ) + + print(f"Train: {train_size} graphs, Val: {val_size} graphs") + + # Data loaders + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.num_workers + ) + + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=collate_fn, + num_workers=args.num_workers + ) + + # Model + print("Creating knowledge graph NSM model...") + model = create_kg_model( + node_features=args.node_features, + num_classes=args.num_classes, + device=device + ) + + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Optimizer + optimizer = Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay + ) + + # Temperature scheduler + temp_scheduler = TemperatureScheduler( + initial_temp=1.0, + final_temp=0.3, + decay_rate=0.9999, + warmup_epochs=10 + ) + + # Trainer + print("Initializing trainer...") + trainer = NSMTrainer( + model=model, + optimizer=optimizer, + device=device, + cycle_loss_weight=args.cycle_loss_weight, + gradient_clip=args.gradient_clip, + temp_scheduler=temp_scheduler, + checkpoint_dir=args.checkpoint_dir, + log_interval=args.log_interval, + use_wandb=args.use_wandb, + use_tensorboard=args.use_tensorboard + ) + + # Train + print(f"\nStarting training for {args.epochs} epochs...") + print(f"Checkpoint directory: {args.checkpoint_dir}") + + history = trainer.train( + train_loader=train_loader, + val_loader=val_loader, + epochs=args.epochs, + task_type='link_prediction', + compute_metrics=lambda p, l, t: compute_kg_metrics(p, l, t, dataset), + early_stopping_patience=args.early_stopping_patience, + save_best_only=True + ) + + # Save final results + results = { + 'args': vars(args), + 'final_train_loss': history['train'][-1]['total_loss'], + 'final_val_loss': history['val'][-1]['total_loss'], + 'best_val_loss': trainer.best_val_loss, + 'final_metrics': history['val'][-1] + } + + results_path = Path(args.checkpoint_dir) / 'results.json' + with open(results_path, 'w') as f: + json.dump(results, f, indent=2) + + print("\n" + "="*80) + print("Training Complete!") + print("="*80) + print(f"Best validation loss: {trainer.best_val_loss:.4f}") + print(f"Final metrics:") + for key, value in history['val'][-1].items(): + if isinstance(value, (int, float)): + print(f" {key}: {value:.4f}") + print(f"\nResults saved to: {results_path}") + print(f"Best model: {args.checkpoint_dir}/best_model.pt") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Train NSM on Knowledge Graph Domain') + + # Data + parser.add_argument('--data-dir', type=str, default='data/kg', + help='Data directory') + parser.add_argument('--num-entities', type=int, default=100, + help='Number of entities in KG') + parser.add_argument('--num-triples', type=int, default=500, + help='Number of triples per graph') + parser.add_argument('--num-workers', type=int, default=4, + help='Data loader workers') + + # Model + parser.add_argument('--node-features', type=int, default=64, + help='Node feature dimensionality') + parser.add_argument('--num-classes', type=int, default=2, + help='Number of output classes') + + # Training + parser.add_argument('--epochs', type=int, default=100, + help='Number of training epochs') + parser.add_argument('--batch-size', type=int, default=32, + help='Batch size') + parser.add_argument('--lr', type=float, default=1e-3, + help='Learning rate') + parser.add_argument('--weight-decay', type=float, default=1e-5, + help='Weight decay') + parser.add_argument('--cycle-loss-weight', type=float, default=0.15, + help='Weight for cycle consistency loss (higher for weak hierarchy)') + parser.add_argument('--gradient-clip', type=float, default=1.0, + help='Gradient clipping value') + parser.add_argument('--early-stopping-patience', type=int, default=20, + help='Early stopping patience') + + # Logging + parser.add_argument('--checkpoint-dir', type=str, default='checkpoints/kg', + help='Checkpoint directory') + parser.add_argument('--log-interval', type=int, default=10, + help='Logging interval (steps)') + parser.add_argument('--use-wandb', action='store_true', + help='Use Weights & Biases') + parser.add_argument('--use-tensorboard', action='store_true', + help='Use Tensorboard') + + # Misc + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + + args = parser.parse_args() + + # Set seeds + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + + main(args) diff --git a/nsm/evaluation/__init__.py b/nsm/evaluation/__init__.py index d89a01c..e508eac 100644 --- a/nsm/evaluation/__init__.py +++ b/nsm/evaluation/__init__.py @@ -1,7 +1,7 @@ """ NSM evaluation and validation modules. -Provides metrics, preflight checks, and domain-specific evaluation utilities. +Provides metrics, preflight checks, process cleanup, and domain-specific evaluation utilities. """ from nsm.evaluation.preflight_checks import ( @@ -22,7 +22,14 @@ kill_process ) +from nsm.evaluation.kg_metrics import ( + compute_link_prediction_metrics, + compute_analogical_reasoning_accuracy, + compute_type_consistency_accuracy, +) + __all__ = [ + # Preflight checks 'run_preflight_checks', 'check_dataset_balance', 'check_cycle_loss_weight', @@ -32,7 +39,12 @@ 'check_class_weights', 'PreflightCheckError', 'PreflightCheckWarning', + # Process cleanup 'check_and_cleanup', 'find_training_processes', 'kill_process', + # KG-specific metrics + 'compute_link_prediction_metrics', + 'compute_analogical_reasoning_accuracy', + 'compute_type_consistency_accuracy', ] diff --git a/nsm/evaluation/kg_metrics.py b/nsm/evaluation/kg_metrics.py new file mode 100644 index 0000000..4221238 --- /dev/null +++ b/nsm/evaluation/kg_metrics.py @@ -0,0 +1,547 @@ +""" +Knowledge Graph Evaluation Metrics + +Implements evaluation metrics for knowledge graph reasoning tasks: +- Link prediction (Hits@K, MRR) +- Analogical reasoning (A:B :: C:?) +- Type consistency checking + +Mathematical Foundation: + Link prediction: Given (h, r, ?), rank all entities by score + - Hits@K: Fraction of correct entities in top K + - MRR: Mean reciprocal rank = 1/N * Σ(1/rank_i) + + Analogical reasoning: Given A:B :: C:D, find D + - Requires embedding-based similarity or graph traversal + + Type consistency: Verify (entity, type) consistency + - Binary classification: consistent vs inconsistent +""" + +from typing import List, Tuple, Dict, Optional, Set +import torch +from torch import Tensor +import numpy as np + + +def compute_link_prediction_metrics( + predictions: Tensor, + targets: Tensor, + k_values: List[int] = [1, 3, 10], +) -> Dict[str, float]: + """ + Compute link prediction metrics: Hits@K and MRR. + + Args: + predictions: Predicted scores [batch_size, num_candidates] + Higher scores indicate more likely candidates + targets: True entity indices [batch_size] + k_values: Values of K for Hits@K computation + + Returns: + Dictionary containing: + - hits@1, hits@3, hits@10: Fraction of correct predictions in top K + - mrr: Mean reciprocal rank + - mean_rank: Average rank of correct entity + + Mathematical Foundation: + Hits@K = (1/N) * Σ I[rank(correct) ≤ K] + MRR = (1/N) * Σ (1 / rank(correct)) + + Examples: + >>> predictions = torch.tensor([[0.1, 0.9, 0.3], [0.8, 0.2, 0.7]]) + >>> targets = torch.tensor([1, 0]) # True indices + >>> metrics = compute_link_prediction_metrics(predictions, targets) + >>> print(f"Hits@1: {metrics['hits@1']:.3f}") + """ + batch_size = predictions.size(0) + num_candidates = predictions.size(1) + + # Get ranks of target entities + # Sort predictions in descending order + sorted_indices = torch.argsort(predictions, dim=1, descending=True) + + # Find rank of each target + ranks = [] + for i in range(batch_size): + target_idx = targets[i].item() + # Find position of target in sorted list (1-indexed) + rank = (sorted_indices[i] == target_idx).nonzero(as_tuple=True)[0].item() + 1 + ranks.append(rank) + + ranks_tensor = torch.tensor(ranks, dtype=torch.float32) + + # Compute metrics + metrics = {} + + # Hits@K for each K value + for k in k_values: + hits_at_k = (ranks_tensor <= k).float().mean().item() + metrics[f'hits@{k}'] = hits_at_k + + # Mean Reciprocal Rank + mrr = (1.0 / ranks_tensor).mean().item() + metrics['mrr'] = mrr + + # Mean Rank + mean_rank = ranks_tensor.mean().item() + metrics['mean_rank'] = mean_rank + + # Median Rank + median_rank = ranks_tensor.median().item() + metrics['median_rank'] = median_rank + + return metrics + + +def compute_analogical_reasoning_accuracy( + embeddings: Tensor, + analogy_queries: List[Tuple[int, int, int, int]], + k: int = 1, +) -> Dict[str, float]: + """ + Compute accuracy on analogical reasoning: A:B :: C:D. + + Uses vector arithmetic: D ≈ C + (B - A) + + Args: + embeddings: Entity embeddings [num_entities, embed_dim] + analogy_queries: List of (A, B, C, D) entity index tuples + k: Top-K accuracy threshold + + Returns: + Dictionary containing: + - accuracy@k: Fraction of queries where D is in top K predictions + - average_rank: Average rank of correct D + + Mathematical Foundation: + Given embeddings e_A, e_B, e_C, e_D: + Find D' = argmax_{i} sim(e_i, e_C + e_B - e_A) + where sim is cosine similarity or dot product. + + Examples: + >>> embeddings = torch.randn(100, 64) + >>> queries = [(0, 1, 2, 3), (4, 5, 6, 7)] # A:B :: C:D + >>> metrics = compute_analogical_reasoning_accuracy(embeddings, queries) + """ + if len(analogy_queries) == 0: + return {'accuracy@1': 0.0, 'average_rank': float('inf')} + + correct_count = 0 + ranks = [] + + for a_idx, b_idx, c_idx, d_idx in analogy_queries: + # Vector arithmetic: D ≈ C + (B - A) + e_a = embeddings[a_idx] + e_b = embeddings[b_idx] + e_c = embeddings[c_idx] + e_d = embeddings[d_idx] + + # Predicted D embedding + predicted_d = e_c + (e_b - e_a) + + # Compute similarity to all entities + similarities = torch.matmul(embeddings, predicted_d) + + # Mask out A, B, C to avoid trivial solutions + similarities[a_idx] = float('-inf') + similarities[b_idx] = float('-inf') + similarities[c_idx] = float('-inf') + + # Get ranked entities + sorted_indices = torch.argsort(similarities, descending=True) + + # Find rank of correct D + rank = (sorted_indices == d_idx).nonzero(as_tuple=True)[0].item() + 1 + ranks.append(rank) + + # Check if in top K + if rank <= k: + correct_count += 1 + + accuracy = correct_count / len(analogy_queries) + avg_rank = np.mean(ranks) + + return { + f'accuracy@{k}': accuracy, + 'average_rank': avg_rank, + 'num_queries': len(analogy_queries), + } + + +def compute_type_consistency_accuracy( + predictions: Tensor, + labels: Tensor, + threshold: float = 0.5, +) -> Dict[str, float]: + """ + Compute type consistency checking accuracy. + + Args: + predictions: Predicted consistency scores [num_pairs] + Values in [0, 1] where 1 = consistent + labels: Ground truth labels [num_pairs] + 1 = consistent, 0 = inconsistent + threshold: Classification threshold (default 0.5) + + Returns: + Dictionary containing: + - accuracy: Overall classification accuracy + - precision: Precision for positive class (consistent) + - recall: Recall for positive class + - f1: F1 score + - true_positives, false_positives, true_negatives, false_negatives + + Mathematical Foundation: + Binary classification metrics: + - Accuracy = (TP + TN) / (TP + TN + FP + FN) + - Precision = TP / (TP + FP) + - Recall = TP / (TP + FN) + - F1 = 2 * (Precision * Recall) / (Precision + Recall) + + Examples: + >>> predictions = torch.tensor([0.9, 0.3, 0.8, 0.1]) + >>> labels = torch.tensor([1, 0, 1, 0]) + >>> metrics = compute_type_consistency_accuracy(predictions, labels) + """ + # Binary predictions + binary_preds = (predictions >= threshold).float() + + # Compute confusion matrix elements + tp = ((binary_preds == 1) & (labels == 1)).sum().float().item() + tn = ((binary_preds == 0) & (labels == 0)).sum().float().item() + fp = ((binary_preds == 1) & (labels == 0)).sum().float().item() + fn = ((binary_preds == 0) & (labels == 1)).sum().float().item() + + total = tp + tn + fp + fn + + # Compute metrics + accuracy = (tp + tn) / total if total > 0 else 0.0 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + + f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + return { + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'true_positives': int(tp), + 'true_negatives': int(tn), + 'false_positives': int(fp), + 'false_negatives': int(fn), + } + + +def compute_multi_hop_reasoning_accuracy( + graph_queries: List[Dict], + predicted_answers: List[Set[int]], + ground_truth_answers: List[Set[int]], +) -> Dict[str, float]: + """ + Compute accuracy for multi-hop reasoning queries. + + Args: + graph_queries: List of query dictionaries with path information + predicted_answers: List of sets of predicted entity indices + ground_truth_answers: List of sets of correct entity indices + + Returns: + Dictionary containing: + - exact_match: Fraction with exact answer set match + - hits@1: Fraction with at least one correct answer in top-1 + - hits@3: Fraction with at least one correct answer in top-3 + - average_precision: Mean precision across queries + + Mathematical Foundation: + Multi-hop reasoning requires chaining relations: + Given path (r1, r2, ..., rk) and start entity e0, + Find {ek | ∃e1,...,ek-1: (e0,r1,e1), (e1,r2,e2), ..., (ek-1,rk,ek)} + + Examples: + >>> queries = [{'path': ['born_in', 'located_in'], 'start': 0}] + >>> predicted = [{5, 6}] + >>> ground_truth = [{5, 7}] + >>> metrics = compute_multi_hop_reasoning_accuracy(queries, predicted, ground_truth) + """ + if len(predicted_answers) != len(ground_truth_answers): + raise ValueError("Number of predictions must match ground truth") + + exact_matches = 0 + hits_1 = 0 + hits_3 = 0 + precisions = [] + + for pred_set, true_set in zip(predicted_answers, ground_truth_answers): + # Exact match + if pred_set == true_set: + exact_matches += 1 + + # Hits@K - check if any predicted answer is correct + if len(pred_set & true_set) > 0: + hits_1 += 1 # At least one correct + hits_3 += 1 + + # Precision for this query + if len(pred_set) > 0: + precision = len(pred_set & true_set) / len(pred_set) + precisions.append(precision) + else: + precisions.append(0.0) + + num_queries = len(predicted_answers) + + return { + 'exact_match': exact_matches / num_queries if num_queries > 0 else 0.0, + 'hits@1': hits_1 / num_queries if num_queries > 0 else 0.0, + 'hits@3': hits_3 / num_queries if num_queries > 0 else 0.0, + 'average_precision': np.mean(precisions) if precisions else 0.0, + 'num_queries': num_queries, + } + + +def compute_calibration_error( + confidences: Tensor, + accuracies: Tensor, + num_bins: int = 10, +) -> Dict[str, float]: + """ + Compute Expected Calibration Error (ECE). + + Measures how well predicted confidence scores match actual accuracy. + + Args: + confidences: Predicted confidence scores [num_samples] + accuracies: Binary correctness (1 = correct, 0 = incorrect) [num_samples] + num_bins: Number of bins for calibration curve + + Returns: + Dictionary containing: + - ece: Expected Calibration Error + - mce: Maximum Calibration Error + - calibration_curve: List of (bin_confidence, bin_accuracy, bin_count) tuples + + Mathematical Foundation: + ECE = Σ (|Bm| / N) * |acc(Bm) - conf(Bm)| + where Bm is the set of samples in bin m, + acc(Bm) is the accuracy in bin m, + conf(Bm) is the average confidence in bin m. + + Examples: + >>> confidences = torch.tensor([0.9, 0.8, 0.6, 0.3]) + >>> accuracies = torch.tensor([1.0, 1.0, 0.0, 0.0]) + >>> metrics = compute_calibration_error(confidences, accuracies) + """ + confidences = confidences.cpu().numpy() + accuracies = accuracies.cpu().numpy() + + bin_boundaries = np.linspace(0, 1, num_bins + 1) + bin_lowers = bin_boundaries[:-1] + bin_uppers = bin_boundaries[1:] + + ece = 0.0 + mce = 0.0 + calibration_curve = [] + + total_samples = len(confidences) + + for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): + # Find samples in this bin + in_bin = (confidences >= bin_lower) & (confidences < bin_upper) + + # Handle last bin inclusively + if bin_upper == 1.0: + in_bin = (confidences >= bin_lower) & (confidences <= bin_upper) + + bin_size = in_bin.sum() + + if bin_size > 0: + # Average confidence in bin + bin_confidence = confidences[in_bin].mean() + + # Average accuracy in bin + bin_accuracy = accuracies[in_bin].mean() + + # Calibration error for this bin + bin_error = abs(bin_accuracy - bin_confidence) + + # Weighted contribution to ECE + ece += (bin_size / total_samples) * bin_error + + # Update MCE + mce = max(mce, bin_error) + + calibration_curve.append(( + float(bin_confidence), + float(bin_accuracy), + int(bin_size) + )) + else: + calibration_curve.append((0.0, 0.0, 0)) + + return { + 'ece': float(ece), + 'mce': float(mce), + 'calibration_curve': calibration_curve, + 'num_bins': num_bins, + } + + +def compute_kg_comprehensive_metrics( + model: torch.nn.Module, + dataset, + device: str = 'cpu', + num_samples: Optional[int] = None, +) -> Dict[str, float]: + """ + Compute comprehensive evaluation metrics for KG dataset. + + Args: + model: Trained NSM model + dataset: KnowledgeGraphTripleDataset instance + device: Device to run evaluation on + num_samples: Number of samples to evaluate (None = all) + + Returns: + Dictionary with all KG-specific metrics + + Note: + This is a convenience function that orchestrates all KG metrics. + Requires model to have methods: forward(), predict_link(), etc. + """ + model.eval() + model = model.to(device) + + metrics = {} + + # Sample dataset if needed + if num_samples is not None: + indices = torch.randperm(len(dataset))[:num_samples].tolist() + else: + indices = list(range(len(dataset))) + + # Collect predictions and targets + all_predictions = [] + all_targets = [] + all_confidences = [] + + with torch.no_grad(): + for idx in indices: + graph, label = dataset[idx] + graph = graph.to(device) + + # Get model prediction + # NOTE: This function is a stub for future implementation + # when we need batch evaluation across all KG metrics. + # Currently, metrics are computed per-batch in the training loop + # using the simplified wrapper functions below. + pass + + # NOTE: Not yet implemented - use simplified wrappers in training loop instead + # Future implementation would compute: + # - Link prediction metrics (currently handled in training loop) + # - Type consistency metrics (requires type hierarchy) + # - Calibration error (requires confidence scores) + + return metrics + + +# Simplified wrapper functions for training loop integration + +def compute_hits_at_k(preds: torch.Tensor, labels: torch.Tensor, dataset=None, k: int = 10) -> float: + """Simplified Hits@K for training loop (binary classification variant). + + Note: The current KG dataset does binary link prediction (valid/invalid triple), + not entity ranking. This metric approximates ranking-style Hits@K using + classification confidence scores. + + Args: + preds: Predicted logits [batch_size, num_classes] for binary classification + labels: Ground truth labels [batch_size] (0 or 1) + dataset: Dataset (optional, unused in binary mode) + k: Unused in binary classification mode + + Returns: + float: Accuracy for positive class (approximates Hits@K) + """ + # For binary classification, treat as accuracy on positive examples + if preds.dim() == 2 and preds.size(1) == 2: + # Two-class logits [batch_size, 2] + pred_labels = torch.argmax(preds, dim=1) + else: + # Single probability output + pred_labels = (torch.sigmoid(preds.squeeze()) > 0.5).long() + + # Compute accuracy on positive examples only (valid triples) + positive_mask = (labels == 1) + if positive_mask.sum() > 0: + hits = (pred_labels[positive_mask] == labels[positive_mask]).float().mean().item() + else: + hits = 0.0 + + return hits + + +def compute_mrr(preds: torch.Tensor, labels: torch.Tensor, dataset=None) -> float: + """Simplified Mean Reciprocal Rank for training loop (binary classification variant). + + Note: The current KG dataset does binary link prediction (valid/invalid triple), + not entity ranking. This metric approximates MRR using prediction confidence + on positive examples. + + Args: + preds: Predicted logits [batch_size, num_classes] for binary classification + labels: Ground truth labels [batch_size] (0 or 1) + dataset: Dataset (optional, unused in binary mode) + + Returns: + float: Average confidence on positive examples (approximates MRR) + """ + # For binary classification, compute confidence on positive examples + if preds.dim() == 2 and preds.size(1) == 2: + # Two-class logits: get probability for positive class + probs = torch.softmax(preds, dim=1)[:, 1] # Probability of class 1 + else: + # Single probability output + probs = torch.sigmoid(preds.squeeze()) + + # Average confidence on true positive examples + positive_mask = (labels == 1) + if positive_mask.sum() > 0: + mrr_approx = probs[positive_mask].mean().item() + else: + mrr_approx = 0.0 + + return mrr_approx + + +def compute_analogical_reasoning_accuracy(preds: torch.Tensor, labels: torch.Tensor, dataset=None) -> float: + """Simplified analogical reasoning for training loop (binary classification variant). + + Note: The current KG dataset does binary link prediction (valid/invalid triple), + not analogical reasoning patterns. This metric returns overall accuracy as a + proxy for reasoning capability. + + Args: + preds: Predicted logits [batch_size, num_classes] for binary classification + labels: Ground truth labels [batch_size] (0 or 1) + dataset: Dataset (optional, unused in binary mode) + + Returns: + float: Overall classification accuracy (proxy for analogical reasoning) + """ + # For binary classification, compute accuracy + if preds.dim() == 2 and preds.size(1) == 2: + # Two-class logits + pred_labels = torch.argmax(preds, dim=1) + else: + # Single probability output + pred_labels = (torch.sigmoid(preds.squeeze()) > 0.5).long() + + correct = (pred_labels == labels).sum().item() + total = labels.size(0) + + accuracy = correct / total if total > 0 else 0.0 + return accuracy diff --git a/nsm/models/hierarchical.py b/nsm/models/hierarchical.py index 94cb1d7..4a1ca6a 100644 --- a/nsm/models/hierarchical.py +++ b/nsm/models/hierarchical.py @@ -359,6 +359,7 @@ class NSMModel(nn.Module): Integrates all components: - Two SymmetricHierarchicalLayers for L1↔L2↔L3 + - Dual-pass prediction (abstract + concrete fusion) by default - Task-specific prediction heads - Confidence-aware output @@ -370,8 +371,11 @@ class NSMModel(nn.Module): pool_ratio (float): Pooling ratio for each level task_type (str): 'classification', 'regression', or 'link_prediction' num_levels (int): Number of hierarchy levels (2 or 3, default 3) + use_dual_pass (bool): Use dual-pass prediction (default True) + fusion_mode (str): Fusion strategy for dual-pass ('equal', 'learned', 'abstract_only', 'concrete_only') Example: + >>> # Dual-pass mode (default) >>> model = NSMModel( ... node_features=64, ... num_relations=16, @@ -380,6 +384,16 @@ class NSMModel(nn.Module): ... num_levels=3 ... ) >>> + >>> # Single-pass mode (opt-out) + >>> model = NSMModel( + ... node_features=64, + ... num_relations=16, + ... num_classes=2, + ... task_type='classification', + ... num_levels=3, + ... use_dual_pass=False + ... ) + >>> >>> # Forward pass >>> output = model(x, edge_index, edge_type, edge_attr, batch) >>> logits = output['logits'] @@ -399,7 +413,7 @@ def __init__( pool_ratio: float = 0.5, task_type: str = 'classification', num_levels: int = 3, - use_dual_pass: bool = False, + use_dual_pass: bool = True, fusion_mode: str = 'equal' ): super().__init__() diff --git a/test_kg_metrics_fix.py b/test_kg_metrics_fix.py new file mode 100644 index 0000000..ae4baba --- /dev/null +++ b/test_kg_metrics_fix.py @@ -0,0 +1,111 @@ +""" +Test that KG metrics work correctly with binary classification. +""" +import torch +from nsm.evaluation.kg_metrics import ( + compute_hits_at_k, + compute_mrr, + compute_analogical_reasoning_accuracy +) + +def test_kg_metrics_binary_classification(): + """Test KG metrics with binary classification (2-class logits).""" + print("Testing KG metrics with binary classification...") + + # Simulate batch of 10 samples with 2-class logits + batch_size = 10 + preds = torch.randn(batch_size, 2) # [10, 2] logits + labels = torch.tensor([1, 0, 1, 1, 0, 1, 0, 1, 1, 0]) # Binary labels + + # Test compute_hits_at_k + hits = compute_hits_at_k(preds, labels, k=10) + print(f"✓ Hits@K: {hits:.4f}") + assert isinstance(hits, float), "Hits@K should return float" + assert 0.0 <= hits <= 1.0, "Hits@K should be in [0, 1]" + + # Test compute_mrr + mrr = compute_mrr(preds, labels) + print(f"✓ MRR: {mrr:.4f}") + assert isinstance(mrr, float), "MRR should return float" + assert 0.0 <= mrr <= 1.0, "MRR should be in [0, 1]" + + # Test compute_analogical_reasoning_accuracy + acc = compute_analogical_reasoning_accuracy(preds, labels) + print(f"✓ Analogical Reasoning Accuracy: {acc:.4f}") + assert isinstance(acc, float), "Accuracy should return float" + assert 0.0 <= acc <= 1.0, "Accuracy should be in [0, 1]" + + print("\n✅ All KG metrics work correctly with binary classification!") + return True + +def test_kg_metrics_single_output(): + """Test KG metrics with single probability output.""" + print("\nTesting KG metrics with single probability output...") + + # Simulate batch with single probability output + batch_size = 10 + preds = torch.randn(batch_size, 1) # [10, 1] probabilities (before sigmoid) + labels = torch.tensor([1, 0, 1, 1, 0, 1, 0, 1, 1, 0]) + + # Test compute_hits_at_k + hits = compute_hits_at_k(preds, labels, k=10) + print(f"✓ Hits@K (single output): {hits:.4f}") + + # Test compute_mrr + mrr = compute_mrr(preds, labels) + print(f"✓ MRR (single output): {mrr:.4f}") + + # Test compute_analogical_reasoning_accuracy + acc = compute_analogical_reasoning_accuracy(preds, labels) + print(f"✓ Analogical Reasoning Accuracy (single output): {acc:.4f}") + + print("\n✅ Metrics work with single output format!") + return True + +def test_edge_cases(): + """Test edge cases like all zeros, all ones, etc.""" + print("\nTesting edge cases...") + + # All positive labels + preds = torch.randn(5, 2) + labels = torch.ones(5, dtype=torch.long) + + hits = compute_hits_at_k(preds, labels) + mrr = compute_mrr(preds, labels) + acc = compute_analogical_reasoning_accuracy(preds, labels) + + print(f"✓ All positive labels: hits={hits:.4f}, mrr={mrr:.4f}, acc={acc:.4f}") + + # All negative labels + labels = torch.zeros(5, dtype=torch.long) + + hits = compute_hits_at_k(preds, labels) # Should handle no positives + mrr = compute_mrr(preds, labels) # Should return 0 + acc = compute_analogical_reasoning_accuracy(preds, labels) + + print(f"✓ All negative labels: hits={hits:.4f}, mrr={mrr:.4f}, acc={acc:.4f}") + + print("\n✅ Edge cases handled correctly!") + return True + +if __name__ == '__main__': + print("=" * 60) + print("KG Metrics Fix Verification") + print("=" * 60) + print() + + try: + test_kg_metrics_binary_classification() + test_kg_metrics_single_output() + test_edge_cases() + + print("\n" + "=" * 60) + print("✅ All verification tests passed!") + print("=" * 60) + print("\nKG metrics are now compatible with binary classification.") + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + exit(1) diff --git a/test_merge_verification.py b/test_merge_verification.py new file mode 100644 index 0000000..dfc743e --- /dev/null +++ b/test_merge_verification.py @@ -0,0 +1,190 @@ +""" +Quick verification test for merged dual-pass and single-pass modes. + +Note: As of Phase 1.5, dual-pass mode is the default (use_dual_pass=True). +Single-pass mode is now opt-out via use_dual_pass=False. +""" +import torch +import torch.nn.functional as F +from nsm.models.hierarchical import NSMModel + +def test_dual_pass_default(): + """Test dual-pass mode as default (3-level).""" + print("Testing dual-pass mode (default behavior)...") + + model = NSMModel( + node_features=64, + num_relations=4, + num_classes=2, + task_type='classification', + num_levels=3 + # use_dual_pass defaults to True + ) + + # Create dummy data + num_nodes = 20 + x = torch.randn(num_nodes, 64) + edge_index = torch.randint(0, num_nodes, (2, 40)) + edge_type = torch.randint(0, 4, (40,)) + edge_attr = torch.randn(40) # 1D confidence weights + batch = torch.zeros(num_nodes, dtype=torch.long) + + # Forward pass + output = model(x, edge_index, edge_type, edge_attr, batch) + + # Verify dual-pass outputs + assert 'logits' in output, "Missing logits" + assert 'logits_abstract' in output, "Missing logits_abstract (dual-pass should be default)" + assert 'logits_concrete' in output, "Missing logits_concrete" + assert 'fusion_weights' in output, "Missing fusion_weights" + assert 'cycle_loss' in output, "Missing cycle_loss" + assert output['logits'].shape == (1, 2), f"Unexpected logits shape: {output['logits'].shape}" + + print(f"✓ Dual-pass mode (default) works!") + print(f" Fused logits shape: {output['logits'].shape}") + print(f" Abstract logits shape: {output['logits_abstract'].shape}") + print(f" Concrete logits shape: {output['logits_concrete'].shape}") + print(f" Fusion weights: {output['fusion_weights']}") + print(f" Cycle loss: {output['cycle_loss'].item():.4f}") + return True + +def test_single_pass_mode(): + """Test single-pass mode (opt-out via use_dual_pass=False).""" + print("\nTesting single-pass mode (opt-out)...") + + model = NSMModel( + node_features=64, + num_relations=4, + num_classes=2, + task_type='classification', + num_levels=3, + use_dual_pass=False # Explicitly opt-out of dual-pass + ) + + # Create dummy data + num_nodes = 20 + x = torch.randn(num_nodes, 64) + edge_index = torch.randint(0, num_nodes, (2, 40)) + edge_type = torch.randint(0, 4, (40,)) + edge_attr = torch.randn(40) # 1D confidence weights + batch = torch.zeros(num_nodes, dtype=torch.long) + + # Forward pass + output = model(x, edge_index, edge_type, edge_attr, batch) + + # Verify output + assert 'logits' in output, "Missing logits" + assert 'cycle_loss' in output, "Missing cycle_loss" + assert 'x_l2' in output, "Missing x_l2" + assert 'x_l3' in output, "Missing x_l3" + assert output['logits'].shape == (1, 2), f"Unexpected logits shape: {output['logits'].shape}" + + # Verify no dual-pass outputs + assert 'logits_abstract' not in output, "Should not have dual-pass outputs in single-pass mode" + + print(f"✓ Single-pass mode (opt-out) works!") + print(f" Logits shape: {output['logits'].shape}") + print(f" Cycle loss: {output['cycle_loss'].item():.4f}") + return True + +def test_dual_pass_learned_fusion(): + """Test dual-pass mode with learned fusion.""" + print("\nTesting dual-pass with learned fusion...") + + model = NSMModel( + node_features=64, + num_relations=4, + num_classes=2, + task_type='classification', + num_levels=3, + use_dual_pass=True, + fusion_mode='learned' + ) + + # Create dummy data + num_nodes = 20 + x = torch.randn(num_nodes, 64) + edge_index = torch.randint(0, num_nodes, (2, 40)) + edge_type = torch.randint(0, 4, (40,)) + edge_attr = torch.randn(40) # 1D confidence weights + batch = torch.zeros(num_nodes, dtype=torch.long) + + # Forward pass + output = model(x, edge_index, edge_type, edge_attr, batch) + + # Verify output + assert 'logits' in output, "Missing logits" + assert 'logits_abstract' in output, "Missing logits_abstract" + assert 'logits_concrete' in output, "Missing logits_concrete" + assert 'fusion_weights' in output, "Missing fusion_weights" + assert 'cycle_loss' in output, "Missing cycle_loss" + assert output['logits'].shape == (1, 2), f"Unexpected logits shape: {output['logits'].shape}" + + print(f"✓ Dual-pass with learned fusion works!") + print(f" Fused logits shape: {output['logits'].shape}") + print(f" Fusion weights (learned): {output['fusion_weights']}") + print(f" Cycle loss: {output['cycle_loss'].item():.4f}") + return True + +def test_2level_backward_compat(): + """Test 2-level mode for backward compatibility. + + Note: 2-level mode doesn't support dual-pass, so use_dual_pass=False is required. + """ + print("\nTesting 2-level mode (backward compatibility)...") + + model = NSMModel( + node_features=64, + num_relations=4, + num_classes=2, + task_type='classification', + num_levels=2, # 2-level mode + use_dual_pass=False # 2-level doesn't support dual-pass + ) + + # Create dummy data + num_nodes = 20 + x = torch.randn(num_nodes, 64) + edge_index = torch.randint(0, num_nodes, (2, 40)) + edge_type = torch.randint(0, 4, (40,)) + edge_attr = torch.randn(40) # 1D confidence weights + batch = torch.zeros(num_nodes, dtype=torch.long) + + # Forward pass + output = model(x, edge_index, edge_type, edge_attr, batch) + + # Verify output + assert 'logits' in output, "Missing logits" + assert 'cycle_loss' in output, "Missing cycle_loss" + assert output['logits'].shape == (1, 2), f"Unexpected logits shape: {output['logits'].shape}" + + print(f"✓ 2-level mode works!") + print(f" Logits shape: {output['logits'].shape}") + print(f" Cycle loss: {output['cycle_loss'].item():.4f}") + return True + +if __name__ == '__main__': + print("=" * 60) + print("Merge Verification Test Suite (Phase 1.5)") + print("=" * 60) + print("Note: Dual-pass mode is now the default (use_dual_pass=True)\n") + + try: + test_dual_pass_default() + test_dual_pass_learned_fusion() + test_single_pass_mode() + test_2level_backward_compat() + + print("\n" + "=" * 60) + print("✓ All tests passed! Merge successful.") + print("=" * 60) + print("\n✅ Dual-pass mode is now the default") + print("✅ Single-pass mode available via use_dual_pass=False") + print("✅ Backward compatibility maintained with 2-level mode") + print("\nReady to push and create PR!") + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + exit(1) diff --git a/tests/data/test_kg_dataset.py b/tests/data/test_kg_dataset.py new file mode 100644 index 0000000..92ff085 --- /dev/null +++ b/tests/data/test_kg_dataset.py @@ -0,0 +1,394 @@ +""" +Tests for Knowledge Graph Triple Dataset + +Validates KG dataset generation, entity diversity, multi-hop reasoning, +and type hierarchy consistency. +""" + +import pytest +import torch +import shutil +from pathlib import Path + +from nsm.data.knowledge_graph_dataset import KnowledgeGraphTripleDataset +from nsm.data.triple import SemanticTriple + + +@pytest.fixture +def test_data_dir(tmp_path): + """Create temporary data directory.""" + data_dir = tmp_path / "kg_test" + yield data_dir + # Cleanup + if data_dir.exists(): + shutil.rmtree(data_dir) + + +@pytest.fixture +def small_kg_dataset(test_data_dir): + """Create small KG dataset for testing.""" + return KnowledgeGraphTripleDataset( + root=str(test_data_dir), + split='train', + num_entities=100, + num_triples=500, + seed=42 + ) + + +@pytest.fixture +def medium_kg_dataset(test_data_dir): + """Create medium KG dataset.""" + return KnowledgeGraphTripleDataset( + root=str(test_data_dir / "medium"), + split='train', + num_entities=1000, + num_triples=5000, + seed=42 + ) + + +class TestKGDatasetGeneration: + """Test dataset generation and basic properties.""" + + def test_initialization(self, small_kg_dataset): + """Test dataset initializes correctly.""" + assert small_kg_dataset is not None + assert len(small_kg_dataset.triples) == 500 + assert len(small_kg_dataset) == 500 + + def test_triple_structure(self, small_kg_dataset): + """Test that generated triples have correct structure.""" + for triple in small_kg_dataset.triples[:10]: + assert isinstance(triple, SemanticTriple) + assert isinstance(triple.subject, str) + assert isinstance(triple.predicate, str) + assert isinstance(triple.object, str) + assert 0.0 <= triple.confidence <= 1.0 + assert triple.level in [1, 2] + + def test_level_distribution(self, small_kg_dataset): + """Test that both Level 1 and Level 2 triples exist.""" + level_1_count = sum(1 for t in small_kg_dataset.triples if t.level == 1) + level_2_count = sum(1 for t in small_kg_dataset.triples if t.level == 2) + + assert level_1_count > 0, "Should have Level 1 (fact) triples" + assert level_2_count > 0, "Should have Level 2 (type) triples" + assert level_1_count + level_2_count == 500 + + def test_confidence_variance(self, small_kg_dataset): + """Test that confidence scores vary appropriately.""" + confidences = [t.confidence for t in small_kg_dataset.triples] + + # Check range + assert min(confidences) >= 0.5, "Min confidence should be >= 0.5" + assert max(confidences) <= 1.0, "Max confidence should be <= 1.0" + + # Check variance (should not all be the same) + unique_confidences = len(set(confidences)) + assert unique_confidences > 10, "Should have diverse confidence scores" + + def test_predicate_diversity(self, small_kg_dataset): + """Test that dataset uses diverse predicates.""" + predicates = set(t.predicate for t in small_kg_dataset.triples) + + # Should have at least 15 different predicates + assert len(predicates) >= 15, f"Only {len(predicates)} unique predicates" + + # Check for expected predicate categories + level1_preds = set(t.predicate for t in small_kg_dataset.triples if t.level == 1) + level2_preds = set(t.predicate for t in small_kg_dataset.triples if t.level == 2) + + # Level 1 should include facts + expected_l1 = ['born_in', 'located_in', 'works_at', 'studied_at', 'won'] + assert any(pred in level1_preds for pred in expected_l1), \ + "Level 1 should include biographical/factual predicates" + + # Level 2 should include types + expected_l2 = ['instance_of', 'subclass_of', 'typically_has'] + assert any(pred in level2_preds for pred in expected_l2), \ + "Level 2 should include type/category predicates" + + +class TestEntityDiversity: + """Test entity generation and diversity.""" + + def test_entity_count(self, small_kg_dataset): + """Test that dataset has sufficient unique entities.""" + entities = set() + for triple in small_kg_dataset.triples: + entities.add(triple.subject) + entities.add(triple.object) + + # Should have diverse entities + assert len(entities) >= 50, f"Only {len(entities)} unique entities" + + def test_entity_categories(self, small_kg_dataset): + """Test that entities span multiple categories.""" + # Check that we have different types of entities + assert len(small_kg_dataset.people) > 0, "Should have people" + assert len(small_kg_dataset.places) > 0, "Should have places" + assert len(small_kg_dataset.organizations) > 0, "Should have organizations" + assert len(small_kg_dataset.concepts) > 0, "Should have concepts" + assert len(small_kg_dataset.awards) > 0, "Should have awards" + + def test_named_entities(self, small_kg_dataset): + """Test that famous entities are included.""" + all_entities = set() + for triple in small_kg_dataset.triples: + all_entities.add(triple.subject) + all_entities.add(triple.object) + + # Check for some expected named entities + expected_people = ["Albert_Einstein", "Marie_Curie", "Isaac_Newton"] + found_people = [p for p in expected_people if p in all_entities] + assert len(found_people) > 0, "Should include famous people" + + expected_places = ["London", "Paris", "New_York"] + found_places = [p for p in expected_places if p in all_entities] + assert len(found_places) > 0, "Should include major cities" + + def test_entity_types_mapping(self, small_kg_dataset): + """Test that entities have type mappings.""" + assert len(small_kg_dataset.entity_types) > 0 + assert "Person" in small_kg_dataset.entity_types.values() + assert "Place" in small_kg_dataset.entity_types.values() + + +class TestMultiHopReasoning: + """Test multi-hop reasoning capabilities.""" + + def test_multi_hop_query_generation(self, small_kg_dataset): + """Test generation of multi-hop queries.""" + queries = small_kg_dataset.get_multi_hop_queries(num_queries=20) + + # Should generate some queries + assert len(queries) > 0, "Should generate multi-hop queries" + + # Check query structure + for query in queries[:5]: + assert 'start_entity' in query + assert 'relations' in query + assert 'expected_answer' in query + assert isinstance(query['relations'], list) + assert len(query['relations']) > 0 + + def test_two_hop_reasoning_paths(self, medium_kg_dataset): + """Test that 2-hop reasoning paths exist.""" + queries = medium_kg_dataset.get_multi_hop_queries(num_queries=50) + + two_hop_queries = [q for q in queries if q.get('query_type') == '2-hop'] + assert len(two_hop_queries) > 0, "Should have 2-hop reasoning paths" + + # Verify path structure + for query in two_hop_queries[:3]: + assert len(query['relations']) == 2 + assert 'intermediate' in query + assert query['start_entity'] != query['expected_answer'] + + +class TestTypeHierarchy: + """Test type hierarchy and consistency.""" + + def test_type_hierarchy_exists(self, small_kg_dataset): + """Test that type hierarchy is defined.""" + assert len(small_kg_dataset.type_hierarchy) > 0 + assert "Person" in small_kg_dataset.type_hierarchy + assert "Place" in small_kg_dataset.type_hierarchy + + def test_type_triples(self, small_kg_dataset): + """Test that instance_of and subclass_of triples exist.""" + instance_of_triples = [ + t for t in small_kg_dataset.triples + if t.predicate == "instance_of" + ] + subclass_of_triples = [ + t for t in small_kg_dataset.triples + if t.predicate == "subclass_of" + ] + + assert len(instance_of_triples) > 0, "Should have instance_of triples" + assert len(subclass_of_triples) > 0, "Should have subclass_of triples" + + # Check confidence for type triples + for triple in instance_of_triples[:5]: + assert triple.level == 2, "Type triples should be Level 2" + assert triple.confidence >= 0.5 + + def test_type_consistency_pairs(self, small_kg_dataset): + """Test generation of type consistency checking pairs.""" + pairs = small_kg_dataset.get_type_consistency_pairs(num_pairs=50) + + assert len(pairs) > 0, "Should generate consistency pairs" + + positive_pairs = [p for p in pairs if p[2] is True] + negative_pairs = [p for p in pairs if p[2] is False] + + assert len(positive_pairs) > 0, "Should have positive examples" + assert len(negative_pairs) > 0, "Should have negative examples" + + # Check pair structure + for entity, entity_type, is_consistent in pairs[:5]: + assert isinstance(entity, str) + assert isinstance(entity_type, str) + assert isinstance(is_consistent, bool) + + +class TestDatasetInterface: + """Test PyG dataset interface compliance.""" + + def test_getitem(self, small_kg_dataset): + """Test __getitem__ returns correct format.""" + graph, label = small_kg_dataset[0] + + # Check graph structure + assert hasattr(graph, 'x'), "Graph should have node features" + assert hasattr(graph, 'edge_index'), "Graph should have edge_index" + assert hasattr(graph, 'edge_attr'), "Graph should have edge_attr" + assert hasattr(graph, 'edge_type'), "Graph should have edge_type" + + # Check label + assert isinstance(label, torch.Tensor) + assert label.shape == (1,), f"Label shape should be (1,), got {label.shape}" + assert 0.0 <= label.item() <= 1.0, "Label should be confidence in [0, 1]" + + def test_batch_loading(self, small_kg_dataset): + """Test that multiple samples can be loaded.""" + batch_size = 5 + samples = [small_kg_dataset[i] for i in range(batch_size)] + + assert len(samples) == batch_size + for graph, label in samples: + assert graph.num_nodes > 0 + assert graph.edge_index.size(1) > 0 + + def test_statistics(self, small_kg_dataset): + """Test dataset statistics computation.""" + stats = small_kg_dataset.get_statistics() + + assert 'num_triples' in stats + assert 'num_entities' in stats + assert 'num_predicates' in stats + assert 'avg_confidence' in stats + assert 'level_distribution' in stats + + assert stats['num_triples'] == 500 + assert stats['num_entities'] > 0 + assert stats['num_predicates'] >= 15 + assert 0.0 <= stats['avg_confidence'] <= 1.0 + assert 1 in stats['level_distribution'] + assert 2 in stats['level_distribution'] + + +class TestCaching: + """Test dataset caching functionality.""" + + def test_cache_creation(self, test_data_dir): + """Test that cache files are created.""" + dataset = KnowledgeGraphTripleDataset( + root=str(test_data_dir / "cache_test"), + split='train', + num_entities=50, + num_triples=200, + seed=42 + ) + + cache_file = test_data_dir / "cache_test" / "processed" / "train_triples.pt" + assert cache_file.exists(), "Cache file should be created" + + def test_cache_loading(self, test_data_dir): + """Test that cached data is loaded correctly.""" + root = str(test_data_dir / "cache_test2") + + # Create initial dataset + dataset1 = KnowledgeGraphTripleDataset( + root=root, + split='train', + num_entities=50, + num_triples=200, + seed=42 + ) + triples1 = dataset1.triples.copy() + + # Load from cache + dataset2 = KnowledgeGraphTripleDataset( + root=root, + split='train', + num_entities=50, + num_triples=200, + seed=99 # Different seed shouldn't matter - should load from cache + ) + triples2 = dataset2.triples + + # Should be identical (loaded from cache) + assert len(triples1) == len(triples2) + # First triple should be the same + assert triples1[0].subject == triples2[0].subject + assert triples1[0].predicate == triples2[0].predicate + assert triples1[0].object == triples2[0].object + + +class TestReproducibility: + """Test reproducibility with seeds.""" + + def test_seed_reproducibility(self, test_data_dir): + """Test that same seed produces same triples.""" + dataset1 = KnowledgeGraphTripleDataset( + root=str(test_data_dir / "seed1"), + split='train', + num_entities=100, + num_triples=500, + seed=42 + ) + + dataset2 = KnowledgeGraphTripleDataset( + root=str(test_data_dir / "seed2"), + split='train', + num_entities=100, + num_triples=500, + seed=42 + ) + + # Should generate same triples + assert len(dataset1.triples) == len(dataset2.triples) + + # Check first 10 triples match + for i in range(10): + t1 = dataset1.triples[i] + t2 = dataset2.triples[i] + assert t1.subject == t2.subject + assert t1.predicate == t2.predicate + assert t1.object == t2.object + assert abs(t1.confidence - t2.confidence) < 1e-6 + + def test_different_seeds_differ(self, test_data_dir): + """Test that different seeds produce different triples.""" + dataset1 = KnowledgeGraphTripleDataset( + root=str(test_data_dir / "diff_seed1"), + split='train', + num_entities=100, + num_triples=500, + seed=42 + ) + + dataset2 = KnowledgeGraphTripleDataset( + root=str(test_data_dir / "diff_seed2"), + split='train', + num_entities=100, + num_triples=500, + seed=123 + ) + + # Should generate different triples + different_count = 0 + for i in range(min(50, len(dataset1.triples))): + t1 = dataset1.triples[i] + t2 = dataset2.triples[i] + if t1.subject != t2.subject or t1.predicate != t2.predicate: + different_count += 1 + + assert different_count > 0, "Different seeds should produce different triples" + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])