diff --git a/posteriors/gradient_descent/__init__.py b/posteriors/gradient_descent/__init__.py new file mode 100644 index 0000000..d6c4ebf --- /dev/null +++ b/posteriors/gradient_descent/__init__.py @@ -0,0 +1 @@ +from posteriors.gradient_descent import svgd diff --git a/posteriors/gradient_descent/svgd.py b/posteriors/gradient_descent/svgd.py new file mode 100644 index 0000000..e56a83d --- /dev/null +++ b/posteriors/gradient_descent/svgd.py @@ -0,0 +1,97 @@ +from functools import partial +from typing import Callable, Any, NamedTuple + +from optree import tree_map +from torch.func import grad_and_value, vmap +from optree.integration.torch import tree_ravel + +from posteriors.types import TensorTree, Transform + + +def _build_stein_variational_gradient_step( + log_posterior: Callable[[TensorTree, Any], TensorTree], kernel: Callable +): + """ + hardcode a function that calculates phi_star according to user defined log_posterior_gradient + """ + + # def _phi_star_summand(param, param_, batch): + # log_prob_grad, _ = grad_and_value(log_posterior, argnums=0)(param, batch) + # grad_k, k = grad_and_value(kernel, argnums=0)(param, param_) + # return tree_map(lambda gl, gk: (k * gl) + gk, log_prob_grad, grad_k) + + def step(params: TensorTree, batch): + def _phi_star_summand(param, param_, batch): + log_prob_grad, _ = grad_and_value(log_posterior, argnums=0)(param, batch) + grad_k, k = grad_and_value(kernel, argnums=0)(param, param_) + return tree_map(lambda gl, gk: (k * gl) + gk, log_prob_grad, grad_k) + + phi_star_summand = partial(_phi_star_summand, batch=batch) + r_params, unravel = tree_ravel(params) + + gradients = tree_map( + lambda p: vmap(lambda p_: phi_star_summand(p, p_))(r_params).mean(axis=0), + params, + ) + return gradients + # gradients = vmap( + # lambda param: ( + # vmap(lambda param_: phi_star_summand(param, param_))(r_params).mean( + # axis=0 + # ) + # ) + # )(r_params) + # return unravel(gradients) + + return step + + +def build( + log_posterior: Callable[[TensorTree, Any], TensorTree], + learning_rate: float, + kernel: Callable, +) -> Transform: + """ + TBD + """ + step_gradient_fn = _build_stein_variational_gradient_step(log_posterior, kernel) + update_fn = partial( + update, + step_function=step_gradient_fn, + learning_rate=learning_rate, + ) + return Transform(init, update_fn) + + +class SVGDState(NamedTuple): + """ + TBD + """ + + params: TensorTree + + +def init( + params: TensorTree, +) -> SVGDState: + """TBD""" + return SVGDState(params) + + +def update( + state: SVGDState, + batch: Any, + step_function: Callable, + learning_rate: float, + inplace: bool = False, +) -> SVGDState: + """ + TBD + """ + + step_gradient = step_function(state.params, batch) + params = tree_map(lambda p, g: p + learning_rate * g, state.params, step_gradient) + if inplace: + state.params = params + else: + return SVGDState(params) diff --git a/tests/gradient_descent/test_svgd.py b/tests/gradient_descent/test_svgd.py new file mode 100644 index 0000000..faf4a2d --- /dev/null +++ b/tests/gradient_descent/test_svgd.py @@ -0,0 +1,52 @@ +from functools import partial +import torch +from optree import tree_map +from posteriors.gradient_descent import svgd +from optree.integration.torch import tree_ravel +from torch.distributions import Normal + + +def rbf_kernel(x, y, length_scale=1): + arg = tree_ravel( + tree_map(lambda x, y: torch.exp(-(1 / length_scale) * ((x - y) ** 2)), x, y) + )[0] + return arg.sum() + + +def flat_log_probability(params, batch, mean, sd_diag, normalize: bool = False): + if normalize: + + def univariate_norm_and_sum(v, m, sd): + return Normal(m, sd, validate_args=False).log_prob(v).sum() + else: + + def univariate_norm_and_sum(v, m, sd): + return (-0.5 * ((v - m) / sd) ** 2).sum() + + return univariate_norm_and_sum(params, mean, sd_diag) + + +def test_svgd(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10} + flat_target_mean = tree_ravel(target_mean)[0] + target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) + flat_target_sds = tree_ravel(target_sds)[0] + init_mean = tree_map(lambda x: torch.ones_like(x, requires_grad=True), target_mean) + + batch = torch.arange(10).reshape(-1, 1) + batch_normal_log_prob_spec = partial( + flat_log_probability, mean=flat_target_mean, sd_diag=flat_target_sds + ) + + n_steps = 1000 + lr = 1e-2 + transform = svgd.build(batch_normal_log_prob_spec, lr, rbf_kernel) + state = transform.init(init_mean) + + for _ in range(n_steps): + state = transform.update(state, batch, inplace=False) + + flat_params = tree_ravel(state.params)[0] + + assert torch.allclose(flat_params, flat_target_mean, atol=1e-0, rtol=1e-1) diff --git a/tests/gradient_descent/test_svgd_old.py b/tests/gradient_descent/test_svgd_old.py new file mode 100644 index 0000000..2414bf5 --- /dev/null +++ b/tests/gradient_descent/test_svgd_old.py @@ -0,0 +1,34 @@ +import torch +from torchopt import sgd +from optree import tree_map +from optree.integration.torch import tree_ravel +from posteriors import gradient_descent + + +def rbf_kernel(x, y, length_scale=1): + arg = tree_ravel( + tree_map(lambda x, y: torch.exp(-(1 / length_scale) * ((x - y) ** 2)), x, y) + )[0] + return arg.sum() + + +def test_svgd_api(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} + + def log_prob_grad(p, b): + return 1 + + init_mean = tree_map(lambda x: torch.ones_like(x, requires_grad=True), target_mean) + batch = torch.arange(3).reshape(-1, 1) + transform = gradient_descent.svgd.build(log_prob_grad, sgd(lr=1e-1), rbf_kernel) + + state = transform.init(init_mean) + state = transform.update(state, batch, inplace=False) + + # no crushes + assert True + + +def dummy_test(): + pass diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ce4061..1facd5c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -107,7 +107,8 @@ def test_model_to_function(): func_output2 = func_lm(dict(lm.named_parameters()), input_ids, attention_mask) - assert type(output) == type(func_output1) == type(func_output2) + assert type(output) is type(func_output1) + assert type(func_output1) is type(func_output2) assert torch.allclose(output["logits"], func_output1["logits"]) assert torch.allclose(output["logits"], func_output2["logits"])