Softmax Attention
Today I wrote a Triton kernel for attention:
out = softmax(Q @ K.T / sqrt(d)) @ V
This looks like two matrix multiplications with a softmax in the middle. The tricky part is the softmax.
For one query row, attention does:
scores = Q[row] @ K.T
weights = softmax(scores)
out[row] = weights @ V
Softmax is blocking. To compute even one output weight, we need the max and the denominator for the entire row:
softmax(x_i) = exp(x_i - max(x)) / sum_j exp(x_j - max(x))
So we cannot look at one tile of keys, normalize it, and move on. That would produce a local softmax for the tile, not the real softmax over all keys. Each query row has to see every key before its output is final.
one query row
K block 0 K block 1 K block 2 ... K block n
--------- --------- --------- ---------
scores scores scores scores
\ | | /
\ | | /
+------------+--------------+----------------------+
|
one row-wide softmax
|
final output row
The kernel handles this by streaming across K and V blocks. It never stores the full M x N score matrix. Instead it keeps a running max, a running sum, and a running output accumulator for each query row.
Program layout
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
m_indexes = row_pid * M_BLOCK + tl.arange(0, M_BLOCK)
v_indexes = col_pid * d_BLOCK + tl.arange(0, d_BLOCK)
The launch grid is:
grid = (triton.cdiv(M, M_BLOCK), triton.cdiv(d, d_BLOCK))
So each Triton program owns:
M_BLOCKquery rowsd_BLOCKoutput columns
In this version:
M_BLOCK = 16
N_BLOCK = 64
d_BLOCK = 128
That means one program computes a [16, 128] tile of the final output.
The first grid axis is parallel over query rows. The second grid axis is parallel over output feature columns. Different programs can run independently because they write different output tiles.
output O: [M, d]
d dimension
col_pid=0 col_pid=1 col_pid=2
+--------------+--------------+--------------+
row 0 | program | program | program |
block | row_pid=0 | row_pid=0 | row_pid=0 |
+--------------+--------------+--------------+
row 1 | program | program | program |
block | row_pid=1 | row_pid=1 | row_pid=1 |
+--------------+--------------+--------------+
row 2 | program | program | program |
block | row_pid=2 | row_pid=2 | row_pid=2 |
+--------------+--------------+--------------+
Each box writes one [M_BLOCK, d_BLOCK] output tile.
The state we carry
running_sum = tl.zeros([M_BLOCK], dtype=tl.float32)
running_max = tl.full([M_BLOCK], float("-inf"), dtype=tl.float32)
out_acc = tl.zeros([M_BLOCK, d_BLOCK], dtype=tl.float32)
scale = 1 / tl.sqrt(d + 0.0)
For every query row inside the block, we keep:
running_max: the largest score seen so farrunning_sum: the softmax denominator under that maxout_acc: the unnormalized attention output under that max
After some prefix of keys has been processed, the invariant is:
running_sum[i] = sum_j exp(score[i, j] - running_max[i])
out_acc[i, :] = sum_j exp(score[i, j] - running_max[i]) * V[j, :]
This is why the running sum is useful. I would not say it “amortizes the compute” exactly. We still have to compute every QK score. What it amortizes is the normalization work and memory traffic across the scan. We do not need to store all scores, launch a separate softmax, then read them again. The denominator is built incrementally as the row streams through the kernel.
after block 0:
running_max = max(scores_0)
running_sum = sum(exp(scores_0 - running_max))
after block 1:
running_max = max(old_running_max, max(scores_1))
running_sum = rescaled_old_sum + sum(exp(scores_1 - running_max))
after block 2:
same update again
Computing a block of scores
for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):
n_indexes = nn * N_BLOCK + tl.arange(0, N_BLOCK)
n_mask = n_indexes < N
acc = tl.zeros((M_BLOCK, N_BLOCK), dtype=tl.float32)
The outer loop walks through all keys. This loop is not parallel inside one program, because softmax for a row depends on all key blocks. We can process the blocks one after another and update the running statistics.
Inside each key block, we compute:
acc = Q_block @ K_block.T
That dot product is itself tiled over the feature dimension:
for dd in tl.range(0, tl.cdiv(d, d_BLOCK)):
d_indexes = dd * d_BLOCK + tl.arange(0, d_BLOCK)
cur_q = tl.load(Q + M_offsets, mask=M_mask, other=0.0)
cur_k = tl.load(K + K_offsets, mask=K_mask, other=0.0)
acc += tl.dot(cur_q, tl.trans(cur_k)) * scale
This produces a score tile:
acc.shape == [M_BLOCK, N_BLOCK]
The tl.dot is the parallel part. The lanes in the program cooperate to compute the block matmul, and Triton can lower the dot to efficient GPU instructions. On NVIDIA GPUs, compatible shapes and dtypes can use tensor-core style matrix multiply instructions. The Python-looking block operation is not one scalar thread doing all the work.
Are these loops unrolled?
The short answer: not automatically in the general case.
M_BLOCK, N_BLOCK, and d_BLOCK are tl.constexpr, so the shapes inside the program are compile-time constants. That helps Triton generate fixed-size vector operations for loads, reductions, and dot products.
But the loop bounds use runtime values:
tl.cdiv(N, N_BLOCK)
tl.cdiv(d, d_BLOCK)
Since N and d are normal runtime arguments, Triton cannot fully unroll these loops for all possible inputs. They are real loops in the generated program. If N or d were made compile-time constants, or if the loop bounds were otherwise known at compile time, the compiler would have more room to unroll them.
So the useful distinction is:
- the block math has static shapes and is compiled into vectorized GPU code
- the scan over key blocks is still a loop
- the scan over the inner
dtiles is also still a loop whendis runtime
Masking before softmax
acc = tl.where(m_mask[:, None] & n_mask[None, :], acc, float("-inf"))
For matmul, padded values can be zero. For softmax, padded scores must be -inf.
If an invalid key stayed as 0.0, it could affect the row max or add exp(0) to the denominator. That would change the distribution. With -inf, the padded score contributes:
exp(-inf) = 0
So it disappears.
Online softmax
After computing one score block, the kernel merges it into the row statistics:
current_max = tl.max(acc, axis=-1)
new_running_max = tl.maximum(current_max, running_max)
alpha = tl.exp(running_max - new_running_max)
weights = tl.exp(acc - new_running_max[:, None])
incre = tl.sum(weights, axis=1)
running_sum = tl.fma(running_sum, alpha, incre)
running_max = new_running_max
This is the log-sum-exp trick in streaming form.
Suppose the old max was m_old, and the new max after seeing this block is m_new. The old denominator was measured relative to m_old:
sum_old = sum(exp(score_old - m_old))
To merge it with the new block, we convert it into the new coordinate system:
sum_old * exp(m_old - m_new)
That is alpha.
Then we add the new block’s contribution:
sum(exp(score_new - m_new))
So the merged denominator is:
running_sum =
old_sum * exp(old_max - new_max)
+ new_sum
old state, measured at old_max
old_sum
old_out_acc
|
| multiply by alpha = exp(old_max - new_max)
v
old state, measured at new_max
old_sum * alpha
old_out_acc * alpha
|
| add current block contribution
v
new state
running_sum
out_acc
This is the reason the kernel can stream across N. Softmax still needs the whole row, but it does not need to keep the whole row in memory.
Multiplying by V in the same pass
The kernel also immediately applies the current unnormalized weights to V:
cur_v = tl.load(V + v_offsets, mask=V_mask)
weighted_v = tl.dot(weights, cur_v)
out_acc = tl.fma(out_acc, alpha[:, None], weighted_v)
This is the same idea as the denominator update.
If the row max changes, the old output accumulator also has to be rescaled:
old_out_acc * exp(old_max - new_max)
Then we add the contribution from the new key block:
weights @ V_block
This keeps running_sum and out_acc in the same coordinate system. At the end, we can normalize once:
out_acc /= running_sum[:, None]
What is parallel and what is serial?
There are a few layers here.
Across the grid, programs are parallel:
- different
row_pidvalues compute different query rows - different
col_pidvalues compute different output feature columns - these programs can run on different SMs at the same time
Inside one program, block operations are parallel:
tl.loadloads a block of elementstl.maxreduces across a block rowtl.sumreduces across a block rowtl.dotcomputes a matrix multiply tile using many lanes/instructions
Inside one program, the key scan is serial:
for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):
That loop has to visit key blocks in sequence because each iteration updates running_max, running_sum, and out_acc.
Under the hood, a Triton program maps to GPU execution resources roughly like a CUDA thread block. Many programs are launched. The GPU scheduler assigns them to SMs. Each program runs with warps. When one warp is waiting on memory, the SM can run another ready warp. This is how latency is hidden.
So “serial” here does not mean the whole GPU is idle. It means one output tile has a loop-carried dependency over key blocks. The GPU still runs many other output tiles at the same time.
GPU view
SM 0: program A -> scans K blocks 0, 1, 2, ...
SM 1: program B -> scans K blocks 0, 1, 2, ...
SM 2: program C -> scans K blocks 0, 1, 2, ...
SM 3: program D -> scans K blocks 0, 1, 2, ...
Inside each program, the scan over K blocks is ordered.
Across programs, many output tiles run at the same time.
Why fuse instead of using GEMM + softmax + GEMM?
The unfused version is conceptually clean:
S = Q @ K.T
P = softmax(S)
O = P @ V
If we use optimized kernels, both GEMMs can be very fast. The problem is the materialized attention matrix S or P.
For shapes:
Q: [M, d]
K: [N, d]
V: [N, d]
S/P: [M, N]
O: [M, d]
The separate-kernel path has to move a lot of intermediate data through HBM:
- GEMM reads
QandK, then writesS(M * Nelements). - Softmax reads
S, writesP, and often needs multiple passes over each row for max, sum, and normalize. - The second GEMM reads
PandV, then writesO.
unfused path
Q, K -> GEMM -> S in HBM -> softmax -> P in HBM -> GEMM with V -> O
^ ^
| |
huge intermediate huge intermediate
fused path
Q block, K block -> score tile -> online softmax -> apply V block -> O tile
|
+-- stays temporary, then gets discarded
Even if the GEMMs are optimized, the M x N matrix is huge. For example, if:
M = 4096
N = 4096
d = 128
Then S has:
4096 * 4096 = 16,777,216 elements
In fp16, that is about 32 MB for one matrix. Writing scores, reading scores, writing probabilities, and reading probabilities is already around 128 MB of intermediate traffic, before counting the actual Q, K, V, and O traffic. If softmax does more than one pass over S, the traffic grows again.
| Step | Unfused intermediate traffic for S/P | Fused intermediate traffic for S/P |
|---|---|---|
write scores S | M * N elements | 0 global writes |
| read scores for softmax | M * N elements | 0 global reads |
write probabilities P | M * N elements | 0 global writes |
read probabilities for P @ V | M * N elements | 0 global reads |
For the 4096 x 4096 fp16 example, each full matrix pass is about 32 MB. Four passes is about 128 MB of extra HBM traffic.
The fused kernel avoids writing S or P to global memory. It computes a score tile, updates the online softmax state, applies the weights to V, and discards the score tile.
Runtime-wise, this changes the bottleneck:
- separate kernels pay for fast GEMM math plus large global-memory traffic for
SandP - the fused kernel does more complicated per-tile bookkeeping, but removes the largest intermediate reads and writes
This is especially important because attention is often memory-bound around the softmax boundary. The score matrix is too large to keep on chip, so storing it forces HBM traffic and kernel launch boundaries. Fusing keeps the temporary score tile in registers/on-chip memory for the short time it is needed.
There is one caveat in my current implementation. The grid is split over output columns with col_pid, so if d > d_BLOCK, different column programs recompute the same QK.T scores for different V column tiles. For d <= 128, this is fine because there is only one output-column tile. For larger d, this wastes compute. A more tuned implementation would choose the tiling more carefully or use a different strategy to avoid recomputing attention scores too much.
Still, the main benefit is already clear: the fused kernel trades a big materialized intermediate for a streaming online softmax. That is the core idea behind FlashAttention-style kernels.
Complete Code
import torch
import triton
import triton.language as tl
@triton.jit
def softmax_attention(
Q,
K,
V,
M,
N,
d,
out,
Q_stride_M,
Q_stride_d,
K_stride_N,
K_stride_d,
V_stride_N,
V_stride_d,
out_stride_M,
out_stride_d,
M_BLOCK: tl.constexpr,
N_BLOCK: tl.constexpr,
d_BLOCK: tl.constexpr,
):
# tiled gemm
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
m_indexes = row_pid * M_BLOCK + tl.arange(0, M_BLOCK)
# this col_indexes is for final V multiplication, not the initial gemm
v_indexes = col_pid * d_BLOCK + tl.arange(0, d_BLOCK)
m_mask = m_indexes < M
v_mask = v_indexes < d
running_sum = tl.zeros([M_BLOCK], dtype=tl.float32)
running_max = tl.full([M_BLOCK], float("-inf"), dtype=tl.float32)
out_acc = tl.zeros([M_BLOCK, d_BLOCK], dtype=tl.float32)
scale = 1 / tl.sqrt(d + 0.0)
# we need to iterate through all columns (compute one complete row) to get the softmax
# but this can be unrolled when compiled
for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):
n_indexes = nn * N_BLOCK + tl.arange(0, N_BLOCK)
n_mask = n_indexes < N
# block gemm accumulator
acc = tl.zeros((M_BLOCK, N_BLOCK), dtype=tl.float32)
# this dd is for QK inner dimension tiling
for dd in tl.range(0, tl.cdiv(d, d_BLOCK)):
d_indexes = dd * d_BLOCK + tl.arange(0, d_BLOCK)
d_mask = d_indexes < d
M_offsets = (
Q_stride_M * m_indexes[:, None] + Q_stride_d * d_indexes[None, :]
)
# do a transpose on the fly
# this might happen to be efficient for since both are row major right not lol
K_offsets = (
K_stride_N * n_indexes[:, None] + K_stride_d * d_indexes[None, :]
)
M_mask = m_mask[:, None] & d_mask[None, :]
K_mask = n_mask[:, None] & d_mask[None, :]
# load
cur_q = tl.load(Q + M_offsets, mask=M_mask, other=0.0)
cur_k = tl.load(K + K_offsets, mask=K_mask, other=0.0)
# partial sum, normalization on the fly to avoid overflow
acc += tl.dot(cur_q, tl.trans(cur_k)) * scale
# acc is an M_BLOCK * N_BLOCK gemm result
# mask the overflow to -inf so that the alpha can be computed correctly
acc = tl.where(m_mask[:, None] & n_mask[None, :], acc, float("-inf"))
# run max on column, getting M_BLOCK * 1 maximum per row
current_max = tl.max(acc, axis=-1)
# compare with the running max and update the running_max
new_running_max = tl.maximum(current_max, running_max)
# alpha on updating the running_sum given the update on running_max
alpha = tl.exp(running_max - new_running_max) # the diff
# compute increment to running_sum on this N_BLOCK
weights = tl.exp(acc - new_running_max[:, None])
# directly sum because all the denomiators are new_running_max
incre = tl.sum(weights, axis=1)
# running_sum = running_sum * alpha + incre
running_sum = tl.fma(running_sum, alpha, incre)
running_max = new_running_max
# gemm with V.
# N is now the inner dimension (M_BLOCK, N_BLOCK) * (N_BLOCK * d_BLOCK)
v_offsets = n_indexes[:, None] * V_stride_N + v_indexes[None, :] * V_stride_d
V_mask = n_mask[:, None] & v_mask[None, :]
cur_v = tl.load(V + v_offsets, mask=V_mask)
# dot the weights now, will normalize with running_sum later
weighted_v = tl.dot(weights, cur_v)
# the accumulator for gemm with V, also need to correct the error
out_acc = tl.fma(out_acc, alpha[:, None], weighted_v)
# the running sum is correct now
out_acc /= running_sum[:, None]
out_offsets = m_indexes[:, None] * out_stride_M + v_indexes[None, :] * out_stride_d
out_mask = m_mask[:, None] & v_mask[None, :]
tl.store(out + out_offsets, out_acc, mask=out_mask)
# Q, K, V, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
M: int,
N: int,
d: int,
):
M_BLOCK = 16
d_BLOCK = 128
N_BLOCK = 64
grid = (triton.cdiv(M, M_BLOCK), triton.cdiv(d, d_BLOCK))
Q_stride_M, Q_stride_d = Q.stride()
K_stride_N, K_stride_d = K.stride()
V_stride_N, V_stride_d = V.stride()
out_stride_M, out_stride_d = output.stride()
softmax_attention[grid](
Q=Q,
K=K,
V=V,
out=output,
M=M,
N=N,
d=d,
Q_stride_M=Q_stride_M,
Q_stride_d=Q_stride_d,
K_stride_N=K_stride_N,
K_stride_d=K_stride_d,
V_stride_N=V_stride_N,
V_stride_d=V_stride_d,
out_stride_M=out_stride_M,
out_stride_d=out_stride_d,
M_BLOCK=M_BLOCK,
d_BLOCK=d_BLOCK,
N_BLOCK=N_BLOCK,
)