-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
58 lines (50 loc) · 1.78 KB
/
train.py
File metadata and controls
58 lines (50 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import load_from_disk # or use your data_loader result
def run_training():
model_id = "meta-llama/Llama-3-8b-hf" # Or any hardware-ready model
# 1. 4-bit Quantization (Saves VRAM)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# 2. Load Model & Tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# 3. LoRA Configuration (The "Efficiency" part)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # Targets the Attention layers
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# 4. Trainer Configuration
training_args = SFTConfig(
output_dir="./results",
max_seq_length=1024,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=3,
save_steps=100,
logging_steps=10,
dataset_text_field="text", # This matches our preprocess.py output
)
# 5. Initialize Trainer
trainer = SFTTrainer(
model=model,
train_dataset=load_from_disk("processed_data"), # Use your preprocessed data
peft_config=peft_config,
args=training_args,
)
print("--- 🚀 Starting Agent 1 Training ---")
trainer.train()
trainer.save_model("./trained_agent_1")
if __name__ == "__main__":
run_training()