Matrix multiplication is one of the most fundamental operations in machine learning. In PyTorch, you’ll often see three different ways to do it:
torch.matmul
torch.mm
@
(the Python operator)
At first glance, they look interchangeable — but they aren’t. PyTorch makes deliberate choices about when and how these operators work, especially for 1-D tensors and batched operations. Let’s break it all down.
1. torch.mm
: Strict Matrix × Matrix
torch.mm
is the most “traditional” operator.
- Inputs: Must be exactly 2-D tensors.
- Behavior: Standard matrix multiplication
(n,m) @ (m,p) → (n,p)
. - Limitations: No 1-D support, no batching, no broadcasting.
import torch a = torch.randn(2, 3) b = torch.randn(3, 4) print(torch.mm(a, b).shape) # torch.Size([2, 4])
❌ If you try with a vector:
x = torch.tensor([1,2,3]) torch.mm(x, x) # RuntimeError
So: torch.mm
is strict and safe, but less flexible.
2. torch.matmul
: The General Workhorse
torch.matmul
is more powerful and flexible.
- Inputs: 1-D, 2-D, or N-D tensors.
- Special cases for 1-D:
(n,) @ (n,)
→ scalar (dot product).(n,) @ (n,m)
→(m,)
.(n,m) @ (m,)
→(n,)
.
- Batch support:
(b,n,m) @ (b,m,p) → (b,n,p)
(very useful in deep learning). - Broadcasting: Works across leading dimensions.
Example: dot product
a = torch.tensor([1,2,3]) b = torch.tensor([1,2,3]) print(torch.matmul(a, b)) # tensor(14)
Example: batch multiplication
A = torch.randn(10, 3, 4) # batch of 10 matrices B = torch.randn(10, 4, 5) print(torch.matmul(A, B).shape) # torch.Size([10, 3, 5])
✅ Use torch.matmul
when you want flexibility.
3. @
: The Shorthand
Python’s @
operator is mapped to torch.matmul
.
- Everything said about
matmul
applies here. - Cleaner syntax, widely used in PyTorch codebases.
a = torch.randn(2,3) b = torch.randn(3,4) print((a @ b).shape) # torch.Size([2, 4])
So: @
= torch.matmul
.
4. Special 1-D Tensor Rules
This is where beginners get tripped up. Let’s see what happens when we mix 1-D and 2-D tensors.
Left shape | Right shape | Result | Notes |
---|---|---|---|
(n,) | (n,) | scalar ( ) | dot product |
(n,) | (n, m) | (m,) | row vector × matrix |
(n, m) | (m,) | (n,) | matrix × column vector |
(n,m) | (n,) | only valid if n == m | else mismatch |
(n,m) | (m,p) | (n,p) | standard matrix multiply |
Examples:
# Dot product a = torch.tensor([1,2,3]) print(a @ a) # tensor(14) # Vector × Matrix a = torch.tensor([1,2,3]) b = torch.randn(3,2) print((a @ b).shape) # torch.Size([2]) # Matrix × Vector a = torch.randn(2,3) b = torch.tensor([1,2,3]) print((a @ b).shape) # torch.Size([2])
❌ Invalid example:
a = torch.randn(2,3) b = torch.tensor([1,2]) # mismatched inner dim a @ b # RuntimeError
5. Quick Comparison
Operator | Allowed Inputs | Special 1-D Rules | Batch Matmul | Recommended Usage |
---|---|---|---|---|
torch.mm | 2-D only | ❌ no | ❌ no | Strict 2-D matrix multiply |
torch.matmul | 1-D, 2-D, N-D | ✅ yes | ✅ yes | General-purpose multiply |
@ | 1-D, 2-D, N-D | ✅ yes | ✅ yes | Shorthand for readability |
6. Practical Tips
- Use
*
for elementwise multiplication. - Use
@
(ormatmul
) for linear algebra multiplication. - If you need strict 2-D only, stick with
mm
. - When debugging, always check shapes with
.shape
. - For deep learning models (e.g. attention layers), batched matmul is essential — use
@
.
7. Closing Thoughts
PyTorch gives you multiple tools for matrix multiplication:
mm
is strict and safe.matmul
and@
are flexible and powerful.
The special treatment of 1-D tensors makes code concise (dot products, matrix-vector multiplies), but it also means you need to be careful about shapes. Once you internalize these rules, you can move confidently between mm
, matmul
, and @
without surprises.
👉 Would you like me to add a section with diagrams (like showing (n,)
expanding to (1,n)
or (n,1)
before multiplication)? It could make these shape rules visually clearer.