A Mechanistic Interpretation of Arithmetic Reasoning in Language Models using Causal Mediation Analysis
In this paper, we present a set of analyses aimed at mechanistically interpreting LMs on the task of answering simple arithmetic questions (e.g., “What is the product of 11 and 17?”). In particular, we hypothesize that the computations involved in reasoning about such arithmetic problems are carried out by a specific subset of the network. Then, we test this hypothesis by adopting a causal mediation analysis framework (Vig et al., 2020; Meng et al., 2022), where the model is seen as a causal graph going from inputs to outputs, and the model components (e.g., neurons or layers) are seen as mediators (Pearl, 2001). Within this framework, we assess the impact of a mediator on the observed output behavior by conducting controlled interventions on the activations of specific subsets of the model and examining the resulting changes in the probabilities assigned to different numerical predictions.
Through this experimental procedure, we track the flow of information within the model and identify the model components that encode information about the result of arithmetic queries. Our findings show that the model processes the input by conveying information about the operator and the operands from mid-sequence early layers to the final token using attention. At this location, the information is processed by a set of MLP modules, which output result-related information into the residual stream (shown in Figure 1). We verify this finding for bi and tri-variate arithmetic queries across four pretrained language models with different sizes: 2.8B, 6B, and 7B parameters. Finally, we compare the effect of different model components on answering arithmetic questions to two additional tasks: a synthetic task that involves retrieving a number from the prompt and answering questions related to factual knowledge. This comparison validates the specificity of the activation dynamics observed on arithmetic queries.