Skip to content

Commit 0857db2

Browse files
committed
Add model prefix for loading
1 parent be4690e commit 0857db2

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

chebai/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ def cli():
1616
@click.argument("batch_size", type=click.INT)
1717
@click.option("-g", "--group", default="default")
1818
@click.option("--version", default=None)
19+
@click.option("--load-prefix", default=None)
1920
@click.argument("args", nargs=-1)
20-
def train(experiment, batch_size, group, version, args):
21+
def train(experiment, batch_size, group, version, load_prefix, args):
2122
"""Run experiment identified by EXPERIMENT in batches of size BATCH_SIZE."""
2223
try:
2324
ex = experiments.EXPERIMENTS[experiment](batch_size, group, version=version)
@@ -26,7 +27,7 @@ def train(experiment, batch_size, group, version, args):
2627
"Experiment ID not found. The following are available:"
2728
+ ", ".join(experiments.EXPERIMENTS.keys())
2829
)
29-
ex.train(batch_size, *args)
30+
ex.train(batch_size, *args, load_prefix=load_prefix)
3031

3132

3233
@click.command()

chebai/experiments.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ def model_kwargs(self, *args) -> Dict:
4141
def build_dataset(self, batch_size) -> datasets.XYBaseDataModule:
4242
raise NotImplementedError
4343

44-
def train(self, batch_size, *args):
44+
def train(self, batch_size, *args, **kwargs):
4545
self.MODEL.run(
4646
self.dataset,
4747
self.MODEL.NAME,
4848
loss=self.LOSS,
4949
model_kwargs=self.model_kwargs(*args),
50-
version=self.version
50+
version=self.version,
51+
**kwargs
5152
)
5253

5354
def test(self, ckpt_path, *args):

chebai/models/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def run(
146146
model_kwargs: dict = None,
147147
loss=torch.nn.BCELoss,
148148
weighted=False,
149-
version=None
149+
version=None,
150+
**kwargs
150151
):
151152
if model_args is None:
152153
model_args = []
@@ -199,7 +200,7 @@ def run(
199200

200201
# Calculate weights per class
201202

202-
net = cls(*model_args, loss_cls=loss, **model_kwargs)
203+
net = cls(*model_args, loss_cls=loss, **model_kwargs, **kwargs)
203204

204205
# Early stopping seems to be bugged right now with ddp accelerator :(
205206
es = EarlyStopping(

chebai/models/electra.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,15 @@ def __init__(self, **kwargs):
150150
kwargs["config"]["num_labels"] = self.out_dim
151151
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
152152
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
153+
model_prefix = kwargs.get("load_prefix", None)
153154
if pretrained_checkpoint:
154155
with open(pretrained_checkpoint, "rb") as fin:
155156
model_dict = torch.load(fin,map_location=self.device)
156-
self.electra = ElectraModel.from_pretrained(None, state_dict=model_dict['state_dict'], config=self.config)
157+
if model_prefix:
158+
state_dict = {str(k)[len(model_prefix):]:v for k,v in model_dict["state_dict"].items() if str(k).startswith(model_prefix)}
159+
else:
160+
state_dict = model_dict["state_dict"]
161+
self.electra = ElectraModel.from_pretrained(None, state_dict=state_dict, config=self.config)
157162
else:
158163
self.electra = ElectraModel(config=self.config)
159164

0 commit comments

Comments
 (0)