Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking

Paper · arXiv 2403.09629 · Published March 14, 2024
Reasoning by Reflection

https://arxiv.org/abs/2403.09629

For example, this applies to the steps not stated between the lines of a proof or to the theory of mind underlying a conversation. In the Self-Taught Reasoner (STaR, Zelikman et al. 2022), useful thinking is learned by inferring rationales from few-shot examples in question-answering and learning from those that lead to a correct answer. This is a highly constrained setting – ideally, a language model could instead learn to infer unstated rationales in arbitrary text. We present Quiet-STaR, a generalization of STaR in which LMs learn to generate rationales at each token to explain future text, improving their predictions. We address key challenges, including 1) the computational cost of generating continuations, 2) the fact that the LM does not initially know how to generate or use internal thoughts, and 3) the need to predict beyond individual next tokens.

Reasoning about implications of text to predict later text has consistently been shown to improve LM performance on a variety of tasks, but methods for allowing LMs to learn from their reasoning (e.g., Zelikman et al. 2022) have focused on solving individual tasks or predefined sets of tasks (e.g.,Wei et al. 2021b).

Quiet-STaR proceeds by generating rationales after every token to explain future text (think), mixing the future-text predictions with and without rationales (talk), and then learning to generate better rationales using REINFORCE (learn).

We introduce custom meta-tokens at the start and end of each thought to allow the LM to learn that it should be generating a rationale and when it should make a prediction based on that rationale.

  1. We apply a mixing head to retrospectively determine how much to incorporate the next-token prediction from a given thought into the current next-token prediction

In particular, the Self-Taught Reasoner (STaR, Zelikman et al. 2022) showed that LMs can bootstrap their reasoning ability on question-answering (QA) datasets by sampling rationales to attempt to answer questions, training on rationales if they led to a correct final answer, and then repeating this to iteratively solve more difficult problems. Yet, training from curated QA datasets limits the scale and generalizability of the rationales. QA datasets, especially high-quality ones, require thoughtful curation and will inherently only ever cover a subset of reasoning tasks. Thus, we extend STaR – instead of the LM learning to reason on particular tasks like mathematical QA, we train an LM to generate reasoning that helps it infer future text from a large internet text corpus. As a result, we allow the LM to learn from the diverse tasks present in language (Weber et al., 2021). This builds on an intuition essential to the current language modeling paradigm, namely, that ”language models are unsupervised multitask learners” (Radford et al., 2019). Thus, as in STaR, we leverage the LM’s pre-existing reasoning ability to generate rationales and train the LM on them with a REINFORCE-based reward (Williams, 1992). We refer to this technique as Quiet-STaR, as it can be understood as applying STaR “quietly”, training the model to think before it speaks.

We generalize STaR to learn reasoning from diverse unstructured text data. To our

knowledge, this is the first work explicitly training LMs to reason generally from

text, rather than on curated reasoning tasks or collections of reasoning tasks.

There have been many works on training and exploiting language models to solve difficult tasks by first training them to reason through them. For example, Rajani et al. (2019) demonstrated that a pre-trained language model fine-tuned to output on human reasoning traces before answering multiple-choice commonsense reasoning questions outperformed one trained directly on answers. Shwartz et al. (2020) demonstrated that language models, when provided with some scaffolding, can generate these helpful chain-of-thought solutions without additional supervision.

Another direction for teaching reasoning relies on a language model’s own generated reasoning, which can be seen as building on a large body of literature on self-play (Silver et al., 2017; Anthony et al., 2017; Polu & Sutskever, 2020). These include methods such as the Self-Taught Reasoner (Zelikman et al., 2022), which demonstrated that a language model iteratively trained on its reasoning that led to correct answers could solve increasingly difficult problems.

Recently, a growing body of work has demonstrated the usefulness of custom tokens optimized to perform specific functions in the context of a neural network – for this reason, they have also been referred to as “function vectors.” (Todd et al., 2023). One of the original instantiations of this was prompt-tuning (Lester et al., 2021) (and relatedly prefix-tuning (Li & Liang, 2021)), where the embeddings corresponding to the tokens of a prompt could be optimized to better accomplish a task. Others have applied meta-tokens to compress long prompts (Li et al., 2023; Jung & Kim, 2023) for efficiency. Most relevant to this work, Mu et al. (2024) optimized a token such that, when the tokens after it could not attend to the tokens before it (i.e., a context compression token), it would provide sufficient information to future tokens. Although we do not focus on compression, we share the problem of learning a token that affects attention and controls complex downstream behavior. In one related work, Goyal et al. (2023) show that learning a single ”pause” token (essentially representing each token as two tokens) improves LM performance. However, unlike the thought tokens in our work, this pause token does not initialize a thought – instead, it can be seen as acting as the entirety of the thought. We find that reasoning in language is significantly more helpful.

https://www.reddit.com/r/LocalLLaMA/comments/1bfifi2/quietstar_language_models_can_teach_themselves_to/