diff --git a/doc/modules/high_level_api.rst b/doc/modules/high_level_api.rst index f608032e..e6186bd3 100644 --- a/doc/modules/high_level_api.rst +++ b/doc/modules/high_level_api.rst @@ -118,6 +118,25 @@ A CBM learns interpretable concept representations and uses them to predict task } ) +**Concept Memory Reasoner (CMR)** + +A neurosymbolic concept-based model with task-specific rule selection and reasoning: + +.. code-block:: python + + from torch_concepts.nn import ConceptMemoryReasoner + + model = ConceptMemoryReasoner( + input_size=2048, + annotations=annotations, + task_names=['class_A', 'class_B'], + n_rules=10, + memory_latent_size=100, + memory_decoder_hidden_layers=1, + selector_hidden_layers=1, + eps=1e-3, + ) + **BlackBox Model** A standard neural network for comparison baselines: diff --git a/doc/modules/nn.encoders.rst b/doc/modules/nn.encoders.rst index c1fc75e5..33db0b34 100644 --- a/doc/modules/nn.encoders.rst +++ b/doc/modules/nn.encoders.rst @@ -18,6 +18,7 @@ Summary LinearExogenousToConcept StochasticLatentToConcept LinearLatentToExogenous + CategoricalSelectorLatentToExogenous SelectorLatentToExogenous @@ -44,6 +45,11 @@ Class Documentation :undoc-members: :show-inheritance: +.. autoclass:: CategoricalSelectorLatentToExogenous + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: SelectorLatentToExogenous :members: :undoc-members: diff --git a/doc/modules/nn.loss.rst b/doc/modules/nn.loss.rst index c8bff61e..49528cff 100644 --- a/doc/modules/nn.loss.rst +++ b/doc/modules/nn.loss.rst @@ -16,6 +16,7 @@ Summary ConceptLoss WeightedConceptLoss + CMRLoss **Low-Level Losses** @@ -41,6 +42,11 @@ Class Documentation :undoc-members: :show-inheritance: +.. autoclass:: CMRLoss + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: WeightedBCEWithLogitsLoss :members: :undoc-members: diff --git a/doc/modules/nn.models.high.rst b/doc/modules/nn.models.high.rst index acf76cc8..8e42225c 100644 --- a/doc/modules/nn.models.high.rst +++ b/doc/modules/nn.models.high.rst @@ -16,6 +16,7 @@ Summary ConceptBottleneckModel ConceptEmbeddingModel + ConceptMemoryReasoner BlackBox BlackBoxTaskOnly @@ -33,6 +34,11 @@ Class Documentation :undoc-members: :show-inheritance: +.. autoclass:: ConceptMemoryReasoner + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: BlackBox :members: :undoc-members: diff --git a/doc/modules/nn.predictors.rst b/doc/modules/nn.predictors.rst index d87280b4..7edc9aad 100644 --- a/doc/modules/nn.predictors.rst +++ b/doc/modules/nn.predictors.rst @@ -16,6 +16,7 @@ Summary LinearConceptToConcept MixConceptExogegnousToConcept + MixMemoryConceptExogenousToConcept HyperlinearConceptExogenousToConcept CallableConceptToConcept @@ -33,6 +34,11 @@ Class Documentation :undoc-members: :show-inheritance: +.. autoclass:: MixMemoryConceptExogenousToConcept + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: HyperlinearConceptExogenousToConcept :members: :undoc-members: diff --git a/examples/utilization/0_layer/7_concept_based_memory_reasoner.py b/examples/utilization/0_layer/7_concept_based_memory_reasoner.py new file mode 100644 index 00000000..a5ad4238 --- /dev/null +++ b/examples/utilization/0_layer/7_concept_based_memory_reasoner.py @@ -0,0 +1,105 @@ +""" +Example: Concept Memory Reasoner with Low-Level API + +This example demonstrates how to build a Concept Memory Reasoner (CMR) +using the low-level encoder and predictor layers. +""" +import torch +from sklearn.metrics import accuracy_score +from torch.nn import ModuleDict + +from torch_concepts import seed_everything +from torch_concepts.data.datasets import ToyDataset +from torch_concepts.nn import ( + CategoricalSelectorLatentToExogenous, + LinearLatentToConcept, + MixMemoryConceptExogenousToConcept, +) + + +def main(): + latent_dims = 10 + n_epochs = 500 + n_samples = 1000 + nb_rules = 10 + memory_latent_size = 100 + rec_weight = 0.1 + + seed_everything(42) + + # Load dataset + dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples) + x_train = dataset.input_data + concept_idx = list(dataset.graph.edge_index[0].unique().numpy()) + task_idx = list(dataset.graph.edge_index[1].unique().numpy()) + c_train = dataset.concepts[:, concept_idx] + y_train = dataset.concepts[:, task_idx] + + # Get dimensions + n_features = x_train.shape[1] + n_concepts = c_train.shape[1] + n_tasks = y_train.shape[1] + + # Build model using low-level layers + latent_encoder = torch.nn.Sequential( + torch.nn.Linear(n_features, latent_dims), + torch.nn.LeakyReLU(), + ) + + exog_encoder = CategoricalSelectorLatentToExogenous( + in_latent=latent_dims, + out_concepts=n_tasks, + out_exogenous=nb_rules, + ) + + c_encoder = LinearLatentToConcept(in_latent=latent_dims, out_concepts=n_concepts) + + y_predictor = MixMemoryConceptExogenousToConcept( + in_concepts=n_concepts, + in_exogenous=nb_rules, + out_concepts=n_tasks, + memory_latent_size=memory_latent_size, + eps=0.001, + ) + model = ModuleDict( + {"latent_encoder": latent_encoder, + "exog_encoder": exog_encoder, + "concept_encoder": c_encoder, + "task_predictor": y_predictor} + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) + loss_fn_y = torch.nn.BCELoss(reduction='none') + loss_fn_c = torch.nn.BCEWithLogitsLoss() + model.train() + + for epoch in range(n_epochs): + optimizer.zero_grad() + + # Generate concept and task predictions + emb = latent_encoder(x_train) + exog = exog_encoder(latent=emb) + c_pred = c_encoder(latent=emb) + y_pred = y_predictor(concepts=c_pred, exogenous=exog) + y_pred_with_rec = y_predictor(concepts=c_pred, exogenous=exog, include_rec=True, rec_weight=rec_weight) + + # Compute loss + concept_loss = loss_fn_c(c_pred, c_train) + task_loss_no_rec = loss_fn_y(y_pred, y_train) + task_loss_rec = loss_fn_y(y_pred_with_rec, y_train) + task_loss = (y_train * task_loss_rec + (1 - y_train) * task_loss_no_rec).mean() # only apply rec loss to positive samples + loss = concept_loss + task_loss + + loss.backward() + optimizer.step() + + if epoch % 100 == 0: + task_accuracy = accuracy_score(y_train, y_pred.detach() > 0.5) + concept_accuracy = accuracy_score(c_train, c_pred.detach() > 0.) + print(f"Epoch {epoch}: Loss {loss.item():.2f} | Task Acc: {task_accuracy:.2f} | Concept Acc: {concept_accuracy:.2f}") + + return + + +if __name__ == "__main__": + main() diff --git a/examples/utilization/2.2_model/10_different_training_modes.py b/examples/utilization/2.2_model/10_different_training_modes.py index c3df99b7..d753a24e 100644 --- a/examples/utilization/2.2_model/10_different_training_modes.py +++ b/examples/utilization/2.2_model/10_different_training_modes.py @@ -23,7 +23,12 @@ import torch from torch_concepts import seed_everything -from torch_concepts.nn import ConceptBottleneckModel, ConceptEmbeddingModel +from torch_concepts.nn import ( + CMRLoss, + ConceptBottleneckModel, + ConceptEmbeddingModel, + ConceptMemoryReasoner, +) from torch_concepts.nn.modules.mid.inference import ( DeterministicInference, IndependentInference @@ -202,6 +207,44 @@ def main(): trainer_cem.fit(model_cem, datamodule=datamodule) evaluate(model_cem, datamodule, n_concepts, query) + # ========================================================================= + # CMR WITH JOINT TRAINING + # ========================================================================= + print("\n" + "=" * 60) + print("Example 4: CMR with Joint Training") + print("=" * 60) + print("Uses DeterministicInference for both training and evaluation") + + cmr_loss = CMRLoss() + optim_kwargs_cmr = {'lr': 0.01} + + model_cmr = ConceptMemoryReasoner( + input_size=n_features, + annotations=annotations, + variable_distributions=variable_distributions, + task_names=['xor'], + n_rules=10, + memory_latent_size=100, + memory_decoder_hidden_layers=1, + selector_hidden_layers=1, + rec_weight=0.1, + eps=1e-3, + latent_encoder_kwargs={'hidden_size': 16, 'n_layers': 1}, + inference=DeterministicInference, + train_inference=DeterministicInference, + lightning=True, + loss=cmr_loss, + optim_class=optim, + optim_kwargs=optim_kwargs_cmr, + ) + print(f"Model type: {type(model_cmr).__name__}") + print(f"Eval inference: {model_cmr.eval_inference.__class__.__name__}") + print(f"Training inference: {model_cmr.train_inference.__class__.__name__}") + + trainer_cmr = Trainer(max_epochs=100) + trainer_cmr.fit(model_cmr, datamodule=datamodule) + evaluate(model_cmr, datamodule, n_concepts, query) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/tests/nn/modules/high/models/test_cmr.py b/tests/nn/modules/high/models/test_cmr.py new file mode 100644 index 00000000..9d4a79a0 --- /dev/null +++ b/tests/nn/modules/high/models/test_cmr.py @@ -0,0 +1,320 @@ +""" +Comprehensive tests for Concept Memory Reasoner (CMR). + +Tests cover: +- Model initialization with various configurations +- Forward pass and output shapes +- CMR-specific parameter handling +- Inference mode configuration +- Filter methods +- Factory behavior (PyTorch vs Lightning) +- Cardinality constraint checks +""" +import unittest +import torch +import torch.nn as nn +from torch.distributions import Bernoulli + +from torch_concepts.nn.modules.high.models.cmr import ConceptMemoryReasoner +from torch_concepts.nn.modules.high.base.learner import BaseLearner +from torch_concepts.annotations import AxisAnnotation, Annotations +from torch_concepts.nn.modules.mid.inference import ( + DeterministicInference, + IndependentInference, +) + + +class DummyBackbone(nn.Module): + """Simple backbone for testing.""" + + def __init__(self, out_features=8): + super().__init__() + self.out_features = out_features + + def forward(self, x): + return torch.ones(x.shape[0], self.out_features) + + +class TestCMRInitialization(unittest.TestCase): + """Test CMR initialization.""" + + def setUp(self): + """Set up test fixtures.""" + self.ann = Annotations({ + 1: AxisAnnotation( + labels=['c1', 'c2', 'task1'], + cardinalities=[1, 1, 1], + metadata={ + 'c1': {'type': 'binary', 'distribution': Bernoulli}, + 'c2': {'type': 'binary', 'distribution': Bernoulli}, + 'task1': {'type': 'binary', 'distribution': Bernoulli}, + } + ) + }) + + def test_init_basic(self): + """Test basic initialization.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + ) + + self.assertIsInstance(model.model, nn.Module) + self.assertTrue(hasattr(model, 'inference')) + self.assertEqual(model.concept_names, ['c1', 'c2', 'task1']) + + def test_init_with_cmr_specific_params(self): + """Test initialization with CMR-specific hyperparameters.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + n_rules=7, + memory_latent_size=64, + memory_decoder_hidden_layers=2, + selector_hidden_layers=2, + eps=1e-4, + ) + + self.assertIsInstance(model.model, nn.Module) + self.assertTrue(hasattr(model, 'eval_inference')) + self.assertTrue(hasattr(model, 'train_inference')) + + def test_init_with_backbone(self): + """Test initialization with custom backbone.""" + backbone = DummyBackbone() + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + backbone=backbone, + task_names=['task1'], + ) + + self.assertIsNotNone(model.backbone) + + def test_init_with_latent_encoder(self): + """Test initialization with latent encoder config.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + latent_encoder_kwargs={'hidden_size': 16, 'n_layers': 2}, + ) + + self.assertEqual(model.latent_size, 16) + + def test_init_with_deterministic_inference(self): + """Test initialization with deterministic inference.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + inference=DeterministicInference, + ) + model.eval() + + self.assertIsInstance(model.inference, DeterministicInference) + + def test_init_with_independent_train_inference(self): + """Test initialization with independent train inference.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + inference=DeterministicInference, + train_inference=IndependentInference, + ) + model.train() + + self.assertIsInstance(model.inference, IndependentInference) + + def test_factory_default_is_pytorch(self): + """Test that default lightning=False creates pure PyTorch model.""" + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + ) + + self.assertFalse(isinstance(model, BaseLearner)) + + def test_factory_lightning_training(self): + """Test that lightning=True creates Lightning model.""" + model = ConceptMemoryReasoner( + lightning=True, + input_size=8, + annotations=self.ann, + task_names=['task1'], + ) + + self.assertIsInstance(model, BaseLearner) + + def test_cardinality_constraint(self): + """Test CMR cardinality constraint for concept variables.""" + ann_bad = Annotations({ + 1: AxisAnnotation( + labels=['c1', 'c2', 'task1'], + cardinalities=[2, 1, 1], + metadata={ + 'c1': {'type': 'binary', 'distribution': Bernoulli}, + 'c2': {'type': 'binary', 'distribution': Bernoulli}, + 'task1': {'type': 'binary', 'distribution': Bernoulli}, + } + ) + }) + + with self.assertRaises(AssertionError): + ConceptMemoryReasoner( + input_size=8, + annotations=ann_bad, + task_names=['task1'], + ) + + +class TestCMRForward(unittest.TestCase): + """Test CMR forward pass.""" + + def setUp(self): + """Set up test fixtures.""" + self.ann = Annotations({ + 1: AxisAnnotation( + labels=['c1', 'c2', 'task1'], + cardinalities=[1, 1, 1], + metadata={ + 'c1': {'type': 'binary', 'distribution': Bernoulli}, + 'c2': {'type': 'binary', 'distribution': Bernoulli}, + 'task1': {'type': 'binary', 'distribution': Bernoulli}, + } + ) + }) + + self.model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + task_names=['task1'], + ) + + def test_forward_basic(self): + """Test basic forward pass.""" + x = torch.randn(2, 8) + query = ['c1', 'c2'] + out = self.model(query=query, x=x) + + self.assertEqual(out.shape[0], 2) + self.assertEqual(out.shape[1], 2) + + def test_forward_all_concepts(self): + """Test forward with concepts and task.""" + x = torch.randn(4, 8) + query = ['c1', 'c2', 'task1'] + out = self.model(query=query, x=x) + + self.assertEqual(out.shape[0], 4) + self.assertEqual(out.shape[1], 3) + + def test_forward_only_task(self): + """Test forward with only task variable.""" + x = torch.randn(3, 8) + query = ['task1'] + out = self.model(query=query, x=x) + + self.assertEqual(out.shape[0], 3) + self.assertEqual(out.shape[1], 1) + + def test_forward_with_backbone(self): + """Test forward pass with backbone.""" + backbone = DummyBackbone(out_features=8) + model = ConceptMemoryReasoner( + input_size=8, + annotations=self.ann, + backbone=backbone, + task_names=['task1'], + ) + + x = torch.randn(2, 100) + query = ['c1', 'task1'] + out = model(query=query, x=x) + + self.assertEqual(out.shape[0], 2) + self.assertEqual(out.shape[1], 2) + + +class TestCMRFilterMethods(unittest.TestCase): + """Test CMR filter methods.""" + + def setUp(self): + """Set up test fixtures.""" + ann = Annotations({ + 1: AxisAnnotation( + labels=['c1', 'c2', 'task1'], + cardinalities=[1, 1, 1], + metadata={ + 'c1': {'type': 'binary', 'distribution': Bernoulli}, + 'c2': {'type': 'binary', 'distribution': Bernoulli}, + 'task1': {'type': 'binary', 'distribution': Bernoulli}, + } + ) + }) + + self.model = ConceptMemoryReasoner( + input_size=8, + annotations=ann, + task_names=['task1'], + ) + + def test_filter_output_for_loss(self): + """Test filter_output_for_loss returns explicit CMR loss kwargs.""" + x = torch.randn(2, 8) + query = ['c1', 'c2', 'task1'] + out_no_rec = self.model(query=query, x=x, include_rec=False, rec_weight=0.1) + out_with_rec = self.model(query=query, x=x, include_rec=True, rec_weight=0.1) + target = torch.randint(0, 2, out_no_rec.shape).float() + + filtered = self.model.filter_output_for_loss( + {'no_rec': out_no_rec, 'with_rec': out_with_rec}, + target, + ) + + self.assertIsInstance(filtered, dict) + self.assertIn('concept_input', filtered) + self.assertIn('concept_target', filtered) + self.assertIn('task_input', filtered) + self.assertIn('task_input_with_rec', filtered) + self.assertIn('task_target', filtered) + + self.assertEqual(filtered['concept_input'].shape, (2, 2)) + self.assertEqual(filtered['concept_target'].shape, (2, 2)) + self.assertEqual(filtered['task_input'].shape, (2, 1)) + self.assertEqual(filtered['task_input_with_rec'].shape, (2, 1)) + self.assertEqual(filtered['task_target'].shape, (2, 1)) + + def test_filter_output_for_loss_requires_dict(self): + """Test filter_output_for_loss rejects non-dict and incomplete dict inputs.""" + out = torch.randn(2, 3) + target = torch.randint(0, 2, out.shape).float() + + with self.assertRaises(ValueError): + self.model.filter_output_for_loss(out, target) + + with self.assertRaises(ValueError): + self.model.filter_output_for_loss({'no_rec': out}, target) + + def test_filter_output_for_metrics(self): + """Test filter_output_for_metrics returns correct format.""" + x = torch.randn(2, 8) + query = ['c1', 'c2', 'task1'] + out = self.model(query=query, x=x) + target = torch.randint(0, 2, out.shape).float() + + filtered = self.model.filter_output_for_metrics(out, target) + + self.assertIsInstance(filtered, dict) + self.assertIn('preds', filtered) + self.assertIn('target', filtered) + self.assertTrue(torch.allclose(filtered['preds'], out)) + self.assertTrue(torch.allclose(filtered['target'], target)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/nn/modules/low/encoders/test_selector.py b/tests/nn/modules/low/encoders/test_selector.py index 64e5af14..4cecf6a5 100644 --- a/tests/nn/modules/low/encoders/test_selector.py +++ b/tests/nn/modules/low/encoders/test_selector.py @@ -6,7 +6,10 @@ import unittest import torch import torch.nn as nn -from torch_concepts.nn.modules.low.encoders.selector import SelectorLatentToExogenous +from torch_concepts.nn.modules.low.encoders.selector import ( + CategoricalSelectorLatentToExogenous, + SelectorLatentToExogenous, +) class TestSelectorLatentToExogenous(unittest.TestCase): @@ -128,5 +131,111 @@ def test_batch_processing(self): self.assertEqual(output.shape, (batch_size, 3, 4)) +class TestCategoricalSelectorLatentToExogenous(unittest.TestCase): + """Test CategoricalSelectorLatentToExogenous.""" + + def test_initialization(self): + """Test selector initialization.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=64, + out_concepts=5, + out_exogenous=8, + selector_hidden_layers=2, + ) + self.assertEqual(selector.in_latent, 64) + self.assertEqual(selector.out_concepts, 5) + self.assertEqual(selector.out_exogenous, 8) + self.assertEqual(selector.selector_hidden_layers, 2) + + def test_forward_shape(self): + """Test forward pass output shape.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=64, + out_concepts=4, + out_exogenous=6, + ) + latent = torch.randn(2, 64) + output = selector(latent=latent) + self.assertEqual(output.shape, (2, 4, 6)) + + def test_output_is_normalized_over_exogenous_dim(self): + """Test output probabilities sum to 1 over exogenous dimension.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=5, + ) + latent = torch.randn(3, 32) + output = selector(latent=latent) + + sums = output.sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums), atol=1e-5)) + + def test_gradient_flow(self): + """Test gradient flow through selector.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=4, + ) + embeddings = torch.randn(2, 32, requires_grad=True) + output = selector(latent=embeddings) + loss = output.sum() + loss.backward() + self.assertIsNotNone(embeddings.grad) + + def test_hidden_layers_configuration(self): + """Test configurable hidden layers in selector network.""" + selector_zero = CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=4, + selector_hidden_layers=0, + ) + selector_two = CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=4, + selector_hidden_layers=2, + ) + + linear_zero = sum(isinstance(layer, nn.Linear) for layer in selector_zero.selector) + linear_two = sum(isinstance(layer, nn.Linear) for layer in selector_two.selector) + + self.assertEqual(linear_zero, 1) + self.assertEqual(linear_two, 3) + + def test_selector_hidden_layers_validation(self): + """Test hidden layer argument validation.""" + with self.assertRaises(ValueError): + CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=4, + selector_hidden_layers=-1, + ) + + def test_selector_network(self): + """Test selector network structure.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=64, + out_concepts=4, + out_exogenous=6, + ) + self.assertIsInstance(selector.selector, nn.Sequential) + + def test_batch_processing(self): + """Test different batch sizes.""" + selector = CategoricalSelectorLatentToExogenous( + in_latent=32, + out_concepts=3, + out_exogenous=4, + ) + for batch_size in [1, 4, 8]: + embeddings = torch.randn(batch_size, 32) + output = selector(latent=embeddings) + self.assertEqual(output.shape, (batch_size, 3, 4)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/nn/modules/low/predictors/test_exogenous.py b/tests/nn/modules/low/predictors/test_exogenous.py index 5149f12f..4a9d1893 100644 --- a/tests/nn/modules/low/predictors/test_exogenous.py +++ b/tests/nn/modules/low/predictors/test_exogenous.py @@ -6,7 +6,7 @@ import unittest import torch import torch.nn as nn -from torch_concepts.nn import MixConceptExogegnousToConcept +from torch_concepts.nn import MixConceptExogegnousToConcept, MixMemoryConceptExogenousToConcept class TestMixConceptExogegnousToConcept(unittest.TestCase): @@ -67,5 +67,110 @@ def test_gradient_flow(self): self.assertIsNotNone(exogenous.grad) +class TestMixMemoryConceptExogenousToConcept(unittest.TestCase): + """Test MixMemoryConceptExogenousToConcept.""" + + def test_initialization(self): + """Test predictor initialization.""" + predictor = MixMemoryConceptExogenousToConcept( + in_concepts=10, + in_exogenous=5, + out_concepts=3, + memory_latent_size=64, + memory_decoder_hidden_layers=2, + eps=1e-3, + ) + self.assertEqual(predictor.in_concepts, 10) + self.assertEqual(predictor.in_exogenous, 5) + self.assertEqual(predictor.out_concepts, 3) + self.assertEqual(predictor.memory_decoder_hidden_layers, 2) + self.assertEqual(predictor.memory.weight.shape, (3, 64)) + self.assertEqual(predictor.memory_network_shape, (5, 10, 3)) + + def test_forward_shape(self): + """Test forward pass output shape.""" + predictor = MixMemoryConceptExogenousToConcept( + in_concepts=8, + in_exogenous=4, + out_concepts=2, + ) + concepts = torch.randn(4, 8) + exogenous = torch.softmax(torch.randn(4, 2, 4), dim=-1) + output = predictor(concepts=concepts, exogenous=exogenous) + self.assertEqual(output.shape, (4, 2)) + + def test_forward_with_optional_flags(self): + """Test forward pass with include_rec and hard_roles options.""" + predictor = MixMemoryConceptExogenousToConcept( + in_concepts=6, + in_exogenous=3, + out_concepts=2, + ) + concepts = torch.randn(3, 6) + exogenous = torch.softmax(torch.randn(3, 2, 3), dim=-1) + + output_rec = predictor( + concepts=concepts, + exogenous=exogenous, + include_rec=True, + rec_weight=1.0, + ) + output_hard = predictor( + concepts=concepts, + exogenous=exogenous, + hard_roles=True, + ) + self.assertEqual(output_rec.shape, (3, 2)) + self.assertEqual(output_hard.shape, (3, 2)) + + def test_memory_decoder_hidden_layers_config(self): + """Test configurable number of hidden layers in memory decoder.""" + predictor_zero = MixMemoryConceptExogenousToConcept( + in_concepts=4, + in_exogenous=3, + out_concepts=2, + memory_decoder_hidden_layers=0, + ) + predictor_two = MixMemoryConceptExogenousToConcept( + in_concepts=4, + in_exogenous=3, + out_concepts=2, + memory_decoder_hidden_layers=2, + ) + + linear_zero = sum(isinstance(layer, torch.nn.Linear) for layer in predictor_zero.memory_decoder) + linear_two = sum(isinstance(layer, torch.nn.Linear) for layer in predictor_two.memory_decoder) + + self.assertEqual(linear_zero, 1) + self.assertEqual(linear_two, 3) + + def test_memory_decoder_hidden_layers_validation(self): + """Test hidden layer argument validation.""" + with self.assertRaises(ValueError): + MixMemoryConceptExogenousToConcept( + in_concepts=4, + in_exogenous=2, + out_concepts=1, + memory_decoder_hidden_layers=-1, + ) + + def test_gradient_flow(self): + """Test gradient flow to exogenous and memory while concepts are detached.""" + predictor = MixMemoryConceptExogenousToConcept( + in_concepts=5, + in_exogenous=4, + out_concepts=2, + ) + concepts = torch.randn(2, 5, requires_grad=True) + exogenous = torch.softmax(torch.randn(2, 2, 4), dim=-1).requires_grad_() + output = predictor(concepts=concepts, exogenous=exogenous) + loss = output.sum() + loss.backward() + + self.assertIsNone(concepts.grad) + self.assertIsNotNone(exogenous.grad) + self.assertIsNotNone(predictor.memory.weight.grad) + + if __name__ == '__main__': unittest.main() diff --git a/tests/nn/modules/test_loss.py b/tests/nn/modules/test_loss.py index c8fc37e7..cd9fbb0b 100644 --- a/tests/nn/modules/test_loss.py +++ b/tests/nn/modules/test_loss.py @@ -4,11 +4,12 @@ Tests loss functions for concept-based learning: - ConceptLoss: Unified loss for concepts with different types - WeightedConceptLoss: Weighted combination of concept and task losses +- CMRLoss: Explicit CMR objective with reconstruction-aware task branch """ import unittest import torch from torch import nn -from torch_concepts.nn.modules.loss import ConceptLoss, WeightedConceptLoss +from torch_concepts.nn.modules.loss import CMRLoss, ConceptLoss, WeightedConceptLoss from torch_concepts.nn.modules.utils import GroupConfig from torch_concepts.annotations import AxisAnnotation, Annotations @@ -419,5 +420,95 @@ def test_unused_loss_warning(self): self.assertTrue(any("continuous" in str(warning.message).lower() for warning in w)) +class TestCMRLoss(unittest.TestCase): + """Test CMRLoss explicit CMR objective behavior.""" + + def setUp(self): + """Set up tensors for explicit CMR loss inputs.""" + torch.manual_seed(42) + self.batch_size = 12 + self.n_concepts = 4 + self.n_tasks = 2 + + self.concept_input = torch.randn(self.batch_size, self.n_concepts) + self.concept_target = torch.randint( + 0, 2, (self.batch_size, self.n_concepts) + ).float() + + self.task_input = torch.rand(self.batch_size, self.n_tasks) + self.task_input_with_rec = torch.rand(self.batch_size, self.n_tasks) + self.task_target = torch.randint( + 0, 2, (self.batch_size, self.n_tasks) + ).float() + + def test_basic_forward(self): + """Test basic forward pass with explicit tensors.""" + loss_fn = CMRLoss() + loss = loss_fn( + concept_input=self.concept_input, + concept_target=self.concept_target, + task_input=self.task_input, + task_input_with_rec=self.task_input_with_rec, + task_target=self.task_target, + ) + + self.assertIsInstance(loss, torch.Tensor) + self.assertEqual(loss.shape, ()) + self.assertTrue(loss >= 0) + + def test_requires_explicit_tensors(self): + """Test that missing explicit CMR kwargs raises ValueError.""" + loss_fn = CMRLoss() + with self.assertRaises(ValueError): + loss_fn(input=self.task_input, target=self.task_target) + + def test_gradient_flow(self): + """Test gradients flow through CMRLoss inputs.""" + loss_fn = CMRLoss() + + concept_input = self.concept_input.clone().detach().requires_grad_(True) + task_input = self.task_input.clone().detach().requires_grad_(True) + task_input_with_rec = self.task_input_with_rec.clone().detach().requires_grad_(True) + + loss = loss_fn( + concept_input=concept_input, + concept_target=self.concept_target, + task_input=task_input, + task_input_with_rec=task_input_with_rec, + task_target=self.task_target, + ) + loss.backward() + + self.assertIsNotNone(concept_input.grad) + self.assertIsNotNone(task_input.grad) + self.assertIsNotNone(task_input_with_rec.grad) + self.assertTrue(torch.any(concept_input.grad != 0)) + self.assertTrue(torch.any(task_input.grad != 0)) + self.assertTrue(torch.any(task_input_with_rec.grad != 0)) + + def test_weight_parameters_affect_value(self): + """Test that changing concept/task weights changes final loss.""" + loss_concept_only = CMRLoss(concept_weight=1.0, task_weight=0.0)( + concept_input=self.concept_input, + concept_target=self.concept_target, + task_input=self.task_input, + task_input_with_rec=self.task_input_with_rec, + task_target=self.task_target, + ) + loss_task_only = CMRLoss(concept_weight=0.0, task_weight=1.0)( + concept_input=self.concept_input, + concept_target=self.concept_target, + task_input=self.task_input, + task_input_with_rec=self.task_input_with_rec, + task_target=self.task_target, + ) + + self.assertNotAlmostEqual( + loss_concept_only.item(), + loss_task_only.item(), + places=5, + ) + + if __name__ == '__main__': unittest.main() diff --git a/torch_concepts/nn/__init__.py b/torch_concepts/nn/__init__.py index 2b5b9d75..2eb8f978 100644 --- a/torch_concepts/nn/__init__.py +++ b/torch_concepts/nn/__init__.py @@ -22,11 +22,17 @@ from .modules.low.encoders.linear import LinearLatentToConcept, LinearExogenousToConcept from .modules.low.encoders.exogenous import LinearLatentToExogenous from .modules.low.encoders.stochastic import StochasticLatentToConcept -from .modules.low.encoders.selector import SelectorLatentToExogenous +from .modules.low.encoders.selector import ( + CategoricalSelectorLatentToExogenous, + SelectorLatentToExogenous, +) # Predictors from .modules.low.predictors.linear import LinearConceptToConcept -from .modules.low.predictors.exogenous import MixConceptExogegnousToConcept +from .modules.low.predictors.exogenous import ( + MixConceptExogegnousToConcept, + MixMemoryConceptExogenousToConcept, +) from .modules.low.predictors.hypernet import HyperlinearConceptExogenousToConcept from .modules.low.predictors.call import CallableConceptToConcept @@ -37,7 +43,7 @@ from .modules.low.graph.wanda import WANDAGraphLearner # Loss functions -from .modules.loss import ConceptLoss, WeightedConceptLoss +from .modules.loss import CMRLoss, ConceptLoss, WeightedConceptLoss # Metrics from .modules.metrics import ConceptMetrics @@ -46,6 +52,7 @@ from .modules.high.models.blackbox import BlackBox, BlackBoxTaskOnly from .modules.high.models.cbm import ConceptBottleneckModel from .modules.high.models.cem import ConceptEmbeddingModel +from .modules.high.models.cmr import ConceptMemoryReasoner @@ -105,6 +112,7 @@ # Predictor classes "LinearConceptToConcept", "MixConceptExogegnousToConcept", + "MixMemoryConceptExogenousToConcept", "HyperlinearConceptExogenousToConcept", "CallableConceptToConcept", @@ -113,12 +121,14 @@ "ResidualMLP", "MLP", + "CategoricalSelectorLatentToExogenous", "SelectorLatentToExogenous", # COSMO "WANDAGraphLearner", # Loss functions + "CMRLoss", "ConceptLoss", "WeightedConceptLoss", @@ -130,6 +140,7 @@ "BlackBoxTaskOnly", "ConceptBottleneckModel", "ConceptEmbeddingModel", + "ConceptMemoryReasoner", # Models (mid-level) "ParametricCPD", diff --git a/torch_concepts/nn/modules/high/models/cmr.py b/torch_concepts/nn/modules/high/models/cmr.py new file mode 100644 index 00000000..020edeeb --- /dev/null +++ b/torch_concepts/nn/modules/high/models/cmr.py @@ -0,0 +1,255 @@ +"""Concept-based Memory Reasoner (CMR) + + References: + Debot et al. "Interpretable Concept-Based Memory Reasoning", NeurIPS 2024. + https://arxiv.org/abs/2407.15527 +""" + +from typing import List, Optional, Union + +import torch + +from .....annotations import Annotations + +from ...low.base.inference import BaseInference +from ...low.encoders.linear import LinearLatentToConcept +from ...low.encoders.selector import CategoricalSelectorLatentToExogenous +from ...low.predictors.exogenous import MixMemoryConceptExogenousToConcept +from ...low.lazy import LazyConstructor + +from ...mid.inference.deterministic import DeterministicInference +from ...mid.constructors.bipartite import BipartiteModel + +from ..base.bipartite import BaseBipartiteModel + + +class ConceptMemoryReasoner(BaseBipartiteModel): + """Concept Memory Reasoner with configurable training mode. + + A unified CMR class that works as a pure PyTorch module by default, + or as a Lightning module when lightning=True. + + Parameters + ---------- + input_size : int + Dimensionality of input features (after backbone if used). + annotations : Annotations + Concept annotations with labels, cardinalities, and distributions. + task_names : Union[List[str], str] + Names of task variables (subset of annotation labels). + n_rules : int, optional + Number of candidate rules per task. Defaults to 10. + memory_latent_size : int, optional + Latent size of the task-specific rule memory. Defaults to 100. + memory_decoder_hidden_layers : int, optional + Number of hidden layers in the rule memory decoder. Defaults to 1. + selector_hidden_layers : int, optional + Number of hidden layers in the rule selector MLP. Defaults to 1. + rec_weight : float, optional + Reconstruction-weight exponent used by CMR reconstruction-aware + task prediction. Defaults to 0.1. + eps : float, optional + Numerical scaling factor used in the memory decoder softmax. + Defaults to 1e-3. + lightning : bool, default False + If True, adds Lightning training capabilities. + If False (default), works as pure PyTorch module. + inference : BaseInference, optional + Inference engine class for evaluation. Defaults to DeterministicInference. + train_inference : BaseInference, optional + Inference engine class for training. Only used when lightning=True. + Defaults to DeterministicInference. + variable_distributions : Mapping, optional + Distribution classes for each concept if not in annotations. + **kwargs + Additional arguments passed to BaseBipartiteModel, including: + + - **backbone** : Feature extraction module (e.g., ResNet) + - **latent_encoder** : Custom encoder for latent space + - **latent_encoder_kwargs** : Arguments for latent encoder + + Lightning Training (when lightning=True): + + - **loss** : Loss function (nn.Module) + - **metrics** : ConceptMetrics or dict of MetricCollections + - **optim_class** : Optimizer class (e.g., torch.optim.Adam) + - **optim_kwargs** : Optimizer arguments (e.g., {'lr': 0.001}) + - **scheduler_class** : LR scheduler class + - **scheduler_kwargs** : Scheduler arguments + + Examples + -------- + >>> # Pure PyTorch module (default) + >>> model = ConceptMemoryReasoner( + ... input_size=8, + ... annotations=ann, + ... task_names=['task'], + ... n_rules=10, + ... rec_weight=0.1, + ... ) + >>> out = model(x, query=['c1', 'task']) # Direct forward pass + + >>> # Lightning training enabled + >>> model = ConceptMemoryReasoner( + ... lightning=True, + ... input_size=8, + ... annotations=ann, + ... task_names=['task'], + ... n_rules=10, + ... rec_weight=0.1, + ... loss=my_loss, + ... optim_class=torch.optim.Adam, + ... optim_kwargs={'lr': 0.001} + ... ) + """ + + def __init__( + self, + input_size: int, + annotations: Annotations, + task_names: Union[List[str], str], + n_rules: int = 10, + memory_latent_size: int = 100, + memory_decoder_hidden_layers: int = 1, + selector_hidden_layers: int = 1, + rec_weight: float = 0.1, + eps: float = 1e-3, + inference: Optional[BaseInference] = DeterministicInference, + inference_kwargs: Optional[dict] = None, + train_inference: Optional[BaseInference] = DeterministicInference, + train_inference_kwargs: Optional[dict] = None, + lightning: bool = False, + **kwargs + ): + super().__init__( + input_size=input_size, + annotations=annotations, + task_names=task_names, + lightning=lightning, + **kwargs + ) + + # Extract concept cardinalities (excluding tasks) + concept_idxs = [self.concept_names.index(name) for name in self.concept_names + if name not in self.task_names] + cardinalities = [self.concept_annotations.cardinalities[i] for i in concept_idxs] + assert all(cardinality == 1 for cardinality in cardinalities), ( + "ConceptMemoryReasoner currently requires all concepts " + f"to have cardinality 1, got {cardinalities}." + ) + + self.rec_weight = rec_weight + + # Build bipartite model architecture with CMR components + self.model = BipartiteModel( + task_names=task_names, + input_size=self.latent_size, + annotations=annotations, + encoder=LazyConstructor(LinearLatentToConcept), + internal_exogenous=LazyConstructor( + CategoricalSelectorLatentToExogenous, + out_exogenous=n_rules, + selector_hidden_layers=selector_hidden_layers, + ), + predictor=LazyConstructor( + MixMemoryConceptExogenousToConcept, + memory_latent_size=memory_latent_size, + memory_decoder_hidden_layers=memory_decoder_hidden_layers, + eps=eps, + ), + use_source_exogenous=False, + ) + + self.eval_inference = inference( + self.model.probabilistic_model, + **(inference_kwargs or {}) + ) + self.train_inference = train_inference( + self.model.probabilistic_model, + **(train_inference_kwargs or {}) + ) + + def filter_output_for_loss(self, forward_out, target): + """Build explicit CMR loss kwargs for :class:`CMRLoss`. + + Parameters + ---------- + forward_out : dict + Dictionary with keys ``no_rec`` and ``with_rec`` containing full + model predictions (concepts + tasks). + target : torch.Tensor + Ground-truth tensor aligned with ``self.concept_names``. + + Returns + ------- + dict + Explicit tensors required by ``CMRLoss``. + """ + if not isinstance(forward_out, dict): + raise ValueError( + "ConceptMemoryReasoner.filter_output_for_loss expects a dict " + "with 'no_rec' and 'with_rec' predictions." + ) + + if 'no_rec' not in forward_out or 'with_rec' not in forward_out: + raise ValueError( + "ConceptMemoryReasoner.filter_output_for_loss requires both " + "'no_rec' and 'with_rec' entries in forward_out." + ) + + no_rec = forward_out['no_rec'] + with_rec = forward_out['with_rec'] + + task_indices = [ + i for i, name in enumerate(self.concept_names) + if name in self.task_names + ] + concept_indices = [ + i for i, name in enumerate(self.concept_names) + if name not in self.task_names + ] + + return { + 'concept_input': no_rec[:, concept_indices], + 'concept_target': target[:, concept_indices], + 'task_input': no_rec[:, task_indices], + 'task_input_with_rec': with_rec[:, task_indices], + 'task_target': target[:, task_indices], + } + + def shared_step(self, batch, step): + """CMR-specific Lightning step using explicit no-rec/with-rec losses.""" + inputs, concepts, _ = self.unpack_batch(batch) + batch_size = batch['inputs']['x'].size(0) + c = c_loss = concepts['c'] + + inference_kwargs = self._get_inference_kwargs(batch) + + out_no_rec = self.forward( + x=inputs['x'], + query=self.concept_names, + evidence=None, + include_rec=False, + rec_weight=self.rec_weight, + **inference_kwargs, + ) + out_with_rec = self.forward( + x=inputs['x'], + query=self.concept_names, + evidence=None, + include_rec=True, + rec_weight=self.rec_weight, + **inference_kwargs, + ) + + if self.loss is not None: + loss_args = self.filter_output_for_loss( + {'no_rec': out_no_rec, 'with_rec': out_with_rec}, + c_loss, + ) + loss = self.loss(**loss_args) + self.log_loss(step, loss, batch_size=batch_size) + + metrics_args = self.filter_output_for_metrics(out_no_rec, c) + self.update_and_log_metrics(metrics_args, step, batch_size) + return loss diff --git a/torch_concepts/nn/modules/loss.py b/torch_concepts/nn/modules/loss.py index 4b1c6d56..1e367773 100644 --- a/torch_concepts/nn/modules/loss.py +++ b/torch_concepts/nn/modules/loss.py @@ -214,4 +214,103 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: c_loss = self.concept_loss(concept_input, concept_target) t_loss = self.task_loss(task_input, task_target) - return c_loss * self.concept_weight + t_loss * self.task_weight \ No newline at end of file + return c_loss * self.concept_weight + t_loss * self.task_weight + + +class CMRLoss(nn.Module): + """ + Loss for Concept-based Memory Reasoner (CMR). + + Implements the objective used in CMR examples: + - concept loss on concept logits + - task loss without reconstruction term + - task loss with reconstruction term + - blended task objective that applies reconstruction-aware loss on + positive targets and standard loss on negative targets + + Args: + concept_weight: Weight applied to concept loss. + task_weight: Weight applied to blended task loss. + """ + def __init__( + self, + concept_weight: float = 1.0, + task_weight: float = 1.0, + ): + super().__init__() + self.concept_loss_fn = nn.BCEWithLogitsLoss() + self.task_loss_fn = nn.BCELoss(reduction='none') + self.concept_weight = concept_weight + self.task_weight = task_weight + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"concept_loss_fn={self.concept_loss_fn.__class__.__name__}, " + f"task_loss_fn={self.task_loss_fn.__class__.__name__}, " + f"concept_weight={self.concept_weight}, " + f"task_weight={self.task_weight})" + ) + + def _compute_explicit( + self, + concept_input: torch.Tensor, + concept_target: torch.Tensor, + task_input: torch.Tensor, + task_input_with_rec: torch.Tensor, + task_target: torch.Tensor, + ) -> torch.Tensor: + """ + Compute CMR objective. + + Args: + concept_input: Concept logits. + concept_target: Concept targets. + task_input: Task probabilities without reconstruction term. + task_input_with_rec: Task probabilities with reconstruction term. + task_target: Task targets. + + Returns: + Scalar CMR loss. + """ + concept_target = concept_target.float() + task_target = task_target.float() + + concept_loss = self.concept_loss_fn(concept_input, concept_target) + + task_loss_no_rec = self.task_loss_fn(task_input, task_target) + task_loss_rec = self.task_loss_fn(task_input_with_rec, task_target) + + if task_loss_no_rec.shape != task_target.shape or task_loss_rec.shape != task_target.shape: + raise ValueError( + "task_loss_fn must return elementwise losses with the same " + "shape as task_target (use reduction='none')." + ) + + blended_task_loss = (task_target * task_loss_rec + (1 - task_target) * task_loss_no_rec).mean() + + return self.concept_weight * concept_loss + self.task_weight * blended_task_loss + + def forward(self, **kwargs) -> torch.Tensor: + """Compute CMR loss from explicit CMR tensors only.""" + explicit_keys = { + 'concept_input', + 'concept_target', + 'task_input', + 'task_input_with_rec', + 'task_target', + } + + if explicit_keys.issubset(kwargs.keys()): + return self._compute_explicit( + concept_input=kwargs['concept_input'], + concept_target=kwargs['concept_target'], + task_input=kwargs['task_input'], + task_input_with_rec=kwargs['task_input_with_rec'], + task_target=kwargs['task_target'], + ) + + raise ValueError( + "CMRLoss.forward requires explicit CMR tensors: " + "concept_input, concept_target, task_input, task_input_with_rec, task_target." + ) \ No newline at end of file diff --git a/torch_concepts/nn/modules/low/encoders/selector.py b/torch_concepts/nn/modules/low/encoders/selector.py index f525424c..0f546eb9 100644 --- a/torch_concepts/nn/modules/low/encoders/selector.py +++ b/torch_concepts/nn/modules/low/encoders/selector.py @@ -12,6 +12,102 @@ from ..base.layer import BaseEncoder +class CategoricalSelectorLatentToExogenous(BaseEncoder): + """ + Categorical selector that outputs concept-wise assignment probabilities. + + This module maps latent inputs to logits of shape + ``(batch_size, out_concepts, out_exogenous)`` and applies a softmax over + the exogenous dimension to produce normalized mixing probabilities. + + Attributes: + out_exogenous (int): Hidden width used in the selector MLP. + out_concepts (int): Number of output concepts. + selector_hidden_layers (int): Number of hidden layers in the selector MLP. + selector (nn.Sequential): Attention network for memory selection. + + Args: + in_latent: Number of input latent features. + out_exogenous: Number of output exogenous features. + out_concepts: Number of output concept representations. + selector_hidden_layers: Number of hidden layers in the selector MLP. + Must be >= 0. + *args: Additional positional arguments for linear layers in the selector. + **kwargs: Additional keyword arguments for linear layers in the selector. + + References: + Debot et al. "Interpretable Concept-Based Memory Reasoning", NeurIPS 2024. https://arxiv.org/abs/2407.15527 + """ + def __init__( + self, + in_latent: int, + out_exogenous: int, # nb_rules + out_concepts: int, # nb_tasks + selector_hidden_layers: int = 1, + *args, + **kwargs, + ): + """ + Initialize the categorical selector. + + Args: + in_latent: Number of input latent features. + out_exogenous: Number of output exogenous features. + out_concepts: Number of output concepts. + selector_hidden_layers: Number of hidden layers in the selector + MLP. Must be >= 0. + *args: Additional positional arguments for linear layers in the selector. + **kwargs: Additional keyword arguments for linear layers in the selector. + """ + super().__init__( + in_latent=in_latent, + out_concepts=out_concepts, + ) + if selector_hidden_layers < 0: + raise ValueError("selector_hidden_layers must be >= 0") + + self.out_exogenous = out_exogenous + self.out_concepts = out_concepts + self.selector_hidden_layers = selector_hidden_layers + self._selector_out_shape = (out_concepts, out_exogenous) + self._selector_out_dim = np.prod(self._selector_out_shape).item() + + selector_layers = [] + in_features = in_latent + for _ in range(selector_hidden_layers): + selector_layers.extend([ + torch.nn.Linear(in_features, in_latent, *args, **kwargs), + torch.nn.ReLU(), + ]) + in_features = in_latent + selector_layers.extend([ + torch.nn.Linear(in_features, self._selector_out_dim, *args, **kwargs), + torch.nn.Unflatten(-1, self._selector_out_shape), + ]) + self.selector = torch.nn.Sequential(*selector_layers) + + def forward( + self, + latent: torch.Tensor, + ) -> torch.Tensor: + """ + Compute concept-wise mixing probabilities from latent input. + + Applies the selector MLP and normalizes logits with softmax over + the exogenous axis. + + Args: + latent: Input latent of shape (batch_size, in_latent). + + Returns: + torch.Tensor: Mixing probabilities of shape + (batch_size, out_concepts, out_exogenous). + """ + mixing_coeff = self.selector(latent) + mixing_probs = torch.softmax(mixing_coeff, dim=-1) # [Batch x Task x Memory] + return mixing_probs + + class SelectorLatentToExogenous(BaseEncoder): """ Memory-based selector for exogenous variables with attention mechanism. diff --git a/torch_concepts/nn/modules/low/predictors/exogenous.py b/torch_concepts/nn/modules/low/predictors/exogenous.py index cfa91c1b..f4823e58 100644 --- a/torch_concepts/nn/modules/low/predictors/exogenous.py +++ b/torch_concepts/nn/modules/low/predictors/exogenous.py @@ -140,3 +140,134 @@ def forward( groups=list(self.cardinalities_expanded), ) return self.predictor(c_mix.flatten(start_dim=1)) + + +class MixMemoryConceptExogenousToConcept(BasePredictor): + """ + Memory-based concept-to-task predictor used in Concept-based Memory Reasoner. + + This predictor combines concept probabilities with rule selection probabilities + and a learned task-specific rule memory. Each task has an embedding that is + decoded into rule parameters with 3 coefficients per concept-rule pair. + Input concepts are treated as Bernoulli random variables. + + Main reference: "Interpretable Concept-Based Memory Reasoning" + (Debot et al., NeurIPS 2024). + + Attributes: + in_concepts (int): Number of input concepts. + in_exogenous (int): Number of rules per output concept. + out_concepts (int): Number of output concepts. + eps (float): Small scaling factor to avoid floating-point problems with the softmax. + memory_decoder_hidden_layers (int): Number of hidden layers in + ``memory_decoder``. + memory_network_shape (tuple[int, int, int]): Decoded memory tensor shape + ``(in_exogenous, in_concepts, 3)``. + memory (nn.Embedding): Learnable task memory embeddings. + memory_decoder (nn.Sequential): Decoder from memory embeddings to rule parameters. + + Args: + in_concepts: Number of input concepts. + in_exogenous: Number of rules per output concept. + out_concepts: Number of output tasks. + memory_latent_size: Size of the learned task memory embedding. + memory_decoder_hidden_layers: Number of hidden layers in the memory + decoder MLP. Must be >= 0. + eps: Scaling factor applied after memory softmax to avoid floating-point error problems. + + Input Shapes: + - ``concepts``: ``(batch_size, in_concepts)`` + - ``exogenous``: ``(batch_size, out_concepts, in_exogenous)`` + + Output Shape: + - ``(batch_size, out_concepts)`` task probabilities. + + References: + Debot et al. "Interpretable Concept-Based Memory Reasoning", NeurIPS 2024. https://arxiv.org/abs/2407.15527 + """ + def __init__( + self, + in_concepts: int, # concepts + in_exogenous: int, # rules + out_concepts: int, # tasks + memory_latent_size: int = 100, # size of the learned rule memory latent space + memory_decoder_hidden_layers: int = 1, + eps: float = 0.001 + ): + super().__init__( + in_concepts=in_concepts, + in_exogenous=in_exogenous, + out_concepts=out_concepts, + ) + + if memory_decoder_hidden_layers < 0: + raise ValueError("memory_decoder_hidden_layers must be >= 0") + + self.eps = eps + self.memory_decoder_hidden_layers = memory_decoder_hidden_layers + + self.memory_network_shape = (in_exogenous, in_concepts, 3) # nb_rules, nb_concepts, 3 (for 3 memory slots per concept) + self.memory_network_latent = self.memory_network_shape[0] * self.memory_network_shape[1] * self.memory_network_shape[2] + + self.memory = torch.nn.Embedding(out_concepts, memory_latent_size) + + decoder_layers = [torch.nn.Linear(memory_latent_size, self.memory_network_latent)] + for _ in range(memory_decoder_hidden_layers): + decoder_layers.extend([ + torch.nn.LeakyReLU(), + torch.nn.Linear(self.memory_network_latent, self.memory_network_latent), + ]) + decoder_layers.append(torch.nn.Unflatten(-1, self.memory_network_shape)) + self.memory_decoder = torch.nn.Sequential(*decoder_layers) + + def forward( + self, + concepts: torch.Tensor, + exogenous: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through the predictor. + + Args: + concepts: Concept logits of shape (batch_size, in_concepts) representing Bernoulli random variables. + exogenous: Concept exogenous of shape (batch_size, out_concepts, exogenous_dim) representing rule selection probabilities. + **kwargs: Optional controls: + - include_rec (bool): If True, include a rule reconstruction-quality term. + - rec_weight (float): Exponent applied to the reconstruction-quality term. + - hard_roles (bool): If True, discretize decoded memory roles with argmax. + + Returns: + torch.Tensor: Task probabilities of shape (batch_size, out_concepts). + + Note: + Concept probabilities are detached from the graph in this method, + so gradients do not flow from the task loss back into ``concepts``, avoiding task leakage. + """ + include_rec = kwargs.get("include_rec", False) + rec_weight = kwargs.get("rec_weight", 1.0) + hard_roles = kwargs.get("hard_roles", False) + + c_probs = torch.sigmoid(concepts).detach() + c_probs_expanded = c_probs.unsqueeze(1).unsqueeze(1).expand(-1, exogenous.shape[1], exogenous.shape[2], -1) # (batch_size, out_concepts, in_exogenous, in_concepts) + + # Decode the memory + memory_decoded = self.memory_decoder(self.memory.weight) # (out_concepts, in_exogenous, in_concepts, 3) = (nb_tasks, nb_rules, nb_concepts, 3) + memory_decoded = (1-self.eps) * torch.softmax(memory_decoded, dim=-1) + memory_decoded_expanded = memory_decoded.unsqueeze(0).expand(concepts.shape[0], -1, -1, -1, -1) # (batch_size, out_concepts, in_exogenous, in_concepts, 3) + + if hard_roles: + role_indices = torch.argmax(memory_decoded_expanded, dim=-1) # (batch_size, out_concepts, in_exogenous, in_concepts) + memory_decoded_expanded = torch.nn.functional.one_hot(role_indices, num_classes=3).float() # (batch_size, out_concepts, in_exogenous, in_concepts, 3) + + # Use CMR's inference equations to compute each task using the predicted concepts, the decoded memory (rules), and the exogenous (rule selection probabilities) + y_per_rule = (c_probs_expanded * memory_decoded_expanded[..., 0] + (1-c_probs_expanded) * memory_decoded_expanded[..., 1] + memory_decoded_expanded[..., 2]).prod(dim=3) # (batch_size, out_concepts, in_exogenous) + if include_rec: + y_rec_per_rule = (c_probs_expanded * memory_decoded_expanded[..., 0] + (1-c_probs_expanded) * memory_decoded_expanded[..., 1] + 0.5 * memory_decoded_expanded[..., 2]).prod(dim=3) # (batch_size, out_concepts, in_exogenous) + y_rec_per_rule = torch.pow(y_rec_per_rule + 1e-6, rec_weight) + else: + y_rec_per_rule = torch.ones_like(y_per_rule) # dummy + + y_pred = (y_per_rule * y_rec_per_rule * exogenous).sum(dim=2) # (batch_size, out_concepts) + + return y_pred