# 打开组合算子
export FLAGS_prim_enable_dynamic=true && export FLAGS_prim_all=true
# 打开 CINN 编译器相关 FLAG
export FLAGS_use_cinn=true
export FLAGS_cinn_new_group_scheduler=true
export FLAGS_group_schedule_tiling_first=true
export FLAGS_cinn_bucket_compile=true
# 打开 PIR 模式
export FLAGS_enable_pir_api=true
# 是否打印 Program IR 信息
export FLAGS_print_ir=false
python run_net.py
上述代码示例中我们创建了一个简单的rms_norm
计算子图,使用飞桨的动转静流程将子图转为静态图并调用编译器 CINN 进行优化和执行。经过性能对比测试,在 A100 GPU 环境中上述子图使用 CINN 可以取得 3 倍左右的性能提升(该性能数据仅供学习参考,在实际应用模型中能够取得的性能提升效果一般会低于该数据)。
注:由于飞桨的编译器仍然处在快速迭代开发阶段,我们设置了较多 FLAGS 进行分支的选择和调试,因此现阶段在使用 CINN 时需要对如下 FLAGS(FLAGS_prim_enable_dynamic
、 FLAGS_cinn_new_group_scheduler
、 FLAGS_group_schedule_tiling_first
、 FLAGS_cinn_bucket_compile
、 FLAGS_enable_pir_api
) 进行手动设置,待后续相关功能完备后这些 FLAGS 会默认开启,无需再手动设置。
四、设计架构
飞桨框架编译器(CINN, Compiler Infrastructure for Neural Networks)整体架构如上图所示,大体可以分为三个模块,分别是编译器前端、编译器后端和执行器部分。
1. 编译器前端
一般来说编译器前端需要将不同框架和格式的深度学习模型转换为编译器的内部 IR 并进行图级别的优化,CINN 作为飞桨框架原生编译器,可以直接使用飞桨框架提供的模型加载和中间表示(Paddle IR,简称 PIR)组件,因此 CINN 前端的主要功能是基于 PIR 进行图层级别的优化,并对子图进行划分为后端高性能 Kernel 代码生成提供支持。CINN 前端关键的流程可分为三部分:
a. 组合算子拆分
飞桨框架中将算子划分为基础算子(也称作原子算子,语义上该算子无法更进一步拆分成其他算子。基础算子语义上可以通过重组等价实现组合算子的逻辑)和非基础算子两类大,由于非基础算子数量较多,并且在编译器中较难识别和处理,因此我们使用组合算子拆分的方式将非基础算子拆分为等价的基础算子组合,原始计算图经过组合算子拆分后可以大幅提升性能的可优化空间。
b. 图优化 Pass
在计算图层级进行 PIR 的 Pass 优化,常见的图优化 Pass 包括:常量折叠、死代码消除(DCE)、公共子表达式消除(CSE)、冗余算子消除、算子计算合并等。
c. 算子融合
算子融合是编译器前端非常重要的一个功能,主要是将多个算子打包到一个子图中(对应为一个 FusionOp),交给编译器后端生成一个高效的硬件相关计算 Kernel。 算子融合的本质是通过 IO 优化加速访存密集算子,如果我们将两个连续 Kernel 合并为一个 Kernel 调用,我们会减少中间变量的读写开销,因此在访存密集型的 2 个 Op 上,融合可以获取更高的性能。举个例子,如下图:
编译器后端主要负责将前端处理后的 IR 转换为目标硬件可执行的代码或硬件描述。主要功能包括基于硬件特性的 IR 优化、高效内存管理和代码生成等。
2.1. CINN AST IR
AST IR 打印示例:
ScheduleBlock(root)
serial for (i, 0, 32)
serial for (j_0, 0, 64)
serial for (j_1, 0, 128)
ScheduleBlock(A)
vi, vj = axis.bind(i, j_0 * 64 + j_1) // tensor 下标与循环变量的仿射变换
A[vi, vj] = X[vi, vj] * 2
CINN AST IR 中包含了以下信息,但集合和映射并不显示使用某种数据结构进行存储。
集合:语句实例 & 内存单元
映射:
访存关系:语句实例 <---> 内存单元
依赖关系:语句实例 <---> 语句实例
执行顺序:语句实例 -----> 语句实例
执行顺序 = 语句实例的先后关系
语句实例集合范围 = 循环边界 + 循环步长 ------ 循环构成一个带约束的整数空间,即迭代空间,迭代空间决定了语句实例,语句实例充满了迭代空间。
2.2. 基于 AST IR 的 Schedule
Schedule 为定义在 CINN AST IR 上的优化策略,常见的 Schedule 包括:LoopAlignment, Tile, Inline, Vectorize, Unroll 等。
以一个组合算子为例模拟可能的 AST 变换过程:
[S1, S2, 1024] ==E=> [S1, S2, 1024] ==R=> [S1, S2] ==E=> [S1, S2] ==B=> [S1, S2, 1024] ==E=> [S1, S2, 1024]
(1) LowerToAst 得到的结果
// Elemenwise-1
serial for (i, 0, S1)
serial for (j, 0, S2)
serial for (k, 0, 1024)
ScheduleBlock(A)
vi, vj, vk = axis.bind(i, j, k)
A[vi, vj, vk] = X[vi, vj, vk] * 2
// Elemenwise-2
serial for (i, 0, S1)
serial for (j, 0, S2)
serial for (k, 0, 1024)
ScheduleBlock(B)
vi, vj, vk = axis.bind(i, j, k)
B[vi, vj, vk] = A[vi, vj, vk] + 1
// Reduce-1
serial for (i, 0, S1)
serial for (j, 0, S2)
ScheduleBlock(C__reduce_init)
vi, vj = axis.bind(i, j)
C_init[vi, vj] = 0
serial for (i, 0, S1)
serial for (j, 0, S2)
serial for (k, 0, 1024) // Reduce
ScheduleBlock(C)
vi, vj, vk = axis.bind(i, j, k)
C[vi, vj] = C[vi, vj] + B[vi, vj, vk]
// Elemenwise-3
serial for (i, 0, S1)
serial for (j, 0, S2)
ScheduleBlock(D)
vi, vj = axis.bind(i, j)
D[vi, vj] = C[vi, vj] * 2
// Broadcast-1
serial for (i, 0, S1)
serial for (j, 0, S2)
serial for (k, 0, 1024) // Broadcast
ScheduleBlock(E)
vi, vj, vk = axis.bind(i, j, k)
E[vi, vj, vk] = D[vi, vj]
// Elemenwise-4
serial for (i, 0, S1)
serial for (j, 0, S2)
serial for (k, 0, 1024)
ScheduleBlock(F)
vi, vj, vk = axis.bind(i, j, k)
F[vi, vj, vk] = E[vi, vj, vk] + 1
(2) 迭代空间对齐
// 所有 ScheduleBlock 的 loop nest 都变为以下 2 种格式中的一种
// 1
serial for (sp, 0, S1 * S2) // pure_spatial_iter
serial for (rb, 0, 1024) // impure_spatial_iter
ScheduleBlock(XXX)
vsp1, vsp2, vrb = axis.bind(sp / S2, sp % S2, rb)
XXX = XXXXXX
// 2
serial for (sp, 0, S1 * S2) // pure_spatial_iter
ScheduleBlock(XXX)
vsp1, vsp2 = axis.bind(sp / S2, sp % S2)
XXX = XXXXXX
(3) Tile: 对所有 ScheduleBlock 的 loop nest 做相同的 Tile
// pure_spatial 轴 Tile 为:-1 * 16 * 64 Tile size 可为参数传入
serial for (sp1, 0, S1 * S2 / 1024)
serial for (sp2, 0, 16)
serial for (sp3, 0, 64) // S1 * S2 / 16 / 64, predicate: sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2
XXXXXX
// impure_spatial_iter 轴 Tile 为 32
serial for (sp1, 0, S1 * S2 / 1024)
serial for (sp2, 0, 16)
serial for (sp3, 0, 64)
serial for (rb1, 0, 32)
serial for (rb2, 0, 32)
ScheduleBlock(XXX)
predicate = sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb = axis.bind(rb1 * 32 + rb2)
XXX = XXXXX
(4) ComputeInline
// 例如 ScheduleBlock(A) inline 到 ScheduleBlock(B)
serial for (sp1, 0, S1 * S2 / 1024)
serial for (sp2, 0, 16)
serial for (sp3, 0, 64)
serial for (rb1, 0, 32)
serial for (rb2, 0, 32)
ScheduleBlock(A)
predicate = sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb = axis.bind(rb1 * 32 + rb2)
B[vsp1, vsp2, vrb] = (X[vsp1, vsp2, vrb] * 2) + 1
(5) Reduce 优化: two step reduce & 绑定部分 reduce 轴到 cuda
// 为了简洁,此处省略 reduce_init Block 和 predicate
serial for (sp1, 0, S1 * S2 / 1024)
serial for (sp2, 0, 16)
serial for (sp3, 0, 64)
CudaBind[ThreadIdx.x] for (rb1, 0, 32)
serial for (rb2, 0, 32)
ScheduleBlock(C_rf)
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb1 = axis.bind(rb1)
vrb2 = axis.bind(rb2)
C_rf[vsp1, vsp2, vrb1] = C_rf[vsp1, vsp2, vrb1] + B[vsp1, vsp2, vrb1 * 32 + vrb2]
ScheduleBlock(C)
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb1 = axis.bind(rb1)
C[vsp1, vsp2] = C[vsp1, vsp2] + C_rf[vsp1, vsp2, vrb1]
(6) 循环融合: ComputeAt && SimpleComputeAt,融合外层循环乘积相同的循环,并且保证不破坏图级别依赖(规则负责)和元素级别依赖(原语负责)
serial for (sp1, 0, S1 * S2 / 1024)
serial for (sp2, 0, 16)
serial for (sp3, 0, 64)
ScheduleBlock(D)
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
D[vsp1, vsp2] = C[vsp1, vsp2] * 2
serial for (rb1, 0, 32)
serial for (rb2, 0, 32)
ScheduleBlock(E)
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb = axis.bind(rb1 * 32 + rb2)
E[vsp1, vsp2, vrb] = D[vsp1, vsp2]
ScheduleBlock(F)
vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2)
vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2)
vrb = axis.bind(rb1 * 32 + rb2)
F[vsp1, vsp2, vrb] = E[vsp1, vsp2, vrb] + 1
(7) Bind Cuda 轴:在第二步中,所有 ScheduleBlock 对应的循环要 bind 到同一 Cuda 轴
serial for (sp1, 0, S1 * S2 / 1024)
CudaBind[BlockIdx.x] for (sp2, 0, 16)
CudaBind[ThreadIdx.y] for (sp3, 0, 64)
CudaBind[ThreadIdx.x] for (rb1, 0, 32)
serial for (rb2, 0, 32)
ScheduleBlock(XXX)
2.3. Kernel 代码生成与编译
Codegen 在 CINN IR AST 上做前序遍历,打印出对应硬件的指令,并通过硬件相对应的编译器(如 llvm、nvcc 等)进行编译得到可运行的函数指针,该指针会被封装到 `JitKernelOp`` 中用于后续执行器的解析执行。
a. 以函数定义为例子,cuda kernel func 和 x86 kernel func 的不同的是,cuda kernel func 会在函数名前增加 __global__
针对 x86 硬件,转义 ir::_LoweredFunc_
的代码如下:
void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
PrintFunctionDeclaration(op); // 前序遍历继续转义函数名、函数参数等
str_ += "\n";
在 NV GPU 上的转义代码如下:
void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
str_ += "__global__\n"; // 和 x86 的不同,增加 __global__
PrintFunctionDeclaration(op); // 前序遍历继续转义函数名、函数参数等
str_ += "\n";
b. 在动态形状场景下,还会 codegen 出 infer shape function, infer shape function 的 CINN IR 会在 Bucket Lowering 中得到,转义过程复用的 x86 硬件的 codegen。infer shape kernel 如下:
// infer shape 函数名字的组成:kernel_name + "infer_shape"
// 函数参数:
// kernel_args: 指针数组,和 kernel func args 一致
// kernel_args_num: kernel_args 的长度
// tensor_shape_args: 指针数组,存储输出 tensor 的 shape
function fn_exp_0_subtract_0_infer_shape (kernel_args, kernel_args_num, tensor_shape_args)
int64 S0 = cinn_get_value_in_cuda_kernel_args(kernel_args, 2)
// CINN IR 暂时不支持数据索引的语法,暂时用函数调用实现,下面 2 条语句等价于
// tensor_shape_args[0] = {S0, 256ll};
// 即第 0 个出 tensor 的 shape 为{S0, 256ll};
infer_shape_set_value(0, 0, S0, tensor_shape_args)
infer_shape_set_value(0, 1, 256ll, tensor_shape_args)