本文主要是介绍Pytorch(1.2.0+):多机单卡并行实操(MNIST识别),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
背景
简单实际操作一下用Pytorch(1.2.0+)进行多机单卡并行训练,可能就不太关注原理了。
参考
https://blog.csdn.net/u010557442/article/details/79431520
https://zhuanlan.zhihu.com/p/116482019
https://blog.csdn.net/gbyy42299/article/details/103673840
https://blog.csdn.net/m0_38008956/article/details/86559432
代码
https://gitee.com/KevinYan37/pytorch_ddp
流程
1. 配置环境
将多台配置一模一样的电脑(ubuntu系统,显卡版本,NVIDIA驱动,CUDA驱动,pytorch版本)置于同一网段下,例如我的两台电脑分别在192.168.10.235
和192.168.10.236
,同时关闭防火墙等操作。
2. 确认环境
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.distributed.is_available())
3. MNIST数据集代码
以下代码都是从torchvision里拷贝得到,只是修改了一下下载路径。
import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
from torchvision import datasets
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \verify_str_argdef get_int(b: bytes) -> int:return int(codecs.encode(b, 'hex'), 16)def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:"""Return a file object that possibly decompresses 'path' on the fly.Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'."""if not isinstance(path, torch._six.string_classes):return pathif path.endswith('.gz'):return gzip.open(path, 'rb')if path.endswith('.xz'):return lzma.open(path, 'rb')return open(path, 'rb')SN3_PASCALVINCENT_TYPEMAP = {8: (torch.uint8, np.uint8, np.uint8),9: (torch.int8, np.int8, np.int8),11: (torch.int16, np.dtype('>i2'), 'i2'),12: (torch.int32, np.dtype('>i4'), 'i4'),13: (torch.float32, np.dtype('>f4'), 'f4'),14: (torch.float64, np.dtype('>f8'), 'f8')
}def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').Argument may be a filename, compressed filename, or file object."""# readwith open_maybe_compressed_file(path) as f:data = f.read()# parsemagic = get_int(data[0:4])nd = magic % 256ty = magic // 256assert nd >= 1 and nd <= 3assert ty >= 8 and ty <= 14m = SN3_PASCALVINCENT_TYPEMAP[ty]s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))assert parsed.shape[0] == np.prod(s) or not strictreturn torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)def read_label_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 1)return x.long()def read_image_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 3)return xclass MNIST(datasets.VisionDataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``MNIST/processed/training.pt``and ``MNIST/processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""resources = [("file://./data/MNIST/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),("file://./data/MNIST/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),("file://./data/MNIST/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),("file://./data/MNIST/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']@propertydef train_labels(self):warnings.warn("train_labels has been renamed targets")return self.targets@propertydef test_labels(self):warnings.warn("test_labels has been renamed targets")return self.targets@propertydef train_data(self):warnings.warn("train_data has been renamed data")return self.data@propertydef test_data(self):warnings.warn("test_data has been renamed data")return self.datadef __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(MNIST, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:data_file = self.training_fileelse:data_file = self.test_fileself.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))def __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)@propertydef raw_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'raw')@propertydef processed_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'processed')@propertydef class_to_idx(self) -> Dict[str, int]:return {_class: i for i, _class in enumerate(self.classes)}def _check_exists(self) -> bool:return (os.path.exists(os.path.join(self.processed_folder,self.training_file)) andos.path.exists(os.path.join(self.processed_folder,self.test_file)))def download(self) -> None:"""Download the MNIST data if it doesn't exist in processed_folder already."""if self._check_exists():returnos.makedirs(self.raw_folder, exist_ok=True)os.makedirs(self.processed_folder, exist_ok=True)# download filesfor url, md5 in self.resources:filename = url.rpartition('/')[2]download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")
4. 训练代码
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import timeimport torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data
import torch.utils.data.distributed
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variablefrom MNIST import MNIST# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')
# 必须设置的参数
parser.add_argument('--tcp', type=str, default='tcp://192.168.10.235:23456', metavar='N',help='how many batches to wait before logging training status')
parser.add_argument('--rank', type=int, default=0, metavar='N',help='pytorch distribued rank')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()#初始化
dist.init_process_group(init_method=args.tcp,backend="nccl",rank=args.rank,world_size=2,group_name="pytorch_test")torch.manual_seed(args.seed)
if args.cuda:torch.cuda.manual_seed(args.seed)train_dataset=MNIST('./data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
# 分发数据
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.test_batch_size, shuffle=True, **kwargs)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x)model = Net()
if args.cuda:# 分发模型model.cuda()model = torch.nn.parallel.DistributedDataParallel(model)# model = torch.nn.DataParallel(model,device_ids=[0,1,2,3]).cuda()# model.cuda()optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data, volatile=True), Variable(target)output = model(data)test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch losspred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probabilitycorrect += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))tot_time=0;for epoch in range(1, args.epochs + 1):# 设置epoch位置,这应该是个为了同步所做的工作train_sampler.set_epoch(epoch)start_cpu_secs = time.time()#long runningtrain(epoch)end_cpu_secs = time.time()print("Epoch {} of {} took {:.3f}s".format(epoch , args.epochs , end_cpu_secs - start_cpu_secs))tot_time+=end_cpu_secs - start_cpu_secstest()print("Total time= {:.3f}s".format(tot_time))
5. 运行代码
在两台电脑上分别运行代码即可
# 主机,rank为0
python test.py --tcp '192.168.10.235:23456' --rank 0
在另外一台电脑上运行
python test.py --tcp '192.168.10.235:23456' --rank 1
总结
本次就是一个简单的操作,具体细节原理就不讨论了,以后继续学习。
这篇关于Pytorch(1.2.0+):多机单卡并行实操(MNIST识别)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!