-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Caskade get_values and set_values are nice for abstracting the params tensor (or list or dict), but it means that a given index has no meaning to the user. Ideally, they'd have some behaviour like:
find_index(param)which would identify which parts of the params tensor correspond to a given param, andfind_param(idx)which identifies the param associated with a given params tensor index.
There are some extra things to think about too. For find_index if the param is not a scalar then the param will correspond to many indices. For find_param if the param is not a scalar then it would be good to know what index in the param is associated with idx. For both we would need to consider what to do in the case when there are param groups. For find_index it would also be good to consider if the user provides a Module, we could identify all the children params of that module. For both we could consider allowing list inputs rather than just single values.