Pytorch(1.2.0+):多机单卡并行实操(MNIST识别)

2023-12-23 12:10

本文主要是介绍Pytorch(1.2.0+):多机单卡并行实操(MNIST识别),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景

简单实际操作一下用Pytorch(1.2.0+)进行多机单卡并行训练,可能就不太关注原理了。


参考

https://blog.csdn.net/u010557442/article/details/79431520
https://zhuanlan.zhihu.com/p/116482019
https://blog.csdn.net/gbyy42299/article/details/103673840
https://blog.csdn.net/m0_38008956/article/details/86559432


代码

https://gitee.com/KevinYan37/pytorch_ddp

流程

1. 配置环境

将多台配置一模一样的电脑(ubuntu系统,显卡版本,NVIDIA驱动,CUDA驱动,pytorch版本)置于同一网段下,例如我的两台电脑分别在192.168.10.235192.168.10.236,同时关闭防火墙等操作。

2. 确认环境
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.distributed.is_available())
3. MNIST数据集代码

以下代码都是从torchvision里拷贝得到,只是修改了一下下载路径。

import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
from torchvision import datasets
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \verify_str_argdef get_int(b: bytes) -> int:return int(codecs.encode(b, 'hex'), 16)def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:"""Return a file object that possibly decompresses 'path' on the fly.Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'."""if not isinstance(path, torch._six.string_classes):return pathif path.endswith('.gz'):return gzip.open(path, 'rb')if path.endswith('.xz'):return lzma.open(path, 'rb')return open(path, 'rb')SN3_PASCALVINCENT_TYPEMAP = {8: (torch.uint8, np.uint8, np.uint8),9: (torch.int8, np.int8, np.int8),11: (torch.int16, np.dtype('>i2'), 'i2'),12: (torch.int32, np.dtype('>i4'), 'i4'),13: (torch.float32, np.dtype('>f4'), 'f4'),14: (torch.float64, np.dtype('>f8'), 'f8')
}def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').Argument may be a filename, compressed filename, or file object."""# readwith open_maybe_compressed_file(path) as f:data = f.read()# parsemagic = get_int(data[0:4])nd = magic % 256ty = magic // 256assert nd >= 1 and nd <= 3assert ty >= 8 and ty <= 14m = SN3_PASCALVINCENT_TYPEMAP[ty]s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))assert parsed.shape[0] == np.prod(s) or not strictreturn torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)def read_label_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 1)return x.long()def read_image_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 3)return xclass MNIST(datasets.VisionDataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``MNIST/processed/training.pt``and  ``MNIST/processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""resources = [("file://./data/MNIST/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),("file://./data/MNIST/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),("file://./data/MNIST/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),("file://./data/MNIST/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']@propertydef train_labels(self):warnings.warn("train_labels has been renamed targets")return self.targets@propertydef test_labels(self):warnings.warn("test_labels has been renamed targets")return self.targets@propertydef train_data(self):warnings.warn("train_data has been renamed data")return self.data@propertydef test_data(self):warnings.warn("test_data has been renamed data")return self.datadef __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(MNIST, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train  # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:data_file = self.training_fileelse:data_file = self.test_fileself.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))def __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)@propertydef raw_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'raw')@propertydef processed_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'processed')@propertydef class_to_idx(self) -> Dict[str, int]:return {_class: i for i, _class in enumerate(self.classes)}def _check_exists(self) -> bool:return (os.path.exists(os.path.join(self.processed_folder,self.training_file)) andos.path.exists(os.path.join(self.processed_folder,self.test_file)))def download(self) -> None:"""Download the MNIST data if it doesn't exist in processed_folder already."""if self._check_exists():returnos.makedirs(self.raw_folder, exist_ok=True)os.makedirs(self.processed_folder, exist_ok=True)# download filesfor url, md5 in self.resources:filename = url.rpartition('/')[2]download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")
4. 训练代码
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import timeimport torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data 
import torch.utils.data.distributed
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variablefrom MNIST import MNIST# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')
# 必须设置的参数 
parser.add_argument('--tcp', type=str, default='tcp://192.168.10.235:23456', metavar='N',help='how many batches to wait before logging training status')
parser.add_argument('--rank', type=int, default=0, metavar='N',help='pytorch distribued rank')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()#初始化
dist.init_process_group(init_method=args.tcp,backend="nccl",rank=args.rank,world_size=2,group_name="pytorch_test")torch.manual_seed(args.seed)
if args.cuda:torch.cuda.manual_seed(args.seed)train_dataset=MNIST('./data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
# 分发数据
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.test_batch_size, shuffle=True, **kwargs)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x)model = Net()
if args.cuda:# 分发模型model.cuda()model = torch.nn.parallel.DistributedDataParallel(model)# model = torch.nn.DataParallel(model,device_ids=[0,1,2,3]).cuda()# model.cuda()optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data, volatile=True), Variable(target)output = model(data)test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch losspred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probabilitycorrect += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))tot_time=0;for epoch in range(1, args.epochs + 1):# 设置epoch位置,这应该是个为了同步所做的工作train_sampler.set_epoch(epoch)start_cpu_secs = time.time()#long runningtrain(epoch)end_cpu_secs = time.time()print("Epoch {} of {} took {:.3f}s".format(epoch , args.epochs , end_cpu_secs - start_cpu_secs))tot_time+=end_cpu_secs - start_cpu_secstest()print("Total time= {:.3f}s".format(tot_time))
5. 运行代码

在两台电脑上分别运行代码即可

# 主机,rank为0
python test.py --tcp '192.168.10.235:23456' --rank 0

在另外一台电脑上运行

python test.py --tcp '192.168.10.235:23456' --rank 1

在这里插入图片描述

总结

本次就是一个简单的操作,具体细节原理就不讨论了,以后继续学习。

这篇关于Pytorch(1.2.0+):多机单卡并行实操(MNIST识别)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/527985

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

MySQL中实现多表查询的操作方法(配sql+实操图+案例巩固 通俗易懂版)

《MySQL中实现多表查询的操作方法(配sql+实操图+案例巩固通俗易懂版)》本文主要讲解了MySQL中的多表查询,包括子查询、笛卡尔积、自连接、多表查询的实现方法以及多列子查询等,通过实际例子和操... 目录复合查询1. 回顾查询基本操作group by 分组having1. 显示部门号为10的部门名,员

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

使用Python开发一个图像标注与OCR识别工具

《使用Python开发一个图像标注与OCR识别工具》:本文主要介绍一个使用Python开发的工具,允许用户在图像上进行矩形标注,使用OCR对标注区域进行文本识别,并将结果保存为Excel文件,感兴... 目录项目简介1. 图像加载与显示2. 矩形标注3. OCR识别4. 标注的保存与加载5. 裁剪与重置图像

Java之并行流(Parallel Stream)使用详解

《Java之并行流(ParallelStream)使用详解》Java并行流(ParallelStream)通过多线程并行处理集合数据,利用Fork/Join框架加速计算,适用于大规模数据集和计算密集... 目录Java并行流(Parallel Stream)1. 核心概念与原理2. 创建并行流的方式3. 适

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)

《Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)》本文介绍了如何使用Python和Selenium结合ddddocr库实现图片验证码的识别和点击功能,感兴趣的朋友一起看... 目录1.获取图片2.目标识别3.背景坐标识别3.1 ddddocr3.2 打码平台4.坐标点击5.图