2D convolution in Triton
Today I looked at a small Triton kernel for single-channel 2D convolution:
@triton.jit
def conv_2d(
input,
kernel,
output,
input_rows,
input_cols,
kernel_rows,
kernel_cols,
input_row_stride,
input_col_stride,
kernel_row_stride,
kernel_col_stride,
output_row_stride,
output_col_stride,
ROW_BLOCK: tl.constexpr,
COL_BLOCK: tl.constexpr,
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_start = ROW_BLOCK * row_pid
col_start = COL_BLOCK * col_pid
# compute output offsets for block (ROW_BLOCK, COL_BLOCK)
output_row_offsets = row_start + tl.arange(0, ROW_BLOCK)
output_col_offsets = col_start + tl.arange(0, COL_BLOCK)
output_mask = (output_row_offsets < input_rows - kernel_rows + 1)[:, None] & (
output_col_offsets < input_cols - kernel_cols + 1
)
acc = tl.zeros((ROW_BLOCK, COL_BLOCK), dtype=tl.float32)
for i in range(kernel_rows):
for j in range(kernel_cols):
# notice that we loop through each kernel value and compute the increment to each output element
# that it can interact with
k_val = tl.load(kernel + i * kernel_row_stride + j * kernel_col_stride)
input_row_offsets = output_row_offsets + i
input_col_offsets = output_col_offsets + j
input_mask = (input_row_offsets < input_rows)[:, None] & (
input_col_offsets < input_cols
)[None, :]
input_val = tl.load(
input
+ input_row_offsets[:, None] * input_row_stride
+ input_col_offsets[None, :] * input_col_stride,
mask=input_mask,
)
acc += k_val * input_val
tl.store(
output
+ output_row_offsets[:, None] * output_row_stride
+ output_col_offsets[None, :] * output_col_stride,
acc,
mask=output_mask,
)
One program owns one output tile
The launch grid is two-dimensional:
grid = (
cdiv(output_rows, ROW_BLOCK),
cdiv(output_cols, COL_BLOCK),
)
Each Triton program computes a ROW_BLOCK x COL_BLOCK tile of the output. With the current constants, that means one program computes up to 32 x 32 = 1024 output elements.
output
col_pid=0 col_pid=1 col_pid=2
+--------------+--------------+--------------+
row_pid=0 | 32 x 32 | 32 x 32 | 32 x 32 |
+--------------+--------------+--------------+
row_pid=1 | 32 x 32 | 32 x 32 | 32 x 32 |
+--------------+--------------+--------------+
row_pid=2 | 32 x 32 | 32 x 32 | 32 x 32 |
+--------------+--------------+--------------+
Inside one program:
output_row_offsets = row_start + tl.arange(0, ROW_BLOCK)
output_col_offsets = col_start + tl.arange(0, COL_BLOCK)
acc = tl.zeros((ROW_BLOCK, COL_BLOCK), dtype=tl.float32)
output_row_offsets is a vector of 32 row indexes. output_col_offsets is a vector of 32 column indexes. By combining them with [:, None] and [None, :], Triton creates a 32 x 32 block of addresses.
This is the usual Triton style: create offset vectors, broadcast them into a block, load a block, do block arithmetic, then store a block.
The convolution loop
For each kernel element (i, j), the program loads one scalar kernel value:
k_val = tl.load(kernel + i * kernel_row_stride + j * kernel_col_stride)
Then it loads a 32 x 32 input tile shifted by (i, j):
input_row_offsets = output_row_offsets + i
input_col_offsets = output_col_offsets + j
input_val = tl.load(
input
+ input_row_offsets[:, None] * input_row_stride
+ input_col_offsets[None, :] * input_col_stride,
mask=input_mask,
)
Then it accumulates:
acc += k_val * input_val
So the mental model is:
for every kernel position:
load one scalar kernel value
load the input patch covered by the output tile at that kernel offset
multiply the whole input tile by the scalar
add it into the output accumulator
For a 3 x 3 kernel, one program performs 9 block loads from the input and 9 scalar loads from the kernel. For a 5 x 5 kernel, it performs 25 block loads from the input and 25 scalar loads from the kernel.
Memory access pattern
Assume the input is row-major contiguous:
input_row_stride = input_cols
input_col_stride = 1
For a fixed kernel offset (i, j), the input addresses are:
input
+ (output_row_offsets[:, None] + i) * input_cols
+ (output_col_offsets[None, :] + j)
That means each row of the 32 x 32 block is contiguous in memory:
input rows loaded by one program for one kernel element
row r+0: x x x x x x x x ... 32 contiguous elements
row r+1: x x x x x x x x ... 32 contiguous elements
row r+2: x x x x x x x x ... 32 contiguous elements
...
row r+31: x x x x x x x x ... 32 contiguous elements
Across rows there is a stride of input_cols. So the memory access is contiguous along columns and strided along rows. This is a natural layout for row-major images.
The output store has the same shape:
output
+ output_row_offsets[:, None] * output_row_stride
+ output_col_offsets[None, :] * output_col_stride
With:
output_row_stride = output_cols
output_col_stride = 1
each output row is also contiguous.
So the kernel has a clean access pattern:
- Kernel memory: tiny scalar loads, reused across the whole output tile.
- Input memory:
ROW_BLOCK x COL_BLOCKblock loads for each kernel element. - Output memory: one
ROW_BLOCK x COL_BLOCKblock store at the end.
The important part is that the output is not written repeatedly. The partial sum stays in acc until all kernel positions are consumed.
What gets reused
There are two kinds of reuse here.
The first reuse is the kernel scalar:
k_val = tl.load(...)
acc += k_val * input_val
One scalar k_val is applied to all 32 x 32 input elements in the current block. That is good. The kernel is small, and the scalar load is cheap compared with the input tile load.
The second reuse is implicit overlap between neighboring output elements. In convolution, adjacent outputs share most of their input window. For example, two horizontally adjacent 3 x 3 outputs share 6 of their 9 input values.
This kernel does not explicitly stage that shared input region into shared memory. Instead, each kernel offset loads a shifted 32 x 32 tile from global memory:
kernel offset (0, 0): load input rows r..r+31, cols c..c+31
kernel offset (0, 1): load input rows r..r+31, cols c+1..c+32
kernel offset (0, 2): load input rows r..r+31, cols c+2..c+33
Those loads overlap heavily. Hardware caches may catch some of that reuse, but the program itself is not organizing the input tile as:
(ROW_BLOCK + kernel_rows - 1) x (COL_BLOCK + kernel_cols - 1)
For a 32 x 32 output tile and a 3 x 3 kernel, the unique input region is only 34 x 34 = 1156 elements. But the current code loads 9 x 32 x 32 = 9216 input elements at the Triton block level. Many of those refer to overlapping memory locations (This is why we need explicit memory placement in TileLang!)
A production-style implicit GEMM convolution
For a faster convolution kernel, we can instead express convolution as an implicit matrix multiplication.
The conceptual matrix multiplication is:
output[M, OC] = input_im2col[M, K] @ weight[K, OC]
where:
M = batch * output_rows * output_cols
K = input_channels * kernel_rows * kernel_cols
We do not actually materialize input_im2col. The kernel computes the input addresses for the current M x K tile on the fly, loads those values, and feeds them into tl.dot.
This version assumes contiguous NHWC input, HWIO kernel, and NHWC output. Under that layout, the input-channel dimension is contiguous (N-HWC), which makes the K dimension friendlier for vectorized loads and matrix multiply.
import torch
import triton
import triton.language as tl
@triton.jit
def conv2d_implicit_gemm_nhwc_hwio(
input,
kernel,
output,
batch,
input_rows,
input_cols,
input_channels: tl.constexpr,
output_channels,
kernel_rows: tl.constexpr,
kernel_cols: tl.constexpr,
output_rows,
output_cols,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
k_offsets = tl.arange(0, BLOCK_K)
output_spatial = output_rows * output_cols
total_m = batch * output_spatial
total_k: tl.constexpr = input_channels * kernel_rows * kernel_cols
batch_offsets = m_offsets // output_spatial
spatial_offsets = m_offsets % output_spatial
out_rows = spatial_offsets // output_cols
out_cols = spatial_offsets % output_cols
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, total_k, BLOCK_K):
k = k_start + k_offsets
# translate k into indexes to read
ic = k % input_channels
kw = (k // input_channels) % kernel_cols
kh = k // (input_channels * kernel_cols)
input_rows_for_k = out_rows[:, None] + kh[None, :]
input_cols_for_k = out_cols[:, None] + kw[None, :]
input_channels_for_k = ic[None, :]
# translate indexes to offsets
input_ptrs = (
input
+ batch_offsets[:, None] * input_rows * input_cols * input_channels
+ input_rows_for_k * input_cols * input_channels
+ input_cols_for_k * input_channels
+ input_channels_for_k
)
input_mask = (
(m_offsets[:, None] < total_m)
& (k[None, :] < total_k)
& (input_rows_for_k < input_rows)
& (input_cols_for_k < input_cols)
)
input_tile = tl.load(input_ptrs, mask=input_mask, other=0.0)
weight_ptrs = (
kernel
+ kh[:, None] * kernel_cols * input_channels * output_channels
+ kw[:, None] * input_channels * output_channels
+ ic[:, None] * output_channels
+ n_offsets[None, :]
)
weight_mask = (k[:, None] < total_k) & (n_offsets[None, :] < output_channels)
weight_tile = tl.load(weight_ptrs, mask=weight_mask, other=0.0)
acc += tl.dot(input_tile, weight_tile, input_precision="tf32")
output_ptrs = (
output
+ batch_offsets[:, None] * output_rows * output_cols * output_channels
+ out_rows[:, None] * output_cols * output_channels
+ out_cols[:, None] * output_channels
+ n_offsets[None, :]
)
output_mask = (m_offsets[:, None] < total_m) & (
n_offsets[None, :] < output_channels
)
tl.store(output_ptrs, acc, mask=output_mask)
It turns convolution into the form GPUs are best at:
acc += tl.dot(input_tile, weight_tile)
The tiles are:
input_tile: [BLOCK_M, BLOCK_K]
weight_tile: [BLOCK_K, BLOCK_N]
acc: [BLOCK_M, BLOCK_N]
That gives much better reuse than the scalar version:
- Each
input_tilevalue is reused acrossBLOCK_Noutput channels. - Each
weight_tilevalue is reused acrossBLOCK_Moutput spatial positions. - The accumulator stays in registers until the full
Kdimension is reduced. - Compatible fp16/bf16 shapes can use tensor-core style matrix instructions through
tl.dot.
There are still details to tune per GPU and shape. BLOCK_M = 64, BLOCK_N = 64, and BLOCK_K = 32 are starting points, not universal constants. For small channel counts, smaller BLOCK_K or BLOCK_N may win. For large output channels, BLOCK_N = 128 can be better. For fp32-heavy workloads, register pressure and occupancy become more important.
This kernel also assumes stride 1, no padding, and dilation 1. Padding and stride are easy to add to the address math:
input_rows_for_k = out_rows[:, None] * stride_h + kh[None, :] - pad_h
input_cols_for_k = out_cols[:, None] * stride_w + kw[None, :] - pad_w
Then the input mask must check that those computed input coordinates are inside [0, input_rows) and [0, input_cols).