-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_pip]==0.4.25  # Must be consistent with the cuda version
orbax-checkpoint==0.5.1
absl-py==2.1.0
flax==0.8.1
grain==0.1.0
scipy==1.9.0
tensorflow==2.16.1
tensorboardX==2.6.2
google-cloud-storage==2.16.0
tensorflow_datasets==4.8.3
ml_collections==0.1.1
tensorflow_text==2.16.1
sentencepiece==0.2.0
cloud_tpu_diagnostics==0.1.5
protobuf==3.20.3
einops==0.6.1