Back to Homepage

Attention and Linear Regression

Mufeng Tang

November 10, 2023

Ever since Hopfield Network is All You Need, many papers in the past few years have pointed to the connection between Associative Memory (AM) models and the attention mechanism in Transformers. To name but a few: Universal Hopfield Network, Kernel Memory Networks, and the Tolman Eichenbaum Machine. Although the title of ‘Linking Transformers to how the brain performs computations’ sounds very cool, at the heart of these connections is the outer product between two matrices \(A, B\in\mathbb R^{N\times d}\), which results in a \(N\times N\) matrix \(AB^\top\). In AM models like Hopfield Networks, this happens when you try to compare your query with each of the stored memories, while in Transformers, this is the ‘affinity’ matrix \(QK^\top\). These \(N\times N\) matrices are nice, because they stores all the pair-wise relationships between what you have (\(K\)) and what you are presented/queried with (\(Q\)).

Since I’ve been working on AM for quite a while, I recently noticed another fundamental model that might be (loosely) connected to attention: linear regression. I find this connection much more interesting than the connection between AM models and Transformers, as AM models seek to memorize training patterns, while we know that Transformers’ power goes far beyond memorizing training data: they generalize. Meanwhile, as we learned in our first class of statistics, regression models also generalize.

In the following part of this blog, I will write down the exact mathematical expressions of this connection between linear regression and self-attention and try to provide and interpretation and open questions following this line. I do not seek to claim that ‘this is going to be an interesting research topic’; rather, it’s just an interesting finding that worth being transcribed from my scratch paper to a markdown file. Whether it is worth being moved further to a formal Latex file remains to be discovered - if you are reading this and come up with a new research idea, let me know.

The Maths

Let’s start from the simplest linear regression that we learned in STATS101 (although I remember having a whole course on this - what did I learn?). Let’s say we have some training data: independent variables \(X\in\mathbb R^{N\times d}\) and dependent variables \(Y\in\mathbb R^{N\times d_y}\). Also let’s call the regression coefficients/parameters \(\beta\in\mathbb R^{d\times d_y}\). The objective function of linear regression is: \[ \min_{\beta} \Vert Y - X\beta \Vert_F^2 \] To get the optimal \(\beta\), we take the derivative of the squared Frobenius norm with respective to \(\beta\) and set it to 0. The optimal \(\hat\beta\) can be written as (I’ll skip the steps because you can literally find it everywhere): \[ \hat\beta = (X^\top X)^{-1}X^\top Y \] Then, according to this linear model, the fitted values should be: \[ \hat Y = X\hat\beta = X(X^\top X)^{-1}X^\top Y \] Without loss of generality, if we assume the \(N\) data points have zero mean, the middle \((X^TX)^{-1}\) matrix is the inverse of the covariance matrix of the dataset. Let’s call it \(S^{-1}\), then the fitted values become: \[ \hat Y = \textcolor{red}{X} \textcolor{blue}{S^{-1}} \textcolor{red}{X^\top} Y \] (the colors will be useful later). That’s all about linear regression. Now let’s look at the attention mechanism. For clarity and simplicity, I’ll focus on the self-attention here. We again assume some data \(X\in\mathbb R^{N\times d}\), where \(N\) now denotes the input sequence length and \(d\) the size of each embedded token vector. We would then multiply \(X\) with three trainable weight matrices \(W_q \in\mathbb R^{d\times d_k}\), \(W_k\in\mathbb R^{d\times d_k}\) and \(W_v\in\mathbb R^{d\times d_v}\) respectively to get the query \(Q=XW_q\in\mathbb R^{N\times d_k}\) , key \(K=XW_k\in\mathbb R^{N\times d_k}\) and value \(V=XW_v \in\mathbb R^{N\times d_v}\). The attention is then calculated: \[ A(ttention) = \text{softmax}(\frac{QK^\top}{\sqrt d_k})V \] Now, let’s express the \(Q\) and \(K\) matrices as \(X\) and their corresponding weight matrices and some color codes: \[ A = \text{softmax}(\textcolor{red}{X} \textcolor{blue}{\frac{W_qW_k^\top}{\sqrt d_k}} \textcolor{red}{X^\top})V \quad cf. \quad \hat Y = \textcolor{red}{X} \textcolor{blue}{(X^\top X)^{-1}} \textcolor{red}{X^\top} Y = \textcolor{red}{X} \textcolor{blue}{S^{-1}} \textcolor{red}{X^\top} Y \] I hope at this point, the color coded equation has made it quite clear that calculating the fitted values of linear regression and self-attention are doing something similar: they both take the form of \(XMX^\top\). For linear regression this \(M\) is the inverse covariance matrix of the data, and for self-attention this \(M\) is the product of the learnable parameters \(W_qW_k^\top\). In addition, self-attention has a softmax function that performs some nonlinear transformation of this matrix product.

Interpretations

Having set up the mathematical similarity between these two, we can come up with some interpretations of both linear regression and self-attention.

First, for linear regression: We can now interpret linear regression as a type of Universal Hopfield Networks. The fitted values \(\hat Y\) are now a weighted sum of the original dependent variables \(Y_1,...,Y_N\) in the training set, where the weights a some kind of ‘similarity measure’ between each pair of training data \(X_i\) and \(X_j\). When \(S\) is identity, this similarity measure is simply dot product, and I have hand-drawn a diagram of this process above. Imagine if \(X_i\)’s are a bunch of orthonormal vectors, then the dot product \(X_i^\top X_j\) will be zero unless \(i=j\). In our example above this will give us exactly the desired output \(Y_2\). However, almost always we don’t get orthonormal vectors and similarity by dot product is almost never a good idea. The original Universal Hopfield Network paper discussed this so I will not go into details about this.

When \(S\) is not identity, the product \(X_i^\top S^{-1} X_j\) is in fact a ‘whitened’ dot product that makes the similarity measure more robust to variances and correlations between features, because we can effective decompose \(S^{-1}\) into whitening matrices. I had some experiments in my recent paper that shows the benefits of this whitening step before dot product. Conceptually, the division by \(\sqrt d_k\) in self-attention seems to serve the same purpose as to handle variable key/query pairs, although it doesn’t help handle the correlated features like a whitening matrix does.

Another important point to make about this interpretation is that although linear regression itself isn’t an AM model, in the case where we query it with an original training data say \(X_0\)(instead of a new data point), what it does is essentially AM, by memorizing the pattern \([X_0, Y_0]\). When the \(N \leq d\) we can actually recall \(Y_0\) perfectly (imagine fitting a line with two dots on a 2d plane), so the capacity of a linear regression AM model (on condition of perfect recall) should be \(N_{max}=d\).

Second, for attention: This similarity brings us to interpret the attention mechanism as some type of regression operation. If we map the value \(V\) to our dependent variable \(Y\), attention can be considered as the fitted value by a regression model trained on \(N\) \(X_i\) and \(V_i\) pairs, given a seen example \(X\) (self-attention) or an unseen example \(\tilde X\) (cross-attention). So at the end of the day, the attention layer is secretly learning a linear regression between its key-value pairs.

Some open questions worth looking into

Despite the similarity, one can argue that the connection between attention and linear regression is quite loose. Indeed, their difference is also quite obvious: attention uses a softmax function that make the largest affinity between two vectors stand out whereas linear regression doesn’t have this mechanism; the matrix \(W_q W_k^\top\) in attention is trainable whereas the inverse covariance matrix \((X^\top X)^{-1}\) is directly calculated from the embedded tokens. These differences opens up a few questions that I wanted to look into (if time and energy allow of course…):

As I’m not an expert in Transformers or LLMs (been in the niche of NeuroAI for too long…), some of these questions have probably been asked or addressed. If you are reading this blog and happen to know any works related to these questions, please let me know. At the same time, if any of these findings/questions sparks an idea, please also let me know and I’m happy to discuss :)

Back to Homepage