-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_pip]==0.4.30
orbax-checkpoint==0.2.6
torch==2.3.0
flax==0.8.0
transformers>=4.40.2
smart_open[all]
