本文主要是介绍pytorch版本的bert模型代码(MLM),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
魔改bert就必须要知道Bert的结构:
主要解答与BertForMaskedLM(MLM)有关的类:
下面是MLM的分类头:
class BertLMPredictionHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertLMPredictionHead, self).__init__()self.transform = BertPredictionHeadTransform(config)# The output weights are the same as the input embeddings, but there is# an output-only bias for each token.self.decoder = nn.Linear(bert_model_embedding_weights.size(1),bert_model_embedding_weights.size(0),bias=False)self.decoder.weight = bert_model_embedding_weightsself.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))"""上面是创建一个线性映射层, 把transformer block输出的[batch_size, seq_len, embed_dim]映射为[batch_size, seq_len, vocab_size], 也就是把最后一个维度映射成字典中字的数量, 获取MaskedLM的预测结果, 注意这里其实也可以直接矩阵成embedding矩阵的转置, 但一般情况下我们要随机初始化新的一层参数"""def forward(self, hidden_states):hidden_states = self.transform(hidden_states)hidden_states = self.decoder(hidden_states) + self.biasreturn hidden_statesclass BertOnlyMLMHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertOnlyMLMHead, self).__init__()self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)def forward(self, sequence_output):prediction_scores = self.predictions(sequence_output)return prediction_scores
另一个类:
class BertModel(BertPreTrainedModel):"""BERT model ("Bidirectional Embedding Representations from a Transformer").Params:config: a BertConfig class instance with the configuration to build a new modelInputs:`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts`extract_features.py`, `run_classifier.py` and `run_squad.py`)`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the tokentypes indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds toa `sentence B` token (see BERT paper for more details).`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indicesselected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the maxinput sequence length in the current batch. It's the mask that we typically use for attention whena batch has varying length sentences.`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.Outputs: Tuple of (encoded_layers, pooled_output)`encoded_layers`: controled by `output_all_encoded_layers` argument:- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the endof each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), eachencoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states correspondingto the last attention block of shape [batch_size, sequence_length, hidden_size],`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of aclassifier pretrained on top of the hidden state associated to the first character of theinput (`CLS`) to train on the Next-Sentence task (see BERT's paper).Example usage:```python# Already been converted into WordPiece token idsinput_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)model = modeling.BertModel(config=config)all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)```"""def __init__(self, config):super(BertModel, self).__init__(config)self.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config)self.apply(self.init_bert_weights)def forward(self, input_ids, positional_enc, token_type_ids=None, attention_mask=None,output_all_encoded_layers=True, get_attention_matrices=False):if attention_mask is None:# torch.LongTensor# attention_mask = torch.ones_like(input_ids)attention_mask = (input_ids > 0)# attention_mask [batch_size, length]if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)# We create a 3D attention mask from a 2D tensor mask.# Sizes are [batch_size, 1, 1, to_seq_length]# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]# this attention mask is more simple than the triangular masking of causal attention# used in OpenAI GPT, we just need to prepare the broadcast dimension here.extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)# 注意力矩阵mask: [batch_size, 1, 1, seq_length]# Since attention_mask is 1.0 for positions we want to attend and 0.0 for# masked positions, this operation will create a tensor which is 0.0 for# positions we want to attend and -10000.0 for masked positions.# Since we are adding it to the raw scores before the softmax, this is# effectively the same as removing these entirely.extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibilityextended_attention_mask = (1.0 - extended_attention_mask) * -10000.0# 给注意力矩阵里padding的无效区域加一个很大的负数的偏置, 为了使softmax之后这些无效区域仍然为0, 不参与后续计算# embedding层embedding_output = self.embeddings(input_ids, positional_enc, token_type_ids)# 经过所有定义的transformer block之后的输出encoded_layers, all_attention_matrices = self.encoder(embedding_output,extended_attention_mask,output_all_encoded_layers=output_all_encoded_layers,get_attention_matrices=get_attention_matrices)# 可输出所有层的注意力矩阵用于可视化if get_attention_matrices:return all_attention_matrices# [-1]为最后一个transformer block的隐藏层的计算结果sequence_output = encoded_layers[-1]# pooled_output为隐藏层中#CLS#对应的token的一条向量pooled_output = self.pooler(sequence_output)if not output_all_encoded_layers:encoded_layers = encoded_layers[-1]return encoded_layers, pooled_output
参考:https://blog.51cto.com/u_15060462/4254056
这篇关于pytorch版本的bert模型代码(MLM)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!