本文主要是介绍MLIR笔记(6),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
5. 方言与操作
5.1. 方言的概念
在MLIR里,通过Dialect类来抽象方言。具体的每种方言都需要从这个基类派生一个类型,并实现重载自己所需的虚函数。
MLIR文档里这样描述方言( MLIR Language Reference - MLIR):
方言是这样的机制:它融入并扩展MLIR生态系统。它们允许定义新的操作,以及属性与类型。向每个方言给出唯一的名字空间作为定义的每个属性/操作/类型的前缀。例如,Affine方言定义了名字空间affine。
MLIR允许多个方言共存于一个模块中,即使是在主干之外的那些。方言由特定的遍生成与消费。对不同的方言之间以及方言内部的转换,MLIR提供了一个框架。MLIR支持的几个方言:
- Affine dialect
- GPU dialect
- LLVM dialect
- SPIR-V dialect
- Standard dialect
- Vector dialect
在教程中,还给出了Toy方言的例子。
5.2. 操作的概念
MLIR引入了一个称为操作(operation)的统一概念来描述许多不同的抽象与计算层次。在MLIR系统中,从指令到函数再到模块,一切都塑造为Op。 MLIR没有固定的Op集合,因此允许并鼓励用户自定义扩展Op。编译器遍会保守地对待未知Op,并且MLIR支持通过特征(traits)、特权操作hook和优化接口等方式向遍描述Op语义。MLIR里的操作是完全可扩展的(没有固定的操作列表),并具有应用特定的语义。例如,MLIR支持目标无关操作、仿射(affine)操作,以及目标特定机器操作。
操作的内部表示是简单的:操作由一个唯一字符串标识(如dim、tf.Conv2d、x86.repmovsb、ppc.eieio等),可以返回0或多个结果,接受0或多个操作数,有一个属性字典,有0或多个后继者,以及0或多个封闭的区域。通用打印形式包括所有这些元素,加上一个函数类型来表示结果与操作数的类型。
例子:
// An operation that produces two results.
// The results of %result can be accessed via the <name> `#` <opNo> syntax.
%result:2 = "foo_div"() : () -> (f32, i32)
// Pretty form that defines a unique name for each result.
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
// and an attribute "fruit".
%2 = "tf.scramble"(%result#0, %bar) {fruit = "banana"} : (f32, i32) -> f32
5.3. 方言的管理
显然,管理方言最恰当的地方就是MLIRContext,不过MLIRContext只是上下文的接口,真正的实现是MLIRContextImpl。在MLIRContextImpl中有这样一些容器:
DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects
DialectRegistry dialectsRegistry
llvm::StringMap<AbstractOperation> registeredOperations
llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, llvm::BumpPtrAllocator &> identifiers
第一个容器保存已载入的方言对象,第二个容器记录已注册的方言,第三个容器保存已注册的抽象操作,第四个容器记录上下文里可见的标识符。
这里,方言的注册是指MLIRContext知道这个方言的标识符和构造方法,但方言对象并没有构造。方言对象的构造发生在载入时,在载入时刻,不仅构造方言对象,相关的接口也会一并准备好。抽象操作与操作相关,参考操作的管理。
5.3.1. 方言的注册
MLIR提供了一组标准的方言,它们提供了许多有用的功能。为了让程序能方便地使用标准方言,首先,每个程序在main()的入口都要注册标准方言,像这样:
int main(int argc, char **argv) {
mlir::registerAllDialects();
… // 其他初始化
registerAllDialects()的定义如下:
67 inline void registerAllDialects() {
68 static bool initOnce =
69 ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
70 (void)initOnce;
71 }
69行的getGlobalDialectRegistry()返回一个类型为llvm::ManagedStatic<DialectRegistry>的对象dialectRegistry,它过可视为DialectRegistry的静态对象,这个类通过一个类型为std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>的容器registry来记录标准方言。因此,同一行上的registerAllDialects()重载函数是:
41 inline void registerAllDialects(DialectRegistry ®istry) {
42 // clang-format off
43 registry.insert<acc::OpenACCDialect,
44 AffineDialect,
45 avx512::AVX512Dialect,
46 gpu::GPUDialect,
47 LLVM::LLVMAVX512Dialect,
48 LLVM::LLVMDialect,
49 linalg::LinalgDialect,
50 scf::SCFDialect,
51 omp::OpenMPDialect,
52 pdl::PDLDialect,
53 pdl_interp::PDLInterpDialect,
54 quant::QuantizationDialect,
55 spirv::SPIRVDialect,
56 StandardOpsDialect,
57 vector::VectorDialect,
58 NVVM::NVVMDialect,
59 ROCDL::ROCDLDialect,
60 SDBMDialect,
61 shape::ShapeDialect>();
62 // clang-format on
63 }
这个方法列出了MLIR目前实现的标准方言,DialectRegistry通过一系列insert()方法完成注册:
252 template <typename ConcreteDialect, typename OtherDialect,
253 typename... MoreDialects>
254 void insert() {
255 insert<ConcreteDialect>();
256 insert<OtherDialect, MoreDialects...>();
257 }
241 template <typename ConcreteDialect>
242 void insert() {
243 insert(TypeID::get<ConcreteDialect>(),
244 ConcreteDialect::getDialectNamespace(),
245 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
246 // Just allocate the dialect, the context
247 // takes ownership of it.
248 return ctx->getOrLoadDialect<ConcreteDialect>();
249 })));
250 }
53 void DialectRegistry::insert(TypeID typeID, StringRef name,
54 DialectAllocatorFunction ctor) {
55 auto inserted = registry.insert(
56 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
57 if (!inserted.second && inserted.first->second.first != typeID) {
58 llvm::report_fatal_error(
59 "Trying to register different dialects for the same namespace: " +
60 name);
61 }
62 }
注意,这里只是注册了这些方言的构造方法,并没有把这些方言的Dialect对象构造出来。这是因为Dialect对象的构造需要一个MLIRContext实例,因此要把Dialect对象的构造推迟到MLIRContext对象构造出来后。另外,不是每个程序都需要所有的标准方言,在需要时构造所需的方言才比较合理。所以,DialectRegistry提供了两个函数:loadByName(),loadAll()。前者构造指定名字的标准方言,后者构造所有的标准方言。
5.3.2. 方言的载入
从上面注册的构造方法我们看到,实际执行构造的函数是MLIRContext的getOrLoadDialect(),这也是一系列调用:
69 template <typename T>
70 T *getOrLoadDialect() {
71 return static_cast<T *>(
72 getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
73 std::unique_ptr<T> dialect(new T(this));
74 return dialect;
75 }));
76 }
模板参数T是具体的方言类型,它是Dialect的派生类,Dialect没有定义getDialectNamespace(),派生类必须提供自己的定义。在MLIRContext里这个名字将作为这个方言的身份识别。另外,Dialect及其派生类亦是MLIR类型系统中的组成,它们都有TypeIDMLIRContext::getOrLoadDialect()在2021版本里的定义如下:
511 Dialect *
512 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
513 function_ref<std::unique_ptr<Dialect>()> ctor) {
514 auto &impl = getImpl();
515 // Get the correct insertion position sorted by namespace.
516 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
517
518 if (!dialect) {
519 LLVM_DEBUG(llvm::dbgs()
520 << "Load new dialect in Context " << dialectNamespace << "\n");
521 #ifndef NDEBUG
522 if (impl.multiThreadedExecutionContext != 0)
523 llvm::report_fatal_error(
524 "Loading a dialect (" + dialectNamespace +
525 ") while in a multi-threaded execution context (maybe "
526 "the PassManager): this can indicate a "
527 "missing `dependentDialects` in a pass for example.");
528 #endif
529 dialect = ctor();
530 assert(dialect && "dialect ctor failed");
531
532 // Refresh all the identifiers dialect field, this catches cases where a
533 // dialect may be loaded after identifier prefixed with this dialect name
534 // were already created.
535 llvm::SmallString<32> dialectPrefix(dialectNamespace);
536 dialectPrefix.push_back('.');
537 for (auto &identifierEntry : impl.identifiers)
538 if (identifierEntry.second.is<MLIRContext *>() &&
539 identifierEntry.first().startswith(dialectPrefix))
540 identifierEntry.second = dialect.get();
541
542 // Actually register the interfaces with delayed registration.
543 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
544 return dialect.get();
545 }
546
547 // Abort if dialect with namespace has already been registered.
548 if (dialect->getTypeID() != dialectID)
549 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
550 "' has already been registered");
551
552 return dialect.get();
553 }
从上面可以看到,构造出来的Dialect对象存放在MLIRContextImpl的loadedDialects容器中(类型DenseMap<StringRef, std::unique_ptr<Dialect>>)。同样,MLIRContext也提供这些函数获取指定方言的Dialect对象:getOrLoadDialect(),loadDialect()等。比如,Toy例子代码有这样的代码片段来构建自己的方言对象:
int dumpMLIR() {
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
identifiers 是MLIRContextImpl里有类型为llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, llvm::BumpPtrAllocator &>的容器。MLIR里的标识符是带有上下文前缀或方言前缀的(以“.”分隔),identifiers容器就是关联标识符与其所在上下文对象或方言对象的,一个操作在创建时首先假设它在一个上下文(参考Identifier::get()),上面537行的for循环检查是否已经创建了具有这个方言名的上下文对象,如果是把它替换为对应的方言对象。
5.3.2.1. 方言接口
接下来,通过DialectRegistry::registerDelayedInterfaces()向MLIRContextImpl注册方言的接口。这里“延迟接口”的意思是只在方言载入(或创建)时才注册接口。
106 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
107 auto it = interfaces.find(dialect->getTypeID());
108 if (it == interfaces.end())
109 return;
110
111 // Add an interface if it is not already present.
112 for (const auto &kvp : it->getSecond().dialectInterfaces) {
113 if (dialect->getRegisteredInterface(kvp.first))
114 continue;
115 dialect->addInterface(kvp.second(dialect));
116 }
117
118 // Add attribute, operation and type interfaces.
119 for (const auto &kvp : it->getSecond().objectInterfaces)
120 kvp.second(dialect->getContext());
121 }
方言可用的接口都保存在DialectRegistry类型为DenseMap<TypeID, DelayedInterfaces>的interfaces容器中,其中DelayedInterfaces是DialectRegistry里这样的一个嵌套定义:
283 struct DelayedInterfaces {
284 /// Dialect interfaces.
285 SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
286 dialectInterfaces;
287 /// Attribute/Operation/Type interfaces.
288 SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
289 objectInterfaces;
290 };
在下面我们会看到,方言除了自己的接口,还支持操作/类型/属性的外部模式接口,289行的objectInterfaces是存放这些操作接口的地方。这两个容器用到这两个定义:
30 using DialectInterfaceAllocatorFunction =
31 std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
32 using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
顾名思义,这两个std::function封装的方法用于创建接口对象,在上面115与120行,它们被调用来创建具体的接口对象。Dialect的addInterface()将生成的接口对象保存在registeredInterfaces容器中(类型DenseMap<TypeID, std::unique_ptr<DialectInterface>>)。而120行处,创建的接口对象实际上保存在对应操作/类型/属性抽象对象的容器中,下面会看到。
这篇关于MLIR笔记(6)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!