RL-STaR: Theoretical Analysis of Reinforcement Learning Frameworks for Self-Taught Reasoner

Paper · arXiv 2410.23912 · Published October 31, 2024
Self Refinement Self Consistency FeedbackReasoning o1 o3 Search

The reasoning abilities of large language models (LLMs) have improved with chain-of-thought (CoT) prompting, allowing models to solve complex tasks in a stepwise manner. However, training CoT capabilities requires detailed reasoning data, which is often scarce. The self-taught reasoner (STaR) framework addresses this by using reinforcement learning to automatically generate reasoning steps, reducing reliance on human-labeled data. Although STaR and its variants have demonstrated empirical success, a theoretical foundation explaining these improvements is lacking. This work provides a theoretical framework for understanding the effectiveness of reinforcement learning on CoT reasoning and STaR. Our contributions are: (1) an analysis of policy improvement, showing why LLM reasoning improves iteratively with STaR; (2) conditions for convergence to an optimal reasoning policy; (3) an examination of STaR’s robustness, explaining how it can im- prove reasoning even when incorporating occasional incorrect steps; and (4) criteria for the quality of pre-trained models necessary to initiate effective reasoning improvement. This framework aims to bridge empirical findings with theoretical insights, advancing reinforcement learning approaches for reasoning in LLMs.

Numerous improvements to STaR have since been introduced [HYM+24, ZHS+24],

demonstrating empirically that LLMs can effectively learn reasoning steps via reinforcement learning without human intervention.

In this research, we propose a theoretical framework tailored to analyzing the effectiveness of reinforcement learning on CoT reasoning and STaR, which answers the following questions:

• Policy improvement: Why can LLMs improve their reasoning capabilities with each iteration of STaR?

• Convergence to optimal policy: If an optimal reasoning model exists, can STaR find this optimal reasoner within infinite number of iterations?

• Existence of incorrect reasoning steps in STaR: In STaR, it is possible for the model to generate incorrect reasoning steps while still arriving at the correct final answer, which means these erroneous steps are included in the training data for that iteration. We aim to explain why STaR can still enhance the LLM’s reasoning capabilities despite the inclusion of these incorrect steps.

• Properties of pre-training models for STaR: Since STaR requires a pre-trained LLM to bootstrap the discovery of reasoning steps in the first iteration, how good the pre-trained LLM should be in solving reasoning problem?

2.2 Theories of Chain-of-thought

The Chain-of-Thought (CoT) techniques [WWS+22] enable large language models (LLMs) to tackle complex reasoning tasks by breaking down solutions into a series of sequential steps. Beyond empirical success, some theoretical insights into CoT reasoning have emerged. For instance, [PLG24] models the CoT process using Bayesian networks, where questions, answers, and reasoning steps are nodes within the network. Providing a structured path of reasoning steps has been shown to boost LLM performance. Additionally, [XL24] introduces the concept of length generalization, where LLMs can solve complex problems by generalizing patterns from simpler training examples. In [Mal23], the authors extend the PAC supervised learning framework to a PAC auto-regressive framework, demonstrating that an auto-regressive learner can learn linear threshold circuits when CoT steps are provided. Furthermore, [FZG+24] shows that with CoT, transformers can address problem classes solvable by dynamic programming, even when problem sizes grow polynomially. Although these works lay a theoretical foundation for CoT, they fall short of explaining why reinforcement learning could enhance CoT capabilities in LLMs. Moreover, these studies underscore the necessity of ample reasoning step examples in training data to develop strong CoT abilities during inference.

To reduce the labeling effort for CoT during training, the Self-Taught Reasoner (STaR) framework [ZWMG22] employs a reinforcement learning approach, specifically a policy gradient method, to enable LLMs to enhance their reasoning abilities autonomously. STaR initially generates reasoning steps through in-context learning to elicit chain-of-thought processes. Only the reasoning steps that lead to correct answers are added to the training data, which strengthens the model iteratively as the LLM generates new reasoning paths and then added to the training data in each round. Several STaR extensions have been introduced to further enhance the framework. For instance, [ZHS+24] proposed Quiet-STaR, a variant where language models produce token-level rationales to justify upcoming text, refining their predictions.