LLaMA-Factory微调入门个人重制版

2024-08-30 07:36

本文主要是介绍LLaMA-Factory微调入门个人重制版,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

LLaMA-Factory微调入门个人重制版

说明:

  • 首次发表日期:2024-08-30
  • LLaMA-Factory 官方Github仓库: https://github.com/hiyouga/LLaMA-Factory

关于

本文是对LLaMA-Factory入门教程 https://zhuanlan.zhihu.com/p/695287607 的个人重制版,记录一下学习过程,省略掉了很多文字部分,建议直接阅读 https://zhuanlan.zhihu.com/p/695287607

准备环境

git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
# 使用清华pypi源
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install -e '.[torch,metrics]'

校验环境

import torch
torch.cuda.current_device()
torch.cuda.get_device_name(0)
torch.__version__
# 获取训练相关的参数指导
llamafactory-cli train -h

下载模型

apt update
apt install git-lfs
mkdir models-modelscope
cd models-modelscopegit lfs install
git clone https://www.modelscope.cn/LLM-Research/Meta-Llama-3-8B-Instruct.git

下载模型时也可以先下载小文件,然后手动pull需要的大文件,参考 https://blog.csdn.net/flyingluohaipeng/article/details/130788293

# git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/LLM-Research/Meta-Llama-3-8B-Instruct.git
cd Meta-Llama-3-8B-Instruct
git lfs pull --include="*.safetensors:

查看文件大小和数量是否正确:

cd Meta-Llama-3-8B-Instruct
ls -al --block-size=M

运行推理DEMO

运行模型的README中的推理DEMO,验证文件的正确性和transformers等依赖库正常可用:

import transformers
import torch# 切换为你下载的模型文件目录, 这里的demo是Llama-3-8B-Instruct
# 如果是其他模型,比如qwen,chatglm,请使用其对应的官方demo
model_id = "/root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct"pipeline = transformers.pipeline("text-generation",model=model_id,model_kwargs={"torch_dtype": torch.bfloat16},device="cuda",
)messages = [{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},{"role": "user", "content": "Who are you?"},
]prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True
)terminators = [pipeline.tokenizer.eos_token_id,pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]outputs = pipeline(prompt,max_new_tokens=256,eos_token_id=terminators,do_sample=True,temperature=0.6,top_p=0.9,
)
print(outputs[0]["generated_text"][len(prompt):])

输出:

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 16.01it/s]
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Arrrr, shiver me timbers! Me name be Captain Chatbot, the scurviest pirate to ever sail the Seven Seas! Me be a chatbot of great renown, feared and respected by all who cross me digital path. Me specialty be spinnin' yarns, swabbin' decks, and plunderin' knowledge to share with me hearties. So hoist the colors, me matey, and let's set sail fer a swashbucklin' adventure o' conversation!

验证一下LLaMA-Factory的推理部分是否正常(会启动基于gradio开发的ChatBot推理页面):

# 一般不需要,我的环境需要,GRADIO_ROOT_PATH说明见 https://www.gradio.app/guides/environment-variables#7-gradio-root-path
export GRADIO_ROOT_PATH=/proxy/7860/CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--template llama3

构建自定义数据集(指令微调)

自带的identity.json数据集

cd LLaMA-Factory
# 其中的NAME 和 AUTHOR ,替换成我们需要的内容
sed -i 's/{{name}}/PonyBot/g'  data/identity.json 
sed -i 's/{{author}}/LLaMA Factory/g'  data/identity.json 

商品文案生成数据集

下载并解压数据:

cd data
# 部分wget参数说明见 https://stackoverflow.com/questions/53189651/capture-a-download-link-redirected-by-a-page-wget 和 https://unix.stackexchange.com/questions/453465/wget-how-to-download-a-served-file-keeping-its-name
wget -r -l 1 --span-hosts --accept-regex='.*cloud.tsinghua.edu.cn/.*.exe' -erobots=off -nH --content-disposition -nd https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1
tar -xvf AdvertiseGen.tar.gz

检查数据集格式:

tail -n 3 AdvertiseGen/train.json

输出:

{"content": "类型#裙*版型#宽松*版型#显瘦*颜色#黑色*图案#撞色*裙型#直筒裙*裙款式#拼接", "summary": "采用简洁大体的黑色格调,宽松舒适的裙子内里,配上落肩的袖子拼接,不惧夏日的炎热,穿出清凉舒适。用时尚的英文字母,加上撞色的红白搭配,呈现大气时尚的款式。直筒的裙子轮廓,前短后长的长度裁剪,上身拉长宝宝的身体比例,挑高显瘦。"}
{"content": "类型#上衣*颜色#黑色*颜色#紫色*风格#性感*图案#字母*图案#文字*图案#线条*图案#刺绣*衣样式#卫衣*衣长#短款*衣袖型#落肩袖*衣款式#连帽", "summary": "卫衣的短款长度设计能够适当地露出腰线,打造出纤瘦的身材十分性感。衣身的字母刺绣图案有着小巧的样式,黑色的绣线在紫色的衣身上显得很出挑显眼。落肩袖的设计柔化了肩部的线条衬托得人很温柔可爱。紫色的颜色彰显出优雅的气质也不失年轻活力感。连帽的设计让卫衣更加丰满造型感很足,长长的帽绳直到腰际处,有着延长衣身的效果显得身材<UNK>。"}
{"content": "类型#上衣*颜色#黑白*风格#简约*风格#休闲*图案#条纹*衣样式#风衣*衣样式#外套", "summary": "设计师以条纹作为风衣外套的主要设计元素,以简约点缀了外套,带来大气休闲的视觉效果。因为采用的是黑白的经典色,所以有着颇为出色的耐看性与百搭性,可以帮助我们更好的驾驭日常的穿着,而且不容易让人觉得它过时。"}

修改data/dataset_info.json文件:添加自定义数据集adgen_local,添加后文件尾部看起来如下:

  },"adgen_local": {"file_name": "AdvertiseGen/train.json","columns": {"prompt": "content","response": "summary"}}
}

其中columns部分将AdvertiseGen/train.json中的"content"映射为"prompt",将"summary"映射为"response"

数据集说明见: https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README_zh.md#%E6%8C%87%E4%BB%A4%E7%9B%91%E7%9D%A3%E5%BE%AE%E8%B0%83%E6%95%B0%E6%8D%AE%E9%9B%86

基于LoRA的sft指令微调

设置从魔搭社区下载数据集

# 回到LLaMA-Factory文件夹
cd ..
# 安装依赖
pip install modelscope oss2 addict
# 从魔搭社区下载
export USE_MODELSCOPE_HUB=1

开始sft指令微调

CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \--stage sft \--do_train \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--dataset alpaca_gpt4_zh,identity,adgen_local \--dataset_dir ./data \--template llama3 \--finetuning_type lora \--output_dir ./saves/LLaMA3-8B/lora/sft \--overwrite_cache \--overwrite_output_dir \--cutoff_len 1024 \--preprocessing_num_workers 16 \--per_device_train_batch_size 2 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 8 \--lr_scheduler_type cosine \--logging_steps 50 \--warmup_steps 20 \--save_steps 100 \--eval_steps 50 \--evaluation_strategy steps \--load_best_model_at_end \--learning_rate 5e-5 \--num_train_epochs 5.0 \--max_samples 1000 \--val_size 0.1 \--plot_loss \--fp16

动态合并LoRA的推理

启动WebUI(Gradio):

# export GRADIO_ROOT_PATH=/proxy/7860/
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--adapter_name_or_path ./saves/LLaMA3-8B/lora/sft  \--template llama3 \--finetuning_type lora

使用命令行进行交互式推理:

CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--adapter_name_or_path ./saves/LLaMA3-8B/lora/sft  \--template llama3 \--finetuning_type lora

效果如下:

User: 你是谁?
Assistant: 您好,我是 PonyBot,一个由 LLaMA Factory 开发的人工智能助手。我可以帮助回答问题,提供信息,或者进行其他支持性任务。User: 类型#裙*版型#宽松*版型#显瘦*颜色#黑色*图案#撞色*裙型#直筒裙*裙款式#拼接
Assistant: 这款裙子采用黑色和暗棕色拼接的撞色设计,很有设计感。宽松的直筒版型,适合任何身材的女人穿着。撞色拼接的裙摆,显得活泼有趣。裙身的撞色拼接,很有设计感。

批量预测和训练效果评估

pip install jieba
pip install rouge-chinese
pip install nltk
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \--stage sft \--do_predict \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--adapter_name_or_path ./saves/LLaMA3-8B/lora/sft  \--eval_dataset alpaca_gpt4_zh,identity,adgen_local \--dataset_dir ./data \--template llama3 \--finetuning_type lora \--output_dir ./saves/LLaMA3-8B/lora/predict \--overwrite_cache \--overwrite_output_dir \--cutoff_len 1024 \--preprocessing_num_workers 16 \--per_device_eval_batch_size 1 \--max_samples 20 \--predict_with_generate

与训练脚本主要的参数区别如下两个

参数名称参数说明
do_predict现在是预测模式
predict_with_generate现在用于生成文本
max_samples每个数据集采样多少用于预测对比

运行后输出的尾部:

***** predict metrics *****predict_bleu-4                 =    27.9112predict_model_preparation_time =     0.0037predict_rouge-1                =     48.432predict_rouge-2                =    27.0109predict_rouge-l                =    41.2608predict_runtime                = 0:01:46.62predict_samples_per_second     =      0.563predict_steps_per_second       =      0.563
08/29/2024 16:06:36 - INFO - llamafactory.train.sft.trainer - Saving prediction results to ./saves/LLaMA3-8B/lora/predict/generated_predictions.jsonl

其中

  • saves/LLaMA3-8B/lora/predict/generated_predictions.jsonl 输出了要预测的数据集的原始label和模型predict的结果
  • saves/LLaMA3-8B/lora/predict/predict_results.json 给出了原始label和模型predict的结果,用自动计算的指标数据

LoRA模型合并导出

如果想把训练的LoRA和原始的大模型进行融合,输出一个完整的模型文件的话,可以使用如下命令。合并后的模型可以自由地像使用原始的模型一样应用到其他下游环节,当然也可以递归地继续用于训练。

CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--adapter_name_or_path ./saves/LLaMA3-8B/lora/sft  \--template llama3 \--finetuning_type lora \--export_dir megred-model-path \--export_size 2 \--export_device cpu \--export_legacy_format False

查看merge后的文件:

ls -al --block-size=M megred-model-path/
total 15326M
drwxr-xr-x  2 root root    1M Aug 29 16:18 .
drwxr-xr-x 15 root root    1M Aug 29 16:18 ..
-rw-r--r--  1 root root    1M Aug 29 16:18 config.json
-rw-r--r--  1 root root    1M Aug 29 16:18 generation_config.json
-rw-r--r--  1 root root 1883M Aug 29 16:18 model-00001-of-00009.safetensors
-rw-r--r--  1 root root 1809M Aug 29 16:18 model-00002-of-00009.safetensors
-rw-r--r--  1 root root 1889M Aug 29 16:18 model-00003-of-00009.safetensors
-rw-r--r--  1 root root 1857M Aug 29 16:18 model-00004-of-00009.safetensors
-rw-r--r--  1 root root 1889M Aug 29 16:18 model-00005-of-00009.safetensors
-rw-r--r--  1 root root 1857M Aug 29 16:18 model-00006-of-00009.safetensors
-rw-r--r--  1 root root 1889M Aug 29 16:18 model-00007-of-00009.safetensors
-rw-r--r--  1 root root 1249M Aug 29 16:18 model-00008-of-00009.safetensors
-rw-r--r--  1 root root 1003M Aug 29 16:18 model-00009-of-00009.safetensors
-rw-r--r--  1 root root    1M Aug 29 16:18 model.safetensors.index.json
-rw-r--r--  1 root root    1M Aug 29 16:18 special_tokens_map.json
-rw-r--r--  1 root root    9M Aug 29 16:18 tokenizer.json
-rw-r--r--  1 root root    1M Aug 29 16:18 tokenizer_config.json

API Server的启动与调用

使用merge前的LoRA模型推理:

CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \--model_name_or_path /root/workspace/models-modelscope/Meta-Llama-3-8B-Instruct \--adapter_name_or_path ./saves/LLaMA3-8B/lora/sft \--template llama3 \--finetuning_type lora

使用merge后的完整版模型基于VLLM推理:

pip install vllm>=0.4.3
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \--model_name_or_path megred-model-path \--template llama3 \--infer_backend vllm \--vllm_enforce_eager

转换为gguf模型文件格式

git clone https://github.com/ggerganov/llama.cpp.git
cd llama.cpp/gguf-py
pip install --editable .
cd ..
python convert_hf_to_gguf.py /root/workspace/LLaMA-Factory/megred-model-path

输出(最后一行):

INFO:hf-to-gguf:Model successfully exported to /root/workspace/LLaMA-Factory/megred-model-path/Megred-Model-Path-8.0B-F16.gguf

这篇关于LLaMA-Factory微调入门个人重制版的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120191

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU

poj 2104 and hdu 2665 划分树模板入门题

题意: 给一个数组n(1e5)个数,给一个范围(fr, to, k),求这个范围中第k大的数。 解析: 划分树入门。 bing神的模板。 坑爹的地方是把-l 看成了-1........ 一直re。 代码: poj 2104: #include <iostream>#include <cstdio>#include <cstdlib>#include <al

MySQL-CRUD入门1

文章目录 认识配置文件client节点mysql节点mysqld节点 数据的添加(Create)添加一行数据添加多行数据两种添加数据的效率对比 数据的查询(Retrieve)全列查询指定列查询查询中带有表达式关于字面量关于as重命名 临时表引入distinct去重order by 排序关于NULL 认识配置文件 在我们的MySQL服务安装好了之后, 会有一个配置文件, 也就

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

HomeBank:开源免费的个人财务管理软件

在个人财务管理领域,找到一个既免费又开源的解决方案并非易事。HomeBank&nbsp;正是这样一个项目,它不仅提供了强大的功能,还拥有一个活跃的社区,不断推动其发展和完善。 开源免费:HomeBank 是一个完全开源的项目,用户可以自由地使用、修改和分发。用户友好的界面:提供直观的图形用户界面,使得非技术用户也能轻松上手。数据导入支持:支持从 Quicken、Microsoft Money

C语言指针入门 《C语言非常道》

C语言指针入门 《C语言非常道》 作为一个程序员,我接触 C 语言有十年了。有的朋友让我推荐 C 语言的参考书,我不敢乱推荐,尤其是国内作者写的书,往往七拼八凑,漏洞百出。 但是,李忠老师的《C语言非常道》值得一读。对了,李老师有个官网,网址是: 李忠老师官网 最棒的是,有配套的教学视频,可以试看。 试看点这里 接下来言归正传,讲解指针。以下内容很多都参考了李忠老师的《C语言非