Would Strassen matmuls be useful in AI if data movement were free?

Strassen’s algorithm for matrix multiplication is dramatically faster than traditional matrix multiplication, yet it is never used in large language models or other AI systems. Why not?

The standard answer: Strassen’s algorithm involves complicated data movement, whose costs often exceed the savings in multiply-add operations. While true, I find this answer unsatisfying: with clever software and hardware design, one can often reduce data movement costs. If someone solved that problem, would Strassen’s algorithm then be worthwhile?

To explore this possibility, we’ll analyze Strassen’s algorithm under the optimistic assumption that data movement is free. We find that there are other more fundamental barriers.

What is Strassen’s algorithm?

The distributive law, (a+b)(c+d)=ac+ad+bc+bd(a+b)(c+d)=ac+ad+bc+bd, has one multiplication on the left hand side that “does the work of” four multiplications on the right hand side. Strassen’s algorithm exploits this property to reduce the number of multiplications required, by adding terms together before multiplying. On a 2×2×22\times2\times2 matrix multiply, this trick reduces the multiplication count from 8 to 7, and the savings grow as matrices get larger.

The algebra for Strassen’s algorithm on a 2×2×22\times2\times2 matrix multiply is shown below. The specifics aren’t crucial—just note the general pattern of “add terms, multiply, then subtract to recover the result”. To compute

(C11C12C21C22)=(A11A12A21A22)(B11B12B21B22)\begin{pmatrix}C_{11} & C_{12} \\ C_{21} & C_{22}\end{pmatrix} = \begin{pmatrix}A_{11} & A_{12} \\ A_{21} & A_{22}\end{pmatrix}\begin{pmatrix}B_{11} & B_{12} \\ B_{21} & B_{22}\end{pmatrix}

it performs the following calculations1:

M1=(A11+A22)(B11+B22)M2=(A21+A22)B11M3=A11(B12B22)M4=A22(B21B11)M5=(A11+A12)B22M6=(A21A11)(B11+B12)M7=(A12A22)(B21+B22)C11=M1+M4M5+M7C12=M3+M5C21=M2+M4C22=M1M2+M3+M6\begin{align*} M_1 &= (A_{11} + A_{22})(B_{11} + B_{22}) \\ M_2 &= (A_{21} + A_{22})B_{11} \\ M_3 &= A_{11}(B_{12} - B_{22}) \\ M_4 &= A_{22}(B_{21} - B_{11}) \\ M_5 &= (A_{11} + A_{12})B_{22} \\ M_6 &= (A_{21} - A_{11})(B_{11} + B_{12}) \\ M_7 &= (A_{12} - A_{22})(B_{21} + B_{22}) \\ C_{11} &= M_1 + M_4 - M_5 + M_7 \\ C_{12} &= M_3 + M_5 \\ C_{21} &= M_2 + M_4 \\ C_{22} &= M_1 - M_2 + M_3 + M_6 \end{align*}

There is a genuine reduction in multiplication count. This comes at the cost of more additions, but those aren’t a major concern since additions are far cheaper than multiplications2. Strassen’s algorithm also has irregular structure and less parallelism compared to traditional matrix multiplication, which can hurt real-world performance; this is what people mean by “data movement overheads”.

But supposing these could be engineered away, what problems remain? The critical one is numerical precision.

Numerical instability and how to compensate for it

Strassen’s algorithm has weaker numerical stability properties than traditional matrix multiplication. Consider how we calculate C12C_{12}.

In Strassen’s algorithm:

C12=A11(B12B22)+(A11+A12)B22C_{12} = A_{11}(B_{12} - B_{22}) + (A_{11} + A_{12})B_{22}

In traditional matrix multiplication:

C12=A11B12+A12B22.C_{12} = A_{11}B_{12} + A_{12}B_{22}.

The extra operations in Strassen’s version—B12B22B_{12}-B_{22} and A11+A12A_{11}+A_{12}—introduce additional floating point round-off errors. Since the full algorithm applies these formulas recursively, the errors compound at scale.

To summarize: Strassen’s algorithm is computationally cheaper but numerically less precise than traditional matrix multiplication. Since Strassen’s algorithm is better in one way and worse in another, it’s hard to directly compare. To make a fair comparison, we’ll control for precision: we’ll run both algorithms at the same target precision and see which is faster.

Running Strassen’s algorithm at higher precision

The overall intuition is: when Strassen adds matrices before multiplying them, those sums need extra bits to avoid overflow. If AA and BB each fit in 8 bits, A+BA+B needs 9 bits. This means our “cheaper” multiplications are actually multiplying larger numbers. Since the hardware cost of multiplication grows quadratically with bit width, this effect compounds quickly.

To quantify this, let’s assume we have a free choice of number formats; perhaps we’re designing custom hardware. We’ll also simplify by analyzing integers rather than floating point numbers: while neural networks use floating point in practice, the error and cost analysis is cleaner with integers. To first approximation, the hardware cost of multiplying a bb-bit integer by a cc-bit integer is proportional to b×cb\times c.

Suppose the input matrices AA and BB contain bb-bit integers. In traditional matrix multiplication, all multiplies AijBjkA_{ij}B_{jk} are between two bb-bit integers, each costing b2b^2 units of hardware. In Strassen’s algorithm, whenever we compute a sum like (A11+A12)(A_{11} + A_{12}) before multiplication, we need a (b+1)(b+1)-bit integer to store the result without loss of precision. This gives us multiplies like:

M1=(A11+A22)(b+1) bits×(B11+B22)(b+1) bitsM_1 = \underbrace{(A_{11} + A_{22})}_{(b+1)\text{ bits}}\times \underbrace{(B_{11} + B_{22})}_{(b+1)\text{ bits}}

or:

M2=(A21+A22)(b+1) bits×B11b bitsM_2 = \underbrace{(A_{21} + A_{22})}_{(b+1)\text{ bits}}\times\underbrace{B_{11}}_{b\text{ bits}}

For a 2×2×22\times2\times2 matrix multiply with bb-bit inputs, the total multiplier cost is:

Strassen: (7b2+10b+3) hardware costTraditional: 8b2 hardware cost\begin{gather*} \text{Strassen: }(7b^2+10b+3)\text{ hardware cost} \\ \text{Traditional: }8b^2\text{ hardware cost} \end{gather*}

At what precisions does Strassen’s algorithm win?

Since Strassen has a smaller quadratic coefficient but a larger linear coefficient, it will be more expensive at low precisions but cheaper at high precisions. Here’s the crossover:

051015050010001500StrassenTraditionalInput bit widthHardware cost of multipliers

Strassen becomes cost-effective only above 11 bits. At around 32 bits, it approaches its asymptotic advantage of roughly 14% speedup:

0102030400.751.001.25Strassen advantageAsymptotic speedupInput bit widthStrassen's speedup over traditional

Remember, this analysis covers just one recursive level of Strassen’s algorithm. The full algorithm applies this trick log2(n)\log_2(n) times for an n×n×nn\times n\times n matrix multiply. The speedup curves become more pronounced with recursion: greater slowdowns at low precision, greater speedups at high precision.

At what precisions do AI models run at?

Neural networks have steadily moved to lower precisions: from float32 to bfloat16 to fp8, with fp4 on the horizon. Even bfloat16, the workhorse format for many years, has only 8 mantissa bits. For all these formats except float32, Strassen’s algorithm is a net loss once we account for precision.

This reveals the fundamental reason Strassen’s algorithm isn’t used in neural networks:

At low baseline precision, Strassen’s algorithm is an inefficient way to trade precision for speed. It is more efficient to directly trade precision for speed using a smaller number format.


  1. From Wikipedia. See there for more details.↩︎

  2. In well-tuned Strassen implementations, the multiplications are multiplications of matrix sub-blocks, not scalar values. Matrix multiplications are vastly more expensive than matrix additions.↩︎