@ericjang11 forward-backward pass. It takes the same number of updates, so in total it takes almost twice less time.
@ikostrikov @ericjang11 Do you have an idea of where the performance difference comes from? Is it primarily framework overhead?
@cHHillee @ericjang11 I think Jax jit makes the difference. There is another repository that compares Jax and PyTorch and gets a similar performance gain (github.com/yang-song/scor…).
@ikostrikov @ericjang11 Well, Jax jit still needs to be doing something to have better perf :) For example, if you're just doing a single large matmul, Jax jit isn't going to have better performance. From what I've seen, it's usually either fusion or overhead, so I was curious if you had an idea.
@cHHillee @ikostrikov @ericjang11 Most of the gains come from more efficient CPU to GPU data transfer and faster data augmentation. If you fix these things in our original DrQ code (github.com/denisyarats/drq) you can easily get 2x speed up. Without this, it is just matmul and jax doesn't do it much faster.
@cHHillee @ikostrikov @ericjang11 I think the main advantage of Jax is that it allows you to not think much about these type of optimizations in order to get an optimal performance, which saves a lot of time...
@denisyarats @cHHillee @ericjang11 I fully agree. I believe that given a lot of time, it's possible to get better performance from manual optimization than from automatic provided by jit.
@ikostrikov @denisyarats @cHHillee Thanks! very interesting. fwd / bwd wall time are the dominating factor for state-based policies, do any of these numbers change when doing image-based SAC ?
@ericjang11 @denisyarats @cHHillee fwd/bwd also takes most of the computational time for image-based SAC. If I remember correctly, rendering takes only < 10% of the overall time. But I benchmarked it a long time ago.
@ikostrikov @denisyarats @cHHillee I see. IIRC, Mujoco renders image on CPU and that can take O(1) wrt fwd/bwd time for sufficiently large images.
@ericjang11 @denisyarats @cHHillee We use EGL rendering (headless, hardware accelerated). DeepMind Control suite supports it by default but it's possible to build gym mujoco with egl rendering support as well: github.com/openai/mujoco-…