def rgb2centered_yuv(input: Tensor, swing: str = "studio") -> Tensor:
"""Convert color space.
Convert images from RGB format to centered YUV444 BT.601
Args:
input: input image in RGB format, ranging 0~255
swing: "studio" for YUV studio swing (Y: -112~107,
U, V: -112~112)
"full" for YUV full swing (Y, U, V: -128~127).
default is "studio"
Returns:
output: centered YUV image
函数输入为 RGB 图像,输出为 centered YUV 图像。其中,centered YUV 是指减去了 128 的偏置的 YUV 图像,这是 BPU 图像金字塔输出的标准图像格式。对于 full swing 而言,其范围应为 -128~127。您可以通过 swing
参数控制 full 和 studio 的取向。为了和 BPU 数据流格式对齐,请您将 swing
设为 “full”。
4.2.5.3.4. 在推理时对 YUV 输入进行实时转换
在任何情况下,我们都推荐您使用上述介绍的方案,即在训练时就将 RGB 图像转成 YUV 格式,这样可以避免在推理时引入额外的性能开销和精度损失。但如果您已经使用了 RGB 图像训练了模型,我们也提供了补救措施,通过在推理的时候在模型输入处插入颜色空间转换算子,将输入的 YUV 图像实时转换为 RGB 格式,从而支持 RGB 模型的上板部署,避免您重新训练模型给您带来时间成本和资源上的损失。由于该算子随模型运行在 BPU 上,底层采用定点运算实现,因而不可避免地会引入一定的精度损失,因此仅作为补救方案,请您尽可能按照我们所推荐的方式对数据进行处理。
4.2.5.3.4.1. 算子定义
您可以在推理模型的开头(QuantStub 的后面)插入 horizon.functional.centered_yuv2rgb
或 horizon.functional.centered_yuv2bgr
算子实现该功能。以 centered_yuv2rgb
为例,其定义为:
def centered_yuv2rgb(
input: QTensor,
swing: str = "studio",
mean: Union[List[float], Tensor] = (128.0,),
std: Union[List[float], Tensor] = (128.0,),
q_scale: Union[float, Tensor] = 1.0 / 128.0,
) -> QTensor:
swing
为 YUV 的格式,可选项为 “full” 和 “studio”。为了和 BPU 的 YUV 数据格式对齐,请您将 swing
设为 “full”。
mean
, std
均为您在训练时 RGB 图像所使用的归一化均值、标准差,支持 list 和 torch.Tensor 两种输入类型,支持单通道或三通道的归一化参数。如您的归一化均值为 [128, 0, -128] 时,您可以传入一个 [128., 0., -128.] 的 list 或 torch.tensor([128., 0., -128.])。
q_scale
为您在量化训练阶段所用的 QuantStub 的 scale 数值。支持 float 和 torch.Tensor 两种数据类型。
该算子完成了以下操作:
根据给定的 swing
所对应的转换公式将输入图像转换成 RGB 格式
使用给定的 mean
和 std
对 RGB 图像进行归一化
使用给定的 q_scale
对 RGB 图像进行量化
由于该算子已经包括了对 RGB 图像的量化操作,因此在插入这个算子后您需要手动地将模型 QuantStub 的 scale 参数更改为 1。
插入该算子后的部署模型如下图所示:
该算子为部署专用算子,请勿在训练阶段使用该算子。
4.2.5.3.4.2. 使用方法
在您使用 RGB 图像完成量化训练后,您需要:
获取量化训练时模型 QuantStub 所使用的 scale 值,以及 RGB 图像所使用的归一化参数
在模型的 QuantStub 后面插入 centered_yuv2rgb
算子,算子需要传入步骤 1 中所获取的参数
调用我们工具的接口将 qat 模型转换为 quantized 模型
将 QuantStub 的 scale
参数修改成 1
import torch
from horizon_plugin_pytorch.quantization import QuantStub, prepare_qat_fx, convert_fx
from horizon_plugin_pytorch.functional import centered_yuv2rgb
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.is_quantized = False
self.quant = QuantStub()
self.convnet = ConvNet()
def forward(self, input):
x = self.quant(input)
if self.is_quantized:
x = centered_yuv2rgb(
mean=torch.tensor([128]),
std=torch.tensor([128]),
# must be set epual to qat_net.quant.activation_post_process.scale
q_scale=torch.tensor([1 / 128]),
x = self.convnet(x)
return x
def set_qconfig(self):
data = torch.rand(1, 3, 28, 28)
net = Net()
net.set_qconfig()
qat_net = prepare_qat_fx(net)
qat_net(data)
qat_net.is_quantized = True
quantized_net = convert_fx(qat_net)
quantized_net.quant.scale.fill_(1.0)
quantized_net(data)