From 824dfced92e85826babfba409a29920fd738462f Mon Sep 17 00:00:00 2001 From: julius Date: Fri, 3 Nov 2023 14:59:00 +0100 Subject: [PATCH 1/6] Implement a prototypical 2-layer version of GMLVQ --- src/prototorch/models/glvq.py | 48 ++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 4328b10..1ba9f06 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -5,9 +5,10 @@ from prototorch.core.distances import ( lomega_distance, omega_distance, + ML_omega_distance, squared_euclidean_distance, ) -from prototorch.core.initializers import EyeLinearTransformInitializer +from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI) from prototorch.core.losses import ( GLVQLoss, lvq1_loss, @@ -229,6 +230,51 @@ def lambda_matrix(self): return lam.detach().cpu() +class GMLMLVQ(GLVQ): + """Generalized Multi-Layer Matrix Learning Vector Quantization. + + Implemented as a regular GLVQ network that simply uses a different distance + function. This makes it easier to implement a localized variant. + """ + + # Parameters + _omega: torch.Tensor + + def __init__(self, hparams, **kwargs): + distance_fn = kwargs.pop("distance_fn", ML_omega_distance) + super().__init__(hparams, distance_fn=distance_fn, **kwargs) + + # Additional parameters + omega_initializer = kwargs.get("omega_initializer") + masks = kwargs.get("masks") + omega_0 = LLTI(masks[0]).generate(1, 1) + omega_1 = LLTI(masks[1]).generate(1, 1) + self.register_parameter("_omega_0", Parameter(omega_0)) + self.register_parameter("_omega_1", Parameter(omega_1)) + self.mask_0 = masks[0] + self.mask_1 = masks[1] + + @property + def omega_matrix(self): + return self._omega.detach().cpu() + + @property + def lambda_matrix(self): + omega = self._omega.detach() # (input_dim, latent_dim) + lam = omega @ omega.T + return lam.detach().cpu() + + def compute_distances(self, x): + protos, _ = self.proto_layer() + distances = self.distance_layer(x, protos, self._omega_0, + self._omega_1, self.mask_0, + self.mask_1) + return distances + + def extra_repr(self): + return f"(omega): (shape: {tuple(self._omega.shape)})" + + class GMLVQ(GLVQ): """Generalized Matrix Learning Vector Quantization. From 1786031b4e339d4a3f2da9d509c1b471d4534cd8 Mon Sep 17 00:00:00 2001 From: julius Date: Mon, 6 Nov 2023 16:32:57 +0100 Subject: [PATCH 2/6] adjust omega_matrix property --- src/prototorch/models/glvq.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 1ba9f06..4de706b 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -238,7 +238,8 @@ class GMLMLVQ(GLVQ): """ # Parameters - _omega: torch.Tensor + _omega_0: torch.Tensor + _omega_1: torch.Tensor def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", ML_omega_distance) @@ -255,8 +256,8 @@ def __init__(self, hparams, **kwargs): self.mask_1 = masks[1] @property - def omega_matrix(self): - return self._omega.detach().cpu() + def omega_matrices(self): + return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()] @property def lambda_matrix(self): From c6f718a1d4c010734bef187d255df189a7616c4f Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 16:44:13 +0100 Subject: [PATCH 3/6] GMLMLVQ: allow for 2 or more omega layers --- src/prototorch/models/glvq.py | 88 ++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 4de706b..7709e93 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -1,14 +1,17 @@ """Models based on the GLVQ framework.""" +from typing import LiteralString + import torch +from numpy.typing import NDArray from prototorch.core.competitions import wtac from prototorch.core.distances import ( + ML_omega_distance, lomega_distance, omega_distance, - ML_omega_distance, squared_euclidean_distance, ) -from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI) +from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer from prototorch.core.losses import ( GLVQLoss, lvq1_loss, @@ -16,7 +19,7 @@ ) from prototorch.core.transforms import LinearTransform from prototorch.nn.wrappers import LambdaLayer, LossLayer -from torch.nn.parameter import Parameter +from torch.nn import Parameter, ParameterList from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .extras import ltangent_distance, orthogonalization @@ -46,26 +49,28 @@ def __init__(self, hparams, **kwargs): def initialize_prototype_win_ratios(self): self.register_buffer( - "prototype_win_ratios", - torch.zeros(self.num_prototypes, device=self.device)) + "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device) + ) def on_train_epoch_start(self): self.initialize_prototype_win_ratios() def log_prototype_win_ratios(self, distances): batch_size = len(distances) - prototype_wc = torch.zeros(self.num_prototypes, - dtype=torch.long, - device=self.device) - wi, wc = torch.unique(distances.min(dim=-1).indices, - sorted=True, - return_counts=True) + prototype_wc = torch.zeros( + self.num_prototypes, dtype=torch.long, device=self.device + ) + wi, wc = torch.unique( + distances.min(dim=-1).indices, sorted=True, return_counts=True + ) prototype_wc[wi] = wc prototype_wr = prototype_wc / batch_size - self.prototype_win_ratios = torch.vstack([ - self.prototype_win_ratios, - prototype_wr, - ]) + self.prototype_win_ratios = torch.vstack( + [ + self.prototype_win_ratios, + prototype_wr, + ] + ) def shared_step(self, batch, batch_idx): x, y = batch @@ -110,11 +115,9 @@ class SiameseGLVQ(GLVQ): """ - def __init__(self, - hparams, - backbone=torch.nn.Identity(), - both_path_gradients=False, - **kwargs): + def __init__( + self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs + ): distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) self.backbone = backbone @@ -176,6 +179,7 @@ class GRLVQ(SiameseGLVQ): TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. """ + _relevances: torch.Tensor def __init__(self, hparams, **kwargs): @@ -186,8 +190,7 @@ def __init__(self, hparams, **kwargs): self.register_parameter("_relevances", Parameter(relevances)) # Override the backbone - self.backbone = LambdaLayer(self._apply_relevances, - name="relevance scaling") + self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling") def _apply_relevances(self, x): return x @ torch.diag(self._relevances) @@ -211,8 +214,9 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) # Override the backbone - omega_initializer = kwargs.get("omega_initializer", - EyeLinearTransformInitializer()) + omega_initializer = kwargs.get( + "omega_initializer", EyeLinearTransformInitializer() + ) self.backbone = LinearTransform( self.hparams["input_dim"], self.hparams["latent_dim"], @@ -232,48 +236,46 @@ def lambda_matrix(self): class GMLMLVQ(GLVQ): """Generalized Multi-Layer Matrix Learning Vector Quantization. + Masks are applied to the omega layers to achieve sparsity and constrain + learning to certain items of each omega. Implemented as a regular GLVQ network that simply uses a different distance function. This makes it easier to implement a localized variant. """ # Parameters - _omega_0: torch.Tensor - _omega_1: torch.Tensor + _omegas: list[torch.Tensor] + masks: list[torch.Tensor] def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", ML_omega_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer") masks = kwargs.get("masks") - omega_0 = LLTI(masks[0]).generate(1, 1) - omega_1 = LLTI(masks[1]).generate(1, 1) - self.register_parameter("_omega_0", Parameter(omega_0)) - self.register_parameter("_omega_1", Parameter(omega_1)) - self.mask_0 = masks[0] - self.mask_1 = masks[1] + for i, _mask in enumerate(masks): + self.register_buffer(f"_mask_{i}", _mask) + self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)] + self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) @property def omega_matrices(self): - return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()] + return [_omega.detach().cpu() for _omega in self._omegas] @property def lambda_matrix(self): + # TODO update to respective lambda calculation rules. omega = self._omega.detach() # (input_dim, latent_dim) lam = omega @ omega.T return lam.detach().cpu() def compute_distances(self, x): protos, _ = self.proto_layer() - distances = self.distance_layer(x, protos, self._omega_0, - self._omega_1, self.mask_0, - self.mask_1) + distances = self.distance_layer(x, protos, self._omegas, self._masks) return distances def extra_repr(self): - return f"(omega): (shape: {tuple(self._omega.shape)})" + return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})" class GMLVQ(GLVQ): @@ -292,10 +294,12 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer", - EyeLinearTransformInitializer()) - omega = omega_initializer.generate(self.hparams["input_dim"], - self.hparams["latent_dim"]) + omega_initializer = kwargs.get( + "omega_initializer", EyeLinearTransformInitializer() + ) + omega = omega_initializer.generate( + self.hparams["input_dim"], self.hparams["latent_dim"] + ) self.register_parameter("_omega", Parameter(omega)) @property From 78f8b6cc00e8140bcc83439cdff8fefd694fbf57 Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 18:52:51 +0100 Subject: [PATCH 4/6] remove accidental LiteralString import --- src/prototorch/models/glvq.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 7709e93..cc187e7 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -1,7 +1,5 @@ """Models based on the GLVQ framework.""" -from typing import LiteralString - import torch from numpy.typing import NDArray from prototorch.core.competitions import wtac @@ -255,7 +253,7 @@ def __init__(self, hparams, **kwargs): masks = kwargs.get("masks") for i, _mask in enumerate(masks): self.register_buffer(f"_mask_{i}", _mask) - self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)] + self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)] self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) @property From adafb4998593ced0308e5ae093496b20a12fef74 Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 19:17:43 +0100 Subject: [PATCH 5/6] masks -> ParameterList(requires_grad=False) --- src/prototorch/models/glvq.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index cc187e7..341b4b1 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -250,11 +250,10 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - masks = kwargs.get("masks") - for i, _mask in enumerate(masks): - self.register_buffer(f"_mask_{i}", _mask) - self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)] - self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) + self._masks = ParameterList( + [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")] + ) + self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks]) @property def omega_matrices(self): From 1fe2951ebf398b35661bdff13717632e9254ac7f Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 30 Apr 2024 09:56:39 +0200 Subject: [PATCH 6/6] style: ran pre-commit --- src/prototorch/models/callbacks.py | 5 ++- src/prototorch/models/glvq.py | 62 +++++++++++++++--------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/prototorch/models/callbacks.py b/src/prototorch/models/callbacks.py index b58ccc1..e0d8fae 100644 --- a/src/prototorch/models/callbacks.py +++ b/src/prototorch/models/callbacks.py @@ -10,7 +10,8 @@ from .extras import ConnectionTopology if TYPE_CHECKING: - from prototorch.models import GLVQ, GrowingNeuralGas + from prototorch.models.glvq import GLVQ + from prototorch.models.knn import GrowingNeuralGas class PruneLoserPrototypes(pl.Callback): @@ -61,7 +62,7 @@ def on_train_epoch_end(self, trainer, pl_module: "GLVQ"): return_counts=True) distribution = dict(zip(labels.tolist(), counts.tolist())) - logging.info(f"Re-adding pruned prototypes...") + logging.info("Re-adding pruned prototypes...") logging.debug(f"distribution={distribution}") pl_module.add_prototypes( diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 341b4b1..941f442 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -47,28 +47,26 @@ def __init__(self, hparams, **kwargs): def initialize_prototype_win_ratios(self): self.register_buffer( - "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device) - ) + "prototype_win_ratios", + torch.zeros(self.num_prototypes, device=self.device)) def on_train_epoch_start(self): self.initialize_prototype_win_ratios() def log_prototype_win_ratios(self, distances): batch_size = len(distances) - prototype_wc = torch.zeros( - self.num_prototypes, dtype=torch.long, device=self.device - ) - wi, wc = torch.unique( - distances.min(dim=-1).indices, sorted=True, return_counts=True - ) + prototype_wc = torch.zeros(self.num_prototypes, + dtype=torch.long, + device=self.device) + wi, wc = torch.unique(distances.min(dim=-1).indices, + sorted=True, + return_counts=True) prototype_wc[wi] = wc prototype_wr = prototype_wc / batch_size - self.prototype_win_ratios = torch.vstack( - [ - self.prototype_win_ratios, - prototype_wr, - ] - ) + self.prototype_win_ratios = torch.vstack([ + self.prototype_win_ratios, + prototype_wr, + ]) def shared_step(self, batch, batch_idx): x, y = batch @@ -113,9 +111,11 @@ class SiameseGLVQ(GLVQ): """ - def __init__( - self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs - ): + def __init__(self, + hparams, + backbone=torch.nn.Identity(), + both_path_gradients=False, + **kwargs): distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) self.backbone = backbone @@ -188,7 +188,8 @@ def __init__(self, hparams, **kwargs): self.register_parameter("_relevances", Parameter(relevances)) # Override the backbone - self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling") + self.backbone = LambdaLayer(self._apply_relevances, + name="relevance scaling") def _apply_relevances(self, x): return x @ torch.diag(self._relevances) @@ -212,9 +213,8 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) # Override the backbone - omega_initializer = kwargs.get( - "omega_initializer", EyeLinearTransformInitializer() - ) + omega_initializer = kwargs.get("omega_initializer", + EyeLinearTransformInitializer()) self.backbone = LinearTransform( self.hparams["input_dim"], self.hparams["latent_dim"], @@ -250,10 +250,12 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - self._masks = ParameterList( - [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")] - ) - self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks]) + self._masks = ParameterList([ + Parameter(mask, requires_grad=False) + for mask in kwargs.get("masks") + ]) + self._omegas = ParameterList( + [LLTI(mask).generate(1, 1) for mask in self._masks]) @property def omega_matrices(self): @@ -291,12 +293,10 @@ def __init__(self, hparams, **kwargs): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get( - "omega_initializer", EyeLinearTransformInitializer() - ) - omega = omega_initializer.generate( - self.hparams["input_dim"], self.hparams["latent_dim"] - ) + omega_initializer = kwargs.get("omega_initializer", + EyeLinearTransformInitializer()) + omega = omega_initializer.generate(self.hparams["input_dim"], + self.hparams["latent_dim"]) self.register_parameter("_omega", Parameter(omega)) @property