SENet[1]是ImageNet 2017年的冠军模型,自SeNet提出后,ImageNet挑战赛就停止举办了。SENet同之前的ResNet一样,引入了一些技巧,可以在很大程度上降低模型的参数,并且提升模型的运算速度。
SENet全称Squeeze-and-Excitation Networks,中文名可以翻译为挤压和激励网络。SENet在ImageNet 2017取得了第一名的成绩,Top-5 error rate降低到了2.251%,官方的模型和代码在github仓库中可以找到[2]。
SE block
SE模块可以看作是一个计算单元,用 F t r F_{tr} Ftr 表示,可以将输入 X ∈ R H ′ × W ′ × C ′ X \in \R^{H' \times W' \times C'} X∈RH′×W′×C′ 映射为特征图 U ∈ R H × W × C U \in \R^{H \times W \times C} U∈RH×W×C。以下的符号中, F t r F_{tr} Ftr 表示卷积操作, V = [ V 1 , V 2 , … , V C ] \bold{V}=[V_1, V_2, \dots, V_C] V=[V1,V2,…,VC] 来表示学习到的一组滤波器核,其中 V c V_c Vc 表示的是第 c c c 个滤波器的参数,所以输出可以表示为 U = [ U 1 , U 2 , … , U C ] \bold{U}=[U_1, U_2, \dots, U_C] U=[U1,U2,…,UC],其中:
U c = V c ∗ X = ∑ s = 1 C ′ V c s ∗ X s U_c=V_c * \bold{X}=\sum_{s=1}^{C'}V_c^s * X^s Uc=Vc∗X=s=1∑C′Vcs∗Xs
公式1中 ∗ * ∗ 表示的是卷积操作, V c = [ V c 1 , V c 2 , … , V c C ′ ] , X = [ X 1 , X 2 , … , X C ′ ] V_c=[V_c^1, V_c^2, \dots, V_c^{C'}], \quad \bold{X}=[X^1, X^2, \dots, X^{C'}] Vc=[Vc1,Vc2,…,VcC′],X=[X1,X2,…,XC′] 以及 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W, V c s V_c^s Vcs 表示的是 X \bold{X} X 对应单个 V c V_c Vc 通道的 2D 空间核。
- 为了简化符号表达,省略了偏差项
- 从以上的卷积公式可以看出,各个通道的卷积进行了求和操作,所以通道的特征信息和卷积核学习到的空间关系混合到一起,所以需要分离两个特征信息,让模型学习到通道的特征关系
Squeeze: Global Information Embedding
为了解决这个问题,SeNet 将全局空间信息压缩到通道描述符中,这是通过使用全局平均池化(global average pooling)来生成通道统计数据来实现的。形式上,统计量 Z ∈ R C Z \in \R^C Z∈RC 是通过收缩 U U U 的空间维度 H × W H \times W H×W 来生成的,从而 Z Z Z 的第 c c c 个元素通过以下方式计算:
z c = F s q ( u c ) = 1 H × W ∑ i = 1 H ∑ j = 1 W u c ( i , j ) z_c = F_{sq}(u_c)=\frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W u_c (i, j) zc=Fsq(uc)=H×W1i=1∑Hj=1∑Wuc(i,j)
Excitation: Adaptive Recalibration
为了利用在 Squeeze 操作中聚集到的信息,接下来进行第二个操作,目的是为了完全捕获通道依赖信息。为了实现这一目标,该功能必须满足两个标准:
- 它必须要是灵活的,特别地,它必须能够学习通道之间的非线性相互作用
- 它必须学习一种非互斥的关系,因为我们希望确保允许强调多个通道
为了满足这些标准,这里选择了带有 sigmoid 激活函数的简单门控机制:
s = F e x ( z , W ) = σ ( g ( z , W ) ) = σ ( W 2 δ ( W 1 z ) ) s = F_{ex}(z, W) = \sigma (g(z, \bold{W}))=\sigma(\bold{W}_2 \delta(\bold{W}_1 z)) s=Fex(z,W)=σ(g(z,W))=σ(W2δ(W1z))
其中 δ \delta δ 表示的是 ReLU 函数, W 1 ∈ R C r × C , W 2 ∈ R C × C r \bold{W}_1 \in \R^{\frac{C}{r} \times C} ,\quad \bold{W}_2 \in \R^{C \times \frac{C}{r}} W1∈RrC×C,W2∈RC×rC 。为了降低模型复杂度以及提升泛化能力,这里用到了两个全连接层的bottleneck结构,其中第一个全连接层起到降维的作用,降维系数为r是个超参数,然后采用ReLU激活,最后的全连接层恢复原始的维度,最后将学习到的各个通道的激活值(sigmoid激活,值为0~1)乘上 U U U 上的原始特征:
x ~ c = F s c a l e ( u c , s c ) = s c ⋅ u c \tilde{x}_c = F_{scale}(u_c, s_c) = s_c \cdot u_c x~c=Fscale(uc,sc)=sc⋅uc
其中 X ~ = [ X ~ 1 , X ~ 2 , … , X ~ C ] \widetilde{\bold{X}}=[\widetilde{X}_1, \widetilde{X}_2, \dots, \widetilde{X}_C] X =[X 1,X 2,…,X C], F s c a l e ( u c , s c ) F_{scale}(u_c, s_c) Fscale(uc,sc) 表示的是标量 S c S_c Sc 和特征图 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W 的乘法
SE block的应用
2 r ∑ s = 1 s N s ⋅ C s 2 \frac{2}{r}\sum_{s=1}^s N_s \cdot C_s^2 r2s=1∑sNs⋅Cs2
其中 r r r 表示的是降维系数, S S S 表示的是级数(the number of stages),一个级数指的是对公共空间维度的特征图进行操作的块的集合, C s C_s Cs 表示的输出通道的维度, N s N_s Ns 表示的级数 S S S 重复块的数量。
当 r = 16 r=16 r=16 时, SE-ResNet-50 只增加了约10%的参数量,但是计算量却增加不到1%
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
import timedevice = ('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]))test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# Squeeze and Excitation Block Module
class SEBlock(nn.Module):def __init__(self, channels, reduction=16):super(SEBlock, self).__init__()self.fc = nn.Sequential(nn.Conv2d(channels, channels // reduction, 1, bias=False),nn.ReLU(),nn.Conv2d(channels // reduction, channels * 2, 1, bias=False),)def forward(self, x):w = F.adaptive_avg_pool2d(x, 1) # Squeezew = self.fc(x)w, b = w.split(w.data.size(1) // 2, dim=1) # Excitationw = torch.sigmoid(w)return x * w + b # Scale and add bias
# Residual Block with SEBlock
class ResBlock(nn.Module):def __init__(self, channels):super(ResBlock, self).__init__()self.conv_lower = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=False),nn.BatchNorm2d(channels),nn.ReLU())self.conv_upper = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, bias=False),nn.BatchNorm2d(channels))self.se_block = SEBlock(channels)def forward(self, x):path = self.conv_lower(x)path = self.conv_upper(path)path = self.se_block(path)path = x + pathreturn F.relu(path)
# Network Module
class Network(nn.Module):def __init__(self, in_channel, filters, blocks, num_classes):super(Network, self).__init__()self.conv_block = nn.Sequential(nn.Conv2d(in_channel, filters, 3, padding=1, bias=False),nn.BatchNorm2d(filters),nn.ReLU())self.res_blocks = nn.Sequential(*[ResBlock(filters) for _ in range(blocks - 1)])self.out_conv = nn.Sequential(nn.Conv2d(filters, 128, 1, padding=0, bias=False),nn.BatchNorm2d(128),nn.ReLU())self.fc = nn.Linear(128, num_classes)def forward(self, x):x = self.conv_block(x)x = self.res_blocks(x) x = self.out_conv(x)x = F.adaptive_avg_pool2d(x, 1)x = x.view(x.data.size(0), -1)x = self.fc(x)return F.log_softmax(x, dim=1)
net = Network(3, 128, 10, 10).to(device)
ACE = nn.CrossEntropyLoss().to(device)
opt = optim.SGD(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, momentum=.9, nesterov=True)
for epoch in range(1, EPOCHS + 1):print('[Epoch %d]' % epoch)train_loss = 0train_correct, train_total = 0, 0start_point = time.time()for inputs, labels in train_loader:inputs, labels = Variable(inputs).to(device),Variable(labels).to(device)opt.zero_grad()preds = net(inputs)loss = ACE(preds, labels)loss.backward()opt.step() train_loss += loss.item()train_correct += (preds.argmax(dim=1) == labels).sum().item()train_total += len(preds)print('train-acc : %.4f%% train-loss : %.5f' % (100 * train_correct / train_total, train_loss / len(train_loader)))print('elapsed time: %ds' % (time.time() - start_point))test_loss = 0test_correct, test_total = 0, 0for inputs, labels in test_loader:with torch.no_grad():inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)preds = net(inputs)test_loss += ACE(preds, labels).item()test_correct += (preds.argmax(dim=1) == labels).sum().item()test_total += len(preds)print('test-acc : %.4f%% test-loss : %.5f' % (100 * test_correct / test_total, test_loss / len(test_loader)))torch.save(net.state_dict(), './data/checkpoint/checkpoint-%04d.bin' % epoch)
[Epoch 1]
train-acc : 62.9240% train-loss : 1.02725
elapsed time: 167s
test-acc : 59.9800% test-loss : 1.13711
[Epoch 2]
train-acc : 69.3160% train-loss : 0.85710
elapsed time: 170s
test-acc : 67.6300% test-loss : 0.92139
[Epoch 3]
train-acc : 73.9000% train-loss : 0.74356
elapsed time: 171s
test-acc : 70.7700% test-loss : 0.84002
[Epoch 4]
train-acc : 77.2340% train-loss : 0.65098
elapsed time: 171s
test-acc : 74.3400% test-loss : 0.75001
[Epoch 5]
train-acc : 79.7560% train-loss : 0.58424
elapsed time: 171s
test-acc : 74.8000% test-loss : 0.71813
[Epoch 6]
train-acc : 81.8820% train-loss : 0.52713
elapsed time: 171s
test-acc : 77.7400% test-loss : 0.66449
[Epoch 7]
train-acc : 83.0260% train-loss : 0.49098
elapsed time: 171s
test-acc : 79.3000% test-loss : 0.60599
[Epoch 8]
train-acc : 84.2880% train-loss : 0.45633
elapsed time: 171s
test-acc : 78.0500% test-loss : 0.64819
[Epoch 9]
train-acc : 85.2660% train-loss : 0.43147
elapsed time: 171s
test-acc : 80.7400% test-loss : 0.57734
[Epoch 10]
train-acc : 86.2080% train-loss : 0.39924
elapsed time: 171s
test-acc : 81.9000% test-loss : 0.53836
[Epoch 11]
train-acc : 86.9320% train-loss : 0.38040
elapsed time: 171s
test-acc : 82.7100% test-loss : 0.51160
[Epoch 12]
train-acc : 87.4740% train-loss : 0.36286
elapsed time: 170s
test-acc : 81.8500% test-loss : 0.54868
[Epoch 13]
train-acc : 88.1580% train-loss : 0.34673
elapsed time: 171s
test-acc : 83.0700% test-loss : 0.49779
[Epoch 14]
train-acc : 88.9260% train-loss : 0.31996
elapsed time: 171s
test-acc : 83.8900% test-loss : 0.48193
[Epoch 15]
train-acc : 89.1380% train-loss : 0.31583
elapsed time: 171s
test-acc : 83.9900% test-loss : 0.49245
[Epoch 16]
train-acc : 89.5460% train-loss : 0.30087
elapsed time: 170s
test-acc : 84.0100% test-loss : 0.49648
[Epoch 17]
train-acc : 90.0420% train-loss : 0.29067
elapsed time: 171s
test-acc : 85.2700% test-loss : 0.44473
[Epoch 18]
train-acc : 90.3720% train-loss : 0.28137
elapsed time: 171s
test-acc : 83.8900% test-loss : 0.49883
[Epoch 19]
train-acc : 90.6020% train-loss : 0.26961
elapsed time: 171s
test-acc : 84.4700% test-loss : 0.47203
[Epoch 20]
train-acc : 91.1460% train-loss : 0.25927
elapsed time: 170s
test-acc : 84.4200% test-loss : 0.49412
[Epoch 21]
train-acc : 91.1540% train-loss : 0.25661
elapsed time: 170s
test-acc : 85.3500% test-loss : 0.43626
[Epoch 22]
train-acc : 91.3620% train-loss : 0.24741
elapsed time: 171s
test-acc : 86.2200% test-loss : 0.41310
[Epoch 23]
train-acc : 91.9760% train-loss : 0.23271
elapsed time: 171s
test-acc : 86.5600% test-loss : 0.40795
[Epoch 24]
train-acc : 92.0000% train-loss : 0.23080
elapsed time: 171s
test-acc : 84.8000% test-loss : 0.46834
[Epoch 25]
train-acc : 92.1460% train-loss : 0.22744
elapsed time: 171s
test-acc : 85.4300% test-loss : 0.44402
[Epoch 26]
train-acc : 92.2120% train-loss : 0.22320
elapsed time: 170s
test-acc : 86.3300% test-loss : 0.41405
[Epoch 27]
train-acc : 92.3740% train-loss : 0.21625
elapsed time: 170s
test-acc : 87.3800% test-loss : 0.38440
[Epoch 28]
train-acc : 92.6960% train-loss : 0.21098
elapsed time: 171s
test-acc : 84.9300% test-loss : 0.46326
[Epoch 29]
train-acc : 92.8700% train-loss : 0.20541
elapsed time: 171s
test-acc : 86.5900% test-loss : 0.41840
[Epoch 30]
train-acc : 93.0700% train-loss : 0.20067
elapsed time: 170s
test-acc : 86.8400% test-loss : 0.42302
[Epoch 31]
train-acc : 93.2300% train-loss : 0.19319
elapsed time: 171s
test-acc : 87.1700% test-loss : 0.39542
[Epoch 32]
train-acc : 93.2280% train-loss : 0.19576
elapsed time: 171s
test-acc : 86.6500% test-loss : 0.43697
[Epoch 33]
train-acc : 93.5900% train-loss : 0.18686
elapsed time: 170s
test-acc : 86.8300% test-loss : 0.40863
[Epoch 34]
train-acc : 93.5820% train-loss : 0.18315
elapsed time: 170s
test-acc : 86.8200% test-loss : 0.42321
[Epoch 35]
train-acc : 93.6140% train-loss : 0.18232
elapsed time: 170s
test-acc : 86.1700% test-loss : 0.43491
[Epoch 36]
train-acc : 93.9620% train-loss : 0.17560
elapsed time: 170s
test-acc : 86.9100% test-loss : 0.41068
[Epoch 37]
train-acc : 93.9920% train-loss : 0.17193
elapsed time: 170s
test-acc : 87.0600% test-loss : 0.41822
[Epoch 38]
train-acc : 93.8620% train-loss : 0.17253
elapsed time: 170s
test-acc : 88.0500% test-loss : 0.38560
[Epoch 39]
train-acc : 94.2040% train-loss : 0.16850
elapsed time: 170s
test-acc : 86.7000% test-loss : 0.42949
[Epoch 40]
train-acc : 94.2940% train-loss : 0.16422
elapsed time: 170s
test-acc : 87.2100% test-loss : 0.39914
net.load_state_dict(torch.load('data\\checkpoint\\checkpoint-0040.bin', map_location=torch.device('cpu')))
Network((conv_block): Sequential((0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(res_blocks): Sequential((0): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(1): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(2): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(3): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(4): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(5): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(6): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(7): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))))(8): ResBlock((conv_lower): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(conv_upper): Sequential((0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(se_block): SEBlock((fc): Sequential((0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): ReLU()(2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)))))(out_conv): Sequential((0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU())(fc): Linear(in_features=128, out_features=10, bias=True)
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
for images, labels in test_loader:pred = torch.argmax(net(images), axis=1)print('confusion_matrix: \n', confusion_matrix(pred, labels))print('accuracy_score:', accuracy_score(pred, labels))print('precision_score:', precision_score(pred, labels, average='micro'))print('f1-score:', f1_score(pred, labels, average='micro'))break
confusion_matrix: [[11 0 0 0 0 0 0 0 1 0][ 0 9 0 0 0 0 0 0 0 0][ 0 0 10 0 1 2 0 0 0 0][ 0 0 1 11 0 0 0 1 0 0][ 0 0 0 1 9 0 0 0 0 0][ 0 0 0 1 0 7 1 0 0 0][ 0 0 0 0 0 0 18 0 0 0][ 1 0 0 2 0 0 0 12 0 0][ 1 0 0 0 0 0 0 0 16 0][ 0 1 0 0 0 0 0 0 0 11]]
accuracy_score: 0.890625
precision_score: 0.890625
f1-score: 0.890625
tensor([3, 8, 8, 8, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 0, 6, 7, 0, 4, 9,2, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 3, 4, 9, 9, 5, 4, 6, 5, 6, 0, 9, 4, 9,7, 6, 9, 8, 7, 3, 8, 8, 7, 3, 2, 5, 7, 5, 6, 3, 6, 2, 1, 2, 7, 7, 2, 6,8, 8, 0, 2, 9, 3, 7, 8, 8, 1, 1, 7, 2, 2, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,8, 3, 1, 2, 8, 0, 8, 3])
conf_mat = confusion_matrix(labels, pred)
array([[11, 0, 0, 0, 0, 0, 0, 1, 1, 0],[ 0, 9, 0, 0, 0, 0, 0, 0, 0, 1],[ 0, 0, 10, 1, 0, 0, 0, 0, 0, 0],[ 0, 0, 0, 11, 1, 1, 0, 2, 0, 0],[ 0, 0, 1, 0, 9, 0, 0, 0, 0, 0],[ 0, 0, 2, 0, 0, 7, 0, 0, 0, 0],[ 0, 0, 0, 0, 0, 1, 18, 0, 0, 0],[ 0, 0, 0, 1, 0, 0, 0, 12, 0, 0],[ 1, 0, 0, 0, 0, 0, 0, 0, 16, 0],[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 11]], dtype=int64)
df = pd.DataFrame(conf_mat, index=test_dataset.classes, columns=test_dataset.classes)
airplane | automobile | bird | cat | deer | dog | frog | horse | ship | truck | |
airplane | 11 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 |
automobile | 0 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
bird | 0 | 0 | 10 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
cat | 0 | 0 | 0 | 11 | 1 | 1 | 0 | 2 | 0 | 0 |
deer | 0 | 0 | 1 | 0 | 9 | 0 | 0 | 0 | 0 | 0 |
dog | 0 | 0 | 2 | 0 | 0 | 7 | 0 | 0 | 0 | 0 |
frog | 0 | 0 | 0 | 0 | 0 | 1 | 18 | 0 | 0 | 0 |
horse | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 12 | 0 | 0 |
ship | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 16 | 0 |
truck | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 11 |
# 绘制混淆矩阵图
plt.figure(figsize=(12, 12))
sns.heatmap(df, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix")
plt.ylabel("True Class")
plt.xlabel("Predicted Class")
- [1] Hu J, Shen L, Sun G. Squeeze-and-excitation networks[J]. arXiv preprint arXiv:1709.01507, 2017, 7.
- [2] hujip-frank/SENet
- [3] 知乎文章:最后一届ImageNet冠军模型:SENet
- [4] JYPark09/SENet-Pytorch
个人博客:madao33 blog