Refactor initial sample generation into reusable utility function#239
Refactor initial sample generation into reusable utility function#239thomasckng wants to merge 1 commit intojim-devfrom
Conversation
… utility function
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Pull Request Overview
This PR refactors duplicated initial sample generation logic into a reusable utility function. The refactoring extracts common code that was repeated across multiple files for generating valid MCMC initial samples by repeatedly sampling from a prior and applying transforms until all values are finite.
- Consolidates duplicate initial sampling logic from multiple files into a single utility function
- Updates import statements to include the new utility function
- Replaces inline sampling loops with calls to the new
generate_initial_samplesfunction
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/jimgw/core/utils.py | Adds the new generate_initial_samples utility function and required imports |
| src/jimgw/core/single_event/likelihood.py | Replaces inline sampling logic with call to new utility function |
| src/jimgw/core/jim.py | Replaces inline sampling logic with call to new utility function and updates imports |
| guess = prior.sample(subkey, n_samples) | ||
| for transform in sample_transforms: | ||
| guess = jax.vmap(transform.forward)(guess) | ||
| guess = jnp.array(list(guess.values())).T |
There was a problem hiding this comment.
Converting dictionary values to a list may not preserve consistent ordering across Python versions. This could lead to parameter misalignment. Consider using a consistent key ordering or parameter names list to ensure deterministic behavior.
| guess = jnp.array(list(guess.values())).T | |
| guess = jnp.array([guess[key] for key in sorted(guess.keys())]).T |
| )[0] | ||
| common_length = min(len(finite_guess), len(non_finite_index)) | ||
| initial_position = initial_position.at[non_finite_index[:common_length]].set( | ||
| guess[:common_length] |
There was a problem hiding this comment.
Indexing guess[:common_length] takes the first common_length rows, but it should take the rows corresponding to finite_guess[:common_length] to ensure only finite samples are used.
| guess[:common_length] | |
| guess[finite_guess[:common_length]] |
| guess = prior.sample(subkey, n_samples) | ||
| for transform in sample_transforms: | ||
| guess = jax.vmap(transform.forward)(guess) | ||
| guess = jnp.array(list(guess.values())).T | ||
| finite_guess = jnp.where( | ||
| jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) | ||
| )[0] | ||
| common_length = min(len(finite_guess), len(non_finite_index)) | ||
| initial_position = initial_position.at[non_finite_index[:common_length]].set( | ||
| guess[:common_length] | ||
| ) |
There was a problem hiding this comment.
Sampling n_samples every iteration is inefficient when only a few samples need to be replaced. Consider sampling only len(non_finite_index) samples to reduce unnecessary computation.
| guess = prior.sample(subkey, n_samples) | |
| for transform in sample_transforms: | |
| guess = jax.vmap(transform.forward)(guess) | |
| guess = jnp.array(list(guess.values())).T | |
| finite_guess = jnp.where( | |
| jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) | |
| )[0] | |
| common_length = min(len(finite_guess), len(non_finite_index)) | |
| initial_position = initial_position.at[non_finite_index[:common_length]].set( | |
| guess[:common_length] | |
| ) | |
| num_invalid = len(non_finite_index) # Number of invalid samples | |
| if num_invalid > 0: # Only sample if there are invalid samples | |
| guess = prior.sample(subkey, num_invalid) | |
| for transform in sample_transforms: | |
| guess = jax.vmap(transform.forward)(guess) | |
| guess = jnp.array(list(guess.values())).T | |
| finite_guess = jnp.where( | |
| jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) | |
| )[0] | |
| common_length = min(len(finite_guess), len(non_finite_index)) | |
| initial_position = initial_position.at[non_finite_index[:common_length]].set( | |
| guess[:common_length] | |
| ) |
No description provided.