From 9837aeb15d75bc0c2dd41706f51c81c6e0275f7f Mon Sep 17 00:00:00 2001 From: Gerixmus Date: Wed, 15 Oct 2025 20:01:02 +0200 Subject: [PATCH 1/3] docs: update to nnx optimizer --- docs/source/JAX_Vision_transformer.ipynb | 2 +- docs/source/JAX_Vision_transformer.md | 2 +- docs/source/JAX_basic_text_classification.ipynb | 4 ++-- docs/source/JAX_basic_text_classification.md | 2 +- docs/source/JAX_examples_image_segmentation.ipynb | 2 +- docs/source/JAX_examples_image_segmentation.md | 2 +- docs/source/JAX_for_LLM_pretraining.ipynb | 2 +- docs/source/JAX_for_LLM_pretraining.md | 2 +- docs/source/JAX_for_PyTorch_users.ipynb | 2 +- docs/source/JAX_for_PyTorch_users.md | 2 +- docs/source/JAX_image_captioning.ipynb | 2 +- docs/source/JAX_image_captioning.md | 2 +- docs/source/JAX_machine_translation.ipynb | 2 +- docs/source/JAX_machine_translation.md | 2 +- docs/source/JAX_time_series_classification.ipynb | 2 +- docs/source/JAX_time_series_classification.md | 2 +- .../source/JAX_transformer_text_classification.ipynb | 2 +- docs/source/JAX_transformer_text_classification.md | 2 +- docs/source/JAX_visualizing_models_metrics.ipynb | 2 +- docs/source/JAX_visualizing_models_metrics.md | 2 +- docs/source/digits_diffusion_model.ipynb | 2 +- docs/source/digits_diffusion_model.md | 2 +- docs/source/digits_vae.ipynb | 12 ++++++------ docs/source/digits_vae.md | 12 ++++++------ docs/source/neural_net_basics.ipynb | 2 +- docs/source/neural_net_basics.md | 2 +- 26 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/source/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb index c391d932..cf56fccf 100644 --- a/docs/source/JAX_Vision_transformer.ipynb +++ b/docs/source/JAX_Vision_transformer.ipynb @@ -879,7 +879,7 @@ "plt.show()\n", "\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))" + "optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))" ] }, { diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index 5e1e5b63..94d1a889 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -619,7 +619,7 @@ plt.xlim((0, num_epochs)) plt.show() -optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True)) +optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True)) ``` Define a loss function with `optax.softmax_cross_entropy_with_integer_labels`: diff --git a/docs/source/JAX_basic_text_classification.ipynb b/docs/source/JAX_basic_text_classification.ipynb index bae39f30..bde4d185 100644 --- a/docs/source/JAX_basic_text_classification.ipynb +++ b/docs/source/JAX_basic_text_classification.ipynb @@ -516,7 +516,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "id": "6d9f4756-4e64-49d1-81dd-20c0e0480dd0", "metadata": {}, "outputs": [], @@ -528,7 +528,7 @@ "learning_rate = 0.0005\n", "momentum = 0.9\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))" + "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, { diff --git a/docs/source/JAX_basic_text_classification.md b/docs/source/JAX_basic_text_classification.md index 5621ba37..f9d64429 100644 --- a/docs/source/JAX_basic_text_classification.md +++ b/docs/source/JAX_basic_text_classification.md @@ -303,7 +303,7 @@ num_epochs = 10 learning_rate = 0.0005 momentum = 0.9 -optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum)) +optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) ``` ```{code-cell} ipython3 diff --git a/docs/source/JAX_examples_image_segmentation.ipynb b/docs/source/JAX_examples_image_segmentation.ipynb index 7d4a825b..c1a34dbc 100644 --- a/docs/source/JAX_examples_image_segmentation.ipynb +++ b/docs/source/JAX_examples_image_segmentation.ipynb @@ -1587,7 +1587,7 @@ "plt.show()\n", "\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr_schedule, momentum))" + "optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum))" ] }, { diff --git a/docs/source/JAX_examples_image_segmentation.md b/docs/source/JAX_examples_image_segmentation.md index 1fad1e8b..c5268dbc 100644 --- a/docs/source/JAX_examples_image_segmentation.md +++ b/docs/source/JAX_examples_image_segmentation.md @@ -1067,7 +1067,7 @@ plt.xlim((0, num_epochs)) plt.show() -optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr_schedule, momentum)) +optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum)) ``` Let us implement Jaccard loss and the loss function combining Cross-Entropy and Jaccard losses. diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index d8a4aab8..67b0c667 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -978,7 +978,7 @@ ], "source": [ "model = create_model(rngs=nnx.Rngs(0))\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", ")\n", diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index c26d9290..863ce7f5 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -476,7 +476,7 @@ id: Ysl6CsfENeJN outputId: 5dd06dca-f030-4927-a9b6-35d412da535c --- model = create_model(rngs=nnx.Rngs(0)) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) metrics = nnx.MultiMetric( loss=nnx.metrics.Average('loss'), ) diff --git a/docs/source/JAX_for_PyTorch_users.ipynb b/docs/source/JAX_for_PyTorch_users.ipynb index e77950b6..a3a01c77 100644 --- a/docs/source/JAX_for_PyTorch_users.ipynb +++ b/docs/source/JAX_for_PyTorch_users.ipynb @@ -1272,7 +1272,7 @@ "learning_rate = 0.005\n", "momentum = 0.9\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum))" + "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))" ] }, { diff --git a/docs/source/JAX_for_PyTorch_users.md b/docs/source/JAX_for_PyTorch_users.md index 8998f22d..0e9fc6d0 100644 --- a/docs/source/JAX_for_PyTorch_users.md +++ b/docs/source/JAX_for_PyTorch_users.md @@ -703,7 +703,7 @@ import optax learning_rate = 0.005 momentum = 0.9 -optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum)) +optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) ``` ```{code-cell} ipython3 diff --git a/docs/source/JAX_image_captioning.ipynb b/docs/source/JAX_image_captioning.ipynb index d6c5f64a..3b48b02f 100644 --- a/docs/source/JAX_image_captioning.ipynb +++ b/docs/source/JAX_image_captioning.ipynb @@ -1304,7 +1304,7 @@ "momentum = 0.9\n", "total_steps = len(train_dataset) // train_batch_size\n", "\n", - "optimizer = nnx.ModelAndOptimizer(\n", + "optimizer = nnx.Optimizer(\n", " model, optax.sgd(learning_rate, momentum, nesterov=True), wrt=trainable_params_filter\n", ")" ] diff --git a/docs/source/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md index c521058d..f51819c2 100644 --- a/docs/source/JAX_image_captioning.md +++ b/docs/source/JAX_image_captioning.md @@ -925,7 +925,7 @@ learning_rate = 0.015 momentum = 0.9 total_steps = len(train_dataset) // train_batch_size -optimizer = nnx.ModelAndOptimizer( +optimizer = nnx.Optimizer( model, optax.sgd(learning_rate, momentum, nesterov=True), wrt=trainable_params_filter ) ``` diff --git a/docs/source/JAX_machine_translation.ipynb b/docs/source/JAX_machine_translation.ipynb index 34e1af37..99394c6e 100644 --- a/docs/source/JAX_machine_translation.ipynb +++ b/docs/source/JAX_machine_translation.ipynb @@ -647,7 +647,7 @@ "outputs": [], "source": [ "model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate))" + "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))" ] }, { diff --git a/docs/source/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md index e357ebd5..16e87778 100644 --- a/docs/source/JAX_machine_translation.md +++ b/docs/source/JAX_machine_translation.md @@ -452,7 +452,7 @@ def evaluate_model(epoch): ```{code-cell} ipython3 model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng) -optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate)) +optimizer = nnx.Optimizer(model, optax.adamw(learning_rate)) ``` ## Start the Training! diff --git a/docs/source/JAX_time_series_classification.ipynb b/docs/source/JAX_time_series_classification.ipynb index 9163f0e9..be8a69cb 100644 --- a/docs/source/JAX_time_series_classification.ipynb +++ b/docs/source/JAX_time_series_classification.ipynb @@ -506,7 +506,7 @@ "learning_rate = 0.0005\n", "momentum = 0.9\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))" + "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, { diff --git a/docs/source/JAX_time_series_classification.md b/docs/source/JAX_time_series_classification.md index 57ff86a6..ec7f4171 100644 --- a/docs/source/JAX_time_series_classification.md +++ b/docs/source/JAX_time_series_classification.md @@ -250,7 +250,7 @@ num_epochs = 300 learning_rate = 0.0005 momentum = 0.9 -optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum)) +optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) ``` We'll define a loss and logits computation function using Optax's diff --git a/docs/source/JAX_transformer_text_classification.ipynb b/docs/source/JAX_transformer_text_classification.ipynb index 6ef22c6e..476d58fc 100644 --- a/docs/source/JAX_transformer_text_classification.ipynb +++ b/docs/source/JAX_transformer_text_classification.ipynb @@ -758,7 +758,7 @@ "learning_rate = 0.0001 # The learning rate.\n", "momentum = 0.9 # Momentum for Adam.\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))" + "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, { diff --git a/docs/source/JAX_transformer_text_classification.md b/docs/source/JAX_transformer_text_classification.md index b2294a1a..882a7413 100644 --- a/docs/source/JAX_transformer_text_classification.md +++ b/docs/source/JAX_transformer_text_classification.md @@ -377,7 +377,7 @@ num_epochs = 10 # Number of epochs during training. learning_rate = 0.0001 # The learning rate. momentum = 0.9 # Momentum for Adam. -optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum)) +optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) ``` Next, we define the loss function - `compute_losses_and_logits()` - using `optax.softmax_cross_entropy_with_integer_labels`: diff --git a/docs/source/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb index cc5a801e..5280d493 100644 --- a/docs/source/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -248,7 +248,7 @@ "import jax\n", "import optax\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05))\n", + "optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))\n", "\n", "def loss_fun(\n", " model: nnx.Module,\n", diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 37c9fdbb..7819d42e 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -145,7 +145,7 @@ In order to track loss across our training run, we've collected the loss functio import jax import optax -optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05)) +optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05)) def loss_fun( model: nnx.Module, diff --git a/docs/source/digits_diffusion_model.ipynb b/docs/source/digits_diffusion_model.ipynb index 54e97572..4b4088a9 100644 --- a/docs/source/digits_diffusion_model.ipynb +++ b/docs/source/digits_diffusion_model.ipynb @@ -758,7 +758,7 @@ ")\n", "\n", "# Optimizer configuration (AdamW) with gradient clipping.\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.chain(\n", + "optimizer = nnx.Optimizer(model, optax.chain(\n", " optax.clip_by_global_norm(0.5), # Gradient clipping.\n", " optax.adamw(\n", " learning_rate=schedule_fn,\n", diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 4a8b2815..9bb021b5 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -641,7 +641,7 @@ schedule_fn = optax.join_schedules( ) # Optimizer configuration (AdamW) with gradient clipping. -optimizer = nnx.ModelAndOptimizer(model, optax.chain( +optimizer = nnx.Optimizer(model, optax.chain( optax.clip_by_global_norm(0.5), # Gradient clipping. optax.adamw( learning_rate=schedule_fn, diff --git a/docs/source/digits_vae.ipynb b/docs/source/digits_vae.ipynb index f4bd9b19..086500a3 100644 --- a/docs/source/digits_vae.ipynb +++ b/docs/source/digits_vae.ipynb @@ -297,7 +297,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "\n", "@nnx.jit\n", "def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):\n", @@ -449,7 +449,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "\n", "with jax.debug_nans(True):\n", " for epoch in range(2001):\n", @@ -516,7 +516,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "\n", "for epoch in range(501):\n", " loss = train_step(model, optimizer, images_train)\n", @@ -597,7 +597,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "train_step(model, optimizer, images_train)" ] }, @@ -654,7 +654,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "\n", "for i in range(5):\n", " train_step(model, optimizer, images_train)" @@ -807,7 +807,7 @@ " rngs=nnx.Rngs(0, noise=1),\n", ")\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", "\n", "for epoch in range(2001):\n", " loss = train_step(model, optimizer, images_train)\n", diff --git a/docs/source/digits_vae.md b/docs/source/digits_vae.md index 825728cb..2732a2af 100644 --- a/docs/source/digits_vae.md +++ b/docs/source/digits_vae.md @@ -207,7 +207,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) @nnx.jit def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): @@ -247,7 +247,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) with jax.debug_nans(True): for epoch in range(2001): @@ -290,7 +290,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) for epoch in range(501): loss = train_step(model, optimizer, images_train) @@ -328,7 +328,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) train_step(model, optimizer, images_train) ``` @@ -361,7 +361,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) for i in range(5): train_step(model, optimizer, images_train) @@ -438,7 +438,7 @@ model = VAE( rngs=nnx.Rngs(0, noise=1), ) -optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) for epoch in range(2001): loss = train_step(model, optimizer, images_train) diff --git a/docs/source/neural_net_basics.ipynb b/docs/source/neural_net_basics.ipynb index 6a8f536e..2b783fb8 100644 --- a/docs/source/neural_net_basics.ipynb +++ b/docs/source/neural_net_basics.ipynb @@ -237,7 +237,7 @@ "import jax\n", "import optax\n", "\n", - "optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05))\n", + "optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))\n", "\n", "def loss_fun(\n", " model: nnx.Module,\n", diff --git a/docs/source/neural_net_basics.md b/docs/source/neural_net_basics.md index 366e2218..05896560 100644 --- a/docs/source/neural_net_basics.md +++ b/docs/source/neural_net_basics.md @@ -129,7 +129,7 @@ With the `SimpleNN` model created and instantiated, we can now choose the loss f import jax import optax -optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05)) +optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05)) def loss_fun( model: nnx.Module, From 7d03555ce76b05683789db48da8fdc17aca4ec92 Mon Sep 17 00:00:00 2001 From: Gerixmus Date: Wed, 15 Oct 2025 22:51:18 +0200 Subject: [PATCH 2/3] chore: pin flax to 0.10.6 for colab compatibility --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5dc8941d..ac1dbc48 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,6 +10,7 @@ sphinx-copybutton matplotlib penzai scikit-learn +flax==0.10.6 # install jax-ai-stack from current directory . From 686b46d05c092bc6eab4b534238195e6db65bf0c Mon Sep 17 00:00:00 2001 From: Gerixmus Date: Wed, 15 Oct 2025 23:52:10 +0200 Subject: [PATCH 3/3] chore: perform jupytext sync --- docs/source/JAX_Vision_transformer.ipynb | 1 + docs/source/JAX_Vision_transformer.md | 2 +- docs/source/JAX_basic_text_classification.ipynb | 1 + docs/source/JAX_basic_text_classification.md | 2 +- docs/source/JAX_examples_image_segmentation.ipynb | 1 + docs/source/JAX_examples_image_segmentation.md | 2 +- docs/source/JAX_for_LLM_pretraining.md | 2 +- docs/source/JAX_for_PyTorch_users.ipynb | 1 + docs/source/JAX_for_PyTorch_users.md | 2 +- docs/source/JAX_image_captioning.ipynb | 1 + docs/source/JAX_image_captioning.md | 2 +- docs/source/JAX_machine_translation.ipynb | 1 + docs/source/JAX_machine_translation.md | 2 +- docs/source/JAX_time_series_classification.ipynb | 1 + docs/source/JAX_time_series_classification.md | 2 +- docs/source/JAX_transformer_text_classification.ipynb | 1 + docs/source/JAX_transformer_text_classification.md | 2 +- docs/source/JAX_visualizing_models_metrics.ipynb | 1 + docs/source/JAX_visualizing_models_metrics.md | 2 +- docs/source/digits_diffusion_model.md | 2 +- docs/source/digits_vae.md | 2 +- docs/source/neural_net_basics.md | 2 +- 22 files changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/source/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb index cf56fccf..80253f08 100644 --- a/docs/source/JAX_Vision_transformer.ipynb +++ b/docs/source/JAX_Vision_transformer.ipynb @@ -1268,6 +1268,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index 94d1a889..be1bec0b 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_basic_text_classification.ipynb b/docs/source/JAX_basic_text_classification.ipynb index bde4d185..5229b33f 100644 --- a/docs/source/JAX_basic_text_classification.ipynb +++ b/docs/source/JAX_basic_text_classification.ipynb @@ -1025,6 +1025,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_basic_text_classification.md b/docs/source/JAX_basic_text_classification.md index f9d64429..853747d5 100644 --- a/docs/source/JAX_basic_text_classification.md +++ b/docs/source/JAX_basic_text_classification.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_examples_image_segmentation.ipynb b/docs/source/JAX_examples_image_segmentation.ipynb index c1a34dbc..a0bf9a83 100644 --- a/docs/source/JAX_examples_image_segmentation.ipynb +++ b/docs/source/JAX_examples_image_segmentation.ipynb @@ -2542,6 +2542,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_examples_image_segmentation.md b/docs/source/JAX_examples_image_segmentation.md index c5268dbc..7b3c2979 100644 --- a/docs/source/JAX_examples_image_segmentation.md +++ b/docs/source/JAX_examples_image_segmentation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 863ce7f5..70469551 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/source/JAX_for_PyTorch_users.ipynb b/docs/source/JAX_for_PyTorch_users.ipynb index a3a01c77..5c93e33c 100644 --- a/docs/source/JAX_for_PyTorch_users.ipynb +++ b/docs/source/JAX_for_PyTorch_users.ipynb @@ -1565,6 +1565,7 @@ "provenance": [] }, "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_for_PyTorch_users.md b/docs/source/JAX_for_PyTorch_users.md index 0e9fc6d0..e3ddc277 100644 --- a/docs/source/JAX_for_PyTorch_users.md +++ b/docs/source/JAX_for_PyTorch_users.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_image_captioning.ipynb b/docs/source/JAX_image_captioning.ipynb index 3b48b02f..6c731374 100644 --- a/docs/source/JAX_image_captioning.ipynb +++ b/docs/source/JAX_image_captioning.ipynb @@ -2316,6 +2316,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md index f51819c2..6adc08d8 100644 --- a/docs/source/JAX_image_captioning.md +++ b/docs/source/JAX_image_captioning.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_machine_translation.ipynb b/docs/source/JAX_machine_translation.ipynb index 99394c6e..a892fe13 100644 --- a/docs/source/JAX_machine_translation.ipynb +++ b/docs/source/JAX_machine_translation.ipynb @@ -1039,6 +1039,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md index 16e87778..9e67d179 100644 --- a/docs/source/JAX_machine_translation.md +++ b/docs/source/JAX_machine_translation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/JAX_time_series_classification.ipynb b/docs/source/JAX_time_series_classification.ipynb index be8a69cb..08046c84 100644 --- a/docs/source/JAX_time_series_classification.ipynb +++ b/docs/source/JAX_time_series_classification.ipynb @@ -1516,6 +1516,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_time_series_classification.md b/docs/source/JAX_time_series_classification.md index ec7f4171..d72c4570 100644 --- a/docs/source/JAX_time_series_classification.md +++ b/docs/source/JAX_time_series_classification.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: jax-env language: python diff --git a/docs/source/JAX_transformer_text_classification.ipynb b/docs/source/JAX_transformer_text_classification.ipynb index 476d58fc..211a5073 100644 --- a/docs/source/JAX_transformer_text_classification.ipynb +++ b/docs/source/JAX_transformer_text_classification.ipynb @@ -1321,6 +1321,7 @@ ], "metadata": { "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_transformer_text_classification.md b/docs/source/JAX_transformer_text_classification.md index 882a7413..ff3bed81 100644 --- a/docs/source/JAX_transformer_text_classification.md +++ b/docs/source/JAX_transformer_text_classification.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: jax-env language: python diff --git a/docs/source/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb index 5280d493..34df3912 100644 --- a/docs/source/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -449,6 +449,7 @@ "provenance": [] }, "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 7819d42e..bb763ca7 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 9bb021b5..04661827 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/source/digits_vae.md b/docs/source/digits_vae.md index 2732a2af..92a9b1aa 100644 --- a/docs/source/digits_vae.md +++ b/docs/source/digits_vae.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/source/neural_net_basics.md b/docs/source/neural_net_basics.md index 05896560..d57bce3a 100644 --- a/docs/source/neural_net_basics.md +++ b/docs/source/neural_net_basics.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.17.3 kernelspec: display_name: Python 3 name: python3