政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习

本文主要是介绍政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

概述

设置

数据集

嵌入模型

测试


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:在 CIFAR-10 图像上使用相似度量学习的示例。

概述


度量学习旨在训练能将输入嵌入高维空间的模型,从而使训练方案所定义的 "相似 "输入彼此靠近。这些模型一经训练,就能为下游系统生成对这种相似性有用的嵌入模型,例如作为搜索的排名信号,或作为另一种监督问题的预训练嵌入模型。

设置


将 Keras 后端设置为 tensorflow。

import osos.environ["KERAS_BACKEND"] = "tensorflow"import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from collections import defaultdict
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
import keras
from keras import layers

数据集


在本示例中,我们将使用 CIFAR-10 数据集。

from keras.datasets import cifar10(x_train, y_train), (x_test, y_test) = cifar10.load_data()x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)

为了了解数据集,我们可以将 25 个随机例子组成的网格可视化。

height_width = 32def show_collage(examples):box_size = height_width + 2num_rows, num_cols = examples.shape[:2]collage = Image.new(mode="RGB",size=(num_cols * box_size, num_rows * box_size),color=(250, 250, 250),)for row_idx in range(num_rows):for col_idx in range(num_cols):array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)collage.paste(Image.fromarray(array), (col_idx * box_size, row_idx * box_size))# Double size for visualisation.collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))return collage# Show a collage of 5x5 random images.
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
examples = x_train[sample_idxs]
show_collage(examples)

度量学习提供的训练数据并不是明确的(X,y)对,而是使用以我们想要表达的相似性方式相关的多个实例。

在我们的例子中,我们将使用同一类别的实例来表示相似性;单个训练实例将不是一幅图像,而是同一类别的一对图像。

在提及这对图像时,我们将使用常见的度量学习名称:锚图像(随机选择的图像)和正图像(随机选择的另一张同类图像)。

为此,我们需要建立一种从类到该类实例的查询形式。在生成用于训练的数据时,我们将从该查找表中采样。

class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):class_idx_to_train_idxs[y].append(y_train_idx)class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):class_idx_to_test_idxs[y].append(y_test_idx)

在本例中,我们使用的是最简单的训练方法;一个批次将由分布在各个类别中的(锚、正)对组成。

学习的目标是使锚和正对在批次中更接近、更远离其他实例。

在这种情况下,批次大小将由类的数量决定;对于 CIFAR-10,类的数量为 10。

num_classes = 10class AnchorPositivePairs(keras.utils.Sequence):def __init__(self, num_batches):super().__init__()self.num_batches = num_batchesdef __len__(self):return self.num_batchesdef __getitem__(self, _idx):x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)for class_idx in range(num_classes):examples_for_class = class_idx_to_train_idxs[class_idx]anchor_idx = random.choice(examples_for_class)positive_idx = random.choice(examples_for_class)while positive_idx == anchor_idx:positive_idx = random.choice(examples_for_class)x[0, class_idx] = x_train[anchor_idx]x[1, class_idx] = x_train[positive_idx]return x

我们可以用另一张拼贴图来直观地展示一批结果。上排显示从 10 个类别中随机选择的锚点,下排显示相应的 10 个阳性锚点。

examples = next(iter(AnchorPositivePairs(num_batches=1)))show_collage(examples)

嵌入模型


我们定义了一个带有 train_step 的自定义模型,它首先嵌入锚点和正点,然后使用它们的成对点乘作为 softmax 的对数。

class EmbeddingModel(keras.Model):def train_step(self, data):# Note: Workaround for open issue, to be removed.if isinstance(data, tuple):data = data[0]anchors, positives = data[0], data[1]with tf.GradientTape() as tape:# Run both anchors and positives through model.anchor_embeddings = self(anchors, training=True)positive_embeddings = self(positives, training=True)# Calculate cosine similarity between anchors and positives. As they have# been normalised this is just the pair wise dot products.similarities = keras.ops.einsum("ae,pe->ap", anchor_embeddings, positive_embeddings)# Since we intend to use these as logits we scale them by a temperature.# This value would normally be chosen as a hyper parameter.temperature = 0.2similarities /= temperature# We use these similarities as logits for a softmax. The labels for# this call are just the sequence [0, 1, 2, ..., num_classes] since we# want the main diagonal values, which correspond to the anchor/positive# pairs, to be high. This loss will move embeddings for the# anchor/positive pairs together and move all other pairs apart.sparse_labels = keras.ops.arange(num_classes)loss = self.compute_loss(y=sparse_labels, y_pred=similarities)# Calculate gradients and apply via optimizer.gradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Update and return metrics (specifically the one for the loss value).for metric in self.metrics:# Calling `self.compile` will by default add a [`keras.metrics.Mean`](/api/metrics/metrics_wrappers#mean-class) lossif metric.name == "loss":metric.update_state(loss)else:metric.update_state(sparse_labels, similarities)return {m.name: m.result() for m in self.metrics}

接下来,我们将介绍从图像映射到嵌入空间的结构。该模型由一系列 2d 卷积组成,然后进行全局池化,最后线性投影到嵌入空间。按照度量学习的常见方法,我们对嵌入空间进行归一化处理,以便使用简单的点积来衡量相似性。为了简单起见,我们有意缩小了模型的规模。

inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = layers.UnitNormalization()(embeddings)model = EmbeddingModel(inputs, embeddings)

最后,我们运行训练。在 Google Colab GPU 实例上,这大约需要一分钟。

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)plt.plot(history.history["loss"])
plt.show()
Epoch 1/2077/1000 ━[37m━━━━━━━━━━━━━━━━━━━  1s 2ms/step - loss: 2.2962WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700589927.295343 3724442 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.1000/1000 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step - loss: 2.2504
Epoch 2/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.1068
Epoch 3/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0646
Epoch 4/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0210
Epoch 5/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9857
Epoch 6/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9543
Epoch 7/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9175
Epoch 8/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8740
Epoch 9/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8474
Epoch 10/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8380
Epoch 11/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8146
Epoch 12/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7658
Epoch 13/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7512
Epoch 14/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7671
Epoch 15/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7245
Epoch 16/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7001
Epoch 17/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7099
Epoch 18/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6775
Epoch 19/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6547
Epoch 20/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6356

测试


我们可以将该模型应用于测试集,并考虑嵌入空间中的近邻,从而检验该模型的质量。

首先,我们嵌入测试集并计算所有近邻。回想一下,由于嵌入是单位长度的,我们可以通过点积计算余弦相似度。

near_neighbours_per_example = 10embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step

为了直观地检验这些嵌入,我们可以为 5 个随机例子建立一个近邻拼贴图。下图的第一列是随机选取的图像,随后的 10 列按相似度排序显示了近邻图像。

num_collage_examples = 5examples = np.empty((num_collage_examples,near_neighbours_per_example + 1,height_width,height_width,3,),dtype=np.float32,
)
for row_idx in range(num_collage_examples):examples[row_idx, 0] = x_test[row_idx]anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])for col_idx, nn_idx in enumerate(anchor_near_neighbours):examples[row_idx, col_idx + 1] = x_test[nn_idx]show_collage(examples)

我们还可以通过混淆矩阵来考虑近邻的正确性,从而对性能进行量化。

让我们从 10 个类别中各抽取 10 个例子,并将它们的近邻视为一种预测形式;也就是说,该例子及其近邻是否属于同一类别?

我们观察到,每个动物类别的表现一般都很好,与其他动物类别混淆的情况最多。车辆类别也遵循同样的模式。

confusion_matrix = np.zeros((num_classes, num_classes))# For each class.
for class_idx in range(num_classes):# Consider 10 examples.example_idxs = class_idx_to_test_idxs[class_idx][:10]for y_test_idx in example_idxs:# And count the classes of its near neighbours.for nn_idx in near_neighbours[y_test_idx][:-1]:nn_class_idx = y_test[nn_idx]confusion_matrix[class_idx, nn_class_idx] += 1# Display a confusion matrix.
labels = ["Airplane","Automobile","Bird","Cat","Deer","Dog","Frog","Horse","Ship","Truck",
]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()


这篇关于政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/951419

相关文章

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

Golang使用etcd构建分布式锁的示例分享

《Golang使用etcd构建分布式锁的示例分享》在本教程中,我们将学习如何使用Go和etcd构建分布式锁系统,分布式锁系统对于管理对分布式系统中共享资源的并发访问至关重要,它有助于维护一致性,防止竞... 目录引言环境准备新建Go项目实现加锁和解锁功能测试分布式锁重构实现失败重试总结引言我们将使用Go作

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

SpringBoot使用注解集成Redis缓存的示例代码

《SpringBoot使用注解集成Redis缓存的示例代码》:本文主要介绍在SpringBoot中使用注解集成Redis缓存的步骤,包括添加依赖、创建相关配置类、需要缓存数据的类(Tes... 目录一、创建 Caching 配置类二、创建需要缓存数据的类三、测试方法Spring Boot 熟悉后,集成一个外

Springboot使用RabbitMQ实现关闭超时订单(示例详解)

《Springboot使用RabbitMQ实现关闭超时订单(示例详解)》介绍了如何在SpringBoot项目中使用RabbitMQ实现订单的延时处理和超时关闭,通过配置RabbitMQ的交换机、队列和... 目录1.maven中引入rabbitmq的依赖:2.application.yml中进行rabbit

Python绘制土地利用和土地覆盖类型图示例详解

《Python绘制土地利用和土地覆盖类型图示例详解》本文介绍了如何使用Python绘制土地利用和土地覆盖类型图,并提供了详细的代码示例,通过安装所需的库,准备地理数据,使用geopandas和matp... 目录一、所需库的安装二、数据准备三、绘制土地利用和土地覆盖类型图四、代码解释五、其他可视化形式1.

opencv实现像素统计的示例代码

《opencv实现像素统计的示例代码》本文介绍了OpenCV中统计图像像素信息的常用方法和函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 统计像素值的基本信息2. 统计像素值的直方图3. 统计像素值的总和4. 统计非零像素的数量

Python使用asyncio实现异步操作的示例

《Python使用asyncio实现异步操作的示例》本文主要介绍了Python使用asyncio实现异步操作的示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录1. 基础概念2. 实现异步 I/O 的步骤2.1 定义异步函数2.2 使用 await 等待异

spring 参数校验Validation示例详解

《spring参数校验Validation示例详解》Spring提供了Validation工具类来实现对客户端传来的请求参数的有效校验,本文给大家介绍spring参数校验Validation示例详... 目录前言一、Validation常见的校验注解二、Validation的简单应用三、分组校验四、自定义校