Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions merlin/pcvl_pytorch/locirc_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
from __future__ import annotations

import random
import ast

import torch
from multipledispatch import dispatch

from perceval.utils import Expression
from perceval.components import (
BS,
PERM,
Expand Down Expand Up @@ -443,11 +446,13 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no-
NotImplementedError: If BS convention is not supported
"""
param_values = []

for _index, param in enumerate(comp.get_parameters(all_params=True)):
for index, param in enumerate(comp.get_parameters(all_params=True, expressions=True)):
if param.is_variable:
(tensor_id, idx_in_tensor) = self.param_mapping[param.name]
param_values.append(self.torch_params[tensor_id][..., idx_in_tensor])
if isinstance(param, Expression):
param_values.append(self._parse_expression(param))
else:
(tensor_id, idx_in_tensor) = self.param_mapping[param.name]
param_values.append(self.torch_params[tensor_id][..., idx_in_tensor])
else:
param_values.append(
torch.tensor(
Expand Down Expand Up @@ -509,8 +514,11 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no-
Batched 1x1 phase tensor of shape (batch_size, 1, 1)
"""
if comp.param("phi").is_variable:
(tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name]
phase = self.torch_params[tensor_id][..., idx_in_tensor]
if isinstance(comp.param("phi"), Expression):
phase = self._parse_expression(comp.param("phi"))
else:
(tensor_id, idx_in_tensor) = self.param_mapping[comp.param("phi").name]
phase = self.torch_params[tensor_id][..., idx_in_tensor]
else:
phase = torch.tensor(
comp.param("phi")._value, dtype=self.tensor_fdtype, device=self.device
Expand All @@ -525,3 +533,21 @@ def _compute_tensor(self, comp: AComponent) -> torch.Tensor: # type: ignore[no-
-1, 1
) # reshape so that in any case, we have 2 dim
return unitary_tensor.unsqueeze(-1) # to change shape of tensor to (b, 1, 1)

def _parse_expression(self, expression: Expression) -> torch.Tensor:
"""Returns value of a given expression by parsing its name."""
# Base params in expression
param_list = expression.parameters
tensor_ids_and_indices = [self.param_mapping[p.name] for p in param_list]

assign_params = {
param_list[i].name: self.torch_params[id][..., idx]
for i, (id, idx) in enumerate(tensor_ids_and_indices)
}

# Use ast to parse expression name & assign value
tree = ast.parse(expression.name, mode="eval")
code = compile(tree, "<string>", mode="eval")
value = eval(code, {}, assign_params)

return value
Loading