基于飞浆NLP的BERT-finetuning新闻文本分类

2023-11-07 16:12

本文主要是介绍基于飞浆NLP的BERT-finetuning新闻文本分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

1.数据预处理

2.加载模型

3.批训练

4.准确率

1.数据预处理

导入所需库

import numpy as np
from paddle.io import DataLoader,TensorDataset
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from sklearn.model_selection import train_test_split
import paddle
import matplotlib.pyplot as plt
import jieba

训练集格式 标签ID+\t+标签+\t+原文标题

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

去除停用词、

stop=[]
with open('stop.txt',mode='r',encoding='utf-8') as f:stop=f.read().split('\n')
stop_word={}
for s in stop:stop_word[s]=True
def remove_stopwords(datas):  filtered_words = [text for text in datas if text not in stop_word]return ' '.join(filtered_words)  

进行中文分词、转换为token序列

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')def convert(datas, max_seq_length=40):ans=[]for text in datas:input_ids = tokenizer(text, max_seq_len=max_seq_length)['input_ids']input_ids = input_ids[:max_seq_length]  # 截断input_ids = input_ids + [tokenizer.pad_token_id] * (max_seq_length - len(input_ids))  # 填充ans.append(input_ids)return ans

导入数据,进行预处理,数据集在最后

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

 

2.加载模型 

加载预训练模型,冻结大部分参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
model.classifier = paddle.nn.Linear(768, 14)
for name, param in model.named_parameters():if "classifier" not in name and 'bert.pooler.dense' not in name and 'bert.encoder.layers.11' not in name:param.stop_gradient = True

ps:如果只保留classifier用来训练,效果欠佳。

设置超参数,学习率初始设为0.01~0.1

epochs=2
batch_size=1024*4
learning_rate=0.001

损失函数和优化器

criterion = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters())

3.批训练

划分训练集和测试集

datas=np.array(datas)
labels=np.array(labels)
x_train,x_test,y_train,y_test=train_test_split(datas,labels,random_state=42,test_size=0.2)
train_dataset=TensorDataset([x_train,y_train])
train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

迭代分批训练,可视化损失函数

total_loss=[]
for epoch in range(epochs):for batch_data,batch_label in train_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(epoch,loss.numpy()[0])total_loss.append(loss.numpy()[0])optimizer.clear_grad()loss.backward()optimizer.step()
paddle.save({'model':model.state_dict()},'model.param')
paddle.save({'optimizer':optimizer.state_dict()},'optimizer.param')
plt.plot(range(len(total_loss)),total_loss)
plt.show()

4.准确率

在测试集上如法炮制,查看准确率

total_loss=[]
x_test=np.array(x_test)
y_test=np.array(y_test)
test_dataset=TensorDataset([x_test,y_test])
test_loader=DataLoader(test_dataset,shuffle=True,batch_size=batch_size)with paddle.no_grad():for batch_data,batch_label in test_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(loss)outputs=paddle.argmax(outputs,axis=1)total_loss.append(loss.numpy()[0])score=0for predict,label in zip(outputs,batch_label):if predict==label:score+=1print(score/len(batch_label))
plt.plot(range(len(total_loss)),total_loss)
plt.show()

最后在验证集上输出要求的类别

arr=['财经','彩票','房产','股票','家居','教育','科技','社会','时尚','时政','体育','星座','游戏','娱乐']
evals=[]
contetns=[]
with open('data/data126283/data/Test.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continueevals.append(item)
evals=convert(evals)
evals=np.array(evals)
with paddle.no_grad():for i in range(0,len(evals),2048):i=min(len(evals),i)batch_data=evals[i:i+2048]batch_data=paddle.to_tensor(batch_data,dtype='int64')predict=model(batch_data)predict=list(paddle.argmax(predict,axis=1))print(i,len(predict))for j in range(len(predict)):predict[j]=arr[predict[j]]with open('result.txt',mode='a',encoding='utf-8') as f:f.write('\n'.join(predict))f.write('\n')

ps:注意最后的f.write('\n'),否则除第一次,每次打印少一行,很坑

最后损失函数收敛在0.2或0.1左右比较正常,四舍五入差不多90准确率,当然如果你解冻更多参数,自然可以更加精确,看运行环境的配置了,建议不要使用免费平台配置,否则比乌龟还慢。。

欢迎提出问题

数据集

这篇关于基于飞浆NLP的BERT-finetuning新闻文本分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/364782

相关文章

使用Python实现文本转语音(TTS)并播放音频

《使用Python实现文本转语音(TTS)并播放音频》在开发涉及语音交互或需要语音提示的应用时,文本转语音(TTS)技术是一个非常实用的工具,下面我们来看看如何使用gTTS和playsound库将文本... 目录什么是 gTTS 和 playsound安装依赖库实现步骤 1. 导入库2. 定义文本和语言 3

Python实现常用文本内容提取

《Python实现常用文本内容提取》在日常工作和学习中,我们经常需要从PDF、Word文档中提取文本,本文将介绍如何使用Python编写一个文本内容提取工具,有需要的小伙伴可以参考下... 目录一、引言二、文本内容提取的原理三、文本内容提取的设计四、文本内容提取的实现五、完整代码示例一、引言在日常工作和学

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

Java实现将Markdown转换为纯文本

《Java实现将Markdown转换为纯文本》这篇文章主要为大家详细介绍了两种在Java中实现Markdown转纯文本的主流方法,文中的示例代码讲解详细,大家可以根据需求选择适合的方案... 目录方法一:使用正则表达式(轻量级方案)方法二:使用 Flexmark-Java 库(专业方案)1. 添加依赖(Ma

Linux使用cut进行文本提取的操作方法

《Linux使用cut进行文本提取的操作方法》Linux中的cut命令是一个命令行实用程序,用于从文件或标准输入中提取文本行的部分,本文给大家介绍了Linux使用cut进行文本提取的操作方法,文中有详... 目录简介基础语法常用选项范围选择示例用法-f:字段选择-d:分隔符-c:字符选择-b:字节选择--c

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

通过C#获取PDF中指定文本或所有文本的字体信息

《通过C#获取PDF中指定文本或所有文本的字体信息》在设计和出版行业中,字体的选择和使用对最终作品的质量有着重要影响,然而,有时我们可能会遇到包含未知字体的PDF文件,这使得我们无法准确地复制或修改文... 目录引言C# 获取PDF中指定文本的字体信息C# 获取PDF文档中用到的所有字体信息引言在设计和出

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

Java操作xls替换文本或图片的功能实现

《Java操作xls替换文本或图片的功能实现》这篇文章主要给大家介绍了关于Java操作xls替换文本或图片功能实现的相关资料,文中通过示例代码讲解了文件上传、文件处理和Excel文件生成,需要的朋友可... 目录准备xls模板文件:template.xls准备需要替换的图片和数据功能实现包声明与导入类声明与

python解析HTML并提取span标签中的文本

《python解析HTML并提取span标签中的文本》在网页开发和数据抓取过程中,我们经常需要从HTML页面中提取信息,尤其是span元素中的文本,span标签是一个行内元素,通常用于包装一小段文本或... 目录一、安装相关依赖二、html 页面结构三、使用 BeautifulSoup javascript