Would Strassen matmuls be useful in AI if data movement were free?
Strassen’s algorithm for matrix multiplication is dramatically faster than traditional matrix multiplication, yet it is never used in large language models or other AI systems. Why not?
The standard answer: Strassen’s algorithm involves complicated data movement, whose costs often exceed the savings in multiply-add operations. While true, I find this answer unsatisfying: with clever software and hardware design, one can often reduce data movement costs. If someone solved that problem, would Strassen’s algorithm then be worthwhile?
To explore this possibility, we’ll analyze Strassen’s algorithm under the optimistic assumption that data movement is free. We find that there are other more fundamental barriers.
What is Strassen’s algorithm?
The distributive law, , has one multiplication on the left hand side that “does the work of” four multiplications on the right hand side. Strassen’s algorithm exploits this property to reduce the number of multiplications required, by adding terms together before multiplying. On a matrix multiply, this trick reduces the multiplication count from 8 to 7, and the savings grow as matrices get larger.
The algebra for Strassen’s algorithm on a matrix multiply is shown below. The specifics aren’t crucial—just note the general pattern of “add terms, multiply, then subtract to recover the result”. To compute
it performs the following calculations1:
There is a genuine reduction in multiplication count. This comes at the cost of more additions, but those aren’t a major concern since additions are far cheaper than multiplications2. Strassen’s algorithm also has irregular structure and less parallelism compared to traditional matrix multiplication, which can hurt real-world performance; this is what people mean by “data movement overheads”.
But supposing these could be engineered away, what problems remain? The critical one is numerical precision.
Numerical instability and how to compensate for it
Strassen’s algorithm has weaker numerical stability properties than traditional matrix multiplication. Consider how we calculate .
In Strassen’s algorithm:
In traditional matrix multiplication:
The extra operations in Strassen’s version— and —introduce additional floating point round-off errors. Since the full algorithm applies these formulas recursively, the errors compound at scale.
To summarize: Strassen’s algorithm is computationally cheaper but numerically less precise than traditional matrix multiplication. Since Strassen’s algorithm is better in one way and worse in another, it’s hard to directly compare. To make a fair comparison, we’ll control for precision: we’ll run both algorithms at the same target precision and see which is faster.
Running Strassen’s algorithm at higher precision
The overall intuition is: when Strassen adds matrices before multiplying them, those sums need extra bits to avoid overflow. If and each fit in 8 bits, needs 9 bits. This means our “cheaper” multiplications are actually multiplying larger numbers. Since the hardware cost of multiplication grows quadratically with bit width, this effect compounds quickly.
To quantify this, let’s assume we have a free choice of number formats; perhaps we’re designing custom hardware. We’ll also simplify by analyzing integers rather than floating point numbers: while neural networks use floating point in practice, the error and cost analysis is cleaner with integers. To first approximation, the hardware cost of multiplying a -bit integer by a -bit integer is proportional to .
Suppose the input matrices and contain -bit integers. In traditional matrix multiplication, all multiplies are between two -bit integers, each costing units of hardware. In Strassen’s algorithm, whenever we compute a sum like before multiplication, we need a -bit integer to store the result without loss of precision. This gives us multiplies like:
or:
For a matrix multiply with -bit inputs, the total multiplier cost is:
At what precisions does Strassen’s algorithm win?
Since Strassen has a smaller quadratic coefficient but a larger linear coefficient, it will be more expensive at low precisions but cheaper at high precisions. Here’s the crossover:
Strassen becomes cost-effective only above 11 bits. At around 32 bits, it approaches its asymptotic advantage of roughly 14% speedup:
Remember, this analysis covers just one recursive level of Strassen’s algorithm. The full algorithm applies this trick times for an matrix multiply. The speedup curves become more pronounced with recursion: greater slowdowns at low precision, greater speedups at high precision.
At what precisions do AI models run at?
Neural networks have steadily moved to lower precisions: from float32 to bfloat16 to fp8, with fp4 on the horizon. Even bfloat16, the workhorse format for many years, has only 8 mantissa bits. For all these formats except float32, Strassen’s algorithm is a net loss once we account for precision.
This reveals the fundamental reason Strassen’s algorithm isn’t used in neural networks:
At low baseline precision, Strassen’s algorithm is an inefficient way to trade precision for speed. It is more efficient to directly trade precision for speed using a smaller number format. From Wikipedia. See there for more details.↩︎ In well-tuned Strassen implementations, the multiplications are multiplications of matrix sub-blocks, not scalar values. Matrix multiplications are vastly more expensive than matrix additions.↩︎