Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions doc/modules/high_level_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ A CBM learns interpretable concept representations and uses them to predict task
}
)

**Concept Memory Reasoner (CMR)**

A neurosymbolic concept-based model with task-specific rule selection and reasoning:

.. code-block:: python

from torch_concepts.nn import ConceptMemoryReasoner

model = ConceptMemoryReasoner(
input_size=2048,
annotations=annotations,
task_names=['class_A', 'class_B'],
n_rules=10,
memory_latent_size=100,
memory_decoder_hidden_layers=1,
selector_hidden_layers=1,
eps=1e-3,
)

**BlackBox Model**

A standard neural network for comparison baselines:
Expand Down
6 changes: 6 additions & 0 deletions doc/modules/nn.encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Summary
LinearExogenousToConcept
StochasticLatentToConcept
LinearLatentToExogenous
CategoricalSelectorLatentToExogenous
SelectorLatentToExogenous


Expand All @@ -44,6 +45,11 @@ Class Documentation
:undoc-members:
:show-inheritance:

.. autoclass:: CategoricalSelectorLatentToExogenous
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: SelectorLatentToExogenous
:members:
:undoc-members:
Expand Down
6 changes: 6 additions & 0 deletions doc/modules/nn.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Summary

ConceptLoss
WeightedConceptLoss
CMRLoss

**Low-Level Losses**

Expand All @@ -41,6 +42,11 @@ Class Documentation
:undoc-members:
:show-inheritance:

.. autoclass:: CMRLoss
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: WeightedBCEWithLogitsLoss
:members:
:undoc-members:
Expand Down
6 changes: 6 additions & 0 deletions doc/modules/nn.models.high.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Summary

ConceptBottleneckModel
ConceptEmbeddingModel
ConceptMemoryReasoner
BlackBox
BlackBoxTaskOnly

Expand All @@ -33,6 +34,11 @@ Class Documentation
:undoc-members:
:show-inheritance:

.. autoclass:: ConceptMemoryReasoner
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: BlackBox
:members:
:undoc-members:
Expand Down
6 changes: 6 additions & 0 deletions doc/modules/nn.predictors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Summary

LinearConceptToConcept
MixConceptExogegnousToConcept
MixMemoryConceptExogenousToConcept
HyperlinearConceptExogenousToConcept
CallableConceptToConcept

Expand All @@ -33,6 +34,11 @@ Class Documentation
:undoc-members:
:show-inheritance:

.. autoclass:: MixMemoryConceptExogenousToConcept
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: HyperlinearConceptExogenousToConcept
:members:
:undoc-members:
Expand Down
105 changes: 105 additions & 0 deletions examples/utilization/0_layer/7_concept_based_memory_reasoner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
Example: Concept Memory Reasoner with Low-Level API

This example demonstrates how to build a Concept Memory Reasoner (CMR)
using the low-level encoder and predictor layers.
"""
import torch
from sklearn.metrics import accuracy_score
from torch.nn import ModuleDict

from torch_concepts import seed_everything
from torch_concepts.data.datasets import ToyDataset
from torch_concepts.nn import (
CategoricalSelectorLatentToExogenous,
LinearLatentToConcept,
MixMemoryConceptExogenousToConcept,
)


def main():
latent_dims = 10
n_epochs = 500
n_samples = 1000
nb_rules = 10
memory_latent_size = 100
rec_weight = 0.1

seed_everything(42)

# Load dataset
dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples)
x_train = dataset.input_data
concept_idx = list(dataset.graph.edge_index[0].unique().numpy())
task_idx = list(dataset.graph.edge_index[1].unique().numpy())
c_train = dataset.concepts[:, concept_idx]
y_train = dataset.concepts[:, task_idx]

# Get dimensions
n_features = x_train.shape[1]
n_concepts = c_train.shape[1]
n_tasks = y_train.shape[1]

# Build model using low-level layers
latent_encoder = torch.nn.Sequential(
torch.nn.Linear(n_features, latent_dims),
torch.nn.LeakyReLU(),
)

exog_encoder = CategoricalSelectorLatentToExogenous(
in_latent=latent_dims,
out_concepts=n_tasks,
out_exogenous=nb_rules,
)

c_encoder = LinearLatentToConcept(in_latent=latent_dims, out_concepts=n_concepts)

y_predictor = MixMemoryConceptExogenousToConcept(
in_concepts=n_concepts,
in_exogenous=nb_rules,
out_concepts=n_tasks,
memory_latent_size=memory_latent_size,
eps=0.001,
)
model = ModuleDict(
{"latent_encoder": latent_encoder,
"exog_encoder": exog_encoder,
"concept_encoder": c_encoder,
"task_predictor": y_predictor}
)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn_y = torch.nn.BCELoss(reduction='none')
loss_fn_c = torch.nn.BCEWithLogitsLoss()
model.train()

for epoch in range(n_epochs):
optimizer.zero_grad()

# Generate concept and task predictions
emb = latent_encoder(x_train)
exog = exog_encoder(latent=emb)
c_pred = c_encoder(latent=emb)
y_pred = y_predictor(concepts=c_pred, exogenous=exog)
y_pred_with_rec = y_predictor(concepts=c_pred, exogenous=exog, include_rec=True, rec_weight=rec_weight)

# Compute loss
concept_loss = loss_fn_c(c_pred, c_train)
task_loss_no_rec = loss_fn_y(y_pred, y_train)
task_loss_rec = loss_fn_y(y_pred_with_rec, y_train)
task_loss = (y_train * task_loss_rec + (1 - y_train) * task_loss_no_rec).mean() # only apply rec loss to positive samples
loss = concept_loss + task_loss

loss.backward()
optimizer.step()

if epoch % 100 == 0:
task_accuracy = accuracy_score(y_train, y_pred.detach() > 0.5)
concept_accuracy = accuracy_score(c_train, c_pred.detach() > 0.)
print(f"Epoch {epoch}: Loss {loss.item():.2f} | Task Acc: {task_accuracy:.2f} | Concept Acc: {concept_accuracy:.2f}")

return


if __name__ == "__main__":
main()
45 changes: 44 additions & 1 deletion examples/utilization/2.2_model/10_different_training_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

import torch
from torch_concepts import seed_everything
from torch_concepts.nn import ConceptBottleneckModel, ConceptEmbeddingModel
from torch_concepts.nn import (
CMRLoss,
ConceptBottleneckModel,
ConceptEmbeddingModel,
ConceptMemoryReasoner,
)
from torch_concepts.nn.modules.mid.inference import (
DeterministicInference,
IndependentInference
Expand Down Expand Up @@ -202,6 +207,44 @@ def main():
trainer_cem.fit(model_cem, datamodule=datamodule)
evaluate(model_cem, datamodule, n_concepts, query)

# =========================================================================
# CMR WITH JOINT TRAINING
# =========================================================================
print("\n" + "=" * 60)
print("Example 4: CMR with Joint Training")
print("=" * 60)
print("Uses DeterministicInference for both training and evaluation")

cmr_loss = CMRLoss()
optim_kwargs_cmr = {'lr': 0.01}

model_cmr = ConceptMemoryReasoner(
input_size=n_features,
annotations=annotations,
variable_distributions=variable_distributions,
task_names=['xor'],
n_rules=10,
memory_latent_size=100,
memory_decoder_hidden_layers=1,
selector_hidden_layers=1,
rec_weight=0.1,
eps=1e-3,
latent_encoder_kwargs={'hidden_size': 16, 'n_layers': 1},
inference=DeterministicInference,
train_inference=DeterministicInference,
lightning=True,
loss=cmr_loss,
optim_class=optim,
optim_kwargs=optim_kwargs_cmr,
)
print(f"Model type: {type(model_cmr).__name__}")
print(f"Eval inference: {model_cmr.eval_inference.__class__.__name__}")
print(f"Training inference: {model_cmr.train_inference.__class__.__name__}")

trainer_cmr = Trainer(max_epochs=100)
trainer_cmr.fit(model_cmr, datamodule=datamodule)
evaluate(model_cmr, datamodule, n_concepts, query)


if __name__ == "__main__":
main()
Loading