diff --git a/.gitignore b/.gitignore index 7f756fb5a0..d646eb5568 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ tmp/* sources/* site/* +*.DS_Store *.pyc *.swp templates/examples/audio/* diff --git a/examples/keras_rs/dlrm.py b/examples/keras_rs/dlrm.py new file mode 100644 index 0000000000..af357f99e8 --- /dev/null +++ b/examples/keras_rs/dlrm.py @@ -0,0 +1,413 @@ +""" +Title: Ranking with Deep Learning Recommendation Model +Author: Harshith Kulkarni +Date created: 2025/06/02 +Last modified: 2025/06/20 +Description: Rank movies with DLRM with KerasRS +""" + +""" +## Introduction + +This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to +effectively learn the relationships between items and user preferences using a +dot-product interaction mechanism. For more details, please refer to the +[DLRM](https://arxiv.org/pdf/1906.00091) paper. + +DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and +is particularly effective at processing both categorical and continuous (sparse/dense) +input features. The architecture consists of three main components: dedicated input +layers to handle diverse features (typically embedding layers for categorical features), +a dot-product interaction layer to explicitly model feature interactions, and a +Multi-Layer Perceptron (MLP) to capture implicit feature relationships. + +The dot-product interaction layer lies at the heart of DLRM, efficiently computing +pairwise interactions between different feature embeddings. This contrasts with models +like Deep & Cross Network (DCN), which can treat elements within a feature vector as +independent units, potentially leading to a higher-dimensional space and increased +computational cost. The MLP is a standard feedforward network. The DLRM is formed by +combining the interaction layer and MLP. + +The following image illustrates the DLRM architecture: + +![DLRM Architecture](https://raw.githubusercontent.com/kharshith-k/keras-io/refs/heads/keras-rs-examples/examples/keras_rs/img/dlrm/dlrm_architecture.gif) + + +Now that we have a foundational understanding of DLRM's architecture and key +characteristics, let's dive into the code. We will train a DLRM on a real-world dataset +to demonstrate its capability to learn meaningful feature interactions. Let's begin by +setting the backend to JAX and organizing our imports. +""" + +"""shell +!pip install keras-rs +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` + +import keras +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from mpl_toolkits.axes_grid1 import make_axes_locatable + +import keras_rs + +""" +Let's also define variables which will be reused throughout the example. +""" + +MOVIELENS_CONFIG = { + # features + "int_features": [ + "movie_id", + "user_id", + "user_gender", + "bucketized_user_age", + ], + "str_features": [ + "user_zip_code", + "user_occupation_text", + ], + # model + "embedding_dim": 8, + "deep_net_num_units": [192, 192, 192], + # training + "learning_rate": 1e-2, + "num_epochs": 25, + "batch_size": 8192, +} + + +""" +Here, we define a helper function for visualising weights of the cross layer in +order to better understand its functioning. Also, we define a function for +compiling, training and evaluating a given model. +""" + + +def plot_training_metrics(history): + """Graphs all metrics tracked in the history object.""" + plt.figure(figsize=(12, 6)) + + for metric_name, metric_values in history.history.items(): + plt.plot(metric_values, label=metric_name.replace("_", " ").title()) + + plt.title("Metrics over Epochs") + plt.xlabel("Epoch") + plt.ylabel("Metric Value") + plt.legend() + plt.grid(True) + + +def visualize_layer(matrix, features, cmap=plt.cm.Blues): + plt.figure(figsize=(9, 9)) + + im = plt.matshow(matrix, cmap=cmap) + + ax = plt.gca() + divider = make_axes_locatable(plt.gca()) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + cax.tick_params(labelsize=10) + + # Set tick locations explicitly before setting labels + ax.set_xticks(np.arange(len(features))) + ax.set_yticks(np.arange(len(features))) + + ax.set_xticklabels(features, rotation=45, fontsize=5) + ax.set_yticklabels(features, fontsize=5) + + plt.show() + + +def train_and_evaluate( + learning_rate, + epochs, + train_data, + test_data, + model, + plot_metrics=False, +): + optimizer = keras.optimizers.AdamW(learning_rate=learning_rate) + loss = keras.losses.MeanSquaredError() + rmse = keras.metrics.RootMeanSquaredError() + + model.compile( + optimizer=optimizer, + loss=loss, + metrics=[rmse], + ) + + history = model.fit( + train_data, + epochs=epochs, + verbose=1, + ) + if plot_metrics: + plot_training_metrics(history) + + results = model.evaluate(test_data, return_dict=True, verbose=1) + rmse_value = results["root_mean_squared_error"] + + return rmse_value, model.count_params() + + +def print_stats(rmse_list, num_params, model_name): + # Report metrics. + num_trials = len(rmse_list) + avg_rmse = np.mean(rmse_list) + std_rmse = np.std(rmse_list) + + if num_trials == 1: + print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}") + else: + print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}") + + +""" +## Real-world example + +Let's use the MovieLens 100K dataset. This dataset is used to train models to +predict users' movie ratings, based on user-related features and movie-related +features. + +### Preparing the dataset + +The dataset processing steps here are similar to what's given in the +[basic ranking](/keras_rs/examples/basic_ranking/) +tutorial. Let's load the dataset, and keep only the useful columns. +""" + +ratings_ds = tfds.load("movielens/100k-ratings", split="train") +ratings_ds = ratings_ds.map( + lambda x: ( + { + "movie_id": int(x["movie_id"]), + "user_id": int(x["user_id"]), + "user_gender": int(x["user_gender"]), + "user_zip_code": x["user_zip_code"], + "user_occupation_text": x["user_occupation_text"], + "bucketized_user_age": int(x["bucketized_user_age"]), + }, + x["user_rating"], # label + ) +) + +""" +For every feature, let's get the list of unique values, i.e., vocabulary, so +that we can use that for the embedding layer. +""" + +vocabularies = {} +for feature_name in MOVIELENS_CONFIG["int_features"] + MOVIELENS_CONFIG["str_features"]: + vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name]) + vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary))) + +""" +One thing we need to do is to use `keras.layers.StringLookup` and +`keras.layers.IntegerLookup` to convert all features into indices, which can +then be fed into embedding layers. +""" + +lookup_layers = {} +lookup_layers.update( + { + feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature]) + for feature in MOVIELENS_CONFIG["int_features"] + } +) +lookup_layers.update( + { + feature: keras.layers.StringLookup(vocabulary=vocabularies[feature]) + for feature in MOVIELENS_CONFIG["str_features"] + } +) + +ratings_ds = ratings_ds.map( + lambda x, y: ( + { + feature_name: lookup_layers[feature_name](x[feature_name]) + for feature_name in vocabularies + }, + y, + ) +) + +""" +Let's split our data into train and test sets. We also use `cache()` and +`prefetch()` for better performance. +""" + +ratings_ds = ratings_ds.shuffle(100_000) + +train_ds = ( + ratings_ds.take(80_000) + .batch(MOVIELENS_CONFIG["batch_size"]) + .cache() + .prefetch(tf.data.AUTOTUNE) +) +test_ds = ( + ratings_ds.skip(80_000) + .batch(MOVIELENS_CONFIG["batch_size"]) + .take(20_000) + .cache() + .prefetch(tf.data.AUTOTUNE) +) + +""" +### Building the model + +The model will have embedding layers, followed by DotInteraction and feedforward +layers. +""" + + +class DLRM(keras.Model): + def __init__( + self, + dense_num_units_lst, + embedding_dim=MOVIELENS_CONFIG["embedding_dim"], + **kwargs, + ): + super().__init__(**kwargs) + + # Layers. + + self.embedding_layers = [] + for feature_name, vocabulary in vocabularies.items(): + self.embedding_layers.append( + keras.layers.Embedding( + input_dim=len(vocabulary) + 1, + output_dim=embedding_dim, + ) + ) + + self.dot_layer = keras_rs.layers.DotInteraction() + + self.dense_layers = [] + for num_units in dense_num_units_lst: + self.dense_layers.append(keras.layers.Dense(num_units, activation="relu")) + + self.output_layer = keras.layers.Dense(1) + + # Attributes. + self.dense_num_units_lst = dense_num_units_lst + self.embedding_dim = embedding_dim + + def call(self, inputs): + embeddings = [] + for feature_name, embedding_layer in zip(vocabularies, self.embedding_layers): + embeddings.append(embedding_layer(inputs[feature_name])) + + # Pass the list of embeddings to the DotInteraction layer + x = self.dot_layer(embeddings) + + for dense_layer in self.dense_layers: + x = dense_layer(x) + + x = self.output_layer(x) + + return x + + +dot_network = DLRM( + dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"], + embedding_dim=MOVIELENS_CONFIG["embedding_dim"], +) +rmse, dot_network_num_params = train_and_evaluate( + learning_rate=MOVIELENS_CONFIG["learning_rate"], + epochs=MOVIELENS_CONFIG["num_epochs"], + train_data=train_ds, + test_data=test_ds, + model=dot_network, + plot_metrics=True, +) +print_stats( + rmse_list=[rmse], + num_params=dot_network_num_params, + model_name="Dot Network", +) + + +""" +### Visualizing feature interactions + +The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense +layer. Instead, its function is to compute the dot product between the embedding vectors +of your features. + +To visualize the strength of these interactions, we can calculate a matrix representing +the pairwise interaction strength between all feature embeddings. A common way to do this +is to take the dot product of the embedding matrices for each pair of features and then +aggregate the result into a single value (like the mean of the absolute values) that +represents the overall interaction strength. +""" + + +def get_dot_interaction_matrix(model, feature_names): + # Extract the trained embedding weights from the model + embedding_weights = [layer.get_weights()[0] for layer in model.embedding_layers] + + num_features = len(feature_names) + interaction_matrix = np.zeros((num_features, num_features)) + + # Iterate through each pair of features to calculate their interaction strength + for i in range(num_features): + for j in range(num_features): + # Calculate the dot product between the two embedding matrices + interaction = np.dot(embedding_weights[i], embedding_weights[j].T) + # Take the mean of the absolute values as a measure of interaction strength + interaction_strength = np.mean(np.abs(interaction)) + interaction_matrix[i, j] = interaction_strength + + return interaction_matrix + + +# Get the list of feature names in the correct order +feature_names = list(vocabularies.keys()) + +# Calculate the interaction matrix +interaction_matrix = get_dot_interaction_matrix(dot_network, feature_names) + +# Visualize the matrix as a heatmap +print("\nVisualizing the feature interaction strengths:") +visualize_layer(interaction_matrix, feature_names) + +dlrm_rmse_list = [] + +for _ in range(20): + + dot_network = DLRM( + dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"], + embedding_dim=MOVIELENS_CONFIG["embedding_dim"], + ) + rmse, dot_network_num_params = train_and_evaluate( + learning_rate=MOVIELENS_CONFIG["learning_rate"], + epochs=MOVIELENS_CONFIG["num_epochs"], + train_data=train_ds, + test_data=test_ds, + model=dot_network, + ) + dlrm_rmse_list.append(rmse) + +print_stats( + rmse_list=dlrm_rmse_list, + num_params=dot_network_num_params, + model_name="Dot Network", +) + + +def plot_rmse(rmse_list, model_name): + plt.figure() + plt.plot(rmse_list) + plt.title(f"RMSE over trials for {model_name}") + plt.xlabel("Trial") + plt.ylabel("RMSE") + plt.show() + + +plot_rmse(dlrm_rmse_list, "Dot Network") diff --git a/examples/keras_rs/img/dlrm/dlrm_architecture.gif b/examples/keras_rs/img/dlrm/dlrm_architecture.gif new file mode 100644 index 0000000000..7186adc365 Binary files /dev/null and b/examples/keras_rs/img/dlrm/dlrm_architecture.gif differ diff --git a/examples/keras_rs/ipynb/dlrm.ipynb b/examples/keras_rs/ipynb/dlrm.ipynb new file mode 100644 index 0000000000..c0b410ae69 --- /dev/null +++ b/examples/keras_rs/ipynb/dlrm.ipynb @@ -0,0 +1,2716 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "33c74727" + }, + "source": [ + "## Introduction\n", + "\n", + "This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to effectively learn the relationships between items and user preferences using a dot-product interaction mechanism. For more details, please refer to the [DLRM](https://arxiv.org/pdf/1906.00091) paper.\n", + "\n", + "DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and is particularly effective at processing both categorical and continuous (sparse/dense) input features. The architecture consists of three main components: dedicated input layers to handle diverse features (typically embedding layers for categorical features), a dot-product interaction layer to explicitly model feature interactions, and a Multi-Layer Perceptron (MLP) to capture implicit feature relationships.\n", + "\n", + "The dot-product interaction layer lies at the heart of DLRM, efficiently computing pairwise interactions between different feature embeddings. This contrasts with models like Deep & Cross Network (DCN), which can treat elements within a feature vector as independent units, potentially leading to a higher-dimensional space and increased computational cost. The MLP is a standard feedforward network. The DLRM is formed by combining the interaction layer and MLP.\n", + "\n", + "The following image illustrates the DLRM architecture:\n", + "\n", + "![DLRM Architecture](https://raw.githubusercontent.com/kharshith-k/keras-io/refs/heads/keras-rs-examples/examples/keras_rs/img/dlrm/dlrm_architecture.gif)\n", + "\n", + "\n", + "Now that we have a foundational understanding of DLRM's architecture and key characteristics, let's dive into the code. We will train a DLRM on a real-world dataset to demonstrate its capability to learn meaningful feature interactions. Let's begin by setting the backend to JAX and organizing our imports." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "blj4_0Wg62kR", + "outputId": "c386b17a-550c-4b21-dfc9-f7cf51b4a3a1" + }, + "outputs": [], + "source": [ + "!pip install keras-rs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QeFXrN1962kT" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\" # `\"tensorflow\"`/`\"torch\"`\n", + "\n", + "import keras\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", + "\n", + "import keras_rs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gHdT0D1762kT" + }, + "source": [ + "Let's also define variables which will be reused throughout the example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GRJOAiMv62kT" + }, + "outputs": [], + "source": [ + "MOVIELENS_CONFIG = {\n", + " # features\n", + " \"int_features\": [\n", + " \"movie_id\",\n", + " \"user_id\",\n", + " \"user_gender\",\n", + " \"bucketized_user_age\",\n", + " ],\n", + " \"str_features\": [\n", + " \"user_zip_code\",\n", + " \"user_occupation_text\",\n", + " ],\n", + " # model\n", + " \"embedding_dim\": 8,\n", + " \"deep_net_num_units\": [192, 192, 192],\n", + " # training\n", + " \"learning_rate\": 1e-2,\n", + " \"num_epochs\": 25,\n", + " \"batch_size\": 8192,\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_6J6XFNC62kT" + }, + "source": [ + "Here, we define a helper function for visualising weights of the cross layer in\n", + "order to better understand its functioning. Also, we define a function for\n", + "compiling, training and evaluating a given model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jyojSTKw62kT" + }, + "outputs": [], + "source": [ + "def plot_training_metrics(history):\n", + " \"\"\"Graphs all metrics tracked in the history object.\"\"\"\n", + " plt.figure(figsize=(12, 6))\n", + "\n", + " for metric_name, metric_values in history.history.items():\n", + " plt.plot(metric_values, label=metric_name.replace('_', ' ').title())\n", + "\n", + " plt.title('Metrics over Epochs')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Metric Value')\n", + " plt.legend()\n", + " plt.grid(True)\n", + "\n", + "def visualize_layer(matrix, features, cmap=plt.cm.Blues):\n", + " plt.figure(figsize=(9, 9))\n", + "\n", + " im = plt.matshow(matrix, cmap=cmap)\n", + "\n", + " ax = plt.gca()\n", + " divider = make_axes_locatable(plt.gca())\n", + " cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n", + " plt.colorbar(im, cax=cax)\n", + " cax.tick_params(labelsize=10)\n", + "\n", + " # Set tick locations explicitly before setting labels\n", + " ax.set_xticks(np.arange(len(features)))\n", + " ax.set_yticks(np.arange(len(features)))\n", + "\n", + " ax.set_xticklabels(features, rotation=45, fontsize=5)\n", + " ax.set_yticklabels(features, fontsize=5)\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "def train_and_evaluate(\n", + " learning_rate,\n", + " epochs,\n", + " train_data,\n", + " test_data,\n", + " model,\n", + " plot_metrics=False,\n", + "):\n", + " optimizer = keras.optimizers.AdamW(learning_rate=learning_rate)\n", + " loss = keras.losses.MeanSquaredError()\n", + " rmse = keras.metrics.RootMeanSquaredError()\n", + "\n", + " model.compile(\n", + " optimizer=optimizer,\n", + " loss=loss,\n", + " metrics=[rmse],\n", + " )\n", + "\n", + " history = model.fit(\n", + " train_data,\n", + " epochs=epochs,\n", + " verbose=1,\n", + " )\n", + " if plot_metrics:\n", + " plot_training_metrics(history)\n", + "\n", + " results = model.evaluate(test_data, return_dict=True, verbose=1)\n", + " rmse_value = results[\"root_mean_squared_error\"]\n", + "\n", + " return rmse_value, model.count_params()\n", + "\n", + "\n", + "def print_stats(rmse_list, num_params, model_name):\n", + " # Report metrics.\n", + " num_trials = len(rmse_list)\n", + " avg_rmse = np.mean(rmse_list)\n", + " std_rmse = np.std(rmse_list)\n", + "\n", + " if num_trials == 1:\n", + " print(f\"{model_name}: RMSE = {avg_rmse}; #params = {num_params}\")\n", + " else:\n", + " print(f\"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tVHJBIJ_62kV" + }, + "source": [ + "## Real-world example\n", + "\n", + "Let's use the MovieLens 100K dataset. This dataset is used to train models to\n", + "predict users' movie ratings, based on user-related features and movie-related\n", + "features.\n", + "\n", + "### Preparing the dataset\n", + "\n", + "The dataset processing steps here are similar to what's given in the\n", + "[basic ranking](/keras_rs/examples/basic_ranking/)\n", + "tutorial. Let's load the dataset, and keep only the useful columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 187, + "referenced_widgets": [ + "58dfd53b556843dabaecd6cf2d6935ef", + "3d29ef3e67884c69866b982e4039d58c", + "4d9b8158ca2e4a3dba9cd223836238ff", + "efae889f54364be4b29cf1dbc7de5a09", + "6204ff5ab70543c184e221f9c8f241b7", + "612dc1537523492abf3dd0ac3fff16a1", + "5d089130d4474d33a2c2a27315c0f2af", + "8685726324d7451c89cc81048e3cc519", + "67bced2a98ef4a35a3138e2397049489", + "93911ffe44c94b149031cce6606355db", + "423b08111fd444b5b8739d6bb5c64393", + "f3c1148efd194daeae69b8dbcc24c9d5", + "176df16de19f4d32ac1b85c5c2082b40", + "b1fe6058815a412d968e69beea11ecb5", + "06299c873ad44899a078291779f6e8e4", + "ac6c0a7f4a574b2880c6e7a76c54f82b", + "2fd3f84a72f04a809f68af13b7430895", + "749464bb77044f2ba799e1c3595d7b37", + "0d6cb36062bc48f792c40439ca1d1cee", + "d04e7b6888de4ca6b5eb77145d5efdde", + "e6718c0203aa4e7894b1960fa83744b8", + "a5abded0002a4afab3e20a89fe56c893", + "dc7d6718b55b499ebde1bd8c3b77c108", + "799bd47f82f2439ab87a96c9ceabc246", + "069c8512afb24fb9ba180bd7d5f86122", + "d9fc394eeabe49b79910eec9d9b93ceb", + "26a524fd760048149ecaf15ed9e2cc5d", + "bba0acfe88c640188a2b45eb174962a6", + "9dfa77aae53d46b0a1ad0028a2c87c5e", + "8eb493a9d7fe42249acbad63ce15ab0f", + "2a7644ab2b7a45b5be4bb8980192be7f", + "bb222800db944751a360d3da73211a60", + "5a498fb2073245ef8b04911efc6d2d81", + "024cbbb854c943319b9be0d743e09f84", + "d76021431ab94848ab28d21553c37237", + "5ec7292979364685a75a01df20852dee", + "c51f2c0b4b684266881cc78606f91c48", + "129109799f124e80bf5c36fcb74fbc8d", + "874e6e7899994faab2341102a9f407de", + "3c94a1d214364e1b9fd7c3ff52a03026", + "5643b587cb634078bcf7b3cc637c23e0", + "0d1862d8e67e45889c7fde05787d409e", + "92d2807a12cb4a9182c70e061bc3598c", + "a062a4d30c1147feb8faf96fba1af0df", + "6c26c23d5e9544e8b114c816ad667320", + "4643620958e34bf685682c3f44780370", + "d86b9434231b4cd28ce1611e778454c2", + "d5804d8334c54434a447f19188ee751f", + "a1fefc93f4ca4fc9866774e28308d62d", + "1115049423874cd78f3b85184f8d7e15", + "7986925eb5624324aee93da08a630099", + "2e706f8ef4d34b39b092199c17262dd7", + "a6000ca73675474ea7c2b1b7e6c4186c", + "25df22d29aab44578894f0aa0ce53aa4", + "6898fd5bc3634fe681a96bd01b629b53", + "ec2a121d9157485d9002795e0c86d471", + "6a0a7673658a4d77921bd05493ad50b6", + "f0b8e85a1e854efa8a43b9ada4109915", + "68bb8313572f4baf853e4313627cab07", + "71f75584dc4041a0adc23a87dbedc2c2", + "12605d31b53048008fabf7e93ebab4c5", + "f4a0ae32f65a45e4b2015e164c7963be", + "cf2d1c56ac624e7b865769436a3e1f33", + "4c3c33ad6f484d30b091858c13ded0da", + "78cdea4a88414d109466b4b1ad60ad18", + "11ac95d724414a02aec0f5424b271b41" + ] + }, + "id": "kipGDLuZ62kW", + "outputId": "389e6a82-e5fa-4f31-8bd2-1ee107c9aed8" + }, + "outputs": [], + "source": [ + "ratings_ds = tfds.load(\"movielens/100k-ratings\", split=\"train\")\n", + "ratings_ds = ratings_ds.map(\n", + " lambda x: (\n", + " {\n", + " \"movie_id\": int(x[\"movie_id\"]),\n", + " \"user_id\": int(x[\"user_id\"]),\n", + " \"user_gender\": int(x[\"user_gender\"]),\n", + " \"user_zip_code\": x[\"user_zip_code\"],\n", + " \"user_occupation_text\": x[\"user_occupation_text\"],\n", + " \"bucketized_user_age\": int(x[\"bucketized_user_age\"]),\n", + " },\n", + " x[\"user_rating\"], # label\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3DrNFuZi62kW" + }, + "source": [ + "For every feature, let's get the list of unique values, i.e., vocabulary, so\n", + "that we can use that for the embedding layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pczAh0MV62kW" + }, + "outputs": [], + "source": [ + "vocabularies = {}\n", + "for feature_name in MOVIELENS_CONFIG[\"int_features\"] + MOVIELENS_CONFIG[\"str_features\"]:\n", + " vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])\n", + " vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1PLhYqF062kW" + }, + "source": [ + "One thing we need to do is to use `keras.layers.StringLookup` and\n", + "`keras.layers.IntegerLookup` to convert all features into indices, which can\n", + "then be fed into embedding layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pLMvMAIF62kW" + }, + "outputs": [], + "source": [ + "lookup_layers = {}\n", + "lookup_layers.update(\n", + " {\n", + " feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])\n", + " for feature in MOVIELENS_CONFIG[\"int_features\"]\n", + " }\n", + ")\n", + "lookup_layers.update(\n", + " {\n", + " feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])\n", + " for feature in MOVIELENS_CONFIG[\"str_features\"]\n", + " }\n", + ")\n", + "\n", + "ratings_ds = ratings_ds.map(\n", + " lambda x, y: (\n", + " {\n", + " feature_name: lookup_layers[feature_name](x[feature_name])\n", + " for feature_name in vocabularies\n", + " },\n", + " y,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XBKk-mX962kW" + }, + "source": [ + "Let's split our data into train and test sets. We also use `cache()` and\n", + "`prefetch()` for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ha3v4rHN62kW" + }, + "outputs": [], + "source": [ + "ratings_ds = ratings_ds.shuffle(100_000)\n", + "\n", + "train_ds = (\n", + " ratings_ds.take(80_000)\n", + " .batch(MOVIELENS_CONFIG[\"batch_size\"])\n", + " .cache()\n", + " .prefetch(tf.data.AUTOTUNE)\n", + ")\n", + "test_ds = (\n", + " ratings_ds.skip(80_000)\n", + " .batch(MOVIELENS_CONFIG[\"batch_size\"])\n", + " .take(20_000)\n", + " .cache()\n", + " .prefetch(tf.data.AUTOTUNE)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RXuXvsl062kW" + }, + "source": [ + "### Building the model\n", + "\n", + "The model will have embedding layers, followed by DotInteraction and feedforward\n", + "layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ww3ZJ6R062kW" + }, + "outputs": [], + "source": [ + "class DLRM(keras.Model):\n", + " def __init__(\n", + " self,\n", + " dense_num_units_lst,\n", + " embedding_dim=MOVIELENS_CONFIG[\"embedding_dim\"],\n", + " **kwargs,\n", + " ):\n", + " super().__init__(**kwargs)\n", + "\n", + " # Layers.\n", + "\n", + " self.embedding_layers = []\n", + " for feature_name, vocabulary in vocabularies.items():\n", + " self.embedding_layers.append(\n", + " keras.layers.Embedding(\n", + " input_dim=len(vocabulary) + 1,\n", + " output_dim=embedding_dim,\n", + " )\n", + " )\n", + "\n", + " self.dot_layer = keras_rs.layers.DotInteraction()\n", + "\n", + " self.dense_layers = []\n", + " for num_units in dense_num_units_lst:\n", + " self.dense_layers.append(keras.layers.Dense(num_units, activation=\"relu\"))\n", + "\n", + " self.output_layer = keras.layers.Dense(1)\n", + "\n", + " # Attributes.\n", + " self.dense_num_units_lst = dense_num_units_lst\n", + " self.embedding_dim = embedding_dim\n", + "\n", + " def call(self, inputs):\n", + " embeddings = []\n", + " for feature_name, embedding_layer in zip(vocabularies, self.embedding_layers):\n", + " embeddings.append(embedding_layer(inputs[feature_name]))\n", + "\n", + " # Pass the list of embeddings to the DotInteraction layer\n", + " x = self.dot_layer(embeddings)\n", + "\n", + " for dense_layer in self.dense_layers:\n", + " x = dense_layer(x)\n", + "\n", + " x = self.output_layer(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "OAoTyyzi_Nau", + "outputId": "2f469cae-09e1-45fd-853d-79dffc1d5251" + }, + "outputs": [], + "source": [ + "dot_network = DLRM(\n", + " dense_num_units_lst=MOVIELENS_CONFIG[\"deep_net_num_units\"],\n", + " embedding_dim=MOVIELENS_CONFIG[\"embedding_dim\"],\n", + ")\n", + "rmse, dot_network_num_params = train_and_evaluate(\n", + " learning_rate=MOVIELENS_CONFIG[\"learning_rate\"],\n", + " epochs=MOVIELENS_CONFIG[\"num_epochs\"],\n", + " train_data=train_ds,\n", + " test_data=test_ds,\n", + " model= dot_network,\n", + " plot_metrics=True\n", + ")\n", + "print_stats(\n", + " rmse_list=[rmse],\n", + " num_params=dot_network_num_params,\n", + " model_name=\"Dot Network\",\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E68GrcDN62kW" + }, + "source": [ + "### Visualizing feature interactions\n", + "\n", + "The DotInteraction layer itself doesn't have a conventional \"weight\" matrix like a Dense layer. Instead, its function is to compute the dot product between the embedding vectors of your features.\n", + "\n", + "To visualize the strength of these interactions, we can calculate a matrix representing the pairwise interaction strength between all feature embeddings. A common way to do this is to take the dot product of the embedding matrices for each pair of features and then aggregate the result into a single value (like the mean of the absolute values) that represents the overall interaction strength." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R8WPVVn-Jq81" + }, + "outputs": [], + "source": [ + "def get_dot_interaction_matrix(model, feature_names):\n", + " # Extract the trained embedding weights from the model\n", + " embedding_weights = [layer.get_weights()[0] for layer in model.embedding_layers]\n", + "\n", + " num_features = len(feature_names)\n", + " interaction_matrix = np.zeros((num_features, num_features))\n", + "\n", + " # Iterate through each pair of features to calculate their interaction strength\n", + " for i in range(num_features):\n", + " for j in range(num_features):\n", + " # Calculate the dot product between the two embedding matrices\n", + " interaction = np.dot(embedding_weights[i], embedding_weights[j].T)\n", + " # Take the mean of the absolute values as a measure of interaction strength\n", + " interaction_strength = np.mean(np.abs(interaction))\n", + " interaction_matrix[i, j] = interaction_strength\n", + "\n", + " return interaction_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 512 + }, + "id": "NH-VtGrMJtAt", + "outputId": "63eac5d9-8e3e-4b40-bc29-1d51d9827a8d" + }, + "outputs": [], + "source": [ + "# Get the list of feature names in the correct order\n", + "feature_names = list(vocabularies.keys())\n", + "\n", + "# Calculate the interaction matrix\n", + "interaction_matrix = get_dot_interaction_matrix(dot_network, feature_names)\n", + "\n", + "# Visualize the matrix as a heatmap\n", + "print(\"\\nVisualizing the feature interaction strengths:\")\n", + "visualize_layer(interaction_matrix, feature_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IbpjtvnQ62kW" + }, + "outputs": [], + "source": [ + "dlrm_rmse_list = []\n", + "\n", + "for _ in range(20):\n", + "\n", + " dot_network = DLRM(\n", + " dense_num_units_lst=MOVIELENS_CONFIG[\"deep_net_num_units\"],\n", + " embedding_dim=MOVIELENS_CONFIG[\"embedding_dim\"],\n", + " )\n", + " rmse, dot_network_num_params = train_and_evaluate(\n", + " learning_rate=MOVIELENS_CONFIG[\"learning_rate\"],\n", + " epochs=MOVIELENS_CONFIG[\"num_epochs\"],\n", + " train_data=train_ds,\n", + " test_data=test_ds,\n", + " model= dot_network,\n", + " )\n", + " dlrm_rmse_list.append(rmse)\n", + "\n", + "print_stats(\n", + " rmse_list=dlrm_rmse_list,\n", + " num_params=dot_network_num_params,\n", + " model_name=\"Dot Network\",\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8mskfVqa-uGx" + }, + "outputs": [], + "source": [ + "def plot_rmse(rmse_list, model_name):\n", + " plt.figure()\n", + " plt.plot(rmse_list)\n", + " plt.title(f'RMSE over trials for {model_name}')\n", + " plt.xlabel('Trial')\n", + " plt.ylabel('RMSE')\n", + " plt.show()\n", + "\n", + "plot_rmse(dlrm_rmse_list, \"Dot Network\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "024cbbb854c943319b9be0d743e09f84": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_d76021431ab94848ab28d21553c37237", + "IPY_MODEL_5ec7292979364685a75a01df20852dee", + "IPY_MODEL_c51f2c0b4b684266881cc78606f91c48" + ], + "layout": "IPY_MODEL_129109799f124e80bf5c36fcb74fbc8d" + } + }, + "06299c873ad44899a078291779f6e8e4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6718c0203aa4e7894b1960fa83744b8", + "placeholder": "​", + "style": "IPY_MODEL_a5abded0002a4afab3e20a89fe56c893", + "value": " 4/4 [00:00<00:00,  6.70 MiB/s]" + } + }, + "069c8512afb24fb9ba180bd7d5f86122": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8eb493a9d7fe42249acbad63ce15ab0f", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2a7644ab2b7a45b5be4bb8980192be7f", + "value": 1 + } + }, + "0d1862d8e67e45889c7fde05787d409e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0d6cb36062bc48f792c40439ca1d1cee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "1115049423874cd78f3b85184f8d7e15": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "11ac95d724414a02aec0f5424b271b41": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "12605d31b53048008fabf7e93ebab4c5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "129109799f124e80bf5c36fcb74fbc8d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "176df16de19f4d32ac1b85c5c2082b40": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2fd3f84a72f04a809f68af13b7430895", + "placeholder": "​", + "style": "IPY_MODEL_749464bb77044f2ba799e1c3595d7b37", + "value": "Dl Size...: 100%" + } + }, + "25df22d29aab44578894f0aa0ce53aa4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "26a524fd760048149ecaf15ed9e2cc5d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2a7644ab2b7a45b5be4bb8980192be7f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "2e706f8ef4d34b39b092199c17262dd7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "2fd3f84a72f04a809f68af13b7430895": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3c94a1d214364e1b9fd7c3ff52a03026": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3d29ef3e67884c69866b982e4039d58c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_612dc1537523492abf3dd0ac3fff16a1", + "placeholder": "​", + "style": "IPY_MODEL_5d089130d4474d33a2c2a27315c0f2af", + "value": "Dl Completed...: 100%" + } + }, + "423b08111fd444b5b8739d6bb5c64393": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4643620958e34bf685682c3f44780370": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1115049423874cd78f3b85184f8d7e15", + "placeholder": "​", + "style": "IPY_MODEL_7986925eb5624324aee93da08a630099", + "value": "Generating train examples...: " + } + }, + "4c3c33ad6f484d30b091858c13ded0da": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4d9b8158ca2e4a3dba9cd223836238ff": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8685726324d7451c89cc81048e3cc519", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_67bced2a98ef4a35a3138e2397049489", + "value": 1 + } + }, + "5643b587cb634078bcf7b3cc637c23e0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "58dfd53b556843dabaecd6cf2d6935ef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3d29ef3e67884c69866b982e4039d58c", + "IPY_MODEL_4d9b8158ca2e4a3dba9cd223836238ff", + "IPY_MODEL_efae889f54364be4b29cf1dbc7de5a09" + ], + "layout": "IPY_MODEL_6204ff5ab70543c184e221f9c8f241b7" + } + }, + "5a498fb2073245ef8b04911efc6d2d81": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5d089130d4474d33a2c2a27315c0f2af": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5ec7292979364685a75a01df20852dee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5643b587cb634078bcf7b3cc637c23e0", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0d1862d8e67e45889c7fde05787d409e", + "value": 1 + } + }, + "612dc1537523492abf3dd0ac3fff16a1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6204ff5ab70543c184e221f9c8f241b7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "67bced2a98ef4a35a3138e2397049489": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6898fd5bc3634fe681a96bd01b629b53": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "68bb8313572f4baf853e4313627cab07": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_78cdea4a88414d109466b4b1ad60ad18", + "placeholder": "​", + "style": "IPY_MODEL_11ac95d724414a02aec0f5424b271b41", + "value": " 0/100000 [00:00<?, ? examples/s]" + } + }, + "6a0a7673658a4d77921bd05493ad50b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_12605d31b53048008fabf7e93ebab4c5", + "placeholder": "​", + "style": "IPY_MODEL_f4a0ae32f65a45e4b2015e164c7963be", + "value": "Shuffling /root/tensorflow_datasets/movielens/100k-ratings/incomplete.S3UIO5_0.1.1/movielens-train.tfrecord*...:   0%" + } + }, + "6c26c23d5e9544e8b114c816ad667320": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4643620958e34bf685682c3f44780370", + "IPY_MODEL_d86b9434231b4cd28ce1611e778454c2", + "IPY_MODEL_d5804d8334c54434a447f19188ee751f" + ], + "layout": "IPY_MODEL_a1fefc93f4ca4fc9866774e28308d62d" + } + }, + "71f75584dc4041a0adc23a87dbedc2c2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "749464bb77044f2ba799e1c3595d7b37": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "78cdea4a88414d109466b4b1ad60ad18": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7986925eb5624324aee93da08a630099": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "799bd47f82f2439ab87a96c9ceabc246": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bba0acfe88c640188a2b45eb174962a6", + "placeholder": "​", + "style": "IPY_MODEL_9dfa77aae53d46b0a1ad0028a2c87c5e", + "value": "Extraction completed...: 100%" + } + }, + "8685726324d7451c89cc81048e3cc519": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "874e6e7899994faab2341102a9f407de": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8eb493a9d7fe42249acbad63ce15ab0f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "92d2807a12cb4a9182c70e061bc3598c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "93911ffe44c94b149031cce6606355db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9dfa77aae53d46b0a1ad0028a2c87c5e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a062a4d30c1147feb8faf96fba1af0df": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a1fefc93f4ca4fc9866774e28308d62d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "a5abded0002a4afab3e20a89fe56c893": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a6000ca73675474ea7c2b1b7e6c4186c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ac6c0a7f4a574b2880c6e7a76c54f82b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b1fe6058815a412d968e69beea11ecb5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0d6cb36062bc48f792c40439ca1d1cee", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d04e7b6888de4ca6b5eb77145d5efdde", + "value": 1 + } + }, + "bb222800db944751a360d3da73211a60": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bba0acfe88c640188a2b45eb174962a6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c51f2c0b4b684266881cc78606f91c48": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_92d2807a12cb4a9182c70e061bc3598c", + "placeholder": "​", + "style": "IPY_MODEL_a062a4d30c1147feb8faf96fba1af0df", + "value": " 1/1 [00:28<00:00, 28.00s/ splits]" + } + }, + "cf2d1c56ac624e7b865769436a3e1f33": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d04e7b6888de4ca6b5eb77145d5efdde": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d5804d8334c54434a447f19188ee751f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_25df22d29aab44578894f0aa0ce53aa4", + "placeholder": "​", + "style": "IPY_MODEL_6898fd5bc3634fe681a96bd01b629b53", + "value": " 97781/? [00:27<00:00, 3799.36 examples/s]" + } + }, + "d76021431ab94848ab28d21553c37237": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_874e6e7899994faab2341102a9f407de", + "placeholder": "​", + "style": "IPY_MODEL_3c94a1d214364e1b9fd7c3ff52a03026", + "value": "Generating splits...: 100%" + } + }, + "d86b9434231b4cd28ce1611e778454c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "info", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2e706f8ef4d34b39b092199c17262dd7", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a6000ca73675474ea7c2b1b7e6c4186c", + "value": 1 + } + }, + "d9fc394eeabe49b79910eec9d9b93ceb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bb222800db944751a360d3da73211a60", + "placeholder": "​", + "style": "IPY_MODEL_5a498fb2073245ef8b04911efc6d2d81", + "value": " 23/23 [00:00<00:00, 39.76 file/s]" + } + }, + "dc7d6718b55b499ebde1bd8c3b77c108": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_799bd47f82f2439ab87a96c9ceabc246", + "IPY_MODEL_069c8512afb24fb9ba180bd7d5f86122", + "IPY_MODEL_d9fc394eeabe49b79910eec9d9b93ceb" + ], + "layout": "IPY_MODEL_26a524fd760048149ecaf15ed9e2cc5d" + } + }, + "e6718c0203aa4e7894b1960fa83744b8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ec2a121d9157485d9002795e0c86d471": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6a0a7673658a4d77921bd05493ad50b6", + "IPY_MODEL_f0b8e85a1e854efa8a43b9ada4109915", + "IPY_MODEL_68bb8313572f4baf853e4313627cab07" + ], + "layout": "IPY_MODEL_71f75584dc4041a0adc23a87dbedc2c2" + } + }, + "efae889f54364be4b29cf1dbc7de5a09": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_93911ffe44c94b149031cce6606355db", + "placeholder": "​", + "style": "IPY_MODEL_423b08111fd444b5b8739d6bb5c64393", + "value": " 1/1 [00:00<00:00,  1.60 url/s]" + } + }, + "f0b8e85a1e854efa8a43b9ada4109915": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cf2d1c56ac624e7b865769436a3e1f33", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4c3c33ad6f484d30b091858c13ded0da", + "value": 100000 + } + }, + "f3c1148efd194daeae69b8dbcc24c9d5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_176df16de19f4d32ac1b85c5c2082b40", + "IPY_MODEL_b1fe6058815a412d968e69beea11ecb5", + "IPY_MODEL_06299c873ad44899a078291779f6e8e4" + ], + "layout": "IPY_MODEL_ac6c0a7f4a574b2880c6e7a76c54f82b" + } + }, + "f4a0ae32f65a45e4b2015e164c7963be": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}