Skip to content

TypeError after command momi.optimize() #3

@quank

Description

@quank

Here is a toy example using momi3 to optimize the inference that I have just tried

the commands and the TypeError info are attached, any suggestions are appreciated

import demes
import demesdraw

b = demes.Builder()
b.add_deme("ABC", epochs = [dict(start_size = 1200, end_time = 800)])
b.add_deme("AB", ancestors = ["ABC"], epochs = [dict(start_size = 1000, end_time = 550)])
b.add_deme("C", ancestors = ["ABC"], epochs = [dict(start_size = 300)])
b.add_deme("A", ancestors = ["AB"], epochs = [dict(start_size = 800)])
b.add_deme("B", ancestors = ["AB"], epochs = [dict(start_size = 700)])
b.add_pulse(sources = ["C"], dest = "B", time = 240, proportions = [0.28])
g = b.resolve()

demesdraw.tubes(g)

from momi3.MOMI import Momi

sampled_demes = ["A", "B", "C"]
sample_sizes = [20, 30, 15]

momi = Momi(g, sampled_demes=sampled_demes, sample_sizes=sample_sizes, low_memory=True)
params = momi._default_params
bounds = momi.bound_sampler(params, 1000, seed=108)
momi = momi.bound(bounds)

params.set_train_all_etas(True)
params.set_train('eta_0', False)
jsfs = momi.simulate(200, seed=1124321)

momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)

TypeError info as below


TypeError Traceback (most recent call last)
Cell In [7], line 1
----> 1 momi.optimize(params=params, jsfs=jsfs, stepsize=0.5, maxiter=50)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/MOMI.py:284, in Momi.optimize(self, params, jsfs, stepsize, maxiter, theta_train_dict_0, htol, monitor_training)
281 negative_loglik_with_gradient = self.negative_loglik_with_gradient
282 sampled_demes = self.sampled_demes
--> 284 return ProjectedGradient_optimizer(
285 negative_loglik_with_gradient=negative_loglik_with_gradient,
286 params=params,
287 jsfs=jsfs,
288 stepsize=stepsize,
289 maxiter=maxiter,
290 theta_train_dict_0=theta_train_dict_0,
291 sampled_demes=sampled_demes,
292 htol=htol,
293 monitor_training=monitor_training,
294 )

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/momi3-0.0.0-py3.10.egg/momi3/optimizers.py:121, in ProjectedGradient_optimizer(negative_loglik_with_gradient, params, jsfs, stepsize, maxiter, sampled_demes, theta_train_dict_0, htol, monitor_training)
118 plt.xlabel("Iteration Number")
120 else:
--> 121 opt_result = pg.run(theta_train_0, hyperparams_proj=(A, b, G, h))
122 theta_train_hat = opt_result.params
123 pg_state = opt_result.state

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/projected_gradient.py:137, in ProjectedGradient.run(self, init_params, hyperparams_proj, *args, **kwargs)
132 def run(self,
133 init_params: Any,
134 hyperparams_proj: Optional[Any] = None,
135 *args,
136 **kwargs) -> base.OptStep:
--> 137 return self._pg.run(init_params, hyperparams_proj, *args, **kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:359, in IterativeSolver.run(self, init_params, *args, **kwargs)
352 decorator = idf.custom_root(
353 self.optimality_fun,
354 has_aux=True,
355 solve=self.implicit_diff_solve,
356 reference_signature=reference_signature)
357 run = decorator(run)
--> 359 return run(init_params, *args, **kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root..wrapped_solver_fun(*args, **kwargs)
249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

[... skipping hidden 5 frame]

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root..make_custom_vjp_solver_fun..solver_fun_flat(*flat_args)
204 @jax.custom_vjp
205 def solver_fun_flat(*flat_args):
206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207 return solver_fun(*args, **kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/base.py:321, in IterativeSolver._run(self, init_params, *args, **kwargs)
303 # We unroll the very first iteration. This allows init_val and body_fun
304 # below to have the same output type, which is a requirement of
305 # lax.while_loop and lax.scan.
(...)
316 # of a lax.cond for now in order to avoid staging the initial
317 # update and the run loop. They might not be staging compatible.
319 zero_step = self._make_zero_step(init_params, state)
--> 321 opt_step = self.update(init_params, state, *args, **kwargs)
322 init_val = (opt_step, (args, kwargs))
324 unroll = self._get_unroll_option()

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:305, in ProximalGradient.update(self, params, state, hyperparams_prox, *args, **kwargs)
293 """Performs one iteration of proximal gradient.
294
295 Args:
(...)
302 (params, state)
303 """
304 f = self._update_accel if self.acceleration else self._update
--> 305 return f(params, state, hyperparams_prox, args, kwargs)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:266, in ProximalGradient._update_accel(self, x, state, hyperparams_prox, args, kwargs)
263 stepsize = state.stepsize
264 (y_fun_val, aux), y_fun_grad = self._value_and_grad_with_aux(y, *args,
265 **kwargs)
--> 266 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad,
267 stepsize, hyperparams_prox, args, kwargs)
268 next_t = 0.5 * (1 + jnp.sqrt(1 + 4 * t ** 2))
269 diff_x = tree_sub(next_x, x)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:240, in ProximalGradient._iter(self, iter_num, x, x_fun_val, x_fun_grad, stepsize, hyperparams_prox, args, kwargs)
238 else:
239 next_stepsize = self.stepsize
--> 240 next_x = self._prox_grad(x, x_fun_grad, next_stepsize, hyperparams_prox)
241 return next_x, next_stepsize

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py:209, in ProximalGradient._prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox)
208 def _prox_grad(self, x, x_fun_grad, stepsize, hyperparams_prox):
--> 209 update = tree_add_scalar_mul(x, -stepsize, x_fun_grad)
210 return self.prox(update, hyperparams_prox, stepsize)

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul(tree_x, scalar, tree_y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)

[... skipping hidden 2 frame]

File ~/tools/mambaforge/envs/momi3/lib/python3.10/site-packages/jaxopt/_src/tree_util.py:91, in tree_add_scalar_mul..(x, y)
89 def tree_add_scalar_mul(tree_x, scalar, tree_y):
90 """Compute tree_x + scalar * tree_y."""
---> 91 return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)

TypeError: unsupported operand type(s) for *: 'float' and 'dict'

error info above finished

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions