-
Notifications
You must be signed in to change notification settings - Fork 120
tests - added additional tests for conditional samplers and toml updates #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
krishjp
wants to merge
1
commit into
extropic-ai:main
Choose a base branch
from
krishjp:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| from thrml.conditional_samplers import BernoulliConditional, SoftmaxConditional | ||
| from thrml.pgm import SpinNode | ||
| from thrml.block_management import Block, BlockSpec, block_state_to_global | ||
| from thrml.models.ising import IsingEBM, IsingSamplingProgram, hinton_init | ||
| from thrml.models.discrete_ebm import SpinEBMFactor, SpinGibbsConditional | ||
| from thrml.block_sampling import BlockGibbsSpec | ||
|
|
||
| def test_bernoulli_conditional_sample_shapes_and_dtype(): | ||
| """Basic shape/dtype checks for BernoulliConditional behaviour. | ||
| """ | ||
|
|
||
| class ConstBern(BernoulliConditional): | ||
| def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd): | ||
| return jnp.array([0.0, 10.0, -10.0]), None | ||
|
|
||
| sampler = ConstBern() | ||
| key = jax.random.PRNGKey(0) | ||
| output_sd = jax.ShapeDtypeStruct((3,), dtype=jnp.bool_) | ||
|
|
||
| sample, state = sampler.sample(key, [], [], [], None, output_sd) | ||
|
|
||
| assert isinstance(sample, jnp.ndarray) | ||
| assert sample.shape == (3,) | ||
| assert sample.dtype == jnp.bool_ | ||
| assert state is None | ||
|
|
||
|
|
||
| def test_bernoulli_sample_given_parameters_consistent_dtype(): | ||
| class ConstBern(BernoulliConditional): | ||
| def compute_parameters(self, *args, **kwargs): | ||
| return jnp.zeros((3,)) | ||
|
|
||
| sampler = ConstBern() | ||
| params = jnp.array([100.0, -100.0, 0.0]) | ||
| output_sd = jax.ShapeDtypeStruct((3,), dtype=jnp.bool_) | ||
|
|
||
| sample, state = sampler.sample_given_parameters(jax.random.PRNGKey(1), params, None, output_sd) | ||
| assert sample.shape == (3,) | ||
| assert sample.dtype == jnp.bool_ | ||
| assert state is None | ||
|
|
||
|
|
||
| def test_bernoulli_conditional_sampling_bias(): | ||
| """Verify that Bernoulli sampler respects parameter biases. | ||
|
|
||
| High positive gamma should bias toward True, high negative toward False. | ||
| """ | ||
|
|
||
| class ConstBern(BernoulliConditional): | ||
| def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd): | ||
| return jnp.array([100.0, -100.0, 0.0]), None | ||
|
|
||
| sampler = ConstBern() | ||
| output_sd = jax.ShapeDtypeStruct((3,), dtype=jnp.bool_) | ||
|
|
||
| key = jax.random.PRNGKey(42) | ||
| samples_list = [] | ||
| for i in range(100): | ||
| key, subkey = jax.random.split(key) | ||
| sample, _ = sampler.sample(subkey, [], [], [], None, output_sd) | ||
| samples_list.append(sample) | ||
|
|
||
| samples = jnp.array(samples_list) | ||
|
|
||
| assert jnp.mean(samples[:, 1]) < 0.05 | ||
| assert 0.3 < jnp.mean(samples[:, 2]) < 0.7 | ||
|
|
||
|
|
||
| def test_softmax_conditional_sample_shapes_and_dtype(): | ||
| """Basic checks for SoftmaxConditional behaviour. | ||
|
|
||
| Check that the sampler accepts a [b, M] parameter matrix and returns | ||
| an integer array with the expected shape and dtype. | ||
| """ | ||
|
|
||
| class ConstSoftmax(SoftmaxConditional): | ||
| def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd): | ||
| return jnp.array([[10.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 10.0]]), None | ||
|
|
||
| sampler = ConstSoftmax() | ||
| params, _ = sampler.compute_parameters(None, [], [], [], None, None) | ||
| output_sd = jax.ShapeDtypeStruct((2,), dtype=jnp.uint8) | ||
|
|
||
| sample, state = sampler.sample_given_parameters(jax.random.PRNGKey(2), params, None, output_sd) | ||
|
|
||
| assert isinstance(sample, jnp.ndarray) | ||
| assert sample.shape == (2,) | ||
| assert sample.dtype == jnp.uint8 | ||
| assert state is None | ||
|
|
||
|
|
||
| def test_softmax_conditional_categorical_bias(): | ||
| """Verify that Softmax sampler respects parameter biases. | ||
| """ | ||
|
|
||
| class ConstSoftmax(SoftmaxConditional): | ||
| def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd): | ||
| return jnp.array([[10.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 10.0]]), None | ||
|
|
||
| sampler = ConstSoftmax() | ||
| output_sd = jax.ShapeDtypeStruct((2,), dtype=jnp.uint8) | ||
|
|
||
| key = jax.random.PRNGKey(42) | ||
| samples_list = [] | ||
| for i in range(100): | ||
| key, subkey = jax.random.split(key) | ||
| sample, _ = sampler.sample(subkey, [], [], [], None, output_sd) | ||
| samples_list.append(sample) | ||
|
|
||
| samples = jnp.array(samples_list) | ||
|
|
||
| assert jnp.mean(samples[:, 0] == 0) > 0.95 | ||
| assert jnp.mean(samples[:, 1] == 3) > 0.95 | ||
|
|
||
|
|
||
| def test_spin_gibbs_conditional_with_ising_chain(): | ||
| nodes = [SpinNode() for _ in range(5)] | ||
| edges = [(nodes[i], nodes[i + 1]) for i in range(4)] | ||
|
|
||
| biases = jnp.array([5.0, 0.0, 0.0, 0.0, -5.0]) | ||
|
|
||
| weights = jnp.ones((4,)) * 2.0 | ||
| beta = jnp.array(1.0) | ||
|
|
||
| model = IsingEBM(nodes, edges, biases, weights, beta) | ||
|
|
||
| free_blocks = [Block(nodes[::2]), Block(nodes[1::2])] | ||
| program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[]) | ||
|
|
||
| key = jax.random.PRNGKey(0) | ||
| k_init, k_samp = jax.random.split(key, 2) | ||
|
|
||
|
|
||
| init_state = hinton_init(k_init, model, free_blocks, tuple()) | ||
|
|
||
| # Collect samples | ||
| samples_list = [] | ||
| key = k_samp | ||
| for _ in range(50): | ||
| key, subkey = jax.random.split(key) | ||
| pass | ||
|
|
||
| # init_state shapes asserts | ||
| assert len(init_state) == 2 | ||
| assert init_state[0].shape == (3,) | ||
| assert init_state[1].shape == (2,) | ||
| assert init_state[0].dtype == jnp.bool_ | ||
| assert init_state[1].dtype == jnp.bool_ | ||
|
|
||
|
|
||
| def test_spin_gibbs_conditional_energy_consistency(): | ||
| """verify SpinGibbsConditional respects energy landscape. | ||
|
|
||
| Confirm that a strong external field on a single spin causes | ||
| the sampler to bias toward the lower-energy configuration. | ||
| """ | ||
| # Create a 3-node chain with strong bias on node 1 | ||
| nodes = [SpinNode() for _ in range(3)] | ||
| edges = [(nodes[0], nodes[1]), (nodes[1], nodes[2])] | ||
|
|
||
| biases = jnp.array([0.0, 10.0, 0.0]) | ||
| weights = jnp.array([0.1, 0.1]) | ||
| beta = jnp.array(1.0) | ||
|
|
||
| model = IsingEBM(nodes, edges, biases, weights, beta) | ||
|
|
||
| free_blocks = [Block(nodes)] | ||
| program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[]) | ||
|
|
||
| key = jax.random.PRNGKey(1) | ||
| init_state = hinton_init(key, model, free_blocks, tuple()) | ||
|
|
||
| assert init_state[0][1].astype(jnp.float32) > 0.5 # Single sample should reflect the bias | ||
|
|
||
|
|
||
| def test_spin_gibbs_conditional_with_coupling(): | ||
| """verify sampler respects edge coupling. | ||
|
|
||
| When two nodes are strongly coupled with a positive weight, they should | ||
| tend to have the same value. | ||
| """ | ||
| # Create a pair of nodes with strong positive coupling | ||
| nodes = [SpinNode() for _ in range(2)] | ||
| edges = [(nodes[0], nodes[1])] | ||
|
|
||
| biases = jnp.array([0.0, 0.0]) | ||
| weights = jnp.array([10.0]) | ||
| beta = jnp.array(1.0) | ||
|
|
||
| model = IsingEBM(nodes, edges, biases, weights, beta) | ||
|
|
||
| key = jax.random.PRNGKey(2) | ||
| free_blocks = [Block([nodes[0]]), Block([nodes[1]])] | ||
| init_state = hinton_init(key, model, free_blocks, tuple()) | ||
|
|
||
| assert len(init_state) == 2 | ||
| assert all(s.dtype == jnp.bool_ for s in init_state) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these changes seem fine, but I would separate out the conditional test additions to another PR. I will have to review the code and determine if these tests are warranted (regardless they should also be adjusted to match the style of the other tests)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I'll make another PR for the tests and adjust them to match styling