本文主要是介绍pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)
- pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)
- Dataset
- Inputs to model
- Caption Lengths
- Data pipeline
- Encoder
- Attention
- Decoder
- 代码
- 数据集初始化 create_input_files.py
- 训练 train.py
- 测试
pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)
书接上文,本篇主要讲解工程代码结构和代码运行。
代码来源:git
Dataset
我正在使用MSCOCO '14数据集。您需要下载训练(13GB)和验证(6GB)。
我们将使用安德烈·卡帕西的训练、验证和测试分割方法。这个压缩文件包含标题。您还可以找到FlushT 8K和FlushT 30K数据集的拆分和标题,所以如果MSCOCO对您的计算机来说太大,请随意使用它们来代替MSCOCO。
Inputs to model
图像由于我们使用的是预处理编码器,我们需要将图像处理成预处理编码器习惯的形式。
预训练的ImageNet模型可作为PyTorch的torchvision模块的一部分。论文原文详细说明了我们需要执行的预处理或转换——像素值必须在[0,1]范围内,然后我们必须通过ImageNet图像的RGB通道的平均值和标准偏差对图像进行归一化。
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
此外,PyTorch遵循NCHW惯例,这意味着通道尺寸©必须在尺寸尺寸之前。我们将调整所有MSCOCO图像的大小为256x256,以保持一致性。因此,馈送到模型的图像必须是维度为N,3,256,256的浮动张量,并且必须通过前述的平均值和标准偏差进行归一化。n为批量大小。字幕字幕是解码器的目标和输入,因为每个单词都用来生成下一个单词。然而,要生成第一个单词,我们需要第零个单词< start >。最后,我们应该预测解码器必须学会预测字幕的结束。这是必要的,因为我们需要知道在推理过程中什么时候停止解码。
例如:<start> a man holds a football <end>
因为我们将标题作为固定大小的张量传递,所以我们需要用< pad >标记将标题(自然长度可变)填充到相同的长度。
<start> a man holds a football <end> <pad> <pad> <pad>…
此外,我们创建一个word_map,它是语料库中每个单词的索引映射,包括<start>,<end>和<pad>标记。像其他库一样,PyTorch也需要编码为索引的单词来为其查找嵌入或标识其在预测单词分数中的位置。
例如:9876 1 5 120 1 5406 9877 9878 9878 9878…
因此,提供给模型的字幕必须是尺寸为N,L的Int张量,其中L是填充长度。
Caption Lengths
由于字幕是填充的,因此我们需要跟踪每个字幕的长度。这是实际长度+ 2(对于和标记)。
字幕长度也很重要,因为您可以使用PyTorch构建动态图形。我们仅处理序列的长度,并且不会在上浪费计算量。
因此,提供给模型的字幕长度必须是维度N的Int张量。
Data pipeline
请参阅utils.py中的create_input_files()。
- 这将读取下载的数据并保存以下文件–一个HDF5文件,该文件包含I,3、256、256张量中每个分割的图像,其中I是分割中的图像数。像素值仍在[0,255]范围内,并存储为无符号8位Ints。
- 每个分割的JSON文件,其中包含N_c *I编码的字幕列表,其中N_c是每个图像采样的字幕数量。这些标题与HDF5文件中的图像的顺序相同。因此,第i个标题将对应于第i //N_cth个图像。
- 每个分割的JSON文件,其中包含N_c * I字幕长度列表。 ith值是ith标题的长度,它对应于i // N_cth图像。
- 一个包含word_map(单词到索引的字典)的JSON文件。
在我们保存这些文件,我们可以选择只使用字幕是短于阈值,并且仓不太频繁的话到标记。
我们将HDF5文件用于图像,因为我们将在训练/验证期间直接从磁盘读取它们。它们太大了,无法一次放入RAM。但是我们确实将所有字幕及其长度加载到内存中。
请参阅datasets.py中的CaptionDataset。
这是PyTorch数据集的子类。它需要定义一个__len__方法,该方法返回数据集的大小,以及一个__getitem__方法,该方法返回第i个图像,标题和标题长度。
我们从磁盘读取图像,将像素转换为[0,255],然后在此类内对其进行规范化。
PyTorch DataLoader在train.py中将使用该数据集,以创建一批数据并将其馈送到模型中以进行训练或验证。
Encoder
请参阅models.py中的编码器。
我们使用PyTorch的Torchvision模块中已经提供的经过预训练的ResNet-101。丢弃最后两层(池化层和线性层),因为我们只需要对图像进行编码,而无需对其进行分类。
我们确实添加了AdaptiveAvgPool2d()层,以将编码大小调整为固定大小。这样就可以将可变大小的图像馈送到编码器。 (但是,我们确实将输入图像的大小调整为256、256,因为我们必须将它们存储为单个张量。)由于我们可能想对编码器进行微调,因此我们添加了fine_tune()方法来启用或禁用计算编码器参数的梯度。我们仅在ResNet中微调卷积块2到4,因为第一个卷积块通常会学到一些非常重要的图像处理基础知识,例如检测直线,边缘,曲线等。我们不会打乱基准特征。
Attention
请参阅models.py中的Attention。
注意网络很简单–它仅由线性层和几个激活组成。
单独的线性层将解码器的编码图像(展平为N,14 * 14,2048)和隐藏状态(输出)都转换为相同尺寸,即。注意大小。然后添加它们并激活ReLU。第三线性层将此结果转换为1的维度,随后我们应用softmax生成权重alpha。
Decoder
请参阅models.py中的DecoderWithAttention。
此处接收编码器的输出,并将其展平为N,14 * 14,2048尺寸。这很方便,并且避免了多次调整张量的形状。
我们使用init_hidden_state()方法使用编码图像初始化LSTM的隐藏状态和单元状态,该方法使用两个单独的线性层。
首先,我们通过减少字幕长度来对N个图像和字幕进行排序。这样一来,我们只能处理有效的时间步,即不能处理。
我们可以遍历每个时间步,仅处理有色区域,该区域是该时间步的有效批次大小N_t。通过排序,可以使任何时间步长的顶部N_t与上一步的输出对齐。例如,在第三时间步,我们使用上一步的前5个输出仅处理前5个图像。
使用PyTorch LSTMCell在for循环中手动执行此迭代,而不是使用PyTorch LSTM在没有循环的情况下自动迭代。这是因为我们需要在每个解码步骤之间执行Attention机制。 LSTMCell是单个时间步操作,而LSTM将连续地在多个时间步上迭代并立即提供所有输出。
我们使用Attention网络在每个时间步计算权重和注意力加权编码。在论文的第4.2.1节中,他们建议通过滤波器或门传递注意力加权编码。此门是解码器先前隐藏状态的S型激活线性变换。作者指出,这有助于Attention网络将更多的重点放在图像中的对象上。
我们将过滤后的注意力加权编码与上一个单词的嵌入(开始)连接起来,然后运行LSTMCell生成新的隐藏状态(或输出)。线性层将这种新的隐藏状态转换为词汇表中每个单词的分数,并将其存储起来。
我们还存储每个时间步长的注意力网络返回的权重。您会很快明白为什么。
代码
数据集初始化 create_input_files.py
为了快速训练和跑通
选择使用flickr8k数据集,样本量较小,调整和训练都节约时间。
from utils import create_input_filesif __name__ == '__main__':# Create input files (along with word map)create_input_files(dataset='flickr8k',karpathy_json_path='/home/wy/docker/resource/cocodataset/dataset_flickr8k.json',image_folder='/home/wy/docker/resource/cocodataset/Flickr8k/Flicker8k_Dataset',captions_per_image=5,min_word_freq=5,output_folder='/home/wy/docker/resource/cocodataset/Flickr8k/data',max_len=50)
训练 train.py
原文代码pytorch0.4 这里的代码是修改后可以在 pytorch1.0以后版本可以运行的。
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
torch.cuda.set_device(9)# Data parameters
data_folder = '/home/wy/docker/resource/cocodataset/Flickr8k/data' # folder with data files saved by create_input_files.py
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq' # base name shared by data files# Model parameters
emb_dim = 512 # dimension of word embeddings
attention_dim = 512 # dimension of attention linear layers
decoder_dim = 512 # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors
cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead# Training parameters
start_epoch = 0
epochs = 10 # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1 # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4 # learning rate for encoder if fine-tuning
decoder_lr = 4e-4 # learning rate for decoder
grad_clip = 5. # clip gradients at an absolute value of
alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0. # BLEU-4 score right now
print_freq = 100 # print training/validation stats every __ batches
fine_tune_encoder = False # fine-tune encoder?
checkpoint = None # path to checkpoint, None if nonedef main():"""Training and validation."""global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map# Read word mapword_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')with open(word_map_file, 'r') as j:word_map = json.load(j)# Initialize / load checkpointif checkpoint is None:decoder = DecoderWithAttention(attention_dim=attention_dim,embed_dim=emb_dim,decoder_dim=decoder_dim,vocab_size=len(word_map),dropout=dropout)decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),lr=decoder_lr)encoder = Encoder()encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr) if fine_tune_encoder else Noneelse:checkpoint = torch.load(checkpoint)start_epoch = checkpoint['epoch'] + 1epochs_since_improvement = checkpoint['epochs_since_improvement']best_bleu4 = checkpoint['bleu-4']decoder = checkpoint['decoder']decoder_optimizer = checkpoint['decoder_optimizer']encoder = checkpoint['encoder']encoder_optimizer = checkpoint['encoder_optimizer']if fine_tune_encoder is True and encoder_optimizer is None:encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr)# Move to GPU, if availabledecoder = decoder.to(device)encoder = encoder.to(device)# Loss functioncriterion = nn.CrossEntropyLoss().to(device)# Custom dataloadersnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])train_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)val_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)# Epochsfor epoch in range(start_epoch, epochs):# Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20if epochs_since_improvement == 20:breakif epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:adjust_learning_rate(decoder_optimizer, 0.8)if fine_tune_encoder:adjust_learning_rate(encoder_optimizer, 0.8)# One epoch's trainingtrain(train_loader=train_loader,encoder=encoder,decoder=decoder,criterion=criterion,encoder_optimizer=encoder_optimizer,decoder_optimizer=decoder_optimizer,epoch=epoch)# One epoch's validationrecent_bleu4 = validate(val_loader=val_loader,encoder=encoder,decoder=decoder,criterion=criterion)# Check if there was an improvementis_best = recent_bleu4 > best_bleu4best_bleu4 = max(recent_bleu4, best_bleu4)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0# Save checkpoint 保存在utils包里面,全部倒导入的。save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,decoder_optimizer, recent_bleu4, is_best)def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):"""Performs one epoch's training.:param train_loader: DataLoader for training data:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning):param decoder_optimizer: optimizer to update decoder's weights:param epoch: epoch number"""decoder.train() # train mode (dropout and batchnorm is used)encoder.train()#utils 方法 AverageMeterbatch_time = AverageMeter() # forward prop. + back prop. timedata_time = AverageMeter() # data loading timelosses = AverageMeter() # loss (per word decoded)top5accs = AverageMeter() # top5 accuracystart = time.time()# Batchesfor i, (imgs, caps, caplens) in enumerate(train_loader):data_time.update(time.time() - start)# Move to GPU, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Back prop.decoder_optimizer.zero_grad()if encoder_optimizer is not None:encoder_optimizer.zero_grad()loss.backward()# Clip gradientsif grad_clip is not None:clip_gradient(decoder_optimizer, grad_clip)if encoder_optimizer is not None:clip_gradient(encoder_optimizer, grad_clip)# Update weightsdecoder_optimizer.step()if encoder_optimizer is not None:encoder_optimizer.step()# Keep track of metricstop5 = accuracy(scores, targets, 5)losses.update(loss.item(), sum(decode_lengths))top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()# Print statusif i % print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),batch_time=batch_time,data_time=data_time, loss=losses,top5=top5accs))def validate(val_loader, encoder, decoder, criterion):"""Performs one epoch's validation.:param val_loader: DataLoader for validation data.:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:return: BLEU-4 score"""decoder.eval() # eval mode (no dropout or batchnorm)if encoder is not None:encoder.eval()batch_time = AverageMeter()losses = AverageMeter()top5accs = AverageMeter()start = time.time()references = list() # references (true captions) for calculating BLEU-4 scorehypotheses = list() # hypotheses (predictions)# explicitly disable gradient calculation to avoid CUDA memory error# solves the issue #57with torch.no_grad():# Batchesfor i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):# Move to device, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.if encoder is not None:imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores_copy = scores.clone()scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Keep track of metricslosses.update(loss.item(), sum(decode_lengths))top5 = accuracy(scores, targets, 5)top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()if i % print_freq == 0:print('Validation: [{0}/{1}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,loss=losses, top5=top5accs))# Store references (true captions), and hypothesis (prediction) for each image# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]# Referencesallcaps = allcaps[sort_ind] # because images were sorted in the decoderfor j in range(allcaps.shape[0]):img_caps = allcaps[j].tolist()img_captions = list(map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],img_caps)) # remove <start> and padsreferences.append(img_captions)# Hypotheses_, preds = torch.max(scores_copy, dim=2)preds = preds.tolist()temp_preds = list()for j, p in enumerate(preds):temp_preds.append(preds[j][:decode_lengths[j]]) # remove padspreds = temp_predshypotheses.extend(preds)assert len(references) == len(hypotheses)# Calculate BLEU-4 scoresbleu4 = corpus_bleu(references, hypotheses)print('\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(loss=losses,top5=top5accs,bleu=bleu4))return bleu4if __name__ == '__main__':main()
测试
使用了原git上提供的预训练模型
import torch
import torch.nn.functional as F
import numpy as np
import json
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
import argparse
from scipy.misc import imread, imresize
from PIL import Imagetorch.cuda.set_device(9)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):"""Reads an image and captions it with beam search.:param encoder: encoder model:param decoder: decoder model:param image_path: path to image:param word_map: word map:param beam_size: number of sequences to consider at each decode-step:return: caption, weights for visualization"""k = beam_sizevocab_size = len(word_map)# Read image and processimg = imread(image_path)#当为单通道图像时,转化为三通道if len(img.shape) == 2:img = img[:, :, np.newaxis] #增加纬度img = np.concatenate([img, img, img], axis=2) #拼接为三通道img = imresize(img, (256, 256))img = img.transpose(2, 0, 1)#矩阵转置 通道数放在前面img = img / 255.img = torch.FloatTensor(img).to(device)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])transform = transforms.Compose([normalize])image = transform(img) # (3, 256, 256)# Encodeimage = image.unsqueeze(0) # (1, 3, 256, 256)encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 1,14,14,2048enc_image_size = encoder_out.size(1)print('enc_image_size:',enc_image_size)encoder_dim = encoder_out.size(3)print('encoder_dim:',encoder_dim)# Flatten encodingencoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1,196,2048#表示了图像的196个区域各自的特征# print('encoder_out:',encoder_out)num_pixels = encoder_out.size(1)#第二位 196 #print('num_pixels:',num_pixels)# We'll treat the problem as having a batch size of k#print(encoder_out.size())encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)1->k纬度扩展,五份特征#print(encoder_out.size())# Tensor to store top k previous words at each step; now they're just <start>k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device) # (k, 1)#print('k_prev_words:',k_prev_words)# Tensor to store top k sequences; now they're just <start>seqs = k_prev_words # (k, 1)# Tensor to store top k sequences' scores; now they're just 0top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)# Tensor to store top k sequences' alphas; now they're just 1s 这里其实就是存储每个字对应图像上的关注区域,映射在14*14的张量上面seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)# Lists to store completed sequences, their alphas and scorescomplete_seqs = list()complete_seqs_alpha = list()complete_seqs_scores = list()# Start decodingstep = 1h, c = decoder.init_hidden_state(encoder_out)#h0print('h, c',h.size(),c.size())# s is a number less than or equal to k, because sequences are removed from this process once they hit <end>while True:embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) (5,隐层512)print('embeddings',embeddings.size())#encode的图片表示 和 隐状态awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)(5,2048(),5,196(attention 存储字对应图像各部分的权重))print(' awe, alpha',awe.size(),alpha.size())#0/0alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)(5,14,14)gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)awe = gate * awe#给特征赋予权重h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)输入(512,2048),(512,512)带权重的特征和上一次的lstm输出和细胞状态值scores = decoder.fc(h) # (s, vocab_size)scores = F.log_softmax(scores, dim=1)print('scores',scores.size())# Add 每一句 含有多少词 更新scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)print('top_k_scores,scores',top_k_scores.size(),scores.size())# For the first step, all k points will have the same scores (since same k previous words, h, c)if step == 1:top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)else:# Unroll and find top scores, and their unrolled indicestop_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 取词,topprint('top_k_scores,top_k_words',top_k_scores.size(),top_k_words.size())# Convert unrolled indices to actual indices of scoresprev_word_inds = torch.floor_divide(top_k_words, vocab_size)#prev_word_inds = top_k_words / vocab_size # (s)next_word_inds = top_k_words % vocab_size # (s)print('top_k_scores,top_k_words,prev_word_inds,next_word_inds',top_k_words,top_k_scores,prev_word_inds,next_word_inds)# Add new words to sequences, alphasseqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)#词加一seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], #词对应图像区域加一dim=1) # (s, step+1, enc_image_size, enc_image_size)# Which sequences are incomplete (didn't reach <end>)? 挑出这次循环完结的 句子incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) ifnext_word != word_map['<end>']]complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))# Set aside complete sequences 挑出完整序列if len(complete_inds) > 0:complete_seqs.extend(seqs[complete_inds].tolist()) #追加全部序列complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())complete_seqs_scores.extend(top_k_scores[complete_inds])k -= len(complete_inds) # reduce beam length accordingly# Proceed with incomplete sequencesif k == 0:break#更新参数 只保留未完全序列参数seqs = seqs[incomplete_inds]seqs_alpha = seqs_alpha[incomplete_inds]h = h[prev_word_inds[incomplete_inds]]c = c[prev_word_inds[incomplete_inds]]encoder_out = encoder_out[prev_word_inds[incomplete_inds]]top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)# Break if things have been going on too longif step > 50:breakstep += 1#标记 scores分数最高序列作为返回值。i = complete_seqs_scores.index(max(complete_seqs_scores))seq = complete_seqs[i]alphas = complete_seqs_alpha[i]return seq, alphasdef visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):"""Visualizes caption with weights at every word.Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb:param image_path: path to image that has been captioned:param seq: caption:param alphas: weights:param rev_word_map: reverse word mapping, i.e. ix2word:param smooth: smooth weights?"""image = Image.open(image_path)image = image.resize([14 * 12, 14 * 12], Image.LANCZOS)words = [rev_word_map[ind] for ind in seq]print(words)for t in range(len(words)):if t > 50:breakplt.subplot(np.ceil(len(words) / 5.), 5, t + 1)plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)plt.imshow(image)current_alpha = alphas[t, :]if smooth:alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=12, sigma=8)else:alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 12, 14 * 12])if t == 0:plt.imshow(alpha, alpha=0)else:plt.imshow(alpha, alpha=0.8)plt.set_cmap(cm.Greys_r)plt.axis('off')plt.show()import scipyprint(scipy.__version__)
checkpoint = torch.load('./BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar', map_location=str(device))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()# Load word map (word2ix)
with open('./WORDMAP_coco_5_cap_per_img_5_min_word_freq.json', 'r') as j:word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()} # ix2word
# Encode, decode with attention and beam search
seq, alphas = caption_image_beam_search(encoder, decoder,'img/q.jpg', word_map,5)
alphas = torch.FloatTensor(alphas)# Visualize caption and attention of best sequence
visualize_att('img/q.jpg', seq, alphas, rev_word_map,True)
这里使用了齐天大圣作为测试图片,输出很有趣,一个长头发的女人在看着相机。
更多细节请查看原文和git。
这篇关于pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!