本文主要是介绍《PyTorch》Part6 PyTorch之seq2seq,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
《PyTorch》Part6 PyTorch之seq2seq
基于PyTorch实现聊天机器人。
环境配置:
torch 1.6.0+cu101
torchvision 0.7.0+cu101
显卡: NVIDIA1050 内存:2GB
1.模型下载:https://download.pytorch.org/models/tutorials/4000_checkpoint.tar
2.注意:所有的注释必须是英文的,否则运行会报错。
"""
"# Awsome notes:
# 1.All code must be English annotation, or it will throw error in 'torch.jit.script()'
# 2.Model download site: https://download.pytorch.org/models/tutorials/4000_checkpoint.tar
"""
#source-python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literalsimport torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import unicodedata
import numpy as np__all__ = [torch]device = torch.device("cpu")MAX_LENGTH = 10 # Maximum sentence lengthPAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token#source-python
class Voc:def __init__(self, name):self.name = nameself.trimmed = Falseself.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3def addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.num_wordsself.word2count[word] = 1self.index2word[self.num_words] = wordself.num_words += 1else:self.word2count[word] += 1# Remove words below a certain count thresholddef trim(self, min_count):if self.trimmed:returnself.trimmed = Truekeep_words = []for k, v in self.word2count.items():if v >= min_count:keep_words.append(k)print('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)))# Reinitialize dictionariesself.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3for word in keep_words:self.addWord(word)def normalizeString(s):s = s.lower()s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)return sdef indexesFromSentence(voc, sentence):return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]class EncoderRNN(nn.Module):def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):super(EncoderRNN, self).__init__()self.n_layers = n_layersself.hidden_size = hidden_sizeself.embedding = embeddingself.gru = nn.GRU(hidden_size, hidden_size, n_layers,dropout=(0 if n_layers == 1 else dropout), bidirectional=True)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)outputs, hidden = self.gru(packed, hidden)outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]return outputs, hiddenclass Attn(torch.nn.Module):def __init__(self, method, hidden_size):super(Attn, self).__init__()self.method = methodif self.method not in ['dot', 'general', 'concat']:raise ValueError(self.method, "is not an appropriate attention method.")self.hidden_size = hidden_sizeif self.method == 'general':self.attn = torch.nn.Linear(self.hidden_size, hidden_size)elif self.method == 'concat':self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))def dot_score(self, hidden, encoder_output):return torch.sum(hidden * encoder_output, dim=2)def general_score(self, hidden, encoder_output):energy = self.attn(encoder_output)return torch.sum(hidden * energy, dim=2)def concat_score(self, hidden, encoder_output):energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()return torch.sum(self.v * energy, dim=2)def forward(self, hidden, encoder_outputs):if self.method == 'general':attn_energies = self.general_score(hidden, encoder_outputs)elif self.method == 'concat':attn_energies = self.concat_score(hidden, encoder_outputs)elif self.method == 'dot':attn_energies = self.dot_score(hidden, encoder_outputs)attn_energies = attn_energies.t()return F.softmax(attn_energies, dim=1).unsqueeze(1)class LuongAttnDecoderRNN(nn.Module):def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):super(LuongAttnDecoderRNN, self).__init__()self.attn_model = attn_modelself.hidden_size = hidden_sizeself.output_size = output_sizeself.n_layers = n_layersself.dropout = dropoutself.embedding = embeddingself.embedding_dropout = nn.Dropout(dropout)self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))self.concat = nn.Linear(hidden_size * 2, hidden_size)self.out = nn.Linear(hidden_size, output_size)self.attn = Attn(attn_model, hidden_size)def forward(self, input_step, last_hidden, encoder_outputs):embedded = self.embedding(input_step)embedded = self.embedding_dropout(embedded)rnn_output, hidden = self.gru(embedded, last_hidden)attn_weights = self.attn(rnn_output, encoder_outputs)context = attn_weights.bmm(encoder_outputs.transpose(0, 1))rnn_output = rnn_output.squeeze(0)context = context.squeeze(1)concat_input = torch.cat((rnn_output, context), 1)concat_output = torch.tanh(self.concat(concat_input))output = self.out(concat_output)output = F.softmax(output, dim=1)return output, hiddenclass GreedySearchDecoder(torch.jit.ScriptModule):def __init__(self, encoder, decoder, decoder_n_layers):super(GreedySearchDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself._device = deviceself._SOS_token = SOS_tokenself._decoder_n_layers = decoder_n_layers__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']@torch.jit.script_methoddef forward(self, input_seq: torch.Tensor, input_length: torch.Tensor, max_length: int):encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)decoder_hidden = encoder_hidden[:self._decoder_n_layers]decoder_input = torch.ones(1, 1, device=self._device, dtype=torch.long) * self._SOS_tokenall_tokens = torch.zeros([0], device=self._device, dtype=torch.long)all_scores = torch.zeros([0], device=self._device)for _ in range(max_length):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)decoder_scores, decoder_input = torch.max(decoder_output, dim=1)all_tokens = torch.cat((all_tokens, decoder_input), dim=0)all_scores = torch.cat((all_scores, decoder_scores), dim=0)decoder_input = torch.unsqueeze(decoder_input, 0)return all_tokens, all_scoresdef evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):# words -> indexesindexes_batch = [indexesFromSentence(voc, sentence)]lengths = torch.tensor([len(indexes) for indexes in indexes_batch])input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)input_batch = input_batch.to(device)lengths = lengths.to(device)tokens, scores = searcher(input_batch, lengths, max_length)# indexes -> wordsdecoded_words = [voc.index2word[token.item()] for token in tokens]return decoded_wordsdef evaluateInput(encoder, decoder, searcher, voc):input_sentence = ''while(1):try:input_sentence = input('> ')# Check if it is quit caseif input_sentence == 'q' or input_sentence == 'quit': breakinput_sentence = normalizeString(input_sentence)output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]print('Bot:', ' '.join(output_words))except KeyError:print("Error: Encountered unknown word.")def evaluateExample(sentence, encoder, decoder, searcher, voc):print("> " + sentence)input_sentence = normalizeString(sentence)output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]print('Bot:', ' '.join(output_words))save_dir = os.path.join("data", "save")
corpus_name = "cornell movie-dialogs corpus"model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64checkpoint_iter = 4000
# loadFilename = os.path.join(save_dir, model_name, corpus_name,
# '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
# '{}_checkpoint.tar'.format(checkpoint_iter))loadFilename = 'data/4000_checkpoint.tar'checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc = Voc(corpus_name)
voc.__dict__ = checkpoint['voc_dict']print('Building encoder and decoder ...')
embedding = nn.Embedding(voc.num_words, hidden_size)
embedding.load_state_dict(embedding_sd)
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)encoder = encoder.to(device)
decoder = decoder.to(device)encoder.eval()
decoder.eval()
print('Models built and ready to go!')'''
Building encoder and decoder
Models built and ready to go!
'''test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)traced_encoder = torch.jit.trace(encoder, (test_seq, test_seq_length))test_encoder_outputs, test_encoder_hidden = traced_encoder(test_seq, test_seq_length)
test_decoder_hidden = test_encoder_hidden[:decoder.n_layers]
test_decoder_input = torch.LongTensor(1, 1).random_(0, voc.num_words)
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))scripted_searcher = GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers)sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?","Are you ok?", "Do you know about China", "Are you foolish?"]
for s in sentences:evaluateExample(s, traced_encoder, traced_decoder, scripted_searcher, voc)#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
运行结果:
> hello
Bot: hello .
> what's up?
Bot: i m going to get my car .
> who are you?
Bot: i m the owner .
> where am I?
Bot: in the house .
> where are you from?
Bot: south america .
> Are you ok?
Bot: i m fine .
> Do you know about China
Bot: i know .
> Are you foolish?
Bot: yes .进程已结束,退出代码 0
这篇关于《PyTorch》Part6 PyTorch之seq2seq的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!