Matrix multiplication is one of the most fundamental operations in machine learning. In PyTorch, you’ll often see three different ways to do it:

  • torch.matmul
  • torch.mm
  • @ (the Python operator)

At first glance, they look interchangeable — but they aren’t. PyTorch makes deliberate choices about when and how these operators work, especially for 1-D tensors and batched operations. Let’s break it all down.


1. torch.mm: Strict Matrix × Matrix

torch.mm is the most “traditional” operator.

  • Inputs: Must be exactly 2-D tensors.
  • Behavior: Standard matrix multiplication (n,m) @ (m,p) → (n,p).
  • Limitations: No 1-D support, no batching, no broadcasting.
import torch

a = torch.randn(2, 3)
b = torch.randn(3, 4)

print(torch.mm(a, b).shape)   # torch.Size([2, 4])

❌ If you try with a vector:

x = torch.tensor([1,2,3])
torch.mm(x, x)   # RuntimeError

So: torch.mm is strict and safe, but less flexible.


2. torch.matmul: The General Workhorse

torch.matmul is more powerful and flexible.

  • Inputs: 1-D, 2-D, or N-D tensors.
  • Special cases for 1-D:
    • (n,) @ (n,) → scalar (dot product).
    • (n,) @ (n,m)(m,).
    • (n,m) @ (m,)(n,).
  • Batch support: (b,n,m) @ (b,m,p) → (b,n,p) (very useful in deep learning).
  • Broadcasting: Works across leading dimensions.

Example: dot product

a = torch.tensor([1,2,3])
b = torch.tensor([1,2,3])

print(torch.matmul(a, b))   # tensor(14)

Example: batch multiplication

A = torch.randn(10, 3, 4)  # batch of 10 matrices
B = torch.randn(10, 4, 5)
print(torch.matmul(A, B).shape)  # torch.Size([10, 3, 5])

✅ Use torch.matmul when you want flexibility.


3. @: The Shorthand

Python’s @ operator is mapped to torch.matmul.

  • Everything said about matmul applies here.
  • Cleaner syntax, widely used in PyTorch codebases.
a = torch.randn(2,3)
b = torch.randn(3,4)

print((a @ b).shape)   # torch.Size([2, 4]) 

So: @ = torch.matmul.


4. Special 1-D Tensor Rules

This is where beginners get tripped up. Let’s see what happens when we mix 1-D and 2-D tensors.

Left shapeRight shapeResultNotes
(n,)(n,)scalar ( )dot product
(n,)(n, m)(m,)row vector × matrix
(n, m)(m,)(n,)matrix × column vector
(n,m)(n,)only valid if n == melse mismatch
(n,m)(m,p)(n,p)standard matrix multiply

Examples:

# Dot product
a = torch.tensor([1,2,3])
print(a @ a)   # tensor(14)

# Vector × Matrix
a = torch.tensor([1,2,3])
b = torch.randn(3,2)
print((a @ b).shape)  # torch.Size([2])

# Matrix × Vector
a = torch.randn(2,3)
b = torch.tensor([1,2,3])
print((a @ b).shape)  # torch.Size([2])

❌ Invalid example:

a = torch.randn(2,3)
b = torch.tensor([1,2])   # mismatched inner dim
a @ b   # RuntimeError

5. Quick Comparison

OperatorAllowed InputsSpecial 1-D RulesBatch MatmulRecommended Usage
torch.mm2-D only❌ no❌ noStrict 2-D matrix multiply
torch.matmul1-D, 2-D, N-D✅ yes✅ yesGeneral-purpose multiply
@1-D, 2-D, N-D✅ yes✅ yesShorthand for readability

6. Practical Tips

  • Use * for elementwise multiplication.
  • Use @ (or matmul) for linear algebra multiplication.
  • If you need strict 2-D only, stick with mm.
  • When debugging, always check shapes with .shape.
  • For deep learning models (e.g. attention layers), batched matmul is essential — use @.

7. Closing Thoughts

PyTorch gives you multiple tools for matrix multiplication:

  • mm is strict and safe.
  • matmul and @ are flexible and powerful.

The special treatment of 1-D tensors makes code concise (dot products, matrix-vector multiplies), but it also means you need to be careful about shapes. Once you internalize these rules, you can move confidently between mm, matmul, and @ without surprises.


👉 Would you like me to add a section with diagrams (like showing (n,) expanding to (1,n) or (n,1) before multiplication)? It could make these shape rules visually clearer.

Leave a Reply

Your email address will not be published. Required fields are marked *