Skip to content

Conversation

@selamw1
Copy link
Collaborator

@selamw1 selamw1 commented Dec 18, 2025

This PR removes the TensorFlow dependency from the mlp_mnist notebook and updates the code to use Flax NNX instead of flax.linen. It also transitions the data loading pipeline to use grain, aligning the example with modern JAX best practices.

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

"from flax import nnx\n",
"\n",
"import grain.python as pygrain\n",
"from torchvision.datasets import MNIST\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using torchvision? Would be best to have either only grain or grain and tensorflow-datasets (if possible avoid tensorflow because of versioning issues).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchvision as a source for the dataset might be correct as it's a widely used and a well maintained package. This also show that optax (and jax) can be highly interoperable with the existing ecosystem.

@rdyro rdyro requested review from vroulet and removed request for vroulet December 28, 2025 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants