Skip to content

Commit f9c26b4

Browse files
Hilly12hashgupta
authored andcommitted
Add BERT4Rec and SASRec.
Co-authored-by: hashgupta <yash.gupta.7782@gmail.com> PiperOrigin-RevId: 745116291
1 parent 75ca93d commit f9c26b4

File tree

6 files changed

+800
-0
lines changed

6 files changed

+800
-0
lines changed

RecML/layers/keras/bert4rec.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2024 RecML authors <recommendations-ml@google.com>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Models baselined."""
15+
16+
from collections.abc import Mapping, Sequence
17+
from typing import Any
18+
19+
import keras
20+
import keras_hub
21+
from recml.layers.keras import utils
22+
23+
Tensor = Any
24+
25+
26+
@keras.saving.register_keras_serializable("recml")
27+
class BERT4Rec(keras.layers.Layer):
28+
"""BERT4Rec architecture as in [1].
29+
30+
Implements the BERT4Rec model architecture as described in 'BERT4Rec:
31+
Sequential Recommendation with Bidirectional Encoder Representations from
32+
Transformer' [1].
33+
34+
[1] https://arxiv.org/abs/1904.06690
35+
"""
36+
37+
def __init__(
38+
self,
39+
*,
40+
vocab_size: int,
41+
max_positions: int,
42+
num_types: int | None = None,
43+
model_dim: int,
44+
mlp_dim: int,
45+
num_heads: int,
46+
num_layers: int,
47+
dropout: float = 0.0,
48+
norm_eps: float = 1e-12,
49+
add_head: bool = True,
50+
**kwargs,
51+
):
52+
"""Initializes the instance.
53+
54+
Args:
55+
vocab_size: The size of the item vocabulary.
56+
max_positions: The maximum number of positions in a sequence.
57+
num_types: The number of types. If None, no type embedding is used.
58+
Defaults to None.
59+
model_dim: The width of the embeddings in the model.
60+
mlp_dim: The width of the MLP in each transformer block.
61+
num_heads: The number of attention heads in each transformer block.
62+
num_layers: The number of transformer blocks in the model.
63+
dropout: The dropout rate. Defaults to 0.
64+
norm_eps: The epsilon for layer normalization.
65+
add_head: Whether to add a masked language modeling head.
66+
**kwargs: Passed through to the super class.
67+
"""
68+
69+
super().__init__(**kwargs)
70+
71+
self.item_embedding = keras_hub.layers.ReversibleEmbedding(
72+
input_dim=vocab_size,
73+
output_dim=model_dim,
74+
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
75+
dtype=self.dtype_policy,
76+
reverse_dtype=self.compute_dtype,
77+
name="item_embedding",
78+
)
79+
if num_types is not None:
80+
self.type_embedding = keras.layers.Embedding(
81+
input_dim=num_types,
82+
output_dim=model_dim,
83+
embeddings_initializer=keras.initializers.TruncatedNormal(
84+
stddev=0.02
85+
),
86+
dtype=self.dtype_policy,
87+
name="type_embedding",
88+
)
89+
else:
90+
self.type_embedding = None
91+
92+
self.position_embedding = keras_hub.layers.PositionEmbedding(
93+
sequence_length=max_positions,
94+
initializer=keras.initializers.TruncatedNormal(stddev=0.02),
95+
dtype=self.dtype_policy,
96+
name="position_embedding",
97+
)
98+
99+
self.embeddings_norm = keras.layers.LayerNormalization(
100+
epsilon=1e-12, name="embedding_norm"
101+
)
102+
self.embeddings_dropout = keras.layers.Dropout(
103+
dropout, name="embedding_dropout"
104+
)
105+
106+
self.encoder_blocks = [
107+
keras_hub.layers.TransformerEncoder(
108+
intermediate_dim=mlp_dim,
109+
num_heads=num_heads,
110+
dropout=dropout,
111+
activation=utils.gelu_approximate,
112+
layer_norm_epsilon=norm_eps,
113+
normalize_first=False,
114+
dtype=self.dtype_policy,
115+
name=f"encoder_block_{i}",
116+
)
117+
for i in range(num_layers)
118+
]
119+
if add_head:
120+
self.head = keras_hub.layers.MaskedLMHead(
121+
vocabulary_size=vocab_size,
122+
token_embedding=self.item_embedding,
123+
intermediate_activation=utils.gelu_approximate,
124+
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
125+
dtype=self.dtype_policy,
126+
name="mlm_head",
127+
)
128+
else:
129+
self.head = None
130+
131+
self._vocab_size = vocab_size
132+
self._model_dim = model_dim
133+
self._config = {
134+
"vocab_size": vocab_size,
135+
"max_positions": max_positions,
136+
"num_types": num_types,
137+
"model_dim": model_dim,
138+
"mlp_dim": mlp_dim,
139+
"num_heads": num_heads,
140+
"num_layers": num_layers,
141+
"dropout": dropout,
142+
"norm_eps": norm_eps,
143+
"add_head": add_head,
144+
}
145+
146+
def build(self, inputs_shape: Sequence[int]):
147+
self.item_embedding.build(inputs_shape)
148+
if self.type_embedding is not None:
149+
self.type_embedding.build(inputs_shape)
150+
151+
self.position_embedding.build((*inputs_shape, self._model_dim))
152+
self.embeddings_norm.build((*inputs_shape, self._model_dim))
153+
154+
for encoder_block in self.encoder_blocks:
155+
encoder_block.build((*inputs_shape, self._model_dim))
156+
157+
if self.head is not None:
158+
self.head.build((*inputs_shape, self._model_dim))
159+
160+
def call(
161+
self,
162+
inputs: Tensor,
163+
type_ids: Tensor | None = None,
164+
padding_mask: Tensor | None = None,
165+
attention_mask: Tensor | None = None,
166+
mask_positions: Tensor | None = None,
167+
training: bool = False,
168+
) -> Tensor:
169+
embeddings = self.item_embedding(inputs)
170+
if self.type_embedding is not None:
171+
if type_ids is None:
172+
raise ValueError(
173+
"`type_ids` cannot be None when `num_types` is not None."
174+
)
175+
embeddings += self.type_embedding(type_ids)
176+
embeddings += self.position_embedding(embeddings)
177+
178+
embeddings = self.embeddings_norm(embeddings)
179+
embeddings = self.embeddings_dropout(embeddings, training=training)
180+
181+
for encoder_block in self.encoder_blocks:
182+
embeddings = encoder_block(
183+
embeddings,
184+
padding_mask=padding_mask,
185+
attention_mask=attention_mask,
186+
training=training,
187+
)
188+
189+
if self.head is None:
190+
return embeddings
191+
192+
return self.head(embeddings, mask_positions)
193+
194+
def compute_output_shape(
195+
self,
196+
inputs_shape: Sequence[int],
197+
mask_positions_shape: Tensor | None = None,
198+
) -> Sequence[int | None]:
199+
if self.head is not None:
200+
if mask_positions_shape is None:
201+
raise ValueError(
202+
"`mask_positions_shape` cannot be None when `add_head` is True."
203+
)
204+
return (*inputs_shape[:-1], mask_positions_shape[-1], self._vocab_size)
205+
return (*inputs_shape, self._model_dim)
206+
207+
def get_config(self) -> Mapping[str, Any]:
208+
return {**super().get_config(), **self._config}

RecML/layers/keras/bert4rec_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 RecML authors <recommendations-ml@google.com>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for Keras architectures."""
15+
16+
from absl.testing import absltest
17+
import keras
18+
from keras.src import testing
19+
from recml.layers.keras import bert4rec
20+
21+
22+
class BERT4RecTest(testing.TestCase):
23+
24+
def test_bert4rec(self):
25+
item_ids = keras.ops.array([[1, 2, 3], [4, 5, 0]], "int32")
26+
item_type_ids = keras.ops.array([[1, 2, 3], [4, 4, 0]], "int32")
27+
mask = keras.ops.array([[1, 1, 1], [1, 1, 0]], "int32")
28+
mask_positions = keras.ops.array([[0], [0]], "int32")
29+
init_kws = {
30+
"vocab_size": 500,
31+
"num_types": 5,
32+
"max_positions": 20,
33+
"model_dim": 32,
34+
"mlp_dim": 64,
35+
"num_heads": 4,
36+
"num_layers": 3,
37+
"dropout": 0.1,
38+
}
39+
40+
tvars = (
41+
(500 * 32) # Item embedding
42+
+ (5 * 32) # Type embedding
43+
+ (20 * 32) # Position embedding
44+
+ (2 * 32) # Embedding norm
45+
+ 3 # 3 encoder blocks
46+
* (
47+
((32 + 1) * 32 * 3 + (32 + 1) * 32) # Attention QKVO
48+
+ (2 * 32) # Attention block norm
49+
+ ((32 + 1) * 64) # MLP inner projection
50+
+ ((64 + 1) * 32) # MLP outer projection
51+
+ (2 * 32) # MLP block norm
52+
)
53+
+ (32 + 1) * 32 # Head projection
54+
+ (2 * 32) # Head norm
55+
+ 500 # Head bias
56+
)
57+
seed_generators = 1 + 3 * 3 # 1 seed generator for each dropout layer.
58+
59+
model = bert4rec.BERT4Rec(**init_kws)
60+
model.build(keras.ops.shape(item_ids))
61+
self.assertEqual(model.count_params(), tvars)
62+
63+
self.run_layer_test(
64+
bert4rec.BERT4Rec,
65+
init_kwargs={**init_kws, "add_head": False},
66+
input_data=item_ids,
67+
call_kwargs={
68+
"type_ids": item_type_ids,
69+
"padding_mask": mask,
70+
"mask_positions": mask_positions,
71+
},
72+
expected_output_shape=(2, 3, 32),
73+
expected_output_dtype="float32",
74+
expected_num_seed_generators=seed_generators,
75+
run_training_check=False,
76+
)
77+
78+
79+
if __name__ == "__main__":
80+
absltest.main()

0 commit comments

Comments
 (0)