RoPE in TileLang
Today I wrote a small TileLang kernel for rotary positional embedding, as part of implementation of Qwen Inference with TileLang project.
The input tensor is laid out as:
X: [N, S, H, D]
- N: batch size
- S: sequence length
- H: number of heads
- D: head dimension
RoPE overview
RoPE applies position information by rotating pairs of hidden dimensions. For a head vector at position , split the vector into two halves:
and treat each pair as:
Each pair is rotated by an angle that depends on the token position and the rotary dimension :
The rotation is:
So the elementwise formulas are:
For the full vector, this is equivalent to multiplying by a block diagonal rotation matrix:
where:
In the split-half layout used by this kernel, I do not physically build that matrix. Instead, I compute two basis tensors:
and apply the same rotation with elementwise operations:
Note here we do element wise multiplication between real/imag and C/S. Since C and S has shape [S, D] the other dimensions are broadcasted. That maps directly onto a kernel: each program owns a tile of [N, S, H, D // 2], loads the real and imaginary halves, computes the corresponding cosine and sine tile, and writes both halves of the rotated output. The element-wise operation is extremely easy to program in TileLang, since you can just unroll all elements with T.Parallel
RoPE treats each head vector as two halves:
x_real = X[..., : D // 2]
x_imag = X[..., D // 2 :]
Then each pair is rotated by a position-dependent angle:
real = x_real * cos - x_imag * sin
imag = x_imag * cos + x_real * sin
Combining grid dimensions
My first kernel tried to launch one grid dimension for every tensor axis:
with T.Kernel(
T.ceildiv(N, BLOCK_N),
T.ceildiv(S, BLOCK_S),
T.ceildiv(H, BLOCK_H),
T.ceildiv(half_D, BLOCK_D),
threads=256,
) as (pid_n, pid_s, pid_h, pid_d):
That is a natural way to think about the work, but TileLang follows the CUDA launch model: the grid can have at most three dimensions.
The fix was to pack the head block and dimension block into one physical grid axis:
num_h_blocks = T.ceildiv(H, BLOCK_H)
num_d_blocks = T.ceildiv(half_D, BLOCK_D)
with T.Kernel(
T.ceildiv(N, BLOCK_N),
T.ceildiv(S, BLOCK_S),
num_h_blocks * num_d_blocks,
threads=256,
) as (pid_n, pid_s, pid_hd):
pid_h = pid_hd // num_d_blocks
pid_d = pid_hd % num_d_blocks
This keeps N and S as explicit grid axes, while pid_hd decodes back into the logical (H, D // 2) tile.
Position index vs dimension index
Another easy mistake was mixing up the sequence index and the dimension index.
The RoPE angle is:
angle = position * base ^ (-dim / half_dim)
So the sequence position controls where the token is, and the dimension index controls the frequency:
seq_idx = offset + pid_s * BLOCK_S + s
dim_idx = pid_d * BLOCK_D + d
freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)
cos_basis[s, d] = T.cos(seq_idx * freq)
sin_basis[s, d] = T.sin(seq_idx * freq)
The basis must be indexed by both s and d. Computing only one value per sequence position is not enough, because every rotary dimension has a different frequency.
Integer division is not float division
TVM does not allow ambiguous integer division:
freq = T.pow(10000.0, -dim_idx / half_D)
This failed because both dim_idx and half_D are integer expressions. RoPE needs a real-valued exponent, so dim_idx has to be cast before division:
freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)
This is one of those compiler details that is easy to miss when translating from PyTorch, where / automatically means floating-point division.
Writing both halves
Since the kernel rotates pairs from the first and second halves of the head dimension, the output also has to be written in two pieces:
T.copy(O_local_real, O[..., pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D])
T.copy(
O_local_imag,
O[..., pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D],
)
This is cleaner than trying to store both halves into a local buffer whose last dimension is only BLOCK_D.
Appendix
Complete kernel:
@tilelang.jit
def rope(X, offset, BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D):
N, S, H, D = T.const("N, S, H, D")
dtype = T.float32
X: T.Tensor((N, S, H, D), dtype)
O = T.empty((N, S, H, D), dtype)
# can I compute constants here?
half_D = D // 2
num_h_blocks = T.ceildiv(H, BLOCK_H)
num_d_blocks = T.ceildiv(half_D, BLOCK_D)
with T.Kernel(
T.ceildiv(N, BLOCK_N),
T.ceildiv(S, BLOCK_S),
num_h_blocks * num_d_blocks,
# tilelang support 3 dimension at most, since hardware usually have at most 3 level of parallelism. So we need to combine H and D dimensions together.
threads=256,
) as (
pid_n,
pid_s,
pid_hd,
):
pid_h = pid_hd // num_d_blocks
pid_d = pid_hd % num_d_blocks
X_local_real = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
X_local_imag = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
O_local_real = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
O_local_imag = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
cos_basis = T.alloc_fragment((BLOCK_S, BLOCK_D), dtype)
sin_basis = T.alloc_fragment((BLOCK_S, BLOCK_D), dtype)
for s, d in T.Parallel(BLOCK_S, BLOCK_D):
seq_idx = offset + pid_s * BLOCK_S + s
dim_idx = pid_d * BLOCK_D + d
freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)
cos_basis[s, d] = T.cos(seq_idx * freq)
sin_basis[s, d] = T.sin(seq_idx * freq)
# for each block, we process
# x[..., BLOCK_D * pid_d : BLOCK_D * (pid_d + 1)]
# and x[..., BLOCK_D * pid_d + half_D : BLOCK_D * (pid_d + 1) + half_D]
n_blk_id = pid_n
s_blk_id = pid_s
T.copy(
X[
n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D,
],
X_local_real,
)
T.copy(
X[
n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D,
],
X_local_imag,
)
for n, s, h, d in T.Parallel(BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D):
real = (
X_local_real[n, s, h, d] * cos_basis[s, d]
- X_local_imag[n, s, h, d] * sin_basis[s, d]
)
imag = (
X_local_imag[n, s, h, d] * cos_basis[s, d]
+ X_local_real[n, s, h, d] * sin_basis[s, d]
)
O_local_real[n, s, h, d] = real
O_local_imag[n, s, h, d] = imag
T.copy(
O_local_real,
O[
n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D,
],
)
T.copy(
O_local_imag,
O[
n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D,
],
)
return O
The Torch reference I tested against is:
half_d = D // 2
positions = torch.arange(offset, offset + S, dtype=torch.float32, device=X.device)
dims = torch.arange(half_d, dtype=torch.float32, device=X.device)
freqs = torch.pow(10000.0, -dims / half_d)
angles = torch.outer(positions, freqs)
cos_basis = torch.cos(angles).reshape(1, S, 1, half_d)
sin_basis = torch.sin(angles).reshape(1, S, 1, half_d)
x_real = X[..., :half_d]
x_imag = X[..., half_d:]
real = x_real * cos_basis - x_imag * sin_basis
imag = x_imag * cos_basis + x_real * sin_basis
out = torch.cat([real, imag], dim=-1)