Matrix multiplication is one of the most fundamental operations in machine learning. In PyTorch, you’ll often see three different ways to do it:
torch.matmultorch.mm@(the Python operator)
At first glance, they look interchangeable — but they aren’t. PyTorch makes deliberate choices about when and how these operators work, especially for 1-D tensors and batched operations. Let’s break it all down.
1. torch.mm: Strict Matrix × Matrix
torch.mm is the most “traditional” operator.
- Inputs: Must be exactly 2-D tensors.
- Behavior: Standard matrix multiplication
(n,m) @ (m,p) → (n,p). - Limitations: No 1-D support, no batching, no broadcasting.
import torch a = torch.randn(2, 3) b = torch.randn(3, 4) print(torch.mm(a, b).shape) # torch.Size([2, 4])
❌ If you try with a vector:
x = torch.tensor([1,2,3]) torch.mm(x, x) # RuntimeError
So: torch.mm is strict and safe, but less flexible.
2. torch.matmul: The General Workhorse
torch.matmul is more powerful and flexible.
- Inputs: 1-D, 2-D, or N-D tensors.
- Special cases for 1-D:
(n,) @ (n,)→ scalar (dot product).(n,) @ (n,m)→(m,).(n,m) @ (m,)→(n,).
- Batch support:
(b,n,m) @ (b,m,p) → (b,n,p)(very useful in deep learning). - Broadcasting: Works across leading dimensions.
Example: dot product
a = torch.tensor([1,2,3]) b = torch.tensor([1,2,3]) print(torch.matmul(a, b)) # tensor(14)
Example: batch multiplication
A = torch.randn(10, 3, 4) # batch of 10 matrices B = torch.randn(10, 4, 5) print(torch.matmul(A, B).shape) # torch.Size([10, 3, 5])
✅ Use torch.matmul when you want flexibility.
3. @: The Shorthand
Python’s @ operator is mapped to torch.matmul.
- Everything said about
matmulapplies here. - Cleaner syntax, widely used in PyTorch codebases.
a = torch.randn(2,3) b = torch.randn(3,4) print((a @ b).shape) # torch.Size([2, 4])
So: @ = torch.matmul.
4. Special 1-D Tensor Rules
This is where beginners get tripped up. Let’s see what happens when we mix 1-D and 2-D tensors.
| Left shape | Right shape | Result | Notes |
|---|---|---|---|
(n,) | (n,) | scalar ( ) | dot product |
(n,) | (n, m) | (m,) | row vector × matrix |
(n, m) | (m,) | (n,) | matrix × column vector |
(n,m) | (n,) | only valid if n == m | else mismatch |
(n,m) | (m,p) | (n,p) | standard matrix multiply |
Examples:
# Dot product a = torch.tensor([1,2,3]) print(a @ a) # tensor(14) # Vector × Matrix a = torch.tensor([1,2,3]) b = torch.randn(3,2) print((a @ b).shape) # torch.Size([2]) # Matrix × Vector a = torch.randn(2,3) b = torch.tensor([1,2,3]) print((a @ b).shape) # torch.Size([2])
❌ Invalid example:
a = torch.randn(2,3) b = torch.tensor([1,2]) # mismatched inner dim a @ b # RuntimeError
5. Quick Comparison
| Operator | Allowed Inputs | Special 1-D Rules | Batch Matmul | Recommended Usage |
|---|---|---|---|---|
torch.mm | 2-D only | ❌ no | ❌ no | Strict 2-D matrix multiply |
torch.matmul | 1-D, 2-D, N-D | ✅ yes | ✅ yes | General-purpose multiply |
@ | 1-D, 2-D, N-D | ✅ yes | ✅ yes | Shorthand for readability |
6. Practical Tips
- Use
*for elementwise multiplication. - Use
@(ormatmul) for linear algebra multiplication. - If you need strict 2-D only, stick with
mm. - When debugging, always check shapes with
.shape. - For deep learning models (e.g. attention layers), batched matmul is essential — use
@.
7. Closing Thoughts
PyTorch gives you multiple tools for matrix multiplication:
mmis strict and safe.matmuland@are flexible and powerful.
The special treatment of 1-D tensors makes code concise (dot products, matrix-vector multiplies), but it also means you need to be careful about shapes. Once you internalize these rules, you can move confidently between mm, matmul, and @ without surprises.
👉 Would you like me to add a section with diagrams (like showing (n,) expanding to (1,n) or (n,1) before multiplication)? It could make these shape rules visually clearer.
