Skip to content
Retep's
Go back

[TIL] Collected: Softmax Attention

Edit page

Softmax Attention

Today I wrote a Triton kernel for attention:

out = softmax(Q @ K.T / sqrt(d)) @ V

This looks like two matrix multiplications with a softmax in the middle. The tricky part is the softmax.

For one query row, attention does:

scores = Q[row] @ K.T
weights = softmax(scores)
out[row] = weights @ V

Softmax is blocking. To compute even one output weight, we need the max and the denominator for the entire row:

softmax(x_i) = exp(x_i - max(x)) / sum_j exp(x_j - max(x))

So we cannot look at one tile of keys, normalize it, and move on. That would produce a local softmax for the tile, not the real softmax over all keys. Each query row has to see every key before its output is final.

one query row

K block 0      K block 1      K block 2      ...      K block n
---------      ---------      ---------               ---------
 scores         scores         scores                  scores
    \              |              |                       /
     \             |              |                      /
      +------------+--------------+----------------------+
                   |
          one row-wide softmax
                   |
              final output row

The kernel handles this by streaming across K and V blocks. It never stores the full M x N score matrix. Instead it keeps a running max, a running sum, and a running output accumulator for each query row.

Program layout

row_pid = tl.program_id(0)
col_pid = tl.program_id(1)

m_indexes = row_pid * M_BLOCK + tl.arange(0, M_BLOCK)
v_indexes = col_pid * d_BLOCK + tl.arange(0, d_BLOCK)

The launch grid is:

grid = (triton.cdiv(M, M_BLOCK), triton.cdiv(d, d_BLOCK))

So each Triton program owns:

In this version:

M_BLOCK = 16
N_BLOCK = 64
d_BLOCK = 128

That means one program computes a [16, 128] tile of the final output.

The first grid axis is parallel over query rows. The second grid axis is parallel over output feature columns. Different programs can run independently because they write different output tiles.

output O: [M, d]

                 d dimension
          col_pid=0        col_pid=1        col_pid=2
        +--------------+--------------+--------------+
row 0   | program      | program      | program      |
block   | row_pid=0    | row_pid=0    | row_pid=0    |
        +--------------+--------------+--------------+
row 1   | program      | program      | program      |
block   | row_pid=1    | row_pid=1    | row_pid=1    |
        +--------------+--------------+--------------+
row 2   | program      | program      | program      |
block   | row_pid=2    | row_pid=2    | row_pid=2    |
        +--------------+--------------+--------------+

Each box writes one [M_BLOCK, d_BLOCK] output tile.

The state we carry

running_sum = tl.zeros([M_BLOCK], dtype=tl.float32)
running_max = tl.full([M_BLOCK], float("-inf"), dtype=tl.float32)
out_acc = tl.zeros([M_BLOCK, d_BLOCK], dtype=tl.float32)
scale = 1 / tl.sqrt(d + 0.0)

For every query row inside the block, we keep:

After some prefix of keys has been processed, the invariant is:

running_sum[i] = sum_j exp(score[i, j] - running_max[i])
out_acc[i, :] = sum_j exp(score[i, j] - running_max[i]) * V[j, :]

This is why the running sum is useful. I would not say it “amortizes the compute” exactly. We still have to compute every QK score. What it amortizes is the normalization work and memory traffic across the scan. We do not need to store all scores, launch a separate softmax, then read them again. The denominator is built incrementally as the row streams through the kernel.

after block 0:
  running_max = max(scores_0)
  running_sum = sum(exp(scores_0 - running_max))

after block 1:
  running_max = max(old_running_max, max(scores_1))
  running_sum = rescaled_old_sum + sum(exp(scores_1 - running_max))

after block 2:
  same update again

Computing a block of scores

for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):
    n_indexes = nn * N_BLOCK + tl.arange(0, N_BLOCK)
    n_mask = n_indexes < N
    acc = tl.zeros((M_BLOCK, N_BLOCK), dtype=tl.float32)

The outer loop walks through all keys. This loop is not parallel inside one program, because softmax for a row depends on all key blocks. We can process the blocks one after another and update the running statistics.

Inside each key block, we compute:

acc = Q_block @ K_block.T

That dot product is itself tiled over the feature dimension:

for dd in tl.range(0, tl.cdiv(d, d_BLOCK)):
    d_indexes = dd * d_BLOCK + tl.arange(0, d_BLOCK)

    cur_q = tl.load(Q + M_offsets, mask=M_mask, other=0.0)
    cur_k = tl.load(K + K_offsets, mask=K_mask, other=0.0)

    acc += tl.dot(cur_q, tl.trans(cur_k)) * scale

This produces a score tile:

acc.shape == [M_BLOCK, N_BLOCK]

The tl.dot is the parallel part. The lanes in the program cooperate to compute the block matmul, and Triton can lower the dot to efficient GPU instructions. On NVIDIA GPUs, compatible shapes and dtypes can use tensor-core style matrix multiply instructions. The Python-looking block operation is not one scalar thread doing all the work.

Are these loops unrolled?

The short answer: not automatically in the general case.

M_BLOCK, N_BLOCK, and d_BLOCK are tl.constexpr, so the shapes inside the program are compile-time constants. That helps Triton generate fixed-size vector operations for loads, reductions, and dot products.

But the loop bounds use runtime values:

tl.cdiv(N, N_BLOCK)
tl.cdiv(d, d_BLOCK)

Since N and d are normal runtime arguments, Triton cannot fully unroll these loops for all possible inputs. They are real loops in the generated program. If N or d were made compile-time constants, or if the loop bounds were otherwise known at compile time, the compiler would have more room to unroll them.

So the useful distinction is:

Masking before softmax

acc = tl.where(m_mask[:, None] & n_mask[None, :], acc, float("-inf"))

For matmul, padded values can be zero. For softmax, padded scores must be -inf.

If an invalid key stayed as 0.0, it could affect the row max or add exp(0) to the denominator. That would change the distribution. With -inf, the padded score contributes:

exp(-inf) = 0

So it disappears.

Online softmax

After computing one score block, the kernel merges it into the row statistics:

current_max = tl.max(acc, axis=-1)
new_running_max = tl.maximum(current_max, running_max)
alpha = tl.exp(running_max - new_running_max)

weights = tl.exp(acc - new_running_max[:, None])
incre = tl.sum(weights, axis=1)

running_sum = tl.fma(running_sum, alpha, incre)
running_max = new_running_max

This is the log-sum-exp trick in streaming form.

Suppose the old max was m_old, and the new max after seeing this block is m_new. The old denominator was measured relative to m_old:

sum_old = sum(exp(score_old - m_old))

To merge it with the new block, we convert it into the new coordinate system:

sum_old * exp(m_old - m_new)

That is alpha.

Then we add the new block’s contribution:

sum(exp(score_new - m_new))

So the merged denominator is:

running_sum =
    old_sum * exp(old_max - new_max)
    + new_sum
old state, measured at old_max
  old_sum
  old_out_acc
        |
        | multiply by alpha = exp(old_max - new_max)
        v
old state, measured at new_max
  old_sum * alpha
  old_out_acc * alpha
        |
        | add current block contribution
        v
new state
  running_sum
  out_acc

This is the reason the kernel can stream across N. Softmax still needs the whole row, but it does not need to keep the whole row in memory.

Multiplying by V in the same pass

The kernel also immediately applies the current unnormalized weights to V:

cur_v = tl.load(V + v_offsets, mask=V_mask)
weighted_v = tl.dot(weights, cur_v)
out_acc = tl.fma(out_acc, alpha[:, None], weighted_v)

This is the same idea as the denominator update.

If the row max changes, the old output accumulator also has to be rescaled:

old_out_acc * exp(old_max - new_max)

Then we add the contribution from the new key block:

weights @ V_block

This keeps running_sum and out_acc in the same coordinate system. At the end, we can normalize once:

out_acc /= running_sum[:, None]

What is parallel and what is serial?

There are a few layers here.

Across the grid, programs are parallel:

Inside one program, block operations are parallel:

Inside one program, the key scan is serial:

for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):

That loop has to visit key blocks in sequence because each iteration updates running_max, running_sum, and out_acc.

Under the hood, a Triton program maps to GPU execution resources roughly like a CUDA thread block. Many programs are launched. The GPU scheduler assigns them to SMs. Each program runs with warps. When one warp is waiting on memory, the SM can run another ready warp. This is how latency is hidden.

So “serial” here does not mean the whole GPU is idle. It means one output tile has a loop-carried dependency over key blocks. The GPU still runs many other output tiles at the same time.

GPU view

SM 0:  program A  -> scans K blocks 0, 1, 2, ...
SM 1:  program B  -> scans K blocks 0, 1, 2, ...
SM 2:  program C  -> scans K blocks 0, 1, 2, ...
SM 3:  program D  -> scans K blocks 0, 1, 2, ...

Inside each program, the scan over K blocks is ordered.
Across programs, many output tiles run at the same time.

Why fuse instead of using GEMM + softmax + GEMM?

The unfused version is conceptually clean:

S = Q @ K.T
P = softmax(S)
O = P @ V

If we use optimized kernels, both GEMMs can be very fast. The problem is the materialized attention matrix S or P.

For shapes:

Q: [M, d]
K: [N, d]
V: [N, d]
S/P: [M, N]
O: [M, d]

The separate-kernel path has to move a lot of intermediate data through HBM:

  1. GEMM reads Q and K, then writes S (M * N elements).
  2. Softmax reads S, writes P, and often needs multiple passes over each row for max, sum, and normalize.
  3. The second GEMM reads P and V, then writes O.
unfused path

Q, K -> GEMM -> S in HBM -> softmax -> P in HBM -> GEMM with V -> O
                  ^                         ^
                  |                         |
          huge intermediate          huge intermediate

fused path

Q block, K block -> score tile -> online softmax -> apply V block -> O tile
                       |
                       +-- stays temporary, then gets discarded

Even if the GEMMs are optimized, the M x N matrix is huge. For example, if:

M = 4096
N = 4096
d = 128

Then S has:

4096 * 4096 = 16,777,216 elements

In fp16, that is about 32 MB for one matrix. Writing scores, reading scores, writing probabilities, and reading probabilities is already around 128 MB of intermediate traffic, before counting the actual Q, K, V, and O traffic. If softmax does more than one pass over S, the traffic grows again.

StepUnfused intermediate traffic for S/PFused intermediate traffic for S/P
write scores SM * N elements0 global writes
read scores for softmaxM * N elements0 global reads
write probabilities PM * N elements0 global writes
read probabilities for P @ VM * N elements0 global reads

For the 4096 x 4096 fp16 example, each full matrix pass is about 32 MB. Four passes is about 128 MB of extra HBM traffic.

The fused kernel avoids writing S or P to global memory. It computes a score tile, updates the online softmax state, applies the weights to V, and discards the score tile.

Runtime-wise, this changes the bottleneck:

This is especially important because attention is often memory-bound around the softmax boundary. The score matrix is too large to keep on chip, so storing it forces HBM traffic and kernel launch boundaries. Fusing keeps the temporary score tile in registers/on-chip memory for the short time it is needed.

There is one caveat in my current implementation. The grid is split over output columns with col_pid, so if d > d_BLOCK, different column programs recompute the same QK.T scores for different V column tiles. For d <= 128, this is fine because there is only one output-column tile. For larger d, this wastes compute. A more tuned implementation would choose the tiling more carefully or use a different strategy to avoid recomputing attention scores too much.

Still, the main benefit is already clear: the fused kernel trades a big materialized intermediate for a streaming online softmax. That is the core idea behind FlashAttention-style kernels.

Complete Code

import torch
import triton
import triton.language as tl


@triton.jit
def softmax_attention(
    Q,
    K,
    V,
    M,
    N,
    d,
    out,
    Q_stride_M,
    Q_stride_d,
    K_stride_N,
    K_stride_d,
    V_stride_N,
    V_stride_d,
    out_stride_M,
    out_stride_d,
    M_BLOCK: tl.constexpr,
    N_BLOCK: tl.constexpr,
    d_BLOCK: tl.constexpr,
):
    # tiled gemm
    row_pid = tl.program_id(0)
    col_pid = tl.program_id(1)
    m_indexes = row_pid * M_BLOCK + tl.arange(0, M_BLOCK)

    # this col_indexes is for final V multiplication, not the initial gemm
    v_indexes = col_pid * d_BLOCK + tl.arange(0, d_BLOCK)
    m_mask = m_indexes < M
    v_mask = v_indexes < d

    running_sum = tl.zeros([M_BLOCK], dtype=tl.float32)
    running_max = tl.full([M_BLOCK], float("-inf"), dtype=tl.float32)
    out_acc = tl.zeros([M_BLOCK, d_BLOCK], dtype=tl.float32)
    scale = 1 / tl.sqrt(d + 0.0)

    # we need to iterate through all columns (compute one complete row) to get the softmax
    # but this can be unrolled when compiled
    for nn in tl.range(0, tl.cdiv(N, N_BLOCK)):
        n_indexes = nn * N_BLOCK + tl.arange(0, N_BLOCK)
        n_mask = n_indexes < N
        # block gemm accumulator
        acc = tl.zeros((M_BLOCK, N_BLOCK), dtype=tl.float32)
        # this dd is for QK inner dimension tiling
        for dd in tl.range(0, tl.cdiv(d, d_BLOCK)):
            d_indexes = dd * d_BLOCK + tl.arange(0, d_BLOCK)
            d_mask = d_indexes < d
            M_offsets = (
                Q_stride_M * m_indexes[:, None] + Q_stride_d * d_indexes[None, :]
            )
            # do a transpose on the fly
            # this might happen to be efficient for since both are row major right not lol
            K_offsets = (
                K_stride_N * n_indexes[:, None] + K_stride_d * d_indexes[None, :]
            )

            M_mask = m_mask[:, None] & d_mask[None, :]
            K_mask = n_mask[:, None] & d_mask[None, :]
            # load
            cur_q = tl.load(Q + M_offsets, mask=M_mask, other=0.0)
            cur_k = tl.load(K + K_offsets, mask=K_mask, other=0.0)

            # partial sum, normalization on the fly to avoid overflow
            acc += tl.dot(cur_q, tl.trans(cur_k)) * scale
        # acc is an M_BLOCK * N_BLOCK gemm result
        # mask the overflow to -inf so that the alpha can be computed correctly
        acc = tl.where(m_mask[:, None] & n_mask[None, :], acc, float("-inf"))
        # run max on column, getting M_BLOCK * 1 maximum per row
        current_max = tl.max(acc, axis=-1)
        # compare with the running max and update the running_max
        new_running_max = tl.maximum(current_max, running_max)
        # alpha on updating the running_sum given the update on running_max
        alpha = tl.exp(running_max - new_running_max)  # the diff

        # compute increment to running_sum on this N_BLOCK
        weights = tl.exp(acc - new_running_max[:, None])
        # directly sum because all the denomiators are new_running_max
        incre = tl.sum(weights, axis=1)

        # running_sum = running_sum * alpha + incre
        running_sum = tl.fma(running_sum, alpha, incre)
        running_max = new_running_max

        # gemm with V.
        # N is now the inner dimension (M_BLOCK, N_BLOCK) * (N_BLOCK * d_BLOCK)
        v_offsets = n_indexes[:, None] * V_stride_N + v_indexes[None, :] * V_stride_d
        V_mask = n_mask[:, None] & v_mask[None, :]
        cur_v = tl.load(V + v_offsets, mask=V_mask)

        # dot the weights now, will normalize with running_sum later
        weighted_v = tl.dot(weights, cur_v)
        # the accumulator for gemm with V, also need to correct the error
        out_acc = tl.fma(out_acc, alpha[:, None], weighted_v)

    # the running sum is correct now
    out_acc /= running_sum[:, None]

    out_offsets = m_indexes[:, None] * out_stride_M + v_indexes[None, :] * out_stride_d
    out_mask = m_mask[:, None] & v_mask[None, :]
    tl.store(out + out_offsets, out_acc, mask=out_mask)


# Q, K, V, output are tensors on the GPU
def solve(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    output: torch.Tensor,
    M: int,
    N: int,
    d: int,
):
    M_BLOCK = 16
    d_BLOCK = 128
    N_BLOCK = 64

    grid = (triton.cdiv(M, M_BLOCK), triton.cdiv(d, d_BLOCK))
    Q_stride_M, Q_stride_d = Q.stride()
    K_stride_N, K_stride_d = K.stride()
    V_stride_N, V_stride_d = V.stride()
    out_stride_M, out_stride_d = output.stride()

    softmax_attention[grid](
        Q=Q,
        K=K,
        V=V,
        out=output,
        M=M,
        N=N,
        d=d,
        Q_stride_M=Q_stride_M,
        Q_stride_d=Q_stride_d,
        K_stride_N=K_stride_N,
        K_stride_d=K_stride_d,
        V_stride_N=V_stride_N,
        V_stride_d=V_stride_d,
        out_stride_M=out_stride_M,
        out_stride_d=out_stride_d,
        M_BLOCK=M_BLOCK,
        d_BLOCK=d_BLOCK,
        N_BLOCK=N_BLOCK,
    )

Edit page
Share this post:

Previous Post
[TIL] TileLang vs Triton EP-01: Initial glance
Next Post
[TIL] Collected: Softmax