[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

2023-11-30 22:12

本文主要是介绍[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; thenecho "-> preprocessed_11k or 212k"S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; thenS3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
elseecho "Something wrong in data folder name"exit
fiwget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if neededBATCH_SIZE=2
GRAD_ACCUMULATION=4StartTime=$(date +%s)CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \--pretrained_model_name_or_path $MODEL_NAME \--train_data_dir $TRAIN_DATA_DIR\--use_ema \--resolution 512 --center_crop --random_flip \--train_batch_size $BATCH_SIZE \--gradient_checkpointing \--mixed_precision="fp16" \--learning_rate 5e-05 \--max_grad_norm 1 \--lr_scheduler="constant" --lr_warmup_steps=0 \--report_to="all" \--max_train_steps=20 \--seed 1234 \--gradient_accumulation_steps $GRAD_ACCUMULATION \--checkpointing_steps 5 \--valid_steps 5 \--lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \--use_copy_weight_from_teacher \--unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \--output_dir $OUTPUT_DIREndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

这篇关于[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/438740

相关文章

uniapp接入微信小程序原生代码配置方案(优化版)

uniapp项目需要把微信小程序原生语法的功能代码嵌套过来,无需把原生代码转换为uniapp,可以配置拷贝的方式集成过来 1、拷贝代码包到src目录 2、vue.config.js中配置原生代码包直接拷贝到编译目录中 3、pages.json中配置分包目录,原生入口组件的路径 4、manifest.json中配置分包,使用原生组件 5、需要把原生代码包里的页面修改成组件的方

公共筛选组件(二次封装antd)支持代码提示

如果项目是基于antd组件库为基础搭建,可使用此公共筛选组件 使用到的库 npm i antdnpm i lodash-esnpm i @types/lodash-es -D /components/CommonSearch index.tsx import React from 'react';import { Button, Card, Form } from 'antd'

17.用300行代码手写初体验Spring V1.0版本

1.1.课程目标 1、了解看源码最有效的方式,先猜测后验证,不要一开始就去调试代码。 2、浓缩就是精华,用 300行最简洁的代码 提炼Spring的基本设计思想。 3、掌握Spring框架的基本脉络。 1.2.内容定位 1、 具有1年以上的SpringMVC使用经验。 2、 希望深入了解Spring源码的人群,对 Spring有一个整体的宏观感受。 3、 全程手写实现SpringM

代码随想录算法训练营:12/60

非科班学习算法day12 | LeetCode150:逆波兰表达式 ,Leetcode239: 滑动窗口最大值  目录 介绍 一、基础概念补充: 1.c++字符串转为数字 1. std::stoi, std::stol, std::stoll, std::stoul, std::stoull(最常用) 2. std::stringstream 3. std::atoi, std

记录AS混淆代码模板

开启混淆得先在build.gradle文件中把 minifyEnabled false改成true,以及shrinkResources true//去除无用的resource文件 这些是写在proguard-rules.pro文件内的 指定代码的压缩级别 -optimizationpasses 5 包明不混合大小写 -dontusemixedcaseclassnames 不去忽略非公共

麻了!一觉醒来,代码全挂了。。

作为⼀名程序员,相信大家平时都有代码托管的需求。 相信有不少同学或者团队都习惯把自己的代码托管到GitHub平台上。 但是GitHub大家知道,经常在访问速度这方面并不是很快,有时候因为网络问题甚至根本连网站都打不开了,所以导致使用体验并不友好。 经常一觉醒来,居然发现我竟然看不到我自己上传的代码了。。 那在国内,除了GitHub,另外还有一个比较常用的Gitee平台也可以用于

众所周知,配置即代码≠基础设置即代码

​前段时间翻到几条留言,问: “配置即代码和基础设施即代码一样吗?” “配置即代码是什么?怎么都是基础设施即代码?” 我们都是知道,DevOp的快速发展,让服务器管理与配置的时间大大减少,配置即代码和基础设施即代码作为DevOps的重要实践,在其中起到了关键性作用。 不少人将二者看作是一件事,配置即大代码是关于管理特定的应用程序配置设置本身,而基础设施即代码更关注的是部署支持应用程序环境所需的

53、Flink Interval Join 代码示例

1、概述 interval Join 默认会根据 keyBy 的条件进行 Join 此时为 Inner Join; interval Join 算子的水位线会取两条流中水位线的最小值; interval Join 迟到数据的判定是以 interval Join 算子的水位线为基准; interval Join 可以分别输出两条流中迟到的数据-[sideOutputLeftLateData,

Git代码管理的常用操作

在VS022中,Git的管理要先建立本地或远程仓库,然后commit到本地,最后push到远程代码库。 或者不建立本地的情况,直接拉取已有的远程代码。 Git是一个分布式版本控制系统,用于跟踪和管理文件的变化。它可以记录文件的修改历史,并且可以轻松地回滚到任何历史版本。 Git的基本概念包括: 仓库(Repository):Git使用仓库来存储文件的版本历史。一个仓库可以包含多个文件

HTML文档插入JS代码的几种方法

在HTML文档里嵌入客户端JavaScript代码有4中方法: 1.内联,放置在< script>和标签对之间。 2.放置在由< script>标签的src属性指定的外部文件中。 3.放置在HTML事件处理程序中,该事件处理程序由onclick或onmouseover这样的HTML属性值指定。 4.放在一个URL里,这个URL使用特殊的“javascript:”协议。 在JS编程中,主张