Skip to content
Retep's
Go back

[TIL] TileLang vs Triton EP-01: Initial glance

Edit page

TileLang vs Triton

I’ve been programming in both TileLang and Triton for some time, and the interesting difference between them is not just syntax. It is the level where each language wants you to think.

Triton usually asks:

What does one program instance do to a block of elements?

TileLang usually asks:

What tiles do I move through the GPU memory hierarchy, and what operations happen on those tiles?

That sounds subtle, but it changes how the code is shaped.

Triton feels like blocked SPMD programming. You manually compute offsets, create masks, load vectors or matrices, do block operations, and store results. TileLang feels closer to a tile-level schedule language. You declare kernel launch structure, allocate shared or fragment buffers, copy tiles between memory scopes, and call tile operations like T.gemm. Let’s look over some examples:

Example 1: vector add

Vector add is too simple to need TileLang, but it shows the API difference clearly.

In Triton:

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    y = tl.load(y_ptr + offsets, mask=mask, other=0.0)

    tl.store(out_ptr + offsets, x + y, mask=mask)


def add(x, y):
    out = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK"]),)
    add_kernel[grid](x, y, out, x.numel(), BLOCK=1024)
    return out

One Triton program owns a contiguous block of elements. The core pattern is:

  1. Get a program id.
  2. Turn it into a vector of offsets.
  3. Mask the tail.
  4. Load, compute, store.

This is the basic Triton rhythm. The block is represented by tl.arange, and operations like x + y apply to the whole block.

The same idea in TileLang looks more like a kernel body with a block context and a parallel loop:

import tilelang
import tilelang.language as T
from tilelang import jit


@jit
def add(N: int, block: int = 256, dtype: str = "float32"):
    @T.prim_func
    def add_kernel(
        A: T.Tensor((N,), dtype),
        B: T.Tensor((N,), dtype),
        C: T.Tensor((N,), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block), threads=block) as bx:
            for i in T.Parallel(block):
                gi = bx * block + i
                if gi < N:
                    C[gi] = A[gi] + B[gi]

    return add_kernel

Here the grid is declared with T.Kernel, and the work inside the block is expressed with T.Parallel. For vector add, this is not really more powerful than Triton. It is just a different spelling of the same simple idea.

But notice the mental model. Triton makes the block a vector value:

offsets = pid * BLOCK + tl.arange(0, BLOCK)
x = tl.load(x_ptr + offsets, mask=mask)

TileLang makes the block a launch context and then describes parallel work inside it:

with T.Kernel(T.ceildiv(N, block), threads=block) as bx:
    for i in T.Parallel(block):
        C[bx * block + i] = A[bx * block + i] + B[bx * block + i]

For elementwise kernels, I usually find Triton more direct (mask syntax is more familiar to me than the pure if-else in the SIMD model).

Example 2: matmul

Matmul is where the design gap becomes more meaningful.

A small Triton matmul kernel is built from pointer arithmetic and block tensors:

@triton.jit
def matmul_kernel(
    A,
    B,
    C,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am: tl.constexpr,
    stride_ak: tl.constexpr,
    stride_bk: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_cm: tl.constexpr,
    stride_cn: tl.constexpr,
    BM: tl.constexpr,
    BN: tl.constexpr,
    BK: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)

    acc = tl.zeros((BM, BN), dtype=tl.float32)

    for k0 in range(0, K, BK):
        a = tl.load(
            A + offs_m[:, None] * stride_am + (k0 + offs_k[None, :]) * stride_ak,
            mask=(offs_m[:, None] < M) & (k0 + offs_k[None, :] < K),
            other=0.0,
        )
        b = tl.load(
            B + (k0 + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn,
            mask=(k0 + offs_k[:, None] < K) & (offs_n[None, :] < N),
            other=0.0,
        )
        acc += tl.dot(a, b)

    tl.store(
        C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
        acc,
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
    )

The important objects are block tensors:

You explicitly build those tensors from pointer expressions. The compiler sees static block shapes and can lower tl.dot into efficient matrix instructions when the types and shapes line up.

In TileLang, a matmul is usually written around the memory hierarchy:

@jit
def matmul(M: int, N: int, K: int, BM: int = 128, BN: int = 128, BK: int = 32):
    @T.prim_func
    def matmul_kernel(
        A: T.Tensor((M, K), "float16"),
        B: T.Tensor((K, N), "float16"),
        C: T.Tensor((M, N), "float16"),
    ):
        with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
            A_s = T.alloc_shared((BM, BK), "float16")
            B_s = T.alloc_shared((BK, BN), "float16")
            C_f = T.alloc_fragment((BM, BN), "float32")

            T.clear(C_f)

            for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
                T.copy(A[by * BM, ko * BK], A_s)
                T.copy(B[ko * BK, bx * BN], B_s)
                T.gemm(A_s, B_s, C_f)

            T.copy(C_f, C[by * BM, bx * BN])

    return matmul_kernel

This code is not centered on pointer tensors. It is centered on named tile buffers:

The loop says: pipeline over K blocks, copy global tiles into shared memory, multiply shared tiles into a fragment accumulator, then copy the fragment result back to global memory. That is a different abstraction boundary.

The API difference

Memory placement

Triton exposes block values. You are still close to the address math:

ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
a = tl.load(ptrs, mask=mask, other=0.0)

TileLang exposes tile movement:

A_s = T.alloc_shared((BM, BK), "float16")
T.copy(A[by * BM, ko * BK], A_s)

In Triton, the shape is carried by tensor expressions. In TileLang, the shape is carried by tile buffers.

In Triton, memory placement is often implicit. You write loads operations, but the allocation is done purely by compiler. In TileLang, memory placement is part of the source program. T.alloc_shared, T.alloc_fragment, T.copy, and T.Pipelined are not incidental details. They are the program.

Therefore, Triton kernels often feel like NumPy-style block programs with explicit masks and pointer arithmetic. You think in terms of one program instance and the block tensor it owns. TileLang is more explicit so has more flavor of CUDA program (like, in C++ you need to malloc and manage the memory explicitly).

Strides

Triton examples usually pass strides explicitly because the kernel is often written against arbitrary PyTorch tensor layouts:

ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
a = tl.load(ptrs, mask=mask, other=0.0)

The stride math is visible in the kernel. This makes it clear that the physical address is not necessarily the same thing as the logical [row, col] coordinate.

TileLang examples often omit stride because T.Tensor((M, K), dtype) describes a logical tensor with a default compact layout:

A: T.Tensor((M, K), "float16")

T.copy(A[by * BM, ko * BK], A_s)

Here A[by * BM, ko * BK] is a logical tile start. The copy extent is inferred from A_s, and the default row-major address calculation is derived from the tensor shape. If the input is non-contiguous, padded, transposed, or otherwise custom laid out, stride has not disappeared. It has to be represented through the tensor/buffer layout or through explicit indexing.

For a simple strided global tensor, the layout can be attached to the buffer itself. The exact spelling depends on the TileLang version, but conceptually it looks like this:

@T.prim_func
def kernel(
    A: T.Buffer((M, K), "float16", strides=(stride_am, stride_ak)),
    C: T.Buffer((M, N), "float16", strides=(stride_cm, stride_cn)),
):
    with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
        row = by * BM
        col = bx * BN

        x = A[row, col]
        C[row, col] = x

The indexing is still logical:

x = A[row, col]

but the physical address is derived from the buffer’s stride metadata, roughly like:

x = A.data[row * stride_am + col * stride_ak]

This is an important difference from the common Triton style. In Triton, the logical-to-physical mapping is usually visible in the pointer expression:

ptr = A + row * stride_am + col * stride_ak
x = tl.load(ptr)

In TileLang, if the buffer carries the layout, the source code can keep using logical indices. For more exotic layouts, such as swizzled shared-memory layouts for tensor-core operations, TileLang usually represents that through layout annotations or specialized allocated buffers rather than by treating the original global tensor as a plain strided matrix.

So Triton tends to make physical layout visible by default. TileLang tends to let the common compact case read like logical tile movement.

Masks

Triton also makes edge masks explicit at the load/store site:

a = tl.load(
    A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak,
    mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
    other=0.0,
)

TileLang’s T.copy can often rely on compiler legalization for boundary safety:

T.copy(A[by * BM, ko * BK], A_s)
T.copy(C_f, C[by * BM, bx * BN])

Since A, C, A_s, and C_f all have known shapes, the compiler can insert guards for edge tiles and remove those guards when it proves the tile is fully in bounds. For matmul, this works naturally because out-of-bounds input elements should behave like zero padding, and out-of-bounds output elements should simply not be stored.

But this is only a memory-safety mask. It is not always the same as the algorithmic mask. For softmax or attention, invalid positions usually need to become -inf, not 0.0. In TileLang that logic should still be written explicitly:

valid = T.all_of(row < M, col < N, col <= row)

score = T.if_then_else(
    valid,
    score,
    -T.infinity("float32"),
)

That corresponds to the Triton pattern:

score = tl.where(valid, score, float("-inf"))

References


Edit page
Share this post:

Next Post
[TIL] Collected: Softmax Attention