LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)

2024-09-01 17:12

本文主要是介绍LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 引言

  • 环境安装

  • 数据准备

    • 下载

    • 处理

  • 模型训练

  • 模型inference

  • 结果

    • gemma-2-9b

    • gemma-2-9b-it

引言

低头观落日,引手摘飞星。

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖黑神话的小女孩。本文紧接前文Google最新开源大语言模型:Gemma 2介绍及其微调(上篇),介绍如何用中文语料微调Gemma 2模型。如想与小编进一步交流,欢迎在《小窗幽记机器学习》上获取小编微信号,或者直接添加小编的wx号:

环境安装

pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
pip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple

数据准备

下载

这里使用Hello-SimpleAI/HC3-Chinese数据集进行微调。预先下载:

huggingface-cli download --resume-download --repo-type dataset --local-dir-use-symlinks False Hello-SimpleAI/HC3-Chinese --local-dir /share_data_zoo/LLM/Hello-SimpleAI/HC3-Chinese/

处理

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:25
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_data_preprocess.py
"""
预处理:划分训练集和测试集
"""
import os
import pdbfrom datasets import load_dataset# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"""
data_dir = "/share_data_zoo/LLM/"
data_id = "Hello-SimpleAI/HC3-Chinese"
data_name = data_id.split('/')[-1]
print("data_name=", data_name)
# pdb.set_trace()
data_path = os.path.join(data_dir, data_id)"""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}if __name__ == "__main__":# Load dataset from the hubdataset_dict = load_dataset("json", data_files=f"{data_path}/baike.jsonl")# 由于只有一个文件,我们将其视为训练集) split="train"dataset = dataset_dict['train']print(create_conversation(dataset[0]))# # Convert dataset to OAI messagesdataset = dataset.map(create_conversation, remove_columns=["chatgpt_answers"], batched=False)# # split dataset into 10,000 training samples and 2,500 test samples# dataset = dataset.train_test_split(test_size=4500/4616)  # baike splitdataset = dataset.train_test_split(test_size=0.1)# save datasets to diskdataset["train"].to_json("train_dataset.json", orient="records")dataset["test"].to_json("test_dataset.json", orient="records")print("Save to disk success")

模型训练

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 14:53
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma.py
"""
安装依赖:pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simplepip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple准备数据:运行 fine_tuning_data_preprocess.py 脚本开始训练:运行 fine_tuning_gemma.py 脚本在脚本的末尾会将lora和原始模型进行merge开始inference:运行 fine_tuning_gemma_inference.py 脚本如果报错:ImportError: /usr/local/lib/python3.10/dist-packages/transformer_engine_extensions.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEEpip3 uninstall transformer-engine 即可
"""
import os
import pdb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
from fine_tuning_data_preprocess import data_name as train_data_nameinit_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)
print("result_model_dir=", result_model_dir)
# 检查路径是否已存在
if not os.path.exists(result_model_dir):# 递归创建目录os.makedirs(result_model_dir)print("目录已创建:", result_model_dir)
else:print("目录已存在:", result_model_dir)# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
""""""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(init_model_path,device_map="auto",# attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16,quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(init_model_path)
tokenizer.padding_side = 'right'  # to prevent warnings# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(lora_alpha=128,lora_dropout=0.05,r=256,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",
)args = TrainingArguments(output_dir=result_model_dir,  # directory to save and repository idnum_train_epochs=3,  # number of training epochsper_device_train_batch_size=1,  # batch size per device during traininggradient_accumulation_steps=2,  # number of steps before performing a backward/update passgradient_checkpointing=True,  # use gradient checkpointing to save memoryoptim="adamw_torch_fused",  # use fused adamw optimizerlogging_steps=10,  # log every 10 stepssave_strategy="epoch",  # save checkpoint every epochlearning_rate=2e-4,  # learning rate, based on QLoRA paper# bf16=True,                              # use bfloat16 precision if you have supported GPU# tf32=True,                              # use tf32 precision if you have supported GPUmax_grad_norm=0.3,  # max gradient norm based on QLoRA paperwarmup_ratio=0.03,  # warmup ratio based on QLoRA paperlr_scheduler_type="constant",  # use constant learning rate schedulerpush_to_hub=False,  # push model to hubreport_to="tensorboard",  # report metrics to tensorboard
)max_seq_length = 1024  # max sequence length for model and packing of the datasettrainer = SFTTrainer(model=model,args=args,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,dataset_kwargs={"add_special_tokens": False,  # We template with special tokens"append_concat_token": False,  # No need to add additional separator token}
)# start training, the model will be automatically saved to the hub and the output directory
trainer.train()# save model
trainer.save_model()
print("Save model success")
### COMMENT IN TO MERGE PEFT AND BASE MODEL ##### Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,torch_dtype=torch.float16,low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
print(f"Save merged_model to {args.output_dir} success")

模型inference

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:51
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma_inference.py
"""
transformers
"""
import os
import pdb
import time
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
from fine_tuning_data_preprocess import data_name as train_data_name
from random import randintinit_model_dir = "/share_model_zoo/LLM/"
init_model_id = "google/gemma-2-9b"
# init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)peft_model_id = result_model_dir# Load Model with PEFT adapter
start_time = time.time()
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id,device_map="auto",torch_dtype=torch.float16
)
print(f"Load peft model={peft_model_id} success")
end_time = time.time()
model_load_cost = round(end_time - start_time, 2)
print(f"model load cost={model_load_cost}")tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# Test on sample
rand_idx = 2
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
test_texts = eval_dataset[rand_idx]["messages"][:2]
# pdb.set_trace()
# 调用方法1:
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False,add_generation_prompt=True)
outputs = pipe(prompt, repetition_penalty=1.3, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50,top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)# # 调用方法2:
# messages = [
#     {"role": "user", "content": "你是谁?"},
# ]
# messages_outputs = pipe(
#     messages,
#     repetition_penalty=1.3,
#     max_new_tokens=256,
#     do_sample=False,
# )
#
# assistant_response = messages_outputs[0]["generated_text"][-1]["content"]
# print("assistant_response=\n", assistant_response)print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

结果

gemma-2-9b

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是采用先进的网络协议和算法(例如RIP、OSPF、BGP等)进行路由管理与控制以及负载均衡的一种类Unix计算机操作系统。它是在1996年由MikroTik公司开发并发布的第一个版本为2.0而设计的用于多种平台上的高级互联网网关软件包或系统。 它的主要目标是对小型办公室和家庭用户的无线局域网提供出色的性能以改善数据传输速率和其他关键指标;同时最大限度地降低成本并在设计中考虑安装复杂性及可扩充性的需求点。 它在全球拥有超过35,000个活跃的用户群并且 Mikrotik 是世界领先且最可靠的小型企业边缘联网设备供应商 。他们已成功地在全世界销售了超过4百万的产品 ,产品覆盖范围从低端到高端商业办公大楼或是 ISP 的核心机房都适用。他们的客户遍布于几乎所有可以上网的地方而且很多国家都有其代表处或者分销商 ;由于产品的易用性和高性价比 ,使得我们的产品受到许多新兴市场的青睐比如:俄罗斯 、印度 和中国等等国家的市场正在蓬勃发展 !我们确信这些还不是它们的极限!随着科技的发展 ,Internet 将会

gemma-2-9b-it

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是由 Latvian Information Technologies Association(丽顿信息技术协会)开发的网络协议栈和路由器软件。它被认为是基于IPv4/IPSec、MPLS及其他高速数据传输协议的高速分组交换与包处理实现的核心;也是Internet骨干网建设的重要设备之一。
其核心应用为:防火墙服务 (NAT / IPsec)、高速度互联网接入服务器 、安全 VPN 等业务功能 。此外, 它还具有丰富的语音压缩算法等线路侧特性.因此在电信固定无线通信方面也发挥着重要作用。由于采用先进的数据转发引擎架构设计使其具备很强的扩展性 ,所以routeros体系能够兼容多种处理器结构,从x86到ARM9E-S 等等 । routeros本身提供很多高级的技术功能面,但并没有进行深入的研究工作,因为它的开发者希望把产品的源代码开放给公众以便共同改进产品性能.随着多核CPU技术的盛行以及云计算理论产生的兴起 ,许多新兴的公司都加入到了这个行业并且利用了开源软件作为基础做出了自己的创新产 品.这促进了计算机硬件的发展 和社会资源的合理利用$.当然对于一些大公司来说他们拥有足够的研发能力可以自行

这篇关于LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1127554

相关文章

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

Spring Boot 中整合 MyBatis-Plus详细步骤(最新推荐)

《SpringBoot中整合MyBatis-Plus详细步骤(最新推荐)》本文详细介绍了如何在SpringBoot项目中整合MyBatis-Plus,包括整合步骤、基本CRUD操作、分页查询、批... 目录一、整合步骤1. 创建 Spring Boot 项目2. 配置项目依赖3. 配置数据源4. 创建实体类

Java子线程无法获取Attributes的解决方法(最新推荐)

《Java子线程无法获取Attributes的解决方法(最新推荐)》在Java多线程编程中,子线程无法直接获取主线程设置的Attributes是一个常见问题,本文探讨了这一问题的原因,并提供了两种解决... 目录一、问题原因二、解决方案1. 直接传递数据2. 使用ThreadLocal(适用于线程独立数据)

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G