whisper 模型源码解读

2024-06-16 23:12
文章标签 源码 whisper 解读 模型

本文主要是介绍whisper 模型源码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optionalimport numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function@dataclass
class ModelDimensions:"""该类用于存储模型的各项参数"""n_mels: int  # Mel谱图的频带数量n_audio_ctx: int  # 音频上下文窗口大小n_audio_state: int  # 音频状态维度n_audio_head: int  # 音频注意力头数量n_audio_layer: int  # 音频层数量n_vocab: int  # 词汇表大小n_text_ctx: int  # 文本上下文窗口大小n_text_state: int  # 文本状态维度n_text_head: int  # 文本注意力头数量n_text_layer: int  # 文本层数量class LayerNorm(nn.LayerNorm):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保输入张量的类型在归一化前后保持一致"""return super().forward(x.float()).type(x.dtype)class Linear(nn.Linear):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保权重和偏置与输入张量的类型一致"""return F.linear(x,self.weight.to(x.dtype),None if self.bias is None else self.bias.to(x.dtype),)class Conv1d(nn.Conv1d):def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:"""重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致"""return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))def sinusoids(length, channels, max_timescale=10000):"""生成用于位置嵌入的正弦曲线"""assert channels % 2 == 0log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)class MultiHeadAttention(nn.Module):def __init__(self, n_state: int, n_head: int):"""初始化多头注意力层"""super().__init__()self.n_head = n_headself.query = Linear(n_state, n_state)self.key = Linear(n_state, n_state, bias=False)self.value = Linear(n_state, n_state)self.out = Linear(n_state, n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""多头注意力的前向传播"""q = self.query(x)if kv_cache is None or xa is None or self.key not in kv_cache:# 如果没有缓存键和值,则正常计算k = self.key(x if xa is None else xa)v = self.value(x if xa is None else xa)else:# 如果有缓存,则使用缓存的键和值k = kv_cache[self.key]v = kv_cache[self.value]wv, qk = self.qkv_attention(q, k, v, mask)return self.out(wv), qkdef qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):"""计算 QKV 注意力"""n_batch, n_ctx, n_state = q.shapescale = (n_state // self.n_head) ** -0.25q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scalek = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scalev = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)qk = q @ kif mask is not None:qk = qk + mask[:n_ctx, :n_ctx]qk = qk.float()w = F.softmax(qk, dim=-1).to(q.dtype)return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()class ResidualAttentionBlock(nn.Module):def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):"""初始化残差注意力块"""super().__init__()self.attn = MultiHeadAttention(n_state, n_head)self.attn_ln = LayerNorm(n_state)self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)self.cross_attn_ln = LayerNorm(n_state) if cross_attention else Nonen_mlp = n_state * 4self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))self.mlp_ln = LayerNorm(n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""残差注意力块的前向传播"""x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]if self.cross_attn:x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]x = x + self.mlp(self.mlp_ln(x))return xclass AudioEncoder(nn.Module):def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化音频编码器"""super().__init__()self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])self.ln_post = LayerNorm(n_state)def forward(self, x: Tensor):"""前向传播,处理音频输入x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)音频的Mel谱图"""x = F.gelu(self.conv1(x))x = F.gelu(self.conv2(x))x = x.permute(0, 2, 1)assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"x = (x + self.positional_embedding).to(x.dtype)for block in self.blocks:x = block(x)x = self.ln_post(x)return xclass TextDecoder(nn.Module):def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化文本解码器"""super().__init__()self.token_embedding = nn.Embedding(n_vocab, n_state)self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True)for _ in range(n_layer)])self.ln = LayerNorm(n_state)mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)self.register_buffer("mask", mask, persistent=False)def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):"""前向传播,处理文本输入并结合音频特征x : torch.LongTensor, shape = (batch_size, <= n_ctx)文本的标记序列xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)编码后的音频特征"""offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0x = (self.token_embedding(x)+ self.positional_embedding[offset : offset + x.shape[-1]])x = x.to(xa.dtype)for block in self.blocks:x = block(x, xa, mask=self.mask, kv_cache=kv_cache)x = self.ln(x)logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()return logitsclass Whisper(nn.Module):def __init__(self, dims: ModelDimensions):"""初始化 Whisper 模型"""super().__init__()self.dims = dimsself.encoder = AudioEncoder(self.dims.n_mels,self.dims.n_audio_ctx,self.dims.n_audio_state,self.dims.n_audio_head,self.dims.n_audio_layer,)self.decoder = TextDecoder(self.dims.n_vocab,self.dims.n_text_ctx,self.dims.n_text_state,self.dims.n_text_head,self.dims.n_text_layer,)# 默认情况下,使用解码器层的后一半进行时间对齐;# 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)all_heads[self.dims.n_text_layer // 2 :] = Trueself.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)def set_alignment_heads(self, dump: bytes):"""设置对齐的注意力头"""array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)def embed_audio(self, mel: torch.Tensor):"""编码音频特征"""return self.encoder(mel)def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):"""获取预测的logits"""return self.decoder(tokens, audio_features)def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:"""前向传播"""return self.decoder(tokens, self.encoder(mel))@propertydef device(self):"""获取模型所在的设备"""return next(self.parameters()).device@propertydef is_multilingual(self):"""判断模型是否支持多语言"""return self.dims.n_vocab >= 51865@propertydef num_languages(self):"""获取模型支持的语言数量"""return self.dims.n_vocab - 51765 - int(self.is_multilingual)def install_kv_cache_hooks(self, cache: Optional[dict] = None):"""为键和值的投影模块安装缓存钩子返回-------cache : Dict[nn.Module, torch.Tensor]映射键/值投影模块到其缓存的字典对象hooks : List[RemovableHandle]用于停止调用钩子的 PyTorch RemovableHandle 对象列表"""cache = {**cache} if cache is not None else {}hooks = []def save_to_cache(module, _, output):if module not in cache or output.shape[1] > self.dims.n_text_ctx:# 第一次标记或交叉注意时保存原始值cache[module] = outputelse:cache[module] = torch.cat([cache[module], output], dim=1).detach()return cache[module]def install_hooks(layer: nn.Module):if isinstance(layer, MultiHeadAttention):hooks.append(layer.key.register_forward_hook(save_to_cache))hooks.append(layer.value.register_forward_hook(save_to_cache))self.decoder.apply(install_hooks)return cache, hooksdetect_language = detect_language_function  # 语言检测函数transcribe = transcribe_function  # 转录函数decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

import torch# 初始化模型参数
dims = ModelDimensions(n_mels=80,n_audio_ctx=1500,n_audio_state=512,n_audio_head=8,n_audio_layer=6,n_vocab=51865,n_text_ctx=448,n_text_state=512,n_text_head=8,n_text_layer=6,
)# 创建模型实例
model = Whisper(dims)# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)# 最终生成的文本标记序列
final_tokens = initial_tokens# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

for _ in range(10):logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。

这篇关于whisper 模型源码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1067820

相关文章

一份LLM资源清单围观技术大佬的日常;手把手教你在美国搭建「百万卡」AI数据中心;为啥大模型做不好简单的数学计算? | ShowMeAI日报

👀日报&周刊合集 | 🎡ShowMeAI官网 | 🧡 点赞关注评论拜托啦! 1. 为啥大模型做不好简单的数学计算?从大模型高考数学成绩不及格说起 司南评测体系 OpenCompass 选取 7 个大模型 (6 个开源模型+ GPT-4o),组织参与了 2024 年高考「新课标I卷」的语文、数学、英语考试,然后由经验丰富的判卷老师评判得分。 结果如上图所

大语言模型(LLMs)能够进行推理和规划吗?

大语言模型(LLMs),基本上是经过强化训练的 n-gram 模型,它们在网络规模的语言语料库(实际上,可以说是我们文明的知识库)上进行了训练,展现出了一种超乎预期的语言行为,引发了我们的广泛关注。从训练和操作的角度来看,LLMs 可以被认为是一种巨大的、非真实的记忆库,相当于为我们所有人提供了一个外部的系统 1(见图 1)。然而,它们表面上的多功能性让许多研究者好奇,这些模型是否也能在通常需要系

springboot家政服务管理平台 LW +PPT+源码+讲解

3系统的可行性研究及需求分析 3.1可行性研究 3.1.1技术可行性分析 经过大学四年的学习,已经掌握了JAVA、Mysql数据库等方面的编程技巧和方法,对于这些技术该有的软硬件配置也是齐全的,能够满足开发的需要。 本家政服务管理平台采用的是Mysql作为数据库,可以绝对地保证用户数据的安全;可以与Mysql数据库进行无缝连接。 所以,家政服务管理平台在技术上是可以实施的。 3.1

人工和AI大语言模型成本对比 ai语音模型

这里既有AI,又有生活大道理,无数渺小的思考填满了一生。 上一专题搭建了一套GMM-HMM系统,来识别连续0123456789的英文语音。 但若不是仅针对数字,而是所有普通词汇,可能达到十几万个词,解码过程将非常复杂,识别结果组合太多,识别结果不会理想。因此只有声学模型是完全不够的,需要引入语言模型来约束识别结果。让“今天天气很好”的概率高于“今天天汽很好”的概率,得到声学模型概率高,又符合表达

智能客服到个人助理,国内AI大模型如何改变我们的生活?

引言 随着人工智能(AI)技术的高速发展,AI大模型越来越多地出现在我们的日常生活和工作中。国内的AI大模型在过去几年里取得了显著的进展,不少独创的技术点和实际应用令人瞩目。 那么,国内的AI大模型有哪些独创的技术点?它们在实际应用中又有哪些出色表现呢?此外,普通人又该如何利用这些大模型提升工作和生活的质量和效率呢?本文将为你一一解析。 一、国内AI大模型的独创技术点 多模态学习 多

高仿精仿愤怒的小鸟android版游戏源码

这是一款很完美的高仿精仿愤怒的小鸟android版游戏源码,大家可以研究一下吧、 为了报复偷走鸟蛋的肥猪们,鸟儿以自己的身体为武器,仿佛炮弹一样去攻击肥猪们的堡垒。游戏是十分卡通的2D画面,看着愤怒的红色小鸟,奋不顾身的往绿色的肥猪的堡垒砸去,那种奇妙的感觉还真是令人感到很欢乐。而游戏的配乐同样充满了欢乐的感觉,轻松的节奏,欢快的风格。 源码下载

OpenCompass:大模型测评工具

大模型相关目录 大模型,包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步,扬帆起航。 大模型应用向开发路径:AI代理工作流大模型应用开发实用开源项目汇总大模型问答项目问答性能评估方法大模型数据侧总结大模型token等基本概念及参数和内存的关系大模型应用开发-华为大模型生态规划从零开始的LLaMA-Factor

模型压缩综述

https://www.cnblogs.com/shixiangwan/p/9015010.html

基于Java医院药品交易系统详细设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W+,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码+数据库🌟 感兴趣的可以先收藏起来,还有大家在毕设选题,项目以及论文编写等相关问题都可以给我留言咨询,希望帮助更多的人  Java精品实战案例《600套》 2023-2025年最值得选择的Java毕业设计选题大全:1000个热

AI赋能天气:微软研究院发布首个大规模大气基础模型Aurora

编者按:气候变化日益加剧,高温、洪水、干旱,频率和强度不断增加的全球极端天气给整个人类社会都带来了难以估计的影响。这给现有的天气预测模型提出了更高的要求——这些模型要更准确地预测极端天气变化,为政府、企业和公众提供更可靠的信息,以便做出及时的准备和响应。为了应对这一挑战,微软研究院开发了首个大规模大气基础模型 Aurora,其超高的预测准确率、效率及计算速度,实现了目前最先进天气预测系统性能的显著