The authors introduce the context-ready transformer, a recurrent neural network architecture that pre-contextualizes each token before it enters a D-layer transformer block using a correction network.
- The correction network combines the previous position's cached summary with the current token embedding to create a contextualized input.
- Sequential inference operates as an RNN, while training unrolls the correction process K times for parallel processing.
- A pretrained transformer can be converted by adding a zero-initialized correction FFN and fine-tuning.
- A D=5 model beats a 12-layer transformer while generating 1.7x faster on an A100.
- With K=10, a single-layer model (D=1) beats a 6-layer transformer with a 2.6x inference speedup and matches parallel performance within 0.01 PPL.
- The architecture benefits most from wide representations and long contexts, solving all 10 composition levels on a pointer-chasing task where standard transformers fail.
This approach allows for significantly faster sequential inference speeds while maintaining or exceeding the performance of deeper standard transformer models.