|
| 1 | +# Copyright 2024 RecML authors <recommendations-ml@google.com>. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Models baselined.""" |
| 15 | + |
| 16 | +from collections.abc import Mapping, Sequence |
| 17 | +from typing import Any |
| 18 | + |
| 19 | +import keras |
| 20 | +import keras_hub |
| 21 | +from recml.layers.keras import utils |
| 22 | + |
| 23 | +Tensor = Any |
| 24 | + |
| 25 | + |
| 26 | +@keras.saving.register_keras_serializable("recml") |
| 27 | +class BERT4Rec(keras.layers.Layer): |
| 28 | + """BERT4Rec architecture as in [1]. |
| 29 | +
|
| 30 | + Implements the BERT4Rec model architecture as described in 'BERT4Rec: |
| 31 | + Sequential Recommendation with Bidirectional Encoder Representations from |
| 32 | + Transformer' [1]. |
| 33 | +
|
| 34 | + [1] https://arxiv.org/abs/1904.06690 |
| 35 | + """ |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + *, |
| 40 | + vocab_size: int, |
| 41 | + max_positions: int, |
| 42 | + num_types: int | None = None, |
| 43 | + model_dim: int, |
| 44 | + mlp_dim: int, |
| 45 | + num_heads: int, |
| 46 | + num_layers: int, |
| 47 | + dropout: float = 0.0, |
| 48 | + norm_eps: float = 1e-12, |
| 49 | + add_head: bool = True, |
| 50 | + **kwargs, |
| 51 | + ): |
| 52 | + """Initializes the instance. |
| 53 | +
|
| 54 | + Args: |
| 55 | + vocab_size: The size of the item vocabulary. |
| 56 | + max_positions: The maximum number of positions in a sequence. |
| 57 | + num_types: The number of types. If None, no type embedding is used. |
| 58 | + Defaults to None. |
| 59 | + model_dim: The width of the embeddings in the model. |
| 60 | + mlp_dim: The width of the MLP in each transformer block. |
| 61 | + num_heads: The number of attention heads in each transformer block. |
| 62 | + num_layers: The number of transformer blocks in the model. |
| 63 | + dropout: The dropout rate. Defaults to 0. |
| 64 | + norm_eps: The epsilon for layer normalization. |
| 65 | + add_head: Whether to add a masked language modeling head. |
| 66 | + **kwargs: Passed through to the super class. |
| 67 | + """ |
| 68 | + |
| 69 | + super().__init__(**kwargs) |
| 70 | + |
| 71 | + self.item_embedding = keras_hub.layers.ReversibleEmbedding( |
| 72 | + input_dim=vocab_size, |
| 73 | + output_dim=model_dim, |
| 74 | + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02), |
| 75 | + dtype=self.dtype_policy, |
| 76 | + reverse_dtype=self.compute_dtype, |
| 77 | + name="item_embedding", |
| 78 | + ) |
| 79 | + if num_types is not None: |
| 80 | + self.type_embedding = keras.layers.Embedding( |
| 81 | + input_dim=num_types, |
| 82 | + output_dim=model_dim, |
| 83 | + embeddings_initializer=keras.initializers.TruncatedNormal( |
| 84 | + stddev=0.02 |
| 85 | + ), |
| 86 | + dtype=self.dtype_policy, |
| 87 | + name="type_embedding", |
| 88 | + ) |
| 89 | + else: |
| 90 | + self.type_embedding = None |
| 91 | + |
| 92 | + self.position_embedding = keras_hub.layers.PositionEmbedding( |
| 93 | + sequence_length=max_positions, |
| 94 | + initializer=keras.initializers.TruncatedNormal(stddev=0.02), |
| 95 | + dtype=self.dtype_policy, |
| 96 | + name="position_embedding", |
| 97 | + ) |
| 98 | + |
| 99 | + self.embeddings_norm = keras.layers.LayerNormalization( |
| 100 | + epsilon=1e-12, name="embedding_norm" |
| 101 | + ) |
| 102 | + self.embeddings_dropout = keras.layers.Dropout( |
| 103 | + dropout, name="embedding_dropout" |
| 104 | + ) |
| 105 | + |
| 106 | + self.encoder_blocks = [ |
| 107 | + keras_hub.layers.TransformerEncoder( |
| 108 | + intermediate_dim=mlp_dim, |
| 109 | + num_heads=num_heads, |
| 110 | + dropout=dropout, |
| 111 | + activation=utils.gelu_approximate, |
| 112 | + layer_norm_epsilon=norm_eps, |
| 113 | + normalize_first=False, |
| 114 | + dtype=self.dtype_policy, |
| 115 | + name=f"encoder_block_{i}", |
| 116 | + ) |
| 117 | + for i in range(num_layers) |
| 118 | + ] |
| 119 | + if add_head: |
| 120 | + self.head = keras_hub.layers.MaskedLMHead( |
| 121 | + vocabulary_size=vocab_size, |
| 122 | + token_embedding=self.item_embedding, |
| 123 | + intermediate_activation=utils.gelu_approximate, |
| 124 | + kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), |
| 125 | + dtype=self.dtype_policy, |
| 126 | + name="mlm_head", |
| 127 | + ) |
| 128 | + else: |
| 129 | + self.head = None |
| 130 | + |
| 131 | + self._vocab_size = vocab_size |
| 132 | + self._model_dim = model_dim |
| 133 | + self._config = { |
| 134 | + "vocab_size": vocab_size, |
| 135 | + "max_positions": max_positions, |
| 136 | + "num_types": num_types, |
| 137 | + "model_dim": model_dim, |
| 138 | + "mlp_dim": mlp_dim, |
| 139 | + "num_heads": num_heads, |
| 140 | + "num_layers": num_layers, |
| 141 | + "dropout": dropout, |
| 142 | + "norm_eps": norm_eps, |
| 143 | + "add_head": add_head, |
| 144 | + } |
| 145 | + |
| 146 | + def build(self, inputs_shape: Sequence[int]): |
| 147 | + self.item_embedding.build(inputs_shape) |
| 148 | + if self.type_embedding is not None: |
| 149 | + self.type_embedding.build(inputs_shape) |
| 150 | + |
| 151 | + self.position_embedding.build((*inputs_shape, self._model_dim)) |
| 152 | + self.embeddings_norm.build((*inputs_shape, self._model_dim)) |
| 153 | + |
| 154 | + for encoder_block in self.encoder_blocks: |
| 155 | + encoder_block.build((*inputs_shape, self._model_dim)) |
| 156 | + |
| 157 | + if self.head is not None: |
| 158 | + self.head.build((*inputs_shape, self._model_dim)) |
| 159 | + |
| 160 | + def call( |
| 161 | + self, |
| 162 | + inputs: Tensor, |
| 163 | + type_ids: Tensor | None = None, |
| 164 | + padding_mask: Tensor | None = None, |
| 165 | + attention_mask: Tensor | None = None, |
| 166 | + mask_positions: Tensor | None = None, |
| 167 | + training: bool = False, |
| 168 | + ) -> Tensor: |
| 169 | + embeddings = self.item_embedding(inputs) |
| 170 | + if self.type_embedding is not None: |
| 171 | + if type_ids is None: |
| 172 | + raise ValueError( |
| 173 | + "`type_ids` cannot be None when `num_types` is not None." |
| 174 | + ) |
| 175 | + embeddings += self.type_embedding(type_ids) |
| 176 | + embeddings += self.position_embedding(embeddings) |
| 177 | + |
| 178 | + embeddings = self.embeddings_norm(embeddings) |
| 179 | + embeddings = self.embeddings_dropout(embeddings, training=training) |
| 180 | + |
| 181 | + for encoder_block in self.encoder_blocks: |
| 182 | + embeddings = encoder_block( |
| 183 | + embeddings, |
| 184 | + padding_mask=padding_mask, |
| 185 | + attention_mask=attention_mask, |
| 186 | + training=training, |
| 187 | + ) |
| 188 | + |
| 189 | + if self.head is None: |
| 190 | + return embeddings |
| 191 | + |
| 192 | + return self.head(embeddings, mask_positions) |
| 193 | + |
| 194 | + def compute_output_shape( |
| 195 | + self, |
| 196 | + inputs_shape: Sequence[int], |
| 197 | + mask_positions_shape: Tensor | None = None, |
| 198 | + ) -> Sequence[int | None]: |
| 199 | + if self.head is not None: |
| 200 | + if mask_positions_shape is None: |
| 201 | + raise ValueError( |
| 202 | + "`mask_positions_shape` cannot be None when `add_head` is True." |
| 203 | + ) |
| 204 | + return (*inputs_shape[:-1], mask_positions_shape[-1], self._vocab_size) |
| 205 | + return (*inputs_shape, self._model_dim) |
| 206 | + |
| 207 | + def get_config(self) -> Mapping[str, Any]: |
| 208 | + return {**super().get_config(), **self._config} |
0 commit comments