Improving large language models with concept-aware fine-tuning

Paper · arXiv 2506.07833 · Published June 9, 2025
Training Fine TuningMechInterp

Large language models (LLMs) have become the cornerstone of modern AI. However, the existing paradigm of next-token prediction fundamentally limits their ability to form coherent, high-level concepts, making it a critical barrier to human-like understanding and reasoning. Take the phrase "ribonucleic acid" as an example: an LLM will first decompose it into tokens, i.e., artificial text fragments ("rib"→"on"→. . . ), then learn each token sequentially, rather than grasping the phrase as a unified, coherent semantic entity. This fragmented representation hinders deeper conceptual understanding and, ultimately, the development of truly intelligent systems. In response, we introduce Concept-Aware Fine-Tuning (CAFT), a novel multi-token training method that redefines how LLMs are fine-tuned. By enabling the learning of sequences that span multiple tokens, this method fosters stronger concept-aware learning. Our experiments demonstrate significant improvements compared to conventional next-token fine-tuning methods across diverse tasks, including traditional applications like text summarization and domain-specific ones like de novo protein design. Multitoken prediction was previously only possible in the prohibitively expensive pretraining phase; CAFT, to our knowledge, is the first to bring the multi-token setting to the post-training phase, thus effectively democratizing its benefits for the broader community of practitioners and researchers.

Importantly, this training paradigm conforms to a seemingly unassailable training objective: next-token prediction. A vocabulary of tokens, or text fragments, is first created using tokenization algorithms, most commonly byte-pair encoding (BPE) (Sennrich et al., 2015), which forms word/subword tokens based on their frequency in the training corpus. After tokenizing the texts using this vocabulary, the tokens are fed into the model to predict the next token autoregressively. For example, as shown in Figure 1(a,b), if a Llama 3 model (Grattafiori et al., 2024) is tasked to predict ribonucleic acid as part of a given question, the phrase is first deconstructed, i.e., tokenized, into rib, on, ucle, ic, and acid. Then, the model is trained to predict a single token in each forward pass sequentially, starting from rib.

Specifically, at each position in the training corpus, models are trained to predict the following n tokens using n output heads. However, these methods are restricted to the pretraining phase, which results in prohibitive costs and diminished effectiveness. First, the pretraining phase is inherently orders of magnitudes more computationally expensive than post-training, making existing multi-token methods unfeasible for all but a select group of well-resourced labs. Second, the pretraining phase teaches models general knowledge and language modeling skills, while the post-training phase teaches specific, relevant skills. Thus, existing methods do not adequately learn domain-specific, multi-token concepts: they exhibit only incremental gains compared to their next-token counterparts on downstream tasks.

Naturally, one would expect multi-token prediction to be applied to fine-tuning instead. However, to the best of our knowledge, current research in this direction has been unsuccessful, finding that fine-tuning with multi-token prediction leads to similar or worse performance (Gloeckle et al., 2024; Cai et al., 2024). Incorporating the multi-token setting into the post-training phase is extremely challenging because the multi-token setting represents a dramatic distribution shift. Given that post-training is much shorter than pretraining, models fail to adapt, leading to degradation. In response, we introduce Concept-Aware Fine-Tuning (CAFT), a novel multi-token fine-tuning method for nexttoken models. First, auxiliary heads that predict token positions beyond the next immediate token are trained using an instruction-tuning mixture, where the ground truth responses are self-distilled from the model itself. We provide trained task-agnostic auxiliary heads for a range of popular open-source models, allowing practitioners to focus on their task-specific MTP fine-tuning, as illustrated in Figure 1c. On top of full or Low-Rank Adaptation (LoRA) fine-tuning on the base model, the auxiliary heads and multi-token loss function are added.

We empirically demonstrate CAFT’s effectiveness and applicability to diverse domains, including traditional ones like text summarization and domain-specific ones like de novo protein design. It achieves superior performance to its next-token full and LoRA fine-tuning counterparts. The magnitudes of gains are similar or better than existing MTP pretraining methods despite using only a fraction of the computational cost. Additionally, we find that CAFT LoRA often outperforms next-token full fine-tuning, suggesting that models learn more effectively in a multi-token setting. In settings where multi-token prediction is highly advantageous, a multi-fold increase in model performance can be observed.

Importantly, CAFT presents significant implications for the scientific community. First, by introducing multi-token prediction into the post-training phase, our method democratizes the benefits of MTP to the broader community of practitioners and researchers.