Can models learn multi-token concepts during fine-tuning?
Does training models to predict multiple tokens at once, rather than one token sequentially, help them form coherent semantic units? This matters because current next-token prediction fragments concepts like "ribonucleic acid" into arbitrary subword pieces.
Next-token prediction fragments multi-token concepts into arbitrary subword units. "Ribonucleic acid" becomes "rib" → "on" → "ucle" → "ic" → "acid" — five separate prediction targets with no unified semantic representation. Concept-Aware Fine-Tuning (CAFT) introduces multi-token prediction into post-training, enabling models to learn sequences that span multiple tokens as coherent concepts.
Prior multi-token prediction methods worked only during pretraining — prohibitively expensive and dominated by general language modeling rather than domain-specific concept formation. Attempts to apply multi-token prediction to fine-tuning previously failed because multi-token prediction represents a dramatic distribution shift that short post-training phases cannot absorb. CAFT solves this through self-distilled auxiliary heads: first train auxiliary heads (predicting positions beyond the next token) using an instruction-tuning mixture with self-distilled ground truth, then fine-tune with multi-token loss on top of standard LoRA or full fine-tuning.
The results: CAFT consistently outperforms next-token fine-tuning across text summarization and de novo protein design. CAFT LoRA often outperforms next-token full fine-tuning — suggesting models learn more effectively in a multi-token setting even with fewer trainable parameters. In settings where multi-token prediction is highly advantageous (protein design, where amino acid sequences have multi-residue semantic units), multi-fold performance increases are observed.
This connects to the format-shapes-reasoning finding: since Does training data format shape reasoning strategy more than domain?, the prediction unit (single token vs. multi-token) is a format variable that shapes what the model learns. Multi-token prediction is a higher-level format that encourages conceptual chunking rather than token-by-token prediction.
The democratization aspect matters: pretraining-phase MTP was restricted to well-resourced labs. CAFT brings this to fine-tuning, where any practitioner can apply it. Trained task-agnostic auxiliary heads are provided for popular open-source models.
Source: Training Fine Tuning
Related concepts in this collection
-
Does training data format shape reasoning strategy more than domain?
What explains why models trained on multiple-choice data reason differently than those trained on free-form text? The research isolates format and domain effects to measure which one matters more.
prediction unit is a format variable; multi-token prediction changes the format at the most fundamental level
-
Can formal language pretraining make language models more efficient?
Does training language models on hierarchical formal languages before natural language improve how efficiently they learn syntax? This explores whether structural inductive biases in training data matter more than raw data volume.
both change the learning unit: formal languages add structural hierarchy, CAFT adds multi-token prediction
-
How do knowledge injection methods trade off flexibility and cost?
When and how should domain knowledge enter an AI system? This explores the speed, training cost, and adaptability trade-offs across four injection paradigms, and when each approach suits different deployment constraints.
CAFT is a new variant within the fine-tuning paradigm: same cost, different learning objective
Click a node to walk · click center to open · click Open full network for a force-directed map
Original note title
multi-token concept-aware fine-tuning overcomes next-token fragmentation to form coherent semantic entities during post-training