Reverse Thinking Makes LLMs Stronger Reasoners

Paper · arXiv 2411.19865 · Published November 29, 2024
Reasoning ArchitecturesTraining Fine Tuning

Reverse thinking plays a crucial role in human reasoning. Humans can reason not only from a problem to a solution but also in reverse, i.e., start from the solution and reason towards the problem. This often enhances overall reasoning performance as it enables consistency checks between their forward and backward thinking. To enable Large Language Models (LLMs) to perform reverse thinking, we introduce Reverse-Enhanced Thinking (REVTHINK), a framework composed of data augmentation and learning objectives. In REVTHINK, we augment the dataset by collecting structured forward-backward reasoning from a teacher model, consisting of: (1) the original question, (2) forward reasoning, (3) backward question, and (4) backward reasoning. We then employ three objectives to train a smaller student model in a multi-task learning fashion: (a) generate forward reasoning from a question, (b) generate a backward question from a question, and (c) generate backward reasoning from the backward question. Experiments across 12 datasets covering commonsense, math, and logical reasoning show an average 13.53% improvement over the student model’s zero-shot performance and a 6.84% improvement over the strongest knowledge distillation baselines

An effective way to improve test scores is to reason both forward and backward. In forward reasoning, we begin with the question and work step by step to an answer. Reverse thinking, on the other hand, starts from the predicted answer and works backward to the original question. This two-way approach allows us to verify the accuracy of the solution and identify potential errors. Consider a simple math problem: Emma has two apples, and Jack has three. How many do they have together? Forward reasoning leads to the calculation 2+3 = 5. Using reverse reasoning, we start with the conclusion that they have five apples. If Emma has two, we can ask: how many does Jack have? The result is three, which matches

Can reverse thinking be applied to broader, less structured domains? Moreover, these methods operate at test time, serving as a verification purpose: given solution, we can ask the LLM to think backward and see whether the forward reasoning is correct or not. While they show moderate improvements over other test-time methods such as Self-Consistency (Wang et al., 2022), it prompts the second question: Instead of using backward reasoning for verification at test time, can we train a model to inherently think backward, thereby improving its forward reasoning?

We begin by augmenting the dataset using a larger, more capable teacher model. Reasoning benchmark data typically consists of a question and answer. We extend this by generating (1) forward reasoning, (2) a backward question, and (3) backward reasoning, all through few-shot prompting from the teacher model. Both forward and backward reasoning are Chain-of-Thought (Wei et al., 2022). We retain only those data points where the forward reasoning is accurate (verified against ground truth) and where the backward reasoning aligns with the original question (validated by prompting the teacher model). After augmenting the dataset, we propose three key objectives for training the smaller student model. Specifically, the student learns to: (1) generate correct forward reasoning from the question, (2) generate a backward question from the original question, and (3) generate backward reasoning from the backward question. The rationale for these objectives is threefold. First, generating correct reasoning from a question is a standard method of knowledge distillation (Li et al., 2023a; West et al., 2022). Second, producing a reverse question encourages the student model to “think” about how to invert a problem and determine the right question to ask. Lastly, solving the backward question reinforces the student’s ability to reason backward. At test time, the student model is prompted with the question, and it generates only forward reasoning, similar to standard zero-shot inference. In essence, our pipeline internalizes the ability to reason backward during training, while keeping test-time computation as efficient as zero shot approaches.

Reasoning with LLMs. A large body of research has shown that LLM reasoning can be improved via advanced test-time approaches, such as prompting and aggregation. Representative methods include Chain-of-Thought (CoT) (Kojima et al., 2022; Wei et al., 2022) and Self-Consistency (Wang et al., 2022), Tree-of-Thought prompting (Yao et al., 2024), Self-Reflection (Shinn et al., 2024; Madaan et al., 2024; Yao et al., 2022), Multi-agent collaboration (Du et al., 2023; Liang et al., 2023; Wang et al., 2023; Lu et al., 2024; Feng et al., 2024; Chen et al., 2023). Several works have been proposed to leverage backward reasoning to verify the chain-of-thought and improve math reasoning (Weng et al., 2022; Jiang et al., 2024), while effective, these methods operate on test time, showing moderate improvements compared to other test time methods like self-consistency (Wang et al., 2022). Also, these methods have mostly been developed for mathematical tasks, limiting their generalizability.

Knowledge distillation is an effective way to transfer knowledge from a larger teacher model to a smaller student model

Additionally, teacher model outputs can be used to augment ground truth data (Ding et al., 2024). Our method aligns with this recent trend, leveraging the teacher model to generate CoT reasoning, along with backward questions and backward reasoning to augment data.

we focus on the mutually inverse relationship between a question and its backward counterpart. In our reasoning tasks, backward questions and backward reasoning are often absent and must be generated by LLMs. Our innovation lies in establishing connections between forward questions with forward reasoning and backward questions with backward reasoning, thereby exploiting the consistency of this connection within our training objectives.