generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
Currently, we have the following schedule operations logic for gpt2.
# code snippet from slapo/model_schedule/gpt2.py
...
attn_op = []
for idx in range(model_config.num_hidden_layers):
sub_sch = sch[attn_path.replace("N", str(idx))]
with init_empty_weights(enable=delay_init):
new_mod = Attention(**init_config)
attn_op.append(new_mod.module.attn_op_name)
sub_sch.replace(new_mod)
cnt += 1
...Issues of this code snippet from my view:
- the primitive function
replacecannot provide type inference features: difficult to know the options for the primitive and not possible to get the doc string for the primitive. - selection for sub graph is not intuitive due to the concept of the sub schedule. Treating the schedule as a dictionary/hash table is not that intuitive to me. For a single model, it is natural to me that we have a single schedule for this model. The schedule can only affect part of the model, and consider them as a list of tuples, e.g.,
(module_part_id, "replace_with", new_module_obj). This can also facility debugging the schedule, e.g., removing entries of in the schedule to disable schedules.
Recommend to changes the APIs to the following
...
for idx in range(model_config.num_hidden_layers):
sub_module = slapo.select(model, "transformer.h."+str(idx))
with init_empty_weights(enable=delay_init):
new_mod = Attention(**init_config)
cur_schedule = slapo.replace(cur_schedule, sub_module, new_mod)
cnt += 1
...And the select method can be further improved to consider fuzzy match
...
sub_modules = slapo.select("transformer.h.*")
with init_empty_weights(enable=delay_init):
new_mod = Attention(**init_config)
cur_schedule = slapo.replace(cur_schedule, sub_modules, new_mod)
cnt = len(sub_modules)
...Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels