-
Notifications
You must be signed in to change notification settings - Fork 7
Enhance Inference #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances the inference module with three main improvements: fixing a torch tensor conversion issue, adding multi-chain plotting capabilities with customizable colors, and implementing a function to determine fixed parameters based on model specifications.
Key Changes
- Fixed torch tensor to numpy conversion by adding
.detach()before.numpy()call - Enhanced
plot_triangleto support multiple chains with color customization and updated docstrings to NumPy format - Added
get_fixed_paramsfunction to parse model strings and identify fixed vs. free parameters
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| sunbird/inference/pocomc.py | Added torch tensor detachment before numpy conversion in log_likelihood method |
| sunbird/inference/priors/priors.py | Implemented get_fixed_params function to parse cosmology and HOD model strings and determine fixed parameters |
| sunbird/inference/samples/chain.py | Refactored plotting methods to support multiple chains, added plot_map method, updated docstrings to NumPy style, and added type hints |
Comments suppressed due to low confidence (2)
sunbird/inference/samples/chain.py:29
- Class methods or methods of a type deriving from type should have 'cls', rather than 'self', as their first parameter.
def load(self, filename: str):
sunbird/inference/samples/chain.py:47
- Class methods or methods of a type deriving from type should have 'cls', rather than 'self', as their first parameter.
def to_getdist(self, chain, add_derived: bool = False, **kwargs) -> MCSamples:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| percentile_list = [percentile, percentile] if isinstance(percentile, int) else percentile | ||
| colors = kwargs.get('colors', ['k'] + [f'C{i}' for i in range(len(chains)-1)]) | ||
| label_dict = kwargs.get('label_dict', {}) | ||
| params = kwargs.get('params', self.names) | ||
| names = [p for p in params if p in self.names] | ||
| labels = {k: v for k, v in zip(self.names, self.labels)} | ||
| labels = [labels[n] for n in names] | ||
| chain_labels = [label_dict.get(chain.data.get('label', ''), chain.data.get('label', None)) for chain in chains] # replace label with actual name if provided | ||
|
|
||
| fig, ax = plt.subplots(1, len(names), sharey=True, figsize=(3*len(names), 3)) | ||
|
|
||
| for i, chain in enumerate(chains): | ||
| maxl = chain.samples[chain.loglike.argmax()] | ||
| mean = chain.samples.mean(axis=0) | ||
| percentiles = np.percentile(chain.samples, percentile_list, axis=0) |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The percentile calculation is incorrect. The description states percentile_list contains the percentiles (e.g., [5, 95] for 95% confidence), but np.percentile expects percentile values directly. For a 95% confidence interval, you would want the 2.5th and 97.5th percentiles, not [95, 95]. The current code with percentile=95 will compute [95th, 95th] percentiles which doesn't create a proper confidence interval.
| Plots the Maximum A Posteriori (MAP) point, the mean and the 95% confidence interval (default) for each parameter in the chain. | ||
| Parameters | ||
| ---------- | ||
| percentile : int | list, optional | ||
| The percentile to use for the error bars. | ||
| If an integer is provided, the error bars will both be the same and correspond to the given percentile. | ||
| If a list is provided, the first element will be the lower percentile and the second element the upper percentile. | ||
| Defaults to 95. |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states that percentile is used for "95% confidence interval (default)" and the parameter description says "the error bars will both be the same and correspond to the given percentile." This is misleading because a confidence interval requires two different percentiles (e.g., 2.5th and 97.5th for 95% CI). The current implementation with percentile=95 would compute [95th, 95th] which is not a confidence interval. The documentation should clarify what percentile value should be used, or the implementation should be changed to accept a confidence level and calculate the appropriate percentiles.
| Plots the Maximum A Posteriori (MAP) point, the mean and the 95% confidence interval (default) for each parameter in the chain. | |
| Parameters | |
| ---------- | |
| percentile : int | list, optional | |
| The percentile to use for the error bars. | |
| If an integer is provided, the error bars will both be the same and correspond to the given percentile. | |
| If a list is provided, the first element will be the lower percentile and the second element the upper percentile. | |
| Defaults to 95. | |
| Plots the Maximum A Posteriori (MAP) point, the mean and percentile-based error bars for each parameter in the chain. | |
| Parameters | |
| ---------- | |
| percentile : int | list, optional | |
| The percentile(s) to use for the error bars. | |
| If an integer is provided, the same percentile is used for both sides of the error bar (i.e. a degenerate interval at that percentile). | |
| If a list is provided, the first element is taken as the lower percentile and the second element as the upper percentile | |
| (e.g. [2.5, 97.5] for a central 95% confidence interval). | |
| Defaults to 95, which corresponds to using the 95th percentile on both sides. |
sunbird.inference.samples.chainto accept several chains in the triangle plot (with colors!), and updated docstrings (Enhance Chains plots #61 )sunbird.inference.priors(Rework priors #56 final implementation)