Skip to content
Retep's
Go back

[TIL] Collected: Convolution

Edit page

2D convolution in Triton

Today I looked at a small Triton kernel for single-channel 2D convolution:

@triton.jit
def conv_2d(
    input,
    kernel,
    output,
    input_rows,
    input_cols,
    kernel_rows,
    kernel_cols,
    input_row_stride,
    input_col_stride,
    kernel_row_stride,
    kernel_col_stride,
    output_row_stride,
    output_col_stride,
    ROW_BLOCK: tl.constexpr,
    COL_BLOCK: tl.constexpr,
):
    row_pid = tl.program_id(0)
    col_pid = tl.program_id(1)
    row_start = ROW_BLOCK * row_pid
    col_start = COL_BLOCK * col_pid
    # compute output offsets for block (ROW_BLOCK, COL_BLOCK)
    output_row_offsets = row_start + tl.arange(0, ROW_BLOCK)
    output_col_offsets = col_start + tl.arange(0, COL_BLOCK)
    output_mask = (output_row_offsets < input_rows - kernel_rows + 1)[:, None] & (
        output_col_offsets < input_cols - kernel_cols + 1
    )
    acc = tl.zeros((ROW_BLOCK, COL_BLOCK), dtype=tl.float32)
    for i in range(kernel_rows):
        for j in range(kernel_cols):
            # notice that we loop through each kernel value and compute the increment to each output element
            # that it can interact with
            k_val = tl.load(kernel + i * kernel_row_stride + j * kernel_col_stride)
            input_row_offsets = output_row_offsets + i
            input_col_offsets = output_col_offsets + j
            input_mask = (input_row_offsets < input_rows)[:, None] & (
                input_col_offsets < input_cols
            )[None, :]
            input_val = tl.load(
                input
                + input_row_offsets[:, None] * input_row_stride
                + input_col_offsets[None, :] * input_col_stride,
                mask=input_mask,
            )
            acc += k_val * input_val

    tl.store(
        output
        + output_row_offsets[:, None] * output_row_stride
        + output_col_offsets[None, :] * output_col_stride,
        acc,
        mask=output_mask,
    )

One program owns one output tile

The launch grid is two-dimensional:

grid = (
    cdiv(output_rows, ROW_BLOCK),
    cdiv(output_cols, COL_BLOCK),
)

Each Triton program computes a ROW_BLOCK x COL_BLOCK tile of the output. With the current constants, that means one program computes up to 32 x 32 = 1024 output elements.

output

              col_pid=0        col_pid=1        col_pid=2
          +--------------+--------------+--------------+
row_pid=0 |  32 x 32     |  32 x 32     |  32 x 32     |
          +--------------+--------------+--------------+
row_pid=1 |  32 x 32     |  32 x 32     |  32 x 32     |
          +--------------+--------------+--------------+
row_pid=2 |  32 x 32     |  32 x 32     |  32 x 32     |
          +--------------+--------------+--------------+

Inside one program:

output_row_offsets = row_start + tl.arange(0, ROW_BLOCK)
output_col_offsets = col_start + tl.arange(0, COL_BLOCK)
acc = tl.zeros((ROW_BLOCK, COL_BLOCK), dtype=tl.float32)

output_row_offsets is a vector of 32 row indexes. output_col_offsets is a vector of 32 column indexes. By combining them with [:, None] and [None, :], Triton creates a 32 x 32 block of addresses.

This is the usual Triton style: create offset vectors, broadcast them into a block, load a block, do block arithmetic, then store a block.

The convolution loop

For each kernel element (i, j), the program loads one scalar kernel value:

k_val = tl.load(kernel + i * kernel_row_stride + j * kernel_col_stride)

Then it loads a 32 x 32 input tile shifted by (i, j):

input_row_offsets = output_row_offsets + i
input_col_offsets = output_col_offsets + j

input_val = tl.load(
    input
    + input_row_offsets[:, None] * input_row_stride
    + input_col_offsets[None, :] * input_col_stride,
    mask=input_mask,
)

Then it accumulates:

acc += k_val * input_val

So the mental model is:

for every kernel position:
    load one scalar kernel value
    load the input patch covered by the output tile at that kernel offset
    multiply the whole input tile by the scalar
    add it into the output accumulator

For a 3 x 3 kernel, one program performs 9 block loads from the input and 9 scalar loads from the kernel. For a 5 x 5 kernel, it performs 25 block loads from the input and 25 scalar loads from the kernel.

Memory access pattern

Assume the input is row-major contiguous:

input_row_stride = input_cols
input_col_stride = 1

For a fixed kernel offset (i, j), the input addresses are:

input
+ (output_row_offsets[:, None] + i) * input_cols
+ (output_col_offsets[None, :] + j)

That means each row of the 32 x 32 block is contiguous in memory:

input rows loaded by one program for one kernel element

row r+0:  x x x x x x x x ... 32 contiguous elements
row r+1:  x x x x x x x x ... 32 contiguous elements
row r+2:  x x x x x x x x ... 32 contiguous elements
...
row r+31: x x x x x x x x ... 32 contiguous elements

Across rows there is a stride of input_cols. So the memory access is contiguous along columns and strided along rows. This is a natural layout for row-major images.

The output store has the same shape:

output
+ output_row_offsets[:, None] * output_row_stride
+ output_col_offsets[None, :] * output_col_stride

With:

output_row_stride = output_cols
output_col_stride = 1

each output row is also contiguous.

So the kernel has a clean access pattern:

The important part is that the output is not written repeatedly. The partial sum stays in acc until all kernel positions are consumed.

What gets reused

There are two kinds of reuse here.

The first reuse is the kernel scalar:

k_val = tl.load(...)
acc += k_val * input_val

One scalar k_val is applied to all 32 x 32 input elements in the current block. That is good. The kernel is small, and the scalar load is cheap compared with the input tile load.

The second reuse is implicit overlap between neighboring output elements. In convolution, adjacent outputs share most of their input window. For example, two horizontally adjacent 3 x 3 outputs share 6 of their 9 input values.

This kernel does not explicitly stage that shared input region into shared memory. Instead, each kernel offset loads a shifted 32 x 32 tile from global memory:

kernel offset (0, 0): load input rows r..r+31, cols c..c+31
kernel offset (0, 1): load input rows r..r+31, cols c+1..c+32
kernel offset (0, 2): load input rows r..r+31, cols c+2..c+33

Those loads overlap heavily. Hardware caches may catch some of that reuse, but the program itself is not organizing the input tile as:

(ROW_BLOCK + kernel_rows - 1) x (COL_BLOCK + kernel_cols - 1)

For a 32 x 32 output tile and a 3 x 3 kernel, the unique input region is only 34 x 34 = 1156 elements. But the current code loads 9 x 32 x 32 = 9216 input elements at the Triton block level. Many of those refer to overlapping memory locations (This is why we need explicit memory placement in TileLang!)

A production-style implicit GEMM convolution

For a faster convolution kernel, we can instead express convolution as an implicit matrix multiplication.

The conceptual matrix multiplication is:

output[M, OC] = input_im2col[M, K] @ weight[K, OC]

where:

M = batch * output_rows * output_cols
K = input_channels * kernel_rows * kernel_cols

We do not actually materialize input_im2col. The kernel computes the input addresses for the current M x K tile on the fly, loads those values, and feeds them into tl.dot.

This version assumes contiguous NHWC input, HWIO kernel, and NHWC output. Under that layout, the input-channel dimension is contiguous (N-HWC), which makes the K dimension friendlier for vectorized loads and matrix multiply.

import torch
import triton
import triton.language as tl


@triton.jit
def conv2d_implicit_gemm_nhwc_hwio(
    input,
    kernel,
    output,
    batch,
    input_rows,
    input_cols,
    input_channels: tl.constexpr,
    output_channels,
    kernel_rows: tl.constexpr,
    kernel_cols: tl.constexpr,
    output_rows,
    output_cols,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    k_offsets = tl.arange(0, BLOCK_K)

    output_spatial = output_rows * output_cols
    total_m = batch * output_spatial
    total_k: tl.constexpr = input_channels * kernel_rows * kernel_cols

    batch_offsets = m_offsets // output_spatial
    spatial_offsets = m_offsets % output_spatial
    out_rows = spatial_offsets // output_cols
    out_cols = spatial_offsets % output_cols

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k_start in range(0, total_k, BLOCK_K):
        k = k_start + k_offsets

        # translate k into indexes to read
        ic = k % input_channels
        kw = (k // input_channels) % kernel_cols
        kh = k // (input_channels * kernel_cols)

        input_rows_for_k = out_rows[:, None] + kh[None, :]
        input_cols_for_k = out_cols[:, None] + kw[None, :]
        input_channels_for_k = ic[None, :]

        # translate indexes to offsets
        input_ptrs = (
            input
            + batch_offsets[:, None] * input_rows * input_cols * input_channels
            + input_rows_for_k * input_cols * input_channels
            + input_cols_for_k * input_channels
            + input_channels_for_k
        )

        input_mask = (
            (m_offsets[:, None] < total_m)
            & (k[None, :] < total_k)
            & (input_rows_for_k < input_rows)
            & (input_cols_for_k < input_cols)
        )

        input_tile = tl.load(input_ptrs, mask=input_mask, other=0.0)

        weight_ptrs = (
            kernel
            + kh[:, None] * kernel_cols * input_channels * output_channels
            + kw[:, None] * input_channels * output_channels
            + ic[:, None] * output_channels
            + n_offsets[None, :]
        )

        weight_mask = (k[:, None] < total_k) & (n_offsets[None, :] < output_channels)
        weight_tile = tl.load(weight_ptrs, mask=weight_mask, other=0.0)

        acc += tl.dot(input_tile, weight_tile, input_precision="tf32")

    output_ptrs = (
        output
        + batch_offsets[:, None] * output_rows * output_cols * output_channels
        + out_rows[:, None] * output_cols * output_channels
        + out_cols[:, None] * output_channels
        + n_offsets[None, :]
    )
    output_mask = (m_offsets[:, None] < total_m) & (
        n_offsets[None, :] < output_channels
    )

    tl.store(output_ptrs, acc, mask=output_mask)

It turns convolution into the form GPUs are best at:

acc += tl.dot(input_tile, weight_tile)

The tiles are:

input_tile:  [BLOCK_M, BLOCK_K]
weight_tile: [BLOCK_K, BLOCK_N]
acc:         [BLOCK_M, BLOCK_N]

That gives much better reuse than the scalar version:

There are still details to tune per GPU and shape. BLOCK_M = 64, BLOCK_N = 64, and BLOCK_K = 32 are starting points, not universal constants. For small channel counts, smaller BLOCK_K or BLOCK_N may win. For large output channels, BLOCK_N = 128 can be better. For fp32-heavy workloads, register pressure and occupancy become more important.

This kernel also assumes stride 1, no padding, and dilation 1. Padding and stride are easy to add to the address math:

input_rows_for_k = out_rows[:, None] * stride_h + kh[None, :] - pad_h
input_cols_for_k = out_cols[:, None] * stride_w + kw[None, :] - pad_w

Then the input mask must check that those computed input coordinates are inside [0, input_rows) and [0, input_cols).


Edit page
Share this post:

Previous Post
[Book Club] TileLang: Bridge Programmability And Performance In Modern Neural Kernels
Next Post
[TIL] TileLang EP-01: Initial glance