Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?

Paper · arXiv 2410.08292 · Published October 10, 2024
Context EngineeringReinforcement Learning

The remarkable capability of Transformers to do reasoning and few-shot learning, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate a multi-step algorithms – such as gradient descent – with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, can training Transformers enable convergence to algorithmic solutions? In this work we resolve this for in-context linear regression with linear looped Transformers – a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for multi-layer Transformer in this setting. We further validate our theoretical findings through synthetic experiments.

Transformers [Vaswani et al., 2017] have completely revolutionized the field of machine learning and have led to state-of-the-art models for various natural language and vision tasks. Large scale Transformer models have demonstrated remarkable capabilities to solve many difficult problems, including those requiring multi-step reasoning through large language models [Brown et al., 2020, Wei et al., 2022b]. One such particularly appealing property is their few-shot learning ability, where the functionality and predictions of the model adapt to additional context provided in the input, without having to update the model weights. This ability of the model, typically referred to as “in-context learning”, has been crucial to their success in various applications. Recently, there has been a surge of interest to understand this phenomenon, particularly since Garg et al. [2022] empirically showed that Transformers can be trained to solve many in-context learning problems based on linear regression and decision trees. Motivated by this empirical success, Von Oswald et al. [2023], Akyürek et al. [2022] theoretically showed the following intriguing expressivity result: multi-layer Transformers with linear self-attention can implement gradient descent for linear regression where each layer of Transformer implements one step of gradient descent. In other words, they hypothesize that the in-context learning ability results from approximating gradient-based few-shot learning within its forward pass. Panigrahi et al. [2023], further, extended this result to more general model classes.

While such an approximation is interesting from the point of view of expressivity, it is unclear if the Transformer model can learn to implement such algorithms. To this end, Ahn et al. [2023], Zhang et al. [2023] theoretically show, in a Gaussian linear regression setting, that the global minimizers of a one-layer model essentially simulate a single step of preconditioned gradient descent, and that gradient flow converges to this solution. Ahn et al. [2023] further show for the multi-layer case that a single step of gradient descent can be implemented by some stationary points of the loss. However, a fundamental characterization of all the stationary points for multi-layer Transformer, and the convergence to a stationary point that implements multi-step gradient descent, remains a challenging and important open question. In this work, we focus our attention on the learnability of such multi-step algorithms by Transformer models. Instead of multi-layer models, we consider a closely related but different class of models called looped Transformers, where the same Transformer block is looped multiple times for a given input. Since the expectation from multi-layer models is to simulate an iterative procedure like multi-step gradient descent, looped models are a fairly natural choice to implement this. There is growing interest in looped models with recent results [Giannou et al., 2023] theoretically showing that the iterative nature of the looped Transformer model can be used to simulate a programmable computer, thus allowing looped models to solve problems requiring arbitrarily long computations. Looped Transformer models are also conceptually appealing for learning iterative optimization procedures — the sharing of parameters across different layers, in principle, can provide a better inductive bias than multi-layer Transformers for learning iterative-optimization procedures. In fact, by employing a regression loss at various levels of looping, Yang et al. [2023] empirically find that looped Transformer models can be trained to solve in-context learning problems, and that looping on an example for longer and longer at test time converges to a desirable fixed-point solution, thus leading them to conjecture that looped models can learn to express iterative algorithms1.

Despite these strong expressivity results for looped models and their empirically observed inductive bias towards simulating iterative algorithms, very little is known about the optimization landscape of looped models, and the theoretical convergence to desirable and interpretable iterative procedures. In fact, a priori it is not clear why training should even succeed given that looped models heavily use weight sharing and thus do not enjoy the optimization benefits of overparameterization that has been well studied [Buhai et al., 2020, Allen-Zhu et al., 2019]. In this work, we delve deeper into the problem of optimizing looped Transformers and theoretically study their landscape and convergence for in-context linear regression under the Gaussian data distribution setting used in [Ahn et al., 2023, Zhang et al., 2023]. In particular, the main contributions of our paper are as follows:

• We obtain a precise characterization of the global minimizer of the population loss for a linear looped Transformer model, and show that it indeed implements multi-step preconditioned gradient descent with pre-conditioner close to the inverse of the population covariance matrix, as intuitvely expected.

• Despite the non-convexity of the loss landscape, we prove the convergence of the gradient flow for in-context linear regression with looped Transformer. To our knowledge, ours is the first such convergence result for a network beyond one-layer in this setting.

• To show this convergence, we prove that the loss satisfies a novel gradient-dominance condition, which guides the flow toward the global optimum. We expect this convergence proof to be generalizable to first-order iterative algorithms such as SGD with gradient estimate using a single random instance De Sa et al. [2022].

• We further translate having a small sub-optimality gap, achieved by our convergence analysis, to the proximity of the parameters to the global minimizer of the loss.