LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)

2024-09-01 17:12

本文主要是介绍LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 引言

  • 环境安装

  • 数据准备

    • 下载

    • 处理

  • 模型训练

  • 模型inference

  • 结果

    • gemma-2-9b

    • gemma-2-9b-it

引言

低头观落日,引手摘飞星。

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖黑神话的小女孩。本文紧接前文Google最新开源大语言模型:Gemma 2介绍及其微调(上篇),介绍如何用中文语料微调Gemma 2模型。如想与小编进一步交流,欢迎在《小窗幽记机器学习》上获取小编微信号,或者直接添加小编的wx号:

环境安装

pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
pip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple

数据准备

下载

这里使用Hello-SimpleAI/HC3-Chinese数据集进行微调。预先下载:

huggingface-cli download --resume-download --repo-type dataset --local-dir-use-symlinks False Hello-SimpleAI/HC3-Chinese --local-dir /share_data_zoo/LLM/Hello-SimpleAI/HC3-Chinese/

处理

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:25
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_data_preprocess.py
"""
预处理:划分训练集和测试集
"""
import os
import pdbfrom datasets import load_dataset# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"""
data_dir = "/share_data_zoo/LLM/"
data_id = "Hello-SimpleAI/HC3-Chinese"
data_name = data_id.split('/')[-1]
print("data_name=", data_name)
# pdb.set_trace()
data_path = os.path.join(data_dir, data_id)"""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}if __name__ == "__main__":# Load dataset from the hubdataset_dict = load_dataset("json", data_files=f"{data_path}/baike.jsonl")# 由于只有一个文件,我们将其视为训练集) split="train"dataset = dataset_dict['train']print(create_conversation(dataset[0]))# # Convert dataset to OAI messagesdataset = dataset.map(create_conversation, remove_columns=["chatgpt_answers"], batched=False)# # split dataset into 10,000 training samples and 2,500 test samples# dataset = dataset.train_test_split(test_size=4500/4616)  # baike splitdataset = dataset.train_test_split(test_size=0.1)# save datasets to diskdataset["train"].to_json("train_dataset.json", orient="records")dataset["test"].to_json("test_dataset.json", orient="records")print("Save to disk success")

模型训练

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 14:53
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma.py
"""
安装依赖:pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simplepip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple准备数据:运行 fine_tuning_data_preprocess.py 脚本开始训练:运行 fine_tuning_gemma.py 脚本在脚本的末尾会将lora和原始模型进行merge开始inference:运行 fine_tuning_gemma_inference.py 脚本如果报错:ImportError: /usr/local/lib/python3.10/dist-packages/transformer_engine_extensions.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEEpip3 uninstall transformer-engine 即可
"""
import os
import pdb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
from fine_tuning_data_preprocess import data_name as train_data_nameinit_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)
print("result_model_dir=", result_model_dir)
# 检查路径是否已存在
if not os.path.exists(result_model_dir):# 递归创建目录os.makedirs(result_model_dir)print("目录已创建:", result_model_dir)
else:print("目录已存在:", result_model_dir)# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
""""""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(init_model_path,device_map="auto",# attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16,quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(init_model_path)
tokenizer.padding_side = 'right'  # to prevent warnings# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(lora_alpha=128,lora_dropout=0.05,r=256,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",
)args = TrainingArguments(output_dir=result_model_dir,  # directory to save and repository idnum_train_epochs=3,  # number of training epochsper_device_train_batch_size=1,  # batch size per device during traininggradient_accumulation_steps=2,  # number of steps before performing a backward/update passgradient_checkpointing=True,  # use gradient checkpointing to save memoryoptim="adamw_torch_fused",  # use fused adamw optimizerlogging_steps=10,  # log every 10 stepssave_strategy="epoch",  # save checkpoint every epochlearning_rate=2e-4,  # learning rate, based on QLoRA paper# bf16=True,                              # use bfloat16 precision if you have supported GPU# tf32=True,                              # use tf32 precision if you have supported GPUmax_grad_norm=0.3,  # max gradient norm based on QLoRA paperwarmup_ratio=0.03,  # warmup ratio based on QLoRA paperlr_scheduler_type="constant",  # use constant learning rate schedulerpush_to_hub=False,  # push model to hubreport_to="tensorboard",  # report metrics to tensorboard
)max_seq_length = 1024  # max sequence length for model and packing of the datasettrainer = SFTTrainer(model=model,args=args,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,dataset_kwargs={"add_special_tokens": False,  # We template with special tokens"append_concat_token": False,  # No need to add additional separator token}
)# start training, the model will be automatically saved to the hub and the output directory
trainer.train()# save model
trainer.save_model()
print("Save model success")
### COMMENT IN TO MERGE PEFT AND BASE MODEL ##### Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,torch_dtype=torch.float16,low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
print(f"Save merged_model to {args.output_dir} success")

模型inference

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:51
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma_inference.py
"""
transformers
"""
import os
import pdb
import time
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
from fine_tuning_data_preprocess import data_name as train_data_name
from random import randintinit_model_dir = "/share_model_zoo/LLM/"
init_model_id = "google/gemma-2-9b"
# init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)peft_model_id = result_model_dir# Load Model with PEFT adapter
start_time = time.time()
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id,device_map="auto",torch_dtype=torch.float16
)
print(f"Load peft model={peft_model_id} success")
end_time = time.time()
model_load_cost = round(end_time - start_time, 2)
print(f"model load cost={model_load_cost}")tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# Test on sample
rand_idx = 2
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
test_texts = eval_dataset[rand_idx]["messages"][:2]
# pdb.set_trace()
# 调用方法1:
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False,add_generation_prompt=True)
outputs = pipe(prompt, repetition_penalty=1.3, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50,top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)# # 调用方法2:
# messages = [
#     {"role": "user", "content": "你是谁?"},
# ]
# messages_outputs = pipe(
#     messages,
#     repetition_penalty=1.3,
#     max_new_tokens=256,
#     do_sample=False,
# )
#
# assistant_response = messages_outputs[0]["generated_text"][-1]["content"]
# print("assistant_response=\n", assistant_response)print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

结果

gemma-2-9b

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是采用先进的网络协议和算法(例如RIP、OSPF、BGP等)进行路由管理与控制以及负载均衡的一种类Unix计算机操作系统。它是在1996年由MikroTik公司开发并发布的第一个版本为2.0而设计的用于多种平台上的高级互联网网关软件包或系统。 它的主要目标是对小型办公室和家庭用户的无线局域网提供出色的性能以改善数据传输速率和其他关键指标;同时最大限度地降低成本并在设计中考虑安装复杂性及可扩充性的需求点。 它在全球拥有超过35,000个活跃的用户群并且 Mikrotik 是世界领先且最可靠的小型企业边缘联网设备供应商 。他们已成功地在全世界销售了超过4百万的产品 ,产品覆盖范围从低端到高端商业办公大楼或是 ISP 的核心机房都适用。他们的客户遍布于几乎所有可以上网的地方而且很多国家都有其代表处或者分销商 ;由于产品的易用性和高性价比 ,使得我们的产品受到许多新兴市场的青睐比如:俄罗斯 、印度 和中国等等国家的市场正在蓬勃发展 !我们确信这些还不是它们的极限!随着科技的发展 ,Internet 将会

gemma-2-9b-it

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是由 Latvian Information Technologies Association(丽顿信息技术协会)开发的网络协议栈和路由器软件。它被认为是基于IPv4/IPSec、MPLS及其他高速数据传输协议的高速分组交换与包处理实现的核心;也是Internet骨干网建设的重要设备之一。
其核心应用为:防火墙服务 (NAT / IPsec)、高速度互联网接入服务器 、安全 VPN 等业务功能 。此外, 它还具有丰富的语音压缩算法等线路侧特性.因此在电信固定无线通信方面也发挥着重要作用。由于采用先进的数据转发引擎架构设计使其具备很强的扩展性 ,所以routeros体系能够兼容多种处理器结构,从x86到ARM9E-S 等等 । routeros本身提供很多高级的技术功能面,但并没有进行深入的研究工作,因为它的开发者希望把产品的源代码开放给公众以便共同改进产品性能.随着多核CPU技术的盛行以及云计算理论产生的兴起 ,许多新兴的公司都加入到了这个行业并且利用了开源软件作为基础做出了自己的创新产 品.这促进了计算机硬件的发展 和社会资源的合理利用$.当然对于一些大公司来说他们拥有足够的研发能力可以自行

这篇关于LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1127554

相关文章

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

PyCharm 接入 DeepSeek最新完整教程

《PyCharm接入DeepSeek最新完整教程》文章介绍了DeepSeek-V3模型的性能提升以及如何在PyCharm中接入和使用DeepSeek进行代码开发,本文通过图文并茂的形式给大家介绍的... 目录DeepSeek-V3效果演示创建API Key在PyCharm中下载Continue插件配置Con

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

MySQL 缓存机制与架构解析(最新推荐)

《MySQL缓存机制与架构解析(最新推荐)》本文详细介绍了MySQL的缓存机制和整体架构,包括一级缓存(InnoDBBufferPool)和二级缓存(QueryCache),文章还探讨了SQL... 目录一、mysql缓存机制概述二、MySQL整体架构三、SQL查询执行全流程四、MySQL 8.0为何移除查

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

MySql9.1.0安装详细教程(最新推荐)

《MySql9.1.0安装详细教程(最新推荐)》MySQL是一个流行的关系型数据库管理系统,支持多线程和多种数据库连接途径,能够处理上千万条记录的大型数据库,本文介绍MySql9.1.0安装详细教程,... 目录mysql介绍:一、下载 Mysql 安装文件二、Mysql 安装教程三、环境配置1.右击此电脑

在 Windows 上安装 DeepSeek 的完整指南(最新推荐)

《在Windows上安装DeepSeek的完整指南(最新推荐)》在Windows上安装DeepSeek的完整指南,包括下载和安装Ollama、下载DeepSeekRXNUMX模型、运行Deep... 目录在www.chinasem.cn Windows 上安装 DeepSeek 的完整指南步骤 1:下载并安装

深入理解Apache Airflow 调度器(最新推荐)

《深入理解ApacheAirflow调度器(最新推荐)》ApacheAirflow调度器是数据管道管理系统的关键组件,负责编排dag中任务的执行,通过理解调度器的角色和工作方式,正确配置调度器,并... 目录什么是Airflow 调度器?Airflow 调度器工作机制配置Airflow调度器调优及优化建议最