大模型之二十八-语音识别Whisper进阶

2024-08-30 08:12

本文主要是介绍大模型之二十八-语音识别Whisper进阶,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在上一篇博客大模型之二十七-语音识别Whisper实例浅析中遗留了几个问题,这里来看一下前两个问题。
1.如果不是Huggingface上可以下载的数据该怎么办?
2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?

进阶内容

在Whisper语音识别fine-tune的例子中,我们使用的是Huggingface封装好的数据加载以及Transformer工具,这将很多底层细节对开发人员屏蔽了,但是对于技术人员而言,这还远远不够,本篇通过一个要解决两个问题:
1.数据集是私有的,并不是Huggingface开源的数据集
2.不使用Huggingface封装好的Training pipeline,在Whisper开源的源代码基础之上fine-tune模型,并验证准确性。

整个框架代码使用pytorch-lightning来实现,目前很多优秀的比较大的开源都是实用pytorch-lightning来实现的。

安装一些python库

首先下载Whisper源代码,并且

! pip install git+https://github.com/openai/whisper.git
! pip install jiwer 
! pip install pytorch-lightning==2.4.0
! pip install -qqq evaluate==0.2.2

导入必要的python包

import os
import glob
import numpy as nptry:import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:passimport torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as atfrom pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLoggerfrom tqdm.notebook import tqdm
import evaluatefrom transformers import (AdamW,get_linear_schedule_with_warmup
)

遗留的第一个问题–数据集

这里的数据集基于清华大学开源的30小时中文照着文本读而录的音频,原下载地址,
为了减小资源的开销,在有限的资源下,多迭代epoch,这里对数据集做了处理:

  • 将数据集缩到了10个小时/30小时,
  • 去掉了txt里音素的标注,只留文本,因为在这个数据集开源的时候,那时语音识别系统还是基于音素的。

可以关注私信我,联系索取处理之后的语料。

数据集处理

import globDATASET_DIR = "/kaggle/input/th30-all"
SAMPLE_RATE = 16000
BATCH_SIZE = 4
TRAIN_RATE = 0.85#whipser的输入是30s,16kHz采样率,最长480000 sample
AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120DEVICE = "gpu" if torch.cuda.is_available() else "cpu"###################### 读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))
13388

读取数据信息并分离出train和val

dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:waveform, sr = torchaudio.load(wave_path, normalize=True)if sample_rate != sr:waveform = at.Resample(sr, sample_rate)(waveform)return waveformdef get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):audio_transcript_pair_list = []for transcripts_path in tqdm(transcripts_path_list):# audio文件目录确认audio_dir = os.path.dirname(transcripts_path)# 从翻译文本获取音频和文本with open(transcripts_path, "r") as f:text_list = f.readlines()for text in text_list:audio_id, text = text.replace("\n", "").split(":")#print(audio_id, text)audio_path = os.path.join(audio_dir, f"{audio_id}.wav")if os.path.exists(audio_path):# 检查数据audio = load_wave(audio_path, sample_rate=sample_rate)[0]if len(text) > text_max_length or len(audio) > audio_max_sample_length:print(len(text), len(audio))continueaudio_transcript_pair_list.append((audio_id, str(audio_path), text))return audio_transcript_pair_listtrain_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))
133880%|          | 0/11379 [00:00<?, ?it/s]  0%|          | 0/2009 [00:00<?, ?it/s]
TRAIN AUDIO DATASET NUM:  11379
EVAL AUDIO DATASET NUM:  2009

Data loader

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
wmodel = whisper.load_model(name="small",download_root="./whisper-small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=woptions.task)class Th30Dataset(torch.utils.data.Dataset):def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:super().__init__()self.audio_info_list = audio_info_listself.sample_rate = sample_rateself.tokenizer = tokenizerdef __len__(self):return len(self.audio_info_list)def __getitem__(self, index):audio_id, audio_path, text = self.audio_info_list[index]#aduio monoaudio = load_wave(audio_path, sample_rate=self.sample_rate)audio = whisper.pad_or_trim(audio.flatten(), AUDIO_MAX_LENGTH)mel = whisper.log_mel_spectrogram(audio)#texttext = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)labels = text[1:] + [self.tokenizer.eot]return {"input_ids": mel,"labels": labels,"dec_input_ids": text}
class WhisperDataCollatorWhithPadding:def __call__(self, features):input_ids, labels, dec_input_ids = [], [], []for f in features:input_ids.append(f["input_ids"])labels.append(f["labels"])dec_input_ids.append(f["dec_input_ids"])input_ids = torch.concat([input_id[None, :] for input_id in input_ids])label_lengths = [len(lab) for lab in labels]dec_input_ids_length = [len(e) for e in dec_input_ids]max_label_len = max(label_lengths + dec_input_ids_length)labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token idbatch = {"labels": labels,"dec_input_ids": dec_input_ids}batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}batch["input_ids"] = input_idsreturn batchdataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

这是典型的Pytorch而不是前篇中Huggingface的数据加载方法,需要实现datasetDataLoader,详细参考Pytorch Lightning官方文档。至此,遗留的第一个问题解决。

验证数据集加载

DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
for b in loader:print(b["labels"].shape)print(b["input_ids"].shape)print(b["dec_input_ids"].shape)for token, dec in zip(b["labels"], b["dec_input_ids"]):token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)dec[dec == -100] = wtokenizer.eottext = wtokenizer.decode(dec)print(text)break
torch.Size([2, 50])
torch.Size([2, 80, 3000])
torch.Size([2, 50])
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着

验证解码器

with torch.no_grad():audio_features = wmodel.encoder(b["input_ids"].cuda())input_ids = b["input_ids"]labels = b["labels"].long()dec_input_ids = b["dec_input_ids"].long()audio_features = wmodel.encoder(input_ids.cuda())print(dec_input_ids)print(input_ids.shape, dec_input_ids.shape, audio_features.shape)print(audio_features.shape)print()# 计算解码器的输出
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)
tensor([[50258, 50260, 50359, 50363, 45161, 11386, 47446,  5708,  5266,   104,5823, 35825, 20708, 17682,  3023,   222,  5975,  1787,   106, 44365,3919,    95, 31906,  1593,   247, 11160, 31382,  1881,   237, 31382,39861,  5155, 39152,  5648,   247, 10928,  3330,   123, 18937,  2289,5881,   249,  5419,   122, 50257, 50257, 50257, 50257, 50257, 50257],[50258, 50260, 50359, 50363, 12744, 25281, 22694,  6734, 42503, 12088,3308,   111, 36257,  4479,   223,  4035, 32045, 41111,   236,  3308,116,  7391,   102,  4479,   245, 11808,   246,  3416,   105, 33597,45506, 37960,  8501,   244, 45506, 37960,  3316,   238, 45506, 17819,99, 15789,  7732,   100, 44059, 10928, 48839, 16337,   234, 20708]])
torch.Size([2, 80, 3000]) torch.Size([2, 50]) torch.Size([2, 1500, 768])
torch.Size([2, 1500, 768])torch.Size([2, 50, 51865])
torch.Size([100, 51865])
torch.Size([100])

token转文本输出

tokens = torch.argmax(out, dim=2)
for token in tokens:token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)
<|zh|><|translate|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙的避开了矛盾<|endoftext|><|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去,定河兩旁人身顎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>

构造trainer

class Config:learning_rate = 0.0001weight_decay = 0.01adam_epsilon = 1e-8warmup_steps = 2batch_size = 16num_worker = 2num_train_epochs = 1000gradient_accumulation_steps = 1sample_rate = SAMPLE_RATEclass WhisperModelModule(LightningModule):def __init__(self, cfg: Config, model_name="small", lang="zh", train_dataset=[], eval_dataset=[]) -> None:super().__init__()self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)self.model = whisper.load_model(model_name)self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=self.options.task)# only decoder trainingfor p in self.model.encoder.parameters():p.requires_grad = Falseself.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)self.metrics_wer = evaluate.load("wer")self.metrics_cer = evaluate.load("cer")self.cfg = cfgself.__train_dataset = train_datasetself.__eval_dataset = eval_datasetdef forward(self, x):return self.model(x)def training_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()with torch.no_grad():audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))self.log("train/loss", loss, on_step=False, on_epoch=True,  prog_bar=True, logger=True)return lossdef on_train_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("train/loss")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Training - Loss: {avg_loss:.4f}")def validation_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))out[out == -100] = self.tokenizer.eotlabels[labels == -100] = self.tokenizer.eoto_list, l_list = [], []for o, l in zip(out, labels):o = torch.argmax(o, dim=1)o_list.append(self.tokenizer.decode(o))l_list.append(self.tokenizer.decode(l))wer = self.metrics_wer.compute(references=l_list, predictions=o_list)self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log("val/wer", wer, on_step=False, on_epoch=True, prog_bar=True, logger=True)# 打印到终端#print(f"Validation - Loss: {loss:.4f}, WER: {wer:.4f}")return {"wer": wer,"loss": loss}def on_validation_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("val/loss")avg_wer = self.trainer.callback_metrics.get("val/wer")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Validation - Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}")def configure_optimizers(self):"""创建优化程序和调度器 """model = self.modelno_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters()if not any(nd in n for nd in no_decay)],"weight_decay": self.cfg.weight_decay,},{"params": [p for n, p in model.named_parameters()if any(nd in n for nd in no_decay)],"weight_decay": 0.0,},]optimizer = AdamW(optimizer_grouped_parameters,lr=self.cfg.learning_rate,eps=self.cfg.adam_epsilon)self.optimizer = optimizerscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.cfg.warmup_steps,num_training_steps=self.t_total)self.scheduler = schedulerreturn [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]def setup(self, stage=None):"""初始设置(读取数据集)"""if stage == 'fit' or stage is None:self.t_total = ((len(self.__train_dataset) // (self.cfg.batch_size))// self.cfg.gradient_accumulation_steps* float(self.cfg.num_train_epochs))def train_dataloader(self):""" 创建训练数据加载程序 """dataset = Th30Dataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())def val_dataloader(self):""" 创建验证数据加载程序 """dataset = Th30Dataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())  

主要是对LightningModule类相关方法的重载,定义了train、validate以及optimizer的行为,以及在训练过程中日志和相关信息、checkpoint的保存。

启动训练

log_output_dir = "./logs"
check_output_dir = "./artifacts"train_name = "whisper"
train_id = "00001"model_name = "small"
lang = "zh"cfg = Config()# os.mkdir(log_output_dir)
# os.mkdir(check_output_dir)tflogger = TensorBoardLogger(save_dir=log_output_dir,name=train_name,version=train_id
)from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpointcheckpoint_callback = ModelCheckpoint(dirpath=f"{check_output_dir}/checkpoint",filename="checkpoint-{epoch:04d}",save_top_k=2, # all model savesave_on_train_epoch_end=False,monitor='val/wer',  # 需要监控的验证损失mode='min',  # 最小化 val_lossverbose=True  # 打印更多的信息到控制台
)
callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)trainer = Trainer(precision=16,accelerator="gpu",max_epochs=cfg.num_train_epochs,check_val_every_n_epoch=2,accumulate_grad_batches=cfg.gradient_accumulation_steps,logger=tflogger,callbacks=callback_list
)trainer.fit(model)
```shell
对于10小时数据集,你可能看到如下输出:

Epoch: 0, Training - Loss: 1.0872
Epoch: 1, Validation - Loss: 0.4207, WER: 0.9847
Epoch: 1, Training - Loss: 0.2955
Epoch: 2, Training - Loss: 0.5555
Epoch: 3, Validation - Loss: 0.2505, WER: 0.9006
Epoch: 3, Training - Loss: 0.0979
Epoch: 4, Training - Loss: 0.0602
Epoch: 5, Validation - Loss: 0.2889, WER: 0.8764
Epoch: 5, Training - Loss: 0.0721
Epoch: 6, Training - Loss: 0.0947
Epoch: 7, Validation - Loss: 0.3839, WER: 0.9809
Epoch: 7, Training - Loss: 0.1379

对于30小时数据集,即使用完整的th30数据,其中85%用于Training,而15%用于validation你可能看到如下输出:
```shell
3051.3s	121	Epoch: 1, Validation - Loss: 0.0588, WER: 0.3499
3061.1s	122	Epoch: 1, Training - Loss: 0.0340
4279.7s	123	Epoch: 2, Training - Loss: 0.0268
5691.4s	124	Epoch: 3, Validation - Loss: 0.0676, WER: 0.8318
5701.1s	125	Epoch: 3, Training - Loss: 0.0201
6919.2s	126	Epoch: 4, Training - Loss: 0.0257
8329.5s	127	Epoch: 5, Validation - Loss: 0.0484, WER: 0.8472
8329.5s	128	Epoch: 5, Training - Loss: 0.0144
9547.5s	129	Epoch: 6, Training - Loss: 0.1127
10959.3s	130	Epoch: 7, Validation - Loss: 0.0422, WER: 0.3982
10969.4s	131	Epoch: 7, Training - Loss: 0.0053
12188.2s	132	Epoch: 8, Training - Loss: 0.0076
13600.0s	133	Epoch: 9, Validation - Loss: 0.0482, WER: 0.8158
13600.0s	134	Epoch: 9, Training - Loss: 0.0126
14819.0s	135	Epoch: 10, Training - Loss: 0.0152
16230.8s	136	Epoch: 11, Validation - Loss: 0.0544, WER: 0.6829
16230.8s	137	Epoch: 11, Training - Loss: 0.0114
17450.1s	138	Epoch: 12, Training - Loss: 0.0174
18862.4s	139	Epoch: 13, Validation - Loss: 0.0523, WER: 0.3225
18872.1s	140	Epoch: 13, Training - Loss: 0.0117
20091.5s	141	Epoch: 14, Training - Loss: 0.0075
21503.2s	142	Epoch: 15, Validation - Loss: 0.0567, WER: 0.5187
21503.2s	143	Epoch: 15, Training - Loss: 0.0137
22722.5s	144	Epoch: 16, Training - Loss: 0.0150
24134.2s	145	Epoch: 17, Validation - Loss: 0.0631, WER: 0.4559
24134.2s	146	Epoch: 17, Training - Loss: 0.0122
25352.9s	147	Epoch: 18, Training - Loss: 0.0120
26765.0s	148	Epoch: 19, Validation - Loss: 0.0523, WER: 0.7387
26765.0s	149	Epoch: 19, Training - Loss: 0.0060
27983.9s	150	Epoch: 20, Training - Loss: 0.0154
29395.1s	151	Epoch: 21, Validation - Loss: 0.0520, WER: 0.4749
29395.1s	152	Epoch: 21, Training - Loss: 0.0073
30612.5s	153	Epoch: 22, Training - Loss: 0.6361
32022.4s	154	Epoch: 23, Validation - Loss: 0.0396, WER: 0.2912
32033.0s	155	Epoch: 23, Training - Loss: 0.0029
33250.8s	156	Epoch: 24, Training - Loss: 0.0036
34662.0s	157	Epoch: 25, Validation - Loss: 0.0461, WER: 0.6043
34662.0s	158	Epoch: 25, Training - Loss: 0.0094
35880.1s	159	Epoch: 26, Training - Loss: 0.0082
37291.0s	160	Epoch: 27, Validation - Loss: 0.0428, WER: 0.7481
37291.0s	161	Epoch: 27, Training - Loss: 0.0051
38509.7s	162	Epoch: 28, Training - Loss: 0.0075
39920.4s	163	Epoch: 29, Validation - Loss: 0.0447, WER: 0.8736
39920.4s	164	Epoch: 29, Training - Loss: 0.0091
41138.9s	165	Epoch: 30, Training - Loss: 0.0088
42549.9s	166	Epoch: 31, Validation - Loss: 0.0530, WER: 0.4500
42549.9s	167	Epoch: 31, Training - Loss: 0.0072

遗留的第二个问题

首先是数据集的问题,因为可以看到随着时长的增加,看到模型训练过程在符合预期方向走,

  1. 最低数据量:起步来说,至少需要几个小时的音频数据来进行有效的fine-tuning。例如,从10小时开始,这是一个相对较小的数据集,可以用来调试模型和流程。

  2. 中等数据量:为了获得更佳的效果,推荐使用20至50小时的音频数据。这可以帮助模型更好地学习到特定语言的特性。

  3. 理想数据量:如果资源允许,使用超过100小时的音频数据将更有助于模型性能的提升。更多的数据可以显著提高模型的泛化能力和准确性。

当然对于大模型,数据质量越高越好,数据多样性越多越好。
进一步通过tensorboard图可以看到:
请添加图片描述
在运行12个小时之后可以看到WER比一开始的确实下降了不少,但是还没有达到20%左右,最低的WER在0.2912,但是这里可以观察到一个非常有趣的现象:

在观察到训练损失(Training Loss)持续下降而验证损失(Validation Loss)和字错误率 (WER, Word Error Rate) 没有持续改善或波动较大的情况时,这通常是过拟合的一个迹象。在这种情况下,模型在训练数据上表现得越来越好,但在未见过的验证数据上的表现却没有相对应的提升,甚至出现恶化。

由于callback回调中会保持前两个在验证集上WER最小的两个checkpoint,接下来有几个思路:

1.分析模型在验证集上的错误,看是否存在特定模式或类型的错误,这可能帮助诊断问题并指导进一步模型调整,因为我们是在whipser开源的基础上fine-tune的,所以不可能简化模型结构的本身,如减少层数或神经元数目以改善过拟合。

2.可以考虑正则化技术(L2正则化、Dropout)等以有助于缓解过拟现象,增强模型的泛化能力

3.调整训练策略,调整学习率或者使用不同的优化器,以评估模型在验证集上的表现;

4.增加更多数据,帮助模型学习到更多特征,从而提高模型泛化能力

观察验证集识别效果

由于输出缩略或视觉上的相似性,一些小的差异(如标点、空白或特殊字符)可能不容易觉察。这些微小的差异在计算WER时会被考虑进去,但在人眼检查时可能会被忽略。
请添加图片描述
可以看到基本上是一致的,但是个别词是有出入的,这是因为th30是人工读的,准确性比较高,并不意味着通话、会议、游戏场景的识别率也能如此。
这里再留几个尾巴给读者自己实现:

CER

1.中文是基于字符的语言,通常我们会使用CER(Character Error Rate,字符错误率)来进行更精确的评估。然而,如果你使用的是WER来评估中文语音识别的质量,这里有几点可能需要注意:

  • 在处理中文时,如果WER是基于词的,就必须先进行准确的分词。中文没有明显的词与词之间的分隔,因此分词的准确性对于WER的计算非常关键。错误的分词可能导致高WER,即使识别的字符完全正确。
  • 中文中的一些微小差异,如同音字错误、词序变化或者是语气词的使用,都可以在视觉上看起来非常相似,但在WER的计算中会被视为错误。
  • 你查看的样本可能并不代表整体数据集的平均表现。此外,中文语音识别可能特别擅长处理某些特定的语句或者在某些领域表现更好。

数据质量

除了数据量之外,数据的质量也非常重要:

  • 多样性:数据应该涵盖多种口音、语速和语调,以及不同的背景噪声环境,这将帮助模型在各种输入条件下都能保持稳定的表现。
  • 标注准确性:确保你的数据标注尽可能准确,错误的标注会直接影响模型学习的结果。

预处理和增强

  • 预处理:对音频进行预处理,如采样率转换(确保和模型训练时使用的采样率一致),音量标准化等。
  • 数据增强:可以考虑使用音频数据增强技术,如添加背景噪声、改变语速和音高等,以增加模型的鲁棒性。

资源和迭代

  • 计算资源:fine-tuning一个语音识别模型可能需要大量的计算资源,特别是当使用大量数据时。确保你有足够的GPU资源进行训练。
  • 迭代和评估:在fine-tuning过程中需要多次迭代和评估,以找到最优的模型参数和设置。

总结来说,训练的过程如炼丹,有些训练的经验是不能从小模型直接用到大模型上的。比如small和对于large-v3两种。
在模型相对较小的时候,learning rate的设置可以比较激进,但是对非常大模型的时候,较大的lr可能导致模型一开始loss就无法收敛,是发散的,但如果设置的lr比较小,那可能使得训练的时长成倍增加,怎么办呢?针对很大的模型,warm-up策略是很多时候会使用的。

load weight and inference

checkpoint_path = "whisper-checkpoint/checkpoint-epoch0023.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
/tmp/ipykernel_36/4099222220.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.state_dict = torch.load(checkpoint_path)
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

加载模型参数

cfg = Config()
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 87.6MiB/s]
Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]
Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]
<All keys matched successfully>

前向推理

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())refs = []
res = []
for b in tqdm(loader):input_ids = b["input_ids"].half().cuda()labels = b["labels"].long().cuda()with torch.no_grad():#audio_features = whisper_model.model.encoder(input_ids)#out = whisper_model.model.decoder(enc_input_ids, audio_features)results = whisper_model.model.decode(input_ids, woptions)for r in results:res.append(r.text)for l in labels:l[l == -100] = wtokenizer.eotref = wtokenizer.decode(l)refs.append(ref)```打印推理结果```for k, v in zip(refs, res):print("-"*10)print(k)print(v)

部分输出结果

  ----------
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙地避开了矛盾
----------
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
----------
<|zh|><|transcribe|><|notimestamps|>旅与游的时间比往往旅长游短与游客的愿望相悖<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
旅与游的时间比往往旅长游短与游客的愿望相悖
----------
<|zh|><|transcribe|><|notimestamps|>中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职<|endoftext|>
中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职
----------
<|zh|><|transcribe|><|notimestamps|>该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等<|endoftext|>
该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等
----------
<|zh|><|transcribe|><|notimestamps|>何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品<|endoftext|><|endoftext|>
何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品
----------
<|zh|><|transcribe|><|notimestamps|>印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明<|endoftext|>
印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明
----------
<|zh|><|transcribe|><|notimestamps|>今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下
----------
<|zh|><|transcribe|><|notimestamps|>亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军<|endoftext|>
亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军
----------
<|zh|><|transcribe|><|notimestamps|>小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤<|endoftext|><|endoftext|><|endoftext|>
小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤
----------
<|zh|><|transcribe|><|notimestamps|>这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血
----------
<|zh|><|transcribe|><|notimestamps|>驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声<|endoftext|>
驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声
----------
<|zh|><|transcribe|><|notimestamps|>其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮运而享誉海内外<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮存而享誉海内外
----------
<|zh|><|transcribe|><|notimestamps|>一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人<|endoftext|>
一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人
----------
<|zh|><|transcribe|><|notimestamps|>当船往下漂时白唇鹿扬起四蹄在岸边追随好像是送行一直跑了十几里太亲切了<|endoftext|>
当船往下漂时白唇鹿扬起四蹄在岸边追赶好像是送行一直跑了十几里太亲切了
----------
<|zh|><|transcribe|><|notimestamps|>有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额
----------
<|zh|><|transcribe|><|notimestamps|>女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍<|endoftext|>
女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍
----------
<|zh|><|transcribe|><|notimestamps|>日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌<|endoftext|><|endoftext|><|endoftext|>
日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌
----------
<|zh|><|transcribe|><|notimestamps|>如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉<|endoftext|><|endoftext|><|endoftext|>
如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉```
接下来还有三个问题对于应用更需要细致考虑:
1.Whisper除了识别,还有直接翻译功能,在以前要先识别成中文,再汉译英等,这个好处是显而易见的,首先只要一个模型,节约部分人力、机器以及服务端GPU,业务场景上可以是会议的实时翻译、看英文视频实时翻译成中文,这会减少latency,用户体验也更好;
2.如何在实时的流式场景中使用?
3.kv-caching是个什么技术?12倍是如何做到的?这在工程部署商用价值非常大。欢迎点赞、收藏、关注,以便及时收到下一篇推送。

这篇关于大模型之二十八-语音识别Whisper进阶的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120266

相关文章

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Java进阶13讲__第12讲_1/2

多线程、线程池 1.  线程概念 1.1  什么是线程 1.2  线程的好处 2.   创建线程的三种方式 注意事项 2.1  继承Thread类 2.1.1 认识  2.1.2  编码实现  package cn.hdc.oop10.Thread;import org.slf4j.Logger;import org.slf4j.LoggerFactory

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}