bert 的MLM框架任务-梯度累积

2024-05-13 04:36

本文主要是介绍bert 的MLM框架任务-梯度累积,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考:BEHRT/task/MLM.ipynb at ca0163faf5ec09e5b31b064b20085f6608c2b6d1 · deepmedicine/BEHRT · GitHub

class BertConfig(Bert.modeling.BertConfig):def __init__(self, config):super(BertConfig, self).__init__(vocab_size_or_config_json_file=config.get('vocab_size'),hidden_size=config['hidden_size'],num_hidden_layers=config.get('num_hidden_layers'),num_attention_heads=config.get('num_attention_heads'),intermediate_size=config.get('intermediate_size'),hidden_act=config.get('hidden_act'),hidden_dropout_prob=config.get('hidden_dropout_prob'),attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),max_position_embeddings = config.get('max_position_embedding'),initializer_range=config.get('initializer_range'),)self.seg_vocab_size = config.get('seg_vocab_size')self.age_vocab_size = config.get('age_vocab_size')class TrainConfig(object):def __init__(self, config):self.batch_size = config.get('batch_size')self.use_cuda = config.get('use_cuda')self.max_len_seq = config.get('max_len_seq')self.train_loader_workers = config.get('train_loader_workers')self.test_loader_workers = config.get('test_loader_workers')self.device = config.get('device')self.output_dir = config.get('output_dir')self.output_name = config.get('output_name')self.best_name = config.get('best_name')file_config = {'vocab':'',  # vocabulary idx2token, token2idx'data': '',  # formated data 'model_path': '', # where to save model'model_name': '', # model name'file_name': '',  # log path
}
create_folder(file_config['model_path'])global_params = {'max_seq_len': 64,'max_age': 110,'month': 1,'age_symbol': None,'min_visit': 5,'gradient_accumulation_steps': 1
}optim_param = {'lr': 3e-5,'warmup_proportion': 0.1,'weight_decay': 0.01
}train_params = {'batch_size': 256,'use_cuda': True,'max_len_seq': global_params['max_seq_len'],'device': 'cuda:0'
}

模型:

BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])data = pd.read_parquet(file_config['data'])
# remove patients with visits less than min visit
data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
data = data[data['length'] >= global_params['min_visit']]
data = data.reset_index(drop=True)Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')
trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)model_config = {'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding'hidden_size': 288, # word embedding and seg embedding hidden size'seg_vocab_size': 2, # number of vocab for seg embedding'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens'hidden_dropout_prob': 0.1, # dropout rate'num_hidden_layers': 6, # number of multi-head attention layers required'num_attention_heads': 12, # number of attention heads'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported'initializer_range': 0.02, # parameter weight initializer range
}conf = BertConfig(model_config)
model = BertForMaskedLM(conf)model = model.to(train_params['device'])
optim = adam(params=list(model.named_parameters()), config=optim_param)

计算准确率:

def cal_acc(label, pred):logs = nn.LogSoftmax()label=label.cpu().numpy()ind = np.where(label!=-1)[0]truepred = pred.detach().cpu().numpy()truepred = truepred[ind]truelabel = label[ind]truepred = logs(torch.tensor(truepred))outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]precision = skm.precision_score(truelabel, outs, average='micro')return precision

开始训练:

def train(e, loader):tr_loss = 0temp_loss = 0nb_tr_examples, nb_tr_steps = 0, 0cnt= 0start = time.time()for step, batch in enumerate(loader):cnt +=1batch = tuple(t.to(train_params['device']) for t in batch)age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batchloss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)if global_params['gradient_accumulation_steps'] >1:loss = loss/global_params['gradient_accumulation_steps']loss.backward()temp_loss += loss.item()tr_loss += loss.item()nb_tr_examples += input_ids.size(0)nb_tr_steps += 1if step % 200==0:print("epoch: {}\t| cnt: {}\t|Loss: {}\t| precision: {:.4f}\t| time: {:.2f}".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))temp_loss = 0start = time.time()if (step + 1) % global_params['gradient_accumulation_steps'] == 0:optim.step()optim.zero_grad()print("** ** * Saving fine - tuned model ** ** * ")model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-selfcreate_folder(file_config['model_path'])output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])torch.save(model_to_save.state_dict(), output_model_file)cost = time.time() - startreturn tr_loss, cost

保存结果:

f = open(os.path.join(file_config['model_path'], file_config['file_name']), "w")
f.write('{}\t{}\t{}\n'.format('epoch', 'loss', 'time'))
for e in range(50):loss, time_cost = train(e, trainload)loss = loss/data_lenf.write('{}\t{}\t{}\n'.format(e, loss, time_cost))
f.close()    

某大模型对上面代码梯度累积的解释:

gradient_accumulation_steps 是一个用于在训练过程中累积梯度的参数。它的作用是在每次更新模型参数之前,累积多个 batch 的梯度,然后进行一次参数更新。这个技巧有助于在显存受限的情况下仍能使用较大的 batch size,并且可以稳定模型的训练过程。

对于BERT的MLM任务来说,通常会使用梯度累积,因为BERT是一个比较大的模型,使用较大的batch size可能会导致显存不足的问题,特别是在GPU资源受限的情况下。梯度累积可以帮助解决这个问题。

在你提供的代码中,gradient_accumulation_steps 的值被用来判断是否要对 loss 进行除法操作。如果 gradient_accumulation_steps 大于1,那么会将 loss 除以 gradient_accumulation_steps,以确保在多个batch的梯度累积后对参数进行更新。这是一个常见的做法。

因此,可以说对于BERT的MLM任务,使用梯度累积是一种常见的做法,特别是在资源受限的情况下。

这篇关于bert 的MLM框架任务-梯度累积的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/984693

相关文章

Python Invoke自动化任务库的使用

《PythonInvoke自动化任务库的使用》Invoke是一个强大的Python库,用于编写自动化脚本,本文就来介绍一下PythonInvoke自动化任务库的使用,具有一定的参考价值,感兴趣的可以... 目录什么是 Invoke?如何安装 Invoke?Invoke 基础1. 运行测试2. 构建文档3.

解决Cron定时任务中Pytest脚本无法发送邮件的问题

《解决Cron定时任务中Pytest脚本无法发送邮件的问题》文章探讨解决在Cron定时任务中运行Pytest脚本时邮件发送失败的问题,先优化环境变量,再检查Pytest邮件配置,接着配置文件确保SMT... 目录引言1. 环境变量优化:确保Cron任务可以正确执行解决方案:1.1. 创建一个脚本1.2. 修

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

如何使用celery进行异步处理和定时任务(django)

《如何使用celery进行异步处理和定时任务(django)》文章介绍了Celery的基本概念、安装方法、如何使用Celery进行异步任务处理以及如何设置定时任务,通过Celery,可以在Web应用中... 目录一、celery的作用二、安装celery三、使用celery 异步执行任务四、使用celery

什么是cron? Linux系统下Cron定时任务使用指南

《什么是cron?Linux系统下Cron定时任务使用指南》在日常的Linux系统管理和维护中,定时执行任务是非常常见的需求,你可能需要每天执行备份任务、清理系统日志或运行特定的脚本,而不想每天... 在管理 linux 服务器的过程中,总有一些任务需要我们定期或重复执行。就比如备份任务,通常会选在服务器资

MyBatis框架实现一个简单的数据查询操作

《MyBatis框架实现一个简单的数据查询操作》本文介绍了MyBatis框架下进行数据查询操作的详细步骤,括创建实体类、编写SQL标签、配置Mapper、开启驼峰命名映射以及执行SQL语句等,感兴趣的... 基于在前面几章我们已经学习了对MyBATis进行环境配置,并利用SqlSessionFactory核

cross-plateform 跨平台应用程序-03-如果只选择一个框架,应该选择哪一个?

跨平台系列 cross-plateform 跨平台应用程序-01-概览 cross-plateform 跨平台应用程序-02-有哪些主流技术栈? cross-plateform 跨平台应用程序-03-如果只选择一个框架,应该选择哪一个? cross-plateform 跨平台应用程序-04-React Native 介绍 cross-plateform 跨平台应用程序-05-Flutte

Spring框架5 - 容器的扩展功能 (ApplicationContext)

private static ApplicationContext applicationContext;static {applicationContext = new ClassPathXmlApplicationContext("bean.xml");} BeanFactory的功能扩展类ApplicationContext进行深度的分析。ApplicationConext与 BeanF

数据治理框架-ISO数据治理标准

引言 "数据治理"并不是一个新的概念,国内外有很多组织专注于数据治理理论和实践的研究。目前国际上,主要的数据治理框架有ISO数据治理标准、GDI数据治理框架、DAMA数据治理管理框架等。 ISO数据治理标准 改标准阐述了数据治理的标准、基本原则和数据治理模型,是一套完整的数据治理方法论。 ISO/IEC 38505标准的数据治理方法论的核心内容如下: 数据治理的目标:促进组织高效、合理地

ZooKeeper 中的 Curator 框架解析

Apache ZooKeeper 是一个为分布式应用提供一致性服务的软件。它提供了诸如配置管理、分布式同步、组服务等功能。在使用 ZooKeeper 时,Curator 是一个非常流行的客户端库,它简化了 ZooKeeper 的使用,提供了高级的抽象和丰富的工具。本文将详细介绍 Curator 框架,包括它的设计哲学、核心组件以及如何使用 Curator 来简化 ZooKeeper 的操作。 1