Teaching Large Language Models to Reason with Reinforcement Learning

Paper · arXiv 2403.04642 · Published March 7, 2024
Reasoning by ReflectionReinforcement Learning

https://arxiv.org/abs/2403.04642

Reinforcement Learning

Reinforcement Learning from Human Feedback (RLHF) has emerged as a dominant approach for aligning LLM outputs with human preferences. Inspired by the success of RLHF, we study the performance of multiple algorithms that learn from feedback (Expert Iteration, Proximal Policy Optimization (PPO), Return-Conditioned RL) on improving LLM reasoning capabilities. We investigate both sparse and dense rewards provided to the LLM both heuristically and via a learned reward model. We additionally start from multiple model sizes and initializations both with and without supervised fine-tuning (SFT) data. Overall, we find all algorithms perform comparably, with Expert Iteration performing best in most cases. Surprisingly, we find the sample complexity of Expert Iteration is similar to that of PPO…

Improvements in model instructability have further increased apparent model capability by making complex behaviors more accessible via instruction prompting. This has led to a number of increasingly sophisticated prompting strategies augmenting LLM reasoning capabilities such as Chain-of-Thought (Wei et al., 2022) or Tree-of-Thoughts (Yao et al., 2023).

Previous work in reinforcement learning (RL) such as AlphaGo (Silver et al., 2017), AlphaStar (Vinyals et al., 2019), and OpenAI Dota 2 (Berner et al., 2019) demonstrate that RL techniques can be used to train neural networks capable of sophisticated planning and reasoning in game environments. Cicero (Bakhtin et al., 2022) in particular succeeds in combining an RL trained planning agent with a dialogue fine-tuned LLM to achieve nearly super-human performance in the board game Diplomacy. Given these previous successes and the inherent interactive nature of problem solving, applying RL to LLM reasoning seems a natural next step.

….

However, RL training does not significantly improve pass@n score beyond what can be achieved with light supervised fine-tuning. This suggests even with RL training our best models are not discovering solutions beyond what can be discovered with (light) supervised fine-tuning given the same rollout budget.

This observation, taken together with the fast convergence of both online algorithms and the low-impact of ORM guidance and dense rewards, suggests models are not engaging in a significant amount of exploration beyond pretraining/SFT data. Regardless of the type of algorithm used or the quality of the reward, all student models engage in similar exploration, resulting in similar performance.

….

Crucial in our setting is the usage of a pretrained model imparting a strong exploration prior. Without such a prior, exploration in a high-dimensional textual action space would be impossible. However, this prior also appears to constrain the exploration engaged in at the beginning of training, with additional SFT training only making things worse. We view the discovery of new techniques encouraging complex, rich exploration of reasoning problems as fundamental to progress in LLM reasoning capability. More sophisticated prompting strategies such as Tree of Thought (Yao et al., 2023) and combining LLM generative abilities with evolutionary algorithms (Lehman et al., 2022) have already begun to make progress in this direction.

….

it remains unclear exactly what factors account for the biggest impact during RL fine-tuning due to wide variance in tasks, pretraining data, supervised fine-tuning data, RL algorithm used, and the reward source.

the sequential decision making setting poses additional challenges having a lower tolerance for errors since the environment’s stochasticity or the agent’s actions can lead to unseen, and sometimes unrecoverable, states.

….

Our key finding is that in contrast to (self-)supervised learning where the context can simply contain a few different examples (or predictions), in sequential decision making it is crucial for the context to contain full/partial trajectories (or sequences of predictions) to cover the potentially wide range of states the agent may find itself in at deployment.