Skip to content

Commit f4874e6

Browse files
Liu KeyuLiu Keyu
authored andcommitted
Fix bugs
Fix bugs Fix bugs Fix bugs
1 parent 2692b96 commit f4874e6

File tree

3 files changed

+19
-72
lines changed

3 files changed

+19
-72
lines changed

src/mqt/predictor/rl/actions.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from qiskit.transpiler import CouplingMap
5151
from qiskit.transpiler.passes import (
5252
ApplyLayout,
53-
BasicSwap,
5453
BasisTranslator,
5554
Collect2qBlocks,
5655
CollectCliffords,
@@ -125,6 +124,7 @@ class Action:
125124
pass_type: PassType
126125
transpile_pass: (
127126
list[qiskit_BasePass | tket_BasePass]
127+
| Callable[..., list[Any]]
128128
| Callable[..., list[qiskit_BasePass | tket_BasePass]]
129129
| Callable[
130130
...,
@@ -144,7 +144,8 @@ class DeviceDependentAction(Action):
144144
"""Action that represents a device-specific compilation pass that can be applied to a specific device."""
145145

146146
transpile_pass: (
147-
Callable[..., list[qiskit_BasePass | tket_BasePass]]
147+
Callable[..., list[Any]]
148+
| Callable[..., list[qiskit_BasePass | tket_BasePass]]
148149
| Callable[
149150
...,
150151
Callable[..., tuple[Any, ...] | Circuit],
@@ -466,7 +467,7 @@ def get_openqasm_gates() -> list[str]:
466467

467468
register_action(
468469
DeviceDependentAction(
469-
"SabreLayout+BasicSwap",
470+
"SabreLayout+AIRouting",
470471
CompilationOrigin.QISKIT,
471472
PassType.MAPPING,
472473
stochastic=True,
@@ -482,26 +483,19 @@ def get_openqasm_gates() -> list[str]:
482483
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
483484
EnlargeWithAncilla(),
484485
ApplyLayout(),
485-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
486+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
486487
],
487488
)
488489
)
489490

490491
register_action(
491492
DeviceDependentAction(
492-
"SabreLayout+AIRouting",
493+
"AIRouting",
493494
CompilationOrigin.QISKIT,
494495
PassType.MAPPING,
495496
stochastic=True,
496-
transpile_pass=lambda device, max_iteration=(20, 20): [
497-
SabreLayout(
498-
coupling_map=CouplingMap(device.build_coupling_map()),
499-
skip_routing=True,
500-
layout_trials=max_iteration[0],
501-
swap_trials=max_iteration[1],
502-
max_iterations=4,
503-
seed=None,
504-
),
497+
transpile_pass=lambda device: [
498+
TrivialLayout(coupling_map=CouplingMap(device.build_coupling_map())),
505499
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
506500
EnlargeWithAncilla(),
507501
ApplyLayout(),
@@ -532,20 +526,6 @@ def get_openqasm_gates() -> list[str]:
532526
)
533527
)
534528

535-
register_action(
536-
DeviceDependentAction(
537-
name="DenseLayout+BasicSwap",
538-
origin=CompilationOrigin.QISKIT,
539-
pass_type=PassType.MAPPING,
540-
transpile_pass=lambda device: [
541-
DenseLayout(coupling_map=CouplingMap(device.build_coupling_map())),
542-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
543-
EnlargeWithAncilla(),
544-
ApplyLayout(),
545-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
546-
],
547-
)
548-
)
549529

550530
register_action(
551531
DeviceDependentAction(
@@ -579,45 +559,11 @@ def get_openqasm_gates() -> list[str]:
579559
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
580560
EnlargeWithAncilla(),
581561
ApplyLayout(),
582-
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="optimize"),
562+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
583563
],
584564
)
585565
)
586566

587-
register_action(
588-
DeviceDependentAction(
589-
name="VF2Layout+BasicSwap",
590-
origin=CompilationOrigin.QISKIT,
591-
pass_type=PassType.MAPPING,
592-
transpile_pass=lambda device: [
593-
VF2Layout(
594-
coupling_map=CouplingMap(device.build_coupling_map()),
595-
target=device,
596-
),
597-
ConditionalController(
598-
[
599-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
600-
EnlargeWithAncilla(),
601-
ApplyLayout(),
602-
],
603-
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
604-
== VF2LayoutStopReason.SOLUTION_FOUND,
605-
),
606-
ConditionalController(
607-
[
608-
TrivialLayout(coupling_map=CouplingMap(device.build_coupling_map())),
609-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
610-
EnlargeWithAncilla(),
611-
ApplyLayout(),
612-
],
613-
# Run if VF2Layout did not find a solution
614-
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
615-
!= VF2LayoutStopReason.SOLUTION_FOUND,
616-
),
617-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
618-
],
619-
)
620-
)
621567

622568
register_action(
623569
DeviceDependentAction(
@@ -691,7 +637,7 @@ def get_openqasm_gates() -> list[str]:
691637
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
692638
!= VF2LayoutStopReason.SOLUTION_FOUND,
693639
),
694-
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="optimize"),
640+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
695641
],
696642
)
697643
)

src/mqt/predictor/rl/predictorenv.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any
188188
RuntimeError: If no valid actions are left.
189189
"""
190190
self.used_actions.append(str(self.action_set[action].name))
191+
logger.info(f"Applying: {self.action_set[action].name!s}")
191192
altered_qc = self.apply_action(action)
192193
if not altered_qc:
193194
return (
@@ -377,9 +378,9 @@ def metric_fn(circ: QuantumCircuit) -> float:
377378
pm = PassManager(transpile_pass)
378379
altered_qc = pm.run(self.state)
379380
pm_property_set = dict(pm.property_set) if hasattr(pm, "property_set") else {}
380-
if action_index in (self.actions_mapping_indices + self.actions_final_optimization_indices):
381-
pm_property_set = dict(pm.property_set)
382-
altered_qc = self._handle_qiskit_layout_postprocessing(action, pm_property_set, altered_qc)
381+
if action_index in (self.actions_mapping_indices + self.actions_final_optimization_indices):
382+
pm_property_set = dict(pm.property_set)
383+
altered_qc = self._handle_qiskit_layout_postprocessing(action, pm_property_set, altered_qc)
383384

384385
return altered_qc
385386

@@ -466,7 +467,7 @@ def _apply_bqskit_action(self, action: Action, action_index: int) -> QuantumCirc
466467

467468
def determine_valid_actions_for_state(self) -> list[int]:
468469
"""Determines and returns the valid actions for the current state."""
469-
check_nat_gates = GatesInBasis(target=self.device)
470+
check_nat_gates = GatesInBasis(basis_gates=self.device.operation_names)
470471
check_nat_gates(self.state)
471472
only_nat_gates = check_nat_gates.property_set["all_gates_in_basis"]
472473

@@ -477,7 +478,7 @@ def determine_valid_actions_for_state(self) -> list[int]:
477478
if not only_nat_gates: # not native gates yet
478479
return self.actions_synthesis_indices + self.actions_opt_indices
479480

480-
if mapped and self.layout is not None: # The circuit is correctly mapped.
481+
if mapped and self.layout is not None: # The circuit is correctly mapped
481482
return [self.action_terminate_index, *self.actions_opt_indices, *self.actions_final_optimization_indices]
482483
# The circuit is not mapped yet
483484
# Or the circuit was mapped but some optimization actions change its structure and the circuit is again unmapped

tests/compilation/test_predictor_rl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_predictor_env_reset_from_string() -> None:
3737
device = get_device("ibm_eagle_127")
3838
predictor = Predictor(figure_of_merit="expected_fidelity", device=device)
3939
qasm_path = Path("test.qasm")
40-
qc = get_benchmark("dj", BenchmarkLevel.ALG, 3)
40+
qc = get_benchmark("dj", BenchmarkLevel.INDEP, 3)
4141
with qasm_path.open("w", encoding="utf-8") as f:
4242
dump(qc, f)
4343
assert predictor.env.reset(qc=qasm_path)[0] == create_feature_dict(qc)
@@ -69,7 +69,7 @@ def test_qcompile_with_newly_trained_models() -> None:
6969
"""
7070
figure_of_merit = "expected_fidelity"
7171
device = get_device("ibm_falcon_127")
72-
qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3)
72+
qc = get_benchmark("ghz", BenchmarkLevel.INDEP, 3)
7373
predictor = Predictor(figure_of_merit=figure_of_merit, device=device)
7474

7575
model_name = "model_" + figure_of_merit + "_" + device.description
@@ -95,7 +95,7 @@ def test_qcompile_with_newly_trained_models() -> None:
9595

9696
def test_qcompile_with_false_input() -> None:
9797
"""Test the qcompile function with false input."""
98-
qc = get_benchmark("dj", BenchmarkLevel.ALG, 5)
98+
qc = get_benchmark("dj", BenchmarkLevel.INDEP, 5)
9999
with pytest.raises(ValueError, match=re.escape("figure_of_merit must not be None if predictor_singleton is None.")):
100100
rl_compile(qc, device=get_device("quantinuum_h2_56"), figure_of_merit=None)
101101
with pytest.raises(ValueError, match=re.escape("device must not be None if predictor_singleton is None.")):

0 commit comments

Comments
 (0)