def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2)
或 torch.cat([x, y]).mul(2)
(条件已扁平化),但为了确定我们处于哪个分支,我们需要知道 z
的大小,这是一个中间结果。因为 TorchDynamo 必须提前知道编译的跟踪是否有效(我们不支持类似某些 JIT 编译器的故障转移),所以我们必须能够将 z.size(0)
作为输入的表达式来减少,x.size(0) + y.size(0)
。这是通过为 PyTorch 中的所有操作符编写元函数来完成的,这些函数可以将大小信息传播到张量的输出,而无需实际对节点执行计算。
总体架构
符号形状工作流程
当我们在 Dynamo 中开始编译一个帧时,我们会分配一个 ShapeEnv(附加到 FakeTensorMode),它用于跟踪符号形状状态。
我们在进入时为张量分配符号大小(什么是静态或动态是一个策略决定,有一些旋钮)。
我们通过算子传播符号大小,同时维护 (1) FX IR,以便我们可以忠实地导出符号计算,以及 (2) 表示大小变量的 Sympy 表达式,以便我们可以对它们进行推理。
当我们在 Dynamo 跟踪或 Inductor 优化中根据符号大小进行条件判断时,我们会根据条件添加保护措施。这些保护措施可以从 Python 和 C++ 中推断出来。
这些保护措施可以对符号变量进行进一步简化。例如,如果你断言 s0 == 4
,我们现在可以将所有 s0
的出现替换为 4
。
当我们完成跟踪和优化后,我们会将所有这些保护措施与编译后的代码一起安装;只有当所有保护措施都评估为真时,编译后的代码才能重复使用。
C++ SymInt API: c10/core/SymInt.h
, SymFloat.h
, SymBool.h
Python SymInt API: torch/__init__.py
(查找 SymInt/SymFloat/SymBool
)
C++ 管道:c10/core/SymNodeImpl.h
, torch/csrc/utils/python_symnode.h
, torch/csrc/jit/python/init.cpp
Python 基础设施:torch/fx/experimental/symbolic_shapes.py
其他重要文件:torch/_subclasses/fake_tensor.py
, torch/_meta_registrations.py
, 分解,PrimTorch 引用
简化内部 API
了解 Python 类层次结构
SymInt/SymFloat/SymBool:这些是用户可见的类,模拟它们的 int/float/bool 对应类。如果你添加两个 SymInt,我们会给你一个新的 SymInt,它会象征性地跟踪整数加法是否已经发生。
SymNode:这是内部结构(可以通过例如 symint.node
访问),它保存实际的符号跟踪信息。SymNode 是类型擦除的;这使得用它来表示混合类型操作更加方便。请注意,从技术上讲,你无需从 SymInt 调用 Python SymNode;例如,XLA 的 C++ SymNodeImpl
将代替 SymNode。
ShapeEnv:每个编译上下文状态,用于跟踪到目前为止积累的所有自由符号和保护措施。每个 SymNode 记录其 ShapeEnv(反之则不然;只有当 SymNode 参与保护措施时,它们才会被使用)。
C++ 相当类似
c10::SymInt/SymFloat/SymBool:模拟 int/float/bool 的用户可见类。
c10::SymNode/SymNodeImpl:类似于 SymNode
C++ 中没有 ShapeEnv;为了便于调试,整个符号推理机制都在 Python 中。
当你编写可以用 make_fx
跟踪的代码时,它必须能够处理 SymInt/SymFloat/SymBool 流经它。动态形状手册 提供了一些关于如何做到这一点的指导。
DimDynamic 策略
Sympy 使用说明
DimDynamic/Constraint
无支撑 SymInt
为了解决控制流问题,我们会检查符号整数的提示(也称为实际值),以确定要执行哪个分支。但是,在某些情况下,我们可能没有提示:当大小变量从像 .nonzero()
或 .item()
这样的数据相关操作中出现时,就会出现所谓的无支撑符号整数。对这些符号整数执行控制流是非法的,因此我们必须在这些操作上进行图形中断。
如果天真地实现,这将过于严格:如果你尝试使用无支撑符号整数执行任何操作,大多数 PyTorch 程序会立即失败。以下是使其真正起作用的最重要的增强功能
在张量创建时,PyTorch 会预先计算关于张量的大量数据;例如,如果你使用 empty_strided
创建张量,我们会急切地对步幅进行排序并确定张量是否是非重叠且密集的。排序会产生大量保护措施。但是,更常见的是使用像 empty
这样的更高级别 API 直接生成张量,这保证会生成非重叠且密集的张量。我们修改了 PyTorch,避免不必要地重新计算这些属性。
即使需要进行非平凡的计算,有时也根本不会查询某个属性。使这些预先计算的属性延迟化,使我们能够避免在无支撑符号整数上进行保护,除非它实际上需要。
整数张量中的数据通常不知道是非负的。但是,我们提供了 constrain_range
API,用户可以通过它指定大小的上限和下限。
在 PT2 的未来版本(超出 PT2.1)中,我们将扩展我们的推理系统,根据用法推断出无支撑符号整数是类似大小的。例如,如果你将 .item()
调用的结果传递给像 torch.empty
这样的工厂函数,我们会自动推断结果是一个大小(因为如果不是,它就会失败)。这个假设会在运行时得到验证,如果它没有得到满足,就会引发错误。