bert文本分类微调笔记

2024-06-22 22:04
文章标签 笔记 分类 微调 文本 bert

本文主要是介绍bert文本分类微调笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Bert实现文本分类微调Demo

import random
from collections import namedtuple'''
有四种文本需要做分类,请使用bert处理这个分类问题
'''# 使用namedtuple定义一个类别(Category),包含两个字段:名称(name)和样例(samples)
Category = namedtuple('Category', ['name', 'samples'])# 定义四个不同的类别及其对应的样例文本
categories = [Category('Weather Forecast', ['今天北京晴转多云,气温20-25度。', '明天上海有小雨,记得带伞。']),  # 天气预报类别的样例Category('Company Financial Report', ['本季度公司净利润增长20%。', '年度财务报告显示,成本控制良好。']),  # 公司财报类别的样例Category('Company Audit Materials', ['审计发现内部控制存在漏洞。', '审计确认财务报表无重大错报。']),  # 公司审计材料类别的样例Category('Product Marketing Ad', ['新口味可乐,清爽上市!', '买一送一,仅限今日。'])  # 产品营销广告类别的样例
]def generate_data(num_samples_per_category=50):''' 生成模拟数据集输入:- num_samples_per_category: 每个类别生成的样本数量,默认为50输出:- data: 包含文本样本及其对应类别的列表,每项为一个元组(text, label)'''data = []  # 初始化存储数据的列表for category in categories:  # 遍历所有类别for _ in range(num_samples_per_category):  # 对每个类别生成指定数量的样本sample = random.choice(category.samples)  # 从该类别的样例中随机选择一条文本data.append((sample, category.name))  # 将文本及其类别添加到data列表中return data# 调用generate_data函数生成模拟数据集
train_data = generate_data(100)  # 为每个类别生成100个训练样本
test_data = generate_data(6)     # 生成少量(6个)测试样本用于演示'''
train_data = 
[('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),]
'''from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn.functional as F# 步骤1: 定义类别到标签的映射
label_map = {category.name: index for index, category in enumerate(categories)}
num_labels = len(categories)  # 类别总数# 步骤2: 初始化BERT分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)# 步骤3: 准备数据集
def encode_texts(texts, labels):# 对文本进行编码,得到BERT模型需要的输入格式encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')# 将标签名称转换为对应的索引label_ids = torch.tensor([label_map[label] for label in labels])return encodings, label_idsdef prepare_data(data):texts, labels = zip(*data)  # 解压数据encodings, label_ids = encode_texts(texts, labels)  # 编码数据dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'], label_ids)  # 创建数据集return DataLoader(dataset, batch_size=8, shuffle=True)  # 创建数据加载器# 步骤4: 准备训练和测试数据
train_loader = prepare_data(train_data)
test_loader = prepare_data(test_data)# 步骤5: 定义训练和评估函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)def train_epoch(model, data_loader, optimizer):model.train()total_loss = 0for batch in data_loader:optimizer.zero_grad()input_ids, attention_mask, labels = batchinput_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()loss.backward()optimizer.step()return total_loss / len(data_loader)def evaluate(model, data_loader):model.eval()total_acc = 0total_count = 0with torch.no_grad():for batch in data_loader:input_ids, attention_mask, labels = batchinput_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)outputs = model(input_ids, attention_mask=attention_mask)predictions = torch.argmax(outputs.logits, dim=1)total_acc += (predictions == labels).sum().item()total_count += labels.size(0)return total_acc / total_count# 步骤6: 训练模型
optimizer = AdamW(model.parameters(), lr=2e-5)for epoch in range(3):  # 训练3个epochtrain_loss = train_epoch(model, train_loader, optimizer)acc = evaluate(model, test_loader)print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Test Accuracy: {acc*100:.2f}%')# 步骤7: 使用微调后的模型进行预测
def predict(text):encodings = tokenizer(text, truncation=True, padding=True, return_tensors='pt')input_ids = encodings['input_ids'].to(device)attention_mask = encodings['attention_mask'].to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_mask)predicted_class_id = torch.argmax(outputs.logits).item()return categories[predicted_class_id].name# 预测一个新文本
new_text = ["明天的天气怎么样?"]  # 注意这里是一个列表
predicted_category = predict(new_text)
print(f'The predicted category for the new text is: {predicted_category}')

这篇关于bert文本分类微调笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1085540

相关文章

使用Python实现文本转语音(TTS)并播放音频

《使用Python实现文本转语音(TTS)并播放音频》在开发涉及语音交互或需要语音提示的应用时,文本转语音(TTS)技术是一个非常实用的工具,下面我们来看看如何使用gTTS和playsound库将文本... 目录什么是 gTTS 和 playsound安装依赖库实现步骤 1. 导入库2. 定义文本和语言 3

Python实现常用文本内容提取

《Python实现常用文本内容提取》在日常工作和学习中,我们经常需要从PDF、Word文档中提取文本,本文将介绍如何使用Python编写一个文本内容提取工具,有需要的小伙伴可以参考下... 目录一、引言二、文本内容提取的原理三、文本内容提取的设计四、文本内容提取的实现五、完整代码示例一、引言在日常工作和学

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

Java实现将Markdown转换为纯文本

《Java实现将Markdown转换为纯文本》这篇文章主要为大家详细介绍了两种在Java中实现Markdown转纯文本的主流方法,文中的示例代码讲解详细,大家可以根据需求选择适合的方案... 目录方法一:使用正则表达式(轻量级方案)方法二:使用 Flexmark-Java 库(专业方案)1. 添加依赖(Ma

Linux使用cut进行文本提取的操作方法

《Linux使用cut进行文本提取的操作方法》Linux中的cut命令是一个命令行实用程序,用于从文件或标准输入中提取文本行的部分,本文给大家介绍了Linux使用cut进行文本提取的操作方法,文中有详... 目录简介基础语法常用选项范围选择示例用法-f:字段选择-d:分隔符-c:字符选择-b:字节选择--c

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

通过C#获取PDF中指定文本或所有文本的字体信息

《通过C#获取PDF中指定文本或所有文本的字体信息》在设计和出版行业中,字体的选择和使用对最终作品的质量有着重要影响,然而,有时我们可能会遇到包含未知字体的PDF文件,这使得我们无法准确地复制或修改文... 目录引言C# 获取PDF中指定文本的字体信息C# 获取PDF文档中用到的所有字体信息引言在设计和出

Java操作xls替换文本或图片的功能实现

《Java操作xls替换文本或图片的功能实现》这篇文章主要给大家介绍了关于Java操作xls替换文本或图片功能实现的相关资料,文中通过示例代码讲解了文件上传、文件处理和Excel文件生成,需要的朋友可... 目录准备xls模板文件:template.xls准备需要替换的图片和数据功能实现包声明与导入类声明与

python解析HTML并提取span标签中的文本

《python解析HTML并提取span标签中的文本》在网页开发和数据抓取过程中,我们经常需要从HTML页面中提取信息,尤其是span元素中的文本,span标签是一个行内元素,通常用于包装一小段文本或... 目录一、安装相关依赖二、html 页面结构三、使用 BeautifulSoup javascript

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境