Multi-Token Attention

Paper · arXiv 2504.00927 · Published April 1, 2025
Novel ArchitecturesMemoryLLM Architecture

Soft attention is a critical mechanism powering LLMs to locate relevant parts within a given context. However, individual attention weights are determined by the similarity of only a single query and key token vector. This “single token attention” bottlenecks the amount of information used in distinguishing a relevant part from the rest of the context. To address this issue, we propose a new attention method, Multi-Token Attention (MTA), which allows LLMs to condition their attention weights on multiple query and key vectors simultaneously. This is achieved by applying convolution operations over queries, keys and heads, allowing nearby queries and keys to affect each other’s attention weights for more precise attention. As a result, our method can locate relevant context using richer, more nuanced information that can exceed a single vector’s capacity. Through extensive evaluations, we demonstrate that MTA achieves enhanced performance on a range of popular benchmarks. Notably, it outperforms Transformer baseline models on standard language modeling tasks, and on tasks that require searching for information within long contexts, where our method’s ability to leverage richer information proves particularly beneficial1.

Each attention value in standard multi-head attention, see Equation 1, depends solely on a single key and query vector. That means all the necessary information for finding and attending to a relevant part of the context must be compressed into these single vectors. This might not be ideal if we are looking for a sentence containing multiple elements. Consider for example the sentence “Where did Alice see the rabbit?”. We could try to find instances of “Alice” and “rabbit” independently and then check if there is a sentence that has both. Let qa and qr be query vectors encoding “Alice” and “rabbit” respectively (assuming a word tokenizer), then their attention weights are computed as follows:

aa = Softmax(qaK⊤/

d), ar = Softmax(qrK⊤/

d) (2)

By doing normal attention with those queries, we can attend where “Alice” and “rabbit” are mentioned in the context. All we have to do then is to check if both attention weights aa and ar have higher probabilities at the same nearby locations, e.g. in the same sentence, which will indicate that the sentence mentions both “Alice” and “rabbit”. Unfortunately, normal attention lacks such interaction between attention maps, and instead only uses them to compute output values. Even if we use different attention heads to find “Alice” and “rabbit”, there is no mechanism to combine these attention weights. This motivates us to modify the attention mechanism to allow combining different attention maps from nearby locations (both in terms of query and key locations), or between different attention heads. As shown in Figure 1 (right), our proposed Multi-Token Attention consists of three important components built on top of multi-head attention: key-query convolution, head mixing convolution, and group normalization with gating mechanism. The overall MTA convolution applies the key-query convolution to combine multiple keys and queries within heads, and the head convolution to share knowledge between heads and amplify important information. Finally, we apply group normalization with scalar gating to push back against residual streams and improve gradient flow. In this section, we will describe each component of MTA in detail.