From aaae396a4800f4a782158e4ded8616c0c9fb4573 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Sat, 27 Jan 2024 13:29:39 +0000 Subject: [PATCH 01/12] Initialised video-classification/vivit --- .../models/video_classification/__init__.py | 13 +++++++++++++ keras_cv/models/video_classification/vivit.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 keras_cv/models/video_classification/__init__.py create mode 100644 keras_cv/models/video_classification/vivit.py diff --git a/keras_cv/models/video_classification/__init__.py b/keras_cv/models/video_classification/__init__.py new file mode 100644 index 0000000000..3992ffb59a --- /dev/null +++ b/keras_cv/models/video_classification/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py new file mode 100644 index 0000000000..4671c9c121 --- /dev/null +++ b/keras_cv/models/video_classification/vivit.py @@ -0,0 +1,19 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend.config import keras_3 From a1da121eed183700507bf77e03b1b66e36e974dc Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Mon, 29 Jan 2024 14:55:37 +0000 Subject: [PATCH 02/12] Initialised ViViT model and add dependent layers --- .../models/video_classification/__init__.py | 2 +- keras_cv/models/video_classification/vivit.py | 137 +++++++++++++++++- .../video_classification/vivit_layers.py | 68 +++++++++ 3 files changed, 203 insertions(+), 4 deletions(-) create mode 100644 keras_cv/models/video_classification/vivit_layers.py diff --git a/keras_cv/models/video_classification/__init__.py b/keras_cv/models/video_classification/__init__.py index 3992ffb59a..0e9cbb5ac9 100644 --- a/keras_cv/models/video_classification/__init__.py +++ b/keras_cv/models/video_classification/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The KerasCV Authors +# Copyright 2024 The KerasCV Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py index 4671c9c121..0a344fc021 100644 --- a/keras_cv/models/video_classification/vivit.py +++ b/keras_cv/models/video_classification/vivit.py @@ -12,8 +12,139 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.backend.config import keras_3 +from keras_cv.models.task import Task + + +@keras_cv_export( + [ + "keras_cv.models.ViViT", + "keras_cv.models.video_classification.ViViT", + ] +) +class ViViT(Task): + """A Keras model implementing a Video Vision Transformer + for video classification. + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + #Example + tubelet_embedder = + keras_cv.layers.TubeletEmbedding( + embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + ) + positional_encoder = + keras_cv.layers.PositionalEncoder( + embed_dim=PROJECTION_DIM + ) + model = keras_cv.models.video_classification.ViViT( + tubelet_embedder, + positional_encoder + ) + + """ + + def __init__( + self, + tubelet_embedder, + positional_encoder, + input_shape, + transformer_layers, + num_heads, + embed_dim, + layer_norm_eps, + num_classes, + **kwargs, + ): + if not isinstance(tubelet_embedder, keras.layers.Layer): + raise ValueError( + "Argument `tubelet_embedder` must be a " + " `keras.layers.Layer` instance " + f" . Received instead " + f"tubelet_embedder={tubelet_embedder} " + f"(of type {type(tubelet_embedder)})." + ) + + if not isinstance(positional_encoder, keras.layers.Layer): + raise ValueError( + "Argument `positional_encoder` must be a " + "`keras.layers.Layer` instance " + f" . Received instead " + f"positional_encoder={positional_encoder} " + f"(of type {type(positional_encoder)})." + ) + + inputs = keras.layers.Input(shape=input_shape) + patches = tubelet_embedder(inputs) + encoded_patches = positional_encoder(patches) + + for _ in range(transformer_layers): + x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + attention_output = keras.layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + dropout=0.1, + )(x1, x1) + + x2 = keras.layers.Add()([attention_output, encoded_patches]) + + x3 = keras.layers.LayerNormalization(epsilon=1e-6)(x2) + x3 = keras.Sequential( + [ + keras.layers.Dense( + units=embed_dim * 4, activation=keras.ops.gelu + ), + keras.layers.Dense( + units=embed_dim, activation=keras.ops.gelu + ), + ] + )(x3) + + encoded_patches = keras.layers.Add()([x3, x2]) + + representation = keras.layers.LayerNormalization( + epsilon=layer_norm_eps + )(encoded_patches) + representation = keras.layers.GlobalAvgPool1D()(representation) + + outputs = keras.layers.Dense(units=num_classes, activation="softmax")( + representation + ) + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + self.num_heads = num_heads + self.num_classes = num_classes + self.tubelet_embedder = tubelet_embedder + self.positional_encoder = positional_encoder + + def get_config(self): + return { + "num_heads": self.num_heads, + "num_classes": self.num_classes, + "tubelet_embedder": keras.saving.serialize_keras_object( + self.tubelet_embedder + ), + "positional_encoder": keras.saving.serialize_keras_object( + self.positional_encoder + ), + } + + @classmethod + def from_config(cls, config): + if "tubelet_embedder" in config and isinstance( + config["tubelet_embedder"], dict + ): + config["tubelet_embedder"] = keras.layers.deserialize( + config["tubelet_embedder"] + ) + if "positional_encoder" in config and isinstance( + config["positional_encoder"], dict + ): + config["positional_encoder"] = keras.layers.deserialize( + config["positional_encoder"] + ) + return super().from_config(config) diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py new file mode 100644 index 0000000000..40420a44c6 --- /dev/null +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -0,0 +1,68 @@ +# Copyright 2024 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import layers +from keras_cv.backend import ops + + +@keras_cv_export( + "keras_cv.layers.TubeletEmebedding", + package="keras_cv.layers", +) +class TubeletEmbedding(layers.Layer): + def __init__(self, embed_dim, patch_size, **kwargs): + super().__init__(**kwargs) + self.projection = layers.Conv3D( + filters=embed_dim, + kernel_size=patch_size, + strides=patch_size, + padding="VALID", + ) + self.flatten = layers.Reshape(target_shape=(-1, embed_dim)) + + def call(self, videos): + projected_patches = self.projection(videos) + flattened_patches = self.flatten(projected_patches) + return flattened_patches + + def get_config(self): + config = super().get_config() + config.update( + {"embed_dim": self.embed_dim, "patch_size": self.patch_size} + ) + return config + + +@keras_cv_export( + "keras_cv.layers.PositionalEncoder", + package="keras_cv.layers", +) +class PositionalEncoder(layers.Layer): + def __init__(self, embed_dim, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + def build(self, input_shape): + _, num_tokens, _ = input_shape + self.position_embedding = layers.Embedding( + input_dim=num_tokens, output_dim=self.embed_dim + ) + self.positions = ops.arange(start=0, stop=num_tokens, step=1) + + def call(self, encoded_tokens): + encoded_positions = self.position_embedding(self.positions) + encoded_tokens = encoded_tokens + encoded_positions + return encoded_tokens From 3ccd176fc18af9d969733babfc5a791e1e29aba1 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Tue, 30 Jan 2024 17:03:20 +0000 Subject: [PATCH 03/12] Updated __init__.py --- keras_cv/models/video_classification/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_cv/models/video_classification/__init__.py b/keras_cv/models/video_classification/__init__.py index 0e9cbb5ac9..320da488c1 100644 --- a/keras_cv/models/video_classification/__init__.py +++ b/keras_cv/models/video_classification/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_cv.models.video_classification.vivit import ViViT From 9a39aa604afdbef18db521991330a552b8cd9e2b Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Tue, 30 Jan 2024 17:04:59 +0000 Subject: [PATCH 04/12] Added model construction and call tests --- .../models/video_classification/vivit_test.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 keras_cv/models/video_classification/vivit_test.py diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py new file mode 100644 index 0000000000..fbcf511567 --- /dev/null +++ b/keras_cv/models/video_classification/vivit_test.py @@ -0,0 +1,80 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_cv.backend import keras +from keras_cv.models.video_classification.vivit import ViViT +from keras_cv.models.video_classification.vivit_layers import PositionalEncoder +from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding +from keras_cv.tests.test_case import TestCase + + +class ViViT_Test(TestCase): + def test_vivit_construction(self): + INPUT_SHAPE = (28, 28, 28, 1) + NUM_CLASSES = 11 + PATCH_SIZE = (8, 8, 8) + LAYER_NORM_EPS = 1e-6 + PROJECTION_DIM = 128 + NUM_HEADS = 8 + NUM_LAYERS = 8 + + model = ViViT( + tubelet_embedder=TubeletEmbedding( + embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + ), + positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), + input_shape=INPUT_SHAPE, + transformer_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + embed_dim=PROJECTION_DIM, + layer_norm_eps=LAYER_NORM_EPS, + num_classes=NUM_CLASSES, + ) + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy( + 5, name="top-5-accuracy" + ), + ], + ) + + def test_vivit_call(self): + INPUT_SHAPE = (28, 28, 28, 1) + NUM_CLASSES = 11 + PATCH_SIZE = (8, 8, 8) + LAYER_NORM_EPS = 1e-6 + PROJECTION_DIM = 128 + NUM_HEADS = 8 + NUM_LAYERS = 8 + + model = ViViT( + tubelet_embedder=TubeletEmbedding( + embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + ), + positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), + input_shape=INPUT_SHAPE, + transformer_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + embed_dim=PROJECTION_DIM, + layer_norm_eps=LAYER_NORM_EPS, + num_classes=NUM_CLASSES, + ) + frames = np.random.uniform(size=(5, 28, 28, 28, 1)) + _ = model(frames) From 30e1c8e65b421886635c328e058dbe8bb82a3451 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Tue, 30 Jan 2024 17:06:34 +0000 Subject: [PATCH 05/12] Updated imports --- keras_cv/models/video_classification/vivit_layers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py index 40420a44c6..63986faf02 100644 --- a/keras_cv/models/video_classification/vivit_layers.py +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -14,7 +14,6 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.backend import layers from keras_cv.backend import ops @@ -22,16 +21,16 @@ "keras_cv.layers.TubeletEmebedding", package="keras_cv.layers", ) -class TubeletEmbedding(layers.Layer): +class TubeletEmbedding(keras.layers.Layer): def __init__(self, embed_dim, patch_size, **kwargs): super().__init__(**kwargs) - self.projection = layers.Conv3D( + self.projection = keras.layers.Conv3D( filters=embed_dim, kernel_size=patch_size, strides=patch_size, padding="VALID", ) - self.flatten = layers.Reshape(target_shape=(-1, embed_dim)) + self.flatten = keras.layers.Reshape(target_shape=(-1, embed_dim)) def call(self, videos): projected_patches = self.projection(videos) @@ -50,14 +49,14 @@ def get_config(self): "keras_cv.layers.PositionalEncoder", package="keras_cv.layers", ) -class PositionalEncoder(layers.Layer): +class PositionalEncoder(keras.layers.Layer): def __init__(self, embed_dim, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim def build(self, input_shape): _, num_tokens, _ = input_shape - self.position_embedding = layers.Embedding( + self.position_embedding = keras.layers.Embedding( input_dim=num_tokens, output_dim=self.embed_dim ) self.positions = ops.arange(start=0, stop=num_tokens, step=1) From 82f06c3f2a5b765f0a87768fdf3b18fc6172e5cf Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Thu, 1 Feb 2024 17:07:36 +0000 Subject: [PATCH 06/12] Added tests --- .../models/video_classification/vivit_test.py | 155 ++++++++++++++---- 1 file changed, 125 insertions(+), 30 deletions(-) diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py index fbcf511567..be8fc02984 100644 --- a/keras_cv/models/video_classification/vivit_test.py +++ b/keras_cv/models/video_classification/vivit_test.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy as np import pytest +import tensorflow as tf from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.backend.config import keras_3 from keras_cv.models.video_classification.vivit import ViViT from keras_cv.models.video_classification.vivit_layers import PositionalEncoder from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding @@ -24,25 +29,25 @@ class ViViT_Test(TestCase): def test_vivit_construction(self): - INPUT_SHAPE = (28, 28, 28, 1) - NUM_CLASSES = 11 - PATCH_SIZE = (8, 8, 8) - LAYER_NORM_EPS = 1e-6 - PROJECTION_DIM = 128 - NUM_HEADS = 8 - NUM_LAYERS = 8 + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 model = ViViT( tubelet_embedder=TubeletEmbedding( - embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + embed_dim=projection_dim, patch_size=patch_size ), - positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), - input_shape=INPUT_SHAPE, - transformer_layers=NUM_LAYERS, - num_heads=NUM_HEADS, - embed_dim=PROJECTION_DIM, - layer_norm_eps=LAYER_NORM_EPS, - num_classes=NUM_CLASSES, + positional_encoder=PositionalEncoder(embed_dim=projection_dim), + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + embed_dim=projection_dim, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, ) model.compile( optimizer="adam", @@ -56,25 +61,115 @@ def test_vivit_construction(self): ) def test_vivit_call(self): - INPUT_SHAPE = (28, 28, 28, 1) - NUM_CLASSES = 11 - PATCH_SIZE = (8, 8, 8) - LAYER_NORM_EPS = 1e-6 - PROJECTION_DIM = 128 - NUM_HEADS = 8 - NUM_LAYERS = 8 + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 model = ViViT( tubelet_embedder=TubeletEmbedding( - embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + embed_dim=projection_dim, patch_size=patch_size ), - positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), - input_shape=INPUT_SHAPE, - transformer_layers=NUM_LAYERS, - num_heads=NUM_HEADS, - embed_dim=PROJECTION_DIM, - layer_norm_eps=LAYER_NORM_EPS, - num_classes=NUM_CLASSES, + positional_encoder=PositionalEncoder(embed_dim=projection_dim), + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + embed_dim=projection_dim, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, ) frames = np.random.uniform(size=(5, 28, 28, 28, 1)) _ = model(frames) + + def test_weights_change(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + frames = np.random.uniform(size=(5, 28, 28, 28, 1)) + labels = np.ones(shape=(5)) + ds = tf.data.Dataset.from_tensor_slices((frames, labels)) + ds = ds.repeat(2) + ds = ds.batch(2) + + model = ViViT( + tubelet_embedder=TubeletEmbedding( + embed_dim=projection_dim, patch_size=patch_size + ), + positional_encoder=PositionalEncoder(embed_dim=projection_dim), + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + embed_dim=projection_dim, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy( + 5, name="top-5-accuracy" + ), + ], + ) + + layer_name = "multi_head_attention_23" + representation_layer = model.get_layer(layer_name) + + original_weights = representation_layer.get_weights() + model.fit(ds, epochs=1) + updated_weights = representation_layer.get_weights() + + for w1, w2 in zip(original_weights, updated_weights): + self.assertNotAllEqual(w1, w2) + self.assertFalse(ops.any(ops.isnan(w2))) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + model = ViViT( + tubelet_embedder=TubeletEmbedding( + embed_dim=projection_dim, patch_size=patch_size + ), + positional_encoder=PositionalEncoder(embed_dim=projection_dim), + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + embed_dim=projection_dim, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + + input_batch = np.random.uniform(size=(5, 28, 28, 28, 1)) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, ViViT) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output) From 2612a04be3cd255f07baf8f469586e37bc918973 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Thu, 1 Feb 2024 17:10:25 +0000 Subject: [PATCH 07/12] Added docs and some minor adjustments --- keras_cv/models/video_classification/vivit.py | 111 +++++++++++++----- 1 file changed, 81 insertions(+), 30 deletions(-) diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py index 0a344fc021..3706438171 100644 --- a/keras_cv/models/video_classification/vivit.py +++ b/keras_cv/models/video_classification/vivit.py @@ -26,37 +26,82 @@ class ViViT(Task): """A Keras model implementing a Video Vision Transformer for video classification. + References: - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) (ICCV 2021) Args: - #Example - tubelet_embedder = - keras_cv.layers.TubeletEmbedding( - embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE - ) - positional_encoder = - keras_cv.layers.PositionalEncoder( - embed_dim=PROJECTION_DIM - ) - model = keras_cv.models.video_classification.ViViT( - tubelet_embedder, - positional_encoder - ) - + tubelet_embedder: 'keras.layers.Layer'. A layer for spatio-temporal tube + embedding applied to input sequences retrieved from video frames. + positional_encoder: 'keras.layers.Layer'. A layer for adding positional + information to the encoded video tokens. + inp_shape: tuple, the shape of the input video frames. + num_classes: int, the number of classes for video classification. + transformer_layers: int, the number of transformer layers in the model. + Defaults to 8. + num_heads: int, the number of heads for multi-head + self-attention mechanism. Defaults to 8. + embed_dim: int, number of dimensions in the embedding space. + Defaults to 128. + layer_norm_eps: float, epsilon value for layer normalization. + Defaults to 1e-6. + + + Examples: + ```python + import keras_cv + + INPUT_SHAPE = (32, 32, 32, 1) + NUM_CLASSES = 11 + PATCH_SIZE = (8, 8, 8) + LAYER_NORM_EPS = 1e-6 + PROJECTION_DIM = 128 + NUM_HEADS = 8 + NUM_LAYERS = 8 + + frames = np.random.uniform(size=(5, 32, 32, 32, 1)) + labels = np.ones(shape=(5)) + model = ViViT( + tubelet_embedder=TubeletEmbedding( + embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE + ), + positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), + inp_shape=INPUT_SHAPE, + transformer_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + embed_dim=PROJECTION_DIM, + layer_norm_eps=LAYER_NORM_EPS, + num_classes=NUM_CLASSES, + ) + + # Evaluate model + model(frames) + + # Train model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + ], + ) + + model.fit(frames, labels, epochs=3) + + ``` """ def __init__( self, tubelet_embedder, positional_encoder, - input_shape, - transformer_layers, - num_heads, - embed_dim, - layer_norm_eps, + inp_shape, num_classes, + transformer_layers=8, + num_heads=8, + embed_dim=128, + layer_norm_eps=1e-6, **kwargs, ): if not isinstance(tubelet_embedder, keras.layers.Layer): @@ -77,7 +122,7 @@ def __init__( f"(of type {type(positional_encoder)})." ) - inputs = keras.layers.Input(shape=input_shape) + inputs = keras.layers.Input(shape=inp_shape) patches = tubelet_embedder(inputs) encoded_patches = positional_encoder(patches) @@ -116,22 +161,28 @@ def __init__( super().__init__(inputs=inputs, outputs=outputs, **kwargs) + self.inp_shape = inp_shape self.num_heads = num_heads self.num_classes = num_classes self.tubelet_embedder = tubelet_embedder self.positional_encoder = positional_encoder def get_config(self): - return { - "num_heads": self.num_heads, - "num_classes": self.num_classes, - "tubelet_embedder": keras.saving.serialize_keras_object( - self.tubelet_embedder - ), - "positional_encoder": keras.saving.serialize_keras_object( - self.positional_encoder - ), - } + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "inp_shape": self.inp_shape, + "num_classes": self.num_classes, + "tubelet_embedder": keras.saving.serialize_keras_object( + self.tubelet_embedder + ), + "positional_encoder": keras.saving.serialize_keras_object( + self.positional_encoder + ), + } + ) + return config @classmethod def from_config(cls, config): From 8099869d489e3b934c9e4d09cf9fa8f1bc13dce3 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Thu, 1 Feb 2024 17:12:05 +0000 Subject: [PATCH 08/12] Updated Documentation and Default Parameters --- .../video_classification/vivit_layers.py | 47 ++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py index 63986faf02..eff877d37a 100644 --- a/keras_cv/models/video_classification/vivit_layers.py +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -22,15 +22,37 @@ package="keras_cv.layers", ) class TubeletEmbedding(keras.layers.Layer): - def __init__(self, embed_dim, patch_size, **kwargs): + """ + A Keras layer for spatio-temporal tube embedding applied to input sequences + retrieved from video frames. + + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + embed_dim: int, number of dimensions in the embedding space. + Defaults to 128. + patch_size: tuple or int, size of the spatio-temporal patch. + If int, the same size is used for all dimensions. + If tuple, specifies the size for each dimension. + Defaults to 8. + + """ + + def __init__(self, embed_dim=128, patch_size=8, **kwargs): super().__init__(**kwargs) + self.embed_dim = embed_dim + self.patch_size = patch_size + + def build(self, input_shape): self.projection = keras.layers.Conv3D( - filters=embed_dim, - kernel_size=patch_size, - strides=patch_size, + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, padding="VALID", ) - self.flatten = keras.layers.Reshape(target_shape=(-1, embed_dim)) + self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim)) def call(self, videos): projected_patches = self.projection(videos) @@ -50,7 +72,20 @@ def get_config(self): package="keras_cv.layers", ) class PositionalEncoder(keras.layers.Layer): - def __init__(self, embed_dim, **kwargs): + """ + A Keras layer for adding positional information to the encoded video tokens. + + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + embed_dim: int, number of dimensions in the embedding space. + Defaults to 128. + + """ + + def __init__(self, embed_dim=128, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim From 13f4829d61bb2ac098b9b499f77ab10d61d48700 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Wed, 7 Feb 2024 16:00:03 +0000 Subject: [PATCH 09/12] Updated comments --- keras_cv/models/video_classification/vivit_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py index be8fc02984..e17e9c75d7 100644 --- a/keras_cv/models/video_classification/vivit_test.py +++ b/keras_cv/models/video_classification/vivit_test.py @@ -167,9 +167,7 @@ def test_saved_model(self): model.save(save_path, save_format="keras_v3") restored_model = keras.models.load_model(save_path) - # Check we got the real object back. self.assertIsInstance(restored_model, ViViT) - # Check that output matches. restored_output = restored_model(input_batch) self.assertAllClose(model_output, restored_output) From 0b6043fbbe46b57d6fcbb73d4ac28cece27a1ffe Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Fri, 23 Feb 2024 12:53:07 +0000 Subject: [PATCH 10/12] Updating parameters and build method --- keras_cv/models/video_classification/vivit.py | 87 +++++++------------ .../video_classification/vivit_layers.py | 45 ++++++---- .../models/video_classification/vivit_test.py | 31 +++---- 3 files changed, 69 insertions(+), 94 deletions(-) diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py index 3706438171..f90c555d96 100644 --- a/keras_cv/models/video_classification/vivit.py +++ b/keras_cv/models/video_classification/vivit.py @@ -15,6 +15,8 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.models.task import Task +from keras_cv.models.video_classification.vivit_layers import PositionalEncoder +from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding @keras_cv_export( @@ -32,17 +34,16 @@ class ViViT(Task): (ICCV 2021) Args: - tubelet_embedder: 'keras.layers.Layer'. A layer for spatio-temporal tube - embedding applied to input sequences retrieved from video frames. - positional_encoder: 'keras.layers.Layer'. A layer for adding positional - information to the encoded video tokens. inp_shape: tuple, the shape of the input video frames. num_classes: int, the number of classes for video classification. transformer_layers: int, the number of transformer layers in the model. Defaults to 8. + patch_size: tuple , contains the size of the + spatio-temporal patches for each dimension + Defaults to (8,8,8) num_heads: int, the number of heads for multi-head self-attention mechanism. Defaults to 8. - embed_dim: int, number of dimensions in the embedding space. + projection_dim: int, number of dimensions in the projection space. Defaults to 128. layer_norm_eps: float, epsilon value for layer normalization. Defaults to 1e-6. @@ -62,11 +63,10 @@ class ViViT(Task): frames = np.random.uniform(size=(5, 32, 32, 32, 1)) labels = np.ones(shape=(5)) + model = ViViT( - tubelet_embedder=TubeletEmbedding( - embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE - ), - positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM), + projection_dim=PROJECTION_DIM, + patch_size=PATCH_SIZE, inp_shape=INPUT_SHAPE, transformer_layers=NUM_LAYERS, num_heads=NUM_HEADS, @@ -94,37 +94,29 @@ class ViViT(Task): def __init__( self, - tubelet_embedder, - positional_encoder, inp_shape, num_classes, + projection_dim=128, + patch_size=(8, 8, 8), transformer_layers=8, num_heads=8, embed_dim=128, layer_norm_eps=1e-6, **kwargs, ): - if not isinstance(tubelet_embedder, keras.layers.Layer): - raise ValueError( - "Argument `tubelet_embedder` must be a " - " `keras.layers.Layer` instance " - f" . Received instead " - f"tubelet_embedder={tubelet_embedder} " - f"(of type {type(tubelet_embedder)})." - ) - - if not isinstance(positional_encoder, keras.layers.Layer): - raise ValueError( - "Argument `positional_encoder` must be a " - "`keras.layers.Layer` instance " - f" . Received instead " - f"positional_encoder={positional_encoder} " - f"(of type {type(positional_encoder)})." - ) + self.projection_dim = projection_dim + self.patch_size = patch_size + self.tubelet_embedder = TubeletEmbedding( + embed_dim=self.projection_dim, patch_size=self.patch_size + ) + + self.positional_encoder = PositionalEncoder( + embed_dim=self.projection_dim + ) inputs = keras.layers.Input(shape=inp_shape) - patches = tubelet_embedder(inputs) - encoded_patches = positional_encoder(patches) + patches = self.tubelet_embedder(inputs) + encoded_patches = self.positional_encoder(patches) for _ in range(transformer_layers): x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches) @@ -164,8 +156,15 @@ def __init__( self.inp_shape = inp_shape self.num_heads = num_heads self.num_classes = num_classes - self.tubelet_embedder = tubelet_embedder - self.positional_encoder = positional_encoder + self.projection_dim = projection_dim + self.patch_size = patch_size + + def build(self, input_shape): + self.tubelet_embedder.build(input_shape) + flattened_patch_shape = self.tubelet_embedder.compute_output_shape( + input_shape + ) + self.positional_encoder.build(flattened_patch_shape) def get_config(self): config = super().get_config() @@ -174,28 +173,8 @@ def get_config(self): "num_heads": self.num_heads, "inp_shape": self.inp_shape, "num_classes": self.num_classes, - "tubelet_embedder": keras.saving.serialize_keras_object( - self.tubelet_embedder - ), - "positional_encoder": keras.saving.serialize_keras_object( - self.positional_encoder - ), + "projection_dim": self.projection_dim, + "patch_size": self.patch_size, } ) return config - - @classmethod - def from_config(cls, config): - if "tubelet_embedder" in config and isinstance( - config["tubelet_embedder"], dict - ): - config["tubelet_embedder"] = keras.layers.deserialize( - config["tubelet_embedder"] - ) - if "positional_encoder" in config and isinstance( - config["positional_encoder"], dict - ): - config["positional_encoder"] = keras.layers.deserialize( - config["positional_encoder"] - ) - return super().from_config(config) diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py index eff877d37a..706ac33509 100644 --- a/keras_cv/models/video_classification/vivit_layers.py +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -33,19 +33,16 @@ class TubeletEmbedding(keras.layers.Layer): Args: embed_dim: int, number of dimensions in the embedding space. Defaults to 128. - patch_size: tuple or int, size of the spatio-temporal patch. - If int, the same size is used for all dimensions. - If tuple, specifies the size for each dimension. - Defaults to 8. + patch_size: tuple , size of the spatio-temporal patch. + Specifies the size for each dimension. + Defaults to (8,8,8). """ - def __init__(self, embed_dim=128, patch_size=8, **kwargs): + def __init__(self, embed_dim=128, patch_size=(8, 8, 8), **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.patch_size = patch_size - - def build(self, input_shape): self.projection = keras.layers.Conv3D( filters=self.embed_dim, kernel_size=self.patch_size, @@ -54,18 +51,26 @@ def build(self, input_shape): ) self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim)) + def build(self, input_shape): + if input_shape is not None: + self.projection.build(input_shape) + projected_patch_shape = self.projection.compute_output_shape( + input_shape + ) + self.flatten.build(projected_patch_shape) + + def compute_output_shape(self, input_shape): + if input_shape is not None: + projected_patch_shape = self.projection.compute_output_shape( + input_shape + ) + return self.flatten.compute_output_shape(projected_patch_shape) + def call(self, videos): projected_patches = self.projection(videos) flattened_patches = self.flatten(projected_patches) return flattened_patches - def get_config(self): - config = super().get_config() - config.update( - {"embed_dim": self.embed_dim, "patch_size": self.patch_size} - ) - return config - @keras_cv_export( "keras_cv.layers.PositionalEncoder", @@ -90,11 +95,13 @@ def __init__(self, embed_dim=128, **kwargs): self.embed_dim = embed_dim def build(self, input_shape): - _, num_tokens, _ = input_shape - self.position_embedding = keras.layers.Embedding( - input_dim=num_tokens, output_dim=self.embed_dim - ) - self.positions = ops.arange(start=0, stop=num_tokens, step=1) + if input_shape is not None: + _, num_tokens, _ = input_shape + self.position_embedding = keras.layers.Embedding( + input_dim=num_tokens, output_dim=self.embed_dim + ) + self.position_embedding.build(input_shape) + self.positions = ops.arange(start=0, stop=num_tokens, step=1) def call(self, encoded_tokens): encoded_positions = self.position_embedding(self.positions) diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py index e17e9c75d7..268b90c294 100644 --- a/keras_cv/models/video_classification/vivit_test.py +++ b/keras_cv/models/video_classification/vivit_test.py @@ -22,8 +22,6 @@ from keras_cv.backend import ops from keras_cv.backend.config import keras_3 from keras_cv.models.video_classification.vivit import ViViT -from keras_cv.models.video_classification.vivit_layers import PositionalEncoder -from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding from keras_cv.tests.test_case import TestCase @@ -38,10 +36,8 @@ def test_vivit_construction(self): num_layers = 8 model = ViViT( - tubelet_embedder=TubeletEmbedding( - embed_dim=projection_dim, patch_size=patch_size - ), - positional_encoder=PositionalEncoder(embed_dim=projection_dim), + projection_dim=projection_dim, + patch_size=patch_size, inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, @@ -70,10 +66,8 @@ def test_vivit_call(self): num_layers = 8 model = ViViT( - tubelet_embedder=TubeletEmbedding( - embed_dim=projection_dim, patch_size=patch_size - ), - positional_encoder=PositionalEncoder(embed_dim=projection_dim), + projection_dim=projection_dim, + patch_size=patch_size, inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, @@ -100,10 +94,8 @@ def test_weights_change(self): ds = ds.batch(2) model = ViViT( - tubelet_embedder=TubeletEmbedding( - embed_dim=projection_dim, patch_size=patch_size - ), - positional_encoder=PositionalEncoder(embed_dim=projection_dim), + projection_dim=projection_dim, + patch_size=patch_size, inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, @@ -123,8 +115,7 @@ def test_weights_change(self): ], ) - layer_name = "multi_head_attention_23" - representation_layer = model.get_layer(layer_name) + representation_layer = model.get_layer(index=-8) # Accesses MHSA Layer original_weights = representation_layer.get_weights() model.fit(ds, epochs=1) @@ -134,7 +125,7 @@ def test_weights_change(self): self.assertNotAllEqual(w1, w2) self.assertFalse(ops.any(ops.isnan(w2))) - @pytest.mark.large # Saving is slow, so mark these large. + # @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self): input_shape = (28, 28, 28, 1) num_classes = 11 @@ -145,10 +136,8 @@ def test_saved_model(self): num_layers = 8 model = ViViT( - tubelet_embedder=TubeletEmbedding( - embed_dim=projection_dim, patch_size=patch_size - ), - positional_encoder=PositionalEncoder(embed_dim=projection_dim), + projection_dim=projection_dim, + patch_size=patch_size, inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, From 9536d35cd46fc8d26310eeaef563f196d2c001e5 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Tue, 27 Feb 2024 01:33:55 +0000 Subject: [PATCH 11/12] Updated build.sh --- .kokoro/github/ubuntu/gpu/build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..f8c1bd7194 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/video_classification \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -82,6 +83,7 @@ else keras_cv/models/classification \ keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ + keras_cv/models/video_classification \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion From 36541bb87d9c0db1040b38fe9f5730d0285e39a5 Mon Sep 17 00:00:00 2001 From: aditya02shah Date: Fri, 1 Mar 2024 16:03:44 +0000 Subject: [PATCH 12/12] Updated Build Method --- keras_cv/models/video_classification/vivit.py | 103 ++++++++++-------- .../video_classification/vivit_layers.py | 52 ++++++--- .../models/video_classification/vivit_test.py | 12 +- 3 files changed, 100 insertions(+), 67 deletions(-) diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py index f90c555d96..3858f423aa 100644 --- a/keras_cv/models/video_classification/vivit.py +++ b/keras_cv/models/video_classification/vivit.py @@ -39,7 +39,7 @@ class ViViT(Task): transformer_layers: int, the number of transformer layers in the model. Defaults to 8. patch_size: tuple , contains the size of the - spatio-temporal patches for each dimension + spatio-temporal patches for each dimension Defaults to (8,8,8) num_heads: int, the number of heads for multi-head self-attention mechanism. Defaults to 8. @@ -64,21 +64,18 @@ class ViViT(Task): frames = np.random.uniform(size=(5, 32, 32, 32, 1)) labels = np.ones(shape=(5)) + # Instantiate Model model = ViViT( projection_dim=PROJECTION_DIM, patch_size=PATCH_SIZE, inp_shape=INPUT_SHAPE, transformer_layers=NUM_LAYERS, num_heads=NUM_HEADS, - embed_dim=PROJECTION_DIM, layer_norm_eps=LAYER_NORM_EPS, num_classes=NUM_CLASSES, ) - # Evaluate model - model(frames) - - # Train model + # Compile model model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", @@ -87,6 +84,10 @@ class ViViT(Task): ], ) + # Build Model + model.build(INPUT_SHAPE) + + # Train Model model.fit(frames, labels, epochs=3) ``` @@ -100,10 +101,11 @@ def __init__( patch_size=(8, 8, 8), transformer_layers=8, num_heads=8, - embed_dim=128, layer_norm_eps=1e-6, **kwargs, ): + super().__init__(**kwargs) + self.projection_dim = projection_dim self.patch_size = patch_size self.tubelet_embedder = TubeletEmbedding( @@ -113,58 +115,73 @@ def __init__( self.positional_encoder = PositionalEncoder( embed_dim=self.projection_dim ) - - inputs = keras.layers.Input(shape=inp_shape) - patches = self.tubelet_embedder(inputs) - encoded_patches = self.positional_encoder(patches) - - for _ in range(transformer_layers): - x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches) - attention_output = keras.layers.MultiHeadAttention( - num_heads=num_heads, - key_dim=embed_dim // num_heads, - dropout=0.1, - )(x1, x1) - - x2 = keras.layers.Add()([attention_output, encoded_patches]) - - x3 = keras.layers.LayerNormalization(epsilon=1e-6)(x2) - x3 = keras.Sequential( - [ - keras.layers.Dense( - units=embed_dim * 4, activation=keras.ops.gelu - ), - keras.layers.Dense( - units=embed_dim, activation=keras.ops.gelu - ), - ] - )(x3) - - encoded_patches = keras.layers.Add()([x3, x2]) - - representation = keras.layers.LayerNormalization( + self.layer_norm = keras.layers.LayerNormalization( epsilon=layer_norm_eps - )(encoded_patches) - representation = keras.layers.GlobalAvgPool1D()(representation) - - outputs = keras.layers.Dense(units=num_classes, activation="softmax")( - representation + ) + self.attention_output = keras.layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=projection_dim // num_heads, + dropout=0.1, + ) + self.dense_1 = keras.layers.Dense( + units=projection_dim * 4, activation=keras.ops.gelu ) - super().__init__(inputs=inputs, outputs=outputs, **kwargs) + self.dense_2 = keras.layers.Dense( + units=projection_dim, activation=keras.ops.gelu + ) + self.add = keras.layers.Add() + self.pooling = keras.layers.GlobalAvgPool1D() + self.dense_output = keras.layers.Dense( + units=num_classes, activation="softmax" + ) self.inp_shape = inp_shape self.num_heads = num_heads self.num_classes = num_classes self.projection_dim = projection_dim self.patch_size = patch_size + self.transformer_layers = transformer_layers def build(self, input_shape): + super().build(input_shape) self.tubelet_embedder.build(input_shape) flattened_patch_shape = self.tubelet_embedder.compute_output_shape( input_shape ) self.positional_encoder.build(flattened_patch_shape) + self.layer_norm.build([None, None, self.projection_dim]) + self.attention_output.build( + query_shape=[None, None, self.projection_dim], + value_shape=[None, None, self.projection_dim], + ) + self.add.build( + [ + (None, None, self.projection_dim), + (None, None, self.projection_dim), + ] + ) + + self.dense_1.build([None, None, self.projection_dim]) + self.dense_2.build([None, None, self.projection_dim * 4]) + self.pooling.build([None, None, self.projection_dim]) + self.dense_output.build([None, self.projection_dim]) + + def call(self, x): + patches = self.tubelet_embedder(x) + encoded_patches = self.positional_encoder(patches) + for _ in range(self.transformer_layers): + x1 = self.layer_norm(encoded_patches) + attention_output = self.attention_output(x1, x1) + x2 = self.add([attention_output, encoded_patches]) + x3 = self.layer_norm(x2) + x4 = self.dense_1(x3) + x5 = self.dense_2(x4) + encoded_patches = self.add([x5, x2]) + representation = self.layer_norm(encoded_patches) + pooled_representation = self.pooling(representation) + outputs = self.dense_output(pooled_representation) + return outputs def get_config(self): config = super().get_config() diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py index 706ac33509..53c2a0fc79 100644 --- a/keras_cv/models/video_classification/vivit_layers.py +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -47,24 +47,44 @@ def __init__(self, embed_dim=128, patch_size=(8, 8, 8), **kwargs): filters=self.embed_dim, kernel_size=self.patch_size, strides=self.patch_size, + data_format="channels_last", padding="VALID", ) self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim)) def build(self, input_shape): - if input_shape is not None: - self.projection.build(input_shape) - projected_patch_shape = self.projection.compute_output_shape( - input_shape + super().build(input_shape) + self.projection.build( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], ) - self.flatten.build(projected_patch_shape) + ) + projected_patch_shape = self.projection.compute_output_shape( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + ) + ) + self.flatten.build(projected_patch_shape) def compute_output_shape(self, input_shape): - if input_shape is not None: - projected_patch_shape = self.projection.compute_output_shape( - input_shape + projected_patch_shape = self.projection.compute_output_shape( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], ) - return self.flatten.compute_output_shape(projected_patch_shape) + ) + return self.flatten.compute_output_shape(projected_patch_shape) def call(self, videos): projected_patches = self.projection(videos) @@ -95,13 +115,13 @@ def __init__(self, embed_dim=128, **kwargs): self.embed_dim = embed_dim def build(self, input_shape): - if input_shape is not None: - _, num_tokens, _ = input_shape - self.position_embedding = keras.layers.Embedding( - input_dim=num_tokens, output_dim=self.embed_dim - ) - self.position_embedding.build(input_shape) - self.positions = ops.arange(start=0, stop=num_tokens, step=1) + super().build(input_shape) + _, num_tokens, _ = input_shape + self.position_embedding = keras.layers.Embedding( + input_dim=num_tokens, output_dim=self.embed_dim + ) + self.position_embedding.build(input_shape) + self.positions = ops.arange(start=0, stop=num_tokens, step=1) def call(self, encoded_tokens): encoded_positions = self.position_embedding(self.positions) diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py index 268b90c294..ed561debc3 100644 --- a/keras_cv/models/video_classification/vivit_test.py +++ b/keras_cv/models/video_classification/vivit_test.py @@ -41,7 +41,6 @@ def test_vivit_construction(self): inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, - embed_dim=projection_dim, layer_norm_eps=layer_norm_eps, num_classes=num_classes, ) @@ -71,10 +70,10 @@ def test_vivit_call(self): inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, - embed_dim=projection_dim, layer_norm_eps=layer_norm_eps, num_classes=num_classes, ) + model.build(input_shape) frames = np.random.uniform(size=(5, 28, 28, 28, 1)) _ = model(frames) @@ -99,7 +98,6 @@ def test_weights_change(self): inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, - embed_dim=projection_dim, layer_norm_eps=layer_norm_eps, num_classes=num_classes, ) @@ -114,9 +112,8 @@ def test_weights_change(self): ), ], ) - + model.build(input_shape) representation_layer = model.get_layer(index=-8) # Accesses MHSA Layer - original_weights = representation_layer.get_weights() model.fit(ds, epochs=1) updated_weights = representation_layer.get_weights() @@ -125,7 +122,7 @@ def test_weights_change(self): self.assertNotAllEqual(w1, w2) self.assertFalse(ops.any(ops.isnan(w2))) - # @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self): input_shape = (28, 28, 28, 1) num_classes = 11 @@ -141,11 +138,10 @@ def test_saved_model(self): inp_shape=input_shape, transformer_layers=num_layers, num_heads=num_heads, - embed_dim=projection_dim, layer_norm_eps=layer_norm_eps, num_classes=num_classes, ) - + model.build(input_shape) input_batch = np.random.uniform(size=(5, 28, 28, 28, 1)) model_output = model(input_batch)