Skip to content
Retep's
Go back

[TIL] Collected: RoPE in TileLang

Edit page

RoPE in TileLang

Today I wrote a small TileLang kernel for rotary positional embedding, as part of implementation of Qwen Inference with TileLang project.

The input tensor is laid out as:

X: [N, S, H, D]

RoPE overview

RoPE applies position information by rotating pairs of hidden dimensions. For a head vector at position pp, split the vector into two halves:

x=[x0,x1,,xd/21,xd/2,,xd1]x = [x_0, x_1, \ldots, x_{d/2-1}, x_{d/2}, \ldots, x_{d-1}]

and treat each pair as:

(ai,bi)=(xi,xi+d/2)(a_i, b_i) = (x_i, x_{i + d/2})

Each pair is rotated by an angle that depends on the token position pp and the rotary dimension ii:

θi=10000i/(d/2)\theta_i = 10000^{-i / (d/2)} ϕp,i=pθi\phi_{p,i} = p \cdot \theta_i

The rotation is:

[aibi]=[cosϕp,isinϕp,isinϕp,icosϕp,i][aibi]\begin{bmatrix} a'_i \\ b'_i \end{bmatrix} = \begin{bmatrix} \cos \phi_{p,i} & -\sin \phi_{p,i} \\ \sin \phi_{p,i} & \cos \phi_{p,i} \end{bmatrix} \begin{bmatrix} a_i \\ b_i \end{bmatrix}

So the elementwise formulas are:

ai=aicosϕp,ibisinϕp,ia'_i = a_i \cos \phi_{p,i} - b_i \sin \phi_{p,i} bi=bicosϕp,i+aisinϕp,ib'_i = b_i \cos \phi_{p,i} + a_i \sin \phi_{p,i}

For the full vector, this is equivalent to multiplying by a block diagonal rotation matrix:

Rp=[Rp,0000Rp,1000Rp,d/21]R_p = \begin{bmatrix} R_{p,0} & 0 & \cdots & 0 \\ 0 & R_{p,1} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R_{p,d/2-1} \end{bmatrix}

where:

Rp,i=[cosϕp,isinϕp,isinϕp,icosϕp,i]R_{p,i} = \begin{bmatrix} \cos \phi_{p,i} & -\sin \phi_{p,i} \\ \sin \phi_{p,i} & \cos \phi_{p,i} \end{bmatrix}

In the split-half layout used by this kernel, I do not physically build that matrix. Instead, I compute two basis tensors:

Cp,i=cos(p10000i/(d/2))C_{p,i} = \cos(p \cdot 10000^{-i/(d/2)}) Sp,i=sin(p10000i/(d/2))S_{p,i} = \sin(p \cdot 10000^{-i/(d/2)})

and apply the same rotation with elementwise operations:

real=realCimagS\text{real}' = \text{real} \odot C - \text{imag} \odot S imag=imagC+realS\text{imag}' = \text{imag} \odot C + \text{real} \odot S

Note here we do element wise multiplication between real/imag and C/S. Since C and S has shape [S, D] the other dimensions are broadcasted. That maps directly onto a kernel: each program owns a tile of [N, S, H, D // 2], loads the real and imaginary halves, computes the corresponding cosine and sine tile, and writes both halves of the rotated output. The element-wise operation is extremely easy to program in TileLang, since you can just unroll all elements with T.Parallel

RoPE treats each head vector as two halves:

x_real = X[..., : D // 2]
x_imag = X[..., D // 2 :]

Then each pair is rotated by a position-dependent angle:

real = x_real * cos - x_imag * sin
imag = x_imag * cos + x_real * sin

Combining grid dimensions

My first kernel tried to launch one grid dimension for every tensor axis:

with T.Kernel(
    T.ceildiv(N, BLOCK_N),
    T.ceildiv(S, BLOCK_S),
    T.ceildiv(H, BLOCK_H),
    T.ceildiv(half_D, BLOCK_D),
    threads=256,
) as (pid_n, pid_s, pid_h, pid_d):

That is a natural way to think about the work, but TileLang follows the CUDA launch model: the grid can have at most three dimensions.

The fix was to pack the head block and dimension block into one physical grid axis:

num_h_blocks = T.ceildiv(H, BLOCK_H)
num_d_blocks = T.ceildiv(half_D, BLOCK_D)

with T.Kernel(
    T.ceildiv(N, BLOCK_N),
    T.ceildiv(S, BLOCK_S),
    num_h_blocks * num_d_blocks,
    threads=256,
) as (pid_n, pid_s, pid_hd):
    pid_h = pid_hd // num_d_blocks
    pid_d = pid_hd % num_d_blocks

This keeps N and S as explicit grid axes, while pid_hd decodes back into the logical (H, D // 2) tile.

Position index vs dimension index

Another easy mistake was mixing up the sequence index and the dimension index.

The RoPE angle is:

angle = position * base ^ (-dim / half_dim)

So the sequence position controls where the token is, and the dimension index controls the frequency:

seq_idx = offset + pid_s * BLOCK_S + s
dim_idx = pid_d * BLOCK_D + d
freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)

cos_basis[s, d] = T.cos(seq_idx * freq)
sin_basis[s, d] = T.sin(seq_idx * freq)

The basis must be indexed by both s and d. Computing only one value per sequence position is not enough, because every rotary dimension has a different frequency.

Integer division is not float division

TVM does not allow ambiguous integer division:

freq = T.pow(10000.0, -dim_idx / half_D)

This failed because both dim_idx and half_D are integer expressions. RoPE needs a real-valued exponent, so dim_idx has to be cast before division:

freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)

This is one of those compiler details that is easy to miss when translating from PyTorch, where / automatically means floating-point division.

Writing both halves

Since the kernel rotates pairs from the first and second halves of the head dimension, the output also has to be written in two pieces:

T.copy(O_local_real, O[..., pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D])
T.copy(
    O_local_imag,
    O[..., pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D],
)

This is cleaner than trying to store both halves into a local buffer whose last dimension is only BLOCK_D.

Appendix

Complete kernel:


@tilelang.jit
def rope(X, offset, BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D):
    N, S, H, D = T.const("N, S, H, D")
    dtype = T.float32
    X: T.Tensor((N, S, H, D), dtype)
    O = T.empty((N, S, H, D), dtype)
    # can I compute constants here?
    half_D = D // 2

    num_h_blocks = T.ceildiv(H, BLOCK_H)
    num_d_blocks = T.ceildiv(half_D, BLOCK_D)
    with T.Kernel(
        T.ceildiv(N, BLOCK_N),
        T.ceildiv(S, BLOCK_S),
        num_h_blocks * num_d_blocks,
        # tilelang support 3 dimension at most, since hardware usually have at most 3 level of parallelism. So we need to combine H and D dimensions together.
        threads=256,
    ) as (
        pid_n,
        pid_s,
        pid_hd,
    ):
        pid_h = pid_hd // num_d_blocks
        pid_d = pid_hd % num_d_blocks
        X_local_real = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
        X_local_imag = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
        O_local_real = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
        O_local_imag = T.alloc_fragment((BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D), dtype)
        cos_basis = T.alloc_fragment((BLOCK_S, BLOCK_D), dtype)
        sin_basis = T.alloc_fragment((BLOCK_S, BLOCK_D), dtype)
        for s, d in T.Parallel(BLOCK_S, BLOCK_D):
            seq_idx = offset + pid_s * BLOCK_S + s
            dim_idx = pid_d * BLOCK_D + d
            freq = T.pow(10000.0, -dim_idx.astype("float32") / half_D)
            cos_basis[s, d] = T.cos(seq_idx * freq)
            sin_basis[s, d] = T.sin(seq_idx * freq)
        # for each block, we process
        # x[..., BLOCK_D * pid_d : BLOCK_D * (pid_d + 1)]
        # and x[..., BLOCK_D * pid_d + half_D : BLOCK_D * (pid_d + 1) + half_D]
        n_blk_id = pid_n
        s_blk_id = pid_s
        T.copy(
            X[
                n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
                s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
                pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
                pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D,
            ],
            X_local_real,
        )
        T.copy(
            X[
                n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
                s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
                pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
                pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D,
            ],
            X_local_imag,
        )

        for n, s, h, d in T.Parallel(BLOCK_N, BLOCK_S, BLOCK_H, BLOCK_D):
            real = (
                X_local_real[n, s, h, d] * cos_basis[s, d]
                - X_local_imag[n, s, h, d] * sin_basis[s, d]
            )
            imag = (
                X_local_imag[n, s, h, d] * cos_basis[s, d]
                + X_local_real[n, s, h, d] * sin_basis[s, d]
            )
            O_local_real[n, s, h, d] = real
            O_local_imag[n, s, h, d] = imag

        T.copy(
            O_local_real,
            O[
                n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
                s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
                pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
                pid_d * BLOCK_D : (pid_d + 1) * BLOCK_D,
            ],
        )
        T.copy(
            O_local_imag,
            O[
                n_blk_id * BLOCK_N : (n_blk_id + 1) * BLOCK_N,
                s_blk_id * BLOCK_S : (s_blk_id + 1) * BLOCK_S,
                pid_h * BLOCK_H : (pid_h + 1) * BLOCK_H,
                pid_d * BLOCK_D + half_D : (pid_d + 1) * BLOCK_D + half_D,
            ],
        )
    return O

The Torch reference I tested against is:

half_d = D // 2
positions = torch.arange(offset, offset + S, dtype=torch.float32, device=X.device)
dims = torch.arange(half_d, dtype=torch.float32, device=X.device)
freqs = torch.pow(10000.0, -dims / half_d)
angles = torch.outer(positions, freqs)

cos_basis = torch.cos(angles).reshape(1, S, 1, half_d)
sin_basis = torch.sin(angles).reshape(1, S, 1, half_d)

x_real = X[..., :half_d]
x_imag = X[..., half_d:]
real = x_real * cos_basis - x_imag * sin_basis
imag = x_imag * cos_basis + x_real * sin_basis
out = torch.cat([real, imag], dim=-1)

Edit page
Share this post:

Previous Post
[TIL] MLIR Toy Explained Ch 1
Next Post
[TIL] Understanding IM2COL