政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习

本文主要是介绍政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

概述

设置

数据集

嵌入模型

测试


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:在 CIFAR-10 图像上使用相似度量学习的示例。

概述


度量学习旨在训练能将输入嵌入高维空间的模型,从而使训练方案所定义的 "相似 "输入彼此靠近。这些模型一经训练,就能为下游系统生成对这种相似性有用的嵌入模型,例如作为搜索的排名信号,或作为另一种监督问题的预训练嵌入模型。

设置


将 Keras 后端设置为 tensorflow。

import osos.environ["KERAS_BACKEND"] = "tensorflow"import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from collections import defaultdict
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
import keras
from keras import layers

数据集


在本示例中,我们将使用 CIFAR-10 数据集。

from keras.datasets import cifar10(x_train, y_train), (x_test, y_test) = cifar10.load_data()x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)

为了了解数据集,我们可以将 25 个随机例子组成的网格可视化。

height_width = 32def show_collage(examples):box_size = height_width + 2num_rows, num_cols = examples.shape[:2]collage = Image.new(mode="RGB",size=(num_cols * box_size, num_rows * box_size),color=(250, 250, 250),)for row_idx in range(num_rows):for col_idx in range(num_cols):array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)collage.paste(Image.fromarray(array), (col_idx * box_size, row_idx * box_size))# Double size for visualisation.collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))return collage# Show a collage of 5x5 random images.
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
examples = x_train[sample_idxs]
show_collage(examples)

度量学习提供的训练数据并不是明确的(X,y)对,而是使用以我们想要表达的相似性方式相关的多个实例。

在我们的例子中,我们将使用同一类别的实例来表示相似性;单个训练实例将不是一幅图像,而是同一类别的一对图像。

在提及这对图像时,我们将使用常见的度量学习名称:锚图像(随机选择的图像)和正图像(随机选择的另一张同类图像)。

为此,我们需要建立一种从类到该类实例的查询形式。在生成用于训练的数据时,我们将从该查找表中采样。

class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):class_idx_to_train_idxs[y].append(y_train_idx)class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):class_idx_to_test_idxs[y].append(y_test_idx)

在本例中,我们使用的是最简单的训练方法;一个批次将由分布在各个类别中的(锚、正)对组成。

学习的目标是使锚和正对在批次中更接近、更远离其他实例。

在这种情况下,批次大小将由类的数量决定;对于 CIFAR-10,类的数量为 10。

num_classes = 10class AnchorPositivePairs(keras.utils.Sequence):def __init__(self, num_batches):super().__init__()self.num_batches = num_batchesdef __len__(self):return self.num_batchesdef __getitem__(self, _idx):x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)for class_idx in range(num_classes):examples_for_class = class_idx_to_train_idxs[class_idx]anchor_idx = random.choice(examples_for_class)positive_idx = random.choice(examples_for_class)while positive_idx == anchor_idx:positive_idx = random.choice(examples_for_class)x[0, class_idx] = x_train[anchor_idx]x[1, class_idx] = x_train[positive_idx]return x

我们可以用另一张拼贴图来直观地展示一批结果。上排显示从 10 个类别中随机选择的锚点,下排显示相应的 10 个阳性锚点。

examples = next(iter(AnchorPositivePairs(num_batches=1)))show_collage(examples)

嵌入模型


我们定义了一个带有 train_step 的自定义模型,它首先嵌入锚点和正点,然后使用它们的成对点乘作为 softmax 的对数。

class EmbeddingModel(keras.Model):def train_step(self, data):# Note: Workaround for open issue, to be removed.if isinstance(data, tuple):data = data[0]anchors, positives = data[0], data[1]with tf.GradientTape() as tape:# Run both anchors and positives through model.anchor_embeddings = self(anchors, training=True)positive_embeddings = self(positives, training=True)# Calculate cosine similarity between anchors and positives. As they have# been normalised this is just the pair wise dot products.similarities = keras.ops.einsum("ae,pe->ap", anchor_embeddings, positive_embeddings)# Since we intend to use these as logits we scale them by a temperature.# This value would normally be chosen as a hyper parameter.temperature = 0.2similarities /= temperature# We use these similarities as logits for a softmax. The labels for# this call are just the sequence [0, 1, 2, ..., num_classes] since we# want the main diagonal values, which correspond to the anchor/positive# pairs, to be high. This loss will move embeddings for the# anchor/positive pairs together and move all other pairs apart.sparse_labels = keras.ops.arange(num_classes)loss = self.compute_loss(y=sparse_labels, y_pred=similarities)# Calculate gradients and apply via optimizer.gradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Update and return metrics (specifically the one for the loss value).for metric in self.metrics:# Calling `self.compile` will by default add a [`keras.metrics.Mean`](/api/metrics/metrics_wrappers#mean-class) lossif metric.name == "loss":metric.update_state(loss)else:metric.update_state(sparse_labels, similarities)return {m.name: m.result() for m in self.metrics}

接下来,我们将介绍从图像映射到嵌入空间的结构。该模型由一系列 2d 卷积组成,然后进行全局池化,最后线性投影到嵌入空间。按照度量学习的常见方法,我们对嵌入空间进行归一化处理,以便使用简单的点积来衡量相似性。为了简单起见,我们有意缩小了模型的规模。

inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = layers.UnitNormalization()(embeddings)model = EmbeddingModel(inputs, embeddings)

最后,我们运行训练。在 Google Colab GPU 实例上,这大约需要一分钟。

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)plt.plot(history.history["loss"])
plt.show()
Epoch 1/2077/1000 ━[37m━━━━━━━━━━━━━━━━━━━  1s 2ms/step - loss: 2.2962WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700589927.295343 3724442 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.1000/1000 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step - loss: 2.2504
Epoch 2/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.1068
Epoch 3/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0646
Epoch 4/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0210
Epoch 5/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9857
Epoch 6/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9543
Epoch 7/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9175
Epoch 8/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8740
Epoch 9/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8474
Epoch 10/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8380
Epoch 11/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8146
Epoch 12/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7658
Epoch 13/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7512
Epoch 14/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7671
Epoch 15/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7245
Epoch 16/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7001
Epoch 17/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7099
Epoch 18/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6775
Epoch 19/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6547
Epoch 20/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6356

测试


我们可以将该模型应用于测试集,并考虑嵌入空间中的近邻,从而检验该模型的质量。

首先,我们嵌入测试集并计算所有近邻。回想一下,由于嵌入是单位长度的,我们可以通过点积计算余弦相似度。

near_neighbours_per_example = 10embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step

为了直观地检验这些嵌入,我们可以为 5 个随机例子建立一个近邻拼贴图。下图的第一列是随机选取的图像,随后的 10 列按相似度排序显示了近邻图像。

num_collage_examples = 5examples = np.empty((num_collage_examples,near_neighbours_per_example + 1,height_width,height_width,3,),dtype=np.float32,
)
for row_idx in range(num_collage_examples):examples[row_idx, 0] = x_test[row_idx]anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])for col_idx, nn_idx in enumerate(anchor_near_neighbours):examples[row_idx, col_idx + 1] = x_test[nn_idx]show_collage(examples)

我们还可以通过混淆矩阵来考虑近邻的正确性,从而对性能进行量化。

让我们从 10 个类别中各抽取 10 个例子,并将它们的近邻视为一种预测形式;也就是说,该例子及其近邻是否属于同一类别?

我们观察到,每个动物类别的表现一般都很好,与其他动物类别混淆的情况最多。车辆类别也遵循同样的模式。

confusion_matrix = np.zeros((num_classes, num_classes))# For each class.
for class_idx in range(num_classes):# Consider 10 examples.example_idxs = class_idx_to_test_idxs[class_idx][:10]for y_test_idx in example_idxs:# And count the classes of its near neighbours.for nn_idx in near_neighbours[y_test_idx][:-1]:nn_class_idx = y_test[nn_idx]confusion_matrix[class_idx, nn_class_idx] += 1# Display a confusion matrix.
labels = ["Airplane","Automobile","Bird","Cat","Deer","Dog","Frog","Horse","Ship","Truck",
]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()


这篇关于政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/951419

相关文章

OpenCV图像形态学的实现

《OpenCV图像形态学的实现》本文主要介绍了OpenCV图像形态学的实现,包括腐蚀、膨胀、开运算、闭运算、梯度运算、顶帽运算和黑帽运算,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起... 目录一、图像形态学简介二、腐蚀(Erosion)1. 原理2. OpenCV 实现三、膨胀China编程(

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Spring LDAP目录服务的使用示例

《SpringLDAP目录服务的使用示例》本文主要介绍了SpringLDAP目录服务的使用示例... 目录引言一、Spring LDAP基础二、LdapTemplate详解三、LDAP对象映射四、基本LDAP操作4.1 查询操作4.2 添加操作4.3 修改操作4.4 删除操作五、认证与授权六、高级特性与最佳

CSS will-change 属性示例详解

《CSSwill-change属性示例详解》will-change是一个CSS属性,用于告诉浏览器某个元素在未来可能会发生哪些变化,本文给大家介绍CSSwill-change属性详解,感... will-change 是一个 css 属性,用于告诉浏览器某个元素在未来可能会发生哪些变化。这可以帮助浏览器优化

C++中std::distance使用方法示例

《C++中std::distance使用方法示例》std::distance是C++标准库中的一个函数,用于计算两个迭代器之间的距离,本文主要介绍了C++中std::distance使用方法示例,具... 目录语法使用方式解释示例输出:其他说明:总结std::distance&n编程bsp;是 C++ 标准

前端高级CSS用法示例详解

《前端高级CSS用法示例详解》在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交互和动态效果的关键技术之一,随着前端技术的不断发展,CSS的用法也日益丰富和高级,本文将深... 前端高级css用法在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

用js控制视频播放进度基本示例代码

《用js控制视频播放进度基本示例代码》写前端的时候,很多的时候是需要支持要网页视频播放的功能,下面这篇文章主要给大家介绍了关于用js控制视频播放进度的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言html部分:JavaScript部分:注意:总结前言在javascript中控制视频播放

Java中StopWatch的使用示例详解

《Java中StopWatch的使用示例详解》stopWatch是org.springframework.util包下的一个工具类,使用它可直观的输出代码执行耗时,以及执行时间百分比,这篇文章主要介绍... 目录stopWatch 是org.springframework.util 包下的一个工具类,使用它

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu