Generalization to New Sequential Decision Making Tasks with In-Context Learning

Paper · arXiv 2312.03801 · Published December 6, 2023
Reasoning ArchitecturesTasks PlanningFlawsNovel Architectures

However, the sequential decision making setting poses additional challenges having a lower tolerance for errors since the environment’s stochasticity or the agent’s actions can lead to unseen, and sometimes unrecoverable, states. In this paper, we use an illustrative example to show that naively applying transformers to sequential decision making problems does not enable in-context learning of new tasks. We then demonstrate how training on sequences of trajectories with certain distributional properties leads to in-context learning of new sequential decision making tasks

In contrast, large transformers trained on vast amounts of data can learn new tasks from only a few examples without any parameter updates (Brown et al., 2020; Kaplan et al., 2020; Olsson et al., 2022a; Chan et al., 2022a). This emergent phenomenon is called few-shot or in-context learning (ICL), and is achieved by simply conditioning the model’s outputs on a context containing a few examples for solving the task (Brown et al., 2020; Ganguli et al., 2022; Wei et al., 2022).

We first focus on the data distributional properties required to enable in-context learning of sequential decision making tasks. Our key finding is that in contrast to (self-)supervised learning where the context can simply contain a few different examples (or predictions), in sequential decision making it is crucial for the context to contain full/partial trajectories (or sequences of predictions) to cover the potentially wide range of states the agent may find itself in at deployment. Translating this insight into a dataset construction pipeline, we demonstrate that we can enable in-context learning of unseen tasks on both the MiniHack and Procgen benchmarks from only a handful of demonstrations.

Burstiness was first introduced in Chan et al. (2022a) to describe the quality of a dataset sample where the context contains examples that are similar to the query (for example, from the same class). Here, we extend this idea to trajectory burstiness, which we define as the probability of a given input sequence containing at least two trajectories from the same level (Figure 2). This means that the relevant information required to solve the task is entirely in the sequence hence this design aids in-context learning. However, note that the trajectories, despite being from the same level, may vary because of the inherent environment stochasticity which means that the model still needs to generalize to handle these cases. These trajectories can be viewed as the sequential decision making equivalent of few-shot examples in supervised learning

In-Context Learning with Transformers: In-context learning (ICL), first coined in (Brown et al., 2020), is a phenomenon where a model learns a completely new task simply by conditioning on a few examples, without the need for any weight updates. Many works thereafter study this phenomenon (von Oswald et al., 2022; Chan et al., 2022b; Akyürek et al., 2022; Olsson et al., 2022b; Hahn and Goyal, 2023; Xie et al., 2021; Dai et al., 2022). For example, (Chan et al., 2022a) analyze what makes large language models perform well on few-shot learning tasks through the lens of data properties. Their findings suggest that ICL naturally arises when data distribution follows a power law (i.e., Zipfian) distribution and exhibits inherent "burstiness." In Kirsch et al. (2022), the authors demonstrate that ICL emerges as a function of the number of tasks and model size, with a clear phase transition between instance memorization, task memorization, and generalization. More recently, Garg et al. (2022) show that standard transformers can ICL learn entire function classes such as linear functions, sparse linear functions, and two-layer MLPs.

The Decision Transformer (DT) (Chen et al., 2021) was one of the first works to treat policy optimization as a sequence modelling problem and train transformer policies conditioned on the episode history and future return. Inspired by DT, Multi-Game Decision Transformer (MGDT) Lee et al. (2022) trains a single-trajectory transformer to solve multiple Atari games, but the generalization capabilities are limited without additional fine-tuning on the unseen tasks. In contrast with our work, they don’t provide the model with additional demonstrations of the new task at test time. More similar to our work, Prompt-DT Xu et al. (2022) shows that DTs conditioned on an explicit prompt exhibit few-shot learning of new tasks (i.e., with different reward functions) in the same environment (i.e., with the same states and dynamics). However, the test tasks they consider differ only slightly from the training ones e.g. the agent has to walk in a different direction so the reward function changes but the states, actions, and dynamics remain the same. In contrast, our train and test tasks differ greatly e.g. playing a platform game versus navigating a maze, having entirely new states, actions, dynamics, and reward functions. Similarly, Melo (2022) show that transformers are meta-reinforcement learners, but they require access to rewards and many demonstrations to solve new tasks at test time. While Team et al. (2023b) demonstrate few-shot online learning of new tasks, their model is trained on billions of tasks, whereas we consider the few-shot offline learning setting and we train our model only on a handful of tasks. Another related work is Algorithmic Distillation (AD) Laskin et al. (2022), which aims to learn a policy improvement operator by training a transformer on sequences of trajectories with increasing returns. However, AD assumes access to multiple model checkpoints with different proficiency levels and hasn’t demonstrated generalization to entirely new tasks with different states, actions, dynamics and rewards. In contrast with all these works, our goal is to generalize to completely new tasks with different states, actions, dynamics, and rewards from a handful of expert demonstrations. We demonstrate cross-task generalization on MiniHack and Procgen, two domains that contain vastly different tasks, from maze navigation to platform games. Our work is also first to extensively study how different factors (such as task diversity, trajectory burstiness, environment stochasticity, model and dataset size) influence the emergence of in-context learning in sequential decision making.

We find that a key ingredient during pretraining is to include entire trajectories in the context, which belong to the same environment level as the query trajectory a.k.a trajectory burstiness. In addition, we find that larger model and dataset sizes, as well as more task diversity, environment stochasticity, and trajectory burstiness, all result in better few-shot learning of out-of-distribution tasks.

https://twitter.com/mattshumer_/status/1783157348673912832