Chain-of-Retrieval Augmented Generation
This paper introduces an approach for training o1-like RAG models that retrieve and reason over relevant information step by step before generating the final answer. Conventional RAG methods usually perform a single retrieval step before the generation process, which limits their effectiveness in addressing complex queries due to imperfect retrieval results. In contrast, our proposed method, CoRAG (Chain-of-Retrieval Augmented Generation), allows the model to dynamically reformulate the query based on the evolving state. To train CoRAG effectively, we utilize rejection sampling to automatically generate intermediate retrieval chains, thereby augmenting existing RAG datasets that only provide the correct final answer. At test time, we propose various decoding strategies to scale the model’s test-time compute by controlling the length and number of sampled retrieval chains.
Contemporary RAG systems typically employ a sequential pipeline of retrieval and generation, wherein the retrieved information serves as additional input to the generative model. The effectiveness of RAG systems predominantly relies on the quality of the retrieved information. Retrieval models are engineered for efficiency to ensure scalability to large corpora. For instance, dense retrievers [16, 33] commonly utilize a bi-encoder architecture to compress documents and queries into fixed-size vector representations. This architectural choice permits the use of fast approximate nearest neighbor search algorithms but simultaneously constrains the expressive capacity of retrieval models to handle complex queries. Furthermore, in multi-hop reasoning tasks, it is often unclear what information should be retrieved initially; decisions must be made based on the progressively evolving state of the reasoning process.
To break the bottleneck of retrieval quality, we propose a framework that dynamically retrieves relevant information and plans subsequent retrieval steps based on the current state. By adjusting the number of retrieval steps at test time, our model can explore various aspects of the query and experiment with different query rewriting strategies when the retriever does not yield useful information. This paradigm mirrors the human problem solving process, where we iteratively seek information to address complex questions.
Rather than solely relying on the model’s in-context learning capability [39] or distillation from proprietary models [1], we advocate for explicitly training language models to retrieve step by step. To this end, we utilize rejection sampling [40, 4] to augment existing RAG datasets with intermediate retrieval chains. Open-source language models are then fine-tuned on these augmented datasets using standard next-token prediction objectives. To examine the scaling behavior of our model, we propose various test-time decoding strategies, including greedy decoding, best-of-N sampling, and tree search. Diverse decoding strategies and hyperparameter configurations can be employed to control test-time token consumption and the frequency of retriever calls. Our empirical evaluation demonstrates that CoRAG substantially surpasses strong baselines in QA tasks that require multi-hop reasoning, where retrievers frequently struggle to recall all necessary information in a single retrieval step.