Stream of Search (SoS): Learning to Search in Language

Paper · arXiv 2404.03683 · Published April 1, 2024
Question Answer SearchEvolution

Language models are rarely shown fruitful mistakes while training. They then struggle to look beyond the next token, suffering from a snowballing of errors and struggling to predict the consequence of their actions several steps ahead. In this paper, we show how language models can be taught to search by representing the process of search in language, as a flattened string—a stream of search (SoS). We propose a unified language for search that captures an array of different symbolic search strategies. We demonstrate our approach using the simple yet difficult game of Countdown, where the goal is to combine input numbers with arithmetic operations to reach a target number. We pretrain a transformer-based language model from scratch on a dataset of streams of search generated by heuristic solvers. We find that SoS pretraining increases search accuracy by 25% over models trained to predict only the optimal search trajectory.

Imagine, only ever seeing the right solutions to problems, never a mistake or recovery from it. You might learn that problems must be solved in one clean pass, rather than through exploration and error. Most data used to train language models (LMs) only reflects the outcome of a decision making process, not the process itself. LMs never learn to make mistakes. They never learn to search, plan or backtrack. Complex decision-making and reasoning requires search. In this paper we explore the impact of training a LM on the search process, including mistakes, and then allowing them to self-improve.

If language models can learn to search during training, they may be able to discover more flexible search strategies through self-improvement. This could lead to models that are better equipped to handle the challenges posed by error compounding and lookahead tasks.

In this paper, we demonstrate that language models can be taught to search and backtrack in language, representing the process as a serialized string, a Stream of Search (SoS). We systematize the different components of search, such as exploration, backtracking, and pruning in a unified language. We instantiate this unified language in the context of a search problem inspired by the game of Countdown, a generalized version of the game of 24 (Yao et al., 2024; Countdown, 2024). A problem consists of input numbers and a target number. The goal is to combine the input numbers with arithmetic operations to reach the target (see Fig. 1a). Countdown presents a challenging search problem due to its high branching factor and the need to efficiently navigate the combinatorial search space towards the target number.

In these methods, LMs typically play two roles: (1) to generate candidate actions or successor states in the reasoning process, and (2) to evaluate proposed actions or states, by determining validity and/or assigning a heuristic value. A symbolic search algorithm such as BFS or DFS dictates the strategy for exploration, and how the steps or evaluators are called (Yao et al., 2024; Besta et al., 2023). While these methods have been shown to improve search accuracy on certain problems, the LM components are typically used only for inference, so their reasoning ability is not improved. In contrast, our work focuses on training LMs that are capable of exploration, backtracking, and other critical components of reasoning. Relative to these “extrinsic” methods, which use fixed search strategies, our method learns an “intrinsic” policy that allows the LM to autonomously search the solution space. In doing so, we avoid the high inference costs (Sel et al., 2023) required by tree-of-thoughts style approaches.

We have introduced the Stream of Search (SoS) framework enabling language models to learn to solve problems by searching in language, without any external structure or components. By systematizing the elements of search into a unified language, we are able to represent various search strategies in a common format to construct a dataset with diverse streams of search. Our experiments demonstrate that training language models to search leads to superior performance compared to models trained solely on optimal trajectories. This highlights the importance of exposing models to the messy process of problem solving, with exploration and backtracking, instead of only the ideal solution steps. SoS models can then self-improve by optimizing for correctness, using STaR and APA.

The SoS framework may address criticisms (LeCun, 2023; Bachmann & Nagarajan, 2024) of language models for planning and problem solving. The problem of snowballing errors is addressed by teaching a model to backtrack. Search allows models to explore alternative paths, overcoming failures in lookahead tasks by considering multiple possible outcomes before committing to a course of action. Crucially, SoS leads language models to learn an internal ’world model’ for search. Unlike symbolic search that relies on an explicit environment model, SoS models simulate state transitions themselves. Using a learned world models allows more adaptable and generalizable search (Cf. Schrittwieser et al., 2020) and addresses a key criticism of pretrained LMs