本文主要是介绍Paddle build_cinn_pass_test源码阅读(fluid目录下),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
代码位置在 paddle\fluid\framework\paddle2cinn\build_cinn_pass_test.cc
,因为paddle CINN和PIR部分依旧在高频更新,所以各位看到的可能和我的不一样
inline bool CheckNodeExisted(const std::unordered_set<Node*>& nodes,const std::string& op_name) {return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {return node->Name() == op_name;}) != nodes.end();
}
用一个内联函数, 去看一个 unordered_set
(一系列节点) 中是否有某个 node 的名字是 op_name,用 std::find_if
去实现, 第三个参数传入的是匿名函数。[&op_name]
闭包被定义在Lambda表达式声明中的方括号[]内. 这个机制允许这些变量被按值或按引用捕获.
函数匿名函数的闭包可以参考这篇文章: https://www.cnblogs.com/pzhfei/archive/2013/01/14/lambda_expression.html
接下来就是返回名字为 op_name
的 node
数量
inline int CountNode(const std::unordered_set<Node*>& nodes,const std::string& op_name) {return std::count_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {return node->Name() == op_name;});
}
接下来是返回节点名字是 op_name
的 节点,注意 std::find_if
前面为啥有 *
呢,因为 find_if
返回一个迭代器, *迭代器
可以返回一个 Node*
inline Node* GetNode(const std::unordered_set<Node*>& nodes,const std::string& op_name) {return *std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {return node->Name().find(op_name) != std::string::npos;});
}
CheckGraphIndependence
内部定义了一个 check_node_ok
匿名函数,匿名函数中 n1
和 n2
都是节点 Node 的指针,
( 说明一下,Paddle PIR之前的节点,节点既有 Op, 也有 Var )
只有 n1
和 n2
一个为 OP, 一个为 Var 才有可能返回 true;
inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool {if (n1->IsOp() && !n2->IsVar()) {return false;}if (n1->IsVar() && !n2->IsOp()) {return false;}if (nodes.count(n2) == 0) {return false;}return true;};for (auto node : nodes) {for (auto in : node->inputs) {if (!check_node_ok(node, in)) {return false;}}for (auto out : node->outputs) {if (!check_node_ok(node, out)) {return false;}}}return true;
}
这里需要说明一下,由于 Paddle pir之前 Op 和 Var 都是node, 所以这样定义
var1 -> op1 -> var2
op3-> var3 -> op4
op1的输入是 var1,输出是 var2,而下边那一行是
va3 的输入是 op3,var3 的输出是 op4 , 这样写有点儿诡异,不过确实是这样定义的
所以 CheckGraphIndependence
的用法就是,首先检查是不是 op->var
和 var->op
的关系,其次就是看当前 op/var
在不在当前 Graph 的 unordered_set<Node*>
中
可以看到之后的调用就是将计算图的节点 g->Nodes()
传入 CheckGraphIndependence
,如果返回值不为 True
则报错
ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));
这个函数主要是将 kCinnLaunchOp
的 operators::kCompilationKey
属性取出来扔到 compilation_keys
这个 vector
中, 目前暂时未知有什么用
// Get compilation_key values
std::vector<int64_t> GetCompilationKeys(const Graph& graph) {std::vector<int64_t> compilation_keys;for (auto& node : graph.Nodes()) {if (node->IsOp() && node->Name() == kCinnLaunchOp) {compilation_keys.emplace_back(PADDLE_GET_CONST(int64_t, node->Op()->GetAttr(operators::kCompilationKey)));}}return compilation_keys;
}
接下来创建一个CINN子图,创建一个空图 Graph
, 之后依次添加 op 和 var
std::unique_ptr<Graph> BuildNoCinnSubgraph() {ProgramDesc prog;auto g = std::make_unique<Graph>(prog);// var1 --// | --> fake1 --> var3 --> fake2 --> var4// var2 --// *Desc 是之后用来创建 OpNode 和 VarNode 的类OpDesc fake1_op;fake1_op.SetType("fake1");OpDesc fake2_op;fake2_op.SetType("fake2");VarDesc var1("var1");VarDesc var2("var2");var2.SetPersistable(true);var2.SetIsParameter(true);VarDesc var3("var3");VarDesc var4("var4");// 之后用 graph 的 Create*Node 来创建对应的 ir::Nodeir::Node* fake1 = g->CreateOpNode(&fake1_op);ir::Node* fake2 = g->CreateOpNode(&fake2_op);ir::Node* v1 = g->CreateVarNode(&var1);ir::Node* v2 = g->CreateVarNode(&var2);ir::Node* v3 = g->CreateVarNode(&var3);ir::Node* v4 = g->CreateVarNode(&var4);// ----------- 创建完 node 之后, 把 op/var 串起来// fill op nodefake1->inputs = {v1, v2};fake1->outputs = {v3};fake2->inputs = {v3};fake2->outputs = {v4};// fill variable nodev1->outputs = {fake1};v2->outputs = {fake1};v3->inputs = {fake1};v3->outputs = {fake2};v4->inputs = {fake2};return g;
}
接下来出现第一个单测
TEST(BuildCinnPassTest, NoCinnSubgraph) {auto g = BuildNoCinnSubgraph(); // 调用上边的函数建计算图auto previous_nodes = g->Nodes(); // 取出计算图的节点// 创建 pass 这个应该是旧IR的passauto pass =paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");// g.get() 返回的是图的指针, g是个 unique_ptr 的智能指针pass->Apply(g.get());// After search, origin graph should no change// 注释的意思是, pass search 之后, 原来的计算图不应当修改ASSERT_EQ(previous_nodes, g->Nodes());ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // 接下来看计算图是否合法且不依赖其他计算图// After search, there should be no cinn subgraphASSERT_TRUE(GetCompilationKeys(*g).empty()); // pass search之后没有 cinn subgraph 子图怎么理解
}
接下来依旧是 BuildAllOpSupportCinnGraph
与上一个建图的函数没啥太大区别
- 图更加复杂
- op 的 type 从
fake2
变成了elementwise_add
|mul
|relu
std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {ProgramDesc prog;auto g = std::make_unique<Graph>(prog);// v1 --// | --> mul --> v3 --// v2 -- | --> add --> v5 --> relu --> v6// v4 --OpDesc add_op;add_op.SetType("elementwise_add");OpDesc mul_op;mul_op.SetType("mul");OpDesc relu_op;relu_op.SetType("relu");VarDesc var1("var1");VarDesc var2("var2");var2.SetPersistable(true);var2.SetIsParameter(true);VarDesc var3("var3");VarDesc var4("var4");VarDesc var5("var5");VarDesc var6("var6");ir::Node* add = g->CreateOpNode(&add_op);ir::Node* mul = g->CreateOpNode(&mul_op);ir::Node* relu = g->CreateOpNode(&relu_op);ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable); // 创建空节点用意是?ir::Node* v1 = g->CreateVarNode(&var1);ir::Node* v2 = g->CreateVarNode(&var2);ir::Node* v3 = g->CreateVarNode(&var3);ir::Node* v4 = g->CreateVarNode(&var4);ir::Node* v5 = g->CreateVarNode(&var5);ir::Node* v6 = g->CreateVarNode(&var6);ir::Node* v7 = g->CreateControlDepVar();// fill op nodemul->inputs = {v0, v1, v2};mul->outputs = {v3};add->inputs = {v3, v4};add->outputs = {v5};relu->inputs = {v5};relu->outputs = {v6, v7};// fill variable nodev0->outputs = {mul};v1->outputs = {mul};v2->outputs = {mul};v3->inputs = {mul};v3->outputs = {add};v4->outputs = {add};v5->inputs = {add};v5->outputs = {relu};v6->inputs = {relu};v7->inputs = {relu};return g;
}
上边这个注释有点儿问题:
// v1 --// | --> mul --> v3 --// v2 -- | --> add --> v5 --> relu --> v6// v4 --
应该改成:
// v0 --|// v1 --| // v2 --| --> mul --> v3 --|// --> v4 --| --> add --> v5 --> relu --> v6// --> v7
接下来的 TEST 和之前的一样,只不过由于图结构变化,pass 之后图结构都变化为 kCinnLaunchOp
TEST(BuildCinnPassTest, AllOpSupportCinn) {auto g = BuildAllOpSupportCinnGraph();auto pass =paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");pass->Apply(g.get());// After search, the graph should as following// v0 --|// v1 --| |--> v6// v2 --| --> kCinnLaunchOp |--> v7// v4 --|const auto& nodes = g->Nodes();ASSERT_EQ(nodes.size(), static_cast<size_t>(7)); // 节点数为 7, 4个输入, 2个输出 和 1 个 Op 节点ASSERT_TRUE(CheckGraphIndependence(nodes)); // 检测该图是否独立,是否会依赖其他图// A new op named kCinnLaunchOp should be addedASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); // kCinnLaunchOp 是个常量字符串, 检测节点 vector 中有无 kCinnLaunchOp auto* cinn_op = GetNode(nodes, kCinnLaunchOp);auto* v0 = GetNode(nodes, "var0");auto* v1 = GetNode(nodes, "var1"); // 依次获取对应的 var Node 指针auto* v2 = GetNode(nodes, "var2");auto* v4 = GetNode(nodes, "var4");auto* v6 = GetNode(nodes, "var6");auto* v7 = GetNode(nodes, Node::kControlDepVarName);// 查看 cinn_op 的输入输出是否与 `v0, v1, v2, v4` 和 `v6, v7` 对应ASSERT_EQ(std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),std::unordered_set<Node*>({v0, v1, v2, v4}));ASSERT_EQ(std::unordered_set<Node*>(cinn_op->outputs.begin(),cinn_op->outputs.end()),std::unordered_set<Node*>({v6, v7}));// 查看 var 节点的输入输出是否是 cinn_op ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));// previous op (mul, add, relu) should all removed// 由于 mul/elementwise_add/relu 被整体合并为 cinn_op 所以图中不应该被搜索到ASSERT_FALSE(CheckNodeExisted(nodes, "mul"));ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add"));ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));// After search, there should has just one cinn subgraph// feed --> v1 --// | --> mul --> v3 --// feed --> v2 -- | --> add --> v5 --> relu --> v6 --> fetch// feed --> v4 --// 获取编译完毕之后的 key, 之后会根据 key 去取对应的 subgraph auto compilation_keys = GetCompilationKeys(*g);ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1)); // 因为只有一个 kCinnLaunchOp 所以 key 的数量也为 1 auto* cinn_compiler = CinnCompiler::GetInstance();const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); // 根据 key 拿对应的子图const auto& subnodes = subgraph.Nodes(); // 拿子图的节点setASSERT_EQ(subnodes.size(), static_cast<size_t>(13));ASSERT_TRUE(CheckGraphIndependence(subnodes));// 该 cinn op 就是这三 mul | elementwise_add | relu 的合体ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));ASSERT_EQ(CountNode(subnodes, "feed"), 3); // 上边注释有 3个feed OpASSERT_EQ(CountNode(subnodes, "fetch"), 1); // 1 个 fetch Op// 在 kCinnLaunchOp 中有参和无参的 node 都应当有 feed Op // No-parameter input should has feed opauto new_v1 = GetNode(subnodes, "var1");ASSERT_EQ(new_v1->inputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v1->outputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");// Parameter input should also have the feed opauto new_v2 = GetNode(subnodes, "var2");ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");// kCinnLaunchOp 输出中应当有 fetch Op// output should has fetch opauto new_v6 = GetNode(subnodes, "var6");ASSERT_EQ(new_v6->inputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v6->outputs.size(), static_cast<size_t>(1));ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
}
第一个单测是只有 fake Op 没办法 pass 优化,第二个单测是所有Op 都支持 CINN Pass, 那下一个就是一半是 fake Op,另一半是 只是 CINN Pass 的 OP
std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {ProgramDesc prog;auto g = std::make_unique<Graph>(prog);// fake1 --> v1 --// | --> mul --> v3 --> relu --> v4 --> fake2// v2 --OpDesc fake1_op;fake1_op.SetType("fake1");OpDesc mul_op;mul_op.SetType("mul");OpDesc relu_op;relu_op.SetType("relu");OpDesc fake2_op;fake2_op.SetType("fake2");VarDesc var1("var1");VarDesc var2("var2");var2.SetPersistable(true);var2.SetIsParameter(true);VarDesc var3("var3");VarDesc var4("var4");ir::Node* fake1 = g->CreateOpNode(&fake1_op);ir::Node* mul = g->CreateOpNode(&mul_op);ir::Node* relu = g->CreateOpNode(&relu_op);ir::Node* fake2 = g->CreateOpNode(&fake2_op);ir::Node* v1 = g->CreateVarNode(&var1);ir::Node* v2 = g->CreateVarNode(&var2);ir::Node* v3 = g->CreateVarNode(&var3);ir::Node* v4 = g->CreateVarNode(&var4);// fill op nodefake1->outputs = {v1};mul->inputs = {v2, v1};mul->outputs = {v3};relu->inputs = {v3};relu->outputs = {v4};fake2->inputs = {v4};// fill variable nodev2->outputs = {mul};v1->inputs = {fake1};v1->outputs = {mul};v3->inputs = {mul};v3->outputs = {relu};v4->inputs = {relu};v4->outputs = {fake2};return g;
}
上边的函数就是建立了一个这样的一个图
// fake1 --> v1 --// | --> mul --> v3 --> relu --> v4 --> fake2// v2 --
通过 cinn pass 之后这个图的节点变成下边儿这样:
// fake1 --> v1 --// | --> kCinnLaunchOp --> v4 --> fake2// v2 --
只有一个 kCinnLaunchOp 其子图为,有9个节点
// feed --> v1 --// | --> mul --> v3 --> relu --> v4 --> fetch// feed --> v2 --
之前的图是单个 cinn op,下一个单测是多个 cinn op 的情况:
std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {ProgramDesc prog;auto g = std::make_unique<Graph>(prog);// fake1 --> v1 --// | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3// v2 --OpDesc fake1_op;fake1_op.SetType("fake1");OpDesc mul_op;mul_op.SetType("mul");OpDesc relu_op;relu_op.SetType("relu");OpDesc fake2_op;fake2_op.SetType("fake2");OpDesc fake3_op;fake3_op.SetType("fake3");VarDesc var1("var1");VarDesc var2("var2");var2.SetPersistable(true);var2.SetIsParameter(true);VarDesc var3("var3");VarDesc var4("var4");VarDesc var5("var5");ir::Node* fake1 = g->CreateOpNode(&fake1_op);ir::Node* mul = g->CreateOpNode(&mul_op);ir::Node* relu = g->CreateOpNode(&relu_op);ir::Node* fake2 = g->CreateOpNode(&fake2_op);ir::Node* fake3 = g->CreateOpNode(&fake3_op);ir::Node* v1 = g->CreateVarNode(&var1);ir::Node* v2 = g->CreateVarNode(&var2);ir::Node* v3 = g->CreateVarNode(&var3);ir::Node* v4 = g->CreateVarNode(&var4);ir::Node* v5 = g->CreateVarNode(&var5);// fill op nodefake1->outputs = {v1};mul->inputs = {v2, v1};mul->outputs = {v3};fake2->inputs = {v3};fake2->outputs = {v4};relu->inputs = {v4};relu->outputs = {v5};fake3->inputs = {v5};// fill variable nodev2->outputs = {mul};v1->inputs = {fake1};v1->outputs = {mul};v3->inputs = {mul};v3->outputs = {fake2};v4->inputs = {fake2};v4->outputs = {relu};v5->inputs = {relu};v5->outputs = {fake3};return g;
}
以上代码建立一个这样的图:
// fake1 --> v1 --// | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3// v2 --
以 fake2
op 为界,可以建立两个 cinn op pass
// fake1 -> v1 -// | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3// v2 -
cinn pass 就两句代码:
auto pass =paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");pass->Apply(g.get());
此处是检验有两个 cinn pass Op 的代码:
// A new op named kCinnLaunchOp should be addedASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2);
最后的编译结果是 cinn pass 之后有两个 子图:
// subgraph1:// feed --> v4 --> relu --> v5 --> fetch// subgraph2:// feed --> v1 --// | --> mul --> v3 --> fetch// v2 --
BuildGraphWithNoNeedBufferInput
就是建立一个这样的子图:
// fake1 --> v1 -- --> v4 --> relu_grad --> v6// v2 -- | --> add_grad |// v3 -- --> v5 --> fake2
BuildGraphWithNoNeedBufferInput
与之前不同的是,add_grad_op
使用了设置输入的 API SetInput
OpDesc add_grad_op;add_grad_op.SetType("elementwise_add_grad");add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"});add_grad_op.SetInput("X", {"var2"});add_grad_op.SetInput("Y", {"var3"});
之后的单测写了,no_need_buffer_x
不知道什么意思.
// A new op named kCinnLaunchOp should be added and// its input arguments are set correctlyASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1);auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp);ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX),std::vector<std::string>({"var1"}));auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX);ASSERT_EQ(std::unordered_set<std::string>(no_need_buffer_x.begin(),no_need_buffer_x.end()),std::unordered_set<std::string>({"var2", "var3"}));
这里的 no_need_buffer_feeds
什么意思??
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad"));ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad"));ASSERT_EQ(CountNode(subnodes, "feed"), 3);ASSERT_EQ(CountNode(subnodes, "fetch"), 2);const auto& no_need_buffer_feeds =subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);ASSERT_EQ(no_need_buffer_feeds.size(), 2);ASSERT_EQ(no_need_buffer_feeds,std::unordered_set<std::string>({"var2", "var3"}));// check the attributes of variable lists are saved correctlyASSERT_TRUE(subgraph.Has(kInputVars));EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInputVars),std::vector<std::string>({"var1"}));ASSERT_TRUE(subgraph.Has(kInternalVars));EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInternalVars),std::vector<std::string>({"var4"}));ASSERT_TRUE(subgraph.Has(kOutputVars));const auto& output_vars = subgraph.Get<std::vector<std::string>>(kOutputVars);EXPECT_EQ(std::unordered_set<std::string>(output_vars.begin(), output_vars.end()),std::unordered_set<std::string>({"var5", "var6"}));
TEST(BuildCinnPassTest, TestSkipGcVars){auto g = BuildGraphWithOneCinnSubgraph();// 这里什么意思????std::unordered_set<std::string> all_skip_gc_vars = {"var1", "var3"};g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars);auto pass =paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");pass->Apply(g.get());// After search, the graph should as following// fake1 --> v1 --// | --> kCinnLaunchOp --> v4 --> fake2// v2 --const auto& nodes = g->Nodes();ASSERT_EQ(nodes.size(), static_cast<size_t>(7)); // 这里为啥变成了 7ASSERT_TRUE(CheckGraphIndependence(nodes));// A new op named kCinnLaunchOp should be addedASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));// After search, there should has just one cinn subgraph// Note v3 has fetched because of v3 in kSkipGcVarNames// And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames// feed --> v1 --// | --> mul --> v3 --> relu --> v4 --> fetch// feed --> v2 -- --> fetchauto compilation_keys = GetCompilationKeys(*g);ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));auto* cinn_compiler = CinnCompiler::GetInstance();const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);const auto& subnodes = subgraph.Nodes();ASSERT_EQ(subnodes.size(), static_cast<size_t>(10));ASSERT_TRUE(CheckGraphIndependence(subnodes));ASSERT_EQ(CountNode(subnodes, "feed"), 2);// var3 and var4 should has fetch opASSERT_EQ(CountNode(subnodes, "fetch"), 2);
}
最后两个 TEST
没看懂,留下问题
这篇关于Paddle build_cinn_pass_test源码阅读(fluid目录下)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!