diff --git a/merlin/pcvl_pytorch/locirc_to_tensor.py b/merlin/pcvl_pytorch/locirc_to_tensor.py index ba7f82c6..9fefd709 100644 --- a/merlin/pcvl_pytorch/locirc_to_tensor.py +++ b/merlin/pcvl_pytorch/locirc_to_tensor.py @@ -23,9 +23,12 @@ from __future__ import annotations import random +import ast import torch from multipledispatch import dispatch + +from perceval.utils import Expression from perceval.components import ( BS, PERM, @@ -443,11 +446,13 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no- NotImplementedError: If BS convention is not supported """ param_values = [] - - for _index, param in enumerate(comp.get_parameters(all_params=True)): + for index, param in enumerate(comp.get_parameters(all_params=True, expressions=True)): if param.is_variable: - (tensor_id, idx_in_tensor) = self.param_mapping[param.name] - param_values.append(self.torch_params[tensor_id][..., idx_in_tensor]) + if isinstance(param, Expression): + param_values.append(self._parse_expression(param)) + else: + (tensor_id, idx_in_tensor) = self.param_mapping[param.name] + param_values.append(self.torch_params[tensor_id][..., idx_in_tensor]) else: param_values.append( torch.tensor( @@ -509,8 +514,11 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no- Batched 1x1 phase tensor of shape (batch_size, 1, 1) """ if comp.param("phi").is_variable: - (tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name] - phase = self.torch_params[tensor_id][..., idx_in_tensor] + if isinstance(comp.param("phi"), Expression): + phase = self._parse_expression(comp.param("phi")) + else: + (tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name] + phase = self.torch_params[tensor_id][..., idx_in_tensor] else: phase = torch.tensor( comp.param("phi")._value, dtype=self.tensor_fdtype, device=self.device @@ -525,3 +533,21 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no- -1, 1 ) # reshape so that in any case, we have 2 dim return unitary_tensor.unsqueeze(-1) # to change shape of tensor to (b, 1, 1) + + def _parse_expression(self, expression: Expression) -> torch.Tensor: + """Returns value of a given expression by parsing its name.""" + # Base params in expression + param_list = expression.parameters + tensor_ids_and_indices = [self.param_mapping[p.name] for p in param_list] + + assign_params = { + param_list[i].name: self.torch_params[id][..., idx] + for i, (id, idx) in enumerate(tensor_ids_and_indices) + } + + # Use ast to parse expression name & assign value + tree = ast.parse(expression.name, mode="eval") + code = compile(tree, "", mode="eval") + value = eval(code, {}, assign_params) + + return value