政安晨:【Keras机器学习示例演绎】(二十九)—— 利用卷积 LSTM 进行下一帧视频预测

本文主要是介绍政安晨:【Keras机器学习示例演绎】(二十九)—— 利用卷积 LSTM 进行下一帧视频预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

简介

设置

数据集构建

数据可视化

模型构建

模型训练

帧预测可视化

预测视频


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:如何建立和训练用于下一帧视频预测的卷积 LSTM 模型。

简介


卷积 LSTM 架构通过在 LSTM 层中引入卷积递归单元,将时间序列处理和计算机视觉结合在一起。在本示例中,我们将探讨卷积 LSTM 模型在下一帧预测中的应用,下一帧预测是指在一系列过去帧的基础上预测下一个视频帧的过程。

设置

import numpy as np
import matplotlib.pyplot as pltimport keras
from keras import layersimport io
import imageio
from IPython.display import Image, display
from ipywidgets import widgets, Layout, HBox

数据集构建


在本例中,我们将使用移动 MNIST 数据集。

我们将下载该数据集,然后构建并预处理训练集和验证集。

对于下一帧预测,我们的模型将使用前一帧(我们称之为 f_n)来预测新一帧(称之为 f_(n + 1))。为了让模型能够创建这些预测,我们需要处理数据,使输入和输出 "移位",其中输入数据为帧 x_n,用于预测帧 y_(n + 1)。

# Download and load the dataset.
fpath = keras.utils.get_file("moving_mnist.npy","http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
)
dataset = np.load(fpath)# Swap the axes representing the number of frames and number of data samples.
dataset = np.swapaxes(dataset, 0, 1)
# We'll pick out 1000 of the 10000 total examples and use those.
dataset = dataset[:1000, ...]
# Add a channel dimension since the images are grayscale.
dataset = np.expand_dims(dataset, axis=-1)# Split into train and validation sets using indexing to optimize memory.
indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]) :]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]# Normalize the data to the 0-1 range.
train_dataset = train_dataset / 255
val_dataset = val_dataset / 255# We'll define a helper function to shift the frames, where
# `x` is frames 0 to n - 1, and `y` is frames 1 to n.
def create_shifted_frames(data):x = data[:, 0 : data.shape[1] - 1, :, :]y = data[:, 1 : data.shape[1], :, :]return x, y# Apply the processing function to the datasets.
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)# Inspect the dataset.
print("Training Dataset Shapes: " + str(x_train.shape) + ", " + str(y_train.shape))
print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape))

演绎展示:

Downloading data from http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy819200096/819200096 ━━━━━━━━━━━━━━━━━━━━ 116s 0us/step
Training Dataset Shapes: (900, 19, 64, 64, 1), (900, 19, 64, 64, 1)
Validation Dataset Shapes: (100, 19, 64, 64, 1), (100, 19, 64, 64, 1)

数据可视化

我们的数据由一系列的帧组成,每个帧都用于预测即将到来的帧。让我们来看一些这些连续帧。

# Construct a figure on which we will visualize the images.
fig, axes = plt.subplots(4, 5, figsize=(10, 8))# Plot each of the sequential images for one random data example.
data_choice = np.random.choice(range(len(train_dataset)), size=1)[0]
for idx, ax in enumerate(axes.flat):ax.imshow(np.squeeze(train_dataset[data_choice][idx]), cmap="gray")ax.set_title(f"Frame {idx + 1}")ax.axis("off")# Print information and display the figure.
print(f"Displaying frames for example {data_choice}.")
plt.show()
Displaying frames for example 95.

模型构建

为了构建一个卷积LSTM模型,我们将使用ConvLSTM2D层,该层将接受形状为(batch_size,num_frames,width,height,channels)的输入,并返回相同形状的预测电影。

# Construct the input layer with no definite frame size.
inp = layers.Input(shape=(None, *x_train.shape[2:]))# We will construct 3 `ConvLSTM2D` layers with batch normalization,
# followed by a `Conv3D` layer for the spatiotemporal outputs.
x = layers.ConvLSTM2D(filters=64,kernel_size=(5, 5),padding="same",return_sequences=True,activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(filters=64,kernel_size=(3, 3),padding="same",return_sequences=True,activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(filters=64,kernel_size=(1, 1),padding="same",return_sequences=True,activation="relu",
)(x)
x = layers.Conv3D(filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
)(x)# Next, we will build the complete model and compile it.
model = keras.models.Model(inp, x)
model.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.Adam(),
)

模型训练


有了模型和数据,我们就可以训练模型了。

# Define some callbacks to improve training.
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)# Define modifiable training hyperparameters.
epochs = 20
batch_size = 5# Fit the model to the training data.
model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_val, y_val),callbacks=[early_stopping, reduce_lr],
)

演绎展示:

Epoch 1/20180/180 ━━━━━━━━━━━━━━━━━━━━ 50s 226ms/step - loss: 0.1510 - val_loss: 0.2966 - learning_rate: 0.0010
Epoch 2/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0287 - val_loss: 0.1766 - learning_rate: 0.0010
Epoch 3/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0269 - val_loss: 0.0661 - learning_rate: 0.0010
Epoch 4/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0264 - val_loss: 0.0279 - learning_rate: 0.0010
Epoch 5/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0258 - val_loss: 0.0254 - learning_rate: 0.0010
Epoch 6/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0256 - val_loss: 0.0253 - learning_rate: 0.0010
Epoch 7/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0248 - learning_rate: 0.0010
Epoch 8/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0251 - learning_rate: 0.0010
Epoch 9/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0247 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 10/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0246 - val_loss: 0.0246 - learning_rate: 0.0010
Epoch 11/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0245 - val_loss: 0.0247 - learning_rate: 0.0010
Epoch 12/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 13/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0244 - val_loss: 0.0245 - learning_rate: 0.0010
Epoch 14/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 15/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0243 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 16/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0242 - val_loss: 0.0242 - learning_rate: 0.0010
Epoch 17/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0240 - learning_rate: 0.0010
Epoch 18/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 19/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0244 - learning_rate: 0.0010
Epoch 20/20180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0237 - val_loss: 0.0238 - learning_rate: 1.0000e-04<keras.src.callbacks.history.History at 0x7ff294f9c340>

帧预测可视化


在构建并训练好模型后,我们可以根据新视频生成一些帧预测示例。

我们将从验证集中随机挑选一个示例,然后从中选择前十个帧。在此基础上,我们可以让模型预测 10 个新帧,并将其与地面实况帧预测进行比较。

# Select a random example from the validation dataset.
example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]# Pick the first/last ten frames from the example.
frames = example[:10, ...]
original_frames = example[10:, ...]# Predict a new set of 10 frames.
for _ in range(10):# Extract the model's prediction and post-process it.new_prediction = model.predict(np.expand_dims(frames, axis=0))new_prediction = np.squeeze(new_prediction, axis=0)predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)# Extend the set of prediction frames.frames = np.concatenate((frames, predicted_frame), axis=0)# Construct a figure for the original and new frames.
fig, axes = plt.subplots(2, 10, figsize=(20, 4))# Plot the original frames.
for idx, ax in enumerate(axes[0]):ax.imshow(np.squeeze(original_frames[idx]), cmap="gray")ax.set_title(f"Frame {idx + 11}")ax.axis("off")# Plot the new frames.
new_frames = frames[10:, ...]
for idx, ax in enumerate(axes[1]):ax.imshow(np.squeeze(new_frames[idx]), cmap="gray")ax.set_title(f"Frame {idx + 11}")ax.axis("off")# Display the figure.
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 800ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 805ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 821ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 824ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 928ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 813ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 810ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 814ms/step

预测视频


最后,我们将从验证集中挑选几个例子,用它们制作一些 GIF,看看模型预测的视频。

你可以使用 Hugging Face Hub 上托管的训练有素的模型,也可以在 Hugging Face Spaces 上尝试演示。

# Select a few random examples from the dataset.
examples = val_dataset[np.random.choice(range(len(val_dataset)), size=5)]# Iterate over the examples and predict the frames.
predicted_videos = []
for example in examples:# Pick the first/last ten frames from the example.frames = example[:10, ...]original_frames = example[10:, ...]new_predictions = np.zeros(shape=(10, *frames[0].shape))# Predict a new set of 10 frames.for i in range(10):# Extract the model's prediction and post-process it.frames = example[: 10 + i + 1, ...]new_prediction = model.predict(np.expand_dims(frames, axis=0))new_prediction = np.squeeze(new_prediction, axis=0)predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)# Extend the set of prediction frames.new_predictions[i] = predicted_frame# Create and save GIFs for each of the ground truth/prediction images.for frame_set in [original_frames, new_predictions]:# Construct a GIF from the selected video frames.current_frames = np.squeeze(frame_set)current_frames = current_frames[..., np.newaxis] * np.ones(3)current_frames = (current_frames * 255).astype(np.uint8)current_frames = list(current_frames)# Construct a GIF from the frames.with io.BytesIO() as gif:imageio.mimsave(gif, current_frames, "GIF", duration=200)predicted_videos.append(gif.getvalue())# Display the videos.
print(" Truth\tPrediction")
for i in range(0, len(predicted_videos), 2):# Construct and display an `HBox` with the ground truth and prediction.box = HBox([widgets.Image(value=predicted_videos[i]),widgets.Image(value=predicted_videos[i + 1]),])display(box)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/stepTruth  PredictionHBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xfb\xfb\xfb\xf4\…HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xf9\xf9\xf9\xf7\…

这篇关于政安晨:【Keras机器学习示例演绎】(二十九)—— 利用卷积 LSTM 进行下一帧视频预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/956496

相关文章

html5的响应式布局的方法示例详解

《html5的响应式布局的方法示例详解》:本文主要介绍了HTML5中使用媒体查询和Flexbox进行响应式布局的方法,简要介绍了CSSGrid布局的基础知识和如何实现自动换行的网格布局,详细内容请阅读本文,希望能对你有所帮助... 一 使用媒体查询响应式布局        使用的参数@media这是常用的

Java使用SLF4J记录不同级别日志的示例详解

《Java使用SLF4J记录不同级别日志的示例详解》SLF4J是一个简单的日志门面,它允许在运行时选择不同的日志实现,这篇文章主要为大家详细介绍了如何使用SLF4J记录不同级别日志,感兴趣的可以了解下... 目录一、SLF4J简介二、添加依赖三、配置Logback四、记录不同级别的日志五、总结一、SLF4J

Java字符串操作技巧之语法、示例与应用场景分析

《Java字符串操作技巧之语法、示例与应用场景分析》在Java算法题和日常开发中,字符串处理是必备的核心技能,本文全面梳理Java中字符串的常用操作语法,结合代码示例、应用场景和避坑指南,可快速掌握字... 目录引言1. 基础操作1.1 创建字符串1.2 获取长度1.3 访问字符2. 字符串处理2.1 子字

QT进行CSV文件初始化与读写操作

《QT进行CSV文件初始化与读写操作》这篇文章主要为大家详细介绍了在QT环境中如何进行CSV文件的初始化、写入和读取操作,本文为大家整理了相关的操作的多种方法,希望对大家有所帮助... 目录前言一、CSV文件初始化二、CSV写入三、CSV读取四、QT 逐行读取csv文件五、Qt如何将数据保存成CSV文件前言

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

前端CSS Grid 布局示例详解

《前端CSSGrid布局示例详解》CSSGrid是一种二维布局系统,可以同时控制行和列,相比Flex(一维布局),更适合用在整体页面布局或复杂模块结构中,:本文主要介绍前端CSSGri... 目录css Grid 布局详解(通俗易懂版)一、概述二、基础概念三、创建 Grid 容器四、定义网格行和列五、设置行

Node.js 数据库 CRUD 项目示例详解(完美解决方案)

《Node.js数据库CRUD项目示例详解(完美解决方案)》:本文主要介绍Node.js数据库CRUD项目示例详解(完美解决方案),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考... 目录项目结构1. 初始化项目2. 配置数据库连接 (config/db.js)3. 创建模型 (models/

通过Spring层面进行事务回滚的实现

《通过Spring层面进行事务回滚的实现》本文主要介绍了通过Spring层面进行事务回滚的实现,包括声明式事务和编程式事务,具有一定的参考价值,感兴趣的可以了解一下... 目录声明式事务回滚:1. 基础注解配置2. 指定回滚异常类型3. ​不回滚特殊场景编程式事务回滚:1. ​使用 TransactionT

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Spring LDAP目录服务的使用示例

《SpringLDAP目录服务的使用示例》本文主要介绍了SpringLDAP目录服务的使用示例... 目录引言一、Spring LDAP基础二、LdapTemplate详解三、LDAP对象映射四、基本LDAP操作4.1 查询操作4.2 添加操作4.3 修改操作4.4 删除操作五、认证与授权六、高级特性与最佳