Sentence-BERT实现文本匹配【对比损失函数】

2024-09-05 08:28

本文主要是介绍Sentence-BERT实现文本匹配【对比损失函数】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

引言

还是基于Sentence-BERT架构,或者说Bi-Encoder架构,但是本文使用的是参考2中提出的对比损失函数。

架构

image-20210923000654664

如上图,计算两个句嵌入 u \pmb u u v \pmb v v​之间的距离(1-余弦相似度),然后使用参考2中提出的对比损失函数作为目标函数:
L = y × 1 2 ( distance ( u , v ) ) 2 + ( 1 − y ) × 1 2 { max ⁡ ( 0 , m − distance ( u , v ) ) } 2 \mathcal L= y\times \frac{1}{2} (\text{distance}(\pmb u,\pmb v))^2 + (1-y)\times \frac{1}{2} \{ \max(0, m - \text{distance}(\pmb u,\pmb v)) \}^2\\ L=y×21(distance(u,v))2+(1y)×21{max(0,mdistance(u,v))}2
这里的 y y y是真实标签,相似为1,不相似为0; m m m​表示margin(间隔值),默认为0.5。

这里 m m m的意思是,如果 u \pmb u u v \pmb v v不相似( y = 0 y=0 y=0),那么它们之间的距离只要足够大,大于等于间隔值0.5就好了。假设距离为0.6,那么 max ⁡ ( 0 , 0.5 − 0.6 ) = 0 \max(0,0.5-0.6)=0 max(0,0.50.6)=0,如果距离不够大( 0.2 0.2 0.2),那么 max ⁡ ( 0 , 0.5 − 0.2 ) = 0.3 \max(0,0.5-0.2)=0.3 max(0,0.50.2)=0.3,就会产生损失值。

整个公式的目的是拉近相似的文本对,推远不相似的文本对到一定程度就可以了。实现的时候 max ⁡ \max max可以用relu来表示。

实现

实现采用类似Huggingface的形式,每个文件夹下面有一种模型。分为modelingargumentstrainer等不同的文件。不同的架构放置在不同的文件夹内。

modeling.py:

from dataclasses import dataclassimport torch
from torch import Tensor, nnfrom transformers.file_utils import ModelOutputfrom transformers import (AutoModel,AutoTokenizer,
)import numpy as np
from tqdm.autonotebook import trange
from typing import Optionalfrom enum import Enum
import torch.nn.functional as F# 定义了三种距离函数
# 余弦相似度值越小表示越不相似,1减去它就变成了距离函数,越小(余弦越接近1)表示越相似。
class SiameseDistanceMetric(Enum):"""The metric for the contrastive loss"""EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)@dataclass
class BiOutput(ModelOutput):loss: Optional[Tensor] = Nonescores: Optional[Tensor] = Noneclass SentenceBert(nn.Module):def __init__(self,model_name: str,trust_remote_code: bool = True,max_length: int = None,margin: float = 0.5,distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,pooling_mode: str = "mean",normalize_embeddings: bool = False,) -> None:super().__init__()self.model_name = model_nameself.normalize_embeddings = normalize_embeddingsself.device = "cuda" if torch.cuda.is_available() else "cpu"self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)self.max_length = max_lengthself.pooling_mode = pooling_modeself.distance_metric = distance_metricself.margin = margindef sentence_embedding(self, last_hidden_state, attention_mask):if self.pooling_mode == "mean":attention_mask = attention_mask.unsqueeze(-1).float()return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(attention_mask.sum(1), min=1e-9)else:# clsreturn last_hidden_state[:, 0]def encode(self,sentences: str | list[str],batch_size: int = 64,convert_to_tensor: bool = True,show_progress_bar: bool = False,):if isinstance(sentences, str):sentences = [sentences]all_embeddings = []for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):batch = sentences[start_index : start_index + batch_size]features = self.tokenizer(batch,padding=True,truncation=True,return_tensors="pt",return_attention_mask=True,max_length=self.max_length,).to(self.device)out_features = self.model(**features, return_dict=True)embeddings = self.sentence_embedding(out_features.last_hidden_state, features["attention_mask"])if not self.training:embeddings = embeddings.detach()if self.normalize_embeddings:embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)if not convert_to_tensor:embeddings = embeddings.cpu()all_embeddings.extend(embeddings)if convert_to_tensor:all_embeddings = torch.stack(all_embeddings)else:all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])return all_embeddingsdef compute_loss(self, source_embed, target_embed, labels):labels = torch.tensor(labels).float().to(self.device)# 计算距离distances = self.distance_metric(source_embed, target_embed)# 实现损失函数loss = 0.5 * (labels * distances.pow(2)+ (1 - labels) * F.relu(self.margin - distances).pow(2))return loss.mean()def forward(self, source, target, labels) -> BiOutput:"""Args:source :target :"""source_embed = self.encode(source)target_embed = self.encode(target)loss = self.compute_loss(source_embed, target_embed, labels)return BiOutput(loss, None)def save_pretrained(self, output_dir: str):state_dict = self.model.state_dict()state_dict = type(state_dict)({k: v.clone().cpu().contiguous() for k, v in state_dict.items()})self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。

arguments.py:

from dataclasses import dataclass, field
from typing import Optionalimport os@dataclass
class ModelArguments:model_name_or_path: str = field(metadata={"help": "Path to pretrained model"})config_name: Optional[str] = field(default=None,metadata={"help": "Pretrained config name or path if not the same as model_name"},)tokenizer_name: Optional[str] = field(default=None,metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},)@dataclass
class DataArguments:train_data_path: str = field(default=None, metadata={"help": "Path to train corpus"})eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})max_length: int = field(default=512,metadata={"help": "The maximum total input sequence length after tokenization for input text."},)def __post_init__(self):if not os.path.exists(self.train_data_path):raise FileNotFoundError(f"cannot find file: {self.train_data_path}, please set a true path")if not os.path.exists(self.eval_data_path):raise FileNotFoundError(f"cannot find file: {self.eval_data_path}, please set a true path")

定义了模型和数据相关参数。

dataset.py:

from torch.utils.data import Dataset
from datasets import Dataset as dt
import pandas as pdfrom utils import build_dataframe_from_csvclass PairDataset(Dataset):def __init__(self, data_path: str) -> None:df = build_dataframe_from_csv(data_path)self.dataset = dt.from_pandas(df, split="train")self.total_len = len(self.dataset)def __len__(self):return self.total_lendef __getitem__(self, index) -> dict[str, str]:query1 = self.dataset[index]["query1"]query2 = self.dataset[index]["query2"]label = self.dataset[index]["label"]return {"query1": query1, "query2": query2, "label": label}class PairCollator:def __call__(self, features) -> dict[str, list[str]]:queries1 = []queries2 = []labels = []for feature in features:queries1.append(feature["query1"])queries2.append(feature["query2"])labels.append(feature["label"])return {"source": queries1, "target": queries2, "labels": labels}

数据集类考虑了LCQMC数据集的格式,即成对的语句和一个数值标签。类似:

Hello.	Hi.	1
Nice to see you.	Nice	0

trainer.py:

import torch
from transformers.trainer import Trainerfrom typing import Optional
import os
import loggingfrom modeling import SentenceBertTRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)class BiTrainer(Trainer):def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):outputs = model(**inputs)loss = outputs.lossreturn (loss, outputs) if return_outputs else lossdef _save(self, output_dir: Optional[str] = None, state_dict=None):# If we are executing this function, we are the process zero, so we don't check for that.output_dir = output_dir if output_dir is not None else self.args.output_diros.makedirs(output_dir, exist_ok=True)logger.info(f"Saving model checkpoint to {output_dir}")self.model.save_pretrained(output_dir)if self.tokenizer is not None:self.tokenizer.save_pretrained(output_dir)# Good practice: save your training arguments together with the trained modeltorch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承🤗 Transformers的Trainer类,重写了compute_loss_save方法。

这样我们就可以利用🤗 Transformers来训练我们的模型了。

utils.py:

import torch
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from typing import Tupledef build_dataframe_from_csv(dataset_csv: str) -> pd.DataFrame:df = pd.read_csv(dataset_csv,sep="\t",header=None,names=["query1", "query2", "label"],)return dfdef compute_spearmanr(x, y):return spearmanr(x, y).correlationdef compute_pearsonr(x, y):return pearsonr(x, y)[0]def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):"""Copied from https://github.com/UKPLab/sentence-transformers/tree/master"""assert len(scores) == len(labels)rows = list(zip(scores, labels))rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)print(rows)max_acc = 0best_threshold = -1# positive examples number so farpositive_so_far = 0# remain negative examplesremaining_negatives = sum(labels == 0)for i in range(len(rows) - 1):score, label = rows[i]if label == 1:positive_so_far += 1else:remaining_negatives -= 1acc = (positive_so_far + remaining_negatives) / len(labels)if acc > max_acc:max_acc = accbest_threshold = (rows[i][0] + rows[i + 1][0]) / 2return max_acc, best_thresholddef metrics(y: torch.Tensor, y_pred: torch.Tensor) -> Tuple[float, float, float, float]:TP = ((y_pred == 1) & (y == 1)).sum().float()  # True PositiveTN = ((y_pred == 0) & (y == 0)).sum().float()  # True NegativeFN = ((y_pred == 0) & (y == 1)).sum().float()  # False NegatvieFP = ((y_pred == 1) & (y == 0)).sum().float()  # False Positivep = TP / (TP + FP).clamp(min=1e-8)  # Precisionr = TP / (TP + FN).clamp(min=1e-8)  # RecallF1 = 2 * r * p / (r + p).clamp(min=1e-8)  # F1 scoreacc = (TP + TN) / (TP + TN + FP + FN).clamp(min=1e-8)  # Accuraryreturn acc, p, r, F1def compute_metrics(predicts, labels):return metrics(labels, predicts)

定义了一些帮助函数,从sentence-transformers库中拷贝了寻找最佳准确率阈值的实现find_best_acc_and_threshold

除了准确率,还计算了句嵌入的余弦相似度与真实标签之间的斯皮尔曼等级相关系数指标。

最后定义训练和测试脚本。

train.py:

from transformers import set_seed, HfArgumentParser, TrainingArgumentsimport logging
from pathlib import Pathfrom datetime import datetimefrom modeling import SentenceBert
from trainer import BiTrainer
from arguments import DataArguments, ModelArguments
from dataset import PairCollator, PairDatasetlogger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO,
)def main():parser = HfArgumentParser((TrainingArguments, DataArguments, ModelArguments))training_args, data_args, model_args = parser.parse_args_into_dataclasses()# 根据当前时间生成输出目录output_dir = f"{training_args.output_dir}/{model_args.model_name_or_path.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"training_args.output_dir = output_dirlogger.info(f"Training parameters {training_args}")logger.info(f"Data parameters {data_args}")logger.info(f"Model parameters {model_args}")# 设置随机种子set_seed(training_args.seed)# 加载预训练模型model = SentenceBert(model_args.model_name_or_path,trust_remote_code=True,max_length=data_args.max_length,)tokenizer = model.tokenizer# 构建训练和测试集train_dataset = PairDataset(data_args.train_data_path)eval_dataset = PairDataset(data_args.eval_data_path)# 传入参数trainer = BiTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=PairCollator(),tokenizer=tokenizer,)Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)# 开始训练trainer.train()trainer.save_model()if __name__ == "__main__":main()

训练

基于train.py定义了train.sh传入相关参数:

timestamp=$(date +%Y%m%d%H%M)
logfile="train_${timestamp}.log"# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=3 nohup python train.py \--model_name_or_path=hfl/chinese-macbert-large \--output_dir=output \--train_data_path=data/train.txt \--eval_data_path=data/dev.txt \--num_train_epochs=3 \--save_total_limit=5 \--learning_rate=2e-5 \--weight_decay=0.01 \--warmup_ratio=0.01 \--bf16=True \--eval_strategy=epoch \--save_strategy=epoch \--per_device_train_batch_size=64 \--report_to="none" \--remove_unused_columns=False \--max_length=128 \> "$logfile" 2>&1 &

以上参数根据个人环境修改,这里使用的是哈工大的chinese-macbert-large预训练模型。

注意:

  • --remove_unused_columns是必须的。
  • 通过bf16=True可以加速训练同时不影响效果。
  • 其他参数可以自己调整。
100%|██████████| 11193/11193 [47:15<00:00,  4.26it/09/03/2024 18:35:18 - INFO - trainer - Saving model checkpoint to output/hfl-chinese-macbert-large-2024-09-03_17-47-58/checkpoint-11193
100%|██████████| 11193/11193 [47:29<00:00,  3.93it/s]
09/03/2024 18:35:32 - INFO - trainer - Saving model checkpoint to output/hfl-chinese-macbert-large-2024-09-03_17-47-58
{'eval_loss': 0.010313387028872967, 'eval_runtime': 58.5945, 'eval_samples_per_second': 150.219, 'eval_steps_per_second': 18.79, 'epoch': 3.0}
{'train_runtime': 2849.8189, 'train_samples_per_second': 251.349, 'train_steps_per_second': 3.928, 'train_loss': 0.006717115574075615, 'epoch': 3.0}

这里仅训练了3轮,我们拿最后保存的模型output/hfl-chinese-macbert-large-2024-09-03_17-47-58进行测试。

测试

test.py: 测试脚本见后文的完整代码。

test.sh:

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=0 python test.py \--model_name_or_path=output/hfl-chinese-macbert-large-2024-09-03_17-47-58 \--test_data_path=data/test.txt

输出:

TestArguments(model_name_or_path='output/hfl-chinese-macbert-large-2024-09-03_17-47-58/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.76it/s]
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.85it/s]
max_acc: 0.8944, best_threshold: 0.899751
spearman corr: 0.7950 |  pearson_corr corr: 0.7434 | compute time: 22.30s
accuracy=0.894 precision=0.899 recal=0.889 f1 score=0.8938

测试集上的准确率达到89.4%,达到了目前本系列文章的SOTA结果。spearman系数也是最佳的。

这是默认基于余弦距离训练的,修改成欧几里得距离(EUCLIDEAN),其他参数不变的情况下得到结果:

TestArguments(model_name_or_path='output/hfl-chinese-macbert-large-2024-09-03_17-50-30/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.79it/s]
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.85it/s]
max_acc: 0.8881, best_threshold: 0.999987
spearman corr: 0.7795 |  pearson_corr corr: 0.5701 | compute time: 22.26s
accuracy=0.888 precision=0.890 recal=0.885 f1 score=0.8876

准确率是88.8%,这倒没什么, 主要是最佳阈值0.999987太不合理了。

完整代码

完整代码: →点此←

本文代码是和某次提交相关的,Master分支上的代码随时可能会被优化。

参考

  1. [论文笔记]Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
  2. [论文笔记]Dimensionality Reduction by Learning an Invariant Mapping

这篇关于Sentence-BERT实现文本匹配【对比损失函数】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1138434

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

hdu 3065 AC自动机 匹配串编号以及出现次数

题意: 仍旧是天朝语题。 Input 第一行,一个整数N(1<=N<=1000),表示病毒特征码的个数。 接下来N行,每行表示一个病毒特征码,特征码字符串长度在1—50之间,并且只包含“英文大写字符”。任意两个病毒特征码,不会完全相同。 在这之后一行,表示“万恶之源”网站源码,源码字符串长度在2000000之内。字符串中字符都是ASCII码可见字符(不包括回车)。

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、