Here is a toy example using momi3 to optimize the inference that I have just tried
the commands and the TypeError info are attached, any suggestions are appreciated
import demes
import demesdraw
b = demes.Builder()
b.add_deme("ABC", epochs = [dict(start_size = 1200, end_time = 800)])
b.add_deme("AB", ancestors = ["ABC"], epochs = [dict(start_size = 1000, end_time = 550)])
b.add_deme("C", ancestors = ["ABC"], epochs = [dict(start_size = 300)])
b.add_deme("A", ancestors = ["AB"], epochs = [dict(start_size = 800)])
b.add_deme("B", ancestors = ["AB"], epochs = [dict(start_size = 700)])
b.add_pulse(sources = ["C"], dest = "B", time = 240, proportions = [0.28])
g = b.resolve()
demesdraw.tubes(g)
from momi3.MOMI import Momi
sampled_demes = ["A", "B", "C"]
sample_sizes = [20, 30, 15]
momi = Momi(g, sampled_demes=sampled_demes, sample_sizes=sample_sizes, low_memory=True)
params = momi._default_params
bounds = momi.bound_sampler(params, 1000, seed=108)
momi = momi.bound(bounds)
params.set_train_all_etas(True)
params.set_train('eta_0', False)
jsfs = momi.simulate(200, seed=1124321)
momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)
TypeError info as below
TypeError Traceback (most recent call last)
Cell In [7], line 1
----> 1 momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/MOMI.py:284, in Momi.optimize(self, params, jsfs, stepsize, maxiter, theta_train_dict_0, htol, monitor_training)
281 negative_loglik_with_gradient = self.negative_loglik_with_gradient
282 sampled_demes = self.sampled_demes
--> 284 return ProjectedGradient_optimizer(
285 negative_loglik_with_gradient=negative_loglik_with_gradient,
286 params=params,
287 jsfs=jsfs,
288 stepsize=stepsize,
289 maxiter=maxiter,
290 theta_train_dict_0=theta_train_dict_0,
291 sampled_demes=sampled_demes,
292 htol=htol,
293 monitor_training=monitor_training,
294 )
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/optimizers.py:121, in ProjectedGradient_optimizer(negative_loglik_with_gradient, params, jsfs, stepsize, maxiter, sampled_demes, theta_train_dict_0, htol, monitor_training)
118 plt.xlabel("Iteration Number")
120 else:
--> 121 opt_result = pg.run(theta_train_0, hyperparams_proj=(A, b, G, h))
122 theta_train_hat = opt_result.params
123 pg_state = opt_result.state
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/projected_gradient.py:137, in ProjectedGradient.run(self, init_params, hyperparams_proj, *args, **kwargs)
132 def run(self,
133 init_params: Any,
134 hyperparams_proj: Optional[Any] = None,
135 *args,
136 **kwargs) -> base.OptStep:
--> 137 return self._pg.run(init_params, hyperparams_proj, *args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:359, in IterativeSolver.run(self, init_params, *args, **kwargs)
352 decorator = idf.custom_root(
353 self.optimality_fun,
354 has_aux=True,
355 solve=self.implicit_diff_solve,
356 reference_signature=reference_signature)
357 run = decorator(run)
--> 359 return run(init_params, *args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root..wrapped_solver_fun(*args, **kwargs)
249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
[... skipping hidden 5 frame]
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root..make_custom_vjp_solver_fun..solver_fun_flat(*flat_args)
204 @jax.custom_vjp
205 def solver_fun_flat(*flat_args):
206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207 return solver_fun(*args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:321, in IterativeSolver._run(self, init_params, *args, **kwargs)
303 # We unroll the very first iteration. This allows init_val and body_fun
304 # below to have the same output type, which is a requirement of
305 # lax.while_loop and lax.scan.
(...)
316 # of a lax.cond for now in order to avoid staging the initial
317 # update and the run loop. They might not be staging compatible.
319 zero_step = self._make_zero_step(init_params, state)
--> 321 opt_step = self.update(init_params, state, *args, **kwargs)
322 init_val = (opt_step, (args, kwargs))
324 unroll = self._get_unroll_option()
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:305, in ProximalGradient.update(self, params, state, hyperparams_prox, *args, **kwargs)
293 """Performs one iteration of proximal gradient.
294
295 Args:
(...)
302 (params, state)
303 """
304 f = self._update_accel if self.acceleration else self._update
--> 305 return f(params, state, hyperparams_prox, args, kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:266, in ProximalGradient._update_accel(self, x, state, hyperparams_prox, args, kwargs)
263 stepsize = state.stepsize
264 (y_fun_val, aux), y_fun_grad = self._value_and_grad_with_aux(y, *args,
265 **kwargs)
--> 266 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad,
267 stepsize, hyperparams_prox, args, kwargs)
268 next_t = 0.5 * (1 + jnp.sqrt(1 + 4 * t ** 2))
269 diff_x = tree_sub(next_x, x)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:240, in ProximalGradient._iter(self, iter_num, x, x_fun_val, x_fun_grad, stepsize, hyperparams_prox, args, kwargs)
238 else:
239 next_stepsize = self.stepsize
--> 240 next_x = self._prox_grad(x, x_fun_grad, next_stepsize, hyperparams_prox)
241 return next_x, next_stepsize
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:209, in ProximalGradient._prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox)
208 def _prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox):
--> 209 update = tree_add_scalar_mul(x, -stepsize, x_fun_grad)
210 return self.prox(update, hyperparams_prox, stepsize)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul(tree_x, scalar, tree_y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
[... skipping hidden 2 frame]
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul..(x, y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
TypeError: unsupported operand type(s) for *: 'float' and 'dict'
error info above finished
Here is a toy example using momi3 to optimize the inference that I have just tried
the commands and the TypeError info are attached, any suggestions are appreciated
TypeError info as below
TypeError Traceback (most recent call last)
Cell In [7], line 1
----> 1 momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/MOMI.py:284, in Momi.optimize(self, params, jsfs, stepsize, maxiter, theta_train_dict_0, htol, monitor_training)
281 negative_loglik_with_gradient = self.negative_loglik_with_gradient
282 sampled_demes = self.sampled_demes
--> 284 return ProjectedGradient_optimizer(
285 negative_loglik_with_gradient=negative_loglik_with_gradient,
286 params=params,
287 jsfs=jsfs,
288 stepsize=stepsize,
289 maxiter=maxiter,
290 theta_train_dict_0=theta_train_dict_0,
291 sampled_demes=sampled_demes,
292 htol=htol,
293 monitor_training=monitor_training,
294 )
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/optimizers.py:121, in ProjectedGradient_optimizer(negative_loglik_with_gradient, params, jsfs, stepsize, maxiter, sampled_demes, theta_train_dict_0, htol, monitor_training)
118 plt.xlabel("Iteration Number")
120 else:
--> 121 opt_result = pg.run(theta_train_0, hyperparams_proj=(A, b, G, h))
122 theta_train_hat = opt_result.params
123 pg_state = opt_result.state
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/projected_gradient.py:137, in ProjectedGradient.run(self, init_params, hyperparams_proj, *args, **kwargs)
132 def run(self,
133 init_params: Any,
134 hyperparams_proj: Optional[Any] = None,
135 *args,
136 **kwargs) -> base.OptStep:
--> 137 return self._pg.run(init_params, hyperparams_proj, *args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:359, in IterativeSolver.run(self, init_params, *args, **kwargs)
352 decorator = idf.custom_root(
353 self.optimality_fun,
354 has_aux=True,
355 solve=self.implicit_diff_solve,
356 reference_signature=reference_signature)
357 run = decorator(run)
--> 359 return run(init_params, *args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root..wrapped_solver_fun(*args, **kwargs)
249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root..make_custom_vjp_solver_fun..solver_fun_flat(*flat_args)
204 @jax.custom_vjp
205 def solver_fun_flat(*flat_args):
206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207 return solver_fun(*args, **kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:321, in IterativeSolver._run(self, init_params, *args, **kwargs)
303 # We unroll the very first iteration. This allows
init_valandbody_fun304 # below to have the same output type, which is a requirement of
305 # lax.while_loop and lax.scan.
(...)
316 # of a
lax.condfor now in order to avoid staging the initial317 # update and the run loop. They might not be staging compatible.
319 zero_step = self._make_zero_step(init_params, state)
--> 321 opt_step = self.update(init_params, state, *args, **kwargs)
322 init_val = (opt_step, (args, kwargs))
324 unroll = self._get_unroll_option()
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:305, in ProximalGradient.update(self, params, state, hyperparams_prox, *args, **kwargs)
293 """Performs one iteration of proximal gradient.
294
295 Args:
(...)
302 (params, state)
303 """
304 f = self._update_accel if self.acceleration else self._update
--> 305 return f(params, state, hyperparams_prox, args, kwargs)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:266, in ProximalGradient._update_accel(self, x, state, hyperparams_prox, args, kwargs)
263 stepsize = state.stepsize
264 (y_fun_val, aux), y_fun_grad = self._value_and_grad_with_aux(y, *args,
265 **kwargs)
--> 266 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad,
267 stepsize, hyperparams_prox, args, kwargs)
268 next_t = 0.5 * (1 + jnp.sqrt(1 + 4 * t ** 2))
269 diff_x = tree_sub(next_x, x)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:240, in ProximalGradient._iter(self, iter_num, x, x_fun_val, x_fun_grad, stepsize, hyperparams_prox, args, kwargs)
238 else:
239 next_stepsize = self.stepsize
--> 240 next_x = self._prox_grad(x, x_fun_grad, next_stepsize, hyperparams_prox)
241 return next_x, next_stepsize
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:209, in ProximalGradient._prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox)
208 def _prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox):
--> 209 update = tree_add_scalar_mul(x, -stepsize, x_fun_grad)
210 return self.prox(update, hyperparams_prox, stepsize)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul(tree_x, scalar, tree_y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul..(x, y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
TypeError: unsupported operand type(s) for *: 'float' and 'dict'
error info above finished