AUTOML_NNI案例之 1.pytorch——minist 超参优化

2024-03-28 12:38

本文主要是介绍AUTOML_NNI案例之 1.pytorch——minist 超参优化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.代码文件

https://github.com/microsoft/nni/tree/master/examples/trials/mnist-pytorch

主要包括,配置文件config_windows.yml和minist.py文件,搜索空间文件search_space.json文件。

2.config_windows.ymal配置文件

配置文件中包设置了trial次数和时间,要起训练的脚本,以及搜索空间

authorName: default
experimentName: example_mnist_pytorch#本次实验名称
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner#SMAC (SMAC should be installed through nnictl)builtinTunerName: TPEclassArgs:#choice: maximize, minimizeoptimize_mode: maximize
trial:command: python mnist.pycodeDir: .gpuNum: 0

3.搜索空间 search_space.json

其中包括可搜索超参空间。

有常见的“batch_size”,隐层数量"hideen_size",学习率"lr",loss优化的动量"momentum"

{"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},"momentum":{"_type":"uniform","_value":[0, 1]}
}

4.工程代码mnist.py

前面搭建网络,加载数据操作都很常规,代码写的也很nice,简单易懂。

关键在后面几句

 # get parameters form tuner
        tuner_params = nni.get_next_parameter()
        logger.debug(tuner_params)
        params = vars(merge_parameter(get_params(), tuner_params))#
        print(params)
        main(params)

"""
A deep MNIST classifier using convolutional layers.This file is a modification of the official pytorch mnist example:
https://github.com/pytorch/examples/blob/master/mnist/main.py
"""import os
import argparse
import logging
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nni.utils import merge_parameter
from torchvision import datasets, transformslogger = logging.getLogger('mnist_AutoML')class Net(nn.Module):def __init__(self, hidden_size):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, hidden_size)self.fc2 = nn.Linear(hidden_size, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)def train(args, model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:breakdata, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args['log_interval'] == 0:logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(args, model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, reduction='sum').item()# get the index of the max log-probabilitypred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), accuracy))return accuracydef main(args):use_cuda = not args['no_cuda'] and torch.cuda.is_available()torch.manual_seed(args['seed'])device = torch.device("cuda" if use_cuda else "cpu")kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}data_dir = args['data_dir']train_loader = torch.utils.data.DataLoader(datasets.MNIST(data_dir, train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args['batch_size'], shuffle=True, **kwargs)test_loader = torch.utils.data.DataLoader(datasets.MNIST(data_dir, train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=1000, shuffle=True, **kwargs)hidden_size = args['hidden_size']model = Net(hidden_size=hidden_size).to(device)optimizer = optim.SGD(model.parameters(), lr=args['lr'],momentum=args['momentum'])for epoch in range(1, args['epochs'] + 1):train(args, model, device, train_loader, optimizer, epoch)test_acc = test(args, model, device, test_loader)# report intermediate resultnni.report_intermediate_result(test_acc)logger.debug('test accuracy %g', test_acc)logger.debug('Pipe send intermediate result done.')# report final resultnni.report_final_result(test_acc)logger.debug('Final result is %g', test_acc)logger.debug('Send final result done.')def get_params():# Training settingsparser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument("--data_dir", type=str,default='./data', help="data directory")parser.add_argument('--batch_size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')parser.add_argument("--batch_num", type=int, default=None)parser.add_argument("--hidden_size", type=int, default=512, metavar='N',help='hidden layer size (default: 512)')parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')parser.add_argument('--no_cuda', action='store_true', default=False,help='disables CUDA training')parser.add_argument('--log_interval', type=int, default=1000, metavar='N',help='how many batches to wait before logging training status')args, _ = parser.parse_known_args()return argsif __name__ == '__main__':try:# get parameters form tunertuner_params = nni.get_next_parameter()logger.debug(tuner_params)params = vars(merge_parameter(get_params(), tuner_params))print(params)main(params)except Exception as exception:logger.exception(exception)raise


 

这篇关于AUTOML_NNI案例之 1.pytorch——minist 超参优化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/855644

相关文章

MySql基本查询之表的增删查改+聚合函数案例详解

《MySql基本查询之表的增删查改+聚合函数案例详解》本文详解SQL的CURD操作INSERT用于数据插入(单行/多行及冲突处理),SELECT实现数据检索(列选择、条件过滤、排序分页),UPDATE... 目录一、Create1.1 单行数据 + 全列插入1.2 多行数据 + 指定列插入1.3 插入否则更

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

Linux进程CPU绑定优化与实践过程

《Linux进程CPU绑定优化与实践过程》Linux支持进程绑定至特定CPU核心,通过sched_setaffinity系统调用和taskset工具实现,优化缓存效率与上下文切换,提升多核计算性能,适... 目录1. 多核处理器及并行计算概念1.1 多核处理器架构概述1.2 并行计算的含义及重要性1.3 并

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

PostgreSQL的扩展dict_int应用案例解析

《PostgreSQL的扩展dict_int应用案例解析》dict_int扩展为PostgreSQL提供了专业的整数文本处理能力,特别适合需要精确处理数字内容的搜索场景,本文给大家介绍PostgreS... 目录PostgreSQL的扩展dict_int一、扩展概述二、核心功能三、安装与启用四、字典配置方法

Python中re模块结合正则表达式的实际应用案例

《Python中re模块结合正则表达式的实际应用案例》Python中的re模块是用于处理正则表达式的强大工具,正则表达式是一种用来匹配字符串的模式,它可以在文本中搜索和匹配特定的字符串模式,这篇文章主... 目录前言re模块常用函数一、查看文本中是否包含 A 或 B 字符串二、替换多个关键词为统一格式三、提

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

Python get()函数用法案例详解

《Pythonget()函数用法案例详解》在Python中,get()是字典(dict)类型的内置方法,用于安全地获取字典中指定键对应的值,它的核心作用是避免因访问不存在的键而引发KeyError错... 目录简介基本语法一、用法二、案例:安全访问未知键三、案例:配置参数默认值简介python是一种高级编

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(