Skip to content
Retep's
Go back

[TIL] MLIR Toy Explained Ch 6

Edit page

Link to tutorial

In this blog post, we will explore how to actually lower the print operation to LLVM IR (which we left in the previous chapter).

Transitive lowering

The official definition for transitive lower is: An A->B->C lowering; that is, a lowering in which multiple patterns may be applied in order to fully transform an illegal operation into a set of legal ones.

This gives us the freedom to not fully legalize all operations after a pass. In this example, we are generating a structured loop nest instead of the branch-form in the LLVM dialect. As long as we then have a lowering from the loop operations to LLVM, the lowering will still succeed.

Lowering to LLVM

Let’s build a pass to fully lower the Toy dialect. We first implement a converstion pattern PrintOpLowering. For a toy.print op, we

  1. Do some preparation, like looking up the LLVM print op with getOrInsertPrintf, initialize some formatters.
  2. Create a loop to loop over each dimensions
  3. Insert the load and print operation into the loop body.

Notice a few things:

  1. We create the loop first, then clear the loop body with rewriter.eraseOp(&nested);. This is because some loop constructor will have a yield operation created by default. Using make_early_inc_range so that we can mutate as we traverse.
  2. Notice that we set the insertion to the start of the most nested loop’s body, so we can create the print inside the body. We don’t need to reset the insertion point because this is completely trapped in the pattern matching scope.
class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
public:
  using OpConversionPattern<toy::PrintOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto *context = rewriter.getContext();
    auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
    auto memRefShape = memRefType.getShape();
    auto loc = op->getLoc();

    ModuleOp parentModule = op->getParentOfType<ModuleOp>();

    // Get a symbol reference to the printf function, inserting it if necessary.
    auto printfRef = getOrInsertPrintf(rewriter, parentModule);
    Value formatSpecifierCst = getOrCreateGlobalString(
        loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
    Value newLineCst = getOrCreateGlobalString(
        loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);

    // Create a loop for each of the dimensions within the shape.
    SmallVector<Value, 4> loopIvs;
    for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
      auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
      auto upperBound =
          arith::ConstantIndexOp::create(rewriter, loc, memRefShape[i]);
      auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
      auto loop =
          scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
      for (Operation &nested : make_early_inc_range(*loop.getBody()))
        rewriter.eraseOp(&nested);
      loopIvs.push_back(loop.getInductionVar());

      // Terminate the loop body.
      rewriter.setInsertionPointToEnd(loop.getBody());

      // Insert a newline after each of the inner dimensions of the shape.
      if (i != e - 1)
        LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
                             newLineCst);
      scf::YieldOp::create(rewriter, loc);
      rewriter.setInsertionPointToStart(loop.getBody());
    }

    // Generate a call to printf for the current element of the loop.
    auto elementLoad =
        memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs);
    LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
                         ArrayRef<Value>({formatSpecifierCst, elementLoad}));

    // Notify the rewriter that this operation has been removed.
    rewriter.eraseOp(op);
    return success();
  }

Finally, we can define our pass to utilize the conversion pattern. Notice we also run some other builtin lowering pattern, like populateAffineToStdConversionPatterns. For multi-pattern matching, MLIR provides an implicity benefit parameter for each pattern added to the pattern set, which serves as priority. The rough process is:

  1. Put operations into a worklist
  2. Pick an operation
  3. Find patterns whose root kind matches the operation
  4. Try higher-benefit patterns, rewrite and update the IR
  5. Add affected operations back the the worklist
  6. Repeat until no more rewrites apply.
namespace {
struct ToyToLLVMLoweringPass
    : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass)
  StringRef getArgument() const override { return "toy-to-llvm"; }

  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
  }
  void runOnOperation() final;
};
} // namespace

void ToyToLLVMLoweringPass::runOnOperation() {
  // The first thing to define is the conversion target. This will define the
  // final target for this lowering. For this lowering, we are only targeting
  // the LLVM dialect.
  LLVMConversionTarget target(getContext());
  target.addLegalOp<ModuleOp>();

  // During this lowering, we will also be lowering the MemRef types, that are
  // currently being operated on, to a representation in LLVM. To perform this
  // conversion we use a TypeConverter as part of the lowering. This converter
  // details how one type maps to another. This is necessary now that we will be
  // doing more complicated lowerings, involving loop region arguments.
  LLVMTypeConverter typeConverter(&getContext());

  // Now that the conversion target has been defined, we need to provide the
  // patterns used for lowering. At this point of the compilation process, we
  // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
  // are already exists a set of patterns to transform `affine` and `std`
  // dialects. These patterns lowering in multiple stages, relying on transitive
  // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
  // patterns must be applied to fully transform an illegal operation into a
  // set of legal ones.
  RewritePatternSet patterns(&getContext());
  populateAffineToStdConversionPatterns(patterns);
  populateSCFToControlFlowConversionPatterns(patterns);
  mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
  populateFuncToLLVMConversionPatterns(typeConverter, patterns);

  // The only remaining operation to lower from the `toy` dialect, is the
  // PrintOp.
  patterns.add<PrintOpLowering>(&getContext());

  // We want to completely lower to LLVM, so we use a `FullConversion`. This
  // ensures that only legal operations will remain after the conversion.
  auto module = getOperation();
  if (failed(applyFullConversion(module, target, std::move(patterns))))
    signalPassFailure();
}

Edit page
Share this post:

Previous Post
[TIL] MLIR Toy Explained Ch 7
Next Post
[TIL] MLIR Toy Explained Ch 5