BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例

2023-10-25 22:15

本文主要是介绍BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

加载数据集

# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

在训练集中植入5000个中毒样本

# 在训练集中植入5000个中毒样本
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

训练模型

data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue
if __name__=='__main__':main()

测试攻击成功率

# 攻击成功率 99.66%  对测试集中所有图像都注入后门for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()

可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。
在这里插入图片描述
在这里插入图片描述

完整代码

from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()# 训练集样本数据
print(len(train_loader))# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue# 选择一个训练集中植入后门的数据,测试后门是否有效'''sample, label = next(iter(data_loader_train))print(sample.size())  # [64, 1, 28, 28]print(label[0])# 可视化plt.imshow(sample[0][0])plt.show()model.load_state_dict(torch.load('badnets.pth'))model.eval()sample = sample.to(device)output = model(sample)print(output[0])pred = output.argmax(dim=1)print(pred[0])'''# 攻击成功率 99.66%for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()if __name__=='__main__':main()

这篇关于BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/285286

相关文章

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

Python使用vllm处理多模态数据的预处理技巧

《Python使用vllm处理多模态数据的预处理技巧》本文深入探讨了在Python环境下使用vLLM处理多模态数据的预处理技巧,我们将从基础概念出发,详细讲解文本、图像、音频等多模态数据的预处理方法,... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Navicat数据表的数据添加,删除及使用sql完成数据的添加过程

《Navicat数据表的数据添加,删除及使用sql完成数据的添加过程》:本文主要介绍Navicat数据表的数据添加,删除及使用sql完成数据的添加过程,具有很好的参考价值,希望对大家有所帮助,如有... 目录Navicat数据表数据添加,删除及使用sql完成数据添加选中操作的表则出现如下界面,查看左下角从左