class ANN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
def forward(self,x):
x = self.network(x)
return x
注意:如果遇到需要将tensor展开的情况,就在网络中定义一个 nn.Flatten
模块,在forward函数中需要使用定义的Flatten而不是view函数。
定义我们的超参数:
device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
model_name = input('输入模型名字,例如“mnist”\n input model name, for log_dir generating , e.g., "mnist": ')
之后的所有临时文件都会储存到文件夹中。
初始化数据加载器、网络、优化器、损失函数:
# 初始化网络
ann = ANN().to(device)
# 定义损失函数
loss_function = nn.CrossEntropyLoss()
# 使用Adam优化器
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)
训练ANN,并定期测试。训练时也可以使用utils中预先写好的训练程序:
for epoch in range(train_epoch):
# 使用utils中预先写好的训练程序训练网络
# 训练程序的写法和经典ANN中的训练也是一样的
# Train the network using a pre-prepared code in ''utils''
utils.train_ann(net=ann,
device=device,
data_loader=train_data_loader,
optimizer=optimizer,
loss_function=loss_function,
epoch=epoch
# 使用utils中预先写好的验证程序验证网络输出
# Validate the network using a pre-prepared code in ''utils''
acc = utils.val_ann(net=ann,
device=device,
data_loader=test_data_loader,
epoch=epoch
if best_acc <= acc:
utils.save_model(ann, log_dir, model_name+'.pkl')
完整的代码位于 ann2snn.examples.cnn_mnist.py
,在代码中我们还使用了Tensorboard来保存训练日志。可以直接在Python命令行运行它:
>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
>>> cnn_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
input device, e.g., "cpu" or "cuda:0": cuda:15
输入保存MNIST数据集的位置,例如“./”
input root directory for saving MNIST dataset, e.g., "./": ./mnist
输入batch_size,例如“64”
input batch_size, e.g., "64": 128
输入学习率,例如“1e-3”
input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“100”
input simulating steps, e.g., "100": 100
输入训练轮数,即遍历训练集的次数,例如“10”
input training epochs, e.g., "10": 10
输入模型名字,用于自动生成日志文档,例如“cnn_mnist”
input model name, for log_dir generating , e.g., "cnn_mnist"
Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669
Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835
Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889
100%|██████████| 100/100 [00:00<00:00, 116.12it/s]
Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881
Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl
......
示例中,这个模型训练10个epoch。训练时测试集准确率变化情况如下:
最终达到98.8%的测试集准确率。
从训练集中,取出一部分数据,用于模型的归一化步骤。这里我们取192张图片。
# 加载用于归一化模型的数据
# Load the data to normalize the model
percentage = 0.004 # load 0.004 of the data
norm_data_list = []
for idx, (imgs, targets) in enumerate(train_data_loader):
norm_data_list.append(imgs)
if idx == int(len(train_data_loader) * percentage) - 1:
break
norm_data = torch.cat(norm_data_list)
print('use %d imgs to parse' % (norm_data.size(0)))
调用ann2snn
中的类parser,并使用ONNX kernel。
我们可以保存好我们转换好的snn模型,并且定义一个plt.figure用于绘图
torch.save(snn, os.path.join(log_dir,'snn-'+model_name+'.pkl'))
fig = plt.figure('simulator')
现在,我们定义用于SNN的仿真器。由于我们的任务是分类,选择类``classify_simulator``
sim = classify_simulator(snn,
log_dir=log_dir + '/simulator',
device=simulator_device,
canvas=fig
sim.simulate(test_data_loader,
T=T,
online_drawer=True,
ann_acc=ann_acc,
fig_name=model_name,
step_max=True
模型仿真由于时间较长,我们设计了tqdm的进度条用于预估仿真时间。仿真结束时会有仿真器的summary
simulator is working on the normal mode, device: cuda:0
100%|██████████| 100/100 [00:46<00:00, 2.15it/s]
--------------------simulator summary--------------------
time elapsed: 46.55072790000008 (sec)
---------------------------------------------------------
通过最后的输出,可以知道,仿真器使用了46.6s。转换后的SNN准确率可以从simulator文件夹中plot.pdf看到,最高的转换准确率为98.51%。转换带来了0.37%的性能下降。通过增加推理时间可以减少转换损失。
1(1,2,3,4,5)
Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
Diehl, Peter U. , et al. Fast classifying, high-accuracy spiking deep networks through weight and threshold balancing. Neural Networks (IJCNN), 2015 International Joint Conference on IEEE, 2015.
Rueckauer, B., Lungu, I. A., Hu, Y., & Pfeiffer, M. (2016). Theory and tools for the conversion of analog to spiking convolutional neural networks. arXiv preprint arXiv:1612.04052.
Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.