22import torch
33from unsloth import add_new_tokens
44from typing import Optional , List
5- from transformers import TrainingArguments
65from unsloth import is_bfloat16_supported
76from unsloth import UnslothTrainer , UnslothTrainingArguments
8- import fire
97import wandb
108from datasets import load_dataset
9+ import fire
1110
12-
13- def load_model (rank : int = 128 , train_embeddings : bool = True , add_special_tokens : Optional [List [str ]]= None ):
14- max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
15- dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
16- load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
11+ def load_model (
12+ rank : int = 128 ,
13+ train_embeddings : bool = True ,
14+ add_special_tokens : Optional [List [str ]] = None ,
15+ ):
16+ max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
17+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
18+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
1719
1820 model , tokenizer = FastLanguageModel .from_pretrained (
19- model_name = "unsloth/llama-3-8b-bnb-4bit" ,
20- max_seq_length = max_seq_length ,
21- dtype = dtype ,
22- load_in_4bit = load_in_4bit ,
21+ model_name = "unsloth/llama-3-8b-bnb-4bit" ,
22+ max_seq_length = max_seq_length ,
23+ dtype = dtype ,
24+ load_in_4bit = load_in_4bit ,
2325 )
2426
25- add_new_tokens (model , tokenizer , new_tokens = add_special_tokens )
27+ add_new_tokens (model , tokenizer , new_tokens = add_special_tokens )
2628
27- target_modules = ["q_proj" , "k_proj" , "v_proj" , "o_proj" ,
28- "gate_proj" , "up_proj" , "down_proj" ]
29+ target_modules = [
30+ "q_proj" ,
31+ "k_proj" ,
32+ "v_proj" ,
33+ "o_proj" ,
34+ "gate_proj" ,
35+ "up_proj" ,
36+ "down_proj" ,
37+ ]
2938
3039 if train_embeddings :
31- target_modules += ["embed_tokens" , "lm_head" ]
40+ target_modules += ["embed_tokens" , "lm_head" ]
3241 model = FastLanguageModel .get_peft_model (
3342 model ,
34- r = rank , # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
35- target_modules = target_modules ,
36- lora_alpha = rank / 4 ,
37- lora_dropout = 0 , # Supports any, but = 0 is optimized
38- bias = "none" , # Supports any, but = "none" is optimized
43+ r = rank , # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
44+ target_modules = target_modules ,
45+ lora_alpha = rank / 4 ,
46+ lora_dropout = 0 , # Supports any, but = 0 is optimized
47+ bias = "none" , # Supports any, but = "none" is optimized
3948 # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
40- use_gradient_checkpointing = "unsloth" , # True or "unsloth" for very long context
41- random_state = 3407 ,
42- use_rslora = True , # We support rank stabilized LoRA
43- loftq_config = None , # And LoftQ
49+ use_gradient_checkpointing = "unsloth" , # True or "unsloth" for very long context
50+ random_state = 3407 ,
51+ use_rslora = True , # We support rank stabilized LoRA
52+ loftq_config = None , # And LoftQ
4453 )
4554
4655 return model , tokenizer
4756
4857
49- def train (model , tokenizer , dataset , run_name : str , batch_size :int = 64 , max_seq_length = 2048 ):
50- wandb .init (
51- project = "chemnlp-ablations" ,
52- name = run_name
53- )
58+ def train (
59+ model , tokenizer , dataset , run_name : str , batch_size : int = 64 , max_seq_length = 2048
60+ ):
61+ wandb .init (project = "chemnlp-ablations" , name = run_name )
5462 trainer = UnslothTrainer (
55- model = model ,
56- tokenizer = tokenizer ,
57- train_dataset = dataset ,
58- dataset_text_field = "text" ,
59- max_seq_length = max_seq_length ,
60- dataset_num_proc = 2 ,
61-
62- args = UnslothTrainingArguments (
63- per_device_train_batch_size = batch_size ,
64- gradient_accumulation_steps = 1 ,
65- warmup_ratio = 0.1 ,
66- num_train_epochs = 1 ,
67- learning_rate = 5e-5 ,
68- embedding_learning_rate = 1e-5 ,
69- fp16 = not is_bfloat16_supported (),
70- bf16 = is_bfloat16_supported (),
71- logging_steps = 1 ,
72- optim = "adamw_8bit" ,
73- weight_decay = 0.01 ,
74- lr_scheduler_type = "linear" ,
75- seed = 3407 ,
76- output_dir = f"outputs_{ run_name } " ,
63+ model = model ,
64+ tokenizer = tokenizer ,
65+ train_dataset = dataset ,
66+ dataset_text_field = "text" ,
67+ max_seq_length = max_seq_length ,
68+ dataset_num_proc = 2 ,
69+ args = UnslothTrainingArguments (
70+ per_device_train_batch_size = batch_size ,
71+ gradient_accumulation_steps = 1 ,
72+ warmup_ratio = 0.1 ,
73+ num_train_epochs = 1 ,
74+ learning_rate = 5e-5 ,
75+ embedding_learning_rate = 1e-5 ,
76+ fp16 = not is_bfloat16_supported (),
77+ bf16 = is_bfloat16_supported (),
78+ logging_steps = 1 ,
79+ optim = "adamw_8bit" ,
80+ weight_decay = 0.01 ,
81+ lr_scheduler_type = "linear" ,
82+ seed = 3407 ,
83+ output_dir = f"outputs_{ run_name } " ,
7784 ),
7885 )
7986
80- #@title Show current memory stats
87+ # @title Show current memory stats
8188 gpu_stats = torch .cuda .get_device_properties (0 )
8289 start_gpu_memory = round (torch .cuda .max_memory_reserved () / 1024 / 1024 / 1024 , 3 )
8390 max_memory = round (gpu_stats .total_memory / 1024 / 1024 / 1024 , 3 )
@@ -86,28 +93,38 @@ def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_
8693
8794 trainer_stats = trainer .train ()
8895
89- model .save_pretrained (f"lora_model_{ run_name } " ) # Local saving
96+ model .save_pretrained (f"lora_model_{ run_name } " ) # Local saving
9097 tokenizer .save_pretrained (f"lora_model_{ run_name } " )
9198
9299
93100def create_dataset (tokenizer , datasets ):
94- EOS_TOKEN = tokenizer .eos_token # Must add EOS_TOKEN
101+ EOS_TOKEN = tokenizer .eos_token # Must add EOS_TOKEN
102+
95103 def formatting_prompts_func (examples ):
96104 outputs = []
97- for t in examples [' text' ]:
105+ for t in examples [" text" ]:
98106 outputs .append (t + EOS_TOKEN )
99- return { "text" : outputs , }
107+ return {
108+ "text" : outputs ,
109+ }
100110
101111 dataset = load_dataset ("json" , data_files = datasets )
102112 dataset = dataset ["train" ]
103113
104- dataset = dataset .map (formatting_prompts_func , batched = True )
114+ dataset = dataset .map (formatting_prompts_func , batched = True )
105115
106116 return dataset
107117
108- if __name__ == "__main__" :
109- model , tokenizer = load_model (train_embeddings = True , add_special_tokens = None )
110118
111- dataset = create_dataset (tokenizer , ["data/chemnlp_train.json" , "data/chemnlp_val.json" ])
119+ def run (data_files : List [str ], train_embeddings : bool , run_name : str , batch_size : int , add_special_tokens : Optional [List [str ]]= None )
120+ model , tokenizer = load_model (train_embeddings = train_embeddings , add_special_tokens = add_special_tokens )
112121
113- train (model , tokenizer , dataset , "lora_128" , batch_size = 64 )
122+ dataset = create_dataset (
123+ tokenizer , data_files
124+ )
125+
126+ train (model , tokenizer , dataset , run_name , batch_size = batch_size )
127+
128+
129+ if __name__ == "__main__" :
130+ fire .Fire (run )
0 commit comments