Back to Homepage

Attention and Linear Regression

Mufeng Tang

November 10, 2023

Ever since Hopfield Network is All You Need, many papers in the past few years have pointed to the connection between Associative Memory (AM) models and the attention mechanism in Transformers. To name but a few: Universal Hopfield Network, Kernel Memory Networks, and the Tolman Eichenbaum Machine. Although the title of ‘Linking Transformers to how the brain performs computations’ sounds very cool, at the heart of these connections is the outer product between two matrices A,BRN×d, which results in a N×N matrix AB. In AM models like Hopfield Networks, this happens when you try to compare your query with each of the stored memories, while in Transformers, this is the ‘affinity’ matrix QK. These N×N matrices are nice, because they stores all the pair-wise relationships between what you have (K) and what you are presented/queried with (Q).

Since I’ve been working on AM for quite a while, I recently noticed another fundamental model that might be (loosely) connected to attention: linear regression. I find this connection much more interesting than the connection between AM models and Transformers, as AM models seek to memorize training patterns, while we know that Transformers’ power goes far beyond memorizing training data: they generalize. Meanwhile, as we learned in our first class of statistics, regression models also generalize.

In the following part of this blog, I will write down the exact mathematical expressions of this connection between linear regression and self-attention and try to provide and interpretation and open questions following this line. I do not seek to claim that ‘this is going to be an interesting research topic’; rather, it’s just an interesting finding that worth being transcribed from my scratch paper to a markdown file. Whether it is worth being moved further to a formal Latex file remains to be discovered - if you are reading this and come up with a new research idea, let me know.

The Maths

Let’s start from the simplest linear regression that we learned in STATS101 (although I remember having a whole course on this - what did I learn?). Let’s say we have some training data: independent variables XRN×d and dependent variables YRN×dy. Also let’s call the regression coefficients/parameters βRd×dy. The objective function of linear regression is: minβYXβF2 To get the optimal β, we take the derivative of the squared Frobenius norm with respective to β and set it to 0. The optimal β^ can be written as (I’ll skip the steps because you can literally find it everywhere): β^=(XX)1XY Then, according to this linear model, the fitted values should be: Y^=Xβ^=X(XX)1XY Without loss of generality, if we assume the N data points have zero mean, the middle (XTX)1 matrix is the inverse of the covariance matrix of the dataset. Let’s call it S1, then the fitted values become: Y^=XS1XY (the colors will be useful later). That’s all about linear regression. Now let’s look at the attention mechanism. For clarity and simplicity, I’ll focus on the self-attention here. We again assume some data XRN×d, where N now denotes the input sequence length and d the size of each embedded token vector. We would then multiply X with three trainable weight matrices WqRd×dk, WkRd×dk and WvRd×dv respectively to get the query Q=XWqRN×dk , key K=XWkRN×dk and value V=XWvRN×dv. The attention is then calculated: A(ttention)=softmax(QKdk)V Now, let’s express the Q and K matrices as X and their corresponding weight matrices and some color codes: A=softmax(XWqWkdkX)Vcf.Y^=X(XX)1XY=XS1XY I hope at this point, the color coded equation has made it quite clear that calculating the fitted values of linear regression and self-attention are doing something similar: they both take the form of XMX. For linear regression this M is the inverse covariance matrix of the data, and for self-attention this M is the product of the learnable parameters WqWk. In addition, self-attention has a softmax function that performs some nonlinear transformation of this matrix product.

Interpretations

Having set up the mathematical similarity between these two, we can come up with some interpretations of both linear regression and self-attention.

First, for linear regression: We can now interpret linear regression as a type of Universal Hopfield Networks. The fitted values Y^ are now a weighted sum of the original dependent variables Y1,...,YN in the training set, where the weights a some kind of ‘similarity measure’ between each pair of training data Xi and Xj. When S is identity, this similarity measure is simply dot product, and I have hand-drawn a diagram of this process above. Imagine if Xi’s are a bunch of orthonormal vectors, then the dot product XiXj will be zero unless i=j. In our example above this will give us exactly the desired output Y2. However, almost always we don’t get orthonormal vectors and similarity by dot product is almost never a good idea. The original Universal Hopfield Network paper discussed this so I will not go into details about this.

When S is not identity, the product XiS1Xj is in fact a ‘whitened’ dot product that makes the similarity measure more robust to variances and correlations between features, because we can effective decompose S1 into whitening matrices. I had some experiments in my recent paper that shows the benefits of this whitening step before dot product. Conceptually, the division by dk in self-attention seems to serve the same purpose as to handle variable key/query pairs, although it doesn’t help handle the correlated features like a whitening matrix does.

Another important point to make about this interpretation is that although linear regression itself isn’t an AM model, in the case where we query it with an original training data say X0(instead of a new data point), what it does is essentially AM, by memorizing the pattern [X0,Y0]. When the Nd we can actually recall Y0 perfectly (imagine fitting a line with two dots on a 2d plane), so the capacity of a linear regression AM model (on condition of perfect recall) should be Nmax=d.

Second, for attention: This similarity brings us to interpret the attention mechanism as some type of regression operation. If we map the value V to our dependent variable Y, attention can be considered as the fitted value by a regression model trained on N Xi and Vi pairs, given a seen example X (self-attention) or an unseen example X~ (cross-attention). So at the end of the day, the attention layer is secretly learning a linear regression between its key-value pairs.

Some open questions worth looking into

Despite the similarity, one can argue that the connection between attention and linear regression is quite loose. Indeed, their difference is also quite obvious: attention uses a softmax function that make the largest affinity between two vectors stand out whereas linear regression doesn’t have this mechanism; the matrix WqWk in attention is trainable whereas the inverse covariance matrix (XX)1 is directly calculated from the embedded tokens. These differences opens up a few questions that I wanted to look into (if time and energy allow of course…):

As I’m not an expert in Transformers or LLMs (been in the niche of NeuroAI for too long…), some of these questions have probably been asked or addressed. If you are reading this blog and happen to know any works related to these questions, please let me know. At the same time, if any of these findings/questions sparks an idea, please also let me know and I’m happy to discuss :)

Back to Homepage