Skip to content

Commit 6e37708

Browse files
committed
add training script using unsloth
1 parent c0ddf9f commit 6e37708

File tree

4 files changed

+155
-3
lines changed

4 files changed

+155
-3
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from unsloth import FastLanguageModel
2+
import torch
3+
from unsloth import add_new_tokens
4+
from typing import Optional, List
5+
from transformers import TrainingArguments
6+
from unsloth import is_bfloat16_supported
7+
from unsloth import UnslothTrainer, UnslothTrainingArguments
8+
import fire
9+
import wandb
10+
from datasets import load_dataset
11+
12+
13+
def load_model(rank: int = 128, train_embeddings: bool = True, add_special_tokens: Optional[List[str]]=None):
14+
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
15+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
16+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
17+
18+
model, tokenizer = FastLanguageModel.from_pretrained(
19+
model_name = "unsloth/llama-3-8b-bnb-4bit",
20+
max_seq_length = max_seq_length,
21+
dtype = dtype,
22+
load_in_4bit = load_in_4bit,
23+
)
24+
25+
add_new_tokens(model, tokenizer, new_tokens = add_special_tokens)
26+
27+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
28+
"gate_proj", "up_proj", "down_proj"]
29+
30+
if train_embeddings:
31+
target_modules += ["embed_tokens", "lm_head"]
32+
model = FastLanguageModel.get_peft_model(
33+
model,
34+
r = rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
35+
target_modules = target_modules,
36+
lora_alpha = rank/4,
37+
lora_dropout = 0, # Supports any, but = 0 is optimized
38+
bias = "none", # Supports any, but = "none" is optimized
39+
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
40+
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
41+
random_state = 3407,
42+
use_rslora = True, # We support rank stabilized LoRA
43+
loftq_config = None, # And LoftQ
44+
)
45+
46+
return model, tokenizer
47+
48+
49+
def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_length = 2048):
50+
wandb.init(
51+
project="chemnlp-ablations",
52+
name=run_name
53+
)
54+
trainer = UnslothTrainer(
55+
model = model,
56+
tokenizer = tokenizer,
57+
train_dataset = dataset,
58+
dataset_text_field = "text",
59+
max_seq_length = max_seq_length,
60+
dataset_num_proc = 2,
61+
62+
args = UnslothTrainingArguments(
63+
per_device_train_batch_size = batch_size,
64+
gradient_accumulation_steps = 1,
65+
warmup_ratio = 0.1,
66+
num_train_epochs = 1,
67+
learning_rate = 5e-5,
68+
embedding_learning_rate = 1e-5,
69+
fp16 = not is_bfloat16_supported(),
70+
bf16 = is_bfloat16_supported(),
71+
logging_steps = 1,
72+
optim = "adamw_8bit",
73+
weight_decay = 0.01,
74+
lr_scheduler_type = "linear",
75+
seed = 3407,
76+
output_dir = f"outputs_{run_name}",
77+
),
78+
)
79+
80+
#@title Show current memory stats
81+
gpu_stats = torch.cuda.get_device_properties(0)
82+
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
83+
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
84+
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
85+
print(f"{start_gpu_memory} GB of memory reserved.")
86+
87+
trainer_stats = trainer.train()
88+
89+
model.save_pretrained(f"lora_model_{run_name}") # Local saving
90+
tokenizer.save_pretrained(f"lora_model_{run_name}")
91+
92+
93+
def create_dataset(tokenizer, datasets):
94+
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
95+
def formatting_prompts_func(examples):
96+
outputs = []
97+
for t in examples['text']:
98+
outputs.append(t + EOS_TOKEN)
99+
return { "text" : outputs, }
100+
101+
dataset = load_dataset("json", data_files=datasets)
102+
dataset = dataset["train"]
103+
104+
dataset = dataset.map(formatting_prompts_func, batched = True)
105+
106+
return dataset
107+
108+
if __name__ == "__main__":
109+
model, tokenizer = load_model(train_embeddings=True, add_special_tokens=None)
110+
111+
dataset = create_dataset(tokenizer, ["data/chemnlp_train.json", "data/chemnlp_val.json"])
112+
113+
train(model, tokenizer, dataset, "lora_128", batch_size=64)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ chemnlp-generate-meta = "chemnlp.data.meta_yaml_generator:cli"
3535
chemnlp-augment-meta = "chemnlp.data.meta_yaml_augmenter:cli"
3636
chemnlp-sample = "chemnlp.data.sampler_cli:cli"
3737
chemnlp-add-random-split-column = "chemnlp.data.utils:add_random_split_column_cli"
38+
chemnlp-concatenate-jsonl = "chemnlp.data.utils:concatenate_jsonl_files_cli"
3839

3940
[tool.setuptools_scm]
4041
version_scheme = "post-release"

src/chemnlp/data/sampler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,8 @@ def export(self, output_dir: str, template: str) -> pd.DataFrame:
839839
df_split = self.df[self.df["split"] == split]
840840
samples = []
841841
for _, row in tqdm(df_split.iterrows(), total=len(df_split)):
842-
sample_dict = row.to_dict()
843-
sample = self._fill_template(template, sample_dict)
844-
samples.append(sample)
842+
sampled = self.sample(row, template)
843+
samples.append(sampled)
845844
df_out = pd.DataFrame(samples)
846845

847846
# if self.benchmarking_templates:

src/chemnlp/data/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,45 @@
88
import pandas as pd
99

1010

11+
from pathlib import Path
12+
import fire
13+
14+
def get_all_datasets(root_dir):
15+
return [d.name for d in Path(root_dir).iterdir() if d.is_dir()]
16+
17+
def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='train'):
18+
root_dir = Path(root_dir)
19+
20+
if datasets is None:
21+
datasets = get_all_datasets(root_dir)
22+
elif isinstance(datasets, str):
23+
datasets = [datasets]
24+
25+
print(f"Processing datasets: {', '.join(datasets)}")
26+
print(f"File type: {file_type}.jsonl")
27+
28+
with open(output_file, 'w') as outfile:
29+
for dataset in datasets:
30+
dataset_path = root_dir / dataset
31+
if not dataset_path.is_dir():
32+
print(f"Warning: Dataset '{dataset}' not found. Skipping.")
33+
continue
34+
35+
for chunk_dir in dataset_path.glob('chunk_*'):
36+
for template_dir in chunk_dir.glob('template_*'):
37+
jsonl_file = template_dir / f'{file_type}.jsonl'
38+
if jsonl_file.is_file():
39+
with open(jsonl_file, 'r') as infile:
40+
for line in infile:
41+
outfile.write(line)
42+
43+
print(f"Concatenated {file_type}.jsonl files have been saved to {output_file}")
44+
45+
def concatenate_jsonl_files_cli():
46+
fire.Fire(concatenate_jsonl_files)
47+
48+
49+
1150
def add_random_split_column(df):
1251
# Calculate the number of rows for each split
1352
n_rows = len(df)

0 commit comments

Comments
 (0)