MLIR技术细节整理

58 分钟阅读

MLIR

编译方法

git clone https://github.com/llvm/llvm-project.git
mkdir llvm-project/build
cd llvm-project/build
cmake -G Ninja ../llvm \
   -DLLVM_ENABLE_PROJECTS=mlir \
   -DLLVM_BUILD_EXAMPLES=ON \
   -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \
   -DCMAKE_BUILD_TYPE=Release \
   -DLLVM_ENABLE_ASSERTIONS=ON \
#  -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_LLD=ON

cmake --build . --target check-mlir

基本概念

以一段来自官网toy,来解释各种概念

# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
  return transpose(a) * transpose(b);
}

def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<2, 3> = [1, 2, 3, 4, 5, 6];
  var c = multiply_transpose(a, b);
  var d = multiply_transpose(b, a);
  print(d);
}

转成IR定义如下:

mlir::Operation

%2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("Toy/Ch2/codegen.toy":5:25)

Operation是指操作,也可以理解成运算。如上中%2 = "toy.mul"(%0,%1)可以看做一个operation。如toy中的toy.multoy.transposetoy.constant等等都是Operation的名称。

代码中mlir::Operation是通用定义,包含通用的接口和属性;MulOpTransposeOpConstantOp等等是特定定义,包含特定的属性。前者可以通过llvm::dyn_cast(动态)或llvm::cast(静态)转换成后者;后者通过getOperation()转换成前者。如下:

void processConstantOp(mlir::Operation *operation) {
  ConstantOp op = llvm::dyn_cast<ConstantOp>(operation);

  // This operation is not an instance of `ConstantOp`.
  if (!op)
    return;

  // Get the internal operation instance wrapped by the smart pointer.
  mlir::Operation *internalOperation = op.getOperation();
  assert(internalOperation == operation &&
         "these operation instances are the same");
}

特定Op可以直接访问属性;mlir::Operation无法直接访问特定Op的属性,但可以类似如下间接访问,如下:

// getAttrOfType
IntegerAttr opsetAttr = op->getAttrOfType<::mlir::Attribute>("onnx_opset").dyn_cast_or_null<IntegerAttr>();
if (opsetAttr)
  opset = opsetAttr.getValue().getSExtValue();
// getAttr
mlir::Type targetType = op->getAttr("to").cast<::mlir::TypeAttr>().getValue();

mlir::Operation有如下常用接口:

OperationName getName() { return name; }
/// Remove this operation from its parent block and delete it.
void erase();
/// Remove the operation from its parent block, but don't delete it.
void remove();
/// Returns the operation block that contains this operation.
Block *getBlock() { return block; }
/// Return the context this operation is associated with.
MLIRContext *getContext();
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy> OpTy getParentOfType();
void dump();
unsigned getNumOperands();
Value getOperand(unsigned idx) { return getOpOperand(idx).get(); }
void setOperand(unsigned idx, Value value);
unsigned getNumResults();
/// Get the 'idx'th result of this operation.
OpResult getResult(unsigned idx) { return OpResult(this, idx); }
void replaceAllUsesWith(Operation *op);
void moveBefore(Operation *existingOp);
void moveBefore(Block *block, llvm::iplist<Operation>::iterator iterator);
void moveAfter(Operation *existingOp);
void moveAfter(Block *block, llvm::iplist<Operation>::iterator iterator);

mlir::Value

Value可以理解成操作数,参数等等,如例子中的%arg0%1

其中%1,%2,%3这种被称为SSA (static single assignment)静态单赋值。

Value几乎所有接口都非常有用:

Type getType() const;
void setType(Type newType);
MLIRContext *getContext() const { return getType().getContext(); }
Location getLoc() const;
Region *getParentRegion();
Block *getParentBlock();
void replaceAllUsesWith(Value newValue) const;
void replaceUsesWithIf(Value newValue, function_ref<bool(OpOperand &)> shouldReplace);
use_range getUses() const { return {use_begin(), use_end()}; }
bool hasOneUse() const;
bool use_empty() const;
void dump();

BlockArgument

如例子中的%arg0%arg1等等

class BlockArgument : public Value{
  /// Returns the block that owns this argument.
  Block *getOwner() const { return getImpl()->owner; }

  /// Return the type of this value.
  Type getType() const { return getImpl()->type; }

  /// Set the type of this value.
  void setType(Type newType) { getImpl()->type = newType; }

  /// Returns the number of this argument.
  unsigned getArgNumber() const { return getImpl()->index; }
}

OpResult

如例中的%0%1%2等等

class OpResult : public Value {
public:
  using Value::Value;

  static bool classof(Value value) {
    return value.getKind() != Kind::BlockArgument;
  }

  /// Returns the operation that owns this result.
  Operation *getOwner() const;

  /// Returns the number of this result.
  unsigned getResultNumber() const;
}

通过value,获取operation,如下:

// %2对应value,则
Operation * operation = value.getDefiningOp();
toy::MulOp op = value.getDefiningOp<toy::MulOp>();
// get Value from operation
Value inner_value = operation->getResult(0);
assert(inner_value == value);

mlir::Type

Type可以理解成Value的类型,如例子中的tensor<*xf64>tensor<2x3xf64>等等对应的是TensorType,继承自mlir::Type

type常用接口有这些:

bool operator==(Type other) const { return impl == other.impl; }
bool operator!=(Type other) const { return !(*this == other); }
explicit operator bool() const { return impl; }
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
template <typename First, typename Second, typename... Rest> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
bool isIndex() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;
bool isF64() const;
bool isF80() const;
bool isF128() const;

/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;

ShapedType

ShapedType用于表示Shape,有rankedunranked之分,ranked在维度上又有staticdynamic之分。

有如下这些方式:

  • [*] if it is an unranked shape
  • [?, 2] if a rank 2 tensor with one unknown dimension
  • [3, 4] is a rank 2 static tensor
  • [] is a scalar
  • [1] is a rank 1 tensor with 1 element
  • [invalid] for an invalid shape

有如下这些接口,用了Trait:

class ShapedType : public ::mlir::TypeInterface<ShapedType, detail::ShapedTypeInterfaceTraits> {
public:
  using ::mlir::TypeInterface<ShapedType, detail::ShapedTypeInterfaceTraits>::TypeInterface;
  template <typename ConcreteType>
  struct Trait : public detail::ShapedTypeTrait<ConcreteType> {};
  ::mlir::ShapedType cloneWith(::llvm::Optional<::llvm::ArrayRef<int64_t>> shape, ::mlir::Type elementType) const;
  ::mlir::Type getElementType() const;
  bool hasRank() const;
  ::llvm::ArrayRef<int64_t> getShape() const;
  /// Returns whether the shape has a rank.
  bool hasRank() const;
  /// Returns the element type.
  Type getElementType() const;
  /// Populates the dimensions from shape referenced.
  /// Requires: shape is ranked.
  void getDims(SmallVectorImpl<int64_t> &res) const;
  /// Populates the dimensions of the ShapeTypeComponents.
  /// Requires: shape is ranked.
  void getDims(ShapedTypeComponents &res) const;
  /// Returns the size of the index'th dimension.
  /// Requires: shape is ranked.
  int64_t getDimSize(int index) const;
  /// Returns whether the index'th dimension is dynamic.
  /// Requires: shape is ranked.
  bool isDynamicDim(int index) const { // 判断dimSize是否为kDynamicSize(对应-1)
    return ShapedType::isDynamic(getDimSize(index));
  }
  /// Returns whether the shape is fully static.
  bool hasStaticShape() const;
  /// Returns the rank of the shape.
  /// Requires: shape is ranked.
  int64_t getRank() const;
  /// Returns the number of elements in the shape.
  /// Requires: hasStaticShape
  int64_t getNumElements() const;

该类由dialect自动生成:

def ShapedTypeInterface : TypeInterface<"ShapedType"> {
  let cppNamespace = "::mlir";
  let description = [{
    ...

TensorType

如例子中的tensor<*xf64>tensor<2x3xf64>等等,是type与ShapeType的子类。

定义如下:

class TensorType : public Type, public ShapedType::Trait<TensorType> {
public:
  using Type::Type;

  /// Returns the element type of this tensor type.
  Type getElementType() const;

  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
  bool hasRank() const;

  /// Returns the shape of this tensor type.
  ArrayRef<int64_t> getShape() const;

  /// Clone this type with the given shape and element type. If the
  /// provided shape is `None`, the current shape of the type is used.
  TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
                       Type elementType) const;

  /// Return true if the specified element type is ok in a tensor.
  static bool isValidElementType(Type type);

  /// Methods for support type inquiry through isa, cast, and dyn_cast.
  static bool classof(Type type);

  /// Allow implicit conversion to ShapedType.
  operator ShapedType() const { return cast<ShapedType>(); }
};

通过value获取相关信息如下:

mlir::Type type = value.getType(); // get mlir::Type
mlir::TensorType tensor_type = type.cast<mlir::TensorType>(); // cast to tensor type
ArrayRef<int64_t> shape = tensor_type.getShape(); // get shape
mlir::Type eleType = tensor_type.getElementType();
assert(eleType.isF64() == true);

RankedTensorType 与 UnrankedTensorType

tensor<*xfp32>,对应的是UnrankedTensorType;RankedTensorType有如下这些形式:

// Known rank but unknown dimensions.
tensor<? x ? x ? x ? x f32>
// Partially known dimensions.
tensor<? x ? x 13 x ? x f32>
// Full static shape.
tensor<17 x 4 x 13 x 4 x f32>
// Tensor with rank zero. Represents a scalar.
tensor<f32>
// Zero-element dimensions are allowed.should be optimized away before lowering tensors to vectors
tensor<0 x 42 x f32>

使用方法举例:

// create
mlir::Type getType(ArrayRef<int64_t> shape) {
  // If the shape is empty, then this type is unranked.
  if (shape.empty())
    return mlir::UnrankedTensorType::get(builder.getF64Type());

  // Otherwise, we use the given shape.
  return mlir::RankedTensorType::get(shape, builder.getF64Type());
}

mlir::Attribute

定义op的属性,比如onnx中LstmOp的一段定义如下:

  let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$W,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$R,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$B,
    AnyTypeOf<[TensorOf<[I32]>, AnyMemRef, NoneType]>:$sequence_lens,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType]>:$initial_h,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType, NoneType]>:$initial_c,
    AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType, NoneType, NoneType]>:$P,
    OptionalAttr<F32ArrayAttr>:$activation_alpha,
    OptionalAttr<F32ArrayAttr>:$activation_beta,
    OptionalAttr<StrArrayAttr>:$activations,
    OptionalAttr<F32Attr>:$clip,
    DefaultValuedStrAttr<StrAttr, "forward">:$direction,
    OptionalAttr<SI64Attr>:$hidden_size,
    DefaultValuedAttr<SI64Attr, "0">:$input_forget);
......

其中AnyTypeOf定义的都是Value,Attr部分定义的是属性。属性限定有这两种,(没有限定,则是必选属性):

  • OptionalAttr:可选属性

  • DefaultValuedAttr:默认属性

属性有这些类型:

  • 无符号整型:UI64Attr、UI32Attr、UI16Attr、UI8Attr、UI1Attr
  • 有符号整型:SI64Attr、SI32Attr、SI16Attr、SI8Attr、SI1Attr
  • 浮点型:F32Attr、F64Attr
  • 字符串:StrAttr
  • 布尔型:BoolAttr
  • 数组型:BoolArrayAttr、StrArrayAttr、I32ArrayAttr、F32ArrayAttr、
  • 字典型:DictionaryAttr
def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;

可以通过builder创建各种类型的属性,如下:

::mlir::Builder((*this)->getContext()).getF32FloatAttr(7.0)
::mlir::Builder((*this)->getContext()).getF64ArrayAttr({7., 8.})
::mlir::Builder((*this)->getContext()).getStrArrayAttr({"a", "b"})

mlir::FuncOp

形式如下:

// External function definitions.
func @abort()
func @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64

// A function that returns its argument twice:
func @count(%x: i64) -> (i64, i64)
attributes {fruit: "banana"} {
  return %x, %x: i64, i64
}

// A function with an argument attribute
func @example_fn_arg(%x: i32 {swift.self = unit})

// A function with a result attribute
func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})

// A function with an attribute
func @example_fn_attr() attributes {dialectName.attrName = false}

创建方式如下:

// Create a module with 2 functions.
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
for (StringRef name : {"secret", "not_secret"}) {
  FuncOp func =
    FuncOp::create(builder.getUnknownLoc(), name,
                   builder.getFunctionType(llvm::None, llvm::None));
  func.setPrivate();
  module->push_back(func);
}

遍历方式如下:

// 遍历特定op
void runOnOperation() override {
  FuncOp func = getOperation();
  SmallVector<scf::ForOp, 4> loops;
  func.walk([&](scf::ForOp forOp) {
    if (getNestingDepth(forOp) == loopDepth)
      loops.push_back(forOp);
  });
  auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
    if (annotateLoop) {
      op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
    }
  };
  for (auto loop : loops)
    (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
}

// 遍历所有op
func.walk([&](mlir::Operation *op) {
   // process Operation `op`.
});

mlir::Block

block在{}之间一系列operation的合集,一般op是没有block,funcOp会有一个Block。

所以一般op->getParentOp()得到的是funcOp。

有如下这些接口:

Region *getParent() const;

/// Returns the closest surrounding operation that contains this block.
Operation *getParentOp();
/// Return if this block is the entry block in the parent region.
bool isEntryBlock();

/// Insert this block (which must not already be in a region) right before the specified block.
void insertBefore(Block *block);

/// Unlink this block from its current region and insert it right before the specific block.
void moveBefore(Block *block);

/// Unlink this Block from its parent region and delete it.
void erase();
BlockArgListType getArguments() { return arguments; }
unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }
iterator begin() { return operations.begin(); }
iterator end() { return operations.end(); }
Operation &back() { return operations.back(); }
Operation &front() { return operations.front(); }

mlir::Region

一般情况下op是没有region,在控制流op中会存在region的概念。

一个Region包含1个或多个Block,第一个block的argument也是region的argument。

Region定义如下:

class Region<Pred condition, string descr = ""> :
    RegionConstraint<condition, descr>;

// Any region.
def AnyRegion : Region<CPred<"true">, "any region">;

// A region with the given number of blocks.
class SizedRegion<int numBlocks> : Region<
  CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">,
  "region with " # numBlocks # " blocks">;

以if举例如下:

def IfOp : SCF_Op<"if",
      [DeclareOpInterfaceMethods<RegionBranchOpInterface,
                                 ["getNumRegionInvocations",
                                  "getRegionInvocationBounds"]>,
       SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
       NoRegionArguments]> {
  let summary = "if-then-else operation";
  let arguments = (ins I1:$condition);
  let results = (outs Variadic<AnyType>:$results);
  let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
  let extraClassDeclaration = [{
    OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
      Block* body = getBody(0);
      return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener)
                                  : OpBuilder::atBlockEnd(body, listener);
    }
    OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
      Block* body = getBody(1);
      return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener)
                                  : OpBuilder::atBlockEnd(body, listener);
    }
    Block* thenBlock();
    YieldOp thenYield();
    Block* elseBlock();
    YieldOp elseYield();
  }];
  let hasFolder = 1;
  ...

可以看出IfOp定义了两个region,每个region有一个block。一段IR定义如下:

%x, %y = scf.if %b -> (f32, f32) {
  %x_true = ...
  %y_true = ...
  scf.yield %x_true, %y_true : f32, f32
} else {
  %x_false = ...
  %y_false = ...
  scf.yield %x_false, %y_false : f32, f32
}

有如下接口:

MLIRContext *getContext();

/// Return a location for this region. This is the location attached to the
/// parent container. The region must have a valid parent container.
Location getLoc();

//===--------------------------------------------------------------------===//
// Block list management
//===--------------------------------------------------------------------===//

using BlockListType = llvm::iplist<Block>;
BlockListType &getBlocks() { return blocks; }
iterator begin() { return blocks.begin(); }
iterator end() { return blocks.end(); }
Block &back() { return blocks.back(); }
Block &front() { return blocks.front(); }
unsigned getNumArguments() { return getArguments().size(); }
BlockArgument getArgument(unsigned i) { return getArguments()[i]; }

mlir::ModuleOp

可以理解成最顶层Op,形式如下:

module {
  func @foo()
}

mlir::MLIRContext

mlir上下文,可以理解成一系列对象的最顶层。所有mlir相关对象都依赖MLIRContext。

mlir::Builder

用于创建全局对象的建造者,比如创建types、attributes等等,声明如下:

class Builder {
public:
  explicit Builder(MLIRContext *context) : context(context) {}
  explicit Builder(Operation *op) : Builder(op->getContext()) {}

  MLIRContext *getContext() const { return context; }

  // Locations.
  Location getUnknownLoc();
  // Types.
  FloatType getBF16Type();
  FloatType getF16Type();
  FloatType getF32Type();
  FloatType getF64Type();
  IntegerType getIntegerType(unsigned width);
  IntegerType getIntegerType(unsigned width, bool isSigned);
  // Attributes.
  NamedAttribute getNamedAttr(StringRef name, Attribute val);
  UnitAttr getUnitAttr();
  BoolAttr getBoolAttr(bool value);
  DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
  IntegerAttr getIntegerAttr(Type type, int64_t value);
  IntegerAttr getIntegerAttr(Type type, const APInt &value);
  FloatAttr getFloatAttr(Type type, double value);
  FloatAttr getFloatAttr(Type type, const APFloat &value);
  StringAttr getStringAttr(const Twine &bytes);
  ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
  // Convenience methods for fixed types.
  FloatAttr getF16FloatAttr(float value);
  FloatAttr getF32FloatAttr(float value);
  FloatAttr getF64FloatAttr(double value);
  IntegerAttr getI8IntegerAttr(int8_t value);
  IntegerAttr getI16IntegerAttr(int16_t value);
  IntegerAttr getI32IntegerAttr(int32_t value);
  ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
  ArrayAttr getF32ArrayAttr(ArrayRef<float> values);

OpBuilder

继承Builder,用于创建Operation。

有如下接口:

/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&...args) {
  OperationState state(location,
  getCheckRegisteredInfo<OpTy>(location.getContext()));
  OpTy::build(*this, state, std::forward<Args>(args)...);
  auto *op = createOperation(state);
  auto result = dyn_cast<OpTy>(op);
  assert(result && "builder didn't return the right type");
  return result;
}

onnx-mlir的一段代码为例,创建OP如下:

auto gemmOp = builder.create<ONNXGemmOp>(UnknownLoc::get(&ctx),
                                         /*Y=*/yType, /*A=*/aVal, /*B=*/bVal, /*C=*/cVal, alphaAttr, betaAttr,
                                         aTransAttr, bTransAttr);
gemmOp.getResult().setType(yType);

llvm::SmallVector<Value, 1> results = {gemmOp.getResult()};
builder.create<ReturnOp>(UnknownLoc::get(&ctx), results);

Dialect

待补充

Traits

阅读:TraitsC++ traits技术浅谈

traits,被叫做特性萃取技术,提取“被传进的对象”对应的返回类型,让同一个接口实现对应的功能。

MLIR中traits基类是:TraitBase<ConcreteType, TraitType>。子类有这几种:AttributeTraitOpTraitTypeTrait等等。

其中ConcreteType对应绑定到该trait的实体类,TraitType对应trait类。

Trait步骤

1 定义Trait。Trait类定义举例如下:

template <typename ConcreteType>
class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
public:
  /// Override the 'verifyTrait' hook to add additional verification on the
  /// concrete operation.
  static LogicalResult verifyTrait(Operation *op) {
    // ...
  }
};

2 绑定Tait。举例如下:

class MyOp : public Op<MyOp, MyTrait> {};

3 使用Trait。举例如下:

Operation *op = ..;
if (op->hasTrait<MyTrait>())
  ...;

CastOp举例

CastOp定义如下:

def CastOp : Toy_Op<"cast", [
     DeclareOpInterfaceMethods<CastOpInterface>,
     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
     NoSideEffect,
     SameOperandsAndResultShape
  ]> {
  let summary = "shape cast operation";
  ...;

该op用了4个trait,这四个trait可以分两类,一类是通用类,NoSideEffect与SameOperandsAndResultShape;一类是专用类,通过DeclareOpInterfaceMethods声明的CastOpInterface与ShapeInferenceOpInterface。

这两类有两点区别:

  • 通用类,对所有OP都可以使用,不需要专门为op定义相应的接口;专用类,需用对特定OP定义特定的接口。

  • 通用类,会被mlir内部直接调用;专用类,需要显示调用

  • 判断方法上不同,如下:

    Operation *op = ..;
    if (op->hasTrait<NoSideEffect>())
      ...;
    if (isa<ShapeInferenceOpInterface>())
      ...;
    

TableGen生成的代码如下:

class CastOpAdaptor {
public:
  CastOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});

  CastOpAdaptor(CastOp &op);
  ...;
};

class CastOp : public ::mlir::Op<CastOp, ::mlir::OpTrait::ZeroRegion, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::TensorType>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::OneOperand, ::mlir::CastOpInterface::Trait, ShapeInference::Trait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape> {
public:
  using Op::Op;
  using Op::print;
  using Adaptor = CastOpAdaptor;
public:
  ...;

ShapeInferenceOpInterface

定义如下:

def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
  let description = [{
    Interface to access a registered method to infer the return types for an
    operation that can be used during type inference.
  }];

  let methods = [
    InterfaceMethod<"Infer and set the output shape for the current operation.",
                    "void", "inferShapes">
  ];
}

所以castOp需要实现inferShapes接口,如下:

void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }

CastOpInterface

def CastOpInterface : OpInterface<"CastOpInterface"> {
  let description = [{
    A cast-like operation is one that converts from a set of input types to a
    set of output types. The arity of the inputs may be from 0-N, whereas the
    arity of the outputs may be anything from 1-N. Cast-like operations are
    trivially removable in cases where they produce an No-op, i.e when the
    input types and output types match 1-1.
  }];
  let cppNamespace = "::mlir";

  let methods = [
    StaticInterfaceMethod<[{
        Returns true if the given set of input and result types are compatible
        to cast using this cast operation.
      }],
      "bool", "areCastCompatible",
      (ins "::mlir::TypeRange":$inputs, "::mlir::TypeRange":$outputs)
    >,
  ];
  ...;
}

所以castOp需要实现areCastCompatible接口:

bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
  if (inputs.size() != 1 || outputs.size() != 1)
    return false;
  // The inputs must be Tensors with the same element type.
  TensorType input = inputs.front().dyn_cast<TensorType>();
  TensorType output = outputs.front().dyn_cast<TensorType>();
  if (!input || !output || input.getElementType() != output.getElementType())
    return false;
  // The shape is required to match if both types are ranked.
  return !input.hasRank() || !output.hasRank() || input == output;
}

SameOperandsAndResultShape

op中加入对所有的输入输出shape校验,确保所有的输入输出shape相同。

OpDefinition.h中定义如下:

template <typename ConcreteType>
class SameOperandsAndResultShape
    : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsAndResultShape(op);
  }
};

OpBase.td中定义如下:

// Op has same operand and result shape.
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;

NoSideEffect

include/mlir/Interfaces/SideEffectInterfaces.td中定义,表示不需要做消除副作用的处理,如下:

void CastOp::getEffects(::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) {
}

什么情况下需要做消除副作用的处理,还不得而知;建议所有op都先加上这个trait。(待研究清楚

通用Traits

参见头文件include/mlir/IR/OpBase.td

// Op defines an affine scope.
def AffineScope : NativeOpTrait<"AffineScope">;
// Op defines an automatic allocation scope.
def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">;
// Op supports operand broadcast behavior.
def ResultsBroadcastableShape :
  NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative  : NativeOpTrait<"IsCommutative">;
// op op X == op X (unary) / X op X == X (binary)
def Idempotent  : NativeOpTrait<"IsIdempotent">;
// op op X == X
def Involution  : NativeOpTrait<"IsInvolution">;
// Op behaves like a constant.
def ConstantLike : NativeOpTrait<"ConstantLike">;
// Op is isolated from above.
def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">;
// Op results are float or vectors/tensors thereof.
def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">;
// Op has the same operand type.
def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
// Op has same shape for all operands.
def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
// Op has same operand and result shape.
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same element type (or type itself, if scalar) for all operands.
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type (or type itself, if scalar).
def SameOperandsAndResultElementType :
  NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
// Op is elementwise on tensor/vector operands and results.
def Elementwise : NativeOpTrait<"Elementwise">;
// Elementwise op can be applied to scalars instead tensor/vector operands.
def Scalarizable : NativeOpTrait<"Scalarizable">;
// Elementwise op can be applied to all-vector operands.
def Vectorizable : NativeOpTrait<"Vectorizable">;
// Elementwise op can be applied to all-tensor operands.
def Tensorizable : NativeOpTrait<"Tensorizable">;

其他

Cpred

用于封装C语音表达式,定义如下:

// `$_builder` will be replaced by a mlir::Builder instance.
//  `$_op` will be replaced by the current operation.
// `$_self` will be replaced with the entity this predicate is attached to.

class CPred<code pred> : Pred {
  code predExpr = "(" # pred # ")";
}

使用举例如下:

// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
// Whether a type is an UnrankedTensorType
def IsUnrankedTensorTypePred: CPred<"$_self.isa<::mlir::UnrankedTensorType>()">;
// Signless integer type of a specific width.
class I<int width>
    : Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
                  width # "-bit signless integer", "::mlir::IntegerType">,
      BuildableType<"$_builder.getIntegerType(" # width # ")"> {
  int bitwidth = width;
}
def I1  : I<1>;
def I8  : I<8>;
def I16 : I<16>;

Signless Integer

IntegerType有三种类型,如下:

// Sized integers like i1, i4, i8, i16, i32.
signed-integer-type ::= `si` [1-9][0-9]*
unsigned-integer-type ::= `ui` [1-9][0-9]*
signless-integer-type ::= `i` [1-9][0-9]*
integer-type ::= signed-integer-type |
                 unsigned-integer-type |
                 signless-integer-type

signless integer的含义可以用这段话来解释,来自Integer signedness semantics

For the standard dialect, the choice is to have signless integer types. An integer value does not have an intrinsic sign, and it’s up to the specific op for interpretation. For example, ops like arith.addi and arith.muli do two’s complement arithmetic, but some other operations get a sign, e.g. arith.divsi vs arith.divui.

也就是不确定符号,根据具体operation的需要来决定它的符号。

评论