mirror of
https://github.com/hoshikawa2/qlora_training.git
synced 2026-03-03 16:09:36 +00:00
89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
import os
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
|
from transformers import DataCollatorForLanguageModeling, BitsAndBytesConfig
|
|
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
|
|
from datasets import load_dataset
|
|
|
|
# Configurações
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
|
data_file_path = "./datasets/qa_dataset.json" # <<== Ajuste aqui o nome do seu arquivo JSON
|
|
output_dir = "./qlora-output"
|
|
max_length = 512
|
|
|
|
# Quantização com bitsandbytes
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype="float16"
|
|
)
|
|
|
|
# Tokenizador e modelo
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
quantization_config=bnb_config,
|
|
device_map="auto",
|
|
trust_remote_code=True
|
|
)
|
|
model = prepare_model_for_kbit_training(model)
|
|
|
|
# LoRA
|
|
peft_config = LoraConfig(
|
|
r=8,
|
|
lora_alpha=16,
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM"
|
|
)
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
# Dataset local em JSON
|
|
dataset = load_dataset("json", data_files=data_file_path, split="train")
|
|
|
|
# Tokenização
|
|
def tokenize(example):
|
|
tokenized = tokenizer(
|
|
example["text"],
|
|
truncation=True,
|
|
max_length=max_length,
|
|
padding="max_length"
|
|
)
|
|
tokenized["labels"] = tokenized["input_ids"].copy()
|
|
return tokenized
|
|
|
|
tokenized_dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
|
|
|
|
# Data collator
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
|
|
# Argumentos de treino
|
|
training_args = TrainingArguments(
|
|
output_dir=output_dir,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=4,
|
|
num_train_epochs=3,
|
|
learning_rate=2e-4,
|
|
fp16=True,
|
|
logging_steps=10,
|
|
save_strategy="epoch",
|
|
report_to="none"
|
|
)
|
|
|
|
# Trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized_dataset,
|
|
data_collator=data_collator
|
|
)
|
|
|
|
# Treino
|
|
trainer.train()
|
|
model.save_pretrained(output_dir)
|
|
tokenizer.save_pretrained(output_dir)
|