In this blog post, we will explore how we convert the AST to actual MLIR code.
Dialect
When we talk about creating a dialect, philosophical question that we want to ask is, what is a dialect?
A dialect is a combination of the following things:
- Prefix: a defined namespace where all operations live. e.g.
Toyis our namespace - Types: custom types
- Operations: custom Ops, with invariants/semantics/traits
- Passes: analytical or transformative compiler passes to optimize the dialect
Defining a dialect is essentially defining the above components.
Operation & Opaque API
An innovation of MLIR is that it is completely operation-based. If you have looked at typical LLVM IR or assembly code, you would notice that there are a lot of different things happening: functions, instructions, blocks, etc. The problem with that is it becomes less flexible and extensible since some designs are coupled with a certain architecture.
[Maybe add an example here]
In MLIR, everything is operation based. This means function is also an operation. Operation is opaquely defined that enables full customizability.
%t_tensor = "toy.transpose"(%tensor) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> loc("example/file/path":12:1)
An operation has the following components:
- Types. For
toy.transpose, you see it takes in a tensor and output a tensor. - Attributes. A dictionary of zero or more attributes, which are special operands that are always constant. It enables more variety in using the same operation.
- Input and Ouput.
- A source code identifier.
Block & Regions
Block is a list of operations. It is similar to basic block in LLVM. A block must end with a terminator operation that branches or returns.
Region is a list of block nested inside an operation. Since function is an operation, the function body is a region. The SSA values inside the region can reference outside ones, but not the other way around.
Defining a dialect (Ops.td)
We will use TableGen to define a dialect. TableGen is library for declarative specification so that we can avoid a lot of boilerplates.
# define the namespace
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
}
# define the base for toy op. It inherits the base Op class
include "mlir/IR/OpBase.td"
class Toy_Op<string mnemonic, list<Trait> traits = []> :
Op<Toy_Dialect, mnemonic, traits>;
Then we need to define for each operation:
- arguments (input)
- results (output)
- builders (constructor)
- number of verifiers
// Notice we inherit from Toy_Op here!!
// "constant" is the mnemonic for this op
def ConstantOp : Toy_Op<"constant", [Pure]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documentation of the operations within our dialect.
let summary = "constant";
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:
```mlir
%0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>
: tensor<2x3xf64>
```
}];
// The constant operation takes an attribute as the only input.
let arguments = (ins F64ElementsAttr:$value);
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
// using `ConstantOp::create(builder, ...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<(ins "DenseElementsAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<(ins "double":$value)>
];
// Indicate that additional verification for this operation is necessary.
let hasVerifier = 1;
}
You can see we added two custom builders. By default, a full form builder is created for free (if skipDefaultBuilders is not turned off):
builder.create<toy::ConstantOp>(loc, type, denseAttr)
However, for convenience we can create custom builders
builder.create<toy::ConstantOp>(loc, denseAttr)infers the output type from denseAttrbuilder.create<toy::ConstantOp>(loc, 1.0)is a short hand override for f64 constant construction.
Function operation is also interesting to look at. The argument for this operation is symbol name, function type (input -> output), argument list and return list. Note that in MLIR we can return multiple outputs for a function.
def FuncOp : Toy_Op<"func", [
FunctionOpInterface, IsolatedFromAbove
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
Region *getCallableRegion() { return &getBody(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
We also need an operation to call the function. The argument is a function symbol and a list of params. The output is the output of the function executed (tensor type since we only have this type).
def GenericCallOp : Toy_Op<"generic_call"> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
be specialized for the shape of its arguments. The callee name is attached
as a symbol reference via an attribute. The arguments list must match the
arguments expected by the callee. For example:
```mlir
%4 = toy.generic_call @my_func(%1, %3)
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
```
This is only valid if a function named "my_func" exists and takes two
arguments.
}];
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Specialize assembly printing and parsing using a declarative format.
let assemblyFormat = [{
$callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
// Add custom build methods for the generic call operation.
let builders = [
OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
];
}
Implementing the dialect (Dialect.cpp)
The TableGen only do the declaration for us. Implementing a dialect essentially means implementing these four methods for each operation:
build: the constructor for the operationparse: parsing the operation from text and construct the operationprint: the inverse ofparse. Dump the operation into text formatverify: verify that the operation is legal
So at the beginning, we will import our definition and run the registration:
#include "toy/Dialect.cpp.inc"
/// Dialect initialization, the instance will be owned by the context. This is
/// the point of registration of types and operations for the dialect.
void ToyDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc" // this is the output of tablegen
>();
}
Again, taking constant op as an example, we have:
// build op
void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) {
auto dataType = RankedTensorType::get({}, builder.getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,mlir::OperationState &result) {
mlir::DenseElementsAttr value;
// OpAsmParser is LLVM builtin that we can directly leverage
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();
result.addTypes(value.getType());
return success();
}
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << getValue();
}
llvm::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}
Implementing MLIR generation (MLIRGen.cpp)
Now let’s implement the MLIR generation function. It should turn the AST into MLIR operations.
mlirGen is our public API.
// The public API for codegen.
mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
ModuleAST &moduleAST) {
return MLIRGenImpl(context).mlirGen(moduleAST);
}
The entry of the implementation gets an AST module and run mlirGen for each function, which again runs the mlirGen for each operations in function body.
class MLIRGenImpl {
public:
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module operation.
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("module verification error");
return nullptr;
}
return theModule;
}
For example, in the code below, we use AddOp::create to construct an add operation, which in turn calls the build function we defined earlier.
mlir::Value mlirGen(BinaryExprAST &binop) {
// First emit the operations for each side of the operation before emitting
// the operation itself. For example if the expression is `a + foo(a)`
// 1) First it will visiting the LHS, which will return a reference to the
// value holding `a`. This value should have been emitted at declaration
// time and registered in the symbol table, so nothing would be
// codegen'd. If the value is not in the symbol table, an error has been
// emitted and nullptr is returned.
// 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
mlir::Value lhs = mlirGen(*binop.getLHS());
if (!lhs)
return nullptr;
mlir::Value rhs = mlirGen(*binop.getRHS());
if (!rhs)
return nullptr;
auto location = loc(binop.loc());
// Derive the operation name from the binary operator. At the moment we only
// support '+' and '*'.
switch (binop.getOp()) {
case '+':
return AddOp::create(builder, location, lhs, rhs);
case '*':
return MulOp::create(builder, location, lhs, rhs);
}
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
return nullptr;
}