2.9-tf2-数据增强-tf_flowers

2023-11-07 19:51
文章标签 数据 tf 2.9 增强 flowers tf2

本文主要是介绍2.9-tf2-数据增强-tf_flowers,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 1.导入包
    • 2.加载数据
    • 3.数据预处理
    • 4.数据增强
    • 5.预处理层的两种方法
    • 6.把与处理层用在数据集上
    • 7.训练模型
    • 8.自定义数据增强
    • 9.Using tf.image

tf_flowers数据集
在这里插入图片描述

1.导入包

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfdsfrom tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

2.加载数据

(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers',split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],with_info=True,as_supervised=True,
)
#The flowers dataset has five classes.num_classes = metadata.features['label'].num_classes
print(num_classes)
5

复现数据:

#Let's retrieve an image from the dataset and use it to demonstrate data augmentation.get_label_name = metadata.features['label'].int2strimage, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))

在这里插入图片描述

3.数据预处理

Use Keras preprocessing layers

#Resizing and rescaling
#You can use preprocessing layers to resize your images to a consistent shape, and to rescale pixel values.IMG_SIZE = 180resize_and_rescale = tf.keras.Sequential([layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE),layers.experimental.preprocessing.Rescaling(1./255)
])
#Note: the rescaling layer above standardizes pixel values to [0,1]. If instead you wanted [-1,1], you would write Rescaling(1./127.5, offset=-1).
result = resize_and_rescale(image)
_ = plt.imshow(result)

在这里插入图片描述

#You can verify the pixels are in [0-1].print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0

4.数据增强

#Data augmentation
#You can use preprocessing layers for data augmentation as well.#Let's create a few preprocessing layers and apply them repeatedly to the same image.data_augmentation = tf.keras.Sequential([layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),layers.experimental.preprocessing.RandomRotation(0.2),
])
# Add the image to a batch
image = tf.expand_dims(image, 0)
plt.figure(figsize=(10, 10))
for i in range(9):augmented_image = data_augmentation(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])plt.axis("off")

在这里插入图片描述

#There are a variety of preprocessing layers you can use for data augmentation including layers.RandomContrast, layers.RandomCrop, layers.RandomZoom, and others.

5.预处理层的两种方法

There are two ways you can use these preprocessing layers, with important tradeoffs.

  1. 第一种方法
Option 1: Make the preprocessing layers part of your model
model = tf.keras.Sequential([resize_and_rescale,data_augmentation,layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),# Rest of your model
])
There are two important points to be aware of in this case:Data augmentation will run on-device, synchronously with the rest of your layers, and benefit from GPU acceleration.When you export your model using model.save, the preprocessing layers will be saved along with the rest of your model. If you later deploy this model, it will automatically standardize images (according to the configuration of your layers). This can save you from the effort of having to reimplement that logic server-side.Note: Data augmentation is inactive at test time so input images will only be augmented during calls to model.fit (not model.evaluate or model.predict).
  1. 第二种方法:
#Option 2: Apply the preprocessing layers to your dataset
aug_ds = train_ds.map(lambda x, y: (resize_and_rescale(x, training=True), y))
With this approach, you use Dataset.map to create a dataset that yields batches of augmented images. In this case:Data augmentation will happen asynchronously on the CPU, and is non-blocking. You can overlap the training of your model on the GPU with data preprocessing, using Dataset.prefetch, shown below.
In this case the prepreprocessing layers will not be exported with the model when you call model.save. You will need to attach them to your model before saving it or reimplement them server-side. After training, you can attach the preprocessing layers before export.

6.把与处理层用在数据集上

Configure the train, validation, and test datasets with the preprocessing layers you created above. You will also configure the datasets for performance, using parallel reads and buffered prefetching to yield batches from disk without I/O become blocking. 
Note: data augmentation should only be applied to the training set.
batch_size = 32
AUTOTUNE = tf.data.experimental.AUTOTUNEdef prepare(ds, shuffle=False, augment=False):# Resize and rescale all datasetsds = ds.map(lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)if shuffle:ds = ds.shuffle(1000)# Batch all datasetsds = ds.batch(batch_size)# Use data augmentation only on the training setif augment:ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)# Use buffered prefecting on all datasetsreturn ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

7.训练模型

model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(num_classes)
])
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
epochs=5
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/5
92/92 [==============================] - 30s 315ms/step - loss: 1.5078 - accuracy: 0.3428 - val_loss: 1.0809 - val_accuracy: 0.6240
Epoch 2/5
92/92 [==============================] - 28s 303ms/step - loss: 1.0781 - accuracy: 0.5724 - val_loss: 0.9762 - val_accuracy: 0.6322
Epoch 3/5
92/92 [==============================] - 28s 295ms/step - loss: 1.0083 - accuracy: 0.5900 - val_loss: 0.9570 - val_accuracy: 0.6376
Epoch 4/5
92/92 [==============================] - 28s 300ms/step - loss: 0.9537 - accuracy: 0.6116 - val_loss: 0.9081 - val_accuracy: 0.6485
Epoch 5/5
92/92 [==============================] - 28s 301ms/step - loss: 0.8816 - accuracy: 0.6525 - val_loss: 0.8353 - val_accuracy: 0.6594
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
12/12 [==============================] - 1s 83ms/step - loss: 0.8226 - accuracy: 0.6567
Accuracy 0.6566757559776306

8.自定义数据增强

First, you will create a layers.Lambda layer. This is a good way to write concise code. Next, you will write a new layer via subclassing, which gives you more control. Both layers will randomly invert the colors in an image, accoring to some probability.
def random_invert_img(x, p=0.5):if  tf.random.uniform([]) < p:x = (255-x)else:xreturn xdef random_invert(factor=0.5):return layers.Lambda(lambda x: random_invert_img(x, factor))random_invert = random_invert()plt.figure(figsize=(10, 10))
for i in range(9):augmented_image = random_invert(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0].numpy().astype("uint8"))plt.axis("off")

在这里插入图片描述

#Next, implement a custom layer by subclassing.class RandomInvert(layers.Layer):def __init__(self, factor=0.5, **kwargs):super().__init__(**kwargs)self.factor = factordef call(self, x):return random_invert_img(x)_ = plt.imshow(RandomInvert()(image)[0])

在这里插入图片描述

9.Using tf.image

Since the flowers dataset was previously configured with data augmentation, let's reimport it to start fresh.(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers',split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],with_info=True,as_supervised=True,
)
#Retrieve an image to work with.image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))

在这里插入图片描述

Let's use the following function to visualize and compare the original and augmented images side-by-side.def visualize(original, augmented):fig = plt.figure()plt.subplot(1,2,1)plt.title('Original image')plt.imshow(original)plt.subplot(1,2,2)plt.title('Augmented image')plt.imshow(augmented)
#Data augmentation
#Flipping the image
3Flip the image either vertically or horizontally.flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

在这里插入图片描述

#Grayscale an image.grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()

在这里插入图片描述

#Saturate an image by providing a saturation factor.saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

在这里插入图片描述

#Change image brightness
#Change the brightness of image by providing a brightness factor.bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

在这里插入图片描述

#Center crop the image
#Crop the image from center up to the image part you desire.cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image,cropped)

在这里插入图片描述

#Rotate the image
#Rotate an image by 90 degrees.rotated = tf.image.rot90(image)
visualize(image, rotated)

在这里插入图片描述

#Apply augmentation to a dataset
#As before, apply data augmentation to a dataset using Dataset.map.def resize_and_rescale(image, label):image = tf.cast(image, tf.float32)image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])image = (image / 255.0)return image, labeldef augment(image,label):image, label = resize_and_rescale(image, label)# Add 6 pixels of paddingimage = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6) # Random crop back to the original sizeimage = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])image = tf.image.random_brightness(image, max_delta=0.5) # Random brightnessimage = tf.clip_by_value(image, 0, 1)return image, label
#Configure the datasets
train_ds = (train_ds.shuffle(1000).map(augment, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
) val_ds = (val_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
)test_ds = (test_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
)#These datasets can now be used to train a model as shown previously.

这篇关于2.9-tf2-数据增强-tf_flowers的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/365916

相关文章

MySQL InnoDB引擎ibdata文件损坏/删除后使用frm和ibd文件恢复数据

《MySQLInnoDB引擎ibdata文件损坏/删除后使用frm和ibd文件恢复数据》mysql的ibdata文件被误删、被恶意修改,没有从库和备份数据的情况下的数据恢复,不能保证数据库所有表数据... 参考:mysql Innodb表空间卸载、迁移、装载的使用方法注意!此方法只适用于innodb_fi

mysql通过frm和ibd文件恢复表_mysql5.7根据.frm和.ibd文件恢复表结构和数据

《mysql通过frm和ibd文件恢复表_mysql5.7根据.frm和.ibd文件恢复表结构和数据》文章主要介绍了如何从.frm和.ibd文件恢复MySQLInnoDB表结构和数据,需要的朋友可以参... 目录一、恢复表结构二、恢复表数据补充方法一、恢复表结构(从 .frm 文件)方法 1:使用 mysq

mysql8.0无备份通过idb文件恢复数据的方法、idb文件修复和tablespace id不一致处理

《mysql8.0无备份通过idb文件恢复数据的方法、idb文件修复和tablespaceid不一致处理》文章描述了公司服务器断电后数据库故障的过程,作者通过查看错误日志、重新初始化数据目录、恢复备... 周末突然接到一位一年多没联系的妹妹打来电话,“刘哥,快来救救我”,我脑海瞬间冒出妙瓦底,电信火苲马扁.

golang获取prometheus数据(prometheus/client_golang包)

《golang获取prometheus数据(prometheus/client_golang包)》本文主要介绍了使用Go语言的prometheus/client_golang包来获取Prometheu... 目录1. 创建链接1.1 语法1.2 完整示例2. 简单查询2.1 语法2.2 完整示例3. 范围值

javaScript在表单提交时获取表单数据的示例代码

《javaScript在表单提交时获取表单数据的示例代码》本文介绍了五种在JavaScript中获取表单数据的方法:使用FormData对象、手动提取表单数据、使用querySelector获取单个字... 方法 1:使用 FormData 对象FormData 是一个方便的内置对象,用于获取表单中的键值

Rust中的BoxT之堆上的数据与递归类型详解

《Rust中的BoxT之堆上的数据与递归类型详解》本文介绍了Rust中的BoxT类型,包括其在堆与栈之间的内存分配,性能优势,以及如何利用BoxT来实现递归类型和处理大小未知类型,通过BoxT,Rus... 目录1. Box<T> 的基础知识1.1 堆与栈的分工1.2 性能优势2.1 递归类型的问题2.2

Python使用Pandas对比两列数据取最大值的五种方法

《Python使用Pandas对比两列数据取最大值的五种方法》本文主要介绍使用Pandas对比两列数据取最大值的五种方法,包括使用max方法、apply方法结合lambda函数、函数、clip方法、w... 目录引言一、使用max方法二、使用apply方法结合lambda函数三、使用np.maximum函数

Redis的数据过期策略和数据淘汰策略

《Redis的数据过期策略和数据淘汰策略》本文主要介绍了Redis的数据过期策略和数据淘汰策略,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录一、数据过期策略1、惰性删除2、定期删除二、数据淘汰策略1、数据淘汰策略概念2、8种数据淘汰策略

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

Python给Excel写入数据的四种方法小结

《Python给Excel写入数据的四种方法小结》本文主要介绍了Python给Excel写入数据的四种方法小结,包含openpyxl库、xlsxwriter库、pandas库和win32com库,具有... 目录1. 使用 openpyxl 库2. 使用 xlsxwriter 库3. 使用 pandas 库