Skip to content
Retep's
Go back

[TIL] Collected: Softmax

Edit page

Softmax

Softmax looks like a simple elementwise op, but it is actually a global reduction followed by an elementwise write:

y_i = exp(x_i - max(x)) / sum_j exp(x_j - max(x))

The annoying part is that every output element needs the same global max(x) and the same global denominator. If N <= BLOCK_SIZE, one Triton program can own the whole vector and do the reduction locally. But once the vector is larger than a block, tl.max and tl.sum only solve the local problem.

That is why the kernel is split into three phases.

Phase 1: summarize each block

@triton.jit
def block_stats_kernel(input_ptr, block_m_ptr, block_l_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    x = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
    m = tl.max(x, axis=0)
    l = tl.sum(tl.exp(x - m), axis=0)

    tl.store(block_m_ptr + pid, m)
    tl.store(block_l_ptr + pid, l)

Each program owns a contiguous chunk. It computes:

The second value is not the final denominator yet. It is a denominator under the block’s own numerical baseline. Notice that we need to store both max and some. The max alone is not enough. The sum alone is also not enough, because the sum was computed after shifting by a local max. m_b tells us how to translate that block’s denominator into the global baseline. l_b tells us the amount of mass inside the block after the local shift. Together they are a compact sufficient statistic for softmax normalization.

Phase 2: merge block summaries

@triton.jit
def reduce_stats_kernel(
    block_m_ptr, block_l_ptr, global_m_ptr, global_l_ptr, num_blocks, BLOCK_SIZE: tl.constexpr
):
    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_blocks

    mb = tl.load(block_m_ptr + offsets, mask=mask, other=-float("inf"))
    lb = tl.load(block_l_ptr + offsets, mask=mask, other=0.0)

    M = tl.max(mb, axis=0)
    L = tl.sum(lb * tl.exp(mb - M), axis=0)

    tl.store(global_m_ptr, M)
    tl.store(global_l_ptr, L)

The key line is:

L = tl.sum(lb * tl.exp(mb - M), axis=0)

This rescales every block’s local denominator into the global denominator’s coordinate system.

Why does this work?

l_b = sum_{i in block b} exp(x_i - m_b)
l_b * exp(m_b - M)
    = sum_{i in block b} exp(x_i - m_b) * exp(m_b - M)
    = sum_{i in block b} exp(x_i - M)

So after summing over blocks:

L = sum_i exp(x_i - M)

This is basically the same log-sum-exp trick, just applied hierarchically.

Phase 3: normalize every element

@triton.jit
def softmax_kernel(input, output, global_m_ptr, global_l_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    M = tl.load(global_m_ptr)
    L = tl.load(global_l_ptr)
    x = tl.load(input + offsets, mask=mask, other=-float("inf"))
    y = tl.exp(x - M) / L

    tl.store(output + offsets, y, mask=mask)

Now the kernel becomes embarrassingly parallel again. Every program reads the same two scalars, M and L, then writes one contiguous block of outputs.

BLOCK_SIZE determination

Make the reduction block size the power of two:

reduce_block = 1
while reduce_block < num_blocks:
    reduce_block <<= 1

Triton reductions operate over a static block of lanes. A power-of-two block size is friendly for vectorized reduction lowering, and the mask makes the padded lanes harmless.

The tradeoff is that phase 2 assumes all block summaries fit into one Triton program. For very large N, this should become a tree reduction instead of a single reduction program.

Why three kernels?

Q: Why not compute the global stats and output in one kernel?

Because there is no grid-wide synchronization inside a normal Triton kernel. Programs can synchronize within themselves, but independent programs cannot say “everyone finish phase 1, then everyone start phase 2” inside the same launch.

Kernel launches become the synchronization boundaries:

  1. after phase 1, all block_m and block_l values exist
  2. after phase 2, global_m and global_l exist
  3. phase 3 can safely normalize the whole vector

This is similar to the reduction post: the work is parallel locally, but global meaning only appears after a reduction boundary.

Cost Analysis

This design reads the input twice

It also launches three kernels and writes small intermediate buffers. That sounds expensive, but it buys the missing global synchronization and keeps each memory access pattern simple:

For a single large vector softmax, this is a reasonable decomposition. The main limitation is launch overhead and the single-program phase-2 reduction. For many rows, the usual next step is to map one row or one row-split to each program and reduce per row, but the same invariant remains:

(m_a, l_a) + (m_b, l_b)
=> M = max(m_a, m_b)
=> L = l_a * exp(m_a - M) + l_b * exp(m_b - M)

That pair (m, l) is the useful mental model. Softmax across blocks is not “compute exponentials, then somehow sum them globally”; it is “carry enough shifted statistics that blocks can be merged without losing numerical stability.”


Edit page
Share this post:

Previous Post
[TIL] Collected: Softmax Attention
Next Post
[TIL] Collected: Transpose, Reduction