利用clip模型实现text2draw

2024-08-31 16:36
文章标签 实现 模型 clip text2draw

本文主要是介绍利用clip模型实现text2draw,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

这篇关于利用clip模型实现text2draw的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1124419

相关文章

Python使用python-can实现合并BLF文件

《Python使用python-can实现合并BLF文件》python-can库是Python生态中专注于CAN总线通信与数据处理的强大工具,本文将使用python-can为BLF文件合并提供高效灵活... 目录一、python-can 库:CAN 数据处理的利器二、BLF 文件合并核心代码解析1. 基础合

Python使用OpenCV实现获取视频时长的小工具

《Python使用OpenCV实现获取视频时长的小工具》在处理视频数据时,获取视频的时长是一项常见且基础的需求,本文将详细介绍如何使用Python和OpenCV获取视频时长,并对每一行代码进行深入解析... 目录一、代码实现二、代码解析1. 导入 OpenCV 库2. 定义获取视频时长的函数3. 打开视频文

golang版本升级如何实现

《golang版本升级如何实现》:本文主要介绍golang版本升级如何实现问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录golanwww.chinasem.cng版本升级linux上golang版本升级删除golang旧版本安装golang最新版本总结gola

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

Mysql实现范围分区表(新增、删除、重组、查看)

《Mysql实现范围分区表(新增、删除、重组、查看)》MySQL分区表的四种类型(范围、哈希、列表、键值),主要介绍了范围分区的创建、查询、添加、删除及重组织操作,具有一定的参考价值,感兴趣的可以了解... 目录一、mysql分区表分类二、范围分区(Range Partitioning1、新建分区表:2、分

MySQL 定时新增分区的实现示例

《MySQL定时新增分区的实现示例》本文主要介绍了通过存储过程和定时任务实现MySQL分区的自动创建,解决大数据量下手动维护的繁琐问题,具有一定的参考价值,感兴趣的可以了解一下... mysql创建好分区之后,有时候会需要自动创建分区。比如,一些表数据量非常大,有些数据是热点数据,按照日期分区MululbU

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方