Poster
Weight decay induces low-rank attention layers
Seijin Kobayashi · Yassir Akram · Johannes von Oswald
West Ballroom A-D #5909
[
Abstract
]
Thu 12 Dec 11 a.m. PST
— 2 p.m. PST
Abstract:
The effect of regularizers such as weight decay when training deep neural networks is not well understood. We study the influence of weight decay as well as $L2$-regularization when training neural network models in which parameter matrices interact multiplicatively. This combination is of particular interest as this parametrization is common in attention layers, the workhorse of Transformers. Here, key-query, as well as value-projection parameter matrices, are multiplied directly with each other: $W_K^TW_Q$ and $PW_V$. We extend previous results and show on one hand that any local minimum of a $L2$-regularized loss of the form $L(AB^\top) + \lambda (\|A\|^2 + \|B\|^2)$ coincides with a minimum of the nuclear norm-regularized loss $L(AB^\top) + \lambda\|AB^\top\|_*$, and on the other hand that the 2 losses become identical exponentially quickly during training. We thus complement existing works linking $L2$-regularization with low-rank regularization, and in particular explain why such regularization on the matrix product kicks-in very early during training.Based on these theoretical insights, we verify empirically that the key-query and value-projection matrix products $W_K^TW_Q, PW_V$ within attention layers, when optimized with weight decay, as usually done in vision tasks and language modelling, indeed induce a significant reduction in the rank of $W_K^TW_Q$ and $PW_V$, even in fully online training.We find that, in accordance with existing work, inducing low rank in attention layers damages language models, and see a performance improvement when omitting weight decay in attention layers.
Live content is unavailable. Log in and register to view live content