基于飞浆NLP的BERT-finetuning新闻文本分类

2023-11-07 16:12

本文主要是介绍基于飞浆NLP的BERT-finetuning新闻文本分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

1.数据预处理

2.加载模型

3.批训练

4.准确率

1.数据预处理

导入所需库

import numpy as np
from paddle.io import DataLoader,TensorDataset
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from sklearn.model_selection import train_test_split
import paddle
import matplotlib.pyplot as plt
import jieba

训练集格式 标签ID+\t+标签+\t+原文标题

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

去除停用词、

stop=[]
with open('stop.txt',mode='r',encoding='utf-8') as f:stop=f.read().split('\n')
stop_word={}
for s in stop:stop_word[s]=True
def remove_stopwords(datas):  filtered_words = [text for text in datas if text not in stop_word]return ' '.join(filtered_words)  

进行中文分词、转换为token序列

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')def convert(datas, max_seq_length=40):ans=[]for text in datas:input_ids = tokenizer(text, max_seq_len=max_seq_length)['input_ids']input_ids = input_ids[:max_seq_length]  # 截断input_ids = input_ids + [tokenizer.pad_token_id] * (max_seq_length - len(input_ids))  # 填充ans.append(input_ids)return ans

导入数据,进行预处理,数据集在最后

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

 

2.加载模型 

加载预训练模型,冻结大部分参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
model.classifier = paddle.nn.Linear(768, 14)
for name, param in model.named_parameters():if "classifier" not in name and 'bert.pooler.dense' not in name and 'bert.encoder.layers.11' not in name:param.stop_gradient = True

ps:如果只保留classifier用来训练,效果欠佳。

设置超参数,学习率初始设为0.01~0.1

epochs=2
batch_size=1024*4
learning_rate=0.001

损失函数和优化器

criterion = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters())

3.批训练

划分训练集和测试集

datas=np.array(datas)
labels=np.array(labels)
x_train,x_test,y_train,y_test=train_test_split(datas,labels,random_state=42,test_size=0.2)
train_dataset=TensorDataset([x_train,y_train])
train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

迭代分批训练,可视化损失函数

total_loss=[]
for epoch in range(epochs):for batch_data,batch_label in train_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(epoch,loss.numpy()[0])total_loss.append(loss.numpy()[0])optimizer.clear_grad()loss.backward()optimizer.step()
paddle.save({'model':model.state_dict()},'model.param')
paddle.save({'optimizer':optimizer.state_dict()},'optimizer.param')
plt.plot(range(len(total_loss)),total_loss)
plt.show()

4.准确率

在测试集上如法炮制,查看准确率

total_loss=[]
x_test=np.array(x_test)
y_test=np.array(y_test)
test_dataset=TensorDataset([x_test,y_test])
test_loader=DataLoader(test_dataset,shuffle=True,batch_size=batch_size)with paddle.no_grad():for batch_data,batch_label in test_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(loss)outputs=paddle.argmax(outputs,axis=1)total_loss.append(loss.numpy()[0])score=0for predict,label in zip(outputs,batch_label):if predict==label:score+=1print(score/len(batch_label))
plt.plot(range(len(total_loss)),total_loss)
plt.show()

最后在验证集上输出要求的类别

arr=['财经','彩票','房产','股票','家居','教育','科技','社会','时尚','时政','体育','星座','游戏','娱乐']
evals=[]
contetns=[]
with open('data/data126283/data/Test.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continueevals.append(item)
evals=convert(evals)
evals=np.array(evals)
with paddle.no_grad():for i in range(0,len(evals),2048):i=min(len(evals),i)batch_data=evals[i:i+2048]batch_data=paddle.to_tensor(batch_data,dtype='int64')predict=model(batch_data)predict=list(paddle.argmax(predict,axis=1))print(i,len(predict))for j in range(len(predict)):predict[j]=arr[predict[j]]with open('result.txt',mode='a',encoding='utf-8') as f:f.write('\n'.join(predict))f.write('\n')

ps:注意最后的f.write('\n'),否则除第一次,每次打印少一行,很坑

最后损失函数收敛在0.2或0.1左右比较正常,四舍五入差不多90准确率,当然如果你解冻更多参数,自然可以更加精确,看运行环境的配置了,建议不要使用免费平台配置,否则比乌龟还慢。。

欢迎提出问题

数据集

这篇关于基于飞浆NLP的BERT-finetuning新闻文本分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/364782

相关文章

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

通过C#获取PDF中指定文本或所有文本的字体信息

《通过C#获取PDF中指定文本或所有文本的字体信息》在设计和出版行业中,字体的选择和使用对最终作品的质量有着重要影响,然而,有时我们可能会遇到包含未知字体的PDF文件,这使得我们无法准确地复制或修改文... 目录引言C# 获取PDF中指定文本的字体信息C# 获取PDF文档中用到的所有字体信息引言在设计和出

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

Java操作xls替换文本或图片的功能实现

《Java操作xls替换文本或图片的功能实现》这篇文章主要给大家介绍了关于Java操作xls替换文本或图片功能实现的相关资料,文中通过示例代码讲解了文件上传、文件处理和Excel文件生成,需要的朋友可... 目录准备xls模板文件:template.xls准备需要替换的图片和数据功能实现包声明与导入类声明与

python解析HTML并提取span标签中的文本

《python解析HTML并提取span标签中的文本》在网页开发和数据抓取过程中,我们经常需要从HTML页面中提取信息,尤其是span元素中的文本,span标签是一个行内元素,通常用于包装一小段文本或... 目录一、安装相关依赖二、html 页面结构三、使用 BeautifulSoup javascript

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

Vue3项目开发——新闻发布管理系统(六)

文章目录 八、首页设计开发1、页面设计2、登录访问拦截实现3、用户基本信息显示①封装用户基本信息获取接口②用户基本信息存储③用户基本信息调用④用户基本信息动态渲染 4、退出功能实现①注册点击事件②添加退出功能③数据清理 5、代码下载 八、首页设计开发 登录成功后,系统就进入了首页。接下来,也就进行首页的开发了。 1、页面设计 系统页面主要分为三部分,左侧为系统的菜单栏,右侧

BERT 论文逐段精读【论文精读】

BERT: 近 3 年 NLP 最火 CV: 大数据集上的训练好的 NN 模型,提升 CV 任务的性能 —— ImageNet 的 CNN 模型 NLP: BERT 简化了 NLP 任务的训练,提升了 NLP 任务的性能 BERT 如何站在巨人的肩膀上的?使用了哪些 NLP 已有的技术和思想?哪些是 BERT 的创新? 1标题 + 作者 BERT: Pre-trainin

8. 自然语言处理中的深度学习:从词向量到BERT

引言 深度学习在自然语言处理(NLP)领域的应用极大地推动了语言理解和生成技术的发展。通过从词向量到预训练模型(如BERT)的演进,NLP技术在机器翻译、情感分析、问答系统等任务中取得了显著成果。本篇博文将探讨深度学习在NLP中的核心技术,包括词向量、序列模型(如RNN、LSTM),以及BERT等预训练模型的崛起及其实际应用。 1. 词向量的生成与应用 词向量(Word Embedding)