Tensorflow实现图片StyleTransfer

2024-04-27 02:48

本文主要是介绍Tensorflow实现图片StyleTransfer,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.效果展示:

原图:

风格图:                                                   

二. 数据集为8000多张图片,训练一个模型,指定一种训练风格的图片

数据集链接:训练数据,8W多 12G蛮大的
http://msvocds.blob.core.windows.net/coco2014/train2014.zip

训练代码:

from __future__ import print_function
import sys, os, pdb
import numpy as np
import scipy.misc
from src.optimize import optimize
from argparse import ArgumentParser
from src.utils import save_img, get_img, exists, list_files
import evaluate  # 迭代优化CONTENT_WEIGHT = 7.5e0
STYLE_WEIGHT = 1e2
TV_WEIGHT = 2e2LEARNING_RATE = 1e-3
NUM_EPOCHS = 2
CHECKPOINT_DIR = 'checkpoints'
CHECKPOINT_ITERATIONS = 2000
VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
TRAIN_PATH = 'data/'  # 图片数据路径
BATCH_SIZE = 4
DEVICE = '/gpu:0'   # gpu 计算
FRAC_GPU = 1# 检测模型中的各个 参数是否已设置好
def check_opts(opts):exists(opts.checkpoint_dir, "checkpoint dir not found!")exists(opts.style, "style path not found!")exists(opts.train_path, "train path not found!")if opts.test or opts.test_dir:exists(opts.test, "test img not found!")exists(opts.test_dir, "test directory not found!")exists(opts.vgg_path, "vgg network data not found!")assert opts.epochs > 0assert opts.batch_size > 0assert opts.checkpoint_iterations > 0assert os.path.exists(opts.vgg_path)assert opts.content_weight >= 0assert opts.style_weight >= 0assert opts.tv_weight >= 0assert opts.learning_rate >= 0def _get_files(img_dir):files = list_files(img_dir)return [os.path.join(img_dir,x) for x in files]def main():parser = build_parser()options = parser.parse_args()check_opts(options)style_target = get_img(options.style)if not options.slow:content_targets = _get_files(options.train_path)elif options.test:content_targets = [options.test]kwargs = {"slow":options.slow,"epochs":options.epochs,"print_iterations":options.checkpoint_iterations,"batch_size":options.batch_size,"save_path":os.path.join(options.checkpoint_dir,'fns.ckpt'),"learning_rate":options.learning_rate}if options.slow:if options.epochs < 10:kwargs['epochs'] = 1000if options.learning_rate < 1:kwargs['learning_rate'] = 1e1args = [content_targets,style_target,options.content_weight,options.style_weight,options.tv_weight,options.vgg_path]for preds, losses, i, epoch in optimize(*args, **kwargs):style_loss, content_loss, tv_loss, loss = lossesprint('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))to_print = (style_loss, content_loss, tv_loss)print('style: %s, content:%s, tv: %s' % to_print)if options.test:assert options.test_dir != Falsepreds_path = '%s/%s_%s.png' % (options.test_dir,epoch,i)if not options.slow:ckpt_dir = os.path.dirname(options.checkpoint_dir)evaluate.ffwd_to_img(options.test,preds_path,options.checkpoint_dir)else:save_img(preds_path, img)ckpt_dir = options.checkpoint_dircmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dirprint("Training complete. For evaluation:\n    `%s`" % cmd_text)if __name__ == '__main__':main()

  VGG训练好的模型:
http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat

三. 测试代码,指定一种风格的model,测试便可生成混合图片

from __future__ import print_function
import sys
sys.path.insert(0, 'src')
import numpy as np, src.vgg, pdb, os
from src import transform
import scipy.misc
import tensorflow as tf
from src.utils import save_img, get_img, exists, list_files
from argparse import ArgumentParser
from collections import defaultdict
import time
import json
import subprocess
import numpyBATCH_SIZE = 4
DEVICE = '/gpu:0'def from_pipe(opts):command = ["ffprobe",'-v', "quiet",'-print_format', 'json','-show_streams', opts.in_path]info = json.loads(str(subprocess.check_output(command), encoding="utf8"))width = int(info["streams"][0]["width"])height = int(info["streams"][0]["height"])fps = round(eval(info["streams"][0]["r_frame_rate"]))command = ["ffmpeg",'-loglevel', "quiet",'-i', opts.in_path,'-f', 'image2pipe','-pix_fmt', 'rgb24','-vcodec', 'rawvideo', '-']pipe_in = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10 ** 9, stdin=None, stderr=None)command = ["ffmpeg",'-loglevel', "info",'-y',  # (optional) overwrite output file if it exists'-f', 'rawvideo','-vcodec', 'rawvideo','-s', str(width) + 'x' + str(height),  # size of one frame'-pix_fmt', 'rgb24','-r', str(fps),  # frames per second'-i', '-',  # The imput comes from a pipe'-an',  # Tells FFMPEG not to expect any audio'-c:v', 'libx264','-preset', 'slow','-crf', '18',opts.out]pipe_out = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=None, stderr=None)g = tf.Graph()soft_config = tf.ConfigProto(allow_soft_placement=True)soft_config.gpu_options.allow_growth = Truewith g.as_default(), g.device(opts.device), \tf.Session(config=soft_config) as sess:batch_shape = (opts.batch_size, height, width, 3)img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,name='img_placeholder')preds = transform.net(img_placeholder)saver = tf.train.Saver()if os.path.isdir(opts.checkpoint):ckpt = tf.train.get_checkpoint_state(opts.checkpoint)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:raise Exception("No checkpoint found...")else:saver.restore(sess, opts.checkpoint)X = np.zeros(batch_shape, dtype=np.float32)nbytes = 3 * width * heightread_input = Truelast = Falsewhile read_input:count = 0while count < opts.batch_size:raw_image = pipe_in.stdout.read(width * height * 3)if len(raw_image) != nbytes:if count == 0:read_input = Falseelse:last = TrueX = X[:count]batch_shape = (count, height, width, 3)img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,name='img_placeholder')preds = transform.net(img_placeholder)breakimage = numpy.fromstring(raw_image, dtype='uint8')image = image.reshape((height, width, 3))X[count] = imagecount += 1if read_input:if last:read_input = False_preds = sess.run(preds, feed_dict={img_placeholder: X})for i in range(0, batch_shape[0]):img = np.clip(_preds[i], 0, 255).astype(np.uint8)try:pipe_out.stdin.write(img)except IOError as err:ffmpeg_error = pipe_out.stderr.read()error = (str(err) + ("\n\nFFMPEG encountered""the following error while writing file:""\n\n %s" % ffmpeg_error))read_input = Falseprint(error)pipe_out.terminate()pipe_in.terminate()pipe_out.stdin.close()pipe_in.stdout.close()del pipe_indel pipe_out# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):assert len(paths_out) > 0is_paths = type(data_in[0]) == strif is_paths:assert len(data_in) == len(paths_out)img_shape = get_img(data_in[0]).shapeelse:assert data_in.size[0] == len(paths_out)# img_shape = X[0].shapeg = tf.Graph()batch_size = min(len(paths_out), batch_size)curr_num = 0soft_config = tf.ConfigProto(allow_soft_placement=True)soft_config.gpu_options.allow_growth = Truewith g.as_default(), g.device(device_t), tf.Session(config=soft_config) as sess:batch_shape = (batch_size,) + img_shapeimg_placeholder = tf.placeholder(tf.float32, shape=batch_shape,name='img_placeholder')preds = transform.net(img_placeholder)saver = tf.train.Saver()if os.path.isdir(checkpoint_dir):ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:raise Exception("No checkpoint found...")else:saver.restore(sess, checkpoint_dir)num_iters = int(len(paths_out)/batch_size)for i in range(num_iters):pos = i * batch_sizecurr_batch_out = paths_out[pos:pos+batch_size]if is_paths:curr_batch_in = data_in[pos:pos+batch_size]X = np.zeros(batch_shape, dtype=np.float32)for j, path_in in enumerate(curr_batch_in):img = get_img(path_in)assert img.shape == img_shape, \'Images have different dimensions. ' +  \'Resize images or use --allow-different-dimensions.'X[j] = imgelse:X = data_in[pos:pos+batch_size]_preds = sess.run(preds, feed_dict={img_placeholder:X})for j, path_out in enumerate(curr_batch_out):save_img(path_out, _preds[j])remaining_in = data_in[num_iters*batch_size:]remaining_out = paths_out[num_iters*batch_size:]if len(remaining_in) > 0:ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1)def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):paths_in, paths_out = [in_path], [out_path]ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)def ffwd_different_dimensions(in_path, out_path, checkpoint_dir, device_t=DEVICE, batch_size=4):in_path_of_shape = defaultdict(list)out_path_of_shape = defaultdict(list)for i in range(len(in_path)):in_image = in_path[i]out_image = out_path[i]shape = "%dx%dx%d" % get_img(in_image).shapein_path_of_shape[shape].append(in_image)out_path_of_shape[shape].append(out_image)for shape in in_path_of_shape:print('Processing images of shape %s' % shape)ffwd(in_path_of_shape[shape], out_path_of_shape[shape], checkpoint_dir, device_t, batch_size)def check_opts(opts):exists(opts.checkpoint_dir, 'Checkpoint not found!')exists(opts.in_path, 'In path not found!')if os.path.isdir(opts.out_path):exists(opts.out_path, 'out dir not found!')assert opts.batch_size > 0def build_parser():parser = ArgumentParser()parser.add_argument('--checkpoint', type=str,dest='checkpoint_dir',help='dir or .ckpt file to load checkpoint from',metavar='CHECKPOINT', required=True,default='./model/la_muse.ckpt')parser.add_argument('--in-path', type=str,dest='in_path',help='dir or file to transform',metavar='IN_PATH', required=True,default='./examples/content/stata.jpg')help_out = 'destination (dir or file) of transformed file or files'parser.add_argument('--out-path', type=str,dest='out_path', help=help_out, metavar='OUT_PATH',required=True,default='./')parser.add_argument('--device', type=str,dest='device',help='device to perform compute on',metavar='DEVICE', default=DEVICE)parser.add_argument('--batch-size', type=int,dest='batch_size',help='batch size for feedforwarding',metavar='BATCH_SIZE', default=BATCH_SIZE)parser.add_argument('--allow-different-dimensions', action='store_true',dest='allow_different_dimensions', help='allow different image dimensions')return parserdef main():parser = build_parser()opts = parser.parse_args()# 确认输入参数是否已存在,若不存在,重新创建check_opts(opts)if not os.path.isdir(opts.in_path):if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):# 获取图片的名称,作为输出图片名out_path = os.path.join(opts.out_path,os.path.basename(opts.in_path))else:out_path = opts.out_pathffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,device=opts.device)else:files = list_files(opts.in_path)full_in = [os.path.join(opts.in_path,x) for x in files]full_out = [os.path.join(opts.out_path,x) for x in files]if opts.allow_different_dimensions:ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir, device_t=opts.device, batch_size=opts.batch_size)else :ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,batch_size=opts.batch_size)if __name__ == '__main__':main()

四.应用的神经网络模型

import tensorflow as tf, pdbWEIGHTS_INIT_STDEV = .1
# 网络结构
def net(image):conv1 = _conv_layer(image, 32, 9, 1)conv2 = _conv_layer(conv1, 64, 3, 2)conv3 = _conv_layer(conv2, 128, 3, 2)# 残差网络结构resid1 = _residual_block(conv3, 3)resid2 = _residual_block(resid1, 3)resid3 = _residual_block(resid2, 3)resid4 = _residual_block(resid3, 3)resid5 = _residual_block(resid4, 3)conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)conv_t3 = _conv_layer(conv_t2, 3, 9, 1, relu=False)preds = tf.nn.tanh(conv_t3) * 150 + 255./2return predsdef _conv_layer(net, num_filters, filter_size, strides, relu=True):weights_init = _conv_init_vars(net, num_filters, filter_size)strides_shape = [1, strides, strides, 1]net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')net = _instance_norm(net)if relu:net = tf.nn.relu(net)return net# 反卷积操作
def _conv_tranpose_layer(net, num_filters, filter_size, strides):weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True) #True 反卷积batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]new_rows, new_cols = int(rows * strides), int(cols * strides)  # 反卷积变换# new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters])new_shape = [batch_size, new_rows, new_cols, num_filters] # 新的shapetf_shape = tf.stack(new_shape)strides_shape = [1,strides,strides,1]net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME')net = _instance_norm(net)return tf.nn.relu(net)# 残差网络的 模块
def _residual_block(net, filter_size=3):tmp = _conv_layer(net, 128, filter_size, 1)return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)# batch_normalization 模块
def _instance_norm(net, train=True):batch, rows, cols, channels = [i.value for i in net.get_shape()] # 特征图var_shape = [channels]# 当前特征图中的均值,方差mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)shift = tf.Variable(tf.zeros(var_shape))scale = tf.Variable(tf.ones(var_shape))epsilon = 1e-3normalized = (net-mu)/(sigma_sq + epsilon)**(.5)return scale * normalized + shiftdef _conv_init_vars(net, out_channels, filter_size, transpose=False):_, rows, cols, in_channels = [i.value for i in net.get_shape()]if not transpose:weights_shape = [filter_size, filter_size, in_channels, out_channels]else:weights_shape = [filter_size, filter_size, out_channels, in_channels]   # 反卷积weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)return weights_init

由于代码过多,不易全部展示,完整Demo参加GitHub链接:

https://github.com/Whq123/Style-transfer-of-picture

这篇关于Tensorflow实现图片StyleTransfer的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/939333

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

使用opencv优化图片(画面变清晰)

文章目录 需求影响照片清晰度的因素 实现降噪测试代码 锐化空间锐化Unsharp Masking频率域锐化对比测试 对比度增强常用算法对比测试 需求 对图像进行优化,使其看起来更清晰,同时保持尺寸不变,通常涉及到图像处理技术如锐化、降噪、对比度增强等 影响照片清晰度的因素 影响照片清晰度的因素有很多,主要可以从以下几个方面来分析 1. 拍摄设备 相机传感器:相机传

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

C++——stack、queue的实现及deque的介绍

目录 1.stack与queue的实现 1.1stack的实现  1.2 queue的实现 2.重温vector、list、stack、queue的介绍 2.1 STL标准库中stack和queue的底层结构  3.deque的简单介绍 3.1为什么选择deque作为stack和queue的底层默认容器  3.2 STL中对stack与queue的模拟实现 ①stack模拟实现