I'm excited to share my Jax implementation of SAC from pixels + image augmentations from DrQ (github.com/ikostrikov/jax…, see train_pixels.py). This Jax version is almost twice faster than our original implementation in PyTorch.
@ikostrikov awesome work! Is this 2x faster in wall-time for training the agent, or just the forward-backward pass of the NN?
@ericjang11 forward-backward pass. It takes the same number of updates, so in total it takes almost twice less time.
@ikostrikov @ericjang11 Do you have an idea of where the performance difference comes from? Is it primarily framework overhead?
@cHHillee @ericjang11 I think Jax jit makes the difference. There is another repository that compares Jax and PyTorch and gets a similar performance gain (github.com/yang-song/scor…).
@ikostrikov @ericjang11 Well, Jax jit still needs to be doing something to have better perf :) For example, if you're just doing a single large matmul, Jax jit isn't going to have better performance. From what I've seen, it's usually either fusion or overhead, so I was curious if you had an idea.