本文主要是介绍AIGC笔记--SVD中UNet加载预训练权重,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1--加载方式
1. 加载全参数(.ckpt)
2. 加载LoRA(.safetensors)
2--简单实例
import sys
sys.path.append("/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion")import torch
from peft import LoraConfig
from safetensors import safe_openfrom svd.models.i2v_svd_unet import UNetSpatioTemporalConditionModel
from svd.utils.util import zero_rank_printif __name__ == "__main__":pretrained_model_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/svd_models/models/stable-video-diffusion-img2vid-xt"unet = UNetSpatioTemporalConditionModel.from_pretrained(pretrained_model_path, subfolder = "unet")# resume_checkpoint_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion/results/outputs_motionlora_realRota_0603_1024_stride5/test-0-2024-06-03T14-31-30/checkpoints/checkpoint-500.safetensors"resume_checkpoint_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion/results/outputs_motionFull_realRota_0529_stride5/test-0-2024-05-29T10-04-34/checkpoints/checkpoint-step-5000.ckpt"# Load pretrained unet weightsif resume_checkpoint_path.endswith(".ckpt"):zero_rank_print(f"resume from checkpoint: {resume_checkpoint_path}")resume_checkpoint = torch.load(resume_checkpoint_path, map_location="cpu")# resume dit parametersprint(f'resume_checkpoint keys: {resume_checkpoint.keys()}')state_dict = resume_checkpoint["state_dict"]m, u = unet.load_state_dict(state_dict, strict=False)zero_rank_print(f"dit missing keys: {len(m)}, unexpected keys: {len(u)}")assert len(u) == 0# resume global stepresume_global_step = Falseif "global_step" in resume_checkpoint and resume_global_step:zero_rank_print(f"resume global_step: {resume_checkpoint['global_step']}")global_step = resume_checkpoint['global_step'] elif resume_checkpoint_path.endswith(".safetensors"):unet_lora_config = LoraConfig(r = 64, lora_alpha = 64, # scaling = lora_alpha / rinit_lora_weights = "gaussian", target_modules = ["to_q","to_k","to_v","to_out.0"],lora_dropout = 0.1)unet.add_adapter(unet_lora_config)zero_rank_print(f"resume from safetensors: {resume_checkpoint_path}")state_dict = {}with safe_open(resume_checkpoint_path, framework="pt", device="cpu") as f:for key in f.keys():key_ = key.replace('unet.', '').replace('.weight', '')state_dict[key_] = f.get_tensor(key)u = 0try:unet.get_submodule(key_+'.default').state_dict()['weight'].data.copy_(state_dict[key_])except:u += 1assert u == 0, "resume unet params failed"print("All Done!")
这篇关于AIGC笔记--SVD中UNet加载预训练权重的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!