Muddy Hats
Inelegant Explorations of Some Personal Problems
Home Contact About Subscribe

Linear Regression: Understanding the Matrix Calculus

Linear regression is possibly the most well-known machine learning algorithm. It tries to find a linear relationship between a given of set of input-output pairs. One notable aspect is that linear regression, unlike most of its peers, has a closed-form solution.

The mathematics involved in the derivation of this solution (also known as the Normal equation) is pretty basic. However, to understand the equation in its commonly-used form, we need to appreciate some matrix calculus. In this post, I will attempt to explain, from ground up, the linear regression formula along with the necessary matrix calculus. I do assume that you are familiar with matrices (like tranposes and matrix multiplication), and basic calculus.

The Basics

Given a set of \(n\) data points \((\mathbf{x}_1, y_1)\), \((\mathbf{x}_2, y_2)\), \(\dots\) \((\mathbf{x}_n, y_n)\), linear regression tries to find a line (or a hyperplane in higher dimensions) which maps the input \(\mathbf{x}\) to output \(y\). Here, each \(\mathbf{x}_i\) may be a d-dimensional tuple i.e., \((x_{i,1}, x_{i,2}, \dots x_{i,d})\).

A linear function on \(\mathbf{x}_i\) can be represented as \(w_0 + w_1 x_{i,1} + \dots + w_d x_{i,d}\), where \(w_0, \dots, w_d\) are real numbers. In order to make the formula more uniform, we assume that \(\mathbf{x}_i\) consists of an extra element \(x_{i,0}\) which always equals to \(1\). So our required function becomes \(w_0 x_{i,0} + w_1 x_{i,1} + \dots w_d x_{i,d}\). We denote it as \(h_{\mathbf{w}}(\mathbf{x}_i)\).

Using matrices, we can write \(h_{\mathbf{w}}(\mathbf{x}_i)\) in a much more compact form. Conventionally, we use column matrices to represent vectors. Thus we have

\[ \mathbf{w} = \begin{pmatrix} w_0\\\
w_1\\\
w_2\\\
\vdots\\\
w_d \end{pmatrix}, \hspace{1cm}\mathbf{x}_i = \begin{pmatrix} x_{i,0}\\\
x_{i,1}\\\
x_{i,2}\\\
\vdots\\\
x_{i,d} \end{pmatrix} \]

Our function \(h_{\mathbf{w}}(\mathbf{x}_i)\) thus can be written as \(\mathbf{w}^\intercal\mathbf{x}_i\), or equivalently, as \(\mathbf{x}_i^\intercal\mathbf{w}\).

If there are \(n\) data points \(\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_n\), the outputs corresponding to all these can be kept together in a column matrix as follows: \[ \begin{pmatrix} \mathbf{x}_1^\intercal\mathbf{w}\\\
\mathbf{x}_2^\intercal\mathbf{w}\\\
\vdots\\\
\mathbf{x}_n^\intercal\mathbf{w} \end{pmatrix} \]

Note that this column matrix can be decomposed into the following product \[ \begin{pmatrix} \mathbf{x}_1^\intercal\mathbf{w}\\\
\mathbf{x}_2^\intercal\mathbf{w}\\\
\vdots\\\
\mathbf{x}_n^\intercal\mathbf{w} \end{pmatrix} = \begin{pmatrix} —\mathbf{x}_1^\intercal— \\\
—\mathbf{x}_2^\intercal—\\\
\vdots\\\
—\mathbf{x}_n^\intercal— \end{pmatrix} \times \mathbf{w} \]

It is important to keep in mind that even though the last product looks like a scalar multiplication, \(\mathbf{w}\) is in fact a matrix, and that the LHS and RHS are equal nevertheless.

Understandably, we call the first matrix in the decomposition \(\mathbf{X}\). It is an \(n\times (d+1)\) matrix, each of whose rows represents a data point. Thus we can compute the outputs \(h_\mathbf{w}(\mathbf{x}_i)\) corresponding to all the data points \(\mathbf{x}_i\) using the single matrix product \(\mathbf{X}\mathbf{w}\).

The Loss Function

To recapitulate, the linear function we want to learn is represented by the weight matrix \(\mathbf{w}\) and, for a given set of inputs \(\mathbf{X}\), the output of the linear function is given by \(\mathbf{X}\mathbf{w}\). We also have a set of \(y\) values. The scalar \(y_i\) is the actual output corresponding to the input \(\mathbf{x}_i\). These \(y\)’s can also be treated together as a column matrix \(\mathbf{y}\), whose \(i^\text{th}\) element is the \(i^\text{th}\) output.

Thus \(\mathbf{X}\mathbf{w}\) is the output given (or predicted) by our linear function, and \(\mathbf{y}\) is the actual output. The difference is simply \(\mathbf{X}\mathbf{w} - \mathbf{y}\). We can see that this matrix is a column matrix with \(n\) rows.

The aim of linear regression is to minimize the errors as much as possible. But rather than try to minimize each error separately—which would be a hard task, as decreasing one error might cause another error to shoot up—we try to minimize the sum of squares of the individual errors.

If \(\mathbf{a}\) is a column matrix, the sum of squares of its elements is just the product \(\mathbf{a}^\intercal \mathbf{a}\). (Why?) Therefore, what we would like to minimize is \((\mathbf{X}\mathbf{w} - \mathbf{y})^\intercal (\mathbf{X}\mathbf{w} - \mathbf{y})\). This quantity is known as the loss function of \(\mathbf{w}\).

Minimizing the Loss Function

The solution to the linear regression problem is the point \(\mathbf{w}\) at which the loss function is the minimum. To find it, we simply find the derivative of the loss function and equate it to zero.

In general, functions may have multiple minima and/or maxima. Some functions may not even have a minimum1. But here, we don’t have to worry about those cases. The sum of squares of errors, our loss function, is a quadratic function. It turns out a quadratic function (think about \(y=x^2\)), has only one extreme point. In our case, that extreme point happens to be the minimum2.

Partial Derivatives

Our loss function depends on not one, but \(d+1\) variables. So to find the minimum point, we need to take partial derivates with respect to all of these variables and equate to zero. Rather than take the \((d+1)\) derivatives separately, we can use the power of matrices to avoid the unnecessary work.

First, let us use a simple convention. To differentiate a scalar with respect to a column matrix of variables, we will differentiate the scalar using each variable in the column matrix, and collect the outputs in a column matrix. The output thus will have the same shape as the denominator.

If \(f(\mathbf{w})\) is a scalar, its derivative wrt \(\mathbf{w}\) will thus be:

\[ \frac{\partial f(\mathbf{w})}{\partial\mathbf{w}} = \begin{pmatrix} \frac{\partial f(\mathbf{w})}{\partial w_0}\\\
\frac{\partial f(\mathbf{w})}{\partial w_1}\\\
\vdots\\\
\frac{\partial f(\mathbf{w})}{\partial w_d} \end{pmatrix} \]

As an example, if \(\mathbf{z}\) is an \(n\times 1\) matrix, \(\frac{\partial \mathbf{z}^\intercal \mathbf{z}}{\partial \mathbf{z}}\) is \[ \begin{pmatrix} \frac{\partial}{\partial \mathbf{z_1}} (z_1^2 + \dots + z_n^2)\\\
\vdots\\\
\frac{\partial}{\partial \mathbf{z_{n}}} (z_1^2 + \dots + z_{n}^2) \end{pmatrix} = \begin{pmatrix} 2z_1\\\
\vdots\\\
2z_{n} \end{pmatrix} = 2\mathbf{z} \]

The Chain Rule

A relatively easy method for computing the derivative of the loss function is using the chain rule.

In the chain rule of calculus, if \(u\) is a function of \(v\) and \(v\) in turn is a function of \(w\), \[\frac{\partial u}{\partial w} = \frac{\partial u}{\partial v}\frac{\partial v}{\partial w}\]

Does the same rule work for matrix differentiation also? Let us find out.

We define \(\mathbf{v} = \mathbf{X}\mathbf{w} - y\), and \(u = \mathbf{v}^\intercal \mathbf{v}\). Here \(\mathbf{v}\) is a column matrix of size \(n \times 1\), and \(u\) is a scalar.

Then what we wish to compute is \(\frac{\partial u}{\partial \mathbf{w}}\). By the convention we are following, this is a \((d+1)\times 1\) column matrix. Similarly, \(\frac{\partial u}{\partial \mathbf{v}}\) is an \(n \times 1\) matrix.

But what about \(\frac{\partial \mathbf{v}}{\partial \mathbf{w}}\)? It is a differentiation of a column matrix by another column matrix. Intuitively, the output should consist of the derivative of each element in \(\mathbf{v}\) with respect to each element in \(\mathbf{w}\). But in which should the matrices be processed—should we take each element in the numerator, and then differentiate it with respect to each element in the denominator, or should we take each element in the denominator, and differentiate each element in the numerator with it?

We will follow the convention of populating the output matrix in the column-first manner. In the first case, we will get the following \((d+1)\times n\) matrix. Note that we are assuming that the elements of \(\mathbf{v}\) are \(v_1, v_2, \dots, v_n\), as they are the errors corresponding to the data points \(x_1, x_2, \dots, x_n\) In contrast, \(\mathbf{w}\)’s indexing starts with \(0\): \(w_0, w_1, \dots, w_d\).

\begin{pmatrix} \frac{\partial v_1}{\partial w_0} \dots \frac{\partial v_{n}}{\partial w_0}\\\
\vdots\\\
\frac{\partial v_1}{\partial w_d} \dots \frac{\partial v_{n}}{\partial w_d} \end{pmatrix}

In the second case, our output would be a \(n \times (d+1)\) matrix.

\begin{pmatrix} \frac{\partial v_1}{\partial w_0} \dots \frac{\partial v_{1}}{\partial w_d}\\\
\vdots\\\
\frac{\partial v_n}{\partial w_0} \dots \frac{\partial v_{n}}{\partial w_d} \end{pmatrix}

There is a bigger problem: no matter which among the two we choose, the chain rule equation would not work. Remember that on the RHS of the scalar chain-rule equation was the product \(\frac{\partial u}{\partial v}\frac{\partial v}{\partial w}\). In the matrix vesion, we have \(\frac{\partial u}{\partial \mathbf{v}}\), a \(n\times 1\) column matrix, as the first element. As we saw just now, the dimension of the second matrix may be either \((d+1) \times n\) or \(n \times (d+1)\). In both cases, we cannot multiply the first and second matrices. Are we stuck?

One thing we should realize at this point is that Matrix calculus, unlike the actual Calculus, is not a fundamental branch of Mathematics. It is just a shorthand for doing multiple calculus operations in a single shot. That is why we have different conventions for representing the results of operations.

With that in mind, note that even though we can’t multiply the matrices in the given order, they can be multiplied in the reverse order for one case. The product \(\frac{\partial \mathbf{v}}{\partial \mathbf{w}}\frac{\partial u}{\partial \mathbf{v}}\) is indeed valid when \(\frac{\partial \mathbf{v}}{\partial \mathbf{w}}\) has \((d+1)\) rows and \(n\) columns. In fact, this choice is just an extention of our convention for computing derivatives of scalars with respect to column matrices. In both cases, we deal with one numerator value in one column. Furthermore, the result is a column matrix with \((d+1)\) rows, matching with our expected output.

But we cannot create rules out of thin air just because the dimensions match. Maths does not work that way. Let us do a proper check.

Our hypothesis is that \(\frac{\partial u}{\partial \mathbf{w}} = \frac{\partial \mathbf{v}}{\partial \mathbf{w}}\frac{\partial u}{\partial \mathbf{v}}\)

LHS is a column matrix consisting of \(u\)’s derivatives with respect to \(w_0, \dots, w_d\). In particular, consider \(\frac{\partial u}{\partial w_i}\) \[ \frac{\partial u}{\partial w_i} = \frac{\partial u}{v_1}\frac{\partial v_1}{w_i} + \frac{\partial u}{v_2}\frac{\partial v_2}{w_i} + \dots + \frac{\partial u}{v_n}\frac{\partial v_n}{w_i} \]

If we stack the above formula for \(i=0, 1, \dots, d\) in a column, we get our output. It would look as follows:

\begin{pmatrix} \frac{\partial u}{v_1}\frac{\partial v_1}{w_0} + \dots + \frac{\partial u}{v_n}\frac{\partial v_n}{w_0}\\\
\frac{\partial u}{v_1}\frac{\partial v_1}{w_1} + \dots + \frac{\partial u}{v_n}\frac{\partial v_n}{w_1}\\\
\vdots\\\
\frac{\partial u}{v_1}\frac{\partial v_1}{w_d} + \dots + \frac{\partial u}{v_n}\frac{\partial v_n}{w_d} \end{pmatrix}

It is easy to see that the above matrix is the same as the following product:

\begin{align} \begin{pmatrix} \frac{\partial v_1}{w_0} & \frac{\partial v_2}{w_0} & \dots & \frac{\partial v_n}{w_0}\\\
\frac{\partial v_1}{w_1} & \frac{\partial v_2}{w_1} & \dots & \frac{\partial v_n}{w_1}\\\
& \vdots\\\
\frac{\partial v_1}{w_d} & \frac{\partial v_2}{w_d} & \dots & \frac{\partial v_n}{w_d}\\\
\end{pmatrix} \times \begin{pmatrix} \frac{\partial u}{v_1}\\\
\frac{\partial u}{v_2}\\\
\vdots\\\
\frac{\partial u}{v_n}\\\
\end{pmatrix} \end{align}

Which is exactly what we expected! So we can happily conclude that our hypothesis is valid.

Finding the Derivatives

Let us go on to compute the derivative of the loss function.

The second term in the chain rule expansion of \(\frac{\partial u}{\partial \mathbf{w}}\) is \(\frac{\partial u}{\partial \mathbf{v}}\) which is \(\frac{\partial \mathbf{v}^\intercal \mathbf{v}}{\partial \mathbf{v}}\). This is just like the example we did, and the derivative is \(2\mathbf{v}\), i.e., \(2(\mathbf{X}\mathbf{w}-\mathbf{y})\).

The first term is \(\frac{\partial \mathbf{v}}{\partial \mathbf{w}}\). Substituting for \(\mathbf{v}\), we get

\begin{align} \frac{\partial \mathbf{v}}{\partial \mathbf{w}} &= \frac{\partial (\mathbf{X}{w}-\mathbf{y})}{\partial \mathbf{w}}\\\
&=\frac{\partial (\mathbf{X}{w})}{\partial \mathbf{w}} \end{align}

In the last step, \(\mathbf{y}\) disappears because it is independent of \(\mathbf{w}\).

We now know how to compute this derivative. As per our convention, we can take each element in \(\mathbf{X}{w}\), and add to the output a column consiting of its derivates with respect to each weight in \(\mathbf{w}\). After doing it just for one element we can see that the derivative is \(\mathbf{X}^\intercal\).

Multiplying the two derivatives using our very own chain rule, we find the derivative of the loss function to be \[\mathbf{X}^\intercal \times 2 (\mathbf{X}\mathbf{w}-\mathbf{y}) = 2\mathbf{X}^\intercal\mathbf{X}\mathbf{w} - 2\mathbf{X}^\intercal\mathbf{y} \]

Getting the Minimum

Equating the last equation to zero, we finally get the normal equation:

\[\mathbf{w} = (\mathbf{X}^\intercal \mathbf{X})^{-1}\mathbf{X}^\intercal \mathbf{y}\]

Appendix

Just for completeness, I will now outline a Normal equation derivation that does not require the chain rule. Feel free to skip this section if you have already understood the method given above—unless you don’t want to miss out on some more Matrix calculus.

OK, let us start by expanding out the loss function.

\begin{align} h(\mathbf{w}) &= (\mathbf{X}\mathbf{w} - \mathbf{y})^\intercal (\mathbf{X}\mathbf{w} - \mathbf{y})\\\
& = ((\mathbf{X}\mathbf{w})^\intercal - \mathbf{y}^\intercal) (\mathbf{X}\mathbf{w} - \mathbf{y})\\\
& = (\mathbf{w}^\intercal \mathbf{X}^\intercal - \mathbf{y}^\intercal) (\mathbf{X}\mathbf{w} - \mathbf{y})\\\
& = \mathbf{w}^\intercal \mathbf{X}^\intercal \mathbf{X}\mathbf{w} - \mathbf{y}^\intercal \mathbf{X}\mathbf{w} - \mathbf{w}^\intercal \mathbf{X}^\intercal \mathbf{y} + \mathbf{y}^\intercal \mathbf{y} \end{align}

To compute the partial derivatives of the second and third terms, let us first observe that \(\mathbf{X}^\intercal \mathbf{y}\) is a column matrix of size \(d+1\). If we denote this product as \(\mathbf{s}\), the third term becomes \(\mathbf{w}^\intercal \mathbf{s}\). Interestingly, the second term is then \(\mathbf{s}^\intercal \mathbf{w}\). This means that both the second and the third terms are equal, as both \(\mathbf{s}\) and \(\mathbf{w}\) are column vectors.

If \(s_0, \dots s_d\) are the elements of the column matrix \(\mathbf{s}\), the second and third terms would be equal to \(s_0 w_0 + s_1 w_1 + \dots + s_d w_d\). If we take the derivative with respect to \(\mathbf{w}\), the result is simply \(\mathbf{s}\), i.e., \(\mathbf{X}^\intercal \mathbf{y}\).

The First Term

Computing the derivative of the first term is a bit more involved. Observe that \(\mathbf{X}^\intercal \mathbf{X}\) is a square matrix of size \((d+1) \times (d+1)\). Let us denote this product as \(\mathbf{B}\) for convenience. Now we will compute the derivative of \(\mathbf{w}^\intercal \mathbf{B} \mathbf{w}\) wrt \(\mathbf{w}\).

If \(\mathbf{b}_i\) stands for the \(i^\text{th}\) row of matrix \(\mathbf{B}\), \(\mathbf{B} \mathbf{w}\) is given by \[ \begin{pmatrix} \mathbf{b}_0^\intercal \mathbf{w}\\\
\mathbf{b}_1^\intercal \mathbf{w}\\\
\vdots\\\
\mathbf{b}_d^\intercal \mathbf{w} \end{pmatrix} \]

And hence \(\mathbf{w}^\intercal \mathbf{B} \mathbf{w}\) is \[ w_0 (\mathbf{b}_0^\intercal \mathbf{w}) + w_1 (\mathbf{b}_1^\intercal \mathbf{w}) + \dots + w_0 (\mathbf{b}_d^\intercal \mathbf{w}) \]

We will use \(s\) to denote the above sum. Let us consider its partial derivative wrt a single weight \(w_i\), i.e., \(\frac{\partial s}{\partial w_i}\)

We can see that for \(j\not=i\),

\begin{align} \frac{\partial (w_j (b_{j}^\intercal \mathbf{w}))}{\partial w_i} &= \frac{\partial}{\partial w_i} (w_j (b_{j,0} w_0 + \dots + b_{j,i} w_i + \dots + b_{j,d} w_d))\\\
&= w_j b_{j,i} \end{align}

But for the \(i^\text{th}\) term, we have

\begin{align} \frac{\partial (w_i (b_i^\intercal \mathbf{w}))}{\partial w_i} &= \frac{\partial}{\partial w_i} (w_i (b_{i, 0} w_0 + \dots + b_{i, i} w_i + \dots + b_{i, d} w_d))\\\
&= \frac{\partial}{\partial w_i} (b_{i, 0} w_i w_0 + \dots + b_{i, i} w_i^2 + \dots + b_{i, d} w_i w_d)\\\
&= (b_{i, 0} w_0 + \dots + 2 b_{i, i} w_i + \dots + b_{i, d} w_d)) \end{align}

Adding together, we get

\begin{align} \frac{\partial s}{\partial w_i} &= w_0 (b_{0, i} + b_{i, 0}) + \dots + 2w_i b_{i, i} + \dots + w_d (b_{d, i} + b_{i, d}) \end{align}

This can be decomposed into a matrix product as below:

\begin{align} \frac{\partial s}{\partial w_i} &= \begin{pmatrix} b_{0, i} + b_{i, 0} & \dots & 2b_{i, i} + \dots + b_{d, i} + b_{i, d} \end{pmatrix} \begin{pmatrix} w_0\\\
\vdots\\\
w_d \end{pmatrix}\\\
&= ( \begin{pmatrix} b_{0, i} & \dots & b_{i, i} & \dots & b_{d, i} \end{pmatrix}\\\
&,,,,,,,,, + \begin{pmatrix} b_{i, 0} & \dots & b_{i, i} & \dots & b_{i, d} \end{pmatrix} ) \begin{pmatrix} w_0\\\
\vdots\\\
w_d \end{pmatrix} \end{align}

In other words, \(\frac{\partial s}{\partial w_i}\) is simply the sum of the \(i^\text{th}\) row and the \(i^\text{th}\) column of \(\mathbf{B}\), multiplied with \(\textbf{w}\). But the \(i^\text{th}\) column of \(\mathbf{B}\) is the \(i^\text{th}\) row of \(\mathbf{B}^\intercal\). Therefore, the full partial derivative is \[ \frac{\partial s}{\partial \mathbf{w}} = (\mathbf{B} + \mathbf{B}^\intercal) \mathbf{w} \]

Now we can substitute back \(\mathbf{B} = \mathbf{X}^\intercal \mathbf{X}\), and get

\begin{align} \frac{\partial s}{\partial \mathbf{w}} &= (\mathbf{X}^\intercal \mathbf{X} + (\mathbf{X}^\intercal \mathbf{X})^\intercal) \mathbf{w}\\\
&=(\mathbf{X}^\intercal \mathbf{X} + \mathbf{X}^\intercal \mathbf{X}) \mathbf{w}\\\
&=2\mathbf{X}^\intercal \mathbf{X} \mathbf{w} \end{align}

Adding Them Up

Now that we have the derivatives of all the terms, we can just combine them and get the full derivative.

\begin{align} \frac{\partial h(\mathbf{w})}{\partial \mathbf{w}} &= \frac{\partial}{\mathbf{w}} (\mathbf{w}^\intercal \mathbf{X}^\intercal \mathbf{X}\mathbf{w}) - \frac{\partial }{\mathbf{w}} (\mathbf{y}^\intercal \mathbf{X}\mathbf{w}) - \frac{\partial }{\mathbf{w}}(\mathbf{w}^\intercal \mathbf{X}^\intercal \mathbf{y} )\\\
&=2\mathbf{X}^\intercal\mathbf{X}\mathbf{w} - \mathbf{X}^\intercal\mathbf{y} - \mathbf{X}^\intercal\mathbf{y}\\\
&=2\mathbf{X}^\intercal\mathbf{X}\mathbf{w} - 2\mathbf{X}^\intercal\mathbf{y} \end{align}

This checks out with the derivative we got using the chain rule. So we will stop here.


  1. As befitting the occasion, linear functions are examples of functions having neither a minimum nor a maximum ↩︎

  2. It is also possible that the extremum of a quadratic function is a global maximum. Think about \(y=-x^2\). But in that case, we can simply multiply it by \(-1\) to get a function with global minimum. ↩︎