Skip to content

Commit f72d74f

Browse files
committed
revise training script
1 parent 6e37708 commit f72d74f

File tree

3 files changed

+102
-84
lines changed

3 files changed

+102
-84
lines changed

experiments/ablations/continued_pretrain.py

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,89 @@
22
import torch
33
from unsloth import add_new_tokens
44
from typing import Optional, List
5-
from transformers import TrainingArguments
65
from unsloth import is_bfloat16_supported
76
from unsloth import UnslothTrainer, UnslothTrainingArguments
8-
import fire
97
import wandb
108
from datasets import load_dataset
9+
import fire
1110

12-
13-
def load_model(rank: int = 128, train_embeddings: bool = True, add_special_tokens: Optional[List[str]]=None):
14-
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
15-
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
16-
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
11+
def load_model(
12+
rank: int = 128,
13+
train_embeddings: bool = True,
14+
add_special_tokens: Optional[List[str]] = None,
15+
):
16+
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
17+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
18+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
1719

1820
model, tokenizer = FastLanguageModel.from_pretrained(
19-
model_name = "unsloth/llama-3-8b-bnb-4bit",
20-
max_seq_length = max_seq_length,
21-
dtype = dtype,
22-
load_in_4bit = load_in_4bit,
21+
model_name="unsloth/llama-3-8b-bnb-4bit",
22+
max_seq_length=max_seq_length,
23+
dtype=dtype,
24+
load_in_4bit=load_in_4bit,
2325
)
2426

25-
add_new_tokens(model, tokenizer, new_tokens = add_special_tokens)
27+
add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)
2628

27-
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
28-
"gate_proj", "up_proj", "down_proj"]
29+
target_modules = [
30+
"q_proj",
31+
"k_proj",
32+
"v_proj",
33+
"o_proj",
34+
"gate_proj",
35+
"up_proj",
36+
"down_proj",
37+
]
2938

3039
if train_embeddings:
31-
target_modules += ["embed_tokens", "lm_head"]
40+
target_modules += ["embed_tokens", "lm_head"]
3241
model = FastLanguageModel.get_peft_model(
3342
model,
34-
r = rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
35-
target_modules = target_modules,
36-
lora_alpha = rank/4,
37-
lora_dropout = 0, # Supports any, but = 0 is optimized
38-
bias = "none", # Supports any, but = "none" is optimized
43+
r=rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
44+
target_modules=target_modules,
45+
lora_alpha=rank / 4,
46+
lora_dropout=0, # Supports any, but = 0 is optimized
47+
bias="none", # Supports any, but = "none" is optimized
3948
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
40-
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
41-
random_state = 3407,
42-
use_rslora = True, # We support rank stabilized LoRA
43-
loftq_config = None, # And LoftQ
49+
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
50+
random_state=3407,
51+
use_rslora=True, # We support rank stabilized LoRA
52+
loftq_config=None, # And LoftQ
4453
)
4554

4655
return model, tokenizer
4756

4857

49-
def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_length = 2048):
50-
wandb.init(
51-
project="chemnlp-ablations",
52-
name=run_name
53-
)
58+
def train(
59+
model, tokenizer, dataset, run_name: str, batch_size: int = 64, max_seq_length=2048
60+
):
61+
wandb.init(project="chemnlp-ablations", name=run_name)
5462
trainer = UnslothTrainer(
55-
model = model,
56-
tokenizer = tokenizer,
57-
train_dataset = dataset,
58-
dataset_text_field = "text",
59-
max_seq_length = max_seq_length,
60-
dataset_num_proc = 2,
61-
62-
args = UnslothTrainingArguments(
63-
per_device_train_batch_size = batch_size,
64-
gradient_accumulation_steps = 1,
65-
warmup_ratio = 0.1,
66-
num_train_epochs = 1,
67-
learning_rate = 5e-5,
68-
embedding_learning_rate = 1e-5,
69-
fp16 = not is_bfloat16_supported(),
70-
bf16 = is_bfloat16_supported(),
71-
logging_steps = 1,
72-
optim = "adamw_8bit",
73-
weight_decay = 0.01,
74-
lr_scheduler_type = "linear",
75-
seed = 3407,
76-
output_dir = f"outputs_{run_name}",
63+
model=model,
64+
tokenizer=tokenizer,
65+
train_dataset=dataset,
66+
dataset_text_field="text",
67+
max_seq_length=max_seq_length,
68+
dataset_num_proc=2,
69+
args=UnslothTrainingArguments(
70+
per_device_train_batch_size=batch_size,
71+
gradient_accumulation_steps=1,
72+
warmup_ratio=0.1,
73+
num_train_epochs=1,
74+
learning_rate=5e-5,
75+
embedding_learning_rate=1e-5,
76+
fp16=not is_bfloat16_supported(),
77+
bf16=is_bfloat16_supported(),
78+
logging_steps=1,
79+
optim="adamw_8bit",
80+
weight_decay=0.01,
81+
lr_scheduler_type="linear",
82+
seed=3407,
83+
output_dir=f"outputs_{run_name}",
7784
),
7885
)
7986

80-
#@title Show current memory stats
87+
# @title Show current memory stats
8188
gpu_stats = torch.cuda.get_device_properties(0)
8289
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
8390
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
@@ -86,28 +93,38 @@ def train(model, tokenizer, dataset, run_name: str, batch_size:int =64, max_seq_
8693

8794
trainer_stats = trainer.train()
8895

89-
model.save_pretrained(f"lora_model_{run_name}") # Local saving
96+
model.save_pretrained(f"lora_model_{run_name}") # Local saving
9097
tokenizer.save_pretrained(f"lora_model_{run_name}")
9198

9299

93100
def create_dataset(tokenizer, datasets):
94-
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
101+
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
102+
95103
def formatting_prompts_func(examples):
96104
outputs = []
97-
for t in examples['text']:
105+
for t in examples["text"]:
98106
outputs.append(t + EOS_TOKEN)
99-
return { "text" : outputs, }
107+
return {
108+
"text": outputs,
109+
}
100110

101111
dataset = load_dataset("json", data_files=datasets)
102112
dataset = dataset["train"]
103113

104-
dataset = dataset.map(formatting_prompts_func, batched = True)
114+
dataset = dataset.map(formatting_prompts_func, batched=True)
105115

106116
return dataset
107117

108-
if __name__ == "__main__":
109-
model, tokenizer = load_model(train_embeddings=True, add_special_tokens=None)
110118

111-
dataset = create_dataset(tokenizer, ["data/chemnlp_train.json", "data/chemnlp_val.json"])
119+
def run(data_files: List[str], train_embeddings: bool, run_name: str, batch_size: int, add_special_tokens: Optional[List[str]]=None)
120+
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )
112121

113-
train(model, tokenizer, dataset, "lora_128", batch_size=64)
122+
dataset = create_dataset(
123+
tokenizer, data_files
124+
)
125+
126+
train(model, tokenizer, dataset, run_name, batch_size=batch_size)
127+
128+
129+
if __name__ == "__main__":
130+
fire.Fire(run)

pyproject.toml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ dynamic = ["version"]
1313
[project.optional-dependencies]
1414
dev = ["pre-commit", "pytest"]
1515
dataset_creation = [
16-
"PyTDC",
17-
"rdkit",
18-
"ruamel.yaml",
19-
"selfies",
20-
"deepsmiles",
21-
"pubchempy",
22-
"bioc",
23-
"pylatexenc",
24-
"canonicalize_psmiles@git+https://github.com/Ramprasad-Group/canonicalize_psmiles.git",
25-
"rxn-chem-utils",
26-
"backoff",
27-
"givemeconformer",
28-
"chembl_webresource_client",
29-
"dask",
30-
"pandarallel",
16+
"PyTDC",
17+
"rdkit",
18+
"ruamel.yaml",
19+
"selfies",
20+
"deepsmiles",
21+
"pubchempy",
22+
"bioc",
23+
"pylatexenc",
24+
"canonicalize_psmiles@git+https://github.com/Ramprasad-Group/canonicalize_psmiles.git",
25+
"rxn-chem-utils",
26+
"backoff",
27+
"givemeconformer",
28+
"chembl_webresource_client",
29+
"dask",
30+
"pandarallel"
3131
]
3232

3333
[project.scripts]

src/chemnlp/data/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010

1111
from pathlib import Path
12-
import fire
12+
1313

1414
def get_all_datasets(root_dir):
1515
return [d.name for d in Path(root_dir).iterdir() if d.is_dir()]
1616

17-
def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='train'):
17+
18+
def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type="train"):
1819
root_dir = Path(root_dir)
1920

2021
if datasets is None:
@@ -25,28 +26,28 @@ def concatenate_jsonl_files(root_dir, output_file, datasets=None, file_type='tra
2526
print(f"Processing datasets: {', '.join(datasets)}")
2627
print(f"File type: {file_type}.jsonl")
2728

28-
with open(output_file, 'w') as outfile:
29+
with open(output_file, "w") as outfile:
2930
for dataset in datasets:
3031
dataset_path = root_dir / dataset
3132
if not dataset_path.is_dir():
3233
print(f"Warning: Dataset '{dataset}' not found. Skipping.")
3334
continue
3435

35-
for chunk_dir in dataset_path.glob('chunk_*'):
36-
for template_dir in chunk_dir.glob('template_*'):
37-
jsonl_file = template_dir / f'{file_type}.jsonl'
36+
for chunk_dir in dataset_path.glob("chunk_*"):
37+
for template_dir in chunk_dir.glob("template_*"):
38+
jsonl_file = template_dir / f"{file_type}.jsonl"
3839
if jsonl_file.is_file():
39-
with open(jsonl_file, 'r') as infile:
40+
with open(jsonl_file, "r") as infile:
4041
for line in infile:
4142
outfile.write(line)
4243

4344
print(f"Concatenated {file_type}.jsonl files have been saved to {output_file}")
4445

46+
4547
def concatenate_jsonl_files_cli():
4648
fire.Fire(concatenate_jsonl_files)
4749

4850

49-
5051
def add_random_split_column(df):
5152
# Calculate the number of rows for each split
5253
n_rows = len(df)

0 commit comments

Comments
 (0)