Although I don't feel much love for Jax, I realized that the new thing I want to implement would be horrible to implement in Torch, but very easy in Jax. Having the forward/backward pass as functions taking the parameters as input is so neat. Can I ever go back to Torch?
@FlorinGogianu @jm_alexia It's shown to be faster in some cases, also just dealing with Jacobians and Hessians or n-th order is nice in JAX vs PyTorch. You have full control over forward vs reverse mode autodiff
@FlorinGogianu @jm_alexia @ikostrikov claimed 2x speedup for some RL stuff if I remember? Yes, this RNG thing is a bit annoying for quick experiments, if you want something with look and feel as PyTorch, take a look at objax or flax
@FlorinGogianu @vlastelicap @jm_alexia github.com/ikostrikov/jax…, I compared it against github.com/denisyarats/py…. I noticed this speed up only on GPU. Also, there is another repository that claims similar gains github.com/yang-song/scor…