Goyalayus

Notes, essays, and fragments from the edge of understanding.

BROADCASTING

December 9, 2025

Original Substack post

Broadcasting

Broadcasting is one of those ideas that looks like a shape trick until you understand what the tensor library is actually doing.

At the Python level, broadcasting means:

x = torch.randn(4, 3)
b = torch.randn(3)
y = x + b

and PyTorch behaves as if b had shape (4, 3).

But the important part is this: PyTorch usually does not copy b four times. It creates a view whose metadata says, "when the row index changes, keep reading from the same memory."

That is the core idea:

Broadcasting is mostly shape metadata plus stride tricks. The data is not stretched. The pointer arithmetic changes.

To understand broadcasting properly, you need three things:

  1. shape
  2. storage
  3. stride

Once those are clear, expand, repeat, view, reshape, transpose, and batched matmul all become much less magical.

Shape, Storage, and Stride

A tensor has a logical shape, but its values live in a flat storage buffer.

Suppose:

import torch

a = torch.tensor([
    [10, 11, 12],
    [20, 21, 22],
])

print(a.shape)   # torch.Size([2, 3])
print(a.stride()) # (3, 1)

Logically, a is:

shape = (2, 3)

row 0: 10  11  12
row 1: 20  21  22

Physically, the values are stored as one flat buffer:

storage index:  0   1   2   3   4   5
value:         10  11  12  20  21  22

The stride tells PyTorch how many storage positions to jump when an index changes by 1 along a dimension.

For a:

shape  = (2, 3)
stride = (3, 1)

That means:

  • moving down one row jumps by 3 elements
  • moving right one column jumps by 1 element

The general offset formula is:

where is the stride for dimension .

For the matrix above:

So:

a[0, 0] -> offset 3*0 + 0 = 0 -> 10
a[0, 2] -> offset 3*0 + 2 = 2 -> 12
a[1, 0] -> offset 3*1 + 0 = 3 -> 20
a[1, 2] -> offset 3*1 + 2 = 5 -> 22

Contiguous Strides

For a normal row-major contiguous tensor with shape:

the last dimension moves fastest, so:

and for any earlier dimension:

Example:

x = torch.empty(2, 3, 4)
print(x.stride()) # (12, 4, 1)

Why?

shape  = (2, 3, 4)
stride = (3*4, 4, 1)
       = (12, 4, 1)

The offset formula is:

This is what "contiguous" means: the logical traversal order matches the physical memory order.

Transpose: Same Storage, Different Stride

Now transpose the matrix:

a = torch.arange(6).reshape(2, 3)
b = a.t()

print(a)
print(a.shape, a.stride()) # torch.Size([2, 3]) (3, 1)

print(b)
print(b.shape, b.stride()) # torch.Size([3, 2]) (1, 3)

a is:

logical a, shape (2, 3), stride (3, 1)

0  1  2
3  4  5

b = a.t() is:

logical b, shape (3, 2), stride (1, 3)

0  3
1  4
2  5

The storage did not move:

storage index: 0  1  2  3  4  5
value:         0  1  2  3  4  5

Only the metadata changed.

For b, the offset formula is:

So:

b[0, 0] -> offset 0 + 3*0 = 0 -> 0
b[0, 1] -> offset 0 + 3*1 = 3 -> 3
b[1, 0] -> offset 1 + 3*0 = 1 -> 1
b[2, 1] -> offset 2 + 3*1 = 5 -> 5

This is why b is usually non-contiguous. Iterating across the second axis now jumps by 3 in storage.

print(a.is_contiguous()) # True
print(b.is_contiguous()) # False

Some operations can work with non-contiguous tensors. Others need contiguous memory. For example:

b.view(6)        # often fails
b.reshape(6)     # works, may copy
b.contiguous()   # explicitly makes a contiguous copy

The short version:

  • view changes shape metadata only, so it needs compatible strides
  • reshape changes shape if possible, otherwise copies
  • contiguous copies data into standard row-major layout

Broadcasting Rules

Broadcasting compares shapes from right to left.

Two dimensions are compatible if:

  1. they are equal, or
  2. one of them is 1, or
  3. one of them does not exist

The result dimension is the maximum of the two dimensions.

Example:

    (4, 3)
+      (3)
---------
    (4, 3)

The missing leading dimension in (3,) behaves like 1:

    (4, 3)
+   (1, 3)
---------
    (4, 3)

Another example:

    (10, 1, 3, 4)
+       (20, 3, 1)
------------------
    (10,20, 3, 4)

Aligned right:

    10   1   3   4
         20  3   1
------------------
    10   20  3   4

Compatibility check:

dim -1: 4 vs 1  -> 4
dim -2: 3 vs 3  -> 3
dim -3: 1 vs 20 -> 20
dim -4: 10 vs missing -> 10

A failing example:

    (5, 6)
target idea: (5, 6, 10)

If you try to combine (5, 6) with (5, 6, 10), PyTorch aligns from the right:

target: 5  6  10
input:     5   6

The last dimensions are 10 and 6, so this fails.

If you actually mean "add a singleton dimension at the end," you need:

x = torch.randn(5, 6)
x = x.unsqueeze(-1) # shape (5, 6, 1)

Now:

target: 5  6  10
input:  5  6   1

That broadcasts successfully.

Broadcasting as Stride 0

Here is the key example.

v = torch.tensor([[10, 20, 30]])
print(v.shape, v.stride()) # torch.Size([1, 3]) (3, 1)

e = v.expand(4, 3)
print(e.shape, e.stride()) # torch.Size([4, 3]) (0, 1)

Logical view:

e, shape (4, 3)

10  20  30
10  20  30
10  20  30
10  20  30

Physical storage is still just:

storage index: 0   1   2
value:        10  20  30

The expanded tensor has stride (0, 1).

Its offset formula is:

So all rows point to the same storage:

e[0, 0] -> offset 0*0 + 0 = 0 -> 10
e[1, 0] -> offset 0*1 + 0 = 0 -> 10
e[2, 0] -> offset 0*2 + 0 = 0 -> 10
e[3, 0] -> offset 0*3 + 0 = 0 -> 10

e[0, 2] -> offset 0*0 + 2 = 2 -> 30
e[3, 2] -> offset 0*3 + 2 = 2 -> 30

That is the whole trick.

Stride 0 means:

moving along this dimension does not move the memory pointer.

Memory diagram:

logical expanded tensor:

          col0       col1       col2
row0   storage[0] storage[1] storage[2]
row1   storage[0] storage[1] storage[2]
row2   storage[0] storage[1] storage[2]
row3   storage[0] storage[1] storage[2]

actual storage:

storage[0] = 10
storage[1] = 20
storage[2] = 30

It looks like 12 numbers, but there are only 3 stored numbers.

expand vs repeat

expand creates a view. repeat creates a copy.

v = torch.tensor([[10, 20, 30]])

e = v.expand(4, 3)
r = v.repeat(4, 1)

print(e.shape, e.stride()) # torch.Size([4, 3]) (0, 1)
print(r.shape, r.stride()) # torch.Size([4, 3]) (3, 1)

Conceptually:

v storage:
10 20 30

expand(4, 3) storage:
10 20 30

repeat(4, 1) storage:
10 20 30 10 20 30 10 20 30 10 20 30

So:

  • expand is cheap in memory
  • repeat uses more memory
  • expand can only expand dimensions of size 1
  • repeat can tile any dimension because it physically copies

The difference matters a lot for large tensors.

x = torch.randn(1, 4096)

expanded = x.expand(8192, 4096)
repeated = x.repeat(8192, 1)

print(expanded.storage().size()) # roughly 4096 elements
print(repeated.storage().size()) # roughly 8192 * 4096 elements

The exact storage API has changed across PyTorch versions, but the idea is stable: expand shares storage, repeat allocates.

A safer modern check:

print(expanded.untyped_storage().nbytes())
print(repeated.untyped_storage().nbytes())

Why You Should Be Careful Writing to Expanded Tensors

Because expanded dimensions may point multiple logical elements at the same memory, in-place writes can be ambiguous.

v = torch.tensor([[10., 20., 30.]])
e = v.expand(4, 3)

e.add_(1)

PyTorch will usually reject operations like this because multiple elements alias the same storage location.

The problem is:

e[0, 0], e[1, 0], e[2, 0], e[3, 0]

all refer to the same underlying value.

If you need independent writable data, use:

e = v.expand(4, 3).clone()

or:

r = v.repeat(4, 1)

Use expand when you want virtual broadcasting. Use repeat when you genuinely need physical copies.

Broadcasting in Elementwise Operations

Elementwise operations use broadcasting directly.

x = torch.arange(12).reshape(4, 3)
b = torch.tensor([100, 200, 300])

y = x + b
print(y)

Shapes:

x: (4, 3)
b:    (3)
y: (4, 3)

Equivalent mental model:

y = x + b.expand(4, 3)

But you usually do not write the expand yourself because PyTorch does it inside the operation.

Formula:

Since b has no row dimension, its row stride is effectively 0.

Another common pattern is adding per-channel values to image batches.

images = torch.randn(32, 3, 224, 224) # N, C, H, W
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

normalized = (images - mean[None, :, None, None]) / std[None, :, None, None]

Shapes:

images: (32, 3, 224, 224)
mean:   ( 1, 3,   1,   1)
std:    ( 1, 3,   1,   1)
result: (32, 3, 224, 224)

The formula is:

The same channel mean is reused for every batch item and every spatial position.

None, unsqueeze, and Shape Alignment

Broadcasting works from the right, so sometimes you must insert dimensions explicitly.

These two are equivalent:

x = torch.randn(5, 6)

a = x.unsqueeze(-1) # (5, 6, 1)
b = x[:, :, None]  # (5, 6, 1)

Useful patterns:

v = torch.randn(3)

v[None, :]    # (1, 3), row vector
v[:, None]    # (3, 1), column vector

Outer sum:

a = torch.tensor([10, 20, 30]) # (3,)
b = torch.tensor([1, 2, 3, 4]) # (4,)

out = a[:, None] + b[None, :]
print(out.shape) # (3, 4)

Formula:

Memory view:

a[:, None] has shape (3, 1)

10
20
30

b[None, :] has shape (1, 4)

1  2  3  4

result shape (3, 4)

11 12 13 14
21 22 23 24
31 32 33 34

This is not a special "outer" operator. It is just broadcasting.

Broadcasting Is Not Matrix Multiplication

Elementwise multiplication and matrix multiplication use different rules.

Elementwise Product

Elementwise multiplication uses broadcasting.

a = torch.randn(3, 4)
b = torch.randn(4)
c = a * b

print(c.shape) # (3, 4)

Formula:

The shapes align as:

a: (3, 4)
b:    (4)
c: (3, 4)

But:

a = torch.randn(3, 4)
b = torch.randn(4, 5)
a * b

fails because:

(3, 4)
(4, 5)

Rightmost dimensions 4 and 5 are incompatible.

Matrix Product

Matrix multiplication is linear algebra, not elementwise broadcasting.

a = torch.randn(3, 4)
b = torch.randn(4, 5)
c = a @ b

print(c.shape) # (3, 5)

Formula:

For torch.matmul, the last two dimensions are matrix dimensions:

a: (..., n, k)
b: (..., k, m)
out: (..., n, m)

The inner dimension k must match.

The leading dimensions are batch dimensions and are broadcast.

Example:

a = torch.randn(10, 1, 3, 4)
b = torch.randn(1, 20, 4, 5)
c = a @ b

print(c.shape) # (10, 20, 3, 5)

Break it into two parts:

matrix core:
(3, 4) @ (4, 5) -> (3, 5)

batch dims:
(10, 1) and (1, 20) -> (10, 20)

final:
(10, 20, 3, 5)

Formula:

More carefully, because broadcasting reuses the singleton dimensions:

The batch dimensions broadcast, but the matrix multiplication still sums over the inner dimension.

Vector Cases in matmul

torch.matmul has special rules for 1D tensors.

Vector @ Matrix

v = torch.randn(4)
m = torch.randn(4, 5)

out = v @ m
print(out.shape) # (5,)

PyTorch internally treats the vector as (1, 4):

(4,) @ (4, 5)
becomes
(1, 4) @ (4, 5) -> (1, 5)
then squeeze -> (5,)

Formula:

Matrix @ Vector

m = torch.randn(3, 4)
v = torch.randn(4)

out = m @ v
print(out.shape) # (3,)

PyTorch internally treats the vector as (4, 1):

(3, 4) @ (4,)
becomes
(3, 4) @ (4, 1) -> (3, 1)
then squeeze -> (3,)

Formula:

Common PyTorch Shape Tools

Broadcasting is automatic shape compatibility. These tools let you control the geometry manually.

unsqueeze

Adds a dimension of size 1.

x = torch.randn(3)
print(x.unsqueeze(0).shape)  # (1, 3)
print(x.unsqueeze(1).shape)  # (3, 1)

This is metadata-only.

squeeze

Removes dimensions of size 1.

x = torch.randn(1, 3, 1, 5)

print(x.squeeze().shape)     # (3, 5)
print(x.squeeze(0).shape)    # (3, 1, 5)

Be careful with bare squeeze() in model code. If your batch size is 1, it can accidentally remove the batch dimension.

permute

Reorders axes.

x = torch.randn(2, 3, 4)
y = x.permute(2, 0, 1)

print(y.shape)   # (4, 2, 3)
print(y.stride())

permute usually changes strides without copying data. The result is often non-contiguous.

view

Changes shape without copying, but only when the current strides permit it.

x = torch.arange(12).reshape(3, 4)
y = x.view(2, 6)

After a transpose or permute:

x = torch.arange(12).reshape(3, 4)
t = x.t()

t.view(12) # likely fails

because t is non-contiguous.

reshape

Like view, except it may copy if needed.

z = t.reshape(12)

This is convenient, but remember: if it must copy, it costs memory and time.

expand

Expands size-1 dimensions by using stride 0.

x = torch.randn(1, 3, 1)
y = x.expand(5, 3, 7)

print(y.shape)   # (5, 3, 7)
print(y.stride())

Only dimensions of size 1 can be expanded to larger sizes.

repeat

Tiles data physically.

x = torch.tensor([[1, 2, 3]])
y = x.repeat(4, 1)

print(y)

Use this when you need real independent copies.

A Practical Debugging Checklist

When a broadcasting error happens, do this:

  1. Write both shapes vertically.
  2. Align them from the right.
  3. Compare each column.
  4. A column is valid if the sizes are equal, one is 1, or one is missing.
  5. Insert None / unsqueeze where your intended singleton dimension is missing.

Example:

x = torch.randn(5, 6)
w = torch.randn(10)

x + w

Alignment:

x: 5   6
w:     10

6 vs 10 fails.

If you wanted (5, 6, 10), write:

x3 = x[:, :, None] # (5, 6, 1)
w3 = w[None, None, :] # (1, 1, 10)

y = x3 + w3
print(y.shape) # (5, 6, 10)

Alignment:

x3: 5  6   1
w3: 1  1  10
--------------
y:  5  6  10

Formula:

The Mental Model

A tensor is not just a block of numbers. A tensor is:

storage pointer + shape + stride + storage offset + dtype + device

Broadcasting mostly changes shape and stride.

The most important stride trick is:

stride 0 = reuse the same memory location along this axis

So when you see:

x = torch.randn(1, 3)
y = x.expand(4, 3)

read it as:

make a logical (4, 3) tensor,
but when the row index changes,
do not move the pointer.

That is why broadcasting is fast, why expand is memory-cheap, why in-place writes on expanded views are dangerous, and why repeat is a completely different operation.

Once you think in shape + stride, broadcasting stops being magic. It becomes pointer arithmetic with a very nice interface.