Skip to yearly menu bar Skip to main content


Transformers learn to implement preconditioned gradient descent for in-context learning

Kwangjun Ahn · Xiang Cheng · Hadi Daneshmand · Suvrit Sra

Great Hall & Hall B1+B2 (level 1) #512
[ ]
Wed 13 Dec 8:45 a.m. PST — 10:45 a.m. PST

Abstract: Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: \emph{Can transformers learn to implement such algorithms by training over random problem instances?} To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with $L$ attention layers, we prove certain critical points of the training objective implement $L$ iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.

Chat is not available.