Softmax
Softmax looks like a simple elementwise op, but it is actually a global reduction followed by an elementwise write:
y_i = exp(x_i - max(x)) / sum_j exp(x_j - max(x))
The annoying part is that every output element needs the same global max(x) and the same global denominator. If N <= BLOCK_SIZE, one Triton program can own the whole vector and do the reduction locally. But once the vector is larger than a block, tl.max and tl.sum only solve the local problem.
That is why the kernel is split into three phases.
Phase 1: summarize each block
@triton.jit
def block_stats_kernel(input_ptr, block_m_ptr, block_l_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
m = tl.max(x, axis=0)
l = tl.sum(tl.exp(x - m), axis=0)
tl.store(block_m_ptr + pid, m)
tl.store(block_l_ptr + pid, l)
Each program owns a contiguous chunk. It computes:
m_b = max(x in block b)l_b = sum(exp(x - m_b)) in block b
The second value is not the final denominator yet. It is a denominator under the block’s own numerical baseline. Notice that we need to store both max and some. The max alone is not enough. The sum alone is also not enough, because the sum was computed after shifting by a local max. m_b tells us how to translate that block’s denominator into the global baseline. l_b tells us the amount of mass inside the block after the local shift. Together they are a compact sufficient statistic for softmax normalization.
Phase 2: merge block summaries
@triton.jit
def reduce_stats_kernel(
block_m_ptr, block_l_ptr, global_m_ptr, global_l_ptr, num_blocks, BLOCK_SIZE: tl.constexpr
):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < num_blocks
mb = tl.load(block_m_ptr + offsets, mask=mask, other=-float("inf"))
lb = tl.load(block_l_ptr + offsets, mask=mask, other=0.0)
M = tl.max(mb, axis=0)
L = tl.sum(lb * tl.exp(mb - M), axis=0)
tl.store(global_m_ptr, M)
tl.store(global_l_ptr, L)
The key line is:
L = tl.sum(lb * tl.exp(mb - M), axis=0)
This rescales every block’s local denominator into the global denominator’s coordinate system.
Why does this work?
l_b = sum_{i in block b} exp(x_i - m_b)
l_b * exp(m_b - M)
= sum_{i in block b} exp(x_i - m_b) * exp(m_b - M)
= sum_{i in block b} exp(x_i - M)
So after summing over blocks:
L = sum_i exp(x_i - M)
This is basically the same log-sum-exp trick, just applied hierarchically.
Phase 3: normalize every element
@triton.jit
def softmax_kernel(input, output, global_m_ptr, global_l_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
M = tl.load(global_m_ptr)
L = tl.load(global_l_ptr)
x = tl.load(input + offsets, mask=mask, other=-float("inf"))
y = tl.exp(x - M) / L
tl.store(output + offsets, y, mask=mask)
Now the kernel becomes embarrassingly parallel again. Every program reads the same two scalars, M and L, then writes one contiguous block of outputs.
BLOCK_SIZE determination
Make the reduction block size the power of two:
reduce_block = 1
while reduce_block < num_blocks:
reduce_block <<= 1
Triton reductions operate over a static block of lanes. A power-of-two block size is friendly for vectorized reduction lowering, and the mask makes the padded lanes harmless.
- padding
mblanes load-inf, so they cannot affectmax - padding
lblanes load0.0, so they cannot affectsum
The tradeoff is that phase 2 assumes all block summaries fit into one Triton program. For very large N, this should become a tree reduction instead of a single reduction program.
Why three kernels?
Q: Why not compute the global stats and output in one kernel?
Because there is no grid-wide synchronization inside a normal Triton kernel. Programs can synchronize within themselves, but independent programs cannot say “everyone finish phase 1, then everyone start phase 2” inside the same launch.
Kernel launches become the synchronization boundaries:
- after phase 1, all
block_mandblock_lvalues exist - after phase 2,
global_mandglobal_lexist - phase 3 can safely normalize the whole vector
This is similar to the reduction post: the work is parallel locally, but global meaning only appears after a reduction boundary.
Cost Analysis
This design reads the input twice
- once to compute block summaries
- once to write normalized outputs
It also launches three kernels and writes small intermediate buffers. That sounds expensive, but it buys the missing global synchronization and keeps each memory access pattern simple:
- phase 1: contiguous reads, tiny writes
- phase 2: tiny reads, two scalar writes
- phase 3: contiguous reads and contiguous writes
For a single large vector softmax, this is a reasonable decomposition. The main limitation is launch overhead and the single-program phase-2 reduction. For many rows, the usual next step is to map one row or one row-split to each program and reduce per row, but the same invariant remains:
(m_a, l_a) + (m_b, l_b)
=> M = max(m_a, m_b)
=> L = l_a * exp(m_a - M) + l_b * exp(m_b - M)
That pair (m, l) is the useful mental model. Softmax across blocks is not “compute exponentials, then somehow sum them globally”; it is “carry enough shifted statistics that blocks can be merged without losing numerical stability.”