From ecdbcb0c5bb00f9f8dbc2c0498f55b698c32551e Mon Sep 17 00:00:00 2001 From: Anthony Walsh Date: Tue, 24 Jun 2025 16:30:18 +0200 Subject: [PATCH] Add perceval parameter expression handling to CircuitConverter. --- merlin/pcvl_pytorch/locirc_to_tensor.py | 37 +++++++++++++++++++++---- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/merlin/pcvl_pytorch/locirc_to_tensor.py b/merlin/pcvl_pytorch/locirc_to_tensor.py index 39eebb58..9b516ed5 100644 --- a/merlin/pcvl_pytorch/locirc_to_tensor.py +++ b/merlin/pcvl_pytorch/locirc_to_tensor.py @@ -23,9 +23,12 @@ from __future__ import annotations import random +import ast import torch from multipledispatch import dispatch + +from perceval.utils import Expression from perceval.components import ( BS, PERM, @@ -386,11 +389,13 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: NotImplementedError: If BS convention is not supported """ param_values = [] - - for _index, param in enumerate(comp.get_parameters(all_params=True)): + for index, param in enumerate(comp.get_parameters(all_params=True, expressions=True)): if param.is_variable: - (tensor_id, idx_in_tensor) = self.param_mapping[param.name] - param_values.append(self.torch_params[tensor_id][..., idx_in_tensor]) + if isinstance(param, Expression): + param_values.append(self._parse_expression(param)) + else: + (tensor_id, idx_in_tensor) = self.param_mapping[param.name] + param_values.append(self.torch_params[tensor_id][..., idx_in_tensor]) else: param_values.append(torch.tensor(float(param), dtype=self.tensor_fdtype, device=self.device)) @@ -436,8 +441,11 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: Batched 1x1 phase tensor of shape (batch_size, 1, 1) """ if comp.param("phi").is_variable: - (tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name] - phase = self.torch_params[tensor_id][..., idx_in_tensor] + if isinstance(comp.param("phi"), Expression): + phase = self._parse_expression(comp.param("phi")) + else: + (tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name] + phase = self.torch_params[tensor_id][..., idx_in_tensor] else: phase = torch.tensor(comp.param("phi")._value, dtype=self.tensor_fdtype, device=self.device) @@ -448,3 +456,20 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: unitary_tensor = torch.exp(1j * phase).reshape(-1, 1) # reshape so that in any case, we have 2 dim return unitary_tensor.unsqueeze(-1) # to change shape of tensor to (b, 1, 1) + def _parse_expression(self, expression: Expression) -> torch.Tensor: + """Returns value of a given expression by parsing its name.""" + # Base params in expression + param_list = expression.parameters + tensor_ids_and_indices = [self.param_mapping[p.name] for p in param_list] + + assign_params = { + param_list[i].name: self.torch_params[id][..., idx] + for i, (id, idx) in enumerate(tensor_ids_and_indices) + } + + # Use ast to parse expression name & assign value + tree = ast.parse(expression.name, mode="eval") + code = compile(tree, "", mode="eval") + value = eval(code, {}, assign_params) + + return value