政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习

本文主要是介绍政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

概述

设置

数据集

嵌入模型

测试


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:在 CIFAR-10 图像上使用相似度量学习的示例。

概述


度量学习旨在训练能将输入嵌入高维空间的模型,从而使训练方案所定义的 "相似 "输入彼此靠近。这些模型一经训练,就能为下游系统生成对这种相似性有用的嵌入模型,例如作为搜索的排名信号,或作为另一种监督问题的预训练嵌入模型。

设置


将 Keras 后端设置为 tensorflow。

import osos.environ["KERAS_BACKEND"] = "tensorflow"import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from collections import defaultdict
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
import keras
from keras import layers

数据集


在本示例中,我们将使用 CIFAR-10 数据集。

from keras.datasets import cifar10(x_train, y_train), (x_test, y_test) = cifar10.load_data()x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)

为了了解数据集,我们可以将 25 个随机例子组成的网格可视化。

height_width = 32def show_collage(examples):box_size = height_width + 2num_rows, num_cols = examples.shape[:2]collage = Image.new(mode="RGB",size=(num_cols * box_size, num_rows * box_size),color=(250, 250, 250),)for row_idx in range(num_rows):for col_idx in range(num_cols):array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)collage.paste(Image.fromarray(array), (col_idx * box_size, row_idx * box_size))# Double size for visualisation.collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))return collage# Show a collage of 5x5 random images.
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
examples = x_train[sample_idxs]
show_collage(examples)

度量学习提供的训练数据并不是明确的(X,y)对,而是使用以我们想要表达的相似性方式相关的多个实例。

在我们的例子中,我们将使用同一类别的实例来表示相似性;单个训练实例将不是一幅图像,而是同一类别的一对图像。

在提及这对图像时,我们将使用常见的度量学习名称:锚图像(随机选择的图像)和正图像(随机选择的另一张同类图像)。

为此,我们需要建立一种从类到该类实例的查询形式。在生成用于训练的数据时,我们将从该查找表中采样。

class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):class_idx_to_train_idxs[y].append(y_train_idx)class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):class_idx_to_test_idxs[y].append(y_test_idx)

在本例中,我们使用的是最简单的训练方法;一个批次将由分布在各个类别中的(锚、正)对组成。

学习的目标是使锚和正对在批次中更接近、更远离其他实例。

在这种情况下,批次大小将由类的数量决定;对于 CIFAR-10,类的数量为 10。

num_classes = 10class AnchorPositivePairs(keras.utils.Sequence):def __init__(self, num_batches):super().__init__()self.num_batches = num_batchesdef __len__(self):return self.num_batchesdef __getitem__(self, _idx):x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)for class_idx in range(num_classes):examples_for_class = class_idx_to_train_idxs[class_idx]anchor_idx = random.choice(examples_for_class)positive_idx = random.choice(examples_for_class)while positive_idx == anchor_idx:positive_idx = random.choice(examples_for_class)x[0, class_idx] = x_train[anchor_idx]x[1, class_idx] = x_train[positive_idx]return x

我们可以用另一张拼贴图来直观地展示一批结果。上排显示从 10 个类别中随机选择的锚点,下排显示相应的 10 个阳性锚点。

examples = next(iter(AnchorPositivePairs(num_batches=1)))show_collage(examples)

嵌入模型


我们定义了一个带有 train_step 的自定义模型,它首先嵌入锚点和正点,然后使用它们的成对点乘作为 softmax 的对数。

class EmbeddingModel(keras.Model):def train_step(self, data):# Note: Workaround for open issue, to be removed.if isinstance(data, tuple):data = data[0]anchors, positives = data[0], data[1]with tf.GradientTape() as tape:# Run both anchors and positives through model.anchor_embeddings = self(anchors, training=True)positive_embeddings = self(positives, training=True)# Calculate cosine similarity between anchors and positives. As they have# been normalised this is just the pair wise dot products.similarities = keras.ops.einsum("ae,pe->ap", anchor_embeddings, positive_embeddings)# Since we intend to use these as logits we scale them by a temperature.# This value would normally be chosen as a hyper parameter.temperature = 0.2similarities /= temperature# We use these similarities as logits for a softmax. The labels for# this call are just the sequence [0, 1, 2, ..., num_classes] since we# want the main diagonal values, which correspond to the anchor/positive# pairs, to be high. This loss will move embeddings for the# anchor/positive pairs together and move all other pairs apart.sparse_labels = keras.ops.arange(num_classes)loss = self.compute_loss(y=sparse_labels, y_pred=similarities)# Calculate gradients and apply via optimizer.gradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Update and return metrics (specifically the one for the loss value).for metric in self.metrics:# Calling `self.compile` will by default add a [`keras.metrics.Mean`](/api/metrics/metrics_wrappers#mean-class) lossif metric.name == "loss":metric.update_state(loss)else:metric.update_state(sparse_labels, similarities)return {m.name: m.result() for m in self.metrics}

接下来,我们将介绍从图像映射到嵌入空间的结构。该模型由一系列 2d 卷积组成,然后进行全局池化,最后线性投影到嵌入空间。按照度量学习的常见方法,我们对嵌入空间进行归一化处理,以便使用简单的点积来衡量相似性。为了简单起见,我们有意缩小了模型的规模。

inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = layers.UnitNormalization()(embeddings)model = EmbeddingModel(inputs, embeddings)

最后,我们运行训练。在 Google Colab GPU 实例上,这大约需要一分钟。

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)plt.plot(history.history["loss"])
plt.show()
Epoch 1/2077/1000 ━[37m━━━━━━━━━━━━━━━━━━━  1s 2ms/step - loss: 2.2962WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700589927.295343 3724442 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.1000/1000 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step - loss: 2.2504
Epoch 2/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.1068
Epoch 3/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0646
Epoch 4/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0210
Epoch 5/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9857
Epoch 6/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9543
Epoch 7/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9175
Epoch 8/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8740
Epoch 9/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8474
Epoch 10/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8380
Epoch 11/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8146
Epoch 12/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7658
Epoch 13/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7512
Epoch 14/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7671
Epoch 15/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7245
Epoch 16/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7001
Epoch 17/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7099
Epoch 18/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6775
Epoch 19/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6547
Epoch 20/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6356

测试


我们可以将该模型应用于测试集,并考虑嵌入空间中的近邻,从而检验该模型的质量。

首先,我们嵌入测试集并计算所有近邻。回想一下,由于嵌入是单位长度的,我们可以通过点积计算余弦相似度。

near_neighbours_per_example = 10embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step

为了直观地检验这些嵌入,我们可以为 5 个随机例子建立一个近邻拼贴图。下图的第一列是随机选取的图像,随后的 10 列按相似度排序显示了近邻图像。

num_collage_examples = 5examples = np.empty((num_collage_examples,near_neighbours_per_example + 1,height_width,height_width,3,),dtype=np.float32,
)
for row_idx in range(num_collage_examples):examples[row_idx, 0] = x_test[row_idx]anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])for col_idx, nn_idx in enumerate(anchor_near_neighbours):examples[row_idx, col_idx + 1] = x_test[nn_idx]show_collage(examples)

我们还可以通过混淆矩阵来考虑近邻的正确性,从而对性能进行量化。

让我们从 10 个类别中各抽取 10 个例子,并将它们的近邻视为一种预测形式;也就是说,该例子及其近邻是否属于同一类别?

我们观察到,每个动物类别的表现一般都很好,与其他动物类别混淆的情况最多。车辆类别也遵循同样的模式。

confusion_matrix = np.zeros((num_classes, num_classes))# For each class.
for class_idx in range(num_classes):# Consider 10 examples.example_idxs = class_idx_to_test_idxs[class_idx][:10]for y_test_idx in example_idxs:# And count the classes of its near neighbours.for nn_idx in near_neighbours[y_test_idx][:-1]:nn_class_idx = y_test[nn_idx]confusion_matrix[class_idx, nn_class_idx] += 1# Display a confusion matrix.
labels = ["Airplane","Automobile","Bird","Cat","Deer","Dog","Frog","Horse","Ship","Truck",
]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()


这篇关于政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/951419

相关文章

SpringBoot线程池配置使用示例详解

《SpringBoot线程池配置使用示例详解》SpringBoot集成@Async注解,支持线程池参数配置(核心数、队列容量、拒绝策略等)及生命周期管理,结合监控与任务装饰器,提升异步处理效率与系统... 目录一、核心特性二、添加依赖三、参数详解四、配置线程池五、应用实践代码说明拒绝策略(Rejected

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

MySQL 定时新增分区的实现示例

《MySQL定时新增分区的实现示例》本文主要介绍了通过存储过程和定时任务实现MySQL分区的自动创建,解决大数据量下手动维护的繁琐问题,具有一定的参考价值,感兴趣的可以了解一下... mysql创建好分区之后,有时候会需要自动创建分区。比如,一些表数据量非常大,有些数据是热点数据,按照日期分区MululbU

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

C++11作用域枚举(Scoped Enums)的实现示例

《C++11作用域枚举(ScopedEnums)的实现示例》枚举类型是一种非常实用的工具,C++11标准引入了作用域枚举,也称为强类型枚举,本文主要介绍了C++11作用域枚举(ScopedEnums... 目录一、引言二、传统枚举类型的局限性2.1 命名空间污染2.2 整型提升问题2.3 类型转换问题三、C

Java实现自定义table宽高的示例代码

《Java实现自定义table宽高的示例代码》在桌面应用、管理系统乃至报表工具中,表格(JTable)作为最常用的数据展示组件,不仅承载对数据的增删改查,还需要配合布局与视觉需求,而JavaSwing... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码