Reinforcement Learning for Optimizing RAG for Domain Chatbots
Large Language Models (LLM), conversational assistants have become prevalent for domain use cases. LLMs acquire the ability to contextual question answering through extensive training, and Retrieval Augmented Generation (RAG) further enables the bot to answer domain-specific questions. This paper describes a RAG-based approach for building a chatbot that answers user’s queries using Frequently Asked Questions (FAQ) data. We train an in-house retrieval embedding model using info NCE loss, and experimental results demonstrate that the in-house model works significantly better than the well-known general-purpose public embedding model, both in terms of retrieval accuracy and Out-of-Domain (OOD) query detection.
To enable multi-turn conversations with RAG, a conversation history needs to be maintained and passed to the LLM with every query. It is known that a larger input token size leads to a drop in accuracy or hallucinations as LLMs have an additional task of choosing relevant information from a large context (Liu et al. 2023a).
In this paper, we first describe a RAG-based approach for building a chatbot that answers user’s queries using Frequently Asked Questions (FAQ) data. We have a domain FAQ dataset consisting of 72 FAQs regarding the credit card application process. The FAQ dataset is prepared to answer user queries regarding general card information pre and post-application queries. We train an in-house retrieval (embedding) model using info Noise Contrastive Estimation (infoNCE) loss (van den Oord, Li, and Vinyals 2019) with the English and Hinglish paraphrase queries created using ChatGPT. The embedding model is trained to maximize query-question and query-QnA similarity.
We noticed that for certain patterns/sequences of queries, we can get a good answer from the bot even without fetching the FAQ context. Examples of such scenarios can be: 1. for a follow-up query; FAQ context need not be retrieved if it has already been fetched for the previous query; 2. for the sequence of queries referring to the same FAQ, a context can be fetched only once at the start; 3. for OOD queries, the LLM prompt itself can guide the bot to generate the answer. Using this insight, we propose a policy gradient-based approach to optimize the number of LLM tokens and, hence, the cost.