Skip to content

Conversation

@kanodiaayush
Copy link
Contributor

Temporary working commit for Tianyu to review.


self.variation = variation

assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution not in ['gaussian', 'gamma']:
raise ValueError( f'Unsupported distribution {distribution}')


assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution == 'gamma':
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you comment out this code chunk using '''? Can you remove it if we don't need it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is an artefact of this being a temporary commit

torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
"""
# p(sample)
# DEBUG_MARKER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the debug marker?

concentration = torch.exp(mu)
rate = self.prior_variance
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug! The rate and prior_variance are different, can you double-check it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i meant for gamma prior_variance to represent rate. Earlier too I set prior_variance to rate. This should be correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging it though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is prior_variance the rate of prior or the log(rate) of prior?

rate = torch.exp(self.variational_logstd)
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotImplementedError("Unknown variational distribution type.")
raise NotImplementedError(f"Unknown variational distribution type {self.distribution}.")

def is_observable(name: str) -> bool:
return any(name.startswith(prefix) for prefix in observable_prefix)

utility_string = utility_string.replace(' - ', ' + -')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to find a way to parse utilities even when the user does not put spaces around + or -; let's do this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
additive_term *= term['sign']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add one single additive_term *= term['sign'] outside the if-else-if loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, should be possible but i'll make sure

loss = - elbo
return loss

# DEBUG_MARKER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove debug marker.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not review the configurations and main in your super-market specific script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i'll take care of those.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants