This seems to be the future of neural network models in JAX. See migration instructions here: https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html