-
Notifications
You must be signed in to change notification settings - Fork 8
[WIP] small changes for methods.gradient_ascent #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] small changes for methods.gradient_ascent #7
Conversation
update master
update master
update master.
christoph-blessing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please run black over this and check if all of the tests still pass?
mei/modules.py
Outdated
| """ | ||
|
|
||
| def __init__(self, model: Module, constraint: int, forward_kwargs: Dict[str, Any] = None): | ||
| def __init__(self, model: Module, constraint: int, target_fn=None, forward_kwargs: Dict[str, Any] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a type annotation for the new argument.
| model: A PyTorch module. | ||
| constraint: An integer representing the index of a neuron in the model's output. Only the value corresponding | ||
| to that index will be returned. | ||
| target_fn: Callable, that gets as an input the constrained output of the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this new argument needed?
mei/modules.py
Outdated
| """ | ||
| output = self.model(x, *args, **self.forward_kwargs, **kwargs) | ||
| return output[:, self.constraint] | ||
| return self.target_fn(output[:, self.constraint]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be tested if the model output gets passed to the function and if the return value of the function gets returned.
|
|
||
|
|
||
| class RandomNormalNullChannel(InitialGuessCreator): | ||
| """Used to create an initial guess tensor filled with values distributed according to a normal distribution.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring is the same as the one for the RandomNormal initial guess creator. This might confuse people as to what the differences between the two are.
|
|
||
| _create_random_tensor = randn | ||
|
|
||
| def __init__(self, null_channel, null_value=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to see type annotations here.
| self.null_value = null_value | ||
|
|
||
| def __call__(self, *shape): | ||
| """Creates a random initial guess from which to start the MEI optimization process given a shape.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also the same docstring as the one of the __call__ method of the RandomNormal` initial guess creator.
mei/initial.py
Outdated
| return inital | ||
|
|
||
| def __repr__(self): | ||
| return f"{self.__class__.__qualname__}()" No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file is missing a newline at the end
Fix Bug in Objective Table naming.
No description provided.