本文主要是介绍精调llama模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
github地址:https://github.com/facebookresearch/llama-recipes
github:https://github.com/facebookresearch/llama
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer#model_id="./models_hf/7B"
# 可以从huggingface上面下载模型,hf就是huggingface模型,也可以通过transformer库的convert_llama_weights_to_hf方法来转换原始的llama模型
model_id="模型path/Llama-2-7b-chat-hf-local"tokenizer = LlamaTokenizer.from_pretrained(model_id)model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
from llama_recipes.configs.datasets import samsum_datasettrain_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')
eval_prompt = """
Summarize this dialog:
A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-)
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))
---
Summary:
"""model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")model.eval()
with torch.no_grad():print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))model.train()def create_peft_config(model):from peft import (get_peft_model,LoraConfig,TaskType,prepare_model_for_int8_training,)peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.05,target_modules = ["q_proj", "v_proj"])# prepare int-8 model for trainingmodel = prepare_model_for_int8_training(model)model = get_peft_model(model, peft_config)model.print_trainable_parameters()return model, peft_config# create peft config
model, lora_config = create_peft_config(model)from transformers import TrainerCallback
from contextlib import nullcontext
enable_profiler = False
output_dir = "tmp/llama-output"config = {'lora_config': lora_config,'learning_rate': 1e-4,'num_train_epochs': 1,'gradient_accumulation_steps': 2,'per_device_train_batch_size': 2,'gradient_checkpointing': False,
}# Set up profiler
if enable_profiler:wait, warmup, active, repeat = 1, 1, 2, 1total_steps = (wait + warmup + active) * (1 + repeat)schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)profiler = torch.profiler.profile(schedule=schedule,on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{output_dir}/logs/tensorboard"),record_shapes=True,profile_memory=True,with_stack=True)class ProfilerCallback(TrainerCallback):def __init__(self, profiler):self.profiler = profilerdef on_step_end(self, *args, **kwargs):self.profiler.step()profiler_callback = ProfilerCallback(profiler)
else:profiler = nullcontext()from transformers import default_data_collator, Trainer, TrainingArguments# Define training args
training_args = TrainingArguments(output_dir=output_dir,overwrite_output_dir=True,bf16=True, # Use BF16 if available# logging strategieslogging_dir=f"{output_dir}/logs",logging_strategy="steps",logging_steps=10,save_strategy="no",optim="adamw_torch_fused",max_steps=total_steps if enable_profiler else -1,**{k:v for k,v in config.items() if k != 'lora_config'}
)with profiler:# Create Trainer instancetrainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=default_data_collator,callbacks=[profiler_callback] if enable_profiler else [],)# Start trainingtrainer.train()model.save_pretrained(output_dir)model.eval()
with torch.no_grad():print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))
这篇关于精调llama模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!