diff --git a/bindings/InterfaceBindings.cpp b/bindings/InterfaceBindings.cpp index 1d636a45..5896f349 100644 --- a/bindings/InterfaceBindings.cpp +++ b/bindings/InterfaceBindings.cpp @@ -21,12 +21,15 @@ #include #include #include +#include +#include #include #include // NOLINT(misc-include-cleaner) #include // NOLINT(misc-include-cleaner) #include // NOLINT(misc-include-cleaner) #include #include +#include #include #include @@ -35,6 +38,11 @@ using namespace nb::literals; namespace { +size_t boundedStrnlen(const char* data, size_t max) { + const auto* end = static_cast(std::memchr(data, '\0', max)); + return end != nullptr ? static_cast(end - data) : max; +} + /** * @brief Checks whether the given result is OK, and throws a runtime_error * otherwise. @@ -62,6 +70,44 @@ struct StatevectorCPP { // NOLINTNEXTLINE(misc-use-internal-linkage) void bindFramework(nb::module_& m) { + // Bind the Result enum + nb::enum_(m, "Result", "Represents the result of an operation.") + .value("OK", OK, "Indicates that the operation was successful.") + .value("ERROR", ERROR, "Indicates that an error occurred."); + + // Bind the LoadResultStatus enum + nb::enum_( + m, "LoadResultStatus", + "Represents the result of a code loading operation.") + .value("OK", LOAD_OK, "Indicates that the code was loaded successfully.") + .value("PARSE_ERROR", LOAD_PARSE_ERROR, + "Indicates that the code could not be parsed.") + .value("INTERNAL_ERROR", LOAD_INTERNAL_ERROR, + "Indicates that an internal error occurred while loading."); + + // Bind the LoadResult struct + nb::class_(m, "LoadResult") + .def(nb::init<>()) + .def_rw("status", &LoadResult::status, + "Indicates whether the load was successful and why it failed.") + .def_rw("line", &LoadResult::line, + "The line number of the error location, or 0 if unknown.") + .def_rw("column", &LoadResult::column, + "The column number of the error location, or 0 if unknown.") + .def_prop_ro( + "message", + [](const LoadResult& self) { + const auto* data = std::data(self.message); + const std::string_view messageView( + data, boundedStrnlen(data, LOAD_RESULT_MESSAGE_MAX)); + if (messageView.empty()) { + return nb::none(); + } + return nb::cast(std::string(messageView)); + }, + "A human-readable error message, or None if none is available.") + .doc() = "The result of a code loading operation."; + // Bind the VariableType enum nb::enum_(m, "VariableType", "The type of a classical variable.") @@ -164,13 +210,16 @@ Contains one element for each of the `num_states` states in the state vector.)") .def( "load_code", [](SimulationState* self, const char* code) { - checkOrThrow(self->loadCode(self, code)); + return self->loadCode(self, code); }, "code"_a, R"(Loads the given code into the simulation state. Args: - code: The code to load.)") + code: The code to load. + +Returns: + LoadResult: The result of the load operation.)") .def( "step_forward", [](SimulationState* self) { checkOrThrow(self->stepForward(self)); }, diff --git a/include/backend/dd/DDSimDebug.hpp b/include/backend/dd/DDSimDebug.hpp index 84fc4fcf..94e195f2 100644 --- a/include/backend/dd/DDSimDebug.hpp +++ b/include/backend/dd/DDSimDebug.hpp @@ -277,9 +277,9 @@ Result ddsimInit(SimulationState* self); * @brief Loads the given code into the simulation state. * @param self The instance to load the code into. * @param code The code to load. - * @return The result of the operation. + * @return The result of the load operation. */ -Result ddsimLoadCode(SimulationState* self, const char* code); +LoadResult ddsimLoadCode(SimulationState* self, const char* code); /** * @brief Steps the simulation forward by one instruction. * @param self The instance to step forward. diff --git a/include/backend/debug.h b/include/backend/debug.h index 2fd7c21f..b3840132 100644 --- a/include/backend/debug.h +++ b/include/backend/debug.h @@ -50,9 +50,9 @@ struct SimulationStateStruct { * @brief Loads the given code into the simulation state. * @param self The instance to load the code into. * @param code The code to load. - * @return The result of the operation. + * @return The result of the load operation. */ - Result (*loadCode)(SimulationState* self, const char* code); + LoadResult (*loadCode)(SimulationState* self, const char* code); /** * @brief Steps the simulation forward by one instruction. diff --git a/include/common.h b/include/common.h index c31e2ff7..247e1e1a 100644 --- a/include/common.h +++ b/include/common.h @@ -42,6 +42,52 @@ typedef enum { ERROR, } Result; +/** + * @brief The result of a code loading operation. + */ +typedef enum { + /** + * @brief Indicates that the code was loaded successfully. + */ + LOAD_OK, + /** + * @brief Indicates that the code could not be parsed. + */ + LOAD_PARSE_ERROR, + /** + * @brief Indicates that an internal error occurred while loading the code. + */ + LOAD_INTERNAL_ERROR, +} LoadResultStatus; + +/** + * @brief Maximum length of a load error message (including null terminator). + */ +#define LOAD_RESULT_MESSAGE_MAX 1024 + +/** + * @brief The result of a code loading operation. + */ +typedef struct { + /** + * @brief Indicates whether the load was successful and why it failed. + */ + LoadResultStatus status; + /** + * @brief The line number of the error location, or 0 if unknown. + */ + size_t line; + /** + * @brief The column number of the error location, or 0 if unknown. + */ + size_t column; + /** + * @brief A human-readable error message, or empty string if none is + * available. + */ + char message[LOAD_RESULT_MESSAGE_MAX]; +} LoadResult; + /** * @brief The type of classical variables. * diff --git a/include/common/parsing/CodePreprocessing.hpp b/include/common/parsing/CodePreprocessing.hpp index fe2f94b6..f777d234 100644 --- a/include/common/parsing/CodePreprocessing.hpp +++ b/include/common/parsing/CodePreprocessing.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -211,6 +212,24 @@ struct ClassicControlledGate { std::vector operations; }; +/** + * @brief Represents a parsed classic-controlled condition. + */ +struct ClassicCondition { + /** + * @brief The name of the classical register. + */ + std::string registerName; + /** + * @brief Optional bit index if the condition targets a single bit. + */ + std::optional bitIndex; + /** + * @brief The expected value in the condition comparison. + */ + size_t expectedValue; +}; + /** * @brief Represents a function definition in the code. */ @@ -286,6 +305,22 @@ bool isClassicControlledGate(const std::string& line); */ ClassicControlledGate parseClassicControlledGate(const std::string& code); +/** + * @brief Parse a classic-controlled condition expression. + * @param condition The condition string to parse. + * @return The parsed condition, or std::nullopt if it cannot be parsed. + */ +std::optional +parseClassicConditionExpression(const std::string& condition); + +/** + * @brief Parse a classic-controlled condition from a classic-controlled gate. + * @param code The code to parse. + * @return The parsed condition, or std::nullopt if it cannot be parsed. + */ +std::optional +parseClassicConditionFromCode(const std::string& code); + /** * @brief Check if a given line is a variable declaration. * diff --git a/include/common/parsing/ParsingError.hpp b/include/common/parsing/ParsingError.hpp index 68aa3459..7078d8f8 100644 --- a/include/common/parsing/ParsingError.hpp +++ b/include/common/parsing/ParsingError.hpp @@ -15,6 +15,7 @@ #pragma once +#include #include #include @@ -30,6 +31,44 @@ class ParsingError : public std::runtime_error { * @param msg The error message. */ explicit ParsingError(const std::string& msg); + + /** + * @brief Constructs a new ParsingError with location information. + * @param line The one-based line number, or 0 if unknown. + * @param column The one-based column number, or 0 if unknown. + * @param detail The error detail message. + */ + ParsingError(size_t line, size_t column, std::string detail); + + /** + * @brief Constructs a new ParsingError with location information and message. + * @param line The one-based line number, or 0 if unknown. + * @param column The one-based column number, or 0 if unknown. + * @param detail The error detail message. + * @param message The formatted error message. + */ + ParsingError(size_t line, size_t column, std::string detail, + const std::string& message); + + /** + * @brief Gets the line number of the error location, or 0 if unknown. + */ + size_t line() const noexcept; + + /** + * @brief Gets the column number of the error location, or 0 if unknown. + */ + size_t column() const noexcept; + + /** + * @brief Gets the error detail message. + */ + const std::string& detail() const noexcept; + +private: + size_t line_ = 0; + size_t column_ = 0; + std::string detail_; }; } // namespace mqt::debugger diff --git a/python/mqt/debugger/__init__.py b/python/mqt/debugger/__init__.py index 75ac8d43..7705d2a7 100644 --- a/python/mqt/debugger/__init__.py +++ b/python/mqt/debugger/__init__.py @@ -18,6 +18,9 @@ Diagnostics, ErrorCause, ErrorCauseType, + LoadResult, + LoadResultStatus, + Result, SimulationState, Statevector, Variable, @@ -33,6 +36,9 @@ "Diagnostics", "ErrorCause", "ErrorCauseType", + "LoadResult", + "LoadResultStatus", + "Result", "SimulationState", "Statevector", "Variable", diff --git a/python/mqt/debugger/check/run_preparation.py b/python/mqt/debugger/check/run_preparation.py index d3468b84..bbd2e772 100644 --- a/python/mqt/debugger/check/run_preparation.py +++ b/python/mqt/debugger/check/run_preparation.py @@ -42,22 +42,32 @@ def start_compilation(code: Path, output_dir: Path) -> None: output_dir (Path): The directory to store the compiled slices. """ state = dbg.create_ddsim_simulation_state() - with code.open("r", encoding="utf-8") as f: - code_str = f.read() - state.load_code(code_str) - i = 0 - while True: - i += 1 - settings = dbg.CompilationSettings( - opt=0, - slice_index=i - 1, - ) - compiled = state.compile(settings) - if not compiled: - break - with (output_dir / f"slice_{i}.qasm").open("w") as f: - f.write(compiled) - dbg.destroy_ddsim_simulation_state(state) + try: + with code.open("r", encoding="utf-8") as f: + code_str = f.read() + load_result = state.load_code(code_str) + if load_result.status != dbg.LoadResultStatus.OK: + message = load_result.message or "Error loading code" + raise RuntimeError(message) + i = 0 + compiled_any = False + while True: + i += 1 + settings = dbg.CompilationSettings( + opt=0, + slice_index=i - 1, + ) + compiled = state.compile(settings) + if not compiled: + break + compiled_any = True + with (output_dir / f"slice_{i}.qasm").open("w") as f: + f.write(compiled) + if not compiled_any: + msg = "No compiled slices produced; check input code for validity." + raise RuntimeError(msg) + finally: + dbg.destroy_ddsim_simulation_state(state) # ------------------------- diff --git a/python/mqt/debugger/dap/dap_server.py b/python/mqt/debugger/dap/dap_server.py index 661bb017..90ef27dd 100644 --- a/python/mqt/debugger/dap/dap_server.py +++ b/python/mqt/debugger/dap/dap_server.py @@ -13,7 +13,7 @@ import json import socket import sys -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import mqt.debugger @@ -114,6 +114,10 @@ def __init__(self, host: str = "127.0.0.1", port: int = 4711) -> None: self.simulation_state = mqt.debugger.SimulationState() self.lines_start_at_one = True self.columns_start_at_one = True + self.pending_highlights: list[dict[str, Any]] = [] + self.source_file = {"name": "", "path": ""} + self.source_code = "" + self._prevent_exit = False def start(self) -> None: """Start the DAP server and listen for one connection.""" @@ -166,12 +170,32 @@ def handle_client(self, connection: socket.socket) -> None: result, cmd = self.handle_command(payload) result_payload = json.dumps(result) send_message(result_payload, connection) + if isinstance( + cmd, + ( + mqt.debugger.dap.messages.NextDAPMessage, + mqt.debugger.dap.messages.StepBackDAPMessage, + mqt.debugger.dap.messages.StepInDAPMessage, + mqt.debugger.dap.messages.StepOutDAPMessage, + mqt.debugger.dap.messages.ContinueDAPMessage, + mqt.debugger.dap.messages.ReverseContinueDAPMessage, + mqt.debugger.dap.messages.RestartFrameDAPMessage, + mqt.debugger.dap.messages.RestartDAPMessage, + mqt.debugger.dap.messages.LaunchDAPMessage, + ), + ): + self._prevent_exit = False e: mqt.debugger.dap.messages.DAPEvent | None = None if isinstance(cmd, mqt.debugger.dap.messages.LaunchDAPMessage): e = mqt.debugger.dap.messages.InitializedDAPEvent() event_payload = json.dumps(e.encode()) send_message(event_payload, connection) + if isinstance( + cmd, (mqt.debugger.dap.messages.LaunchDAPMessage, mqt.debugger.dap.messages.RestartDAPMessage) + ): + clear_event = mqt.debugger.dap.messages.GrayOutDAPEvent([], self.source_file) + send_message(json.dumps(clear_event.encode()), connection) if ( isinstance( cmd, (mqt.debugger.dap.messages.LaunchDAPMessage, mqt.debugger.dap.messages.RestartDAPMessage) @@ -236,6 +260,18 @@ def handle_client(self, connection: socket.socket) -> None: ) event_payload = json.dumps(e.encode()) send_message(event_payload, connection) + if self.pending_highlights: + try: + highlight_event = mqt.debugger.dap.messages.HighlightError( + self.pending_highlights, + self.source_file, + ) + send_message(json.dumps(highlight_event.encode()), connection) + self._prevent_exit = True + except (TypeError, ValueError): + pass + finally: + self.pending_highlights = [] self.regular_checks(connection) def regular_checks(self, connection: socket.socket) -> None: @@ -245,7 +281,11 @@ def regular_checks(self, connection: socket.socket) -> None: connection (socket.socket): The client socket. """ e: mqt.debugger.dap.messages.DAPEvent | None = None - if self.simulation_state.is_finished() and self.simulation_state.get_instruction_count() != 0: + if ( + self.simulation_state.is_finished() + and self.simulation_state.get_instruction_count() != 0 + and not self._prevent_exit + ): e = mqt.debugger.dap.messages.ExitedDAPEvent(0) event_payload = json.dumps(e.encode()) send_message(event_payload, connection) @@ -325,7 +365,16 @@ def handle_assertion_fail(self, connection: socket.socket) -> None: line, column, connection, + "stderr", ) + highlight_entries = self.collect_highlight_entries(current_instruction, error_causes) + if highlight_entries: + try: + highlight_event = mqt.debugger.dap.messages.HighlightError(highlight_entries, self.source_file) + send_message(json.dumps(highlight_event.encode()), connection) + self._prevent_exit = True + except (TypeError, ValueError): + pass def code_pos_to_coordinates(self, pos: int) -> tuple[int, int]: """Helper method to convert a code position to line and column. @@ -337,14 +386,18 @@ def code_pos_to_coordinates(self, pos: int) -> tuple[int, int]: The line and column, 0-or-1-indexed. """ lines = self.source_code.split("\n") - line = 0 + line = 1 if lines else 0 col = 0 for i, line_code in enumerate(lines): - if pos < len(line_code): + if pos <= len(line_code): line = i + 1 col = pos break pos -= len(line_code) + 1 + else: + if lines: + line = len(lines) + col = len(lines[-1]) if self.columns_start_at_one: col += 1 if not self.lines_start_at_one: @@ -391,8 +444,167 @@ def format_error_cause(self, cause: mqt.debugger.ErrorCause) -> str: else "" ) + def collect_highlight_entries( + self, + failing_instruction: int, + error_causes: list[mqt.debugger.ErrorCause] | None = None, + ) -> list[dict[str, Any]]: + """Collect highlight entries for the current assertion failure.""" + highlights: list[dict[str, Any]] = [] + if self.source_code: + try: + if error_causes is None: + diagnostics = self.simulation_state.get_diagnostics() + error_causes = diagnostics.potential_error_causes() + except RuntimeError: + error_causes = [] + + for cause in error_causes: + message = self.format_error_cause(cause) + reason = self._format_highlight_reason(cause.type_) + entry = self._build_highlight_entry(cause.instruction, reason, message) + if entry is not None: + highlights.append(entry) + + if not highlights: + entry = self._build_highlight_entry( + failing_instruction, + mqt.debugger.dap.messages.HighlightReason.ASSERTION_FAILED, + "Assertion failed at this instruction.", + ) + if entry is not None: + highlights.append(entry) + + return highlights + + def _build_highlight_entry( + self, + instruction: int, + reason: mqt.debugger.dap.messages.HighlightReason, + message: str, + ) -> dict[str, Any] | None: + """Create a highlight entry for a specific instruction.""" + try: + start_pos, end_pos = self.simulation_state.get_instruction_position(instruction) + except RuntimeError: + return None + start_line, start_column = self.code_pos_to_coordinates(start_pos) + if end_pos < len(self.source_code) and self.source_code[end_pos] == "\n": + end_position_exclusive = end_pos + else: + end_position_exclusive = min(len(self.source_code), end_pos + 1) + end_line, end_column = self.code_pos_to_coordinates(end_position_exclusive) + snippet = self.source_code[start_pos : end_pos + 1].replace("\r", "") + return { + "instruction": int(instruction), + "range": { + "start": {"line": start_line, "column": start_column}, + "end": {"line": end_line, "column": end_column}, + }, + "reason": reason, + "code": snippet.strip(), + "message": message, + } + + @staticmethod + def _format_highlight_reason( + cause_type: mqt.debugger.ErrorCauseType | None, + ) -> mqt.debugger.dap.messages.HighlightReason: + """Return a short identifier for the highlight reason.""" + if cause_type == mqt.debugger.ErrorCauseType.MissingInteraction: + return mqt.debugger.dap.messages.HighlightReason.MISSING_INTERACTION + if cause_type == mqt.debugger.ErrorCauseType.ControlAlwaysZero: + return mqt.debugger.dap.messages.HighlightReason.CONTROL_ALWAYS_ZERO + return mqt.debugger.dap.messages.HighlightReason.UNKNOWN + + def queue_parse_error( + self, + error_message: str, + line: int | None = None, + column: int | None = None, + ) -> None: + """Store highlight data for a parse error to be emitted later.""" + detail = error_message.strip() + if not detail: + detail = "An error occurred while parsing the code." + if line is None or column is None: + line = 1 + column = 1 + entry = self._build_parse_error_highlight(line, column, detail) + if entry is not None: + self.pending_highlights = [entry] + + def _build_parse_error_highlight(self, line: int, column: int, detail: str) -> dict[str, Any] | None: + """Create a highlight entry for a parse error.""" + if not self.source_code: + return None + lines = self.source_code.split("\n") + if not lines: + return None + line = max(1, min(line, len(lines))) + column = max(1, column) + line_index = line - 1 + line_text = lines[line_index] + + if column <= 1 and line_index > 0 and not line_text.strip(): + prev_index = line_index - 1 + while prev_index >= 0 and not lines[prev_index].strip(): + prev_index -= 1 + if prev_index >= 0: + line_index = prev_index + line = line_index + 1 + line_text = lines[line_index] + stripped = line_text.lstrip() + column = max(1, len(line_text) - len(stripped) + 1) if stripped else 1 + + # Clamp to end-of-line to keep columns within bounds while preserving end >= start. + max_column = len(line_text) + 1 + column = min(column, max_column) + end_column = max_column + snippet = line_text.strip() or line_text + return { + "instruction": -1, + "range": { + "start": {"line": line, "column": column}, + "end": {"line": line, "column": end_column if end_column > 0 else column}, + }, + "reason": mqt.debugger.dap.messages.HighlightReason.PARSE_ERROR, + "code": snippet, + "message": detail, + } + + def _flatten_message_parts(self, parts: list[Any]) -> list[str]: + """Flatten nested message structures into plain text lines.""" + flattened: list[str] = [] + for part in parts: + if isinstance(part, str): + if part: + flattened.append(part) + elif isinstance(part, dict): + title = part.get("title") + if isinstance(title, str) and title: + flattened.append(title) + body = part.get("body") + if isinstance(body, list): + flattened.extend(self._flatten_message_parts(body)) + elif isinstance(body, str) and body: + flattened.append(body) + end = part.get("end") + if isinstance(end, str) and end: + flattened.append(end) + elif isinstance(part, list): + flattened.extend(self._flatten_message_parts(part)) + elif part is not None: + flattened.append(str(part)) + return flattened + def send_message_hierarchy( - self, message: dict[str, str | list[Any] | dict[str, Any]], line: int, column: int, connection: socket.socket + self, + message: dict[str, str | list[Any] | dict[str, Any]], + line: int, + column: int, + connection: socket.socket, + category: str = "console", ) -> None: """Send a hierarchy of messages to the client. @@ -401,34 +613,56 @@ def send_message_hierarchy( line: The line number. column: The column number. connection: The client socket. + category: The output category (console/stdout/stderr). """ - if "title" in message: - title_event = mqt.debugger.dap.messages.OutputDAPEvent( - "console", cast("str", message["title"]), "start", line, column, self.source_file - ) - send_message(json.dumps(title_event.encode()), connection) - - if "body" in message: - body = message["body"] - if isinstance(body, list): - for msg in body: - if isinstance(msg, dict): - self.send_message_hierarchy(msg, line, column, connection) - else: - output_event = mqt.debugger.dap.messages.OutputDAPEvent( - "console", msg, None, line, column, self.source_file - ) - send_message(json.dumps(output_event.encode()), connection) - elif isinstance(body, dict): - self.send_message_hierarchy(body, line, column, connection) - elif isinstance(body, str): - output_event = mqt.debugger.dap.messages.OutputDAPEvent( - "console", body, None, line, column, self.source_file - ) - send_message(json.dumps(output_event.encode()), connection) + raw_body = message.get("body") + body: list[str] | None = None + if isinstance(raw_body, list): + body = self._flatten_message_parts(raw_body) + elif isinstance(raw_body, str): + body = [raw_body] + end_value = message.get("end") + end = end_value if isinstance(end_value, str) else None + title = str(message.get("title", "")) + self.send_message_simple(title, body, end, line, column, connection, category) + + def send_message_simple( + self, + title: str, + body: list[str] | None, + end: str | None, + line: int, + column: int, + connection: socket.socket, + category: str = "console", + ) -> None: + """Send a simple message to the client. - if "end" in message or "title" in message: - end_event = mqt.debugger.dap.messages.OutputDAPEvent( - "console", cast("str", message.get("end")), "end", line, column, self.source_file - ) - send_message(json.dumps(end_event.encode()), connection) + Args: + title (str): The title of the message. + body (list[str]): The body of the message. + end (str | None): The end of the message. + line (int): The line number. + column (int): The column number. + connection (socket.socket): The client socket. + category (str): The output category (console/stdout/stderr). + """ + segments: list[str] = [] + if title: + segments.append(title) + if body: + segments.extend(body) + if end: + segments.append(end) + if not segments: + return + output_text = "\n".join(segments) + event = mqt.debugger.dap.messages.OutputDAPEvent( + category, + output_text, + None, + line, + column, + self.source_file, + ) + send_message(json.dumps(event.encode()), connection) diff --git a/python/mqt/debugger/dap/messages/__init__.py b/python/mqt/debugger/dap/messages/__init__.py index c905cfef..d1e0ed19 100644 --- a/python/mqt/debugger/dap/messages/__init__.py +++ b/python/mqt/debugger/dap/messages/__init__.py @@ -21,6 +21,7 @@ from .exception_info_message import ExceptionInfoDAPMessage from .exited_dap_event import ExitedDAPEvent from .gray_out_event import GrayOutDAPEvent +from .highlight_error_dap_message import HighlightError, HighlightReason from .initialize_dap_message import InitializeDAPMessage from .initialized_dap_event import InitializedDAPEvent from .launch_dap_message import LaunchDAPMessage @@ -55,6 +56,8 @@ "ExceptionInfoDAPMessage", "ExitedDAPEvent", "GrayOutDAPEvent", + "HighlightError", + "HighlightReason", "InitializeDAPMessage", "InitializedDAPEvent", "LaunchDAPMessage", diff --git a/python/mqt/debugger/dap/messages/highlight_error_dap_message.py b/python/mqt/debugger/dap/messages/highlight_error_dap_message.py new file mode 100644 index 00000000..d6525c1e --- /dev/null +++ b/python/mqt/debugger/dap/messages/highlight_error_dap_message.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +"""Represents the custom 'highlightError' DAP event.""" + +from __future__ import annotations + +import enum +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from .dap_event import DAPEvent + + +class HighlightReason(enum.Enum): + """Represents the reason for highlighting a range.""" + + MISSING_INTERACTION = "missingInteraction" + CONTROL_ALWAYS_ZERO = "controlAlwaysZero" + ASSERTION_FAILED = "assertionFailed" + PARSE_ERROR = "parseError" + UNKNOWN = "unknown" + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +class HighlightError(DAPEvent): + """Represents the 'highlightError' custom DAP event. + + Attributes: + event_name (str): DAP event identifier emitted by this message. + highlights (list[dict[str, Any]]): Normalized highlight entries with ranges and metadata. + source (dict[str, Any]): Normalized DAP source information for the highlighted file. + """ + + event_name = "highlightError" + + highlights: list[dict[str, Any]] + source: dict[str, Any] + + def __init__(self, highlights: Sequence[Mapping[str, Any]], source: Mapping[str, Any]) -> None: + """Create a new 'highlightError' DAP event message. + + Args: + highlights (Sequence[Mapping[str, Any]]): Highlight entries describing the problematic ranges. + source (Mapping[str, Any]): Information about the current source file. + """ + self.highlights = [self._normalize_highlight(entry) for entry in highlights] + self.source = self._normalize_source(source) + super().__init__() + + def validate(self) -> None: + """Validate the 'highlightError' DAP event message after creation. + + Raises: + ValueError: If required highlight fields are missing or empty. + """ + if not self.highlights: + msg = "At least one highlight entry is required to show the issue location." + raise ValueError(msg) + + for highlight in self.highlights: + if "message" not in highlight or not str(highlight["message"]).strip(): + msg = "Each highlight entry must contain a descriptive 'message'." + raise ValueError(msg) + + def encode(self) -> dict[str, Any]: + """Encode the 'highlightError' DAP event message as a dictionary. + + Returns: + dict[str, Any]: The encoded DAP event payload. + """ + encoded = super().encode() + encoded["body"] = {"highlights": self.highlights, "source": self.source} + return encoded + + @staticmethod + def _normalize_highlight(entry: Mapping[str, Any]) -> dict[str, Any]: + """Return a shallow copy of a highlight entry with guaranteed structure. + + Args: + entry (Mapping[str, Any]): Highlight metadata including a range mapping. + + Returns: + dict[str, Any]: A normalized highlight entry suitable for serialization. + + Raises: + TypeError: If the range mapping or its positions are not mappings. + ValueError: If required fields are missing or malformed. + """ + if "range" not in entry: + msg = "A highlight entry must contain a 'range'." + raise ValueError(msg) + highlight_range = entry["range"] + if not isinstance(highlight_range, Mapping): + msg = "Highlight range must be a mapping with 'start' and 'end'." + raise TypeError(msg) + + start = HighlightError._normalize_position(highlight_range.get("start")) + end = HighlightError._normalize_position(highlight_range.get("end")) + if HighlightError._start_comes_after_end(start, end): + msg = "Highlight range 'end' must be after 'start'." + raise ValueError(msg) + + normalized = dict(entry) + normalized["instruction"] = int(normalized.get("instruction", -1)) + reason = normalized.get("reason", HighlightReason.UNKNOWN) + if isinstance(reason, HighlightReason): + normalized["reason"] = reason.value + else: + normalized["reason"] = str(reason) + normalized["code"] = str(normalized.get("code", "")) + normalized["message"] = str(normalized.get("message", "")).strip() + normalized["range"] = { + "start": start, + "end": end, + } + return normalized + + @staticmethod + def _normalize_position(position: Mapping[str, Any] | None) -> dict[str, int]: + """Normalize a position mapping, ensuring it includes a line and column. + + Args: + position (Mapping[str, Any] | None): The position mapping to normalize. + + Returns: + dict[str, int]: A normalized position with integer line and column. + + Raises: + TypeError: If the provided position is not a mapping. + ValueError: If required keys are missing. + """ + if not isinstance(position, Mapping): + msg = "Highlight positions must be mappings with 'line' and 'column'." + raise TypeError(msg) + try: + line = int(position["line"]) + column = int(position["column"]) + except (KeyError, TypeError, ValueError) as exc: + msg = "Highlight positions require 'line' and 'column'." + raise ValueError(msg) from exc + return { + "line": line, + "column": column, + } + + @staticmethod + def _normalize_source(source: Mapping[str, Any] | None) -> dict[str, Any]: + """Create a defensive copy of the provided DAP Source information. + + Args: + source (Mapping[str, Any] | None): The source mapping to normalize. + + Returns: + dict[str, Any]: Normalized source information with string fields. + + Raises: + TypeError: If the source is not a mapping. + ValueError: If required keys are missing. + """ + if not isinstance(source, Mapping): + msg = "Source information must be provided as a mapping." + raise TypeError(msg) + normalized = dict(source) + if "name" not in normalized or "path" not in normalized: + msg = "Source mappings must at least provide 'name' and 'path'." + raise ValueError(msg) + normalized["name"] = str(normalized["name"]) + normalized["path"] = str(normalized["path"]) + return normalized + + @staticmethod + def _start_comes_after_end(start: Mapping[str, Any], end: Mapping[str, Any]) -> bool: + """Return True if 'start' describes a position after 'end'. + + Args: + start (Mapping[str, Any]): The start position mapping. + end (Mapping[str, Any]): The end position mapping. + + Returns: + bool: True when the start position is after the end position. + """ + start_line = int(start.get("line", 0)) + start_column = int(start.get("column", 0)) + end_line = int(end.get("line", 0)) + end_column = int(end.get("column", 0)) + return (end_line, end_column) < (start_line, start_column) diff --git a/python/mqt/debugger/dap/messages/launch_dap_message.py b/python/mqt/debugger/dap/messages/launch_dap_message.py index be3fd6e3..6830a370 100644 --- a/python/mqt/debugger/dap/messages/launch_dap_message.py +++ b/python/mqt/debugger/dap/messages/launch_dap_message.py @@ -10,10 +10,13 @@ from __future__ import annotations +import contextlib import locale from pathlib import Path from typing import TYPE_CHECKING, Any +import mqt.debugger + from .dap_message import DAPMessage if TYPE_CHECKING: @@ -63,21 +66,22 @@ def handle(self, server: DAPServer) -> dict[str, Any]: dict[str, Any]: The response to the request. """ program_path = Path(self.program) - code = program_path.read_text(encoding=locale.getpreferredencoding(False)) + server.source_file = {"name": program_path.name, "path": self.program} + parsed_successfully = True + code = program_path.read_text(encoding=locale.getpreferredencoding(do_setlocale=False)) server.source_code = code - try: - server.simulation_state.load_code(code) - except RuntimeError: - return { - "type": "response", - "request_seq": self.sequence_number, - "success": False, - "command": "launch", - "message": "An error occurred while parsing the code.", - } - if not self.stop_on_entry: + load_result = server.simulation_state.load_code(code) + if load_result.status != mqt.debugger.LoadResultStatus.OK: + parsed_successfully = False + line = load_result.line if load_result.line > 0 else None + column = load_result.column if load_result.column > 0 else None + message = str(load_result.message or "") + server.queue_parse_error(message, line, column) + if parsed_successfully and not self.stop_on_entry: server.simulation_state.run_simulation() - server.source_file = {"name": program_path.name, "path": self.program} + if not parsed_successfully: + with contextlib.suppress(RuntimeError): + server.simulation_state.reset_simulation() return { "type": "response", "request_seq": self.sequence_number, diff --git a/python/mqt/debugger/dap/messages/restart_dap_message.py b/python/mqt/debugger/dap/messages/restart_dap_message.py index 7411887d..3e563d5d 100644 --- a/python/mqt/debugger/dap/messages/restart_dap_message.py +++ b/python/mqt/debugger/dap/messages/restart_dap_message.py @@ -10,10 +10,13 @@ from __future__ import annotations +import contextlib import locale from pathlib import Path from typing import TYPE_CHECKING, Any +import mqt.debugger + from .dap_message import DAPMessage if TYPE_CHECKING: @@ -64,12 +67,22 @@ def handle(self, server: DAPServer) -> dict[str, Any]: """ server.simulation_state.reset_simulation() program_path = Path(self.program) - code = program_path.read_text(encoding=locale.getpreferredencoding(False)) + server.source_file = {"name": program_path.name, "path": self.program} + parsed_successfully = True + code = program_path.read_text(encoding=locale.getpreferredencoding(do_setlocale=False)) server.source_code = code - server.simulation_state.load_code(code) - if not self.stop_on_entry: + load_result = server.simulation_state.load_code(code) + if load_result.status != mqt.debugger.LoadResultStatus.OK: + parsed_successfully = False + line = load_result.line if load_result.line > 0 else None + column = load_result.column if load_result.column > 0 else None + message = str(load_result.message or "") + server.queue_parse_error(message, line, column) + if parsed_successfully and not self.stop_on_entry: server.simulation_state.run_simulation() - server.source_file = {"name": program_path.name, "path": self.program} + if not parsed_successfully: + with contextlib.suppress(RuntimeError): + server.simulation_state.reset_simulation() return { "type": "response", "request_seq": self.sequence_number, diff --git a/python/mqt/debugger/pydebugger.pyi b/python/mqt/debugger/pydebugger.pyi index cfc4f0e1..30386e54 100644 --- a/python/mqt/debugger/pydebugger.pyi +++ b/python/mqt/debugger/pydebugger.pyi @@ -154,6 +154,53 @@ class Diagnostics: A list of new assertions. """ +class Result(enum.Enum): + """Represents the result of an operation.""" + + OK = 0 + """Indicates that the operation was successful.""" + + ERROR = 1 + """Indicates that an error occurred.""" + +class LoadResultStatus(enum.Enum): + """Represents the result of a code loading operation.""" + + OK = 0 + """Indicates that the code was loaded successfully.""" + + PARSE_ERROR = 1 + """Indicates that the code could not be parsed.""" + + INTERNAL_ERROR = 2 + """Indicates that an internal error occurred while loading.""" + +class LoadResult: + """The result of a code loading operation.""" + + def __init__(self) -> None: ... + @property + def status(self) -> LoadResultStatus: + """Indicates whether the load was successful and why it failed.""" + + @status.setter + def status(self, arg: LoadResultStatus, /) -> None: ... + @property + def line(self) -> int: + """The line number of the error location, or 0 if unknown.""" + + @line.setter + def line(self, arg: int, /) -> None: ... + @property + def column(self) -> int: + """The column number of the error location, or 0 if unknown.""" + + @column.setter + def column(self, arg: int, /) -> None: ... + @property + def message(self) -> object: + """A human-readable error message, or None if none is available.""" + class VariableType(enum.Enum): """The type of a classical variable.""" @@ -313,11 +360,14 @@ class SimulationState: def init(self) -> None: """Initializes the simulation state.""" - def load_code(self, code: str) -> None: + def load_code(self, code: str) -> LoadResult: """Loads the given code into the simulation state. Args: code: The code to load. + + Returns: + LoadResult: The result of the load operation. """ def step_forward(self) -> None: diff --git a/src/backend/dd/DDSimDebug.cpp b/src/backend/dd/DDSimDebug.cpp index c807ad6a..1bed2617 100644 --- a/src/backend/dd/DDSimDebug.cpp +++ b/src/backend/dd/DDSimDebug.cpp @@ -25,6 +25,7 @@ #include "common/parsing/AssertionParsing.hpp" #include "common/parsing/AssertionTools.hpp" #include "common/parsing/CodePreprocessing.hpp" +#include "common/parsing/ParsingError.hpp" #include "common/parsing/Utils.hpp" #include "dd/DDDefinitions.hpp" #include "dd/Operations.hpp" @@ -48,12 +49,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -61,6 +64,11 @@ namespace mqt::debugger { namespace { +size_t boundedStrnlen(const char* data, size_t max) { + const auto* end = static_cast(std::memchr(data, '\0', max)); + return end != nullptr ? static_cast(end - data) : max; +} + /** * @brief Cast a `SimulationState` pointer to a `DDSimulationState` pointer. * @@ -75,6 +83,59 @@ DDSimulationState* toDDSimulationState(SimulationState* state) { // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast) } +struct DDSimulationStateGuard { + explicit DDSimulationStateGuard(DDSimulationState* state) : state(state) {} + DDSimulationStateGuard(const DDSimulationStateGuard&) = delete; + DDSimulationStateGuard& operator=(const DDSimulationStateGuard&) = delete; + ~DDSimulationStateGuard() { + if (state != nullptr) { + destroyDDSimulationState(state); + } + } + DDSimulationState* state; +}; + +/** + * @brief Evaluate a classic-controlled condition from the original code. + * @param ddsim The simulation state. + * @param instructionIndex The instruction index to inspect. + * @return The evaluated condition, or std::nullopt if it cannot be evaluated. + */ +std::optional evaluateClassicConditionFromCode(DDSimulationState* ddsim, + size_t instructionIndex) { + if (instructionIndex >= ddsim->instructionObjects.size()) { + return std::nullopt; + } + const auto& code = ddsim->instructionObjects[instructionIndex].code; + const auto parsed = parseClassicConditionFromCode(code); + if (!parsed.has_value()) { + return std::nullopt; + } + + size_t registerValue = 0; + if (parsed->bitIndex.has_value()) { + const auto bitName = parsed->registerName + "[" + + std::to_string(parsed->bitIndex.value()) + "]"; + const auto& value = ddsim->variables[bitName].value.boolValue; + registerValue = value ? 1ULL : 0ULL; + } else { + const auto regIt = std::ranges::find_if( + ddsim->classicalRegisters, [&parsed](const auto& reg) { + return reg.name == parsed->registerName; + }); + if (regIt == ddsim->classicalRegisters.end()) { + return std::nullopt; + } + for (size_t i = 0; i < regIt->size; i++) { + const auto name = getClassicalBitName(ddsim, regIt->index + i); + const auto& value = ddsim->variables[name].value.boolValue; + registerValue |= (value ? 1ULL : 0ULL) << i; + } + } + + return registerValue == parsed->expectedValue; +} + /** * @brief Generate a random number between 0 and 1. * @@ -251,11 +312,23 @@ bool checkAssertionEqualityCircuit( } DDSimulationState secondSimulation; - createDDSimulationState(&secondSimulation); - secondSimulation.interface.loadCode(&secondSimulation.interface, - assertion->getCircuitCode().c_str()); + if (createDDSimulationState(&secondSimulation) == ERROR) { + throw std::runtime_error( + "Failed to initialize simulation for equality assertion."); + } + const DDSimulationStateGuard secondSimulationGuard(&secondSimulation); + const auto loadResult = secondSimulation.interface.loadCode( + &secondSimulation.interface, assertion->getCircuitCode().c_str()); + if (loadResult.status != LOAD_OK) { + const auto* data = std::data(loadResult.message); + const std::string_view messageView( + data, boundedStrnlen(data, LOAD_RESULT_MESSAGE_MAX)); + throw std::runtime_error( + !messageView.empty() + ? std::string(messageView) + : "Failed to load circuit for equality assertion."); + } if (!secondSimulation.assertionInstructions.empty()) { - destroyDDSimulationState(&secondSimulation); throw std::runtime_error( "Circuit equality assertions cannot contain nested assertions"); } @@ -269,7 +342,6 @@ bool checkAssertionEqualityCircuit( sv2.amplitudes = amplitudes2.data(); secondSimulation.interface.getStateVectorFull(&secondSimulation.interface, &sv2); - destroyDDSimulationState(&secondSimulation); Statevector sv; sv.numQubits = qubits.size(); @@ -467,6 +539,27 @@ bool areAssertionsIndependent(DDSimulationState* ddsim, }); } +void setLoadResultMessage(LoadResult& result, const std::string& message) { + result.message[0] = '\0'; + if (message.empty()) { + return; + } + const auto copyLen = std::min( + message.size(), static_cast(LOAD_RESULT_MESSAGE_MAX - 1)); + std::copy_n(message.data(), copyLen, std::begin(result.message)); + result.message[LOAD_RESULT_MESSAGE_MAX - 1] = '\0'; +} + +LoadResult makeLoadResult(LoadResultStatus status, size_t line, size_t column, + const std::string& message) { + LoadResult result{}; + result.status = status; + result.line = line; + result.column = column; + setLoadResultMessage(result, message); + return result; +} + /** * @brief Compile an assertion using projective measurements. * @param ddsim The simulation state. @@ -583,25 +676,46 @@ Result ddsimInit(SimulationState* self) { return OK; } -Result ddsimLoadCode(SimulationState* self, const char* code) { +LoadResult ddsimLoadCode(SimulationState* self, const char* code) { auto* ddsim = toDDSimulationState(self); ddsim->currentInstruction = 0; ddsim->previousInstructionStack.clear(); ddsim->callReturnStack.clear(); ddsim->callSubstitutions.clear(); ddsim->restoreCallReturnStack.clear(); + ddsim->ready = false; ddsim->code = code; ddsim->variables.clear(); ddsim->variableNames.clear(); + ddsim->instructionTypes.clear(); + ddsim->instructionStarts.clear(); + ddsim->instructionEnds.clear(); + ddsim->functionDefinitions.clear(); + ddsim->assertionInstructions.clear(); + ddsim->successorInstructions.clear(); + ddsim->classicalRegisters.clear(); + ddsim->qubitRegisters.clear(); + ddsim->dataDependencies.clear(); + ddsim->functionCallers.clear(); + ddsim->targetQubits.clear(); + ddsim->instructionObjects.clear(); try { std::stringstream ss{preprocessAssertionCode(code, ddsim)}; const auto imported = qasm3::Importer::import(ss); ddsim->qc = std::make_unique(imported); qc::CircuitOptimizer::flattenOperations(*ddsim->qc, true); + } catch (const ParsingError& e) { + return makeLoadResult(LOAD_PARSE_ERROR, e.line(), e.column(), e.detail()); } catch (const std::exception& e) { - std::cerr << e.what() << "\n"; - return ERROR; + std::string message = e.what(); + if (message.empty()) { + message = "An error occurred while executing the operation"; + } + return makeLoadResult(LOAD_INTERNAL_ERROR, 0, 0, message); + } catch (...) { + return makeLoadResult(LOAD_INTERNAL_ERROR, 0, 0, + "An error occurred while executing the operation"); } ddsim->iterator = ddsim->qc->begin(); @@ -613,7 +727,7 @@ Result ddsimLoadCode(SimulationState* self, const char* code) { ddsim->ready = true; - return OK; + return makeLoadResult(LOAD_OK, 0, 0, ""); } Result ddsimChangeClassicalVariableValue(SimulationState* self, @@ -994,20 +1108,31 @@ Result ddsimStepForward(SimulationState* self) { throw std::runtime_error("If-else operations with non-equality " "comparisons are currently not supported"); } - if (op->getControlBit().has_value()) { - throw std::runtime_error("If-else operations controlled by a single " - "classical bit are currently not supported"); - } - const auto& controls = op->getControlRegister(); - const auto& exp = op->getExpectedValueRegister(); - size_t registerValue = 0; - for (size_t i = 0; i < controls->getSize(); i++) { - const auto name = - getClassicalBitName(ddsim, controls->getStartIndex() + i); - const auto& value = ddsim->variables[name].value.boolValue; - registerValue |= (value ? 1ULL : 0ULL) << i; + const auto condition = + evaluateClassicConditionFromCode(ddsim, currentInstruction); + bool conditionMet = false; + if (condition.has_value()) { + conditionMet = condition.value(); + } else { + const auto& exp = op->getExpectedValueRegister(); + size_t registerValue = 0; + if (op->getControlBit().has_value()) { + const auto controlBit = op->getControlBit().value(); + const auto name = getClassicalBitName(ddsim, controlBit); + const auto& value = ddsim->variables[name].value.boolValue; + registerValue = value ? 1ULL : 0ULL; + } else { + const auto& controls = op->getControlRegister(); + for (size_t i = 0; i < controls->getSize(); i++) { + const auto name = + getClassicalBitName(ddsim, controls->getStartIndex() + i); + const auto& value = ddsim->variables[name].value.boolValue; + registerValue |= (value ? 1ULL : 0ULL) << i; + } + } + conditionMet = (registerValue == exp); } - if (registerValue == exp) { + if (conditionMet) { auto* thenOp = op->getThenOp(); currDD = dd::getDD(*thenOp, *ddsim->dd); } else if (op->getElseOp() != nullptr) { @@ -1081,20 +1206,31 @@ Result ddsimStepBackward(SimulationState* self) { throw std::runtime_error("If-else operations with non-equality " "comparisons are currently not supported"); } - if (op->getControlBit().has_value()) { - throw std::runtime_error("If-else operations controlled by a single " - "classical bit are currently not supported"); - } - const auto& controls = op->getControlRegister(); - const auto& exp = op->getExpectedValueRegister(); - size_t registerValue = 0; - for (size_t i = 0; i < controls->getSize(); i++) { - const auto name = - getClassicalBitName(ddsim, controls->getStartIndex() + i); - const auto& value = ddsim->variables[name].value.boolValue; - registerValue |= (value ? 1ULL : 0ULL) << i; + const auto condition = + evaluateClassicConditionFromCode(ddsim, ddsim->currentInstruction); + bool conditionMet = false; + if (condition.has_value()) { + conditionMet = condition.value(); + } else { + const auto& exp = op->getExpectedValueRegister(); + size_t registerValue = 0; + if (op->getControlBit().has_value()) { + const auto controlBit = op->getControlBit().value(); + const auto name = getClassicalBitName(ddsim, controlBit); + const auto& value = ddsim->variables[name].value.boolValue; + registerValue = value ? 1ULL : 0ULL; + } else { + const auto& controls = op->getControlRegister(); + for (size_t i = 0; i < controls->getSize(); i++) { + const auto name = + getClassicalBitName(ddsim, controls->getStartIndex() + i); + const auto& value = ddsim->variables[name].value.boolValue; + registerValue |= (value ? 1ULL : 0ULL) << i; + } + } + conditionMet = (registerValue == exp); } - if (registerValue == exp) { + if (conditionMet) { auto* thenOp = op->getThenOp(); currDD = dd::getInverseDD(*thenOp, *ddsim->dd); } else if (op->getElseOp() != nullptr) { @@ -1119,6 +1255,10 @@ Result ddsimStepBackward(SimulationState* self) { } Result ddsimRunAll(SimulationState* self, size_t* failedAssertions) { + auto* ddsim = toDDSimulationState(self); + if (!ddsim->ready) { + return ERROR; + } size_t errorCount = 0; while (!self->isFinished(self)) { const Result result = self->runSimulation(self); @@ -1129,7 +1269,9 @@ Result ddsimRunAll(SimulationState* self, size_t* failedAssertions) { errorCount++; } } - *failedAssertions = errorCount; + if (failedAssertions != nullptr) { + *failedAssertions = errorCount; + } return OK; } @@ -1389,6 +1531,11 @@ Result ddsimSetBreakpoint(SimulationState* self, size_t desiredPosition, for (auto i = 0ULL; i < ddsim->instructionTypes.size(); i++) { const size_t start = ddsim->instructionStarts[i]; const size_t end = ddsim->instructionEnds[i]; + if (desiredPosition < start) { + *targetInstruction = i; + ddsim->breakpoints.insert(i); + return OK; + } if (desiredPosition >= start && desiredPosition <= end) { if (ddsim->functionDefinitions.contains(i)) { // Breakpoint may be located in a sub-gate of the gate definition. diff --git a/src/common/parsing/CodePreprocessing.cpp b/src/common/parsing/CodePreprocessing.cpp index c0ce03ff..4c0b3112 100644 --- a/src/common/parsing/CodePreprocessing.cpp +++ b/src/common/parsing/CodePreprocessing.cpp @@ -20,11 +20,15 @@ #include "common/parsing/Utils.hpp" #include +#include #include +#include #include #include #include +#include #include +#include #include #include #include @@ -33,6 +37,176 @@ namespace mqt::debugger { namespace { +/** + * @brief Check whether a string is non-empty and contains only digits. + * @param text The string to validate. + * @return True if the string is non-empty and all characters are digits. + */ +bool isDigits(const std::string& text) { + if (text.empty()) { + return false; + } + return std::ranges::all_of( + text, [](unsigned char c) { return std::isdigit(c) != 0; }); +} + +/** + * @brief 1-based line/column location within source text. + */ +struct LineColumn { + size_t line = 1; + size_t column = 1; +}; + +/** + * @brief Compute the 1-based line and column for a given character offset. + * @param code The source code to inspect. + * @param offset The zero-based character offset in the source code. + * @return The line and column of the offset in the source code. + */ +LineColumn lineColumnForOffset(const std::string& code, size_t offset) { + LineColumn location; + const auto lineStartPos = code.rfind('\n', offset); + const size_t lineStart = (lineStartPos == std::string::npos) + ? 0 + : static_cast(lineStartPos + 1); + location.line = 1; + for (size_t i = 0; i < lineStart; i++) { + if (code[i] == '\n') { + location.line++; + } + } + location.column = offset - lineStart + 1; + return location; +} + +/** + * @brief Compute the 1-based line and column for a target within a line. + * @param code The source code to inspect. + * @param instructionStart The zero-based offset of the instruction start. + * @param target The target token to locate on the line. + * @return The line and column of the target, or the first non-space column. + */ +LineColumn lineColumnForTarget(const std::string& code, size_t instructionStart, + const std::string& target) { + LineColumn location = lineColumnForOffset(code, instructionStart); + const auto lineStartPos = code.rfind('\n', instructionStart); + const size_t lineStart = (lineStartPos == std::string::npos) + ? 0 + : static_cast(lineStartPos + 1); + auto lineEndPos = code.find('\n', instructionStart); + const size_t lineEnd = (lineEndPos == std::string::npos) + ? code.size() + : static_cast(lineEndPos); + const auto lineText = code.substr(lineStart, lineEnd - lineStart); + if (!target.empty()) { + const auto targetPos = lineText.find(target); + if (targetPos != std::string::npos) { + location.column = targetPos + 1; + return location; + } + } + const auto nonSpace = lineText.find_first_not_of(" \t"); + if (nonSpace != std::string::npos) { + location.column = nonSpace + 1; + } + return location; +} + +ParsingError makeParseError(const std::string& code, size_t instructionStart, + const std::string& detail, + const std::string& target = "") { + const auto location = lineColumnForTarget(code, instructionStart, target); + const std::string message = ":" + std::to_string(location.line) + ":" + + std::to_string(location.column) + ": " + detail; + return {location.line, location.column, detail, message}; +} + +/** + * @brief Build an error detail string for an invalid target. + * @param target The invalid target token. + * @param context Additional context to append. + * @return The formatted detail string. + */ +std::string invalidTargetDetail(const std::string& target, + const std::string& context) { + std::string detail = "Invalid target qubit "; + detail += target; + detail += context; + detail += "."; + return detail; +} + +/** + * @brief Build an error detail string for an invalid register declaration. + * @param trimmedLine The register declaration line. + * @return The formatted detail string. + */ +std::string invalidRegisterDetail(const std::string& trimmedLine) { + std::string detail = "Invalid register declaration "; + detail += trimmedLine; + detail += "."; + return detail; +} + +/** + * @brief Validate target references against known registers and indices. + * @param code The source code to inspect. + * @param instructionStart The zero-based offset of the instruction start. + * @param targets The target tokens to validate. + * @param definedRegisters The registers defined in the current scope. + * @param shadowedRegisters The shadowed register names in the current scope. + * @param context Additional context to append to error messages. + */ +void validateTargets(const std::string& code, size_t instructionStart, + const std::vector& targets, + const std::map& definedRegisters, + const std::vector& shadowedRegisters, + const std::string& context) { + for (const auto& target : targets) { + if (target.empty()) { + std::string detail = "Empty target"; + detail += context; + detail += "."; + throw makeParseError(code, instructionStart, detail); + } + const auto open = target.find('['); + if (open == std::string::npos) { + continue; + } + const auto close = target.find(']', open + 1); + if (open == 0 || close == std::string::npos || close != target.size() - 1) { + throw makeParseError(code, instructionStart, + invalidTargetDetail(target, context), target); + } + const auto registerName = target.substr(0, open); + const auto indexText = target.substr(open + 1, close - open - 1); + if (!isDigits(indexText)) { + throw makeParseError(code, instructionStart, + invalidTargetDetail(target, context), target); + } + size_t registerIndex = 0; + try { + registerIndex = std::stoul(indexText); + } catch (const std::invalid_argument&) { + throw makeParseError(code, instructionStart, + invalidTargetDetail(target, context), target); + } catch (const std::out_of_range&) { + throw makeParseError(code, instructionStart, + invalidTargetDetail(target, context), target); + } + if (std::ranges::find(shadowedRegisters, registerName) != + shadowedRegisters.end()) { + continue; + } + const auto found = definedRegisters.find(registerName); + if (found == definedRegisters.end() || found->second <= registerIndex) { + throw makeParseError(code, instructionStart, + invalidTargetDetail(target, context), target); + } + } +} + /** * @brief Sweep a given code string for blocks and replace them with a unique * identifier. @@ -230,6 +404,72 @@ ClassicControlledGate parseClassicControlledGate(const std::string& code) { return {.condition = condition.str(), .operations = operations}; } +std::optional +parseClassicConditionExpression(const std::string& condition) { + auto normalized = removeWhitespace(condition); + if (!normalized.empty() && normalized.front() == '(') { + normalized.erase(0, 1); + } + const auto eqPos = normalized.find("=="); + if (eqPos == std::string::npos) { + return std::nullopt; + } + const auto lhs = normalized.substr(0, eqPos); + const auto rhs = normalized.substr(eqPos + 2); + if (lhs.empty() || rhs.empty()) { + return std::nullopt; + } + + if (!isDigits(rhs)) { + return std::nullopt; + } + size_t expected = 0; + try { + expected = std::stoull(rhs); + } catch (const std::invalid_argument&) { + return std::nullopt; + } catch (const std::out_of_range&) { + return std::nullopt; + } + + const auto bracketPos = lhs.find('['); + if (bracketPos != std::string::npos) { + const auto closePos = lhs.find(']', bracketPos + 1); + if (bracketPos == 0 || closePos == std::string::npos || + closePos != lhs.size() - 1) { + return std::nullopt; + } + const auto base = lhs.substr(0, bracketPos); + const auto indexText = + lhs.substr(bracketPos + 1, closePos - bracketPos - 1); + if (!isDigits(indexText)) { + return std::nullopt; + } + size_t bitIndex = 0; + try { + bitIndex = std::stoull(indexText); + } catch (const std::invalid_argument&) { + return std::nullopt; + } catch (const std::out_of_range&) { + return std::nullopt; + } + return ClassicCondition{ + .registerName = base, .bitIndex = bitIndex, .expectedValue = expected}; + } + + return ClassicCondition{ + .registerName = lhs, .bitIndex = std::nullopt, .expectedValue = expected}; +} + +std::optional +parseClassicConditionFromCode(const std::string& code) { + if (!isClassicControlledGate(code)) { + return std::nullopt; + } + const auto condition = parseClassicControlledGate(code).condition; + return parseClassicConditionExpression(condition); +} + bool isMeasurement(const std::string& line) { return line.find("->") != std::string::npos; } @@ -250,10 +490,13 @@ std::vector parseParameters(const std::string& instruction) { } if (isClassicControlledGate(instruction)) { - const auto end = instruction.find(')'); - - return parseParameters( - instruction.substr(end + 1, instruction.length() - end - 1)); + const auto classic = parseClassicControlledGate(instruction); + std::vector parameters; + for (const auto& op : classic.operations) { + const auto targets = parseParameters(op); + parameters.insert(parameters.end(), targets.begin(), targets.end()); + } + return parameters; } auto parts = splitString( @@ -332,7 +575,11 @@ preprocessCode(const std::string& code, size_t startIndex, auto isAssert = isAssertion(line); auto blockPos = line.find("$__block"); - const size_t trueStart = pos + blocksOffset; + const auto leadingPos = blocksRemoved.find_first_not_of(" \t\r\n", pos); + const size_t trueStart = + ((leadingPos != std::string::npos && leadingPos < end) ? leadingPos + : pos) + + blocksOffset; Block block{.valid = false, .code = ""}; if (blockPos != std::string::npos) { @@ -349,6 +596,12 @@ preprocessCode(const std::string& code, size_t startIndex, line.replace(blockPos, endPos - blockPos + 1, ""); } + if (block.valid && isClassicControlledGate(line)) { + line.append(" { ").append(block.code).append(" }"); + block.valid = false; + block.code.clear(); + } + const auto targets = parseParameters(line); const size_t trueEnd = end + blocksOffset; @@ -358,7 +611,18 @@ preprocessCode(const std::string& code, size_t startIndex, replaceString(replaceString(trimmedLine, "creg", ""), "qreg", "")); const auto parts = splitString(declaration, {'[', ']'}); const auto& name = parts[0]; - const auto size = std::stoi(parts[1]); + const auto sizeText = parts.size() > 1 ? parts[1] : ""; + if (name.empty() || !isDigits(sizeText)) { + throw makeParseError(code, trueStart, + invalidRegisterDetail(trimmedLine)); + } + size_t size = 0; + try { + size = std::stoul(sizeText); + } catch (const std::exception&) { + throw makeParseError(code, trueStart, + invalidRegisterDetail(trimmedLine)); + } definedRegisters.insert({name, size}); } @@ -403,14 +667,6 @@ preprocessCode(const std::string& code, size_t startIndex, continue; } - if (isClassicControlledGate(line)) { - if (block.valid) { - throw ParsingError( - "Classic-controlled gates with body blocks are not supported. Use " - "individual `if` statements for each operation."); - } - } - bool isFunctionCall = false; std::string calledFunction; if (!tokens.empty() && @@ -423,25 +679,16 @@ preprocessCode(const std::string& code, size_t startIndex, auto a = parseAssertion(line, block.code); unfoldAssertionTargetRegisters(*a, definedRegisters, shadowedRegisters); a->validate(); - for (const auto& target : a->getTargetQubits()) { - if (std::ranges::find(shadowedRegisters, target) != - shadowedRegisters.end()) { - continue; - } - const auto registerName = variableBaseName(target); - const auto registerIndex = - std::stoul(splitString(splitString(target, '[')[1], ']')[0]); - - if (!definedRegisters.contains(registerName) || - definedRegisters[registerName] <= registerIndex) { - throw ParsingError("Invalid target qubit " + target + - " in assertion."); - } - } + validateTargets(code, trueStart, a->getTargetQubits(), definedRegisters, + shadowedRegisters, " in assertion"); instructions.emplace_back(i, line, a, a->getTargetQubits(), trueStart, trueEnd, i + 1, isFunctionCall, calledFunction, false, false, block); } else { + if (!isVariableDeclaration(line)) { + validateTargets(code, trueStart, targets, definedRegisters, + shadowedRegisters, ""); + } std::unique_ptr a(nullptr); instructions.emplace_back(i, line, a, targets, trueStart, trueEnd, i + 1, isFunctionCall, calledFunction, false, false, diff --git a/src/common/parsing/ParsingError.cpp b/src/common/parsing/ParsingError.cpp index 0b60a322..50634bd6 100644 --- a/src/common/parsing/ParsingError.cpp +++ b/src/common/parsing/ParsingError.cpp @@ -15,11 +15,29 @@ #include "common/parsing/ParsingError.hpp" +#include #include #include +#include namespace mqt::debugger { -ParsingError::ParsingError(const std::string& msg) : std::runtime_error(msg) {} +ParsingError::ParsingError(const std::string& msg) + : std::runtime_error(msg), detail_(msg) {} + +ParsingError::ParsingError(size_t line, size_t column, std::string detail) + : std::runtime_error(detail), line_(line), column_(column), + detail_(std::move(detail)) {} + +ParsingError::ParsingError(size_t line, size_t column, std::string detail, + const std::string& message) + : std::runtime_error(message), line_(line), column_(column), + detail_(std::move(detail)) {} + +size_t ParsingError::line() const noexcept { return line_; } + +size_t ParsingError::column() const noexcept { return column_; } + +const std::string& ParsingError::detail() const noexcept { return detail_; } } // namespace mqt::debugger diff --git a/src/frontend/cli/CliFrontEnd.cpp b/src/frontend/cli/CliFrontEnd.cpp index e81f760f..52d14755 100644 --- a/src/frontend/cli/CliFrontEnd.cpp +++ b/src/frontend/cli/CliFrontEnd.cpp @@ -22,15 +22,27 @@ #include #include #include +#include #include #include #include +#include #include namespace mqt::debugger { namespace { +size_t boundedStrnlen(const char* data, size_t max) { + const auto* end = static_cast(std::memchr(data, '\0', max)); + return end != nullptr ? static_cast(end - data) : max; +} + +std::string_view loadResultMessageView(const LoadResult& result) { + const auto* data = std::data(result.message); + return {data, boundedStrnlen(data, LOAD_RESULT_MESSAGE_MAX)}; +} + /** * @brief ANSI escape sequence for resetting the background color. * @@ -68,8 +80,13 @@ void CliFrontEnd::run(const char* code, SimulationState* state) { std::string command; const auto result = state->loadCode(state, code); state->resetSimulation(state); - if (result == ERROR) { - std::cout << "Error loading code\n"; + if (result.status != LOAD_OK) { + const auto messageView = loadResultMessageView(result); + if (!messageView.empty()) { + std::cout << "Error loading code: " << messageView << "\n"; + } else { + std::cout << "Error loading code\n"; + } return; } diff --git a/test/python/test_compilation.py b/test/python/test_compilation.py index a26f4a99..07d8c695 100644 --- a/test/python/test_compilation.py +++ b/test/python/test_compilation.py @@ -23,8 +23,9 @@ import pytest +import mqt.debugger as dbg from mqt.debugger import check -from mqt.debugger.check import result_checker, runtime_check +from mqt.debugger.check import result_checker, run_preparation, runtime_check if TYPE_CHECKING: import types @@ -199,6 +200,48 @@ def test_incorrect_good_sample_size(compiled_slice_1: str) -> None: assert errors >= 75 +def test_start_compilation_raises_on_invalid_code(tmp_path: Path) -> None: + """Ensure invalid input code surfaces as a RuntimeError.""" + invalid_code = tmp_path / "invalid.qasm" + invalid_code.write_text("INVALID QASM", encoding="utf-8") + + output_dir = tmp_path / "out" + output_dir.mkdir() + + with pytest.raises(RuntimeError): + check.start_compilation(invalid_code, output_dir) + + +def test_start_compilation_raises_on_load_error(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure load errors surface as RuntimeError with the provided message.""" + + class DummyLoadResult: + status = dbg.LoadResultStatus.PARSE_ERROR + message = "Bad input" + + class DummyState: + @staticmethod + def load_code(_code: str) -> DummyLoadResult: + return DummyLoadResult() + + @staticmethod + def compile(_settings: dbg.CompilationSettings) -> str: + msg = "compile should not be called" + raise AssertionError(msg) + + monkeypatch.setattr(run_preparation.dbg, "create_ddsim_simulation_state", DummyState) + monkeypatch.setattr(run_preparation.dbg, "destroy_ddsim_simulation_state", lambda _state: None) + + invalid_code = tmp_path / "invalid.qasm" + invalid_code.write_text("INVALID QASM", encoding="utf-8") + + output_dir = tmp_path / "out" + output_dir.mkdir() + + with pytest.raises(RuntimeError, match="Bad input"): + check.start_compilation(invalid_code, output_dir) + + def test_sample_estimate(compiled_slice_1: str) -> None: """Test the estimation of required shots. diff --git a/test/python/test_dap_server.py b/test/python/test_dap_server.py new file mode 100644 index 00000000..229b8bdf --- /dev/null +++ b/test/python/test_dap_server.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +"""Tests for the DAP server helper utilities.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from mqt.debugger.dap.dap_server import DAPServer + + +def test_code_pos_to_coordinates_handles_line_end() -> None: + """Ensure coordinates for newline positions stay on the current line.""" + server = DAPServer() + server.source_code = "measure q[0] -> c[0];\nmeasure q[1] -> c[1];\n" + line, column = server.code_pos_to_coordinates(server.source_code.index("\n")) + assert line == 1 + # Column is 1-based because the DAP client requests it that way. + assert column == len("measure q[0] -> c[0];") + 1 + + +def test_build_highlight_entry_does_not_span_next_instruction() -> None: + """Ensure highlight ranges stop at the end of the instruction.""" + server = DAPServer() + server.source_code = "measure q[0] -> c[0];\nmeasure q[1] -> c[1];\n" + first_line_end = server.source_code.index("\n") + fake_diagnostics = SimpleNamespace(potential_error_causes=list) + fake_state = SimpleNamespace( + get_instruction_position=lambda _instr: (0, first_line_end), + get_diagnostics=lambda: fake_diagnostics, + ) + server.simulation_state = fake_state # type: ignore[assignment] + + entries = server.collect_highlight_entries(0) + assert entries + entry = entries[0] + assert entry["range"]["start"]["line"] == 1 + assert entry["range"]["end"]["line"] == 1 diff --git a/test/test_custom_code.cpp b/test/test_custom_code.cpp index bf1bd69b..8528bbb6 100644 --- a/test/test_custom_code.cpp +++ b/test/test_custom_code.cpp @@ -86,8 +86,13 @@ TEST_F(CustomCodeTest, IfElseOperationMulti) { loadCode(2, 1, "x q[0];" "measure q[0] -> c[0];" - "if(c==1) { x q[0]; x q[1]; }", - true); + "if(c==1) { x q[0]; x q[1]; }"); + ASSERT_EQ(state->runSimulation(state), OK); + + std::array amplitudes{}; + Statevector sv{2, 4, amplitudes.data()}; + state->getStateVectorFull(state, &sv); + ASSERT_TRUE(complexEquality(amplitudes[2], 1, 0.0)); } /** diff --git a/test/test_simulation.cpp b/test/test_simulation.cpp index 1e575cd6..29130cfd 100644 --- a/test/test_simulation.cpp +++ b/test/test_simulation.cpp @@ -60,7 +60,8 @@ class SimulationTest : public testing::TestWithParam { */ void loadFromFile(const std::string& testName) { const auto code = readFromCircuitsPath(testName); - state->loadCode(state, code.c_str()); + const auto result = state->loadCode(state, code.c_str()); + ASSERT_EQ(result.status, OK); } /** diff --git a/test/utils/common_fixtures.hpp b/test/utils/common_fixtures.hpp index 3fb0c1ca..b138c6c7 100644 --- a/test/utils/common_fixtures.hpp +++ b/test/utils/common_fixtures.hpp @@ -123,8 +123,12 @@ class CustomCodeFixture : public testing::Test { bool shouldFail = false, const char* preamble = "") { userCode = code; fullCode = addBoilerplate(numQubits, numClassics, code, preamble); - ASSERT_EQ(state->loadCode(state, fullCode.c_str()), - shouldFail ? ERROR : OK); + const auto result = state->loadCode(state, fullCode.c_str()); + if (shouldFail) { + ASSERT_NE(result.status, LOAD_OK); + } else { + ASSERT_EQ(result.status, LOAD_OK); + } } /** @@ -181,7 +185,8 @@ class LoadFromFileFixture : public virtual testing::Test { */ void loadFromFile(const std::string& testName) { const auto code = readFromCircuitsPath(testName); - state->loadCode(state, code.c_str()); + const auto result = state->loadCode(state, code.c_str()); + ASSERT_EQ(result.status, LOAD_OK); } /**