github.com/google-researc… (based on JAX/Linen) allows you to play with memorizing and block-recurrent transformers. Unlike other code bases, it is based on transformer-XL and lets you train on long documents using sliding-window attention.
github.com/google-researc… (based on JAX/Linen) allows you to play with memorizing and block-recurrent transformers. Unlike other code bases, it is based on transformer-XL and lets you train on long documents using sliding-window attention.
@ChrSzegedy Wonderful, thank you for the pointer and the amazing research!
@erik_nijkamp This code is mostly developed by Delesley Hutchins with contributions from @MarkusNRabe and @Yuhu_ai_ .
@ChrSzegedy @MarkusNRabe @Yuhu_ai_ This is nice code, thanks again! Two rather silly questions: (1) The implementation seems to rely on pmap() SPMD limited to 8 TPU cores. I guess your 8B models or training code won't be released? (2) Besides ppl experiments, have you tried few-shot with memory on some benchmark?
@erik_nijkamp @MarkusNRabe @Yuhu_ai_ It is true that this code base can accomodate only models of limited size, but there are simple patches to fix that. We have not tried few-shot prompting on the memory.
@ChrSzegedy @MarkusNRabe @Yuhu_ai_ We are playing with the Meliad and your models, thanks again! Probably very underspecified, but would you have any high-level thoughts (or practical/empirical findings) on memorizing transformer, block-recurrent, S4, perceiver-AR, RETRO?