diff --git a/src/prototorch/models/callbacks.py b/src/prototorch/models/callbacks.py index b58ccc1..e0d8fae 100644 --- a/src/prototorch/models/callbacks.py +++ b/src/prototorch/models/callbacks.py @@ -10,7 +10,8 @@ from .extras import ConnectionTopology if TYPE_CHECKING: - from prototorch.models import GLVQ, GrowingNeuralGas + from prototorch.models.glvq import GLVQ + from prototorch.models.knn import GrowingNeuralGas class PruneLoserPrototypes(pl.Callback): @@ -61,7 +62,7 @@ def on_train_epoch_end(self, trainer, pl_module: "GLVQ"): return_counts=True) distribution = dict(zip(labels.tolist(), counts.tolist())) - logging.info(f"Re-adding pruned prototypes...") + logging.info("Re-adding pruned prototypes...") logging.debug(f"distribution={distribution}") pl_module.add_prototypes( diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 4328b10..941f442 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -1,13 +1,15 @@ """Models based on the GLVQ framework.""" import torch +from numpy.typing import NDArray from prototorch.core.competitions import wtac from prototorch.core.distances import ( + ML_omega_distance, lomega_distance, omega_distance, squared_euclidean_distance, ) -from prototorch.core.initializers import EyeLinearTransformInitializer +from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer from prototorch.core.losses import ( GLVQLoss, lvq1_loss, @@ -15,7 +17,7 @@ ) from prototorch.core.transforms import LinearTransform from prototorch.nn.wrappers import LambdaLayer, LossLayer -from torch.nn.parameter import Parameter +from torch.nn import Parameter, ParameterList from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .extras import ltangent_distance, orthogonalization @@ -175,6 +177,7 @@ class GRLVQ(SiameseGLVQ): TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. """ + _relevances: torch.Tensor def __init__(self, hparams, **kwargs): @@ -229,6 +232,51 @@ def lambda_matrix(self): return lam.detach().cpu() +class GMLMLVQ(GLVQ): + """Generalized Multi-Layer Matrix Learning Vector Quantization. + Masks are applied to the omega layers to achieve sparsity and constrain + learning to certain items of each omega. + + Implemented as a regular GLVQ network that simply uses a different distance + function. This makes it easier to implement a localized variant. + """ + + # Parameters + _omegas: list[torch.Tensor] + masks: list[torch.Tensor] + + def __init__(self, hparams, **kwargs): + distance_fn = kwargs.pop("distance_fn", ML_omega_distance) + super().__init__(hparams, distance_fn=distance_fn, **kwargs) + + # Additional parameters + self._masks = ParameterList([ + Parameter(mask, requires_grad=False) + for mask in kwargs.get("masks") + ]) + self._omegas = ParameterList( + [LLTI(mask).generate(1, 1) for mask in self._masks]) + + @property + def omega_matrices(self): + return [_omega.detach().cpu() for _omega in self._omegas] + + @property + def lambda_matrix(self): + # TODO update to respective lambda calculation rules. + omega = self._omega.detach() # (input_dim, latent_dim) + lam = omega @ omega.T + return lam.detach().cpu() + + def compute_distances(self, x): + protos, _ = self.proto_layer() + distances = self.distance_layer(x, protos, self._omegas, self._masks) + return distances + + def extra_repr(self): + return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})" + + class GMLVQ(GLVQ): """Generalized Matrix Learning Vector Quantization.