diff --git a/guides/custom_train_step_in_jax.py b/guides/custom_train_step_in_jax.py index 6cdd8b912b..c4bcf0a8ce 100644 --- a/guides/custom_train_step_in_jax.py +++ b/guides/custom_train_step_in_jax.py @@ -60,17 +60,10 @@ - We implement a fully-stateless `compute_loss_and_updates()` method to compute the loss as well as the updated values for the non-trainable variables of the model. Internally, it calls `stateless_call()` and -the built-in `compute_loss()`. +the built-in `stateless_compute_loss()`. - We implement a fully-stateless `train_step()` method to compute current metric values (including the loss) as well as updated values for the trainable variables, the optimizer variables, and the metric variables. - -Note that you can also take into account the `sample_weight` argument by: - -- Unpacking the data as `x, y, sample_weight = data` -- Passing `sample_weight` to `compute_loss()` -- Passing `sample_weight` alongside `y` and `y_pred` -to metrics in `stateless_update_state()` """ @@ -79,8 +72,10 @@ def compute_loss_and_updates( self, trainable_variables, non_trainable_variables, + metrics_variables, x, y, + sample_weight, training=False, ): y_pred, non_trainable_variables = self.stateless_call( @@ -89,8 +84,21 @@ def compute_loss_and_updates( x, training=training, ) - loss = self.compute_loss(x, y, y_pred) - return loss, (y_pred, non_trainable_variables) + loss, ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = self.stateless_compute_loss( + trainable_variables, + non_trainable_variables, + metrics_variables, + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=training, + ) + return loss, (y_pred, non_trainable_variables, metrics_variables) def train_step(self, state, data): ( @@ -99,25 +107,24 @@ def train_step(self, state, data): optimizer_variables, metrics_variables, ) = state - x, y = data + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) # Get the gradient function. grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True) # Compute the gradients. - (loss, (y_pred, non_trainable_variables)), grads = grad_fn( + (loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, + metrics_variables, x, y, + sample_weight, training=True, ) # Update trainable variables and optimizer variables. - ( - trainable_variables, - optimizer_variables, - ) = self.optimizer.stateless_apply( + trainable_variables, optimizer_variables = self.optimizer.stateless_apply( optimizer_variables, grads, trainable_variables ) @@ -129,10 +136,12 @@ def train_step(self, state, data): len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": - this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) + this_metric_vars = metric.stateless_update_state( + this_metric_vars, loss, sample_weight=sample_weight + ) else: this_metric_vars = metric.stateless_update_state( - this_metric_vars, y, y_pred + this_metric_vars, y, y_pred, sample_weight=sample_weight ) logs[metric.name] = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars @@ -186,6 +195,7 @@ def compute_loss_and_updates( non_trainable_variables, x, y, + sample_weight, training=False, ): y_pred, non_trainable_variables = self.stateless_call( @@ -194,7 +204,7 @@ def compute_loss_and_updates( x, training=training, ) - loss = self.loss_fn(y, y_pred) + loss = self.loss_fn(y, y_pred, sample_weight=sample_weight) return loss, (y_pred, non_trainable_variables) def train_step(self, state, data): @@ -204,7 +214,7 @@ def train_step(self, state, data): optimizer_variables, metrics_variables, ) = state - x, y = data + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) # Get the gradient function. grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True) @@ -215,14 +225,12 @@ def train_step(self, state, data): non_trainable_variables, x, y, + sample_weight, training=True, ) # Update trainable variables and optimizer variables. - ( - trainable_variables, - optimizer_variables, - ) = self.optimizer.stateless_apply( + trainable_variables, optimizer_variables = self.optimizer.stateless_apply( optimizer_variables, grads, trainable_variables ) @@ -231,10 +239,10 @@ def train_step(self, state, data): mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :] loss_tracker_vars = self.loss_tracker.stateless_update_state( - loss_tracker_vars, loss + loss_tracker_vars, loss, sample_weight=sample_weight ) mae_metric_vars = self.mae_metric.stateless_update_state( - mae_metric_vars, y, y_pred + mae_metric_vars, y, y_pred, sample_weight=sample_weight ) logs = {} @@ -287,7 +295,7 @@ def metrics(self): class CustomModel(keras.Model): def test_step(self, state, data): # Unpack the data. - x, y = data + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) ( trainable_variables, non_trainable_variables, @@ -301,21 +309,37 @@ def test_step(self, state, data): x, training=False, ) - loss = self.compute_loss(x, y, y_pred) + loss, ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = self.stateless_compute_loss( + trainable_variables, + non_trainable_variables, + metrics_variables, + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=False, + ) # Update metrics. new_metrics_vars = [] + logs = {} for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": - this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) + this_metric_vars = metric.stateless_update_state( + this_metric_vars, loss, sample_weight=sample_weight + ) else: this_metric_vars = metric.stateless_update_state( - this_metric_vars, y, y_pred + this_metric_vars, y, y_pred, sample_weight=sample_weight ) - logs = metric.stateless_result(this_metric_vars) + logs[metric.name] = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars # Return metric logs and updated state variables. @@ -336,7 +360,7 @@ def test_step(self, state, data): # Evaluate with our custom test_step x = np.random.random((1000, 32)) y = np.random.random((1000, 1)) -model.evaluate(x, y) +model.evaluate(x, y, return_dict=True) """ diff --git a/guides/ipynb/custom_train_step_in_jax.ipynb b/guides/ipynb/custom_train_step_in_jax.ipynb index a975b66ab0..2fa5ece61f 100644 --- a/guides/ipynb/custom_train_step_in_jax.ipynb +++ b/guides/ipynb/custom_train_step_in_jax.ipynb @@ -91,17 +91,10 @@ "- We implement a fully-stateless `compute_loss_and_updates()` method\n", "to compute the loss as well as the updated values for the non-trainable\n", "variables of the model. Internally, it calls `stateless_call()` and\n", - "the built-in `compute_loss()`.\n", + "the built-in `stateless_compute_loss()`.\n", "- We implement a fully-stateless `train_step()` method to compute current\n", "metric values (including the loss) as well as updated values for the\n", - "trainable variables, the optimizer variables, and the metric variables.\n", - "\n", - "Note that you can also take into account the `sample_weight` argument by:\n", - "\n", - "- Unpacking the data as `x, y, sample_weight = data`\n", - "- Passing `sample_weight` to `compute_loss()`\n", - "- Passing `sample_weight` alongside `y` and `y_pred`\n", - "to metrics in `stateless_update_state()`" + "trainable variables, the optimizer variables, and the metric variables." ] }, { @@ -118,8 +111,10 @@ " self,\n", " trainable_variables,\n", " non_trainable_variables,\n", + " metrics_variables,\n", " x,\n", " y,\n", + " sample_weight,\n", " training=False,\n", " ):\n", " y_pred, non_trainable_variables = self.stateless_call(\n", @@ -128,8 +123,21 @@ " x,\n", " training=training,\n", " )\n", - " loss = self.compute_loss(x, y, y_pred)\n", - " return loss, (y_pred, non_trainable_variables)\n", + " loss, (\n", + " trainable_variables,\n", + " non_trainable_variables,\n", + " metrics_variables,\n", + " ) = self.stateless_compute_loss(\n", + " trainable_variables,\n", + " non_trainable_variables,\n", + " metrics_variables,\n", + " x=x,\n", + " y=y,\n", + " y_pred=y_pred,\n", + " sample_weight=sample_weight,\n", + " training=training,\n", + " )\n", + " return loss, (y_pred, non_trainable_variables, metrics_variables)\n", "\n", " def train_step(self, state, data):\n", " (\n", @@ -138,25 +146,24 @@ " optimizer_variables,\n", " metrics_variables,\n", " ) = state\n", - " x, y = data\n", + " x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)\n", "\n", " # Get the gradient function.\n", " grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)\n", "\n", " # Compute the gradients.\n", - " (loss, (y_pred, non_trainable_variables)), grads = grad_fn(\n", + " (loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(\n", " trainable_variables,\n", " non_trainable_variables,\n", + " metrics_variables,\n", " x,\n", " y,\n", + " sample_weight,\n", " training=True,\n", " )\n", "\n", " # Update trainable variables and optimizer variables.\n", - " (\n", - " trainable_variables,\n", - " optimizer_variables,\n", - " ) = self.optimizer.stateless_apply(\n", + " trainable_variables, optimizer_variables = self.optimizer.stateless_apply(\n", " optimizer_variables, grads, trainable_variables\n", " )\n", "\n", @@ -168,10 +175,12 @@ " len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)\n", " ]\n", " if metric.name == \"loss\":\n", - " this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n", + " this_metric_vars = metric.stateless_update_state(\n", + " this_metric_vars, loss, sample_weight=sample_weight\n", + " )\n", " else:\n", " this_metric_vars = metric.stateless_update_state(\n", - " this_metric_vars, y, y_pred\n", + " this_metric_vars, y, y_pred, sample_weight=sample_weight\n", " )\n", " logs[metric.name] = metric.stateless_result(this_metric_vars)\n", " new_metrics_vars += this_metric_vars\n", @@ -253,6 +262,7 @@ " non_trainable_variables,\n", " x,\n", " y,\n", + " sample_weight,\n", " training=False,\n", " ):\n", " y_pred, non_trainable_variables = self.stateless_call(\n", @@ -261,7 +271,7 @@ " x,\n", " training=training,\n", " )\n", - " loss = self.loss_fn(y, y_pred)\n", + " loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)\n", " return loss, (y_pred, non_trainable_variables)\n", "\n", " def train_step(self, state, data):\n", @@ -271,7 +281,7 @@ " optimizer_variables,\n", " metrics_variables,\n", " ) = state\n", - " x, y = data\n", + " x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)\n", "\n", " # Get the gradient function.\n", " grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)\n", @@ -282,14 +292,12 @@ " non_trainable_variables,\n", " x,\n", " y,\n", + " sample_weight,\n", " training=True,\n", " )\n", "\n", " # Update trainable variables and optimizer variables.\n", - " (\n", - " trainable_variables,\n", - " optimizer_variables,\n", - " ) = self.optimizer.stateless_apply(\n", + " trainable_variables, optimizer_variables = self.optimizer.stateless_apply(\n", " optimizer_variables, grads, trainable_variables\n", " )\n", "\n", @@ -298,10 +306,10 @@ " mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]\n", "\n", " loss_tracker_vars = self.loss_tracker.stateless_update_state(\n", - " loss_tracker_vars, loss\n", + " loss_tracker_vars, loss, sample_weight=sample_weight\n", " )\n", " mae_metric_vars = self.mae_metric.stateless_update_state(\n", - " mae_metric_vars, y, y_pred\n", + " mae_metric_vars, y, y_pred, sample_weight=sample_weight\n", " )\n", "\n", " logs = {}\n", @@ -368,7 +376,7 @@ "class CustomModel(keras.Model):\n", " def test_step(self, state, data):\n", " # Unpack the data.\n", - " x, y = data\n", + " x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)\n", " (\n", " trainable_variables,\n", " non_trainable_variables,\n", @@ -382,21 +390,37 @@ " x,\n", " training=False,\n", " )\n", - " loss = self.compute_loss(x, y, y_pred)\n", + " loss, (\n", + " trainable_variables,\n", + " non_trainable_variables,\n", + " metrics_variables,\n", + " ) = self.stateless_compute_loss(\n", + " trainable_variables,\n", + " non_trainable_variables,\n", + " metrics_variables,\n", + " x=x,\n", + " y=y,\n", + " y_pred=y_pred,\n", + " sample_weight=sample_weight,\n", + " training=False,\n", + " )\n", "\n", " # Update metrics.\n", " new_metrics_vars = []\n", + " logs = {}\n", " for metric in self.metrics:\n", " this_metric_vars = metrics_variables[\n", " len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)\n", " ]\n", " if metric.name == \"loss\":\n", - " this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n", + " this_metric_vars = metric.stateless_update_state(\n", + " this_metric_vars, loss, sample_weight=sample_weight\n", + " )\n", " else:\n", " this_metric_vars = metric.stateless_update_state(\n", - " this_metric_vars, y, y_pred\n", + " this_metric_vars, y, y_pred, sample_weight=sample_weight\n", " )\n", - " logs = metric.stateless_result(this_metric_vars)\n", + " logs[metric.name] = metric.stateless_result(this_metric_vars)\n", " new_metrics_vars += this_metric_vars\n", "\n", " # Return metric logs and updated state variables.\n", @@ -417,7 +441,7 @@ "# Evaluate with our custom test_step\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", - "model.evaluate(x, y)\n", + "model.evaluate(x, y, return_dict=True)\n", "" ] }, diff --git a/guides/md/custom_train_step_in_jax.md b/guides/md/custom_train_step_in_jax.md index d61d9ef5b5..a5c5bde722 100644 --- a/guides/md/custom_train_step_in_jax.md +++ b/guides/md/custom_train_step_in_jax.md @@ -64,18 +64,11 @@ Let's start from a simple example: - We implement a fully-stateless `compute_loss_and_updates()` method to compute the loss as well as the updated values for the non-trainable variables of the model. Internally, it calls `stateless_call()` and -the built-in `compute_loss()`. +the built-in `stateless_compute_loss()`. - We implement a fully-stateless `train_step()` method to compute current metric values (including the loss) as well as updated values for the trainable variables, the optimizer variables, and the metric variables. -Note that you can also take into account the `sample_weight` argument by: - -- Unpacking the data as `x, y, sample_weight = data` -- Passing `sample_weight` to `compute_loss()` -- Passing `sample_weight` alongside `y` and `y_pred` -to metrics in `stateless_update_state()` - ```python @@ -84,8 +77,10 @@ class CustomModel(keras.Model): self, trainable_variables, non_trainable_variables, + metrics_variables, x, y, + sample_weight, training=False, ): y_pred, non_trainable_variables = self.stateless_call( @@ -94,8 +89,21 @@ class CustomModel(keras.Model): x, training=training, ) - loss = self.compute_loss(x, y, y_pred) - return loss, (y_pred, non_trainable_variables) + loss, ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = self.stateless_compute_loss( + trainable_variables, + non_trainable_variables, + metrics_variables, + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=training, + ) + return loss, (y_pred, non_trainable_variables, metrics_variables) def train_step(self, state, data): ( @@ -104,25 +112,24 @@ class CustomModel(keras.Model): optimizer_variables, metrics_variables, ) = state - x, y = data + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) # Get the gradient function. grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True) # Compute the gradients. - (loss, (y_pred, non_trainable_variables)), grads = grad_fn( + (loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn( trainable_variables, non_trainable_variables, + metrics_variables, x, y, + sample_weight, training=True, ) # Update trainable variables and optimizer variables. - ( - trainable_variables, - optimizer_variables, - ) = self.optimizer.stateless_apply( + trainable_variables, optimizer_variables = self.optimizer.stateless_apply( optimizer_variables, grads, trainable_variables ) @@ -134,10 +141,12 @@ class CustomModel(keras.Model): len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) ] if metric.name == "loss": - this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) + this_metric_vars = metric.stateless_update_state( + this_metric_vars, loss, sample_weight=sample_weight + ) else: this_metric_vars = metric.stateless_update_state( - this_metric_vars, y, y_pred + this_metric_vars, y, y_pred, sample_weight=sample_weight ) logs[metric.name] = metric.stateless_result(this_metric_vars) new_metrics_vars += this_metric_vars @@ -173,16 +182,21 @@ model.fit(x, y, epochs=3)