Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图

本文主要是介绍Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 残差网络ResNet的结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.图像特征提取和可视化分析

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as npimgname = 'bottle_broken_large.png' 
savepath='vis_resnet50/features_bottle'
if not os.path.isdir(savepath):os.makedirs(savepath)def draw_features(width,height,x,savename):tic = time.time()fig = plt.figure(figsize=(16, 16))fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)for i in range(width*height):plt.subplot(height, width, i + 1)plt.axis('off')img = x[0, i, :, :]pmin = np.min(img)pmax = np.max(img)img = ((img - pmin) / (pmax - pmin + 0.000001))*255  #float在[0,1]之间,转换成0-255img=img.astype(np.uint8)  #转成unit8img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #生成heat mapimg = img[:, :, ::-1]#注意cv2(BGR)和matplotlib(RGB)通道是相反的plt.imshow(img)print("{}/{}".format(i,width*height))fig.savefig(savename, dpi=100)fig.clf()plt.close()print("time:{}".format(time.time()-tic))class ft_net(nn.Module):def __init__(self):super(ft_net, self).__init__()model_ft = models.resnet50(pretrained=True)self.model = model_ftdef forward(self, x):if True: # draw features or notx = self.model.conv1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f1_conv1.png".format(savepath))x = self.model.bn1(x)draw_features(8, 8, x.cpu().numpy(),"{}/f2_bn1.png".format(savepath))x = self.model.relu(x)draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(savepath))x = self.model.maxpool(x)draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(savepath))x = self.model.layer1(x)draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(savepath))x = self.model.layer2(x)draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(savepath))x = self.model.layer3(x)draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(savepath))x = self.model.layer4(x)draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(savepath))draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(savepath))x = self.model.avgpool(x)plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0])plt.savefig("{}/f9_avgpool.png".format(savepath))plt.clf()plt.close()x = x.view(x.size(0), -1)x = self.model.fc(x)plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :])plt.savefig("{}/f10_fc.png".format(savepath))plt.clf()plt.close()else :x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = x.view(x.size(0), -1)x = self.model.fc(x)return xmodel = ft_net().cuda()# pretrained_dict = resnet50.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# net.load_state_dict(model_dict)
model.eval()
img = cv2.imread(imgname)
img = cv2.resize(img, (288, 288))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = transform(img).cuda()
img = img.unsqueeze(0)with torch.no_grad():start = time.time()out = model(img)print("total time:{}".format(time.time()-start))result = out.cpu().numpy()# ind=np.argmax(out.cpu().numpy())ind = np.argsort(result, axis=1)for i in range(5):print("predict:top {} = cls {} : score {}".format(i+1,ind[0,1000-i-1],result[0,1000-i-1]))print("done")

可视化结果:

在这里插入图片描述

在这里插入图片描述

这篇关于Pytorch: 利用预训练的残差网络ResNet50进行图像特征提取,并可视化特征图热图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/899260

相关文章

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

OpenCV图像形态学的实现

《OpenCV图像形态学的实现》本文主要介绍了OpenCV图像形态学的实现,包括腐蚀、膨胀、开运算、闭运算、梯度运算、顶帽运算和黑帽运算,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起... 目录一、图像形态学简介二、腐蚀(Erosion)1. 原理2. OpenCV 实现三、膨胀China编程(

通过Spring层面进行事务回滚的实现

《通过Spring层面进行事务回滚的实现》本文主要介绍了通过Spring层面进行事务回滚的实现,包括声明式事务和编程式事务,具有一定的参考价值,感兴趣的可以了解一下... 目录声明式事务回滚:1. 基础注解配置2. 指定回滚异常类型3. ​不回滚特殊场景编程式事务回滚:1. ​使用 TransactionT

Java中使用Hutool进行AES加密解密的方法举例

《Java中使用Hutool进行AES加密解密的方法举例》AES是一种对称加密,所谓对称加密就是加密与解密使用的秘钥是一个,下面:本文主要介绍Java中使用Hutool进行AES加密解密的相关资料... 目录前言一、Hutool简介与引入1.1 Hutool简介1.2 引入Hutool二、AES加密解密基础

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

SpringSecurity6.0 如何通过JWTtoken进行认证授权

《SpringSecurity6.0如何通过JWTtoken进行认证授权》:本文主要介绍SpringSecurity6.0通过JWTtoken进行认证授权的过程,本文给大家介绍的非常详细,感兴趣... 目录项目依赖认证UserDetailService生成JWT token权限控制小结之前写过一个文章,从S

基于Python打造一个可视化FTP服务器

《基于Python打造一个可视化FTP服务器》在日常办公和团队协作中,文件共享是一个不可或缺的需求,所以本文将使用Python+Tkinter+pyftpdlib开发一款可视化FTP服务器,有需要的小... 目录1. 概述2. 功能介绍3. 如何使用4. 代码解析5. 运行效果6.相关源码7. 总结与展望1

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.