Skip to content
Retep's
Go back

[TIL] MLIR Toy Explained Ch 4

Edit page

Link to tutorial

In this blog post, we will explore how to use interfaces for rewrites.

Interface

Interface is another core MLIR offerings. It is introduced to address the problem of massive code duplication in pass implementation. Think about the case: for operations like add, mult, div, etc, they share a common structure and semantic meaning, and therefore are often the target for the same pass. However, as illustrated in the previous blog, we need to implement the pass for every operation kind. Is there a way that we can implement the shared logic once? This leads to the concept of interface (or traits). We can extract the common logic to interact with an interface, and each operation implement that interface. In this way, we separates the generic and custom logic so that it’s easier to maintain and is a more accurate abstraction.

We will illustrate the usage of interface with an example: shape inference. Currently we allow a tensor to be unshaped, but for actual codegen we need every tensor to include shape information (i.e. for boundary check). Therefore, we need to infer all tensor shapes from the known shapes.

Inlining

First, we need to inline all functions. This is because the functions user defined are shape generic, and the tensors in the function can be shaped differently on each function call depending on the input. An alternative is to do function specialization, that is, for each possible combination of the input, create a separate function for it.

Inlining can be done pretty easily with the builtin InlinerPass. We only need to implement the interface for Toy:

struct ToyInlinerInterface : public DialectInlinerInterface {
  using DialectInlinerInterface::DialectInlinerInterface;

  /// All call operations within toy can be inlined.
  bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final {
    return true;
  }

  /// All operations within toy can be inlined.
  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
    return true;
  }

  // All functions within toy can be inlined.
  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
    return true;
  }

  /// Handle the given inlined terminator(toy.return) by replacing it with a new
  /// operation as necessary.
  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
    // Only "toy.return" needs to be handled here.
    auto returnOp = cast<ReturnOp>(op);

    // Replace the values directly with the return operands.
    assert(returnOp.getNumOperands() == valuesToRepl.size());
    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
  }

};

We also mark all non-main function as private when generating mlir from AST so that the inliner can clean them up:

mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
  ...
  // If this function isn't main, then set the visibility to private.
  if (funcAST.getProto()->getName() != "main")
    function.setPrivate();

  return function;
}

Adding the interface to initialization:

/// Dialect initialization, the instance will be owned by the context. This is
/// the point of registration of types and operations for the dialect.
void ToyDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
      >();
  addInterfaces<ToyInlinerInterface>();
}

Then, we want the inliner to operate on toy.generic_call and toy.func. Since they are our dialect the operation, the inliner does not have primer knowledge on they are call op and functions. We can do this by modifying the decalartion of those operations:

// in Opts.td
include "mlir/Interfaces/CallInterfaces.td"
def FuncOp : Toy_Op<"func",
    [FunctionOpInterface, IsolatedFromAbove]> {
  ...
}

def GenericCallOp : Toy_Op<"generic_call",
    [DeclareOpInterfaceMethods<CallOpInterface>]> {
  ...
}

The last piece to make inlining work is to add a cast operation, since the input of the function is unshaped, but the tensors in main function are shaped. We explicitly cast the tensors into unshaped (we will address it later!) so that the function can be correctly inlined.

For cast op, notice that the type of the result is the target type to cast into, so the functionality is already defined by the operation’s type signature.

def CastOp : Toy_Op<"cast", [
     DeclareOpInterfaceMethods<CastOpInterface>,
     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
     Pure,
     SameOperandsAndResultShape
  ]> {
  let summary = "shape cast operation";
  let description = [{
    The "cast" operation converts a tensor from one type to an equivalent type
    without changing any data elements. The source and destination types must
    both be tensor types with the same element type. If both are ranked, then
    shape is required to match. The operation is invalid if converting to a
    mismatching constant dimension.
  }];

  let arguments = (ins F64Tensor:$input);
  let results = (outs F64Tensor:$output);

  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}

We can add the cast to materializeCallConversion, which is a builtin hook for mlir inliner pass.

struct ToyInlinerInterface : public DialectInlinerInterface {
  ...
  /// Attempts to materialize a conversion for a type mismatch between a call
  /// from this dialect, and a callable region. This method should generate an
  /// operation that takes 'input' as the only operand, and produces a single
  /// result of 'resultType'. If a conversion can not be generated, nullptr
  /// should be returned.
  Operation *materializeCallConversion(OpBuilder &builder, Value input,
                                       Type resultType,
                                       Location conversionLoc) const final {
    return CastOp::create(builder, conversionLoc, resultType, input);
  }
};

That’s it. We can just add the createInlinerPass and it will help us inline all function calls.

Shape inference

The core of shape inference is that we want to know the output type given the input of an operation. We will also leverage ODS to specify the inference rule (we used ODS to declare operations before). We create a new file called ShapeInferenceInterface.td. Notice that in the method secion, we define an interface method called inferShapes.

#ifndef SHAPE_INFERENCE_INTERFACE
#define SHAPE_INFERENCE_INTERFACE

include "mlir/IR/OpBase.td"

def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
  let description = [{
    Interface to access a registered method to infer the return types for an
    operation that can be used during type inference.
  }];

  let methods = [
    InterfaceMethod<"Infer and set the output shape for the current operation.",
                    "void", "inferShapes">
  ];
}
#endif // SHAPE_INFERENCE_INTERFACE

Next, we import the interface into Ops.td to add the declaration to ops that we care about: transpose, add, mult.

include "toy/ShapeInferenceInterface.td"
def MulOp : Toy_Op<"mul",
    [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
  ...
}

We also need to provide the implementation of inferShapes for those operations in Dialect.cpp. It’s pretty trivial because for element-wise op, the output shape is the same as input shape. Note that the inferShapes is a mutator.

void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }

For transpose, we swap two dimensions

void TransposeOp::inferShapes() {
  auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
  SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
  getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}

Finally, we need to create a pass that leverages the interface to perform type inference. Create a new file called ShapeInferencePass.cpp:

struct ShapeInferencePass
    : public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
  StringRef getArgument() const override { return "toy-shape-inference"; }

  void runOnOperation() override {
    auto f = getOperation();

    // Populate the worklist with the operations that need shape inference:
    // these are operations that return a dynamic shape.
    llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
    f.walk([&](mlir::Operation *op) {
      if (returnsDynamicShape(op))
        opWorklist.insert(op);
    });

    // Iterate on the operations in the worklist until all operations have been
    // inferred or no change happened (fix point).
    while (!opWorklist.empty()) {
      // Find the next operation ready for inference, that is an operation
      // with all operands already resolved (non-generic).
      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
      if (nextop == opWorklist.end())
        break;

      Operation *op = *nextop;
      opWorklist.erase(op);

      // Ask the operation to infer its output shapes.
      LDBG() << "Inferring shape for: " << *op;
      if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
        shapeOp.inferShapes();
      } else {
        op->emitError("unable to infer shape of operation without shape "
                      "inference interface");
        return signalPassFailure();
      }
    }

    // If the operation worklist isn't empty, this indicates a failure.
    if (!opWorklist.empty()) {
      f.emitError("Shape inference failed, ")
          << opWorklist.size() << " operations couldn't be inferred\n";
      signalPassFailure();
    }
  }

  /// A utility method that returns if the given operation has all of its
  /// operands inferred.
  static bool allOperandsInferred(Operation *op) {
    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
      return llvm::isa<RankedTensorType>(operandType);
    });
  }

  /// A utility method that returns if the given operation has a dynamically
  /// shaped result.
  static bool returnsDynamicShape(Operation *op) {
    return llvm::any_of(op->getResultTypes(), [](Type resultType) {
      return !llvm::isa<RankedTensorType>(resultType);
    });
  }
};

/// Create a Shape Inference pass.
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
  return std::make_unique<ShapeInferencePass>();
}

Notice a few things:

  1. ShapeInferencePass inherits OperationPass
  2. The ShapeInferencePass is defined upon toy::FuncOp, which is intended because we only want it to run inside a function, not crossing the function boundary (“intraprocedural”). Usually after inlining, there’s only one main function left, but we do this for safety, since we don’t want shape inference to be applied across function boundaries.

The way shape inference pass work is the following steps:

  1. Get all operations in the function body
  2. Collect all operations that returns dynamic shape
  3. Interatively call inferShape on all the operations, until all operations are type inferred.

Finally, add the pass to pass manager and we are done! Interface is definitely interesting and it’s a great software engineering techniques that shines in compiler nowadays.


Edit page
Share this post:

Previous Post
[TIL] MLIR Toy Explained Ch 5
Next Post
[TIL] MLIR Toy Explained Ch 3