Pretraining RNNs Without Recurrence
đź”— Source: arXiv
Pretraining Recurrent Networks without Recurrence
🚀 Technical Novelty
- Mechanism: Trains nonlinear RNNs via one-step supervised learning on memory transition labels
(m_t, x_{t+1}) → m_{t+1}, where the labels are generated by a time-parallel Transformer encoder optimized to retain only future-predictive information (predictive state). - Nuance: Unlike BPTT’s O(T) sequential gradient path or linear attention models’ restricted expressivity, SMT decouples memory representation from dynamics, achieving fully parallel training with stable O(1) credit assignment while preserving full nonlinear RNN capacity.
đź’ˇ Yield
- Outperforms standard BPTT on language modeling and pixel sequence modeling tasks by more effectively capturing long-range dependencies; enables time-parallel RNN pretraining without ever unrolling the network, drastically reducing sequential computation bottlenecks.
⚠️ Limitations
- Teacher Transformer’s inherent expressivity limits may constrain pretraining quality, potentially requiring downstream BPTT fine-tuning to surpass teacher capabilities; not designed for direct reasoning due to lack of intermediate step supervision; experiences memory drift post-pretraining requiring lightweight adaptation; currently computes/trains only a single memory state per sequence.