Skip to content

Commit 9d2972e

Browse files
committed
fix wrapping
1 parent f72d74f commit 9d2972e

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

data/tabular/melting_points/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ targets:
1919
benchmarks: []
2020
identifiers:
2121
- id: SMILES
22-
type: text
22+
type: SMILES
2323
description: SMILES
2424
- id: NAME
25-
type: text
25+
type: Other
2626
description: name
2727
license: CC BY 4.0
2828
links:

experiments/ablations/continued_pretrain.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def load_model(
2323
dtype=dtype,
2424
load_in_4bit=load_in_4bit,
2525
)
26-
27-
add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)
26+
if add_special_tokens is not None:
27+
add_new_tokens(model, tokenizer, new_tokens=add_special_tokens)
2828

2929
target_modules = [
3030
"q_proj",
@@ -116,7 +116,12 @@ def formatting_prompts_func(examples):
116116
return dataset
117117

118118

119-
def run(data_files: List[str], train_embeddings: bool, run_name: str, batch_size: int, add_special_tokens: Optional[List[str]]=None)
119+
def run(data_files: List[str], run_name: str, batch_size: int=64, add_special_tokens: Optional[List[str]]=None, train_embeddings: bool=True):
120+
print(f"Data files {data_files}")
121+
print(f"Run name {run_name}")
122+
print(f"Batch size {batch_size}")
123+
print(f"Add special tokens {add_special_tokens}")
124+
print(f"Train embeddings {train_embeddings}")
120125
model, tokenizer = load_model(train_embeddings=train_embeddings, add_special_tokens=add_special_tokens )
121126

122127
dataset = create_dataset(

src/chemnlp/data/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
172172
"""Wrap the identifier value with tags if wrap_identifiers is enabled."""
173173

174174
if not self.wrap_identifiers:
175+
logger.debug("Not wrapping identifiers.")
175176
return value
176177

177178
identifier_type = next(
@@ -188,9 +189,11 @@ def _wrap_identifier(self, identifier: str, value: str) -> str:
188189
except ValueError:
189190
identifier_type = None
190191

192+
logger.debug(f'Identifier type: {identifier_type}, value: {value}')
191193
if identifier_type and identifier_type not in self.config.get(
192194
"excluded_from_wrapping", []
193195
):
196+
logger.debug(f"Wrapping {identifier_type} with tags.")
194197
return f"[BEGIN_{identifier_type}]{value}[END_{identifier_type}]"
195198
return value
196199

src/chemnlp/data/sampler_cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def process_dataset(
102102
"excluded_from_wrapping": ["Other"],
103103
}
104104

105+
105106
templates = meta["templates"]
106107
if benchmarking:
107108
templates = [t for t in templates if "<EOI>" in t]
@@ -116,7 +117,9 @@ def process_dataset(
116117
logger.debug(f"Processing chunk {chunk_idx} to {chunk_output_dir}")
117118
os.makedirs(chunk_output_dir, exist_ok=True)
118119

119-
sampler = TemplateSampler(df_chunk, meta, config, data_dir)
120+
sampler = TemplateSampler(df_chunk, meta=meta, config=config, path_data_dir=data_dir)
121+
if wrap_identifiers:
122+
assert sampler.wrap_identifiers, "Wrap identifiers must be enabled in the sampler"
120123

121124
for template_idx, template in enumerate(templates):
122125
print(
@@ -177,7 +180,7 @@ def main(
177180
benchmarking,
178181
additional_templates,
179182
use_standard_templates,
180-
wrap_identifiers,
183+
wrap_identifiers=wrap_identifiers,
181184
)
182185

183186

0 commit comments

Comments
 (0)