Learning rate collapse prevents training recurrent neural networks at scale
Abstract
Recurrent neural networks (RNNs) are central to modeling neural computation in systems neuroscience, yet the principles that enable their stable and efficient training at large scales remain poorly understood. Seminal work in machine learning predicts that the effective learning rate should shrink with the size of feedforward networks. Here, we demonstrate an analogous phenomenon, termed learning rate collapse, in which the maximum trainable learning rate decreases inversely with the number of neurons. This behavior can be mitigated partially by scaling parameters with the inverse of network, though learning still takes longer for larger networks. These limits are further compounded by severe memory demands, which together make training large RNNs both unstable and computationally costly. As a proof of principle for mitigating learning rate collapse, we study the learning process of low-rank networks, which enforces a low-dimensional geometry in RNN representations. These results situate learning rate collapse within a broader lineage of scaling analyses in RNNs, with potential solutions likely to come from future work that incorporates careful consideration of symmetry and geometry in neural representations.