Conversation
…finite-difference hessian just diagonal).
…nstead of the sum likelihood.
There was a problem hiding this comment.
Pull Request Overview
This PR implements autodifferentiation capabilities using JAX for computing gradients and Hessians of likelihood functions, supporting both Gaussian and Poisson likelihoods.
- Adds JAX as an optional dependency with fallback to existing implementations
- Updates likelihood API to support different computation methods (exact, finite difference, JAX autodiff)
- Modifies likelihood functions to return element-wise arrays instead of scalar sums
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/dalia/core/likelihood.py | Adds base methods for gradient/Hessian computation with method selection and finite difference fallbacks |
| src/dalia/configs/likelihood_config.py | Extends configuration to include method selection for derivative computation |
| src/dalia/likelihoods/gaussian.py | Implements JAX autodiff methods and updates return types for element-wise computation |
| src/dalia/likelihoods/poisson.py | Adds JAX support and updates likelihood evaluation to return arrays |
| src/dalia/likelihoods/binomial.py | Updates likelihood return type for consistency |
| src/dalia/models/coregional_model.py | Updates method calls to use new likelihood API |
| src/dalia/core/model.py | Adapts to new likelihood interface and handles different return types |
| src/dalia/submodels/brainiac.py | Adds TODO comment about potential refactoring |
Comments suppressed due to low confidence (1)
src/dalia/likelihoods/binomial.py:22
- The parameter order in the super().init() call is inconsistent with other likelihood classes in this PR. Based on the changes in gaussian.py and poisson.py, it should be
super().__init__(n_observations, config).
super().__init__(config, n_observations)
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| ) -> None: | ||
| """Initializes the Poisson likelihood.""" | ||
| super().__init__(config, n_observations) | ||
| super().__init__(n_observations, config) |
There was a problem hiding this comment.
The parameter order in the super().init() call is inconsistent with the parent class constructor. Based on the parent class definition, it should be super().__init__(config, n_observations) to match the expected signature.
| super().__init__(n_observations, config) | |
| super().__init__(config, n_observations) |
| Gradient of the likelihood. | ||
| """ | ||
| pass | ||
| self.finite_difference_gradient_likelihood(eta, y, **kwargs) |
There was a problem hiding this comment.
These abstract method implementations are missing return statements. They should return the result of the finite difference computation.
| self.finite_difference_gradient_likelihood(eta, y, **kwargs) | |
| return self.finite_difference_gradient_likelihood(eta, y, **kwargs) |
| Hessian of the likelihood. | ||
| """ | ||
| pass | ||
| self.finite_difference_hessian_likelihood(eta, y, **kwargs) |
There was a problem hiding this comment.
These abstract method implementations are missing return statements. They should return the result of the finite difference computation.
| self.finite_difference_hessian_likelihood(eta, y, **kwargs) | |
| return self.finite_difference_hessian_likelihood(eta, y, **kwargs) |
| def finite_difference_hessian_likelihood( | ||
| self, | ||
| eta: NDArray, | ||
| y: NDArray, | ||
| h: float = 1e-2, | ||
| **kwargs, | ||
| ) -> NDArray: |
There was a problem hiding this comment.
The method signature includes eta and y parameters, but the abstract method evaluate_hessian_likelihood doesn't include these parameters in its kwargs. The finite difference implementation expects these parameters to be passed explicitly.
Adds the capability to use JAX for auto-differentiation (gradient and Hessian likelihood). So far, the following likelihoods are supported: