pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)

本文主要是介绍pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)

  • pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)
    • Dataset
    • Inputs to model
    • Caption Lengths
    • Data pipeline
    • Encoder
    • Attention
    • Decoder
    • 代码
      • 数据集初始化 create_input_files.py
      • 训练 train.py
      • 测试

pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)

书接上文,本篇主要讲解工程代码结构和代码运行。
代码来源:git

Dataset

我正在使用MSCOCO '14数据集。您需要下载训练(13GB)和验证(6GB)。
我们将使用安德烈·卡帕西的训练、验证和测试分割方法。这个压缩文件包含标题。您还可以找到FlushT 8K和FlushT 30K数据集的拆分和标题,所以如果MSCOCO对您的计算机来说太大,请随意使用它们来代替MSCOCO。

Inputs to model

图像由于我们使用的是预处理编码器,我们需要将图像处理成预处理编码器习惯的形式。
预训练的ImageNet模型可作为PyTorch的torchvision模块的一部分。论文原文详细说明了我们需要执行的预处理或转换——像素值必须在[0,1]范围内,然后我们必须通过ImageNet图像的RGB通道的平均值和标准偏差对图像进行归一化。

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

此外,PyTorch遵循NCHW惯例,这意味着通道尺寸©必须在尺寸尺寸之前。我们将调整所有MSCOCO图像的大小为256x256,以保持一致性。因此,馈送到模型的图像必须是维度为N,3,256,256的浮动张量,并且必须通过前述的平均值和标准偏差进行归一化。n为批量大小。字幕字幕是解码器的目标和输入,因为每个单词都用来生成下一个单词。然而,要生成第一个单词,我们需要第零个单词< start >。最后,我们应该预测解码器必须学会预测字幕的结束。这是必要的,因为我们需要知道在推理过程中什么时候停止解码。
例如:<start> a man holds a football <end>
因为我们将标题作为固定大小的张量传递,所以我们需要用< pad >标记将标题(自然长度可变)填充到相同的长度。
<start> a man holds a football <end> <pad> <pad> <pad>…
此外,我们创建一个word_map,它是语料库中每个单词的索引映射,包括<start>,<end>和<pad>标记。像其他库一样,PyTorch也需要编码为索引的单词来为其查找嵌入或标识其在预测单词分数中的位置。
例如:9876 1 5 120 1 5406 9877 9878 9878 9878…
因此,提供给模型的字幕必须是尺寸为N,L的Int张量,其中L是填充长度。

Caption Lengths

由于字幕是填充的,因此我们需要跟踪每个字幕的长度。这是实际长度+ 2(对于和标记)。
字幕长度也很重要,因为您可以使用PyTorch构建动态图形。我们仅处理序列的长度,并且不会在上浪费计算量。
因此,提供给模型的字幕长度必须是维度N的Int张量。

Data pipeline

请参阅utils.py中的create_input_files()。

  1. 这将读取下载的数据并保存以下文件–一个HDF5文件,该文件包含I,3、256、256张量中每个分割的图像,其中I是分割中的图像数。像素值仍在[0,255]范围内,并存储为无符号8位Ints。
  2. 每个分割的JSON文件,其中包含N_c *I编码的字幕列表,其中N_c是每个图像采样的字幕数量。这些标题与HDF5文件中的图像的顺序相同。因此,第i个标题将对应于第i //N_cth个图像。
  3. 每个分割的JSON文件,其中包含N_c * I字幕长度列表。 ith值是ith标题的长度,它对应于i // N_cth图像。
  4. 一个包含word_map(单词到索引的字典)的JSON文件。

在我们保存这些文件,我们可以选择只使用字幕是短于阈值,并且仓不太频繁的话到标记。
我们将HDF5文件用于图像,因为我们将在训练/验证期间直接从磁盘读取它们。它们太大了,无法一次放入RAM。但是我们确实将所有字幕及其长度加载到内存中。
请参阅datasets.py中的CaptionDataset。
这是PyTorch数据集的子类。它需要定义一个__len__方法,该方法返回数据集的大小,以及一个__getitem__方法,该方法返回第i个图像,标题和标题长度。
我们从磁盘读取图像,将像素转换为[0,255],然后在此类内对其进行规范化。
PyTorch DataLoader在train.py中将使用该数据集,以创建一批数据并将其馈送到模型中以进行训练或验证。

Encoder

请参阅models.py中的编码器。
我们使用PyTorch的Torchvision模块中已经提供的经过预训练的ResNet-101。丢弃最后两层(池化层和线性层),因为我们只需要对图像进行编码,而无需对其进行分类。
我们确实添加了AdaptiveAvgPool2d()层,以将编码大小调整为固定大小。这样就可以将可变大小的图像馈送到编码器。 (但是,我们确实将输入图像的大小调整为256、256,因为我们必须将它们存储为单个张量。)由于我们可能想对编码器进行微调,因此我们添加了fine_tune()方法来启用或禁用计算编码器参数的梯度。我们仅在ResNet中微调卷积块2到4,因为第一个卷积块通常会学到一些非常重要的图像处理基础知识,例如检测直线,边缘,曲线等。我们不会打乱基准特征。

Attention

请参阅models.py中的Attention。
注意网络很简单–它仅由线性层和几个激活组成。
单独的线性层将解码器的编码图像(展平为N,14 * 14,2048)和隐藏状态(输出)都转换为相同尺寸,即。注意大小。然后添加它们并激活ReLU。第三线性层将此结果转换为1的维度,随后我们应用softmax生成权重alpha。

Decoder

请参阅models.py中的DecoderWithAttention。
此处接收编码器的输出,并将其展平为N,14 * 14,2048尺寸。这很方便,并且避免了多次调整张量的形状。
我们使用init_hidden_​​state()方法使用编码图像初始化LSTM的隐藏状态和单元状态,该方法使用两个单独的线性层。
首先,我们通过减少字幕长度来对N个图像和字幕进行排序。这样一来,我们只能处理有效的时间步,即不能处理。
在这里插入图片描述
我们可以遍历每个时间步,仅处理有色区域,该区域是该时间步的有效批次大小N_t。通过排序,可以使任何时间步长的顶部N_t与上一步的输出对齐。例如,在第三时间步,我们使用上一步的前5个输出仅处理前5个图像。
使用PyTorch LSTMCell在for循环中手动执行此迭代,而不是使用PyTorch LSTM在没有循环的情况下自动迭代。这是因为我们需要在每个解码步骤之间执行Attention机制。 LSTMCell是单个时间步操作,而LSTM将连续地在多个时间步上迭代并立即提供所有输出。
我们使用Attention网络在每个时间步计算权重和注意力加权编码。在论文的第4.2.1节中,他们建议通过滤波器或门传递注意力加权编码。此门是解码器先前隐藏状态的S型激活线性变换。作者指出,这有助于Attention网络将更多的重点放在图像中的对象上。
我们将过滤后的注意力加权编码与上一个单词的嵌入(开始)连接起来,然后运行LSTMCell生成新的隐藏状态(或输出)。线性层将这种新的隐藏状态转换为词汇表中每个单词的分数,并将其存储起来。
我们还存储每个时间步长的注意力网络返回的权重。您会很快明白为什么。

代码

数据集初始化 create_input_files.py

为了快速训练和跑通
选择使用flickr8k数据集,样本量较小,调整和训练都节约时间。

from utils import create_input_filesif __name__ == '__main__':# Create input files (along with word map)create_input_files(dataset='flickr8k',karpathy_json_path='/home/wy/docker/resource/cocodataset/dataset_flickr8k.json',image_folder='/home/wy/docker/resource/cocodataset/Flickr8k/Flicker8k_Dataset',captions_per_image=5,min_word_freq=5,output_folder='/home/wy/docker/resource/cocodataset/Flickr8k/data',max_len=50)

训练 train.py

原文代码pytorch0.4 这里的代码是修改后可以在 pytorch1.0以后版本可以运行的。

import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
torch.cuda.set_device(9)# Data parameters
data_folder = '/home/wy/docker/resource/cocodataset/Flickr8k/data'  # folder with data files saved by create_input_files.py
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'  # base name shared by data files# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead# Training parameters
start_epoch = 0
epochs = 10  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if nonedef main():"""Training and validation."""global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map# Read word mapword_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')with open(word_map_file, 'r') as j:word_map = json.load(j)# Initialize / load checkpointif checkpoint is None:decoder = DecoderWithAttention(attention_dim=attention_dim,embed_dim=emb_dim,decoder_dim=decoder_dim,vocab_size=len(word_map),dropout=dropout)decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),lr=decoder_lr)encoder = Encoder()encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr) if fine_tune_encoder else Noneelse:checkpoint = torch.load(checkpoint)start_epoch = checkpoint['epoch'] + 1epochs_since_improvement = checkpoint['epochs_since_improvement']best_bleu4 = checkpoint['bleu-4']decoder = checkpoint['decoder']decoder_optimizer = checkpoint['decoder_optimizer']encoder = checkpoint['encoder']encoder_optimizer = checkpoint['encoder_optimizer']if fine_tune_encoder is True and encoder_optimizer is None:encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr)# Move to GPU, if availabledecoder = decoder.to(device)encoder = encoder.to(device)# Loss functioncriterion = nn.CrossEntropyLoss().to(device)# Custom dataloadersnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])train_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)val_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)# Epochsfor epoch in range(start_epoch, epochs):# Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20if epochs_since_improvement == 20:breakif epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:adjust_learning_rate(decoder_optimizer, 0.8)if fine_tune_encoder:adjust_learning_rate(encoder_optimizer, 0.8)# One epoch's trainingtrain(train_loader=train_loader,encoder=encoder,decoder=decoder,criterion=criterion,encoder_optimizer=encoder_optimizer,decoder_optimizer=decoder_optimizer,epoch=epoch)# One epoch's validationrecent_bleu4 = validate(val_loader=val_loader,encoder=encoder,decoder=decoder,criterion=criterion)# Check if there was an improvementis_best = recent_bleu4 > best_bleu4best_bleu4 = max(recent_bleu4, best_bleu4)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0# Save checkpoint 保存在utils包里面,全部倒导入的。save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,decoder_optimizer, recent_bleu4, is_best)def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):"""Performs one epoch's training.:param train_loader: DataLoader for training data:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning):param decoder_optimizer: optimizer to update decoder's weights:param epoch: epoch number"""decoder.train()  # train mode (dropout and batchnorm is used)encoder.train()#utils 方法 AverageMeterbatch_time = AverageMeter()  # forward prop. + back prop. timedata_time = AverageMeter()  # data loading timelosses = AverageMeter()  # loss (per word decoded)top5accs = AverageMeter()  # top5 accuracystart = time.time()# Batchesfor i, (imgs, caps, caplens) in enumerate(train_loader):data_time.update(time.time() - start)# Move to GPU, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Back prop.decoder_optimizer.zero_grad()if encoder_optimizer is not None:encoder_optimizer.zero_grad()loss.backward()# Clip gradientsif grad_clip is not None:clip_gradient(decoder_optimizer, grad_clip)if encoder_optimizer is not None:clip_gradient(encoder_optimizer, grad_clip)# Update weightsdecoder_optimizer.step()if encoder_optimizer is not None:encoder_optimizer.step()# Keep track of metricstop5 = accuracy(scores, targets, 5)losses.update(loss.item(), sum(decode_lengths))top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()# Print statusif i % print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),batch_time=batch_time,data_time=data_time, loss=losses,top5=top5accs))def validate(val_loader, encoder, decoder, criterion):"""Performs one epoch's validation.:param val_loader: DataLoader for validation data.:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:return: BLEU-4 score"""decoder.eval()  # eval mode (no dropout or batchnorm)if encoder is not None:encoder.eval()batch_time = AverageMeter()losses = AverageMeter()top5accs = AverageMeter()start = time.time()references = list()  # references (true captions) for calculating BLEU-4 scorehypotheses = list()  # hypotheses (predictions)# explicitly disable gradient calculation to avoid CUDA memory error# solves the issue #57with torch.no_grad():# Batchesfor i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):# Move to device, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.if encoder is not None:imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores_copy = scores.clone()scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Keep track of metricslosses.update(loss.item(), sum(decode_lengths))top5 = accuracy(scores, targets, 5)top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()if i % print_freq == 0:print('Validation: [{0}/{1}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,loss=losses, top5=top5accs))# Store references (true captions), and hypothesis (prediction) for each image# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]# Referencesallcaps = allcaps[sort_ind]  # because images were sorted in the decoderfor j in range(allcaps.shape[0]):img_caps = allcaps[j].tolist()img_captions = list(map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],img_caps))  # remove <start> and padsreferences.append(img_captions)# Hypotheses_, preds = torch.max(scores_copy, dim=2)preds = preds.tolist()temp_preds = list()for j, p in enumerate(preds):temp_preds.append(preds[j][:decode_lengths[j]])  # remove padspreds = temp_predshypotheses.extend(preds)assert len(references) == len(hypotheses)# Calculate BLEU-4 scoresbleu4 = corpus_bleu(references, hypotheses)print('\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(loss=losses,top5=top5accs,bleu=bleu4))return bleu4if __name__ == '__main__':main()

测试

使用了原git上提供的预训练模型

import torch
import torch.nn.functional as F
import numpy as np
import json
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
import argparse
from scipy.misc import imread, imresize
from PIL import Imagetorch.cuda.set_device(9)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):"""Reads an image and captions it with beam search.:param encoder: encoder model:param decoder: decoder model:param image_path: path to image:param word_map: word map:param beam_size: number of sequences to consider at each decode-step:return: caption, weights for visualization"""k = beam_sizevocab_size = len(word_map)# Read image and processimg = imread(image_path)#当为单通道图像时,转化为三通道if len(img.shape) == 2:img = img[:, :, np.newaxis] #增加纬度img = np.concatenate([img, img, img], axis=2) #拼接为三通道img = imresize(img, (256, 256))img = img.transpose(2, 0, 1)#矩阵转置 通道数放在前面img = img / 255.img = torch.FloatTensor(img).to(device)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])transform = transforms.Compose([normalize])image = transform(img)  # (3, 256, 256)# Encodeimage = image.unsqueeze(0)  # (1, 3, 256, 256)encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim) 1,14,14,2048enc_image_size = encoder_out.size(1)print('enc_image_size:',enc_image_size)encoder_dim = encoder_out.size(3)print('encoder_dim:',encoder_dim)# Flatten encodingencoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim) 1,196,2048#表示了图像的196个区域各自的特征# print('encoder_out:',encoder_out)num_pixels = encoder_out.size(1)#第二位 196 #print('num_pixels:',num_pixels)# We'll treat the problem as having a batch size of k#print(encoder_out.size())encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)1->k纬度扩展,五份特征#print(encoder_out.size())# Tensor to store top k previous words at each step; now they're just <start>k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)#print('k_prev_words:',k_prev_words)# Tensor to store top k sequences; now they're just <start>seqs = k_prev_words  # (k, 1)# Tensor to store top k sequences' scores; now they're just 0top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)# Tensor to store top k sequences' alphas; now they're just 1s 这里其实就是存储每个字对应图像上的关注区域,映射在14*14的张量上面seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)# Lists to store completed sequences, their alphas and scorescomplete_seqs = list()complete_seqs_alpha = list()complete_seqs_scores = list()# Start decodingstep = 1h, c = decoder.init_hidden_state(encoder_out)#h0print('h, c',h.size(),c.size())# s is a number less than or equal to k, because sequences are removed from this process once they hit <end>while True:embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim) (5,隐层512)print('embeddings',embeddings.size())#encode的图片表示 和  隐状态awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)(5,2048(),5,196(attention 存储字对应图像各部分的权重))print(' awe, alpha',awe.size(),alpha.size())#0/0alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)(5,14,14)gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)awe = gate * awe#给特征赋予权重h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)输入(512,2048),(512,512)带权重的特征和上一次的lstm输出和细胞状态值scores = decoder.fc(h)  # (s, vocab_size)scores = F.log_softmax(scores, dim=1)print('scores',scores.size())# Add 每一句 含有多少词 更新scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)print('top_k_scores,scores',top_k_scores.size(),scores.size())# For the first step, all k points will have the same scores (since same k previous words, h, c)if step == 1:top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)else:# Unroll and find top scores, and their unrolled indicestop_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s) 取词,topprint('top_k_scores,top_k_words',top_k_scores.size(),top_k_words.size())# Convert unrolled indices to actual indices of scoresprev_word_inds = torch.floor_divide(top_k_words, vocab_size)#prev_word_inds = top_k_words / vocab_size  # (s)next_word_inds = top_k_words % vocab_size  # (s)print('top_k_scores,top_k_words,prev_word_inds,next_word_inds',top_k_words,top_k_scores,prev_word_inds,next_word_inds)# Add new words to sequences, alphasseqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)#词加一seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],   #词对应图像区域加一dim=1)  # (s, step+1, enc_image_size, enc_image_size)# Which sequences are incomplete (didn't reach <end>)? 挑出这次循环完结的 句子incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) ifnext_word != word_map['<end>']]complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))# Set aside complete sequences 挑出完整序列if len(complete_inds) > 0:complete_seqs.extend(seqs[complete_inds].tolist()) #追加全部序列complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())complete_seqs_scores.extend(top_k_scores[complete_inds])k -= len(complete_inds)  # reduce beam length accordingly# Proceed with incomplete sequencesif k == 0:break#更新参数 只保留未完全序列参数seqs = seqs[incomplete_inds]seqs_alpha = seqs_alpha[incomplete_inds]h = h[prev_word_inds[incomplete_inds]]c = c[prev_word_inds[incomplete_inds]]encoder_out = encoder_out[prev_word_inds[incomplete_inds]]top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)# Break if things have been going on too longif step > 50:breakstep += 1#标记 scores分数最高序列作为返回值。i = complete_seqs_scores.index(max(complete_seqs_scores))seq = complete_seqs[i]alphas = complete_seqs_alpha[i]return seq, alphasdef visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):"""Visualizes caption with weights at every word.Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb:param image_path: path to image that has been captioned:param seq: caption:param alphas: weights:param rev_word_map: reverse word mapping, i.e. ix2word:param smooth: smooth weights?"""image = Image.open(image_path)image = image.resize([14 * 12, 14 * 12], Image.LANCZOS)words = [rev_word_map[ind] for ind in seq]print(words)for t in range(len(words)):if t > 50:breakplt.subplot(np.ceil(len(words) / 5.), 5, t + 1)plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)plt.imshow(image)current_alpha = alphas[t, :]if smooth:alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=12, sigma=8)else:alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 12, 14 * 12])if t == 0:plt.imshow(alpha, alpha=0)else:plt.imshow(alpha, alpha=0.8)plt.set_cmap(cm.Greys_r)plt.axis('off')plt.show()import scipyprint(scipy.__version__)
checkpoint = torch.load('./BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar', map_location=str(device))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()# Load word map (word2ix)
with open('./WORDMAP_coco_5_cap_per_img_5_min_word_freq.json', 'r') as j:word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word
# Encode, decode with attention and beam search
seq, alphas = caption_image_beam_search(encoder, decoder,'img/q.jpg', word_map,5)
alphas = torch.FloatTensor(alphas)# Visualize caption and attention of best sequence
visualize_att('img/q.jpg', seq, alphas, rev_word_map,True)

这里使用了齐天大圣作为测试图片,输出很有趣,一个长头发的女人在看着相机。
在这里插入图片描述
更多细节请查看原文和git。

这篇关于pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141556

相关文章

Python xmltodict实现简化XML数据处理

《Pythonxmltodict实现简化XML数据处理》Python社区为提供了xmltodict库,它专为简化XML与Python数据结构的转换而设计,本文主要来为大家介绍一下如何使用xmltod... 目录一、引言二、XMLtodict介绍设计理念适用场景三、功能参数与属性1、parse函数2、unpa

基于WinForm+Halcon实现图像缩放与交互功能

《基于WinForm+Halcon实现图像缩放与交互功能》本文主要讲述在WinForm中结合Halcon实现图像缩放、平移及实时显示灰度值等交互功能,包括初始化窗口的不同方式,以及通过特定事件添加相应... 目录前言初始化窗口添加图像缩放功能添加图像平移功能添加实时显示灰度值功能示例代码总结最后前言本文将

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

通过C#获取PDF中指定文本或所有文本的字体信息

《通过C#获取PDF中指定文本或所有文本的字体信息》在设计和出版行业中,字体的选择和使用对最终作品的质量有着重要影响,然而,有时我们可能会遇到包含未知字体的PDF文件,这使得我们无法准确地复制或修改文... 目录引言C# 获取PDF中指定文本的字体信息C# 获取PDF文档中用到的所有字体信息引言在设计和出

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档

Java操作xls替换文本或图片的功能实现

《Java操作xls替换文本或图片的功能实现》这篇文章主要给大家介绍了关于Java操作xls替换文本或图片功能实现的相关资料,文中通过示例代码讲解了文件上传、文件处理和Excel文件生成,需要的朋友可... 目录准备xls模板文件:template.xls准备需要替换的图片和数据功能实现包声明与导入类声明与

python解析HTML并提取span标签中的文本

《python解析HTML并提取span标签中的文本》在网页开发和数据抓取过程中,我们经常需要从HTML页面中提取信息,尤其是span元素中的文本,span标签是一个行内元素,通常用于包装一小段文本或... 目录一、安装相关依赖二、html 页面结构三、使用 BeautifulSoup javascript

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

lvgl8.3.6 控件垂直布局 label控件在image控件的下方显示

在使用 LVGL 8.3.6 创建一个垂直布局,其中 label 控件位于 image 控件下方,你可以使用 lv_obj_set_flex_flow 来设置布局为垂直,并确保 label 控件在 image 控件后添加。这里是如何步骤性地实现它的一个基本示例: 创建父容器:首先创建一个容器对象,该对象将作为布局的基础。设置容器为垂直布局:使用 lv_obj_set_flex_flow 设置容器

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目