Skip to content

Autodiff with JAX#122

Open
alexnick83 wants to merge 5 commits intodevfrom
likelihood-jax
Open

Autodiff with JAX#122
alexnick83 wants to merge 5 commits intodevfrom
likelihood-jax

Conversation

@alexnick83
Copy link
Collaborator

Adds the capability to use JAX for auto-differentiation (gradient and Hessian likelihood). So far, the following likelihoods are supported:

  • Gaussian
  • Poisson

@vincent-maillou vincent-maillou requested review from Copilot, lisa-gm and vincent-maillou and removed request for Copilot, lisa-gm and vincent-maillou September 12, 2025 14:32
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements autodifferentiation capabilities using JAX for computing gradients and Hessians of likelihood functions, supporting both Gaussian and Poisson likelihoods.

  • Adds JAX as an optional dependency with fallback to existing implementations
  • Updates likelihood API to support different computation methods (exact, finite difference, JAX autodiff)
  • Modifies likelihood functions to return element-wise arrays instead of scalar sums

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/dalia/core/likelihood.py Adds base methods for gradient/Hessian computation with method selection and finite difference fallbacks
src/dalia/configs/likelihood_config.py Extends configuration to include method selection for derivative computation
src/dalia/likelihoods/gaussian.py Implements JAX autodiff methods and updates return types for element-wise computation
src/dalia/likelihoods/poisson.py Adds JAX support and updates likelihood evaluation to return arrays
src/dalia/likelihoods/binomial.py Updates likelihood return type for consistency
src/dalia/models/coregional_model.py Updates method calls to use new likelihood API
src/dalia/core/model.py Adapts to new likelihood interface and handles different return types
src/dalia/submodels/brainiac.py Adds TODO comment about potential refactoring
Comments suppressed due to low confidence (1)

src/dalia/likelihoods/binomial.py:22

  • The parameter order in the super().init() call is inconsistent with other likelihood classes in this PR. Based on the changes in gaussian.py and poisson.py, it should be super().__init__(n_observations, config).
        super().__init__(config, n_observations)

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

) -> None:
"""Initializes the Poisson likelihood."""
super().__init__(config, n_observations)
super().__init__(n_observations, config)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter order in the super().init() call is inconsistent with the parent class constructor. Based on the parent class definition, it should be super().__init__(config, n_observations) to match the expected signature.

Suggested change
super().__init__(n_observations, config)
super().__init__(config, n_observations)

Copilot uses AI. Check for mistakes.
Gradient of the likelihood.
"""
pass
self.finite_difference_gradient_likelihood(eta, y, **kwargs)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These abstract method implementations are missing return statements. They should return the result of the finite difference computation.

Suggested change
self.finite_difference_gradient_likelihood(eta, y, **kwargs)
return self.finite_difference_gradient_likelihood(eta, y, **kwargs)

Copilot uses AI. Check for mistakes.
Hessian of the likelihood.
"""
pass
self.finite_difference_hessian_likelihood(eta, y, **kwargs)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These abstract method implementations are missing return statements. They should return the result of the finite difference computation.

Suggested change
self.finite_difference_hessian_likelihood(eta, y, **kwargs)
return self.finite_difference_hessian_likelihood(eta, y, **kwargs)

Copilot uses AI. Check for mistakes.
Comment on lines +214 to +220
def finite_difference_hessian_likelihood(
self,
eta: NDArray,
y: NDArray,
h: float = 1e-2,
**kwargs,
) -> NDArray:
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method signature includes eta and y parameters, but the abstract method evaluate_hessian_likelihood doesn't include these parameters in its kwargs. The finite difference implementation expects these parameters to be passed explicitly.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants