diff --git a/.gitignore b/.gitignore index 7f9d3a91..d375e1b4 100644 --- a/.gitignore +++ b/.gitignore @@ -204,3 +204,11 @@ dmypy.json # Pyre type checker .pyre/ + +# Generated protobuf Python files (auto-generated by tests) +p-isa_tools/data_formats/python/heracles/proto/*_pb2.py +!p-isa_tools/data_formats/python/heracles/proto/__init__.py + +# Test artifacts +p-isa_tools/data_formats/test/*.program_trace +p-isa_tools/data_formats/test/*.data_trace* diff --git a/.typos.toml b/.typos.toml index ff6fa075..f7b5b957 100644 --- a/.typos.toml +++ b/.typos.toml @@ -8,6 +8,8 @@ # variation of params parms = "parms" bload = "bload" +ser = "ser" +SerType = "SerType" [files] extend-exclude = [ diff --git a/p-isa_tools/CPPLINT.cfg b/CPPLINT.cfg similarity index 100% rename from p-isa_tools/CPPLINT.cfg rename to CPPLINT.cfg diff --git a/p-isa_tools/CMakeLists.txt b/p-isa_tools/CMakeLists.txt index 80f767e2..750bb468 100644 --- a/p-isa_tools/CMakeLists.txt +++ b/p-isa_tools/CMakeLists.txt @@ -22,12 +22,15 @@ else() set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Choose the type of Build" FORCE) endif() -option(ENABLE_DATA_FORMATS "Enable support for the data formats library" OFF) +option(ENABLE_DATA_FORMATS "Enable support for the data formats library" ON) message(ENABLE_DATA_FORMATS="${ENABLE_DATA_FORMATS}") option(ENABLE_FUNCTIONAL_MODELER "Enable building of functional modeler" ON) message(ENABLE_FUNCTIONAL_MODELER="${ENABLE_FUNCTIONAL_MODELER}") +option(ENABLE_KERNGEN "Enable kerngen (dependencies only)" ON) +message(ENABLE_KERNGEN="${ENABLE_KERNGEN}") + option(ENABLE_PROGRAM_MAPPER "Enable building of program mapper" ON) message(ENABLE_PROGRAM_MAPPER="${ENABLE_PROGRAM_MAPPER}") @@ -63,6 +66,12 @@ file(GLOB_RECURSE IDE_HEADERS program_mapper/*.h functional_modeler/*.h dependen # Build sub-directories add_subdirectory(common) +if(ENABLE_KERNGEN) +add_subdirectory(kerngen) +endif() +if(ENABLE_DATA_FORMATS) +add_subdirectory(data_formats) +endif() if(ENABLE_FUNCTIONAL_MODELER) add_subdirectory(functional_modeler) endif() diff --git a/p-isa_tools/cmake/dependencies.cmake b/p-isa_tools/cmake/dependencies.cmake index 47f8fbc9..9c1649f4 100644 --- a/p-isa_tools/cmake/dependencies.cmake +++ b/p-isa_tools/cmake/dependencies.cmake @@ -45,15 +45,3 @@ if (NOT snap_POPULATED) include_directories(${snap_SOURCE_DIR}/snap-core ${snap_SOURCE_DIR}/glib-core) message(STATUS "Finished building SNAP") endif() - -if(ENABLE_DATA_FORMATS) - find_package(HERACLES_DATA_FORMATS CONFIG) - if(NOT HERACLES_DATA_FORMATS_FOUND) - FetchContent_Declare( - heracles_data_formats - GIT_REPOSITORY git@github.com:IntelLabs/HERACLES-data-formats.git - GIT_TAG main - ) - FetchContent_MakeAvailable(heracles_data_formats) - endif() -endif() diff --git a/p-isa_tools/data_formats/CMakeLists.txt b/p-isa_tools/data_formats/CMakeLists.txt new file mode 100644 index 00000000..a2243b37 --- /dev/null +++ b/p-isa_tools/data_formats/CMakeLists.txt @@ -0,0 +1,57 @@ +cmake_minimum_required(VERSION 3.15.0...3.24.0) + +project(HERACLES_DATA_FORMATS VERSION 1.0.0) + +set(CMAKE_CXX_STANDARD 17) +set(HERACLES_DATA_FORMATS_CMAKE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") +include(CMakePackageConfigHelpers) +include(GNUInstallDirs) + +# Set default output directories +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_LIBDIR}") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_LIBDIR}") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_BINDIR}") + +include(${PROJECT_SOURCE_DIR}/cmake/utils.cmake) +include(${PROJECT_SOURCE_DIR}/cmake/protobuf.cmake) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -O3 -fPIC") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -O3 -fPIC") +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_INSTALL_RPATH "$ORIGIN;$ORIGIN/${CMAKE_INSTALL_LIBDIR}") +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_INSTALL_LIBDIR}) + +find_package(OpenMP REQUIRED) + +option(BUILD_TEST "Build c++/python tests with ctest" ON) +enable_testing() + +add_subdirectory(proto) +add_subdirectory(cpp) +if(BUILD_TEST) + add_subdirectory(test) +endif() + +# install python utility functions +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/ + DESTINATION python + # TODO(skmono): to be added afterwards + # FILES_MATCHING + # PATTERN "*.py" +) + +# copy python utility functions to build/python +add_custom_target(HERACLES_DATA_FORMATS_COPY_PYTHON + ALL + DEPENDS HERACLES_data_proto +) +add_custom_command( + TARGET HERACLES_DATA_FORMATS_COPY_PYTHON + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/python ${PROJECT_BINARY_DIR}/python/ +) + +option(ENABLE_OPENFHE_TRACER "Build the OpenFHE tracer" ON) +if(ENABLE_OPENFHE_TRACER) +add_subdirectory(tracers/openfhe) +endif() diff --git a/p-isa_tools/data_formats/README.md b/p-isa_tools/data_formats/README.md new file mode 100644 index 00000000..d71077f8 --- /dev/null +++ b/p-isa_tools/data_formats/README.md @@ -0,0 +1,128 @@ +# HERACLES Data formatter interface + +## CMake Configure and Build +```bash +cmake -S . -B build +cmake --build build --parallel +``` +_Note: for now cmake will _not build with `ninja`_ and is only tested for +(default) `CMAKE_GENERATOR='Unix Makefiles`_ + + +## Run test +```bash +cmake --build build --target test +``` + +Note: Python `[dev]` dependencies from the root `pyproject.toml` are required to run the Python test. They can be installed via +```bash +pip install -e ".[dev]" # from repository root +``` + +## C++ + +### Importing the **HERACLES-Data-Formats** Library + +The C++ library be found and included with cmake by including +following statements in the cmakefile of the project depending on the +HERACLES data formats library: +```cmake +find_package(HERACLES_DATA_FORMATS 1.0.0 REQUIRED) +... +target_link_libraries( PUBLIC HERACLES_DATA_FORMATS::heracles_data_formats) +``` +Assuming you follow the convention of having all code +checked out in the same directory and named by their component name, you +can then build that project by executing the following: + +```bash +# from project root +HERACLES_DATA_FORMATS_DIR=$(pwd)/../HERACLES-data-formats/build cmake -S . -B build +cmake --build build --parallel +``` +Alternatively, you can also build and install HERACLES-data-formats +(with the destination chosen, e.g., using the +`-DCMAKE_INSTALL_PREFIX=/path/to/install` argument, and an `cmake +--build build --target install` after the build ). However, when +installing be careful in not forgetting to re-install +after each change and subsequent build or accidentally picking up +older versions installed elsewhere and earlier searched in CMAKE's +search paths. + + +### Usage example +The library can be used in the ```C++``` code, e.g., as followed: +```c++ +// protobuf headers +#include "heracles/heracles_proto.h" +// cpp utility headers +#include "heracles/heracles_data_formats.h" + +int main() { + heracles::fhe_trace::Trace trace; + heracles::data::InputPolynomials input_polys; + + return 0; +} +``` +Refer to the [heracles_test.cpp](src/data_formats/test/heracles_test.cpp) source +code for additional examples of using Heracles protobuf objects and +utility functions as well as [Protocol Buffer Basics: +C++](https://protobuf.dev/getting-started/cpptutorial/) for more +general information on using generated C++ protobuf code. + + +## Python + + +For the Python package to be used independently of CMake/C++ builds, the optional `dev` dependencies are required. + +1. **Install dependencies**: +```bash +# For development (includes grpcio-tools for compiling protos, pytest for testing) +pip install -e ".[dev]" +``` + +2. **Compile Protocol Buffers**: +```bash +python p-isa_tools/data_formats/compile_protos.py +``` + +This generates the Python protobuf files in `p-isa_tools/data_formats/python/heracles/proto/`. + +3. **Generate test traces** (if needed for testing): +```bash +python p-isa_tools/data_formats/test/generate_test_traces.py +``` +Alternatively, you can simply run the `pytest` tests, which will create the protobuf files and/or test traces if they do not exist yet. + +### Running Tests + +From the repository root: +```bash +pytest p-isa_tools/data_formats/test/ +``` +(The path is optional, but avoids running unrelated tests) + +### Usage example +The **HERACLES-Data-Formats** library can be imported via, e.g., +```python +from heracles.proto.common_pb2 import Scheme +from heracles.proto.fhe_trace_pb2 import Trace, Instruction +import heracles.fhe_trace.io as hfi +import heracles.data.io as hdi + +# Create and save a trace +trace = Trace() +trace.scheme = Scheme.SCHEME_BGV +hfi.store_trace("my_trace.bin", trace) + +# Load a trace +loaded_trace = hfi.load_trace("my_trace.bin") +``` + +Refer to the [heracles_test.py](test/heracles_test.py) script for +examples of using Heracles protobuf objects and utility functions as +well as [Protocol Buffer Basics: +Python](https://protobuf.dev/getting-started/pythontutorial/) for more +general information on using generated python protobuf code. diff --git a/p-isa_tools/data_formats/cmake/HERACLES_DATA_FORMATSConfig.cmake.in b/p-isa_tools/data_formats/cmake/HERACLES_DATA_FORMATSConfig.cmake.in new file mode 100644 index 00000000..c09a5028 --- /dev/null +++ b/p-isa_tools/data_formats/cmake/HERACLES_DATA_FORMATSConfig.cmake.in @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) + +include(${CMAKE_CURRENT_LIST_DIR}/HERACLES_DATA_FORMATSTargets.cmake) + +if(TARGET HERACLES_DATA_FORMATS::heracles_data_formats) + set(HERACLES_DATA_FORMATS_FOUND TRUE) + message(STATUS "Heracles Data Formats Library found") +else() + message(STATUS "Heracles Data Formats Library not found") +endif() + +# Requirement for protobuf +find_package(ZLIB REQUIRED) + +set(HERACLES_DATA_FORMATS_VERSION "@HERACLES_DATA_FORMATS_VERSION") +set(HERACLES_DATA_FORMATS_VERSION_MAJOR "@HERACLES_DATA_FORMATS_VERSION_MAJOR") +set(HERACLES_DATA_FORMATS_VERSION_MINOR "@HERACLES_DATA_FORMATS_VERSION") +set(HERACLES_DATA_FORMATS_VERSION_PATCH "@HERACLES_DATA_FORMATS_VERSION") + +set(HERACLES_DATA_FORMATS_DEBUG "@HERACLES_DATA_FORMATS_DEBUG") diff --git a/p-isa_tools/data_formats/cmake/protobuf-generate.cmake b/p-isa_tools/data_formats/cmake/protobuf-generate.cmake new file mode 100644 index 00000000..52140a6f --- /dev/null +++ b/p-isa_tools/data_formats/cmake/protobuf-generate.cmake @@ -0,0 +1,155 @@ +function(protobuf_generate) + include(CMakeParseArguments) + + set(_options APPEND_PATH) + set(_singleargs LANGUAGE OUT_VAR EXPORT_MACRO PROTOC_OUT_DIR PLUGIN PLUGIN_OPTIONS DEPENDENCIES) + if(COMMAND target_sources) + list(APPEND _singleargs TARGET) + endif() + set(_multiargs PROTOS IMPORT_DIRS GENERATE_EXTENSIONS PROTOC_OPTIONS) + + cmake_parse_arguments(protobuf_generate "${_options}" "${_singleargs}" "${_multiargs}" "${ARGN}") + + if(NOT protobuf_generate_PROTOS AND NOT protobuf_generate_TARGET) + message(SEND_ERROR "Error: protobuf_generate called without any targets or source files") + return() + endif() + + if(NOT protobuf_generate_OUT_VAR AND NOT protobuf_generate_TARGET) + message(SEND_ERROR "Error: protobuf_generate called without a target or output variable") + return() + endif() + + if(NOT protobuf_generate_LANGUAGE) + set(protobuf_generate_LANGUAGE cpp) + endif() + string(TOLOWER ${protobuf_generate_LANGUAGE} protobuf_generate_LANGUAGE) + + if(NOT protobuf_generate_PROTOC_OUT_DIR) + set(protobuf_generate_PROTOC_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) + endif() + + if(protobuf_generate_EXPORT_MACRO AND protobuf_generate_LANGUAGE STREQUAL cpp) + set(_dll_export_decl "dllexport_decl=${protobuf_generate_EXPORT_MACRO}") + endif() + + foreach(_option ${_dll_export_decl} ${protobuf_generate_PLUGIN_OPTIONS}) + # append comma - not using CMake lists and string replacement as users + # might have semicolons in options + if(_plugin_options) + set( _plugin_options "${_plugin_options},") + endif() + set(_plugin_options "${_plugin_options}${_option}") + endforeach() + + if(protobuf_generate_PLUGIN) + set(_plugin "--plugin=${protobuf_generate_PLUGIN}") + endif() + + if(NOT protobuf_generate_GENERATE_EXTENSIONS) + if(protobuf_generate_LANGUAGE STREQUAL cpp) + set(protobuf_generate_GENERATE_EXTENSIONS .pb.h .pb.cc) + elseif(protobuf_generate_LANGUAGE STREQUAL python) + set(protobuf_generate_GENERATE_EXTENSIONS _pb2.py) + else() + message(SEND_ERROR "Error: protobuf_generate given unknown Language ${LANGUAGE}, please provide a value for GENERATE_EXTENSIONS") + return() + endif() + endif() + + if(protobuf_generate_TARGET) + get_target_property(_source_list ${protobuf_generate_TARGET} SOURCES) + foreach(_file ${_source_list}) + if(_file MATCHES "proto$") + list(APPEND protobuf_generate_PROTOS ${_file}) + endif() + endforeach() + endif() + + if(NOT protobuf_generate_PROTOS) + message(SEND_ERROR "Error: protobuf_generate could not find any .proto files") + return() + endif() + + if(protobuf_generate_APPEND_PATH) + # Create an include path for each file specified + foreach(_file ${protobuf_generate_PROTOS}) + get_filename_component(_abs_file ${_file} ABSOLUTE) + get_filename_component(_abs_dir ${_abs_file} DIRECTORY) + list(FIND _protobuf_include_path ${_abs_dir} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protobuf_include_path -I ${_abs_dir}) + endif() + endforeach() + endif() + + foreach(DIR ${protobuf_generate_IMPORT_DIRS}) + get_filename_component(ABS_PATH ${DIR} ABSOLUTE) + list(FIND _protobuf_include_path ${ABS_PATH} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protobuf_include_path -I ${ABS_PATH}) + endif() + endforeach() + + if(NOT _protobuf_include_path) + set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + set(_generated_srcs_all) + foreach(_proto ${protobuf_generate_PROTOS}) + get_filename_component(_abs_file ${_proto} ABSOLUTE) + get_filename_component(_abs_dir ${_abs_file} DIRECTORY) + + get_filename_component(_file_full_name ${_proto} NAME) + string(FIND "${_file_full_name}" "." _file_last_ext_pos REVERSE) + string(SUBSTRING "${_file_full_name}" 0 ${_file_last_ext_pos} _basename) + + set(_suitable_include_found FALSE) + foreach(DIR ${_protobuf_include_path}) + if(NOT DIR STREQUAL "-I") + file(RELATIVE_PATH _rel_dir ${DIR} ${_abs_dir}) + string(FIND "${_rel_dir}" "../" _is_in_parent_folder) + if (NOT ${_is_in_parent_folder} EQUAL 0) + set(_suitable_include_found TRUE) + break() + endif() + endif() + endforeach() + + if(NOT _suitable_include_found) + message(SEND_ERROR "Error: protobuf_generate could not find any correct proto include directory.") + return() + endif() + + set(_generated_srcs) + foreach(_ext ${protobuf_generate_GENERATE_EXTENSIONS}) + list(APPEND _generated_srcs "${protobuf_generate_PROTOC_OUT_DIR}/${_rel_dir}/${_basename}${_ext}") + endforeach() + list(APPEND _generated_srcs_all ${_generated_srcs}) + + set(_comment "Running ${protobuf_generate_LANGUAGE} protocol buffer compiler on ${_proto}") + if(protobuf_generate_PROTOC_OPTIONS) + set(_comment "${_comment}, protoc-options: ${protobuf_generate_PROTOC_OPTIONS}") + endif() + if(_plugin_options) + set(_comment "${_comment}, plugin-options: ${_plugin_options}") + endif() + + add_custom_command( + OUTPUT ${_generated_srcs} + COMMAND ${CMAKE_COMMAND} -E env "LD_LIBRARY_PATH=${protobuf_LIB_DIR}" ${protobuf_PROTOC_EXECUTABLE}#protobuf_executable#protobuf::protoc + ARGS ${protobuf_generate_PROTOC_OPTIONS} --${protobuf_generate_LANGUAGE}_out ${_plugin_options}:${protobuf_generate_PROTOC_OUT_DIR} ${_plugin} ${_protobuf_include_path} ${_abs_file} + DEPENDS ${_abs_file} ${protobuf_PROTOC_EXE} ${protobuf_generate_DEPENDENCIES} protobuf_executable + COMMENT ${_comment} + VERBATIM ) + endforeach() + + set_source_files_properties(${_generated_srcs_all} PROPERTIES GENERATED TRUE) + if(protobuf_generate_OUT_VAR) + set(${protobuf_generate_OUT_VAR} ${_generated_srcs_all} PARENT_SCOPE) + endif() + if(protobuf_generate_TARGET) + target_sources(${protobuf_generate_TARGET} PRIVATE ${_generated_srcs_all}) + endif() + +endfunction() diff --git a/p-isa_tools/data_formats/cmake/protobuf.cmake b/p-isa_tools/data_formats/cmake/protobuf.cmake new file mode 100644 index 00000000..55210b6a --- /dev/null +++ b/p-isa_tools/data_formats/cmake/protobuf.cmake @@ -0,0 +1,73 @@ +# Recent release version +set(PROTOBUF_EXT_GIT_TAG v4.23.4) +set(PROTOBUF_EXT_GIT_URL https://github.com/protocolbuffers/protobuf.git) +set(PROTOBUF_EXT_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/ext_protobuf) +set(PROTOBUF_EXT_DESTDIR ${PROTOBUF_EXT_PREFIX}/protobuf_install) + +include(ExternalProject) +ExternalProject_Add( + ext_protobuf + GIT_REPOSITORY ${PROTOBUF_EXT_GIT_URL} + GIT_TAG ${PROTOBUF_EXT_GIT_TAG} + PREFIX ${PROTOBUF_EXT_PREFIX} + CMAKE_ARGS ${CMAKE_CXX_FLAGS} + -DCMAKE_INSTALL_PREFIX=${PROTOBUF_EXT_DESTDIR} + -Dprotobuf_BUILD_EXAMPLES=OFF + -Dprotobuf_BUILD_TESTS=OFF + -DABSL_PROPAGATE_CXX_STD=ON + -DCMAKE_INSTALL_LIBDIR=lib + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_RPATH=$ORIGIN + -Dprotobuf_BUILD_SHARED_LIBS=ON + UPDATE_COMMAND "" + INSTALL_COMMAND make install +) + + +set(protobuf_SOURCE_DIR ${PROTOBUF_EXT_PREFIX}/src/ext_protobuf) +set(protobuf_INCLUDE_DIR ${PROTOBUF_EXT_DESTDIR}/include) +set(protobuf_LIB_DIR ${PROTOBUF_EXT_DESTDIR}/lib) +set(protobuf_BIN_DIR ${PROTOBUF_EXT_DESTDIR}/bin) + +# setup protobuf executable +add_executable(protobuf_executable IMPORTED GLOBAL) +add_dependencies(protobuf_executable ext_protobuf) +set_target_properties(protobuf_executable PROPERTIES + IMPORTED_LOCATION ${protobuf_BIN_DIR}/protoc +) +set(protobuf_PROTOC_EXECUTABLE ${protobuf_BIN_DIR}/protoc) + +foreach(_protobuf_lib_name ${protobuf_SHARED_LIB_NAMES}) + set(_protobuf_lib_filename_shared "lib${_protobuf_lib_name}${CMAKE_SHARED_LIBRARY_SUFFIX}") + add_library(${_protobuf_lib_name} SHARED IMPORTED GLOBAL) + add_dependencies(${_protobuf_lib_name} ext_protobuf) + set_target_properties(${_protobuf_lib_name} PROPERTIES + IMPORTED_LOCATION ${protobuf_LIB_DIR}/${_protobuf_lib_filename_shared} + INCLUDE_DIRECTORIES ${protobuf_INCLUDE_DIR} + ) +endforeach() + +foreach(_protobuf_lib_name ${protobuf_STATIC_LIB_NAMES}) + set(_protobuf_lib_filename_static "lib${_protobuf_lib_name}${CMAKE_STATIC_LIBRARY_SUFFIX}") + add_library(${_protobuf_lib_name} STATIC IMPORTED GLOBAL) + add_dependencies(${_protobuf_lib_name} ext_protobuf) + set_target_properties(${_protobuf_lib_name} PROPERTIES + IMPORTED_LOCATION ${protobuf_LIB_DIR}/${_protobuf_lib_filename_static} + INCLUDE_DIRECTORIES ${protobuf_INCLUDE_DIR} + ) +endforeach() + +install(DIRECTORY ${protobuf_INCLUDE_DIR}/ + DESTINATION include +) + +# copy library and binary files when installing +install(DIRECTORY ${protobuf_LIB_DIR}/ + DESTINATION lib + USE_SOURCE_PERMISSIONS +) + +install(DIRECTORY ${protobuf_BIN_DIR}/ + DESTINATION bin + USE_SOURCE_PERMISSIONS +) diff --git a/p-isa_tools/data_formats/cmake/utils.cmake b/p-isa_tools/data_formats/cmake/utils.cmake new file mode 100644 index 00000000..7abc4829 --- /dev/null +++ b/p-isa_tools/data_formats/cmake/utils.cmake @@ -0,0 +1,141 @@ +# used for manually linking to protobuf/absl generated files +set(protobuf_SHARED_LIB_NAMES + protobuf + protobuf-lite + absl_log_internal_check_op + absl_leak_check + absl_die_if_null + absl_log_internal_conditions + absl_log_internal_message + absl_log_internal_nullguard + absl_examine_stack + absl_log_internal_format + absl_log_internal_proto + absl_log_internal_log_sink_set + absl_log_sink + absl_log_entry + absl_flags + absl_flags_internal + absl_flags_marshalling + absl_flags_reflection + absl_flags_config + absl_flags_program_name + absl_flags_private_handle_accessor + absl_flags_commandlineflag + absl_flags_commandlineflag_internal + absl_log_initialize + absl_log_globals + absl_log_internal_globals + absl_hash + absl_city + absl_low_level_hash + absl_raw_hash_set + absl_hashtablez_sampler + absl_statusor + absl_status + absl_cord + absl_cordz_info + absl_cord_internal + absl_cordz_functions + absl_exponential_biased + absl_cordz_handle + absl_crc_cord_state + absl_crc32c + absl_crc_internal + absl_crc_cpu_detect + absl_bad_optional_access + absl_str_format_internal + absl_strerror + absl_synchronization + absl_stacktrace + absl_symbolize + absl_debugging_internal + absl_demangle_internal + absl_graphcycles_internal + absl_malloc_internal + absl_time + absl_civil_time + absl_time_zone + absl_bad_variant_access + absl_strings + absl_throw_delegate + absl_int128 + absl_strings_internal + absl_base + absl_raw_logging_internal + absl_log_severity + absl_spinlock_wait +) + +set(protobuf_STATIC_LIB_NAMES + utf8_validity +) + +set(protobuf_LIB_NAMES + protobuf + protobuf-lite + absl_log_internal_check_op + absl_leak_check + absl_die_if_null + absl_log_internal_conditions + absl_log_internal_message + absl_log_internal_nullguard + absl_examine_stack + absl_log_internal_format + absl_log_internal_proto + absl_log_internal_log_sink_set + absl_log_sink + absl_log_entry + absl_flags + absl_flags_internal + absl_flags_marshalling + absl_flags_reflection + absl_flags_config + absl_flags_program_name + absl_flags_private_handle_accessor + absl_flags_commandlineflag + absl_flags_commandlineflag_internal + absl_log_initialize + absl_log_globals + absl_log_internal_globals + absl_hash + absl_city + absl_low_level_hash + absl_raw_hash_set + absl_hashtablez_sampler + absl_statusor + absl_status + absl_cord + absl_cordz_info + absl_cord_internal + absl_cordz_functions + absl_exponential_biased + absl_cordz_handle + absl_crc_cord_state + absl_crc32c + absl_crc_internal + absl_crc_cpu_detect + absl_bad_optional_access + absl_str_format_internal + absl_strerror + absl_synchronization + absl_stacktrace + absl_symbolize + absl_debugging_internal + absl_demangle_internal + absl_graphcycles_internal + absl_malloc_internal + absl_time + absl_civil_time + absl_time_zone + absl_bad_variant_access + utf8_validity + absl_strings + absl_throw_delegate + absl_int128 + absl_strings_internal + absl_base + absl_raw_logging_internal + absl_log_severity + absl_spinlock_wait +) diff --git a/p-isa_tools/data_formats/cpp/.clang-format b/p-isa_tools/data_formats/cpp/.clang-format new file mode 100644 index 00000000..87bfc5ac --- /dev/null +++ b/p-isa_tools/data_formats/cpp/.clang-format @@ -0,0 +1,126 @@ +--- +Language: Cpp +# BasedOnStyle: Microsoft +AccessModifierOffset: -4 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveMacros: false +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: false +AllowAllArgumentsOnNextLine: true +AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortLambdasOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterCaseLabel: true + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom # Allman +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Merge +IncludeCategories: + - Regex: '<.*>' + Priority: 1 + - Regex: '"seal/util/.*"' + Priority: -2 + - Regex: '"seal/.*"' + Priority: -3 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: false +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: Inner +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 1000 +PointerAlignment: Right +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 4 +UseTab: Never +... diff --git a/p-isa_tools/data_formats/cpp/CMakeLists.txt b/p-isa_tools/data_formats/cpp/CMakeLists.txt new file mode 100644 index 00000000..0e60ab0e --- /dev/null +++ b/p-isa_tools/data_formats/cpp/CMakeLists.txt @@ -0,0 +1,21 @@ +target_sources(heracles_data_formats + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/heracles/data/io.cpp + ${CMAKE_CURRENT_LIST_DIR}/heracles/data/transform.cpp + ${CMAKE_CURRENT_LIST_DIR}/heracles/data/math.cpp + ${CMAKE_CURRENT_LIST_DIR}/heracles/util/util.cpp + ${CMAKE_CURRENT_LIST_DIR}/heracles/fhe_trace/io.cpp +) + +target_include_directories(heracles_data_formats + PUBLIC $ + PUBLIC $ +) +target_link_libraries(heracles_data_formats + PRIVATE OpenMP::OpenMP_CXX +) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ + DESTINATION include + FILES_MATCHING + PATTERN "*.h" +) diff --git a/p-isa_tools/data_formats/cpp/heracles/data/io.cpp b/p-isa_tools/data_formats/cpp/heracles/data/io.cpp new file mode 100644 index 00000000..d440f7e6 --- /dev/null +++ b/p-isa_tools/data_formats/cpp/heracles/data/io.cpp @@ -0,0 +1,316 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "heracles/data/io.h" +#include +#include +#include +#include +#include +#include +#include + +namespace heracles::data +{ +hdf_manifest parse_manifest(const std::string &filename) +{ + hdf_manifest manifest; + std::ifstream file(filename); + + if (!file.is_open()) + { + throw std::runtime_error("Manifest file not found: " + filename); + } + + std::string header, current_line; + auto trim = [](const std::string &line) { + std::string res = line; + res.erase(std::remove_if(res.begin(), res.end(), isspace), res.end()); + return res; + }; + + std::string cur_field = ""; + bool found_first_field = false; + int linenum = 0; + while (std::getline(file, current_line)) + { + ++linenum; + current_line = trim(current_line); + if (current_line.front() == '[' && current_line.back() == ']') + { + cur_field = current_line.substr(1, current_line.size() - 2); + found_first_field = true; + continue; + } + + if (!found_first_field) + { + continue; + } + + std::stringstream current_line_ss(current_line); + std::string value; + std::vector values; + while (std::getline(current_line_ss, value, '=')) + { + values.push_back(trim(value)); + } + // if format is not "a=b", pass + if (values.size() != 2) + { + std::cout << "Warning : ignoring incorrect format in line :" << linenum << std::endl; + continue; + } + + manifest[cur_field][values[0]] = values[1]; + } + + if (!found_first_field) + throw std::runtime_error("Incorrect manifest format: " + filename); + + return manifest; +} + +void generate_manifest(const std::string &filename, const hdf_manifest &manifest) +{ + std::stringstream ss; + for (const auto &[field, values] : manifest) + { + ss << "[" << field << "]" << std::endl; + for (const auto &[key, fn] : values) + { + ss << key << "=" << fn << std::endl; + } + } + std::ofstream ofile(filename, std::ios::out); + ofile << ss.rdbuf(); +} + +bool store_hec_context_json(const std::string &filename, const heracles::data::FHEContext &context) +{ + std::string json_str; + auto rc = ::google::protobuf::util::MessageToJsonString(context, &json_str); + std::ofstream json_ofile(filename, std::ios::out); + json_ofile << json_str; + + return true; +} + +bool store_testvector_json(const std::string &filename, const heracles::data::TestVector &test_vector) +{ + std::string json_str; + auto rc = ::google::protobuf::util::MessageToJsonString(test_vector, &json_str); + std::ofstream json_ofile(filename, std::ios::out); + + json_ofile << json_str; + return true; +} + +void store_hec_context( + hdf_manifest *manifest_out, const std::string &filename, const heracles::data::FHEContext &context_pb) +{ + auto tmp_context = context_pb; + if (context_pb.ByteSizeLong() > (1 << 30)) + { + int gkct = 1; + for (const auto &[ge, key] : tmp_context.ckks_info().keys().rotation_keys()) + { + std::stringstream parts_fnss; + parts_fnss << filename << "_hec_context_part_" << gkct++; + std::ofstream pb_ofile(parts_fnss.str(), std::ios::out | std::ios::binary); + (*manifest_out)["rotation_keys"][std::to_string(ge)] = parts_fnss.str(); + if (!key.SerializeToOstream(&pb_ofile)) + throw std::runtime_error("Serializing rotation key failed"); + } + tmp_context.mutable_ckks_info()->mutable_keys()->clear_rotation_keys(); + } + + auto main_fs = filename + "_hec_context_part_0"; + (*manifest_out)["context"]["main"] = main_fs; + std::ofstream pb_ofile(main_fs, std::ios::out | std::ios::binary); + if (!tmp_context.SerializeToOstream(&pb_ofile)) + throw std::runtime_error("Serializing main hec context failed"); +} +void store_testvector( + hdf_manifest *manifest_out, const std::string &filename, const heracles::data::TestVector &testvector_pb) +{ + if (testvector_pb.ByteSizeLong() > (1 << 30)) + { + int tvct = 0; + for (const auto &[sym, data_part] : testvector_pb.sym_data_map()) + { + std::stringstream parts_fnss; + parts_fnss << filename << "_testvector_part_" << tvct++; + auto parts_fn = parts_fnss.str(); + (*manifest_out)["testvector"][sym] = parts_fn; + std::ofstream pb_ofile(parts_fn, std::ios::out | std::ios::binary); + if (!data_part.SerializeToOstream(&pb_ofile)) + throw std::runtime_error("Serializing test vector part " + sym + " failed. File : " + parts_fn); + } + return; + } + + auto full_fn = filename + "_testvector_part_0"; + (*manifest_out)["testvector"]["full"] = full_fn; + std::ofstream pb_ofile(full_fn, std::ios::out | std::ios::binary); + if (!testvector_pb.SerializeToOstream(&pb_ofile)) + throw std::runtime_error("Serializing full test vector failed. File : " + full_fn); +} + +bool store_data_trace( + const std::string &filename, const heracles::data::FHEContext &context_pb, + const heracles::data::TestVector &testvector_pb) +{ + hdf_manifest manifest_datatrace; + try + { + store_hec_context(&manifest_datatrace, filename, context_pb); + store_testvector(&manifest_datatrace, filename, testvector_pb); + generate_manifest(filename, manifest_datatrace); + } + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during store_data_trace, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } + + return true; +} + +void load_hec_context_from_manifest(heracles::data::FHEContext *context_pb, const hdf_manifest &manifest) +{ + try + { + std::filesystem::path main_fn(manifest.at("context").at("main")); + std::ifstream context_pb_ifile(main_fn, std::ios::in | std::ios::binary); + context_pb->ParseFromIstream(&context_pb_ifile); + + if (manifest.count("rotation_keys")) + { + for (const auto &[ge, gk_fn] : manifest.at("rotation_keys")) + { + heracles::data::KeySwitch gk_pb; + std::ifstream gk_pb_ifile(gk_fn, std::ios::in | std::ios::binary); + gk_pb.ParseFromIstream(&gk_pb_ifile); + (*(context_pb->mutable_ckks_info() + ->mutable_keys() + ->mutable_rotation_keys()))[static_cast(std::stoul(ge))] = gk_pb; + } + } + } + + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during load_hec_context, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } +} +void load_testvector_from_manifest(heracles::data::TestVector *testvector_pb, const hdf_manifest &manifest) +{ + try + { + if (manifest.at("testvector").find("full") != manifest.at("testvector").end()) + { // single file + const auto &full_fn = manifest.at("testvector").at("full"); + std::ifstream pb_ifile(full_fn, std::ios::in | std::ios::binary); + testvector_pb->ParseFromIstream(&pb_ifile); + } + else + { // segmented + for (const auto &[sym, parts_fn] : manifest.at("testvector")) + { + std::ifstream pb_ifile(parts_fn, std::ios::in | std::ios::binary); + (*testvector_pb->mutable_sym_data_map())[sym].ParseFromIstream(&pb_ifile); + } + } + } + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during _load_testvector, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } +} + +heracles::data::FHEContext load_hec_context(const std::string &filename) +{ + heracles::data::FHEContext context_pb; + try + { + auto manifest = parse_manifest(filename); + load_hec_context_from_manifest(&context_pb, manifest); + } + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during load_data_trace, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } + + return context_pb; +} +heracles::data::TestVector load_testvector(const std::string &filename) +{ + heracles::data::TestVector testvector_pb; + try + { + auto manifest = parse_manifest(filename); + load_testvector_from_manifest(&testvector_pb, manifest); + } + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during load_data_trace, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } + + return testvector_pb; +} + +std::pair load_data_trace(const std::string &filename) +{ + heracles::data::FHEContext context_pb; + heracles::data::TestVector testvector_pb; + try + { + auto manifest = parse_manifest(filename); + load_hec_context_from_manifest(&context_pb, manifest); + load_testvector_from_manifest(&testvector_pb, manifest); + } + catch (const std::runtime_error &err) + { + std::cerr << "Runtime error during load_data_trace, err: " << err.what() << std::endl; + throw err; + } + catch (...) + { + std::cerr << "Unknown exception caught in " << __FUNCTION__ << "in file" << __FILE__ << std::endl; + throw; + } + + return { context_pb, testvector_pb }; +} + +} // namespace heracles::data diff --git a/p-isa_tools/data_formats/cpp/heracles/data/math.cpp b/p-isa_tools/data_formats/cpp/heracles/data/math.cpp new file mode 100644 index 00000000..32a4e9cb --- /dev/null +++ b/p-isa_tools/data_formats/cpp/heracles/data/math.cpp @@ -0,0 +1,145 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "heracles/data/math.h" +#include +#include +#include + +std::uint32_t heracles::math::reverse_bits(std::uint32_t operand, std::uint32_t bit_count) +{ + uint32_t c[2]; + multiply_uint(operand, operand, c); + + if (bit_count == 0) + return 0; + + operand = + (((operand & static_cast(0xaaaaaaaaU)) >> 1) | ((operand & static_cast(0x55555555U)) << 1)); + operand = + (((operand & static_cast(0xccccccccU)) >> 2) | ((operand & static_cast(0x33333333U)) << 2)); + operand = + (((operand & static_cast(0xf0f0f0f0U)) >> 4) | ((operand & static_cast(0x0f0f0f0fU)) << 4)); + operand = + (((operand & static_cast(0xff00ff00U)) >> 8) | ((operand & static_cast(0x00ff00ffU)) << 8)); + return (static_cast(operand >> 16) | static_cast(operand << 16)) >> + (32 - static_cast(bit_count)); +} + +template <> +void heracles::math::multiply_uint(uint32_t operand1, uint32_t operand2, uint32_t *result) +{ + auto operand1_coeff_right = operand1 & 0x0000FFFFU; + auto operand2_coeff_right = operand2 & 0x0000FFFFU; + + operand1 >>= 16; + operand2 >>= 16; + auto middle1 = operand1 * operand2_coeff_right; + uint32_t middle; + auto left = operand1 * operand2 + + (static_cast(add_uint(middle1, operand2 * operand1_coeff_right, &middle)) << 16); + auto right = operand1_coeff_right * operand2_coeff_right; + auto temp_sum = (right >> 16) + (middle & 0x0000FFFFU); + result[1] = static_cast(left + (middle >> 16) + (temp_sum >> 16)); + result[0] = static_cast((temp_sum << 16) | (right & 0x0000FFFFU)); +} + +template <> +void heracles::math::multiply_uint(uint64_t operand1, uint64_t operand2, uint64_t *result) +{ + auto operand1_coeff_right = operand1 & 0x00000000FFFFFFFFUL; + auto operand2_coeff_right = operand2 & 0x00000000FFFFFFFFUL; + + operand1 >>= 32; + operand2 >>= 32; + auto middle1 = operand1 * operand2_coeff_right; + uint64_t middle; + auto left = operand1 * operand2 + + (static_cast(add_uint(middle1, operand2 * operand1_coeff_right, &middle)) << 32); + auto right = operand1_coeff_right * operand2_coeff_right; + auto temp_sum = (right >> 32) + (middle & 0x00000000FFFFFFFFUL); + result[1] = static_cast(left + (middle >> 32) + (temp_sum >> 32)); + result[0] = static_cast((temp_sum << 32) | (right & 0x00000000FFFFFFFFUL)); +} + +template <> +size_t heracles::math::get_msb_index(std::uint32_t value) +{ + static const std::uint8_t BitPositionLookup[32] = { 0, 1, 16, 2, 29, 17, 3, 22, 30, 20, 18, 11, 13, 4, 7, 23, + 31, 15, 28, 21, 19, 10, 12, 6, 14, 27, 9, 5, 26, 8, 25, 24 }; + + value |= (value >> 1); + value |= (value >> 2); + value |= (value >> 4); + value |= (value >> 8); + value |= (value >> 16); + + return BitPositionLookup[((value - (value >> 1)) * 0x06EB14F9U) >> 27]; +} + +template <> +size_t heracles::math::get_msb_index(std::uint64_t value) +{ + static const std::uint8_t BitPositionLookup[64] = { 63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5 }; + value |= (value >> 1); + value |= (value >> 2); + value |= (value >> 4); + value |= (value >> 8); + value |= (value >> 16); + value |= (value >> 32); + + return BitPositionLookup[((value - (value >> 1)) * 0x07EDD5E59A4E28C2UL) >> 58]; +} + +std::tuple heracles::math::xgcd(uint64_t x, uint64_t y) +{ + int64_t prev_a = 1; + int64_t a = 0; + int64_t prev_b = 0; + int64_t b = 1; + + while (y != 0) + { + int64_t q = static_cast(x / y); + int64_t temp = static_cast(x % y); + x = y; + y = static_cast(temp); + + temp = a; + a = prev_a - a * q; + prev_a = temp; + + temp = b; + b = prev_b - b * q; + prev_b = temp; + } + return std::make_tuple(x, prev_a, prev_b); +} + +std::tuple heracles::math::xgcd(uint32_t x, uint32_t y) +{ + int32_t prev_a = 1; + int32_t a = 0; + int32_t prev_b = 0; + int32_t b = 1; + + while (y != 0) + { + int32_t q = static_cast(x / y); + int32_t temp = static_cast(x % y); + x = y; + y = static_cast(temp); + + temp = a; + a = prev_a - a * q; + prev_a = temp; + + temp = b; + b = prev_b - b * q; + prev_b = temp; + } + return std::make_tuple(x, prev_a, prev_b); +} diff --git a/p-isa_tools/data_formats/cpp/heracles/data/transform.cpp b/p-isa_tools/data_formats/cpp/heracles/data/transform.cpp new file mode 100644 index 00000000..0b14531a --- /dev/null +++ b/p-isa_tools/data_formats/cpp/heracles/data/transform.cpp @@ -0,0 +1,506 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "heracles/data/transform.h" +#include +#include +#include +#include +#include +#include +#include "heracles/data/math.h" +#include "heracles/util/util.h" + +namespace hmath = heracles::math; +namespace hutildata = heracles::util::data; +namespace heracles::data +{ +void extract_metadata_polys( + heracles::data::MetadataPolynomials *metadata_polys, const heracles::data::FHEContext &context) +{ + auto sym_poly_map = metadata_polys->mutable_metadata()->mutable_sym_poly_map(); + + auto N = context.n(); + std::uint32_t nQ = context.q_i_size(); + std::vector psi(context.psi().begin(), context.psi().end()); + std::vector psi_inv(psi.size()); + for (size_t i = 0; i < psi.size(); ++i) + hmath::try_invert_uint_mod(psi[i], context.q_i(i), &psi_inv[i]); // 32 + + std::set galois_elts; + if (context.scheme() == heracles::common::SCHEME_BGV) + { + for (const auto &pt : context.bgv_info().plaintext_specific()) + { + auto keys = pt.keys(); + for (const auto &[ge, _] : keys.rotation_keys()) + galois_elts.insert(ge); + } + } + else if (context.scheme() == heracles::common::SCHEME_CKKS) + { + for (const auto &[ge, _] : context.ckks_info().keys().rotation_keys()) + galois_elts.insert(ge); + } + + for (std::uint32_t i = 0; i < nQ; ++i) + { + std::string key_psi_default = "psi_default_" + std::to_string(i); + std::string key_ipsi_default = "ipsi_default_" + std::to_string(i); + std::vector vpsi(N), vipsi(N); +#pragma omp parallel for + for (uint32_t j = 0; j < N; ++j) + { + vpsi[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(psi[i], j, context.q_i(i)), context.q_i(i)); + vipsi[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(psi_inv[i], j, context.q_i(i)), context.q_i(i)); + } + hutildata::poly_bit_reverse(&(*sym_poly_map)[key_psi_default], vpsi); + hutildata::poly_bit_reverse(&(*sym_poly_map)[key_ipsi_default], vipsi); + + // Get ql half and ql half mod q + // Rescale: (qi-1) / 2 mod qj for all i < j, (i>=2) + // mod_raise: (qi-1) / 2 mod qj for i=0,1 & all j + if (context.scheme() == heracles::common::SCHEME_CKKS && i < context.q_size()) + { + uint32_t qlHalf_i = context.q_i(i) >> 1; + (*sym_poly_map)["qlHalf_" + hutildata::toStrKey({ i })].mutable_coeffs()->Resize(N, qlHalf_i); + + auto jMax = i <= 1 ? static_cast(context.q_size()) : i; + for (std::uint32_t j = 0; j < jMax; ++j) + { + (*sym_poly_map)["qlHalfModq_" + hutildata::toStrKey({ i, j })].mutable_coeffs()->Resize( + N, qlHalf_i % context.q_i(j)); + } + } + + for (const uint32_t ge : galois_elts) + { + uint32_t exp_scale; + hmath::try_invert_uint_mod(ge, 2 * N, &exp_scale); // 32 + std::string key_ipsi_ge_i = "ipsi_" + std::to_string(ge) + "_" + std::to_string(i); + std::vector tmp(N); +#pragma omp parallel for + for (uint32_t j = 0; j < N; ++j) + { + tmp[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(psi_inv[i], exp_scale * j, context.q_i(i)), context.q_i(i)); + } + hutildata::poly_bit_reverse(&(*sym_poly_map)[key_ipsi_ge_i], tmp); + } + } + + // key switch keys + if (context.scheme() == heracles::common::SCHEME_BGV) + { + for (int pt = 0; pt < context.bgv_info().plaintext_specific_size(); ++pt) + { + auto keys = context.bgv_info().plaintext_specific(pt).keys(); + std::string rlk_prefix = "rlk_" + std::to_string(pt); + hutildata::transform_and_flatten_key_switch( + metadata_polys->mutable_metadata(), rlk_prefix, keys.relin_key()); + for (const auto &[ge, key] : keys.rotation_keys()) + { + std::stringstream gk_prefix; + gk_prefix << "gk_" << pt << "_" << ge; + hutildata::transform_and_flatten_key_switch(metadata_polys->mutable_metadata(), gk_prefix.str(), key); + } + } + } + else if (context.scheme() == heracles::common::SCHEME_CKKS) + { + auto keys = context.ckks_info().keys(); + std::string rlk_prefix = "rlk"; + hutildata::transform_and_flatten_key_switch(metadata_polys->mutable_metadata(), rlk_prefix, keys.relin_key()); + + for (const auto &[ge, key] : keys.rotation_keys()) + { + std::stringstream gk_prefix; + gk_prefix << "gk_" << ge; + hutildata::transform_and_flatten_key_switch(metadata_polys->mutable_metadata(), gk_prefix.str(), key); + } + } + + // bootstrapping + if (context.scheme() == heracles::common::SCHEME_BGV) + { + if (context.bgv_info().has_recrypt_key()) + hutildata::transform_and_flatten_ciphertext( + metadata_polys->mutable_metadata(), "bk", context.bgv_info().recrypt_key()); + } + else if (context.scheme() == heracles::common::SCHEME_CKKS) + { + std::vector zeros(N, 0); + *((*sym_poly_map)["zero"].mutable_coeffs()) = { zeros.begin(), zeros.end() }; + } +} + +void extract_metadata_twiddles( + heracles::data::MetadataTwiddles *metadata_twiddles, const heracles::data::FHEContext &context) +{ + std::vector omega; + std::vector omega_inv; + omega.reserve(context.key_rns_num()); + omega_inv.reserve(context.key_rns_num()); + + // TODO(skmono): replace "default" to "0" for future update on ntt/intt + for (size_t i = 0; i < context.key_rns_num(); ++i) + { + omega.push_back(hmath::exponentiate_uint_mod(context.psi(i), 2U, context.q_i(i))); + uint32_t inv; + hmath::try_invert_uint_mod(omega.back(), context.q_i(i), &inv); // 32 + omega_inv.push_back(inv); + } + + auto twiddles_ntt = metadata_twiddles->mutable_twiddles_ntt(); + auto twiddles_intt = metadata_twiddles->mutable_twiddles_intt(); + + metadata_twiddles->set_only_power_of_two(false); + + for (size_t i = 0; i < context.key_rns_num(); ++i) + { + auto default_ntt = (*twiddles_ntt)["default"].add_rns_polys(); + auto default_intt = (*twiddles_intt)["default"].add_rns_polys(); + std::vector vntt(context.n() / 2), vintt(context.n() / 2); +#pragma omp parallel for + for (uint32_t j = 0; j < context.n() / 2; ++j) + { + vntt[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(omega[i], j, context.q_i(i)), context.q_i(i)); + vintt[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(omega_inv[i], j, context.q_i(i)), context.q_i(i)); + } + *(default_ntt->mutable_coeffs()) = { vntt.begin(), vntt.end() }; + *(default_intt->mutable_coeffs()) = { vintt.begin(), vintt.end() }; + default_ntt->set_modulus(context.q_i(i)); + default_intt->set_modulus(context.q_i(i)); + } + + // twiddle factors for galois elements + std::set galois_elts; + if (context.scheme() == heracles::common::SCHEME_BGV) + { + for (const auto &pt : context.bgv_info().plaintext_specific()) + { + auto keys = pt.keys(); + for (const auto &[ge, _] : keys.rotation_keys()) + galois_elts.insert(ge); + } + } + else if (context.scheme() == heracles::common::SCHEME_CKKS) + { + for (const auto &[ge, _] : context.ckks_info().keys().rotation_keys()) + galois_elts.insert(ge); + } + + for (const uint32_t ge : galois_elts) + { + uint32_t exp_scale; + hmath::try_invert_uint_mod(ge, 2 * context.n(), &exp_scale); // 32 + for (size_t i = 0; i < context.key_rns_num(); ++i) + { + auto ge_intt = (*twiddles_intt)[std::to_string(ge)].add_rns_polys(); + std::vector vintt_ge(context.n() / 2); +#pragma omp parallel for + for (uint32_t j = 0; j < context.n() / 2; ++j) + { + vintt_ge[j] = hutildata::convert_to_montgomery( + hmath::exponentiate_uint_mod(omega_inv[i], exp_scale * j, context.q_i(i)), context.q_i(i)); + } + *(ge_intt->mutable_coeffs()) = { vintt_ge.begin(), vintt_ge.end() }; + ge_intt->set_modulus(context.q_i(i)); + } + } +} + +bool extract_metadata_immediates( + heracles::data::MetadataImmediates *metadata_immediates, const heracles::data::FHEContext &context) +{ + auto sym_immediate_map = metadata_immediates->mutable_sym_immediate_map(); + (*sym_immediate_map)["one"] = 1; + if (context.scheme() == heracles::common::SCHEME_BGV) + { + uint32_t inv = 0; + + for (size_t i = 0; i < context.key_rns_num(); ++i) + { + (*sym_immediate_map)["R2_" + std::to_string(i)] = + hmath::exponentiate_uint_mod(hutildata::montgomery_R, 2UL, static_cast(context.q_i(i))); + hmath::try_invert_uint_mod(context.n(), context.q_i(i), &inv); // 32 + (*sym_immediate_map)["iN_" + std::to_string(i)] = hutildata::convert_to_montgomery(inv, context.q_i(i)); + for (size_t j = 0; j < i; ++j) + { + hmath::try_invert_uint_mod(context.q_i(i), context.q_i(j), &inv); // 32 + (*sym_immediate_map)["inv_q_i_" + std::to_string(i) + "_mod_q_j_" + std::to_string(j)] = + hutildata::convert_to_montgomery(inv, context.q_i(j)); + } + for (int pt = 0; pt < context.bgv_info().plaintext_specific_size(); ++pt) + { + hmath::try_invert_uint_mod( + static_cast(context.bgv_info().plaintext_specific(pt).plaintext_modulus()), + context.q_i(i), &inv); // 32 + (*sym_immediate_map)["neg_inv_t_" + std::to_string(pt) + "_mod_q_i_" + std::to_string(i)] = + hutildata::convert_to_montgomery(-inv, context.q_i(i)); + (*sym_immediate_map)["t_" + std::to_string(pt) + "_mod_q_i_" + std::to_string(i)] = + hutildata::convert_to_montgomery( + context.bgv_info().plaintext_specific(pt).plaintext_modulus(), context.q_i(i)); + } + } + + (*sym_immediate_map)["iN"] = static_cast(0x100000000ULL / static_cast(context.n())); + auto k = context.bgv_info().plaintext_specific(0).keys().relin_key().k(); + uint32_t p = context.q_i(context.key_rns_num() - 1); + for (uint32_t i = 0; i < context.key_rns_num() - 1; ++i) + { + hmath::try_invert_uint_mod(p, context.q_i(i), &inv); // 32 + (*sym_immediate_map)["inv_p_mod_q_i_" + std::to_string(i)] = + hutildata::convert_to_montgomery(inv, context.q_i(i)); + } + + for (size_t l = 0; l < context.key_rns_num() - 1; ++l) + { + for (size_t j = 0; j < context.key_rns_num(); ++j) + { + for (uint32_t i = 0; i < l + 1; ++i) + { + uint32_t q_over_qi_mod_qj = 1; + for (size_t k = 0; k < context.key_rns_num(); ++k) + { + if (k != i) + q_over_qi_mod_qj = + hmath::multiply_uint_mod(q_over_qi_mod_qj, context.q_i(k), context.q_i(j)); // 32 + } + (*sym_immediate_map) + ["base_change_matrix_" + std::to_string(i) + "_" + std::to_string(j) + "_" + + std::to_string(k)] = hutildata::convert_to_montgomery(q_over_qi_mod_qj, context.q_i(j)); + if (i == j) + { + hmath::try_invert_uint_mod(q_over_qi_mod_qj, context.q_i(i), &inv); // 32 + (*sym_immediate_map)["inv_punctured_prod_" + std::to_string(i) + "_" + std::to_string(i)] = + hutildata::convert_to_montgomery(inv, context.q_i(i)); + } + } + } + } + } + else if (context.scheme() == heracles::common::SCHEME_CKKS) + { + uint32_t inv = 0; + + auto dnum = context.digit_size(); + auto alpha = context.alpha(); + auto sizeQ = context.q_size(); + auto sizeP = context.key_rns_num() - sizeQ; + + for (size_t i = 0; i < context.key_rns_num(); ++i) + { + (*sym_immediate_map)["R2_" + std::to_string(i)] = + hmath::exponentiate_uint_mod(hutildata::montgomery_R, 2UL, static_cast(context.q_i(i))); + hmath::try_invert_uint_mod(context.n(), context.q_i(i), &inv); // 32 + (*sym_immediate_map)["iN_" + std::to_string(i)] = hutildata::convert_to_montgomery(inv, context.q_i(i)); + } + (*sym_immediate_map)["iN"] = static_cast(0x100000000ULL / static_cast(context.n())); + + // TODO: remove dnum/alpha + // Get q0 inv mod q1 and q1 inv mod q0 for ModRaise kernel + std::uint32_t q0InvModq1 = heracles::math::get_invert_uint_mod(context.q_i(0), context.q_i(1)); + std::uint32_t q1InvModq0 = heracles::math::get_invert_uint_mod(context.q_i(1), context.q_i(0)); + (*sym_immediate_map)["q0InvModq1"] = hutildata::convert_to_montgomery(q0InvModq1, context.q_i(1)); + + (*sym_immediate_map)["q1InvModq0"] = hutildata::convert_to_montgomery(q1InvModq0, context.q_i(0)); + + // Metadata for key-switching (Relin, Rotate) + // PartQHatInvModq_{i}_{j} = (Q/Qi)^-1 mod qj; equals to zero for qj \notin Qi + for (uint32_t i = 0; i < dnum; ++i) + { + for (uint32_t j = 0; j < sizeQ; ++j) + { + (*sym_immediate_map)["partQHatInvModq_" + hutildata::toStrKey({ i, j })] = + hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("partQHatInvModq_" + hutildata::toStrKey({ i, j })), + context.q_i(j)); + } + } + + // PartQlHatInvModq_{i}_{j}_{l} = (Q^(i*alpha + j)_i/ql)^-1 mod ql for ql \in Q^(i*alpha + j)_i + for (uint32_t i = 0; i < dnum; ++i) + { + uint32_t digitSize = i < (dnum - 1) ? alpha : sizeQ - alpha * (dnum - 1); + for (uint32_t j = 0; j < digitSize; ++j) + { + for (uint32_t l = 0; l < j + 1; ++l) + { + (*sym_immediate_map)["partQlHatInvModq_" + hutildata::toStrKey({ i, j, l })] = + hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at( + "partQlHatInvModq_" + hutildata::toStrKey({ i, j, l })), + context.q_i(alpha * i + l)); + } + } + } + + // PartQlHatModp_{i}_{j}_{l}_{s} = (Q^(i)_j/ql)^-1 mod qs or ps, for qs \notin Q^(i)_j + for (uint32_t i = 0; i < sizeQ; ++i) + { + uint32_t beta = std::ceil(static_cast(i + 1) / static_cast(alpha)); + for (uint32_t j = 0; j < beta; ++j) + { + uint32_t digitSize = j < beta - 1 ? alpha : (i + 1) - alpha * (beta - 1); + auto sizeCompl = (i + 1) + sizeP - digitSize; + for (uint32_t l = 0; l < digitSize; ++l) + { + for (uint32_t s = 0; s < sizeCompl; ++s) + { + size_t idx = + s < alpha * j ? s : (s < i + 1 - digitSize ? s + digitSize : s + digitSize + sizeQ - i - 1); + (*sym_immediate_map)["partQlHatModp_" + hutildata::toStrKey({ i, j, l, s })] = + hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at( + "partQlHatModp_" + hutildata::toStrKey({ i, j, l, s })), + context.q_i(idx)); + } + } + } + } + + // pInvModq_{i} = P^{-1} mod qi + for (uint32_t i = 0; i < sizeQ; ++i) + { + (*sym_immediate_map)["pInvModq_" + std::to_string(i)] = hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("pInvModq_" + std::to_string(i)), context.q_i(i)); + (*sym_immediate_map)["pModq_" + std::to_string(i)] = hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("pModq_" + std::to_string(i)), context.q_i(i)); + } + + // pInvModp_{i} = P^{-1} mod pi + for (uint32_t i = 0; i < sizeP; ++i) + (*sym_immediate_map)["pHatInvModp_" + std::to_string(i)] = hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("pHatInvModp_" + std::to_string(i)), context.q_i(i + sizeQ)); + + // pHatModq_{i}_{j} = P/pi mod qj + for (uint32_t i = 0; i < sizeP; ++i) + { + for (uint32_t j = 0; j < sizeQ; ++j) + { + (*sym_immediate_map)["pHatModq_" + hutildata::toStrKey({ i, j })] = hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("pHatModq_" + hutildata::toStrKey({ i, j })), + context.q_i(j)); + } + } + + // Metadata for Rescale + // qlInvModq_{i}_{j} = q_{sizeQ-(i+1)}^{-1} mod qj + // QlQlInvModqlDivqlModq_{i}_{j} = ((Q/q_{sizeQ-(i+1)})^{-1} mod q_{sizeQ-(i+1)} * (Q/q_{sizeQ-(i+1)})) mod + // qj + for (uint32_t i = 0; i < sizeQ - 1; ++i) + { + for (uint32_t j = 0; j < sizeQ - i - 1; ++j) + { + (*sym_immediate_map)["qlInvModq_" + hutildata::toStrKey({ i, j })] = hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at("qlInvModq_" + hutildata::toStrKey({ i, j })), + context.q_i(j)); + (*sym_immediate_map)["QlQlInvModqlDivqlModq_" + hutildata::toStrKey({ i, j })] = + hutildata::convert_to_montgomery( + context.ckks_info().metadata_extra().at( + "QlQlInvModqlDivqlModq_" + hutildata::toStrKey({ i, j })), + context.q_i(j)); + } + } + + // Metadata for Bootstrap + for (size_t i = 0; i < 2; ++i) + { + for (size_t j = 0; j < sizeQ; ++j) + { + (*sym_immediate_map)["qlModq_" + std::to_string(i) + "_" + std::to_string(j)] = + hutildata::convert_to_montgomery(context.q_i(i), context.q_i(j)); + } + } + + auto boot_correction = context.ckks_info().metadata_extra().at("boot_correction"); + for (std::uint32_t i = 0; i < 32; ++i) + { + std::uint32_t val = 1 << i; + for (size_t j = 0; j < sizeQ; ++j) + { + (*sym_immediate_map)["bmu_" + std::to_string(val) + "_" + std::to_string(j)] = + hutildata::convert_to_montgomery(val, context.q_i(j)); + // only perform once + if (i == 0) + (*sym_immediate_map)["bmu_" + std::to_string(boot_correction)] = + hutildata::convert_to_montgomery(boot_correction, context.q_i(j)); + } + } + } + else + return false; + + return true; +} + +void extract_polys(heracles::data::DataPolynomials *polys, const heracles::data::TestVector &testvector) +{ + for (const auto &[key, data] : testvector.sym_data_map()) + { + hutildata::transform_and_flatten_dcrtpoly(polys->mutable_data(), key, data.dcrtpoly()); + } +} + +void extract_metadata_params(heracles::data::MetadataParams *metadata_params, const heracles::data::FHEContext &context) +{ + auto sym_param_map = metadata_params->mutable_sym_param_map(); + (*sym_param_map)["key_rns_num"] = context.key_rns_num(); + (*sym_param_map)["digit_size"] = context.digit_size(); + (*sym_param_map)["q_size"] = context.q_size(); + (*sym_param_map)["alpha"] = context.alpha(); + + // TODO: this is duplicate of above "digit_size", later merge two one + (*sym_param_map)["dnum"] = context.digit_size(); +} + +void convert_polys_to_testvector(heracles::data::TestVector *testvector, const heracles::data::DataPolynomials &polys) +{ + std::unordered_map> sym_map; + for (const auto &item : polys.data().sym_poly_map()) + { + // find root symbol and order/num_rns + auto [sym_basename, order, rns] = hutildata::split_symbol_name(item.first); + if (sym_map.find(sym_basename) == sym_map.end()) + { + sym_map[sym_basename] = { order + 1, rns + 1 }; + continue; + } + auto &[order_max, rns_num] = sym_map[sym_basename]; + order_max = std::max(order_max, order + 1); + rns_num = std::max(rns_num, rns + 1); + } + + auto merge_sym = [](std::string root, uint32_t o, uint32_t r) { + return root + "_" + std::to_string(o) + "_" + std::to_string(r); + }; + + for (const auto &[sym_basename, params] : sym_map) + { + heracles::data::DCRTPoly data; + for (uint32_t i = 0; i < params.first; ++i) + { + auto poly = data.add_polys(); + for (uint32_t j = 0; j < params.second; ++j) + { + hutildata::convert_rnspoly_to_original( + poly->add_rns_polys(), polys.data().sym_poly_map().at(merge_sym(sym_basename, i, j))); + } + } + *((*(testvector->mutable_sym_data_map()))[sym_basename].mutable_dcrtpoly()) = data; + } +} + +void prune_polys( + heracles::data::TestVector *testvector, const heracles::data::FHEContext &context, + const heracles::fhe_trace::Trace &trace) +{ + throw std::logic_error("Not yet implemented!"); +} + +} // namespace heracles::data diff --git a/p-isa_tools/data_formats/cpp/heracles/fhe_trace/io.cpp b/p-isa_tools/data_formats/cpp/heracles/fhe_trace/io.cpp new file mode 100644 index 00000000..f4bff5ca --- /dev/null +++ b/p-isa_tools/data_formats/cpp/heracles/fhe_trace/io.cpp @@ -0,0 +1,58 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "heracles/fhe_trace/io.h" +#include +#include +#include +#include +#include + +namespace heracles::fhe_trace +{ +bool store_trace(const std::string &filename, const heracles::fhe_trace::Trace &trace) +{ + std::ofstream pb_ofile(filename, std::ios::out | std::ios::binary); + return trace.SerializeToOstream(&pb_ofile); +} + +heracles::fhe_trace::Trace load_trace(const std::string &filename) +{ + std::ifstream pb_ifile(filename, std::ios::binary); + heracles::fhe_trace::Trace trace; + if (!trace.ParseFromIstream(&pb_ifile)) + throw std::runtime_error("Cannot read from file : " + filename); + return trace; +} + +bool store_json_trace(const std::string &filename, const heracles::fhe_trace::Trace &trace) +{ + std::string str; + auto status = google::protobuf::util::MessageToJsonString(trace, &str); + if (!status.ok()) + return false; + std::ofstream json_ofile(filename, std::ios::out); + json_ofile << str; + return true; +} + +heracles::fhe_trace::Trace load_json_trace(const std::string &filename) +{ + std::ifstream json_ifile(filename); + if (!json_ifile.is_open()) + { + throw std::runtime_error("Cannot open file: " + filename); + } + + std::string json_str((std::istreambuf_iterator(json_ifile)), std::istreambuf_iterator()); + + heracles::fhe_trace::Trace trace; + auto status = google::protobuf::util::JsonStringToMessage(json_str, &trace); + if (!status.ok()) + { + throw std::runtime_error("Cannot parse JSON from file: " + filename); + } + + return trace; +} +} // namespace heracles::fhe_trace diff --git a/p-isa_tools/data_formats/cpp/heracles/util/util.cpp b/p-isa_tools/data_formats/cpp/heracles/util/util.cpp new file mode 100644 index 00000000..fc9c014f --- /dev/null +++ b/p-isa_tools/data_formats/cpp/heracles/util/util.cpp @@ -0,0 +1,281 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "heracles/util/util.h" +#include +#include + +namespace hmath = heracles::math; + +namespace heracles::util +{ +namespace data +{ + // some utility functions needed during HEC transformations + + // - montgomery transformation + uint32_t convert_to_montgomery(const uint32_t num, const std::uint32_t modulus) + { + return static_cast((static_cast(num) << 32) % modulus); + } + + // - montgomery to normal conversion + uint32_t convert_to_normal(const uint32_t num, const std::uint32_t modulus) + { + uint64_t inv_r = 0; + hmath::try_invert_uint_mod(montgomery_R, static_cast(modulus), &inv_r); // 64 + return hmath::multiply_uint_mod(num, static_cast(inv_r), modulus); + } + + uint32_t convert_to_normal_inv_r(const uint32_t num, const std::uint32_t inv_r, const std::uint32_t modulus) + { + return hmath::multiply_uint_mod(num, inv_r, modulus); + } + + void poly_bit_reverse(heracles::data::RNSPolynomial *dst, const heracles::data::RNSPolynomial &src) + { + std::size_t degree = src.coeffs_size(); + std::size_t logDegree = log2(degree); + if (degree != static_cast(1 << logDegree)) + throw std::runtime_error("RNS polynomial degree mismatch"); + + std::vector tmp(degree); + +#pragma omp parallel for + for (size_t i = 0; i < degree; i++) + { + tmp[i] = src.coeffs(hmath::reverse_bits(i, logDegree)); + } + *(dst->mutable_coeffs()) = { tmp.begin(), tmp.end() }; + dst->set_modulus(src.modulus()); + } + + void poly_bit_reverse(heracles::data::RNSPolynomial *dst, const std::vector &src) + { + std::size_t degree = src.size(); + std::size_t logDegree = log2(degree); + if (degree != static_cast(1 << logDegree)) + throw std::runtime_error("RNS polynomial degree mismatch"); + + std::vector tmp(degree); + +#pragma omp parallel for + for (size_t i = 0; i < degree; i++) + { + tmp[i] = src[hmath::reverse_bits(i, logDegree)]; + } + *(dst->mutable_coeffs()) = { tmp.begin(), tmp.end() }; + } + + void poly_bit_reverse_inplace(heracles::data::RNSPolynomial *src) + { + heracles::data::RNSPolynomial tmp; + poly_bit_reverse(&tmp, *src); + *src = tmp; + } + + void transform_and_flatten_key_switch( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::KeySwitch &data) + { + for (int d = 0; d < data.digits_size(); ++d) + { + for (int p = 0; p < data.digits(d).polys_size(); ++p) + { + std::string flatten_prefix = prefix + "_" + std::to_string(p) + "_" + std::to_string(d); + transform_and_flatten_poly(poly_symbols, flatten_prefix, data.digits(d).polys(p)); + } + } + } + + void transform_and_flatten_ciphertext( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Ciphertext &data) + { + for (int p = 0; p < data.polys_size(); ++p) + { + std::string flatten_prefix = prefix + "_" + std::to_string(p); + transform_and_flatten_poly(poly_symbols, flatten_prefix, data.polys(p)); + } + } + + void transform_and_flatten_plaintext( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Plaintext &data) + { + transform_and_flatten_poly(poly_symbols, prefix, data.poly()); + } + + void transform_and_flatten_dcrtpoly( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::DCRTPoly &data) + { + for (int p = 0; p < data.polys_size(); ++p) + { + std::string flatten_prefix = prefix + "_" + std::to_string(p); + transform_and_flatten_poly(poly_symbols, flatten_prefix, data.polys(p)); + } + } + + void transform_and_flatten_poly( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Polynomial &poly) + { + for (int r = 0; r < poly.rns_polys_size(); ++r) + { + std::string poly_prefix = prefix + "_" + std::to_string(r); + auto sym_poly_map = poly_symbols->mutable_sym_poly_map(); + std::vector tmp(poly.rns_polys(r).coeffs_size()); +#pragma omp parallel for + for (int j = 0; j < poly.rns_polys(r).coeffs_size(); ++j) + tmp[j] = convert_to_montgomery(poly.rns_polys(r).coeffs(j), poly.rns_polys(r).modulus()); + + poly_bit_reverse(&(*sym_poly_map)[poly_prefix], tmp); + (*sym_poly_map)[poly_prefix].set_modulus(poly.rns_polys(r).modulus()); + } + } + + void convert_rnspoly_to_original(heracles::data::RNSPolynomial *dest, const heracles::data::RNSPolynomial &src) + { + uint64_t inv_r = 0; + hmath::try_invert_uint_mod(montgomery_R, static_cast(src.modulus()), &inv_r); + + dest->set_modulus(src.modulus()); + std::vector tmp(src.coeffs_size()); +#pragma omp parallel for + for (int j = 0; j < src.coeffs_size(); ++j) + tmp[j] = convert_to_normal_inv_r(src.coeffs(j), static_cast(inv_r), src.modulus()); + + poly_bit_reverse(dest, tmp); + } + + std::tuple split_symbol_name(const std::string &sym) + { + std::vector buf; + + int loc = -1, prev_loc = 0; + do + { + loc = sym.find('_', loc + 1); + + auto tmp = sym.substr(prev_loc, loc == std::string::npos ? std::string::npos : loc - prev_loc); + buf.push_back(tmp); + prev_loc = loc + 1; + } while (loc != std::string::npos && buf.size() <= 2); + + if (buf.size() != 3) + throw std::runtime_error("Symbol name is not in correct form"); + + return { buf[0], std::stoul(buf[1]), std::stoul(buf[2]) }; + } + + std::vector toIndex(const std::string &key) + { + std::istringstream ss(key); + std::vector indices; + std::string buf; + while (std::getline(ss, buf, '_')) + { + // skip if not digit + if (buf.find_first_not_of("0123456789") == std::string::npos) + indices.push_back(std::stoul(buf)); + } + return indices; + } + + std::string toStrKey(const std::vector &indices) + { + std::ostringstream key; + std::string sep = ""; + for (const auto &idx : indices) + { + key << sep << idx; + sep = "_"; + } + return key.str(); + } +} // namespace data + +namespace fhe_trace +{ + void print_instruction(const heracles::fhe_trace::Instruction &inst, const std::string &header, bool printBKops) + { + if (printBKops || inst.op().substr(0, 3) != "bk_") + std::cout << header << (header.length() > 0 ? " " : "") << inst << std::endl; + } + + std::ostream &operator<<(std::ostream &out, const heracles::fhe_trace::Instruction &inst) + { + std::string op = inst.op(); + out << op << DELIMITER; + + auto dest = inst.args().dests(0); + out << dest.symbol_name() << DELIMITER << dest.num_rns() << DELIMITER << dest.order() << DELIMITER; + + for (const auto &src : inst.args().srcs()) + { + out << src.symbol_name() << DELIMITER << src.num_rns() << DELIMITER << src.order() << DELIMITER; + } + + for (const auto &[k, v] : inst.args().params()) + { + out << v.value() << DELIMITER; + } + + return out; + } + + void print_trace(const heracles::fhe_trace::Trace &trace) + { + std::string scheme = heracles::common::Scheme_descriptor()->FindValueByNumber(trace.scheme())->name(); + scheme = scheme.substr(7); // Remove "SCHEME_" prefix + std::uint32_t N = trace.n(); + + size_t sz_instructions = trace.instructions_size(); + for (size_t i = 0; i < sz_instructions; ++i) + { + auto inst = trace.instructions(i); + + std::cout << i << ":"; + std::cout << scheme << DELIMITER << N << DELIMITER << inst << std::endl; + } + } + + std::pair, std::vector> get_symbols( + const heracles::fhe_trace::Instruction &inst) + { + std::pair, std::vector> res; + + for (const auto &dest : inst.args().dests()) + res.second.push_back(dest.symbol_name()); + + for (const auto &src : inst.args().srcs()) + res.first.push_back(src.symbol_name()); + + return res; + } + std::pair, std::unordered_set> get_all_symbols( + const heracles::fhe_trace::Trace &trace, bool exclusive_outputs) + { + std::unordered_set symbols_input; + std::unordered_set symbols_output; + for (const auto &instruction : trace.instructions()) + { + std::string op = instruction.op(); + if (op.substr(0, 3) == "bk_") + continue; + + auto [src_symbols, dest_symbols] = get_symbols(instruction); + for (const std::string &sym : src_symbols) + symbols_input.insert(sym); + for (const std::string &sym : dest_symbols) + symbols_output.insert(sym); + } + if (exclusive_outputs) + { + std::unordered_set tmp; + std::set_difference( + symbols_output.begin(), symbols_output.end(), symbols_input.begin(), symbols_input.end(), + std::inserter(tmp, tmp.begin())); + symbols_output = tmp; + } + + return std::make_pair(symbols_input, symbols_output); + } +} // namespace fhe_trace +} // namespace heracles::util diff --git a/p-isa_tools/data_formats/cpp/include/heracles/data/io.h b/p-isa_tools/data_formats/cpp/include/heracles/data/io.h new file mode 100644 index 00000000..1ba7ba1b --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/data/io.h @@ -0,0 +1,38 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include "heracles/proto/data.pb.h" + +namespace heracles::data +{ +using hdf_manifest = std::unordered_map>; + +hdf_manifest parse_manifest(const std::string &filename); +void generate_manifest(const std::string &filename, const hdf_manifest &manifest); + +bool store_data_trace( + const std::string &filename, const heracles::data::FHEContext &context_pb, + const heracles::data::TestVector &testvector_pb); +std::pair load_data_trace(const std::string &filename); + +void store_hec_context( + hdf_manifest *manifest_out, const std::string &filename, const heracles::data::FHEContext &context_pb); +void store_testvector( + hdf_manifest *manifest_out, const std::string &filename, const heracles::data::TestVector &testvector_pb); +heracles::data::FHEContext load_hec_context(const std::string &filename); +heracles::data::TestVector load_testvector(const std::string &filename); + +void load_hec_context_from_manifest(heracles::data::FHEContext *context_pb, const hdf_manifest &manifest); +void load_testvector_from_manifest(heracles::data::TestVector *testvector_pb, const hdf_manifest &manifest); + +//================================== +// For debugging +//================================== +bool store_hec_context_json(const std::string &filename, const heracles::data::FHEContext &context); +bool store_testvector_json(const std::string &filename, const heracles::data::TestVector &test_vector); +} // namespace heracles::data diff --git a/p-isa_tools/data_formats/cpp/include/heracles/data/math.h b/p-isa_tools/data_formats/cpp/include/heracles/data/math.h new file mode 100644 index 00000000..d9d623c2 --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/data/math.h @@ -0,0 +1,406 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace heracles::math +{ +template +struct isUInt : std::conditional< + std::is_integral::value && std::is_unsigned::value && + ((sizeof(T) == sizeof(std::uint64_t)) || (sizeof(T) == sizeof(std::uint32_t))), + std::true_type, std::false_type>::type +{}; +template +constexpr bool is_uint_v = isUInt::value; + +template >> +inline T add_uint_mod(const T operand1, const T operand2, const T modulus) +{ + T res = operand1 + operand2; + return res >= modulus ? res - modulus : res; +} + +template >> +inline T negate_uint_mod(const T operand, const T modulus) +{ + T non_zero = static_cast(operand != 0); + return (modulus - operand) & static_cast(-non_zero); +} + +template +inline void multiply_uint(T operand1, T operand2, T *result) +{ + throw std::logic_error("undefined behavior"); +} + +template <> +void multiply_uint(uint32_t operand1, uint32_t operand2, uint32_t *result); + +template <> +void multiply_uint(uint64_t operand1, uint64_t operand2, uint64_t *result); + +template +inline size_t get_msb_index(T value) +{ + throw std::logic_error("undefined behavior"); +} + +template <> +size_t get_msb_index(uint32_t value); +template <> +size_t get_msb_index(uint64_t value); + +template >> +inline int get_significant_bit_count(T value) +{ + if (value == 0) + return 0; + return static_cast(get_msb_index(value) + 1); +} + +template >> +inline int get_significant_bit_count_uint(const T *value, size_t uint_count) +{ + const size_t Tbitsz = sizeof(T) * 8; + value += uint_count - 1; + while (*value == 0 && uint_count > 1) + { + uint_count--; + value--; + } + return static_cast(uint_count - 1) * Tbitsz + get_significant_bit_count(*value); +} + +template >> +inline void right_shift_uint3(const T *operand, int shift_amount, T *result) +{ + const size_t Tbitsz = sizeof(T) * 8; + const size_t shift_amount_sz = static_cast(shift_amount); + if (shift_amount_sz & (Tbitsz * 2)) + { + result[0] = operand[2]; + result[1] = 0; + result[2] = 0; + } + else if (shift_amount_sz & Tbitsz) + { + result[0] = operand[1]; + result[1] = operand[2]; + result[2] = 0; + } + else + { + result[0] = operand[0]; + result[1] = operand[1]; + result[2] = operand[2]; + } + + size_t bit_shift_amount = shift_amount_sz & (Tbitsz - 1); + if (bit_shift_amount) + { + size_t neg_bit_shift_amount = Tbitsz - bit_shift_amount; + result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount); + result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount); + result[2] = result[2] >> bit_shift_amount; + } +} + +template >> +inline void left_shift_uint3(const T *operand, int shift_amount, T *result) +{ + const size_t Tbitsz = sizeof(T) * 8; + const size_t shift_amount_sz = static_cast(shift_amount); + if (shift_amount_sz & (Tbitsz * 2)) + { + result[2] = operand[0]; + result[1] = 0; + result[0] = 0; + } + else if (shift_amount_sz & Tbitsz) + { + result[2] = operand[1]; + result[1] = operand[0]; + result[0] = 0; + } + else + { + result[2] = operand[2]; + result[1] = operand[1]; + result[0] = operand[0]; + } + + size_t bit_shift_amount = shift_amount_sz & (Tbitsz - 1); + if (bit_shift_amount) + { + size_t neg_bit_shift_amount = Tbitsz - bit_shift_amount; + result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount); + result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount); + result[0] = result[0] << bit_shift_amount; + } +} + +template >> +inline unsigned char add_uint(T operand1, T operand2, T *result) +{ + *result = operand1 + operand2; + return *result < operand1; +} + +template >> +inline unsigned char add_uint(T operand1, T operand2, unsigned char carry, T *result) +{ + operand1 += operand2; + *result = operand1 + carry; + return (operand1 < operand2) || (~operand1 < carry); +} + +template >> +inline unsigned char add_uint_base(const T *operand1, const T *operand2, size_t uint_count, T *result) +{ + unsigned char carry = add_uint(*operand1++, *operand2++, result++); + for (; --uint_count; operand1++, operand2++, result++) + { + T temp_result; + carry = add_uint(*operand1, *operand2, carry, &temp_result); + *result = temp_result; + } + return carry; +} + +template >> +inline unsigned char sub_uint(T operand1, T operand2, T *result) +{ + *result = operand1 - operand2; + return operand2 > operand1; +} + +template >> +inline unsigned char sub_uint(T operand1, T operand2, unsigned char borrow, T *result) +{ + T diff = operand1 - operand2; + *result = diff - (borrow != 0); + return (diff > operand1) || (diff < borrow); +} + +template >> +inline unsigned char sub_uint_base(const T *operand1, const T *operand2, size_t uint_count, T *result) +{ + unsigned char borrow = sub_uint(*operand1++, *operand2++, result++); + for (; --uint_count; operand1++, operand2++, result++) + { + T temp_result; + borrow = sub_uint(*operand1, *operand2, borrow, &temp_result); + *result = temp_result; + } + return borrow; +} + +template >> +inline void set_zero_uint(size_t uint_count, T *result) +{ + std::fill_n(result, uint_count, static_cast(0)); +} + +template >> +inline void divide_uint3_inplace(T *numerator, T denominator, T *quotient) +{ + size_t Tbitsz = sizeof(T) * 8; + size_t uint_count = 3; + quotient[0] = 0; + quotient[1] = 0; + quotient[2] = 0; + + int numerator_bits = get_significant_bit_count_uint(numerator, uint_count); + int denominator_bits = get_significant_bit_count(denominator); + + if (numerator_bits < denominator_bits) + return; + + uint_count = static_cast((numerator_bits + Tbitsz - 1) / Tbitsz); + if (uint_count == 1) + { + *quotient = *numerator / denominator; + *numerator -= *quotient * denominator; + return; + } + + std::vector shifted_denominator(uint_count, 0); + shifted_denominator[0] = denominator; + + std::vector difference(uint_count); + int denominator_shift = numerator_bits - denominator_bits; + + heracles::math::left_shift_uint3(shifted_denominator.data(), denominator_shift, shifted_denominator.data()); + denominator_bits += denominator_shift; + + int remaining_shifts = denominator_shift; + + while (numerator_bits == denominator_bits) + { + if (heracles::math::sub_uint_base(numerator, shifted_denominator.data(), uint_count, difference.data())) + { + if (remaining_shifts == 0) + break; + heracles::math::add_uint_base(difference.data(), numerator, uint_count, difference.data()); + heracles::math::left_shift_uint3(quotient, 1, quotient); + remaining_shifts--; + } + quotient[0] |= 1; + numerator_bits = heracles::math::get_significant_bit_count_uint(difference.data(), uint_count); + int numerator_shift = denominator_bits - numerator_bits; + numerator_shift = std::min(numerator_shift, remaining_shifts); + + if (numerator_bits > 0) + { + left_shift_uint3(difference.data(), numerator_shift, numerator); + numerator_bits += numerator_shift; + } + else + heracles::math::set_zero_uint(uint_count, numerator); + + heracles::math::left_shift_uint3(quotient, numerator_shift, quotient); + remaining_shifts -= numerator_shift; + } + if (numerator_bits > 0) + heracles::math::right_shift_uint3(numerator, denominator_shift, numerator); +} + +template >> +inline T multiply_uint_mod(const T operand1, const T operand2, const T modulus) +{ + if (modulus == 0) + throw std::invalid_argument("modulus cannot be zero"); + + T prod[2]; + multiply_uint(operand1, operand2, prod); + + // barrett reduction 32-bit + T numerator[3]{ 0, 0, 1 }; + T quotient[3]{ 0, 0, 0 }; + + heracles::math::divide_uint3_inplace(numerator, modulus, quotient); + + std::vector const_ratio{ quotient[0], quotient[1], numerator[0] }; + + T tmp1, tmp2[2], tmp3, carry[2]; + + multiply_uint(prod[0], const_ratio[0], carry); + + heracles::math::multiply_uint(prod[0], const_ratio[1], tmp2); + tmp3 = tmp2[1] + heracles::math::add_uint(tmp2[0], carry[1], &tmp1); + + heracles::math::multiply_uint(prod[1], const_ratio[0], tmp2); + carry[1] = tmp2[1] + heracles::math::add_uint(tmp1, tmp2[0], &tmp1); + + tmp1 = prod[1] * const_ratio[1] + tmp3 + carry[1]; + tmp3 = prod[0] - tmp1 * modulus; + + return tmp3 >= modulus ? tmp3 - modulus : tmp3; +} + +template >> +inline T exponentiate_uint_mod(const T operand, T exponent, const T modulus) +{ + if (exponent == 0) + return 1; + if (exponent == 1) + return operand; + T power = operand; + T product = 0; + T intermediate = 1; + while (true) + { + if (exponent & 1) + { + product = multiply_uint_mod(power, intermediate, modulus); + std::swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) + break; + product = multiply_uint_mod(power, power, modulus); + std::swap(product, power); + } + return intermediate; +} + +std::tuple xgcd(uint64_t x, uint64_t y); +std::tuple xgcd(uint32_t x, uint32_t y); + +template >> +inline bool try_invert_uint_mod(const T value, const T modulus, T *result) +{ + if (value == 0) + return false; + + auto gcd_tuple = xgcd(value, modulus); + if (std::get<0>(gcd_tuple) != 1) + return false; + else if (std::get<1>(gcd_tuple) < 0) + { + *result = static_cast(std::get<1>(gcd_tuple)) + modulus; + return true; + } + + *result = static_cast(std::get<1>(gcd_tuple)); + return true; +} + +template >> +T get_invert_uint_mod(const T value, const T modulus) +{ + T result; + if (!try_invert_uint_mod(value, modulus, &result)) + { + std::ostringstream msg; + msg << "Cannot invert value " << value << " with modulus " << modulus; + throw std::runtime_error(msg.str()); + } + + return result; +} +std::uint32_t reverse_bits(const std::uint32_t operand, std::uint32_t bit_count = 32); + +inline std::uint32_t montgomeryAdd(const std::uint32_t a, const std::uint32_t b, const std::uint32_t modulus) +{ + return heracles::math::add_uint_mod(a, b, modulus); +} + +inline std::uint32_t montgomeryMul( + const std::uint32_t a, const std::uint32_t b, const std::uint32_t modulus, bool use_mont = true) +{ + if (!use_mont) + return a * b % modulus; + + std::uint32_t u[2]; + heracles::math::multiply_uint(a, b, u); + + std::uint32_t k = modulus - 2; + std::uint32_t m[2]; + // u[0] = lower 32bit + heracles::math::multiply_uint(u[0], k, m); + + // z = low 32bit m (m[0]) * modulus + std::uint32_t z[2]; + heracles::math::multiply_uint(m[0], modulus, z); + + // _u = u + z + std::uint32_t _u[2]; + heracles::math::add_uint_base(u, z, 2, _u); + + // return high 32bit (_u[1]) + return _u[1] < modulus ? _u[1] : _u[1] - modulus; +} +} // namespace heracles::math diff --git a/p-isa_tools/data_formats/cpp/include/heracles/data/transform.h b/p-isa_tools/data_formats/cpp/include/heracles/data/transform.h new file mode 100644 index 00000000..168b6c74 --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/data/transform.h @@ -0,0 +1,54 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heracles/heracles_proto.h" + +namespace heracles::data +{ +// Note: extraction does all expansion and transformation, e.g., bit-reversal and montgomery conversion of +// data in context & test-vectors. I.e., test-vector and context are mostly HEC agnostic ... + +/* + Extract symbol/value map of all metadata polynomials as needed to build (after swizzling) memory images or + DMA downloads + */ +void extract_metadata_polys( + heracles::data::MetadataPolynomials *metadata_polys, const heracles::data::FHEContext &context); + +/* + Extract symbol/value map of all twiddles as needed to build (after swizzling & replicating) memory images or + DMA downloads + */ +void extract_metadata_twiddles( + heracles::data::MetadataTwiddles *metadata_twiddles, const heracles::data::FHEContext &context); + +/* + Extract symbol/value map of all immediates as needed for final code instantiation + */ +bool extract_metadata_immediates( + heracles::data::MetadataImmediates *metadata_immediates, const heracles::data::FHEContext &context); + +/* + Extract symbol/value map of all input/output polynomials as needed to build (after swizzling) memory images + or DMA downloads + */ +void extract_polys(heracles::data::DataPolynomials *polys, const heracles::data::TestVector &testvector); + +/* + Extract metadata parameters (no polynomials, immediates and twiddles) - downsized context + */ +void extract_metadata_params( + heracles::data::MetadataParams *metadata_params, const heracles::data::FHEContext &context); + +void convert_polys_to_testvector(heracles::data::TestVector *testvector, const heracles::data::DataPolynomials &polys); + +/* + Prune data polynomials based on trace - unused data are removed + */ +void prune_polys( + heracles::data::TestVector *testvector, const heracles::data::FHEContext &context, + const heracles::fhe_trace::Trace &trace); + +} // namespace heracles::data diff --git a/p-isa_tools/data_formats/cpp/include/heracles/fhe_trace/io.h b/p-isa_tools/data_formats/cpp/include/heracles/fhe_trace/io.h new file mode 100644 index 00000000..4d1d7545 --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/fhe_trace/io.h @@ -0,0 +1,32 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include "heracles/proto/fhe_trace.pb.h" + +namespace heracles::fhe_trace +{ +/* + Serialize and store a HE op trace. + */ +bool store_trace(const std::string &filename, const heracles::fhe_trace::Trace &trace); + +/* + Load and deserialize a HEC context. + */ +heracles::fhe_trace::Trace load_trace(const std::string &filename); + +/* + Serialize and store a HE op trace in JSON format. + */ +bool store_json_trace(const std::string &filename, const heracles::fhe_trace::Trace &trace); + +/* + Load and deserialize a HE op trace from JSON format. + */ +heracles::fhe_trace::Trace load_json_trace(const std::string &filename); + +} // namespace heracles::fhe_trace diff --git a/p-isa_tools/data_formats/cpp/include/heracles/heracles_data_formats.h b/p-isa_tools/data_formats/cpp/include/heracles/heracles_data_formats.h new file mode 100644 index 00000000..467efdcf --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/heracles_data_formats.h @@ -0,0 +1,10 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "heracles/data/io.h" +#include "heracles/data/math.h" +#include "heracles/data/transform.h" +#include "heracles/fhe_trace/io.h" +#include "heracles/util/util.h" diff --git a/p-isa_tools/data_formats/cpp/include/heracles/util/util.h b/p-isa_tools/data_formats/cpp/include/heracles/util/util.h new file mode 100644 index 00000000..61514979 --- /dev/null +++ b/p-isa_tools/data_formats/cpp/include/heracles/util/util.h @@ -0,0 +1,92 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "heracles/data/math.h" +#include "heracles/proto/data.pb.h" +#include "heracles/proto/fhe_trace.pb.h" + +namespace heracles::util +{ +namespace data +{ + // some utility functions needed during HEC transformations + + // - montgomery transform + // - constants + const uint64_t montgomery_R_bits = 32; + const uint64_t montgomery_R = 1ULL << montgomery_R_bits; + + // - single value + // - montgomery transformation + uint32_t convert_to_montgomery(const uint32_t num, const std::uint32_t modulus); + // - montgomery to normal conversion + uint32_t convert_to_normal(const uint32_t num, const std::uint32_t modulus); + uint32_t convert_to_normal_inv_r(const uint32_t num, const std::uint32_t inv_r, const std::uint32_t modulus); + + // - bit-reversal + // - shuffle power-of-two-sized poly vector according to bit-reversal of index + // Note: we assume dst is allocated object but does not contain anything! + void poly_bit_reverse(heracles::data::RNSPolynomial *dst, const heracles::data::RNSPolynomial &src); + void poly_bit_reverse(heracles::data::RNSPolynomial *dst, const std::vector &src); + void poly_bit_reverse_inplace(heracles::data::RNSPolynomial *src); + + void transform_and_flatten_key_switch( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::KeySwitch &data); + void transform_and_flatten_ciphertext( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Ciphertext &data); + void transform_and_flatten_plaintext( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Plaintext &data); + void transform_and_flatten_dcrtpoly( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::DCRTPoly &data); + void transform_and_flatten_poly( + heracles::data::PolySymbols *poly_symbols, const std::string &prefix, const heracles::data::Polynomial &poly); + + void convert_rnspoly_to_original(heracles::data::RNSPolynomial *dest, const heracles::data::RNSPolynomial &src); + + std::tuple split_symbol_name(const std::string &sym); + // convert protobuf map of Array field to index vector + std::vector toIndex(const std::string &key); + std::string toStrKey(const std::vector &indices); +} // namespace data + +namespace fhe_trace +{ + constexpr char DELIMITER = ','; + + /* + Print single instruction + */ + void print_instruction( + const heracles::fhe_trace::Instruction &inst, const std::string &header = "", bool printBKops = false); + + std::ostream &operator<<(std::ostream &out, const heracles::fhe_trace::Instruction &inst); + + /* + Print trace + */ + void print_trace(const heracles::fhe_trace::Trace &trace); + + /* + Get input(s) and output symbols of instruction pb + */ + std::pair, std::vector> get_symbols( + const heracles::fhe_trace::Instruction &inst); + + /* + Get all input(s) and output symbols of trace pb + If exclusive_outputs==true, return outputs that are never used as inputs + */ + std::pair, std::unordered_set> get_all_symbols( + const heracles::fhe_trace::Trace &trace, bool exclusive_outputs = false); + +} // namespace fhe_trace +} // namespace heracles::util diff --git a/p-isa_tools/data_formats/proto/.clang-format b/p-isa_tools/data_formats/proto/.clang-format new file mode 100644 index 00000000..9d49187f --- /dev/null +++ b/p-isa_tools/data_formats/proto/.clang-format @@ -0,0 +1,5 @@ +{BasedOnStyle: Google, +AlignConsecutiveDeclarations: true, +AlignConsecutiveAssignments: true, +ColumnLimit: 0, +IndentWidth: 4} diff --git a/p-isa_tools/data_formats/proto/CMakeLists.txt b/p-isa_tools/data_formats/proto/CMakeLists.txt new file mode 100644 index 00000000..bbc8af8a --- /dev/null +++ b/p-isa_tools/data_formats/proto/CMakeLists.txt @@ -0,0 +1,168 @@ +set(HERACLES_OUTPUT_DIR ${PROJECT_BINARY_DIR}/heracles) +set(HERACLES_PROTO_OUTPUT_DIR ${HERACLES_OUTPUT_DIR}/proto) +set(HERACLES_PYTHON_PROTO_OUTPUT_DIR ${PROJECT_BINARY_DIR}/python/heracles/proto) + +file(MAKE_DIRECTORY ${HERACLES_PROTO_OUTPUT_DIR}) +file(MAKE_DIRECTORY ${HERACLES_PYTHON_PROTO_OUTPUT_DIR}) + +# amalgamated single header for protobuf generated files +configure_file(heracles_proto.h.in ${HERACLES_OUTPUT_DIR}/heracles_proto.h) + +######## +# Build HERACLES_data_proto +######## +add_library(HERACLES_data_proto + OBJECT + heracles/common.proto + heracles/maps.proto + heracles/fhe_trace.proto + heracles/data.proto +) + +target_include_directories(HERACLES_data_proto + PUBLIC + $ + $ + $ + ${protobuf_INCLUDE_DIR} +) + +# include custom protobuf-generate function +include(${PROJECT_SOURCE_DIR}/cmake/protobuf-generate.cmake) + +protobuf_generate( + TARGET HERACLES_data_proto + LANGUAGE cpp + IMPORT_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/heracles ${protobuf_SOURCE_DIR}/src + PROTOC_OUT_DIR ${HERACLES_PROTO_OUTPUT_DIR} +) + +# generate python protobuf library +protobuf_generate( + TARGET HERACLES_data_proto + LANGUAGE python + IMPORT_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/heracles ${protobuf_SOURCE_DIR}/src + PROTOC_OUT_DIR ${HERACLES_PYTHON_PROTO_OUTPUT_DIR} +) + +# patch generated Python protobuf files to support "relative" paths for modules +# this allows using the generated protobuf python files to be correctly imported +add_custom_target(HERACLES_DATA_PROTO_PATCH + ALL + DEPENDS + ${HERACLES_PYTHON_PROTO_OUTPUT_DIR}/fhe_trace_pb2.py + ${HERACLES_PYTHON_PROTO_OUTPUT_DIR}/data_pb2.py +) + +# BSD sed requires space between -i '' +if(APPLE) + set(sed_space " ") +endif() + +# replaces "import common_pb2" to "from . import common_pb2" +add_custom_command( + TARGET HERACLES_DATA_PROTO_PATCH + POST_BUILD + COMMAND sed -i${sed_space}'' '/^import common_pb2 as/s/^/from . /' *_pb2.py + WORKING_DIRECTORY ${HERACLES_PYTHON_PROTO_OUTPUT_DIR} +) + +# install generated protobuf headers +install(DIRECTORY ${HERACLES_OUTPUT_DIR}/ + DESTINATION include/heracles + FILES_MATCHING + PATTERN "*.h" +) + +# install generated python protobuf files +install(DIRECTORY ${HERACLES_PYTHON_PROTO_OUTPUT_DIR}/ + DESTINATION python/heracles/proto + FILES_MATCHING + PATTERN "*.py" +) + +# main library build +add_library(heracles_data_formats + SHARED + ${HERACLES_PROTO_OUTPUT_DIR}/common.pb.cc + ${HERACLES_PROTO_OUTPUT_DIR}/data.pb.cc + ${HERACLES_PROTO_OUTPUT_DIR}/maps.pb.cc + ${HERACLES_PROTO_OUTPUT_DIR}/fhe_trace.pb.cc +) +add_library(HERACLES_DATA_FORMATS::heracles_data_formats ALIAS heracles_data_formats) + +set_target_properties(heracles_data_formats PROPERTIES + BUILD_WITH_INSTALL_RPATH FALSE + LINK_FLAGS "-Wl,-rpath,'$ORIGIN'" +) + +target_include_directories(heracles_data_formats + PUBLIC $ + PUBLIC $ + PUBLIC $ + PUBLIC $ +) + +# Required for linking protobuf library +target_link_directories(heracles_data_formats + PUBLIC $ + PUBLIC $ +) + +find_package(ZLIB REQUIRED) +target_link_libraries(heracles_data_formats PRIVATE ZLIB::ZLIB) +if(UNIX AND NOT APPLE) + target_link_libraries(heracles_data_formats PRIVATE rt) +endif() + +add_dependencies(heracles_data_formats ext_protobuf) +foreach(_protobuf_lib_name ${protobuf_LIB_NAMES}) + target_link_libraries(heracles_data_formats PUBLIC ${_protobuf_lib_name}) +endforeach() + +install(TARGETS heracles_data_formats + DESTINATION lib +) + +# install and find_package mechanism +set(HERACLES_DATA_FORMATS_TARGET_FILENAME ${PROJECT_BINARY_DIR}/cmake/HERACLES_DATA_FORMATSTargets.cmake) +set(HERACLES_DATA_FORMATS_CONFIG_IN_FILENAME ${HERACLES_DATA_FORMATS_CMAKE_PATH}/HERACLES_DATA_FORMATSConfig.cmake.in) +set(HERACLES_DATA_FORMATS_CONFIG_FILENAME ${PROJECT_BINARY_DIR}/cmake/HERACLES_DATA_FORMATSConfig.cmake) +set(HERACLES_DATA_FORMATS_CONFIG_VERSION_FILENAME ${PROJECT_BINARY_DIR}/cmake/HERACLES_DATA_FORMATSConfigVersion.cmake) +set(HERACLES_DATA_FORMATS_CONFIG_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}/cmake/HERACLES_DATA_FORMATS-${HERACLES_DATA_FORMATS_VERSION}) + +install( + EXPORT HERACLES_DATA_FORMATSTargets + NAMESPACE HERACLES_DATA_FORMATS:: + DESTINATION ${HERACLES_DATA_FORMATS_CONFIG_INSTALL_DIR} +) + +write_basic_package_version_file( + ${HERACLES_DATA_FORMATS_CONFIG_VERSION_FILENAME} + VERSION ${HERACLES_DATA_FORMATS_VERSION} + COMPATIBILITY ExactVersion +) + +configure_package_config_file( + ${HERACLES_DATA_FORMATS_CONFIG_IN_FILENAME} ${HERACLES_DATA_FORMATS_CONFIG_FILENAME} + INSTALL_DESTINATION ${HERACLES_DATA_FORMATS_CONFIG_INSTALL_DIR} +) + +install( + TARGETS heracles_data_formats + EXPORT HERACLES_DATA_FORMATSTargets + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin +) + +install( + FILES ${HERACLES_DATA_FORMATS_CONFIG_FILENAME} + ${HERACLES_DATA_FORMATS_CONFIG_VERSION_FILENAME} + DESTINATION ${HERACLES_DATA_FORMATS_CONFIG_INSTALL_DIR} +) +# Added to enable non-installation build +export(EXPORT HERACLES_DATA_FORMATSTargets + NAMESPACE HERACLES_DATA_FORMATS:: + FILE ${HERACLES_DATA_FORMATS_TARGET_FILENAME} +) diff --git a/p-isa_tools/data_formats/proto/heracles/common.proto b/p-isa_tools/data_formats/proto/heracles/common.proto new file mode 100644 index 00000000..9d55bc95 --- /dev/null +++ b/p-isa_tools/data_formats/proto/heracles/common.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.EnumValueOptions { + optional string string_name = 123456789; +} + +package heracles.common; + +enum Scheme { + SCHEME_UNSPECIFIED = 0; + SCHEME_BGV = 1 [ + (string_name) = "bgv" + ]; + SCHEME_BFV = 2 [ + (string_name) = "bfv" + ]; + SCHEME_CKKS = 3 [ + (string_name) = "ckks" + ]; + // Specific for OpenFHE_CKKS : To be merged with ckks after + SCHEME_OPENFHE_CKKS = 4 [ + (string_name) = "openfhe_ckks" + ]; +} diff --git a/p-isa_tools/data_formats/proto/heracles/data.proto b/p-isa_tools/data_formats/proto/heracles/data.proto new file mode 100644 index 00000000..3d414bb3 --- /dev/null +++ b/p-isa_tools/data_formats/proto/heracles/data.proto @@ -0,0 +1,196 @@ +syntax = "proto3"; + +import "common.proto"; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.EnumValueOptions { + optional string datatype_name = 123456787; +} + +package heracles.data; + +// COMMON NAMES & NAMING CONVENTIONS +//================================= + +// - N : degree of polynomial +// - t : plaintext modulus +// - q_i: i-th (rns) prime ciphertext modulus, with i=0 the lowest level and the maximal i corresponding to the number of rns terms of a key +// - k: digit size of key-switch keys; 1 initially in SEAL until we add large-digit decomposition to SEAL +// - p : special prime, i.e., q_{rns-num-1}. Note this assumes single-prime-as-digit (i.e., k=1) as in SEAL, will have to extended for large-digit keyswitch as in OpenFHE +// - q_last : last prime at _current_ level of a ciphertext, not the really last one (which would be p) +// - psi_i: 2N-th root of unity of Z_q_i +// - omega_i = psi^2 = Nth root of unity of Z_q_i +// - ... + +// MAIN INTERFACE FOR PRODUCERS +//================================= + +// Note: These types are mostly HEC agnostic and all HEC specific +// expansion & transformation (e.g., bit-reversal, montgomery ..) happens +// in the extraction functions (see transform.h). The exception being that +// primes must be HEC ntt-friendly and RNS-32 ... + +// FHE context of a particular program, to be provided by a library or alike, covering mostly only minimally required information. +// Additional information & transformations required for HEC is handled in ExtendedFHEContext +// - filename: "hec_context.bin" +message FHEContext { + heracles.common.Scheme scheme = 1; + uint32 N = 2; // polynomial degree, must be power-of-two + uint32 key_rns_num = 3; // number of rns-terms of a key-switch key (and an upper-bound of number of rns-terms of any ciphertext) + uint32 digit_size = 4; // number of rns terms per digit for key switching, for SEAL initially always 1, OpenFHE = GetNumPerPartQ + repeated uint32 q_i = 5; // i-th (rns) prime ciphertext modulus, with first "composite_degree" the lowest level and last "few" used only in keys and during key-switch + repeated uint32 psi = 6; // 2n-th root of Z_{q_i}, implies elsewhere used n-th root omega = psi^2 + uint32 q_size = 7; // n(Q), for SEAL it is always ( key_rns_num - 1 ), OpenFHE (cc->GetElementParams()->GetParams().size()) + uint32 alpha = 8; // ceil(q_size/dnum) + oneof scheme_specific { + CKKSSpecific ckks_info = 9; + BGVSpecific bgv_info = 10; + } +} + +// TODO: to be implemented +enum DataType { + TYPE_UNSPECIFIED = 0; + + TYPE_CIPHERTEXT = 1 [ + (datatype_name) = "ciphertext" + ]; + TYPE_PLAINTEXT = 2 [ + (datatype_name) = "plaintext" + ]; + // unneeded.. we only care about ciphertext and plaintext + TYPE_DCRTPOLY = 3 [ + (datatype_name) = "dcrtpoly" + ]; + TYPE_KEYSWITCH = 4 [ + (datatype_name) = "keyswitch" + ]; +} + +// Test Vector data, i.e., input and golden outputvalues in original form +// - filename: "hec_testvec.bin" +message TestVector { + map sym_data_map = 1; // +} + +// TODO: merge this with above +message Data { + DCRTPoly dcrtpoly = 1; // this is only used for v2 +} + +// MAIN INTERFACE FOR CONSUMERs & +// ROOT TYPES FOR SERIALIZATIONS +//================================= + +// Map from symbols of inputs and output to a computation to their value +message DataPolynomials { + PolySymbols data = 1; +} + +// Map from symbols of metadata polynomials to their value +message MetadataPolynomials { + PolySymbols metadata = 1; +} + +// Map from symbols of metadata twiddles to their value +message MetadataTwiddles { + bool only_power_of_two = 1; + // depending on above, below are either N/2 powers omega^i + // or only the power of twos of omega (starting with power 0!) + // In either case it's not a "true" polynomial as for ciphertext, plaintext or keys + map twiddles_ntt = 2; + map twiddles_intt = 3; + // string index is the twiddle type, as e.g., returned by `map_twiddle_type` +} + +// Map from symbols of metadata immediates to their value +message MetadataImmediates { + map sym_immediate_map = 1; +} + +// Map from metadata parameters - avoids storing context +message MetadataParams { + map sym_param_map = 1; +} +//================================= +// AUXILIARY TYPES +//================================= + +message BGVSpecific { + repeated BGVPlaintextSpecific plaintext_specific = 1; // the index in this list (with initial element having index 0) will be the index used in 'plaintext_index' field of `heracles.fhe_trace.Instruction` field. + Ciphertext recrypt_key = 2; +} + +message BGVPlaintextSpecific { + Keys keys = 1; + uint64 plaintext_modulus = 2; +} + +message CKKSSpecific { + Keys keys = 1; + uint32 composite_degree = 2; // BASE_NUM_LEVELS_TO_DROP + // Scaling factors + repeated double scaling_factor_real = 3; // size: q_size (CryptoParametersRNS->GetScalingFactorReal()) + repeated double scaling_factor_real_big = 4; // size: q_size - 1(CryptoParametersRNS->GetScalingFactorRealBig()) + map metadata_extra = 5; +} + +message Keys { + KeySwitch relin_key = 1; + map rotation_keys = 2; +} + +message KeySwitch { + uint32 k = 1; // number of rns terms forming a digit, for SEAL initially always 1. + repeated Ciphertext digits = 2; // the digits which should have FHEContext.key_rns_num rns terms and of which there should be ceil(FHEContext.key_rns_num/k) many + // Polynomial ipsi = 3; +} + +message Ciphertext { + repeated Polynomial polys = 1; // usually repetition is 2, although can be 3 (although that is supported only as input for add* and relin) + bool in_ntt_form = 2; + uint32 level = 3; + uint32 depth = 4; + double scalingFactor = 5; + uint32 scalingFactorInt = 6; +} + +message Plaintext { + Polynomial poly = 1; // if not in_ntt_form this is only a single RNSPolynomial (and only possible if plaintext modulus is < 2^32) + bool in_ntt_form = 2; + uint32 level = 3; + uint32 depth = 4; + double scalingFactor = 5; + uint32 scalingFactorInt = 6; +} + +// DCRTPoly form +message DCRTPoly { + repeated Polynomial polys = 1; // size = order + bool in_ntt_form = 2; +} + +message Polynomial { + // typically ~10 for OpenFHE (size = curr_rns) + repeated RNSPolynomial rns_polys = 1; + bool in_OpenFHE_EVALUATION = 3; +} + +message RNSPolynomial { + repeated uint32 coeffs = 1; // repeated a power-of-two times with power-of-two + uint32 modulus = 2; // need modulus in case curr_rns < max_rns, as it will use part of Q and all of P +} + +message HECRNSPolynomial { + repeated HECBasePolynomial base_polys = 1; +} + +message HECBasePolynomial { + repeated uint32 coeffs = 1; // repeated exactly 8192 times +} + +message PolySymbols { + map sym_poly_map = 1; // + // name: keys -> { name_c : ctxt} -> { name_c_d : Poly } -> { name_c_d_r : RNS32Poly } -> { name_c_d_r_i : RNS32_8kPoly +} diff --git a/p-isa_tools/data_formats/proto/heracles/fhe_trace.proto b/p-isa_tools/data_formats/proto/heracles/fhe_trace.proto new file mode 100644 index 00000000..3daa106e --- /dev/null +++ b/p-isa_tools/data_formats/proto/heracles/fhe_trace.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; + +import "common.proto"; +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.EnumValueOptions { + optional string valuetype_name = 123456788; +} +package heracles.fhe_trace; + +// MAIN INTERFACE FOR BOTH CONSUMERs AND PRODUCERS +// ROOT TYPE FOR SERIALIZATIONS +//================================= + +// A sequence of instructions at an HE abstraction level. +// - filename: "hec_trace.bin" +message Trace { + // Sequence of HE instructions + repeated Instruction instructions = 1; + heracles.common.Scheme scheme = 2; + uint32 N = 3; // poly modulus degree + uint32 key_rns_num = 4; + uint32 q_size = 5; // n(Q) + uint32 dnum = 6; // digit size + uint32 alpha = 7; // ceil(n(Q)/dnum) Note: key_rns_num=n(Q) + n(P) +} + +message Instruction { + string op = 1; + uint32 plaintext_index = 2; // which plaintext algebra used, can be ignored for CKKS. Used as index into `plaintext_specific` field of `heracles.data.BGVSpecific` object inside `heracles.data.FHEContext`. + Operands args = 3; // inputs/outputs and additional params + string evalop_name = 4; // (OpenFHE specific) Evaluator level call tracking, helps identifying what eval op invoked atomic ops +} + +message Operands { + repeated OperandObject dests = 1; + repeated OperandObject srcs = 2; + map params = 3; +} + +message Parameter { + string value = 1; + ValueType type = 2; +} +enum ValueType { + UINT32 = 0 [ + (valuetype_name) = "UINT32" + ]; + UINT64 = 1 [ + (valuetype_name) = "UINT64" + ]; + INT32 = 2 [ + (valuetype_name) = "INT32" + ]; + INT64 = 3 [ + (valuetype_name) = "INT64" + ]; + FLOAT = 4 [ + (valuetype_name) = "FLOAT" + ]; + DOUBLE = 5 [ + (valuetype_name) = "DOUBLE" + ]; + STRING = 6 [ + (valuetype_name) = "STRING" + ]; +} + +message OperandObject { + string symbol_name = 1; + uint32 num_rns = 2; // size = curr_rns of dcrtpoly + uint32 order = 3; // typically 2 for ct/pt (can be 3), single DCRTPoly will always be 1 +} diff --git a/p-isa_tools/data_formats/proto/heracles/maps.proto b/p-isa_tools/data_formats/proto/heracles/maps.proto new file mode 100644 index 00000000..9fd4cb2e --- /dev/null +++ b/p-isa_tools/data_formats/proto/heracles/maps.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package heracles.maps; + +// MAIN INTERFACE FOR BOTH CONSUMERs AND PRODUCERS +// ROOT TYPE FOR SERIALIZATIONS +//================================= + +// Map associating (virtual) HBM address to symbols in code. +// Required in user-space loader library of driver to form the join with meta-data and input values +// when creating the DMA download requests and to find address of outputs to retrieve via DMA uploads. +// see also heracles.data.PolySymbols for "peer" map +message HBM { + map sym_addr_map = 1; +} diff --git a/p-isa_tools/data_formats/proto/heracles_proto.h.in b/p-isa_tools/data_formats/proto/heracles_proto.h.in new file mode 100644 index 00000000..30f385c8 --- /dev/null +++ b/p-isa_tools/data_formats/proto/heracles_proto.h.in @@ -0,0 +1,13 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + + + +#ifndef _heracles_proto_h_ +#define _heracles_proto_h_ + +#include "heracles/proto/common.pb.h" +#include "heracles/proto/data.pb.h" +#include "heracles/proto/fhe_trace.pb.h" +#include "heracles/proto/maps.pb.h" +#endif // _heracles_proto_h_ diff --git a/p-isa_tools/data_formats/python/__init__.py b/p-isa_tools/data_formats/python/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/p-isa_tools/data_formats/python/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/data_formats/python/heracles/.gitignore b/p-isa_tools/data_formats/python/heracles/.gitignore new file mode 100644 index 00000000..7e13bf5a --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/.gitignore @@ -0,0 +1,2 @@ +# protoc-generated files +/proto diff --git a/p-isa_tools/data_formats/python/heracles/data/__init__.py b/p-isa_tools/data_formats/python/heracles/data/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/data_formats/python/heracles/data/io.py b/p-isa_tools/data_formats/python/heracles/data/io.py new file mode 100644 index 00000000..e823681a --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/data/io.py @@ -0,0 +1,174 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +# TODO: create also C++ variants of below; given how simple and stable these functions should be just in replicated form, not shared code + +import json +import sys + +import heracles.proto.data_pb2 as hpd +from google.protobuf.json_format import MessageToDict + +# load & store functions +# =============================== + + +def parse_manifest(filename: str) -> dict: + manifest: dict = {} + with open(filename) as fp: + cur_field = None + found_first_field = False + for linenum, cur_line in enumerate(fp): + cur_line = cur_line.rstrip() + if cur_line.startswith("[") and cur_line.endswith("]"): + cur_field = cur_line[1:-1] + found_first_field = True + manifest[cur_field] = {} + continue + + if not found_first_field: + continue + cur_line_split = cur_line.split("=") + if len(cur_line_split) != 2: + print( + f"Warning: ignoring incorrect format in line {linenum}", + file=sys.stderr, + ) + continue + manifest[cur_field][cur_line_split[0]] = cur_line_split[1] + + if not found_first_field: + raise Exception(f"Incorrect manifest file format: {filename}") + + return manifest + + +def generate_manifest(filename: str, manifest: dict): + with open(filename, "w") as fp: + for field, values in manifest.items(): + fp.write(f"[{field}]\n") + for key, fn in values.items(): + fp.write(f"{key}={fn}\n") + + +# re-check +def store_hec_context_json(filename: str, context: hpd.FHEContext): + print( + "Warning: Dumping FHE Context data trace to json can take a long time", + file=sys.stderr, + ) + with open(filename, "w") as fp: + json.dump(MessageToDict(context), fp) + + +# re-check +def store_testvector_json(filename: str, testvector: hpd.TestVector): + print( + "Warning: Dumping TestVector data trace to json can take a long time", + file=sys.stderr, + ) + with open(filename, "w") as fp: + json.dump(MessageToDict(testvector), fp) + + +def load_hec_context_from_manifest(manifest: dict) -> hpd.FHEContext: + context_base_fn = manifest["context"]["main"] + context_pb = hpd.FHEContext() + + with open(context_base_fn, "rb") as fp: + context_pb.ParseFromString(fp.read()) + + if "rotation_keys" in manifest: + for ge, gk_fn in manifest["rotation_keys"].items(): + gk_pb = hpd.KeySwitch() + with open(gk_fn, "rb") as fp: + gk_pb.ParseFromString(fp.read()) + context_pb.ckks_info.keys.rotation_keys[int(ge)].CopyFrom(gk_pb) + + return context_pb + + +def store_hec_context(filename: str, context_pb: hpd.FHEContext) -> dict: + hec_context_manifest: dict = {"context": {}} + tmp_context = hpd.FHEContext() + tmp_context.CopyFrom(context_pb) + + if tmp_context.ByteSize() > 1 << 30: + hec_context_manifest["rotation_keys"] = {} + for gkct, (ge, gk_pb) in enumerate(tmp_context.ckks_info.keys.rotation_keys.items()): + parts_fn = f"{filename}_hec_context_part_{gkct + 1}" + hec_context_manifest["rotation_keys"][ge] = parts_fn + with open(parts_fn, "wb") as fp: + fp.write(gk_pb.SerializeToString()) + tmp_context.ckks_info.keys.ClearField("rotation_keys") + + main_fn = f"{filename}_hec_context_part_0" + hec_context_manifest["context"]["main"] = main_fn + with open(main_fn, "wb") as fp: + fp.write(tmp_context.SerializeToString()) + + return hec_context_manifest + + +def load_testvector_from_manifest(manifest: dict) -> hpd.TestVector: + # segmented + testvector_pb = hpd.TestVector() + if len(manifest["testvector"]) > 1: + for sym, parts_fn in manifest["testvector"].items(): + data = hpd.Data() + with open(parts_fn, "rb") as fp: + data.ParseFromString(fp.read()) + testvector_pb.sym_data_map[sym].CopyFrom(data) + # whole + else: + full_fn = manifest["testvector"]["full"] + with open(full_fn, "rb") as fp: + testvector_pb.ParseFromString(fp.read()) + + return testvector_pb + + +def store_testvector(filename: str, testvector_pb: hpd.TestVector) -> dict: + testvector_manifest: dict = {"testvector": {}} + if testvector_pb.ByteSize() > 1 << 30: + for tvct, (sym, data_pb) in enumerate(testvector_pb.sym_data_map.items()): + parts_fn = f"{filename}_testvector_part_{tvct}" + testvector_manifest["testvector"][sym] = parts_fn + with open(parts_fn, "wb") as fp: + fp.write(data_pb.SerializeToString()) + else: + full_fn = f"{filename}_testvector_part_0" + testvector_manifest["testvector"]["full"] = full_fn + with open(full_fn, "wb") as fp: + fp.write(testvector_pb.SerializeToString()) + + return testvector_manifest + + +def load_hec_context(filename: str) -> hpd.FHEContext: + manifest = parse_manifest(filename) + return load_hec_context_from_manifest(manifest) + + +def load_testvector(filename: str) -> hpd.FHEContext: + manifest = parse_manifest(filename) + return load_testvector_from_manifest(manifest) + + +def load_data_trace(filename: str) -> tuple[hpd.FHEContext, hpd.TestVector]: + manifest = parse_manifest(filename) + return ( + load_hec_context_from_manifest(manifest), + load_testvector_from_manifest(manifest), + ) + + +def store_data_trace(filename: str, context_pb: hpd.FHEContext, testvector_pb: hpd.TestVector): + generate_manifest( + filename, + { + **store_hec_context(filename, context_pb), + **store_testvector(filename, testvector_pb), + }, + ) diff --git a/p-isa_tools/data_formats/python/heracles/data/naming.py b/p-isa_tools/data_formats/python/heracles/data/naming.py new file mode 100644 index 00000000..d9bfc63f --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/data/naming.py @@ -0,0 +1,224 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import re + +import heracles.proto.common_pb2 as hpc +import heracles.proto.data_pb2 as hpd +import heracles.proto.fhe_trace_pb2 as hpf +import heracles.util.data as hud +import regex_spm + +# SYMBOL MAPPING +# ===================== + + +def map_mem_sym(context: hpd.FHEContext, instr: hpf.Instruction, sym_obj_name: str) -> str: + """ + Map potentially symbols used in kernels as register arguments pointing + to (polynomial) memory from non-universal from to a universal one. + Note this includes both "normal" source and destination arguments + from FHE operation as well as meta-data such as keys + or (i)psis for negative wrapped convolution in (i)ntt + `sym_obj_name` is string returned from function `sym_get_obj_name` + applied to flattened (sub-)object names found in lower-level traces + """ + # TODO: Actual implementation likely in C++ as we need same function also eventually there + # and to have only single implementation (via native code invocation from python) would + # be easier and more robust to maintain + + args = instr.args + # UNUSED? q_size = context.q_size if context.q_size > 0 else context.key_rns_num - context.digit_size + + args_src = args.srcs # getattr(args, args.WhichOneof("op_specific")) + args_dest = args.dests # getattr(args, args.WhichOneof("op_specific_dest")) + + match sym_obj_name: + # rules for "normal" arguments + case "output": + mapped_sym_obj_name = args_dest[0].symbol_name + case "c" | "input": + mapped_sym_obj_name = args_src[0].symbol_name + case "d" | "p": + mapped_sym_obj_name = args_src[1].symbol_name + # rules for meta data + case "psi" | "ipsi": + # See comments in extract_metadata_polys why we do this additional safety check + mapped_sym_obj_name = f"{sym_obj_name}_{'default'}" + case "ipsi_rot": + # See comments in extract_metadata_polys why we do this additional safety check + mapped_sym_obj_name = f"ipsi_{map_twiddle_type(context, instr)}" + case "rlk": + if context.scheme == hpc.SCHEME_BGV: + mapped_sym_obj_name = f"rlk_{instr.plaintext_index}" + else: + mapped_sym_obj_name = "rlk" + case "gk": + if context.scheme == hpc.SCHEME_BGV: + mapped_sym_obj_name = f"gk_{instr.plaintext_index}_{args.params['galois_elt'].value}" + else: + mapped_sym_obj_name = f"gk_{args.params['galois_elt'].value}" + case "q_last_half": + mapped_sym_obj_name = f"q_i_{args_dest.num_rns}_half_mod_q_j" # in flattening we start with 0 but dest is already last-1 .. + case _: + mapped_sym_obj_name = sym_obj_name + + # print(f"DEBUG (TRACE): map_mem_sym(.., {sym_obj_name}) -> {mapped_sym_obj_name}", file=sys.stderr) + return mapped_sym_obj_name + + +def map_immediate_sym(context: hpd.FHEContext, instr: hpf.Instruction, sym_imm_name: str) -> str: + """ + Map potentially non-universal immediate symbols used in kernels with either a universal one + or (e.g., for `add_corrected`) an actual numerical value + """ + # TODO: Actual implementation in C++ & native code invocation here for same reasons as mentioned for `map_mem_sym` + + if context.scheme == hpc.SCHEME_BGV: + bgv_info = context.bgv_info + key_rns_num = context.key_rns_num + # NOTE: we assume _all_ keys have same digit size!! + k = bgv_info.plaintext_specific[0].keys.relin_key.k + match regex_spm.fullmatch_in(sym_imm_name): + case r"^(c|d)_mont_adjusting_factor_(\d+)$" as m: + match m[1]: + case "c": + adj_factor = int(instr.args.params["adj_factor1"].value) + case "d": + adj_factor = int(instr.args.params["adj_factor2"].value) + case _: + raise ValueError(f"Invalid immediate symbol name: {sym_imm_name}") + adj_factor = hud.convert_to_montgomery(adj_factor, context.q_i[int(m[2])]) + mapped_sym_imm_name = f"{adj_factor}" + + case r"^it$" as m: + mapped_sym_imm_name = ( + f"neg_inv_t_{instr.plaintext_index}_mod_q_i_{instr.args.srcs[0].num_rns - 1}" # in flattening we start with 0 .. + ) + + case r"^t_inverse_mod_p_(\d+)$" as m: + mapped_sym_imm_name = ( + f"neg_inv_t_{instr.plaintext_index}_mod_q_i_{key_rns_num - k + int(m[1])}" # in flattening we start with 0 .. + ) + + case r"^iq_(\d+)$" as m: + mapped_sym_imm_name = f"inv_q_i_{instr.args.srcs[0].num_rns - 1}_mod_q_j_{m[1]}" # in flattening we start with 0 .. + + case r"^t_(\d+)$" as m: + mapped_sym_imm_name = f"t_{instr.plaintext_index}_mod_q_i_{m[1]}" + + case r"^pinv_q_(\d+)$" as m: + mapped_sym_imm_name = f"inv_p_mod_q_i_{m[1]}" + + case r"^corr-inv-target-corr-q-scalar_(\d+)$" as m: + mont_adj_factor = hud.convert_to_montgomery(int(instr.args.params["adj_factor1"].value), context.q_i[int(m[1])]) + mapped_sym_imm_name = f"{mont_adj_factor}" + + case r"^const-reduced_(\d+)$" as m: + adj_factor = int(instr.args.params["adj_factor1"].value) + if bool(instr.args.params["do_invert"].value): + adj_factor = pow(adj_factor, -1, context.q_i[int(m[1])]) + mont_adj_factor = hud.convert_to_montgomery(adj_factor, context.q_i[int(m[1])]) + mapped_sym_imm_name = f"{mont_adj_factor}" + + case r"^BaseChangeMatrix_(\d+_\d+)$" as m: + mapped_sym_imm_name = f"base_change_matrix_{instr.args.srcs[0].num_rns - 1}_{m[1]}" # in flattening we start with 0 .. + + case r"^InvPuncturedProd_(\d+)$" as m: + mapped_sym_imm_name = f"inv_punctured_prod_{instr.args.srcs[0].num_rns - 1}_{m[1]}" # in flattening we start with 0 .. + + case _: + mapped_sym_imm_name = sym_imm_name + + # print(f"DEBUG (TRACE): map_immediate_sym(.., {sym_imm_name}) -> {mapped_sym_imm_name}", file=sys.stderr) + return mapped_sym_imm_name + + +def map_twiddle_type(context: hpd.FHEContext, instr: hpf.Instruction) -> str: + """ + Map to the twiddle type used by this instruction. + Will always a type even if instruction doesn't necessarily need twiddles. + """ + + if instr.op.lower() not in ( + "rotate", + "boot_fastrotation_ext", + "boot_fastrotation_ext_noaddfirst", + "boot_galois_plain", + "boot_conjugate", + "boot_addrotate_c0", + ): + return "default" + + return instr.args.params["galois_elt"].value + + +# OBJECT (un)FLATTENING +# ========================== + + +def get_sym_obj_name(flat_sym_name: str) -> str: + """ + Extract the symbolic name of an object from a provided flattened (polynomial-based sub-)object names as found in lower-level traces + """ + return split_sym_name(flat_sym_name)[0] + + +flat_obj_syn_name_pattern = re.compile(r"^([a-zA-Z_]+)_([\d_]*)$") + + +def split_sym_name(flat_sym_name: str) -> tuple[str, str | None]: + """ + Split the symbolic name of an object from a provided flattened(polynomial-based sub-)object names + as found in lower-level traces into the obj name and the extension. + If the name is _not_ a name of a flattened object, + the name is returned and the second element of the tuple, normally the extension, is 'None' + """ + match = flat_obj_syn_name_pattern.match(flat_sym_name) + if match: + return (match.group(1), match.group(2)) + return (flat_sym_name, None) + + +def combine_sym_name(sym_obj_name: str, sym_obj_extension: str) -> str: + """ + Split the symbolic name of an object from a provided flattened + (polynomial-based sub-)object names as found in lower-level traces into the obj name and the extension + """ + return f"{sym_obj_name}_{sym_obj_extension}" + + +# NOTE: To make flattened names more readable/robust, the naming might change to +# +# a. keyname: keys -> { keyname_d${digitnum}: ciphertext } // _d for digit +# b. ctxtname: ciphertext -> {ctxtname_p${polynum}: polynomial } // _p for number of (full) polynomial (=degree) of ciphertext +# c. polyname: polynomial -> {polyname_r${rnsnum}: rns32polynomial } # arbitrary size rns32 polys with _r for rns +# d. rnsname: rns32polynomial -> {rnsname_c${chunkname}: rns32/8kpolynomial chunks with _c for chunk +# +# Above has to be applied recursively to required depth +# -- depending on consumer transformation d. is not always needed ... -- +# with results the union of intermediary sets, i.e., only resulting in a single set, not sets of sets. +# E.g., a ciphertext ctx mapped to level d with name c would result in set { c_p$P_r$R_c$C : ctx.poly[$P].rns-poly[$R].chunk[$C] +# with $P, $R & $C ranging over the size of the corresponding dimension}. +# +# Whether we change that or not depends whether on how much "parallel" change happen in psim & kernels. +# Once that would be in, it probably also makes sense to add a more complete decomposition function along the lines of +# +# def decompose_flat_sym_name(flat_sym_name: str) -> (str, type, digit, poly, rns, chunk): + + +# TODO: Consider adding some object flattening and de-flattening functions. +# +# E.g., something like +# def poly2rnspoly(symbol: str, poly: Polynomial) -> dict[str, RNSPolynomial]: +# and/or +# def poly2rnspoly(dict[str, Polynomial]) -> dict[str, RNSPolynomial]: +# for batch mode and related functions also for KeySwitch, Ciphertext and Plaintext instead of Polynomial. +# +# For treating results, i.e., outputs from HEC, we might eventually also need inverse functions +# +# As part of swizzling and un-swizzling, we also have to flatten and unflatten RNSPolynomials to/from (sets of) HECBasePolynomials, +# so some related functions could also be useful. +# +# Will add them on-demand as need arises ... diff --git a/p-isa_tools/data_formats/python/heracles/data/transform.py b/p-isa_tools/data_formats/python/heracles/data/transform.py new file mode 100644 index 00000000..b29466e8 --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/data/transform.py @@ -0,0 +1,427 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +# Note: see also the C++ definition of equivalent functions +# - given that some of these function might evolve, best done via shared code & native code integration in python +# TODO: Duplicate in C++ and revert below into native code invocations .. + +import functools as ft +import math +import operator as op + +import heracles.proto.common_pb2 as hpc +import heracles.proto.data_pb2 as hpd +import heracles.util.data as hud + + +def galois_elements_from_context(context: hpd.FHEContext) -> set: + match context.scheme: + case hpc.SCHEME_BGV: + return {ge for pt in context.bgv_info.plaintext_specific for ge, _ in pt.keys.rotation_keys.items()} + case hpc.SCHEME_CKKS: + return set(context.ckks_info.keys.rotation_keys.keys()) + case _: + raise ValueError("Only BGV and CKKS schemes are supported.") + + +def extract_metadata_polys(context: hpd.FHEContext) -> hpd.MetadataPolynomials: + """ + Extract symbol/value map of all metadata polynomials as needed to build (after swizzling) memory images or DMA downloads + """ + N = context.N + q_i = context.q_i + nQ = len(q_i) + psi = context.psi + psi_inv = [pow(psi, -1, q) for psi, q in zip(context.psi, q_i, strict=False)] + # UNUSED? k = context.ckks_info.keys.relin_key.k + # if context.scheme == hpc.SCHEME_CKKS else context.bgv_info.plaintext_specific[0].keys.relin_key.k + + meta_polys = hpd.MetadataPolynomials() + + # - galois_element-specific version for rotation + galois_elts = galois_elements_from_context(context) + for i in range(nQ): + # TODO (eventually): refactor below to common naming functions + + # - powers of psi for negative wrapped convolusion + # - default version for ntt, mod-switch & relin + meta_polys.metadata.sym_poly_map[f"psi_default_{i}"].coeffs.extend( + hud.convert_to_montgomery(pow(psi[i], j, q_i[i]), q_i[i]) for j in range(N) + ) + hud.poly_bit_reverse_inplace(meta_polys.metadata.sym_poly_map[f"psi_default_{i}"]) + + meta_polys.metadata.sym_poly_map[f"ipsi_default_{i}"].coeffs.extend( + hud.convert_to_montgomery(pow(psi_inv[i], j, q_i[i]), q_i[i]) for j in range(N) + ) + hud.poly_bit_reverse_inplace(meta_polys.metadata.sym_poly_map[f"ipsi_default_{i}"]) + + # calculate in house + if context.scheme == hpc.SCHEME_CKKS and i < context.q_size: + # TODO: Revisit later if addi/subi is added, then we can put these to meta_immediates, instead of meta_polys + # ql_half + # for i in range(context.q_size): + qlHalf_i = q_i[i] >> 1 + meta_polys.metadata.sym_poly_map[f"qlHalf_{i}"].coeffs.extend([qlHalf_i] * N) + + jMax = context.q_size if i <= 1 else i + # ql_half_modq + # Rescale: (qi-1) / 2 mod qj for all i < j, (i>=2) + # mod_raise: (qi-1) / 2 mod qj for i=0,1 & all j + for j in range(jMax): + meta_polys.metadata.sym_poly_map[f"qlHalfModq_{i}_{j}"].coeffs.extend([qlHalf_i % q_i[j]] * N) + + for ge in galois_elts: + exp_scale = pow(ge, -1, 2 * N) + # if i >= nQ - k: + # # the ones for the special primes have to be the normal, not the rotation ones + # # NOTE: boot_apply_galois also does use these ipsi/psi but as computed in + # # SEAL-Bootstrapping implies all-rotation variants, so would fail psim/simics + # # comparison when run with polynomials of key_nrns terms. Luckily, that is not + # # a needed scenario, so re-using these is fine (but to be safe, we do check also + # # during the mappings) + # exp_scale = 1 + # else: + # exp_scale = inv_ge + meta_polys.metadata.sym_poly_map[f"ipsi_{str(ge)}_{i}"].coeffs.extend( + hud.convert_to_montgomery(pow(psi_inv[i], exp_scale * j, q_i[i]), q_i[i]) for j in range(N) + ) + hud.poly_bit_reverse_inplace(meta_polys.metadata.sym_poly_map[f"ipsi_{str(ge)}_{i}"]) + # NOTE: we need only rotation-specific intt twiddles, not ntt ones but to keep logic in psim & step1 + # simple and uniform, we still create a separate version + # meta_polys.metadata.sym_poly_map[f"psi_{str(ge)}_{i}"].coeffs.extend( + # [hud.convert_to_montgomery(pow(psi[i], j, q_i[i]), q_i[i]) for j in range(N)]) + # hud.poly_bit_reverse_inplace(meta_polys.metadata.sym_poly_map[f"psi_{str(ge)}_{i}"]) + + # - key-switch keys + if context.scheme == hpc.SCHEME_BGV: + for pt in context.bgv_info.plaintext_specific: + # - relin + transform_and_flatten_key_switch( + context, + f"rlk_{pt}", + pt.keys.relin_key, + meta_polys.metadata.sym_poly_map, + ) + # - rotation + for ge, rk in pt.keys.rotation_keys.items(): + transform_and_flatten_key_switch( + context, + f"gk_{pt}_{ge}", + rk, + meta_polys.metadata.sym_poly_map, + ) + elif context.scheme == hpc.SCHEME_CKKS: + keys = context.ckks_info.keys + # - relin + transform_and_flatten_key_switch(context, "rlk", keys.relin_key, meta_polys.metadata.sym_poly_map) + # - rotation + for ge, rk in keys.rotation_keys.items(): + transform_and_flatten_key_switch( + context, + f"gk_{ge}", + rk, + meta_polys.metadata.sym_poly_map, + ) + + # for bootstrapping + if context.scheme == hpc.SCHEME_BGV and context.bgv_info.recrypt_key: + transform_and_flatten_ciphertext( + context, + "bk", + context.bgv_info.recrypt_key, + meta_polys.metadata.sym_poly_map, + ) + elif context.scheme == hpc.SCHEME_CKKS: + # zero polys added + meta_polys.metadata.sym_poly_map["zero"].coeffs.extend([0] * N) + + return meta_polys + + +def extract_metadata_twiddles(context: hpd.FHEContext) -> hpd.MetadataTwiddles: + """ + Extract symbol/value map of all twiddles as needed to build (after swizzling & replicating) memory images or DMA downloads + """ + N = context.N + q_i = context.q_i + nQ = len(q_i) + # UNUSED? psi = context.psi + omega = [pow(psi, 2, q) for psi, q in zip(context.psi, q_i, strict=False)] + omega_inv = [pow(o, -1, q) for o, q in zip(omega, q_i, strict=False)] + # UNUSED? k = context.ckks_info.keys.relin_key.k + # if context.scheme == hpc.SCHEME_CKKS else context.bgv_info.plaintext_specific[0].keys.relin_key.k + + twiddles = hpd.MetadataTwiddles() + twiddles.only_power_of_two = False + # "normal" twiddles + for i in range(nQ): + twiddles.twiddles_ntt["default"].rns_polys.add().coeffs.extend( + hud.convert_to_montgomery(pow(omega[i], j, q_i[i]), q_i[i]) for j in range(N // 2) + ) + twiddles.twiddles_intt["default"].rns_polys.add().coeffs.extend( + hud.convert_to_montgomery(pow(omega_inv[i], j, q_i[i]), q_i[i]) for j in range(N // 2) + ) + # rotation related twiddles + galois_elts = galois_elements_from_context(context) + for ge in galois_elts: + exp_scale = pow(ge, -1, 2 * N) + for i in range(nQ): + # inv_ge = pow(ge, -1, 2 * N) + # if i >= nQ - k: + # # the ones for the special primes have to be the normal, not the rotation ones + # # (see also additional comments in extract_metadata_polys regarding these psi/ipsis) + # exp_scale = 1 + # else: + # exp_scale = inv_ge + twiddles.twiddles_intt[str(ge)].rns_polys.add().coeffs.extend( + hud.convert_to_montgomery(pow(omega_inv[i], exp_scale * j, q_i[i]), q_i[i]) for j in range(N // 2) + ) + # NOTE we need only rotation-specific intt twiddles, not ntt ones but to keep logic in psim & step1 + # simple and uniform, we still create a separate version + # twiddles.twiddles_ntt[str(ge)].rns_polys.add().coeffs.extend( + # [hud.convert_to_montgomery(pow(omega[i], j, q_i[i]), q_i[i]) for j in range(N//2)]) + + return twiddles + + +def extract_metadata_immediates(context: hpd.FHEContext) -> hpd.MetadataImmediates: # noqa: C901 + """ + Extract symbol/value map of all immediates as needed for final code instantiation + """ + N = context.N + q_i = context.q_i + nQ = context.key_rns_num # is same with len(q_i) + + immediates = hpd.MetadataImmediates() + # NOTE: the "real" 1, not montgomery 1 !! + immediates.sym_immediate_map["one"] = 1 + + if context.scheme == hpc.SCHEME_CKKS: + # dnum = # of digits + dnum = context.digit_size + # alpha = ceil(sizeQ / dnum); # of RNS primes in each digit + alpha = context.alpha + # sizeQ = nQ - sizeP + sizeQ = context.q_size + sizeP = nQ - sizeQ + + for i, q in enumerate(q_i): + immediates.sym_immediate_map[f"R2_{i}"] = hud.montgomery_R**2 % q + immediates.sym_immediate_map[f"iN_{i}"] = hud.convert_to_montgomery(pow(N, -1, q), q) + + # Global Metadata + # iN, the inverse of N mod q_i but is identical in montgomery form across moduli but as we use + # iN_i in some cases (e.g., in some psim scripts) and iN in others we generate both + immediates.sym_immediate_map.update( + { + "iN": (1 << 32) // N, + "q0InvModq1": hud.convert_to_montgomery(pow(q_i[0], -1, q_i[1]), q_i[1]), + "q1InvModq0": hud.convert_to_montgomery(pow(q_i[1], -1, q_i[0]), q_i[0]), + } + ) + + # Metadata for key-switching (Relin, Rotate) + # PartQHatInvModq_{i}_{j} = (Q/Qi)^-1 mod qj; equals to zero for qj \notin Qi + for i in range(dnum): + for j in range(sizeQ): + immediates.sym_immediate_map[f"partQHatInvModq_{i}_{j}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"partQHatInvModq_{i}_{j}"], + q_i[j], + ) + # PartQlHatInvModq_{i}_{j}_{l} = (Q^(i*alpha + j)_i/ql)^-1 mod ql for ql \in Q^(i*alpha + j)_i + for i in range(dnum): + digitSize = alpha if i < dnum - 1 else sizeQ - alpha * (dnum - 1) + for j in range(digitSize): + for l in range(j + 1): # noqa E741 + immediates.sym_immediate_map[f"partQlHatInvModq_{i}_{j}_{l}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"partQlHatInvModq_{i}_{j}_{l}"], + q_i[alpha * i + l], + ) + # PartQlHatModp_{i}_{j}_{l}_{s} = (Q^(i)_j/ql)^-1 mod qs or ps, for qs \notin Q^(i)_j + for i in range(sizeQ): + beta = math.ceil(float(i + 1) / alpha) + for j in range(beta): + digitSize = alpha if j < beta - 1 else (i + 1) - alpha * (beta - 1) + sizeCompl = (i + 1) - digitSize + sizeP + for l in range(digitSize): # noqa E741 + for s in range(sizeCompl): + if s < alpha * j: + idx = s + elif s < (i + 1) - digitSize: + idx = s + digitSize + else: + idx = s - (i + 1) + digitSize + sizeQ + immediates.sym_immediate_map[f"partQlHatModp_{i}_{j}_{l}_{s}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"partQlHatModp_{i}_{j}_{l}_{s}"], + q_i[idx], + ) + + # pInvModq_{i} = P^{-1} mod qi + for i in range(sizeQ): + immediates.sym_immediate_map[f"pInvModq_{i}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"pInvModq_{i}"], q_i[i] + ) + immediates.sym_immediate_map[f"pModq_{i}"] = hud.convert_to_montgomery(context.ckks_info.metadata_extra[f"pModq_{i}"], q_i[i]) + # pInvModp_{i} = P^{-1} mod pi + for i in range(sizeP): + idx = i + sizeQ + immediates.sym_immediate_map[f"pHatInvModp_{i}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"pHatInvModp_{i}"], q_i[idx] + ) + # pHatModq_{i}_{j} = P/pi mod qj + for i in range(sizeP): + for j in range(sizeQ): + immediates.sym_immediate_map[f"pHatModq_{i}_{j}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"pHatModq_{i}_{j}"], q_i[j] + ) + + # Metadata for Rescale + # qlInvModq_{i}_{j} = q_{sizeQ-(i+1)}^{-1} mod qj + for i in range(sizeQ - 1): + for j in range(sizeQ - (i + 1)): + immediates.sym_immediate_map[f"qlInvModq_{i}_{j}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"qlInvModq_{i}_{j}"], q_i[j] + ) + # QlQlInvModqlDivqlModq_{i}_{j} = ((Q/q_{sizeQ-(i+1)})^{-1} mod q_{sizeQ-(i+1)} * (Q/q_{sizeQ-(i+1)})) mod qj + immediates.sym_immediate_map[f"QlQlInvModqlDivqlModq_{i}_{j}"] = hud.convert_to_montgomery( + context.ckks_info.metadata_extra[f"QlQlInvModqlDivqlModq_{i}_{j}"], + q_i[j], + ) + + # Metadata for Bootstrap + for i in (0, 1): + for j in range(sizeQ): + immediates.sym_immediate_map[f"qlModq_{i}_{j}"] = hud.convert_to_montgomery(context.q_i[i], q_i[j]) + + # Metadata for boot_mul_uint + for i in range(32): + val = 1 << i + for j in range(sizeQ): + immediates.sym_immediate_map[f"bmu_{val}_{j}"] = hud.convert_to_montgomery(val, q_i[j]) + if i == 0: + boot_correction = context.ckks_info.metadata_extra["boot_correction"] + immediates.sym_immediate_map[f"bmu_{boot_correction}_{j}"] = hud.convert_to_montgomery(boot_correction, q_i[j]) + + else: # SCHEME_BGV + for i, q in enumerate(q_i): + immediates.sym_immediate_map[f"R2_{i}"] = hud.montgomery_R**2 % q + immediates.sym_immediate_map[f"iN_{i}"] = hud.convert_to_montgomery(pow(N, -1, q), q) + for j in range(i): + immediates.sym_immediate_map[f"inv_q_i_{i}_mod_q_j_{j}"] = hud.convert_to_montgomery(pow(q, -1, q_i[j]), q_i[j]) + if context.scheme == hpc.SCHEME_BGV: + bgv_info = context.bgv_info + for pt_idx, pt in enumerate(bgv_info.plaintext_specific): + immediates.sym_immediate_map[f"neg_inv_t_{pt}_mod_q_i_{i}"] = hud.convert_to_montgomery( + -pow(pt.plaintext_modulus, -1, q), q + ) + immediates.sym_immediate_map[f"t_{pt_idx}_mod_q_i_{i}"] = hud.convert_to_montgomery(pt.plaintext_modulus, q) + + immediates.sym_immediate_map["iN"] = (1 << 32) // N + # iN, the inverse of N mod q_i but is identical in montgomery form across moduli but as we use + # iN_i in some cases (e.g., in some psim scripts) and iN in others we generate both + + k = ( + context.ckks_info.keys.relin_key.k + if context.scheme == hpc.SCHEME_CKKS + else context.bgv_info.plaintext_specific[0].keys.relin_key.k + ) + # NOTE we assume _all_ keys have same digit size!! + p = ft.reduce(op.mul, [q_i[context.key_rns_num - i - 1] for i in range(k)]) + for i in range(nQ - k): + immediates.sym_immediate_map[f"inv_p_mod_q_i_{i}"] = hud.convert_to_montgomery(pow(p, -1, q_i[i]), q_i[i]) + + # for base-extension in bootstrapping, boot_dot_prod kernel needs + # - base_change_matrix[i][l+j] := q/q_i mod q_{l+j} for 0 <= i <= l and 1 <= j <= L-l + # - inv_punctured_prod[i] = (q/q_i)^{-1} mod q_i for 0 <= i <= l + # with l number of rns of input and L the number of rns terms of a key + # and q = product_i=0..l(q_i) / Q = product_i=0..L(q_i) + # Then, from RNS_q(a), we can compute RNS_Q(a+qI) via (sum_{0 <= i <= l} base_change_matrix[i][j] * + # (inv_punctured_prod[i] * a mod q_i)) mod q_{l+j} for 1 <= j <= L-l + # + # As we have to "universalize" that for all l's, we add an additional indirection and map to above + # in map_immediate_sym + + # below is 1-1 translated code from export_metadata_bootstrap_dot_product in serialize.cpp, + # with additional outer loop make universal + for l in range(nQ - 1): # noqa E741 + for j in range(nQ): + for i in range(l + 1): + q_over_qi_mod_qj = 1 + # UNUSED? inv_q_over_qi_mod_qi = 1 + for k in range(nQ): + # qhat_mod_qi = q/qi (mod qi) + if k != i: + q_over_qi_mod_qj = (q_over_qi_mod_qj * q_i[k]) % q_i[j] + immediates.sym_immediate_map[f"base_change_matrix_{l}_{i}_{j}"] = hud.convert_to_montgomery(q_over_qi_mod_qj, q_i[j]) + if i == j: + immediates.sym_immediate_map[f"inv_punctured_prod_{l}_{i}"] = hud.convert_to_montgomery( + pow(q_over_qi_mod_qj, -1, q_i[i]), q_i[i] + ) + + return immediates + + +def extract_polys(test_vector: hpd.TestVector) -> hpd.DataPolynomials: + """ + Helper function to extract polys + """ + polys = hpd.DataPolynomials() + for sym, val in test_vector.sym_data_map.items(): + transform_and_flatten_dcrtpoly(f"{sym}", val.dcrtpoly, polys.data.sym_poly_map) + + return polys + + +def transform_and_flatten_key_switch( + context: hpd.FHEContext, + prefix: str, + key_switch: hpd.KeySwitch, + sym_poly_map: dict[str, hpd.RNSPolynomial], +): + # NOTE psim and kernels expect flattening not in natural hierarchical + # version, so we cannot call transform transform_and_flatten_ciphertext but + # have to do two-level unrol + for d, digit in enumerate(key_switch.digits): + for p, poly in enumerate(digit.polys): + transform_and_flatten_poly(f"{prefix}_{p}_{d}", poly, sym_poly_map) + + +def transform_and_flatten_ciphertext( + context: hpd.FHEContext, + prefix: str, + ciphertext: hpd.Ciphertext, + sym_poly_map: dict[str, hpd.RNSPolynomial], +): + for p, poly in enumerate(ciphertext.polys): + transform_and_flatten_poly(f"{prefix}_{p}", poly, sym_poly_map) + + +def transform_and_flatten_dcrtpoly( + prefix: str, + dcrtpoly: hpd.DCRTPoly, + sym_poly_map: dict[str, hpd.RNSPolynomial], +): + for p, poly in enumerate(dcrtpoly.polys): + transform_and_flatten_poly(f"{prefix}_{p}", poly, sym_poly_map) + + +def transform_and_flatten_plaintext( + context: hpd.FHEContext, + prefix: str, + plaintext: hpd.Plaintext, + sym_poly_map: dict[str, hpd.RNSPolynomial], +): + transform_and_flatten_poly(f"{prefix}", plaintext.poly, sym_poly_map) + + +def transform_and_flatten_poly( + prefix: str, + poly: hpd.Polynomial, + sym_poly_map: dict[str, hpd.RNSPolynomial], +): + # print(f"DEBUG (TRACE): transform_and_flatten_poly(...{prefix}...)", file=sys.stderr) + for r, rns in enumerate(poly.rns_polys): + rns_poly = sym_poly_map[f"{prefix}_{r}"] + rns_poly.coeffs.extend(hud.convert_to_montgomery(coeff, rns.modulus) for coeff in rns.coeffs) + hud.poly_bit_reverse_inplace(rns_poly) diff --git a/p-isa_tools/data_formats/python/heracles/fhe_trace/__init__.py b/p-isa_tools/data_formats/python/heracles/fhe_trace/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/fhe_trace/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/data_formats/python/heracles/fhe_trace/io.py b/p-isa_tools/data_formats/python/heracles/fhe_trace/io.py new file mode 100644 index 00000000..aa75bece --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/fhe_trace/io.py @@ -0,0 +1,30 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +# TODO: create also C++ variants of below; given how simple and stable these functions should be just in replicated form, not shared code + +from heracles.proto.fhe_trace_pb2 import Trace + +# load & store functions +# =============================== + + +def store_trace(filename: str, trace: Trace): + """ + Serialize and store a HEC trace. Filename is constructed by concatenating `filename_prefix` with standard suffix. + Prefix can contain directory paths, although they must all be existing directories + """ + with open(filename, "wb") as f: + f.write(trace.SerializeToString()) + + +def load_trace(filename: str) -> Trace: + """ + Load and deserialize a HEC trace. Filename is constructed by concatenating `filename_prefix` with standard suffix. + Prefix can contain directory paths, although they must all be existing directories + """ + trace = Trace() + with open(filename, "rb") as f: + trace.ParseFromString(f.read()) + return trace diff --git a/p-isa_tools/data_formats/python/heracles/util/__init__.py b/p-isa_tools/data_formats/python/heracles/util/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/util/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/data_formats/python/heracles/util/data.py b/p-isa_tools/data_formats/python/heracles/util/data.py new file mode 100644 index 00000000..42830102 --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/util/data.py @@ -0,0 +1,28 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import heracles.proto.data_pb2 as hpd + +# - montgomery transform +montgomery_r_bits = 32 +montgomery_r = 1 << montgomery_r_bits + + +def convert_to_montgomery(num: int, modulus: int) -> int: + return (num << montgomery_r_bits) % modulus + + +# - bit-reversal +def poly_bit_reverse_inplace(a: hpd.RNSPolynomial): + a_in = hpd.RNSPolynomial() + a_in.CopyFrom(a) + n = len(a.coeffs) + j = 0 + for i in range(1, n): + b = n >> 1 + while j >= b: + j -= b + b >>= 1 + j += b + if j > i: + a.coeffs[i], a.coeffs[j] = a_in.coeffs[j], a_in.coeffs[i] diff --git a/p-isa_tools/data_formats/python/heracles/util/fhe_trace.py b/p-isa_tools/data_formats/python/heracles/util/fhe_trace.py new file mode 100644 index 00000000..4288f6ca --- /dev/null +++ b/p-isa_tools/data_formats/python/heracles/util/fhe_trace.py @@ -0,0 +1,29 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import heracles.proto.fhe_trace_pb2 as hpf + + +def get_all_symbols(trace: hpf.Trace, get_intermediates: bool = False): + syms_input = set() + syms_output = set() + syms_intermediate = set() + for instruction in trace.instructions: + if instruction.op.startswith("bk_"): + continue + + for src in instruction.args.srcs: + syms_input.add(src.symbol_name) + for dest in instruction.args.dests: + syms_output.add(dest.symbol_name) + + if get_intermediates: + # get pure inputs + syms_input_exclusive = syms_input - syms_output + # get intermediates + syms_intermediate = syms_input - syms_input_exclusive + # get pure outputs + syms_output = syms_output - syms_intermediate + syms_input = syms_input_exclusive + + return [syms_input, syms_output, syms_intermediate] diff --git a/p-isa_tools/data_formats/test/.clang-format b/p-isa_tools/data_formats/test/.clang-format new file mode 100644 index 00000000..ccd0912f --- /dev/null +++ b/p-isa_tools/data_formats/test/.clang-format @@ -0,0 +1,126 @@ +--- +Language: Cpp +# BasedOnStyle: Microsoft +AccessModifierOffset: -4 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveMacros: false +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: false +AllowAllArgumentsOnNextLine: true +AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortLambdasOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterCaseLabel: true + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom # Allman +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Merge +IncludeCategories: + - Regex: '<.*>' + Priority: 1 + - Regex: '"seal/util/.*"' + Priority: -2 + - Regex: '"seal/.*"' + Priority: -3 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: false +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: All +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 1000 +PointerAlignment: Right +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 4 +UseTab: Never +... diff --git a/p-isa_tools/data_formats/test/.gitignore b/p-isa_tools/data_formats/test/.gitignore new file mode 100644 index 00000000..799d18e6 --- /dev/null +++ b/p-isa_tools/data_formats/test/.gitignore @@ -0,0 +1,3 @@ +# test artifacts +*.program_trace +*.data_trace* diff --git a/p-isa_tools/data_formats/test/CMakeLists.txt b/p-isa_tools/data_formats/test/CMakeLists.txt new file mode 100644 index 00000000..d1a05d85 --- /dev/null +++ b/p-isa_tools/data_formats/test/CMakeLists.txt @@ -0,0 +1,33 @@ +# HERACLES_data_proto test +add_executable(heracles_test + heracles_test.cpp +) +target_link_libraries(heracles_test PRIVATE heracles_data_formats) +add_test( + NAME heracles_proto_c++ + COMMAND ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/heracles_test + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} +) +set(Python3_FIND_VIRTUALENV FIRST) +find_package(Python3 REQUIRED COMPONENTS Interpreter) +add_test( + NAME heracles_test_python + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/heracles_test.py + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} +) + +set_tests_properties(heracles_test_python PROPERTIES + ENVIRONMENT "PYTHONPATH=${PROJECT_BINARY_DIR}/python:$ENV{PYTHONPATH}" +) + +add_executable(heracles_math_test + heracles_math_test.cpp + math_unittest/unittest.cpp +) + +target_link_libraries(heracles_math_test PRIVATE heracles_data_formats) +add_test( + NAME heracles_math_UNITTEST + COMMAND ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/heracles_math_test + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} +) diff --git a/p-isa_tools/data_formats/test/compile_protos.py b/p-isa_tools/data_formats/test/compile_protos.py new file mode 100644 index 00000000..af2fbf11 --- /dev/null +++ b/p-isa_tools/data_formats/test/compile_protos.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Compile Protocol Buffer files to Python modules. +This script can be run standalone without CMake/C++ dependencies. +""" + +import sys +from pathlib import Path + +from grpc_tools import protoc + + +def compile_protos(): + """Compile all .proto files to Python modules.""" + + # Get the directory where this script is located + script_dir = Path(__file__).parent.absolute() + + # The script is now in test/, so go up one level to data_formats + if script_dir.name == "test": + base_dir = script_dir.parent + else: + # Fallback: try to find data_formats from repo root + base_dir = Path.cwd() / "p-isa_tools" / "data_formats" + if not base_dir.exists(): + base_dir = script_dir.parent + + proto_dir = base_dir / "proto" / "heracles" + python_dir = base_dir / "python" + + # Create output directory for generated files + output_dir = python_dir / "heracles" / "proto" + output_dir.mkdir(parents=True, exist_ok=True) + + # Create __init__.py in proto directory + init_file = output_dir / "__init__.py" + if not init_file.exists(): + init_file.write_text("# Auto-generated proto package\n") + + # Find all .proto files + proto_files = list(proto_dir.glob("*.proto")) + + if not proto_files: + print(f"No .proto files found in {proto_dir}") + return 1 + + print(f"Found {len(proto_files)} proto files to compile:") + for proto_file in proto_files: + print(f" - {proto_file.name}") + + # Find the grpcio_tools package to get the google protobuf includes + import grpc_tools + + grpc_tools_path = Path(grpc_tools.__file__).parent + proto_include = grpc_tools_path / "_proto" + + # Compile all proto files at once to handle dependencies + print("Compiling all proto files...") + + # protoc arguments - compile all files together + args = [ + "grpc_tools.protoc", + f"-I{proto_include}", # Include path for google/protobuf/*.proto + f"--proto_path={proto_dir}", + f"--python_out={output_dir}", + ] + [str(proto_file) for proto_file in proto_files] + + # Run protoc + result = protoc.main(args) + + if result != 0: + print("Error compiling proto files") + return result + + print(f"\nSuccessfully compiled {len(proto_files)} proto files to {output_dir}") + + # Fix imports in generated files to use relative imports + print("\nFixing imports in generated files...") + for py_file in output_dir.glob("*_pb2.py"): + content = py_file.read_text() + + # Replace absolute imports with relative imports for local proto files + for proto_file in proto_files: + module_name = proto_file.stem + old_import = f"import {module_name}_pb2" + new_import = f"from . import {module_name}_pb2" + content = content.replace(old_import, new_import) + + py_file.write_text(content) + print(f" Fixed imports in {py_file.name}") + + print("\nProto compilation complete!") + return 0 + + +if __name__ == "__main__": + sys.exit(compile_protos()) diff --git a/p-isa_tools/data_formats/test/generate_test_traces.py b/p-isa_tools/data_formats/test/generate_test_traces.py new file mode 100644 index 00000000..ca81e42f --- /dev/null +++ b/p-isa_tools/data_formats/test/generate_test_traces.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Generate test trace files for heracles_test.py +This replaces the need for C++ test to generate these files. +""" + +import heracles.data.io as hdi +import heracles.fhe_trace.io as hfi +import heracles.proto.common_pb2 as hpc +import heracles.proto.data_pb2 as hpd +import heracles.proto.fhe_trace_pb2 as hpf + + +def generate_program_trace(): + """Generate test.program_trace file matching the C++ test.""" + # Create sample trace + trace = hpf.Trace() + + # Set context + trace.scheme = hpc.SCHEME_BGV + trace.key_rns_num = 70 + trace.N = 16384 + + # First instruction - NEGATE + negate = trace.instructions.add() + negate.op = "NEGATE" + negate.plaintext_index = 2 + + # NEGATE destination + neg_dest = negate.args.dests.add() + neg_dest.symbol_name = "t1" + neg_dest.num_rns = 5 + neg_dest.order = 2 + + # NEGATE source + neg_src = negate.args.srcs.add() + neg_src.symbol_name = "in1" + neg_src.num_rns = 5 + neg_src.order = 2 + + # Second instruction - ADD + add = trace.instructions.add() + add.op = "ADD" + add.plaintext_index = 2 + + # ADD destination + add_dest = add.args.dests.add() + add_dest.symbol_name = "out1" + add_dest.num_rns = 5 + add_dest.order = 2 + + # ADD sources + add_src1 = add.args.srcs.add() + add_src1.symbol_name = "t1" + add_src1.num_rns = 5 + add_src1.order = 2 + + add_src2 = add.args.srcs.add() + add_src2.symbol_name = "in2" + add_src2.num_rns = 5 + add_src2.order = 2 + + # Save the trace + hfi.store_trace("test.program_trace", trace) + print("Generated test.program_trace") + + return trace + + +def generate_data_trace(): + """Generate test.data_trace file (context and test vector).""" + # Create FHE context + context = hpd.FHEContext() + context.scheme = hpc.SCHEME_BGV + context.N = 16384 + context.key_rns_num = 70 + context.q_size = 5 + + # Add some basic BGV-specific information + bgv_spec = context.bgv_info + + # Add a plaintext specification (index 2 as used in the trace) + pt_spec = bgv_spec.plaintext_specific.add() + pt_spec.plaintext_modulus = 65537 + pt_spec = bgv_spec.plaintext_specific.add() + pt_spec.plaintext_modulus = 65537 + pt_spec = bgv_spec.plaintext_specific.add() # Index 2 + pt_spec.plaintext_modulus = 65537 + + # Create TestVector with some sample data + testvector = hpd.TestVector() + + # Add data for symbols used in the trace + for symbol in ["in1", "in2", "t1", "out1", "output_0_1_2"]: + data = testvector.sym_data_map[symbol] + + # Add a simple DCRTPoly (the Data message only has dcrtpoly field) + dcrt = data.dcrtpoly + dcrt.in_ntt_form = True + + # Add polynomial data (2 for ciphertext order) + for _ in range(2): # 2 polynomials for a ciphertext + poly = dcrt.polys.add() + poly.in_OpenFHE_EVALUATION = False + + # Add RNS polynomials (5 moduli as specified in trace) + for _ in range(5): + poly.rns_polys.add() + # Just create empty RNS polynomial structure + + # Save the data trace + hdi.store_data_trace("test.data_trace", context, testvector) + print("Generated test.data_trace and associated files") + + return context, testvector + + +def main(): + """Generate all test trace files.""" + print("Generating test trace files...") + + # Generate the program trace + generate_program_trace() + + # Generate the data trace + generate_data_trace() + + print("\nTest trace generation complete!") + print("Files created:") + print(" - test.program_trace") + print(" - test.data_trace (manifest)") + print(" - test.data_trace_hec_context_part_0") + print(" - test.data_trace_testvector_part_0") + + # Verify the files can be loaded + print("\nVerifying files can be loaded...") + hfi.load_trace("test.program_trace") + hdi.load_hec_context("test.data_trace") + hdi.load_data_trace("test.data_trace") + print("✓ All files loaded successfully") + + +if __name__ == "__main__": + main() diff --git a/p-isa_tools/data_formats/test/heracles_math_test.cpp b/p-isa_tools/data_formats/test/heracles_math_test.cpp new file mode 100644 index 00000000..2f072840 --- /dev/null +++ b/p-isa_tools/data_formats/test/heracles_math_test.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "math_unittest/unittest.h" + +int main(int /*argc*/, const char * /*argv*/[]) +{ + TEST_add_uint32_mod(); + TEST_multiply_uint_mod(); + TEST_exponentiate_uint32_mod(); + TEST_negate_uint_mod(); + TEST_try_invert_uint_mod32(); + TEST_get_msb_index(); + TEST_get_significant_bit_count(); + TEST_get_significant_bit_count_uint(); + TEST_divide_uint96_inplace(); + TEST_left_shift_uint96(); + TEST_right_shift_uint96(); + TEST_add_uint32_base(); + TEST_sub_uint32_base(); + TEST_xgcd32(); + TEST_reverse_bits(); + + TEST_add_uint64_mod(); + TEST_exponentiate_uint64_mod(); + TEST_divide_uint192_inplace(); + TEST_left_shift_uint192(); + TEST_right_shift_uint192(); + TEST_add_uint64_base(); + TEST_sub_uint64_base(); +} diff --git a/p-isa_tools/data_formats/test/heracles_test.cpp b/p-isa_tools/data_formats/test/heracles_test.cpp new file mode 100644 index 00000000..b937039a --- /dev/null +++ b/p-isa_tools/data_formats/test/heracles_test.cpp @@ -0,0 +1,134 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include "google/protobuf/util/json_util.h" +#include "heracles/heracles_data_formats.h" +#include "heracles/heracles_proto.h" + +void fhe_trace_tests() +{ + // Create sample trace + heracles::fhe_trace::Trace trace; + // - context + trace.set_scheme(heracles::common::SCHEME_BGV); + trace.set_key_rns_num(70); + trace.set_n(16384); + // - first instruction + auto negate = trace.add_instructions(); + negate->set_op("NEGATE"); + negate->set_plaintext_index(2); + auto neg_args = negate->mutable_args(); + auto neg_args_dest = neg_args->add_dests(); + neg_args_dest->set_symbol_name("t1"); + neg_args_dest->set_num_rns(5); + neg_args_dest->set_order(2); + auto neg_args_src = neg_args->add_srcs(); + neg_args_src->set_symbol_name("in1"); + neg_args_src->set_num_rns(5); + neg_args_src->set_order(2); + // - second instruction + auto add = trace.add_instructions(); + add->set_op("ADD"); + add->set_plaintext_index(2); + auto add_args = add->mutable_args(); + heracles::fhe_trace::OperandObject *add_args_dest = new heracles::fhe_trace::OperandObject(); + add_args_dest->set_symbol_name("out1"); + add_args_dest->set_num_rns(5); + add_args_dest->set_order(2); + // add_args->add_dests()->CopyFrom(*add_args_dest); + add_args->add_dests()->CopyFrom(*add_args_dest); + add_args->add_srcs()->CopyFrom(add_args->dests(0)); + add_args->mutable_srcs(0)->set_symbol_name("t1"); + add_args->add_srcs()->CopyFrom(add_args->dests(0)); + add_args->mutable_srcs(1)->set_symbol_name("in2"); + + // display it .. + std::cout << "debug string: " << trace.DebugString() << std::endl; + std::string json; + auto rc = google::protobuf::util::MessageToJsonString(trace, &json); + std::cout << "json: " << json << std::endl; + + // accessing enums as default strings and as our own version .. + auto scheme = trace.scheme(); + std::cout << "scheme: as-num=" << scheme + << " / as-default-string=" << heracles::common::Scheme_descriptor()->FindValueByNumber(scheme)->name() + << " / as-friendly-string=" + << heracles::common::Scheme_descriptor()->value(scheme)->options().GetExtension( + heracles::common::string_name) + << std::endl; + + // serialize it to file + if (!heracles::fhe_trace::store_trace("test.program_trace", trace)) + { + std::cerr << "Could not serialize" << std::endl; + exit(1); + } + + // deserialize it from file + heracles::fhe_trace::Trace deserialized_trace; + trace = heracles::fhe_trace::load_trace("test.program_trace"); + + std::cout << "debug string: " << deserialized_trace.DebugString() << std::endl; +} + +void map_tests() +{ + // serialize/deserialize of the input map objects ... + heracles::data::DataPolynomials polys; + auto poly_map = polys.mutable_data()->mutable_sym_poly_map(); + auto key = "key"; + (*poly_map)[key].add_coeffs(1); + (*poly_map)[key].add_coeffs(2); + auto *coeffs = ((*poly_map)[key].mutable_coeffs()); + coeffs->Resize(8, -1); + // coeffs->Add(3); + coeffs->at(2) = -3; + // coeffs->Add(4); + (*poly_map)[key].set_coeffs(3, -4); + + std::cout << "debug string: " << polys.DebugString() << std::endl; + std::string json; + auto rc = google::protobuf::util::MessageToJsonString(polys, &json); + std::cout << "json: " << json << std::endl; + + // serialize to buffer ... + // std::byte would be nicer but had trouble compiling with -std=c++17 + std::vector buf(polys.ByteSizeLong()); + if (!polys.SerializeToArray(buf.data(), buf.size())) + { + std::cerr << "Could not serialize" << std::endl; + exit(1); + } + + // .. and deserialize it to new object + heracles::data::DataPolynomials new_polys; + if (!new_polys.ParseFromArray(buf.data(), buf.size())) + { + std::cerr << "Could not serialize" << std::endl; + exit(1); + } + + std::cout << "new: " << new_polys.DebugString() << std::endl; +} + +void cpp_data_tests() +{ + heracles::data::FHEContext context; + heracles::data::TestVector testvector; + context.set_scheme(heracles::common::SCHEME_BGV); + heracles::data::store_data_trace("test.data_trace", context, testvector); + + auto [new_context, new_testvector] = heracles::data::load_data_trace("test.data_trace"); + std::cout << "COMPLETE: cpp_data_tests" << std::endl; +} + +int main(int /*argc*/, const char * /*argv*/[]) +{ + map_tests(); + fhe_trace_tests(); + cpp_data_tests(); +} diff --git a/p-isa_tools/data_formats/test/heracles_test.py b/p-isa_tools/data_formats/test/heracles_test.py new file mode 100644 index 00000000..67da6164 --- /dev/null +++ b/p-isa_tools/data_formats/test/heracles_test.py @@ -0,0 +1,100 @@ +#!/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +from pathlib import Path + +# Setup: Ensure protos are compiled and test data exists +test_dir = Path(__file__).parent.absolute() +proto_dir = test_dir.parent / "python" / "heracles" / "proto" + +# Check and compile protos if needed +proto_files = ["common_pb2.py", "data_pb2.py", "fhe_trace_pb2.py", "maps_pb2.py"] +if not proto_dir.exists() or not all((proto_dir / f).exists() for f in proto_files): + print("Proto files not found, compiling...") + try: + import compile_protos + + sys.path.insert(0, str(test_dir)) + compile_protos.compile_protos() + except ImportError: + print("ERROR: grpcio-tools not found. Install with: pip install -e '.[dev]'") + sys.exit(1) + +# Add python directory to path +sys.path.insert(0, str(test_dir.parent / "python")) + +# Check and generate test traces if needed +os.chdir(test_dir) +if not os.path.exists("test.program_trace") or not os.path.exists("test.data_trace"): + print("Generating test traces...") + import generate_test_traces + + generate_test_traces.main() + +# Now import the modules (after setup is complete) +import google.protobuf.json_format as gpj # noqa: E402 +import heracles.data.io as hdi # noqa: E402 +import heracles.data.naming as hdn # noqa: E402 +import heracles.fhe_trace.io as hfi # noqa: E402 +import heracles.proto.common_pb2 as hpc # noqa: E402 + + +def test(): + # simulate interaction of program mapper with this library .... + # Test data is already generated at module import time + trace = hfi.load_trace("test.program_trace") + hec_context = hdi.load_hec_context("test.data_trace") + + for fhe_instr in trace.instructions: + # find compiled kernel for operation fhe_instr.op and for each HEC-ISA instruction in kernel .. + # .. find all memory and immediate symbols ... + flat_obj_sym = "output_0_1_2" + mem_sym_prefix = hdn.get_sym_obj_name(flat_obj_sym) + immediate_sym = "meta_7" + # ... and their universal form ... + # (Note: first call is a replacement, with slightly different arguments, + # of the `replace_symbols` function from Sim0.5.1 `program_mapper.py`)) + universal_mem_sym_prefix = hdn.map_mem_sym(hec_context, fhe_instr, mem_sym_prefix) + # TODO: hdn.map_immediate_sym will fail until heracles_test.cpp exports a full context with + # keys so far just skip ... + universal_immediate_sym = None + # universal_immediate_sym = hdn.map_immediate_sym( + # hec_context, fhe_instr, immediate_sym + # ) + # ... and in kernel ... + print( + f"replace for operation '{fhe_instr.op}' memory symbol-prefix '{mem_sym_prefix}'" + f" with '{universal_mem_sym_prefix}' and immediate symbol-prefix '{immediate_sym}' with '{universal_immediate_sym}'" + ) + + # complete dump ... + print(trace) + print(gpj.MessageToJson(trace)) + print(gpj.MessageToDict(trace)) + + # selective access to trace information ... + print( + f"scheme num={trace.scheme} / " + f"default-string={hpc.Scheme.DESCRIPTOR.values_by_number[trace.scheme].name} / " + f"friendly-string={hpc.Scheme.DESCRIPTOR.values_by_number[trace.scheme].GetOptions()}" + ) # TODO: extract heracles.instruction.string_name extension + first_instr = trace.instructions[0] + src1 = first_instr.args.srcs[0].symbol_name + src2 = first_instr.args.srcs[1].symbol_name if len(first_instr.args.srcs) > 1 else "N/A" + + dest = first_instr.args.dests[0].symbol_name + print(f"first instruction arguments: destination '{dest}', src1='{src1}', src2='{src2}'") + second_instr = trace.instructions[1] + src1 = second_instr.args.srcs[0].symbol_name + src2 = second_instr.args.srcs[1].symbol_name if len(second_instr.args.srcs) > 1 else "N/A" + + dest = second_instr.args.dests[0].symbol_name + + print(f"second instruction arguments: destination '{dest}', src1='{src1}', src2='{src2}'") + + +if __name__ == "__main__": + test() diff --git a/p-isa_tools/data_formats/test/math_unittest/unittest.cpp b/p-isa_tools/data_formats/test/math_unittest/unittest.cpp new file mode 100644 index 00000000..9f79a47b --- /dev/null +++ b/p-isa_tools/data_formats/test/math_unittest/unittest.cpp @@ -0,0 +1,1104 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "unittest.h" +#include +#include +#include +#include +#include "heracles/data/math.h" + +bool TEST_add_uint32_mod() +{ + std::uint32_t mod; + + mod = 2; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(0, heracles::math::add_uint_mod(1, 1, mod)); + + mod = 10; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(2, heracles::math::add_uint_mod(1, 1, mod)); + ASSERT_EQ(4, heracles::math::add_uint_mod(7, 7, mod)); + ASSERT_EQ(3, heracles::math::add_uint_mod(6, 7, mod)); + + mod = 1305843001; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(2, heracles::math::add_uint_mod(1, 1, mod)); + ASSERT_EQ(0, heracles::math::add_uint_mod(652921500, 652921501, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(652921501, 652921501, mod)); + ASSERT_EQ(1305842999, heracles::math::add_uint_mod(1305843000, 1305843000, mod)); + + return true; +} + +bool TEST_multiply_uint_mod() +{ + std::uint32_t mod; + + mod = 2; + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 0U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 1U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(1U, 0U, mod)); + ASSERT_EQ(1, heracles::math::multiply_uint_mod(1U, 1U, mod)); + + mod = 10; + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 0U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 1U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(1U, 0U, mod)); + ASSERT_EQ(1, heracles::math::multiply_uint_mod(1U, 1U, mod)); + ASSERT_EQ(9, heracles::math::multiply_uint_mod(7U, 7U, mod)); + ASSERT_EQ(2, heracles::math::multiply_uint_mod(6U, 7U, mod)); + ASSERT_EQ(2, heracles::math::multiply_uint_mod(7U, 6U, mod)); + + mod = 1305843001; + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 0U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(0U, 1U, mod)); + ASSERT_EQ(0, heracles::math::multiply_uint_mod(1U, 0U, mod)); + ASSERT_EQ(1, heracles::math::multiply_uint_mod(1U, 1U, mod)); + ASSERT_EQ(326460750, heracles::math::multiply_uint_mod(652921500U, 652921501U, mod)); + ASSERT_EQ(326460750, heracles::math::multiply_uint_mod(652921501U, 652921500U, mod)); + ASSERT_EQ(979382251, heracles::math::multiply_uint_mod(652921501U, 652921501U, mod)); + ASSERT_EQ(1, heracles::math::multiply_uint_mod(1305843000U, 1305843000U, mod)); + + return true; +} + +bool TEST_exponentiate_uint32_mod() +{ + std::uint32_t mod; + + mod = 5; + ASSERT_EQ(1, heracles::math::exponentiate_uint_mod(1U, 0U, mod)); + ASSERT_EQ(1, heracles::math::exponentiate_uint_mod(1U, 0xFFFFFFFFU, mod)); + ASSERT_EQ(3, heracles::math::exponentiate_uint_mod(2U, 0xFFFFFFFFU, mod)); + + mod = 0x10000000; + ASSERT_EQ(0, heracles::math::exponentiate_uint_mod(2U, 30U, mod)); + ASSERT_EQ(0, heracles::math::exponentiate_uint_mod(2U, 59U, mod)); + + mod = 131313131; + ASSERT_EQ(26909095, heracles::math::exponentiate_uint_mod(242424242U, 16U, mod)); + + return true; +} + +bool TEST_negate_uint_mod() +{ + std::uint32_t mod; + + mod = 2; + ASSERT_EQ(0, heracles::math::negate_uint_mod(0U, mod)); + ASSERT_EQ(1, heracles::math::negate_uint_mod(1U, mod)); + + mod = 0xFFFF; + ASSERT_EQ(0, heracles::math::negate_uint_mod(0U, mod)); + ASSERT_EQ(0xFFFE, heracles::math::negate_uint_mod(1U, mod)); + ASSERT_EQ(1, heracles::math::negate_uint_mod(0xFFFEU, mod)); + + mod = 1844674403; + ASSERT_EQ(0, heracles::math::negate_uint_mod(0U, mod)); + ASSERT_EQ(1844674402, heracles::math::negate_uint_mod(1U, mod)); + + return true; +} + +bool TEST_try_invert_uint_mod32() +{ + std::uint32_t mod, result; + + mod = 5; + ASSERT_EQ(false, heracles::math::try_invert_uint_mod(0U, mod, &result)); + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(1U, mod, &result)); + ASSERT_EQ(1, result); + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(2U, mod, &result)); + ASSERT_EQ(3, result); + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(3U, mod, &result)); + ASSERT_EQ(2, result); + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(4U, mod, &result)); + ASSERT_EQ(4, result); + + mod = 6; + ASSERT_EQ(false, heracles::math::try_invert_uint_mod(2U, mod, &result)); + ASSERT_EQ(false, heracles::math::try_invert_uint_mod(3U, mod, &result)); + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(5U, mod, &result)); + ASSERT_EQ(5, result); + + mod = 1351315121; + ASSERT_EQ(true, heracles::math::try_invert_uint_mod(331975426U, mod, &result)); + ASSERT_EQ(1052541512, result); + + return true; +} + +bool TEST_get_significant_bit_count_uint() +{ + std::vector val(2, 0); + + val[0] = 0; + val[1] = 0; + ASSERT_EQ(0, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 1; + val[1] = 0; + ASSERT_EQ(1, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 2; + val[1] = 0; + ASSERT_EQ(2, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 3; + val[1] = 0; + ASSERT_EQ(2, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 29; + val[1] = 0; + ASSERT_EQ(5, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 4; + val[1] = 0; + ASSERT_EQ(3, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0xFFFFFFFF; + val[1] = 0; + ASSERT_EQ(32, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0; + val[1] = 1; + ASSERT_EQ(33, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0xFFFFFFFF; + val[1] = 1; + ASSERT_EQ(33, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0xFFFFFFFF; + val[1] = 0x70000000; + ASSERT_EQ(63, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0xFFFFFFFF; + val[1] = 0x80000000; + ASSERT_EQ(64, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + val[0] = 0xFFFFFFFF; + val[1] = 0xFFFFFFFF; + ASSERT_EQ(64, heracles::math::get_significant_bit_count_uint(val.data(), 2)); + + std::vector val64(2, 0); + + val64[0] = 0; + val64[1] = 0; + ASSERT_EQ(0, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 1; + val64[1] = 0; + ASSERT_EQ(1, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 2; + val64[1] = 0; + ASSERT_EQ(2, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 3; + val64[1] = 0; + ASSERT_EQ(2, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 29; + val64[1] = 0; + ASSERT_EQ(5, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 4; + val64[1] = 0; + ASSERT_EQ(3, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0xFFFFFFFF; + val64[1] = 0; + ASSERT_EQ(32, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0; + val64[1] = 1; + ASSERT_EQ(65, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0xFFFFFFFF; + val64[1] = 1; + ASSERT_EQ(65, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0xFFFFFFFFFFFFFFFF; + val64[1] = 0x7000000000000000; + ASSERT_EQ(127, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0xFFFFFFFFFFFFFFFF; + val64[1] = 0x8000000000000000; + ASSERT_EQ(128, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + val64[0] = 0xFFFFFFFFFFFFFFFF; + val64[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(128, heracles::math::get_significant_bit_count_uint(val64.data(), 2)); + + return true; +} + +bool TEST_divide_uint96_inplace() +{ + std::vector input(3, 0); + std::vector quotient(3, 0); + + input[0] = 0; + input[1] = 0; + input[2] = 0; + heracles::math::divide_uint3_inplace(input.data(), 1U, quotient.data()); + ASSERT_EQ(0, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(0, quotient[0]); + ASSERT_EQ(0, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + input[0] = 1; + input[1] = 0; + input[2] = 0; + heracles::math::divide_uint3_inplace(input.data(), 1U, quotient.data()); + ASSERT_EQ(0, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(1, quotient[0]); + ASSERT_EQ(0, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + input[0] = 0x10101010U; + input[1] = 0x2B2B2B2BU; + input[2] = 0xF1F1F1F1U; + heracles::math::divide_uint3_inplace(input.data(), 0x1000U, quotient.data()); + ASSERT_EQ(0x10, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(0xB2B10101, quotient[0]); + ASSERT_EQ(0x1F12B2B2, quotient[1]); + ASSERT_EQ(0xF1F1F, quotient[2]); + + input[0] = 12121212; + input[1] = 34343434; + input[2] = 56565656; + heracles::math::divide_uint3_inplace(input.data(), 78787878U, quotient.data()); + ASSERT_EQ(18181818, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(991146299, quotient[0]); + ASSERT_EQ(3083566264, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + return true; +} + +bool TEST_left_shift_uint96() +{ + std::vector a(3, 0); + std::vector b(3, 0xFFFFFFFF); + + heracles::math::left_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + std::fill_n(b.data(), b.size(), 0xFFFFFFFF); + heracles::math::left_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::left_shift_uint3(a.data(), 10, a.data()); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0, a[1]); + ASSERT_EQ(0, a[2]); + + a[0] = 0x55555555; + a[1] = 0xAAAAAAAA; + a[2] = 0xCDCDCDCD; + heracles::math::left_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0x55555555, b[0]); + ASSERT_EQ(0xAAAAAAAA, b[1]); + ASSERT_EQ(0xCDCDCDCD, b[2]); + heracles::math::left_shift_uint3(a.data(), 0, a.data()); + ASSERT_EQ(0x55555555, a[0]); + ASSERT_EQ(0xAAAAAAAA, a[1]); + ASSERT_EQ(0xCDCDCDCD, a[2]); + heracles::math::left_shift_uint3(a.data(), 1, b.data()); + ASSERT_EQ(0xAAAAAAAA, b[0]); + ASSERT_EQ(0x55555554, b[1]); + ASSERT_EQ(0x9B9B9B9B, b[2]); + heracles::math::left_shift_uint3(a.data(), 2, b.data()); + ASSERT_EQ(0x55555554, b[0]); + ASSERT_EQ(0xAAAAAAA9, b[1]); + ASSERT_EQ(0x37373736, b[2]); + heracles::math::left_shift_uint3(a.data(), 32, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0x55555555, b[1]); + ASSERT_EQ(0xAAAAAAAA, b[2]); + heracles::math::left_shift_uint3(a.data(), 33, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0xAAAAAAAA, b[1]); + ASSERT_EQ(0x55555554, b[2]); + heracles::math::left_shift_uint3(a.data(), 95, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0x80000000, b[2]); + + heracles::math::left_shift_uint3(a.data(), 2, a.data()); + ASSERT_EQ(0x55555554, a[0]); + ASSERT_EQ(0xAAAAAAA9, a[1]); + ASSERT_EQ(0x37373736, a[2]); + + heracles::math::left_shift_uint3(a.data(), 32, a.data()); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0x55555554, a[1]); + ASSERT_EQ(0xAAAAAAA9, a[2]); + + return true; +} + +bool TEST_right_shift_uint96() +{ + std::vector a(3, 0); + std::vector b(3, 0xFFFFFFFF); + + heracles::math::right_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + std::fill_n(b.data(), b.size(), 0xFFFFFFFF); + heracles::math::right_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + a[0] = 0x55555555; + a[1] = 0xAAAAAAAA; + a[2] = 0xCDCDCDCD; + heracles::math::right_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0x55555555, b[0]); + ASSERT_EQ(0xAAAAAAAA, b[1]); + ASSERT_EQ(0xCDCDCDCD, b[2]); + heracles::math::right_shift_uint3(a.data(), 0, a.data()); + ASSERT_EQ(0x55555555, a[0]); + ASSERT_EQ(0xAAAAAAAA, a[1]); + ASSERT_EQ(0xCDCDCDCD, a[2]); + heracles::math::right_shift_uint3(a.data(), 1, b.data()); + ASSERT_EQ(0x2AAAAAAA, b[0]); + ASSERT_EQ(0xD5555555, b[1]); + ASSERT_EQ(0x66E6E6E6, b[2]); + heracles::math::right_shift_uint3(a.data(), 2, b.data()); + ASSERT_EQ(0x95555555, b[0]); + ASSERT_EQ(0x6AAAAAAA, b[1]); + ASSERT_EQ(0x33737373, b[2]); + heracles::math::right_shift_uint3(a.data(), 32, b.data()); + ASSERT_EQ(0xAAAAAAAA, b[0]); + ASSERT_EQ(0xCDCDCDCD, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 33, b.data()); + ASSERT_EQ(0xD5555555, b[0]); + ASSERT_EQ(0x66E6E6E6, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 95, b.data()); + ASSERT_EQ(1, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + heracles::math::right_shift_uint3(a.data(), 2, a.data()); + ASSERT_EQ(0x95555555, a[0]); + ASSERT_EQ(0x6AAAAAAA, a[1]); + ASSERT_EQ(0x33737373, a[2]); + + heracles::math::right_shift_uint3(a.data(), 32, a.data()); + ASSERT_EQ(0x6AAAAAAA, a[0]); + ASSERT_EQ(0x33737373, a[1]); + ASSERT_EQ(0, a[2]); + + return true; +} + +bool TEST_add_uint32_base() +{ + std::vector a(2, 0); + std::vector b(2, 0); + std::vector c(2); + + c[0] = 0xFFFFFFFF; + c[1] = 0xFFFFFFFF; + + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFE; + a[1] = 0xFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0xFFFFFFFF; + b[1] = 0xFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFE, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0xFFFFFFFE, a[0]); + ASSERT_EQ(0xFFFFFFFF, a[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(1, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 5; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(6, c[1]); + + return true; +} + +bool TEST_sub_uint32_base() +{ + std::vector a(2, 0); + std::vector b(2, 0); + std::vector c(2); + + c[0] = 0xFFFFFFFF; + c[1] = 0xFFFFFFFF; + + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFE, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0; + a[1] = 0; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0xFFFFFFFF, a[0]); + ASSERT_EQ(0xFFFFFFFF, a[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0xFFFFFFFF; + b[1] = 0xFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0, a[1]); + + a[0] = 0xFFFFFFFE; + a[1] = 0xFFFFFFFF; + b[0] = 0xFFFFFFFF; + b[1] = 0xFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0; + a[1] = 1; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0, c[1]); + + return true; +} + +bool TEST_xgcd32() +{ + std::tuple result; + + result = heracles::math::xgcd(7U, 7U); + ASSERT_EQ(result, std::make_tuple<>(7, 0, 1)); + result = heracles::math::xgcd(2U, 2U); + ASSERT_EQ(result, std::make_tuple<>(2, 0, 1)); + + result = heracles::math::xgcd(1U, 1U); + ASSERT_EQ(result, std::make_tuple<>(1, 0, 1)); + result = heracles::math::xgcd(1U, 2U); + ASSERT_EQ(result, std::make_tuple<>(1, 1, 0)); + result = heracles::math::xgcd(5U, 6U); + ASSERT_EQ(result, std::make_tuple<>(1, -1, 1)); + result = heracles::math::xgcd(13U, 19U); + ASSERT_EQ(result, std::make_tuple<>(1, 3, -2)); + result = heracles::math::xgcd(14U, 21U); + ASSERT_EQ(result, std::make_tuple<>(7, -1, 1)); + + result = heracles::math::xgcd(2U, 1U); + ASSERT_EQ(result, std::make_tuple<>(1, 0, 1)); + result = heracles::math::xgcd(6U, 5U); + ASSERT_EQ(result, std::make_tuple<>(1, 1, -1)); + result = heracles::math::xgcd(19U, 13U); + ASSERT_EQ(result, std::make_tuple<>(1, -2, 3)); + result = heracles::math::xgcd(21U, 14U); + ASSERT_EQ(result, std::make_tuple<>(7, 1, -1)); + + return true; +} + +bool TEST_reverse_bits() +{ + ASSERT_EQ(0, heracles::math::reverse_bits(0)); + ASSERT_EQ(0x80000000, heracles::math::reverse_bits(1)); + ASSERT_EQ(0x40000000, heracles::math::reverse_bits(2)); + ASSERT_EQ(0xC0000000, heracles::math::reverse_bits(3)); + ASSERT_EQ(0x00010000, heracles::math::reverse_bits(0x00008000)); + ASSERT_EQ(0xFFFF0000, heracles::math::reverse_bits(0x0000FFFF)); + ASSERT_EQ(0x0000FFFF, heracles::math::reverse_bits(0xFFFF0000)); + ASSERT_EQ(0x00008000, heracles::math::reverse_bits(0x00010000)); + + ASSERT_EQ(0, heracles::math::reverse_bits(0xFFFFFFFF, 0)); + + ASSERT_EQ(0, heracles::math::reverse_bits(0, 32)); + ASSERT_EQ(0x80000000, heracles::math::reverse_bits(1, 32)); + ASSERT_EQ(0x40000000, heracles::math::reverse_bits(2, 32)); + ASSERT_EQ(0xC0000000, heracles::math::reverse_bits(3, 32)); + ASSERT_EQ(0x00010000, heracles::math::reverse_bits(0x00008000, 32)); + ASSERT_EQ(0xFFFF0000, heracles::math::reverse_bits(0x0000FFFF, 32)); + ASSERT_EQ(0x0000FFFF, heracles::math::reverse_bits(0xFFFF0000, 32)); + ASSERT_EQ(0x00008000, heracles::math::reverse_bits(0x00010000, 32)); + + ASSERT_EQ(0, heracles::math::reverse_bits(0, 16)); + ASSERT_EQ(0x00008000, heracles::math::reverse_bits(1, 16)); + ASSERT_EQ(0x00004000, heracles::math::reverse_bits(2, 16)); + ASSERT_EQ(0x0000C000, heracles::math::reverse_bits(3, 16)); + ASSERT_EQ(0x00000001, heracles::math::reverse_bits(0x00008000, 16)); + ASSERT_EQ(0x0000FFFF, heracles::math::reverse_bits(0x0000FFFF, 16)); + ASSERT_EQ(0x00000000, heracles::math::reverse_bits(0xFFFF0000, 16)); + ASSERT_EQ(0x00000000, heracles::math::reverse_bits(0x00010000, 16)); + ASSERT_EQ(3, heracles::math::reverse_bits(0x0000C000, 16)); + ASSERT_EQ(2, heracles::math::reverse_bits(0x00004000, 16)); + ASSERT_EQ(1, heracles::math::reverse_bits(0x00008000, 16)); + ASSERT_EQ(0x0000FFFF, heracles::math::reverse_bits(0xFFFFFFFF, 16)); + + return true; +} + +bool TEST_add_uint64_mod() +{ + std::uint64_t mod; + + mod = 2; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(0, heracles::math::add_uint_mod(1, 1, mod)); + + mod = 10; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(2, heracles::math::add_uint_mod(1, 1, mod)); + ASSERT_EQ(4, heracles::math::add_uint_mod(7, 7, mod)); + ASSERT_EQ(3, heracles::math::add_uint_mod(6, 7, mod)); + + mod = 1305843001; + ASSERT_EQ(0, heracles::math::add_uint_mod(0, 0, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(0, 1, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(1, 0, mod)); + ASSERT_EQ(2, heracles::math::add_uint_mod(1, 1, mod)); + ASSERT_EQ(0, heracles::math::add_uint_mod(652921500, 652921501, mod)); + ASSERT_EQ(1, heracles::math::add_uint_mod(652921501, 652921501, mod)); + ASSERT_EQ(1305842999, heracles::math::add_uint_mod(1305843000, 1305843000, mod)); + + return true; +} + +bool TEST_exponentiate_uint64_mod() +{ + std::uint64_t mod; + + mod = 5; + ASSERT_EQ(1, heracles::math::exponentiate_uint_mod(1UL, 0UL, mod)); + ASSERT_EQ(1, heracles::math::exponentiate_uint_mod(1UL, 0xFFFFFFFFFFFFFFFFUL, mod)); + ASSERT_EQ(3, heracles::math::exponentiate_uint_mod(2UL, 0xFFFFFFFFFFFFFFFFUL, mod)); + + mod = 0x1000000000000000ULL; + ASSERT_EQ(0, heracles::math::exponentiate_uint_mod(2UL, 60UL, mod)); + ASSERT_EQ(0x800000000000000ULL, heracles::math::exponentiate_uint_mod(2UL, 59UL, mod)); + + mod = 131313131313; + ASSERT_EQ(39418477653ULL, heracles::math::exponentiate_uint_mod(2424242424UL, 16UL, mod)); + + return true; +} + +bool TEST_get_msb_index() +{ + ASSERT_EQ(0, heracles::math::get_msb_index(1U)); + ASSERT_EQ(1, heracles::math::get_msb_index(2U)); + ASSERT_EQ(1, heracles::math::get_msb_index(3U)); + ASSERT_EQ(2, heracles::math::get_msb_index(4U)); + ASSERT_EQ(4, heracles::math::get_msb_index(16U)); + ASSERT_EQ(15, heracles::math::get_msb_index(0xFFFFU)); + ASSERT_EQ(15, heracles::math::get_msb_index(0xFFFFUL)); + ASSERT_EQ(16, heracles::math::get_msb_index(0x10000U)); + ASSERT_EQ(16, heracles::math::get_msb_index(0x10000UL)); + ASSERT_EQ(31, heracles::math::get_msb_index(0xFFFFFFFFU)); + ASSERT_EQ(31, heracles::math::get_msb_index(0xFFFFFFFFUL)); + ASSERT_EQ(32, heracles::math::get_msb_index(0x100000000UL)); + ASSERT_EQ(63, heracles::math::get_msb_index(0xFFFFFFFFFFFFFFFFUL)); + + return true; +} + +bool TEST_get_significant_bit_count() +{ + ASSERT_EQ(0, heracles::math::get_significant_bit_count(0U)); + ASSERT_EQ(1, heracles::math::get_significant_bit_count(1U)); + ASSERT_EQ(2, heracles::math::get_significant_bit_count(2U)); + ASSERT_EQ(2, heracles::math::get_significant_bit_count(3U)); + ASSERT_EQ(3, heracles::math::get_significant_bit_count(4U)); + ASSERT_EQ(3, heracles::math::get_significant_bit_count(5U)); + ASSERT_EQ(3, heracles::math::get_significant_bit_count(6U)); + ASSERT_EQ(3, heracles::math::get_significant_bit_count(7U)); + ASSERT_EQ(4, heracles::math::get_significant_bit_count(8U)); + ASSERT_EQ(31, heracles::math::get_significant_bit_count(0x70000000U)); + ASSERT_EQ(31, heracles::math::get_significant_bit_count(0x7FFFFFFFU)); + ASSERT_EQ(32, heracles::math::get_significant_bit_count(0x80000000U)); + ASSERT_EQ(32, heracles::math::get_significant_bit_count(0xFFFFFFFFU)); + + return true; +} + +bool TEST_divide_uint192_inplace() +{ + std::vector input(3, 0); + std::vector quotient(3, 0); + + input[0] = 0; + input[1] = 0; + input[2] = 0; + heracles::math::divide_uint3_inplace(input.data(), 1UL, quotient.data()); + ASSERT_EQ(0, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(0, quotient[0]); + ASSERT_EQ(0, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + input[0] = 1; + input[1] = 0; + input[2] = 0; + heracles::math::divide_uint3_inplace(input.data(), 1UL, quotient.data()); + ASSERT_EQ(0, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(1, quotient[0]); + ASSERT_EQ(0, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + input[0] = 0x10101010U; + input[1] = 0x2B2B2B2BU; + input[2] = 0xF1F1F1F1U; + heracles::math::divide_uint3_inplace(input.data(), 0x1000UL, quotient.data()); + ASSERT_EQ(0x10, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(0xB2B0000000010101ULL, quotient[0]); + ASSERT_EQ(0x1F1000000002B2B2ULL, quotient[1]); + ASSERT_EQ(0xF1F1FULL, quotient[2]); + + input[0] = 1212121212121212ULL; + input[1] = 3434343434343434ULL; + input[2] = 5656565656565656ULL; + heracles::math::divide_uint3_inplace(input.data(), 7878787878787878UL, quotient.data()); + ASSERT_EQ(7272727272727272ULL, input[0]); + ASSERT_EQ(0, input[1]); + ASSERT_EQ(0, input[2]); + ASSERT_EQ(17027763760347278414ULL, quotient[0]); + ASSERT_EQ(13243816258047883211ULL, quotient[1]); + ASSERT_EQ(0, quotient[2]); + + return true; +} + +bool TEST_left_shift_uint192() +{ + std::vector a(3, 0); + std::vector b(3, 0xFFFFFFFFFFFFFFFF); + + heracles::math::left_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + std::fill_n(b.data(), b.size(), 0xFFFFFFFFFFFFFFFF); + heracles::math::left_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::left_shift_uint3(a.data(), 10, a.data()); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0, a[1]); + ASSERT_EQ(0, a[2]); + + a[0] = 0x5555555555555555; + a[1] = 0xAAAAAAAAAAAAAAAA; + a[2] = 0xCDCDCDCDCDCDCDCD; + heracles::math::left_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0x5555555555555555, b[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[1]); + ASSERT_EQ(0xCDCDCDCDCDCDCDCD, b[2]); + heracles::math::left_shift_uint3(a.data(), 0, a.data()); + ASSERT_EQ(0x5555555555555555, a[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, a[1]); + ASSERT_EQ(0xCDCDCDCDCDCDCDCD, a[2]); + heracles::math::left_shift_uint3(a.data(), 1, b.data()); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[0]); + ASSERT_EQ(0x5555555555555554, b[1]); + ASSERT_EQ(0x9B9B9B9B9B9B9B9B, b[2]); + heracles::math::left_shift_uint3(a.data(), 2, b.data()); + ASSERT_EQ(0x5555555555555554, b[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAA9, b[1]); + ASSERT_EQ(0x3737373737373736, b[2]); + heracles::math::left_shift_uint3(a.data(), 64, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0x5555555555555555, b[1]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[2]); + heracles::math::left_shift_uint3(a.data(), 65, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[1]); + ASSERT_EQ(0x5555555555555554, b[2]); + heracles::math::left_shift_uint3(a.data(), 191, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0x8000000000000000, b[2]); + + heracles::math::left_shift_uint3(a.data(), 2, a.data()); + ASSERT_EQ(0x5555555555555554, a[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAA9, a[1]); + ASSERT_EQ(0x3737373737373736, a[2]); + + heracles::math::left_shift_uint3(a.data(), 64, a.data()); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0x5555555555555554, a[1]); + ASSERT_EQ(0xAAAAAAAAAAAAAAA9, a[2]); + + return true; +} + +bool TEST_right_shift_uint192() +{ + std::vector a(3, 0); + std::vector b(3, 0xFFFFFFFFFFFFFFFF); + + heracles::math::right_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + std::fill_n(b.data(), b.size(), 0xFFFFFFFFFFFFFFFF); + heracles::math::right_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 10, b.data()); + ASSERT_EQ(0, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + a[0] = 0x5555555555555555; + a[1] = 0xAAAAAAAAAAAAAAAA; + a[2] = 0xCDCDCDCDCDCDCDCD; + heracles::math::right_shift_uint3(a.data(), 0, b.data()); + ASSERT_EQ(0x5555555555555555, b[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[1]); + ASSERT_EQ(0xCDCDCDCDCDCDCDCD, b[2]); + heracles::math::right_shift_uint3(a.data(), 0, a.data()); + ASSERT_EQ(0x5555555555555555, a[0]); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, a[1]); + ASSERT_EQ(0xCDCDCDCDCDCDCDCD, a[2]); + heracles::math::right_shift_uint3(a.data(), 1, b.data()); + ASSERT_EQ(0x2AAAAAAAAAAAAAAA, b[0]); + ASSERT_EQ(0xD555555555555555, b[1]); + ASSERT_EQ(0x66E6E6E6E6E6E6E6, b[2]); + heracles::math::right_shift_uint3(a.data(), 2, b.data()); + ASSERT_EQ(0x9555555555555555, b[0]); + ASSERT_EQ(0x6AAAAAAAAAAAAAAA, b[1]); + ASSERT_EQ(0x3373737373737373, b[2]); + heracles::math::right_shift_uint3(a.data(), 64, b.data()); + ASSERT_EQ(0xAAAAAAAAAAAAAAAA, b[0]); + ASSERT_EQ(0xCDCDCDCDCDCDCDCD, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 65, b.data()); + ASSERT_EQ(0xD555555555555555, b[0]); + ASSERT_EQ(0x66E6E6E6E6E6E6E6, b[1]); + ASSERT_EQ(0, b[2]); + heracles::math::right_shift_uint3(a.data(), 191, b.data()); + ASSERT_EQ(1, b[0]); + ASSERT_EQ(0, b[1]); + ASSERT_EQ(0, b[2]); + + heracles::math::right_shift_uint3(a.data(), 2, a.data()); + ASSERT_EQ(0x9555555555555555, a[0]); + ASSERT_EQ(0x6AAAAAAAAAAAAAAA, a[1]); + ASSERT_EQ(0x3373737373737373, a[2]); + + heracles::math::right_shift_uint3(a.data(), 64, a.data()); + ASSERT_EQ(0x6AAAAAAAAAAAAAAA, a[0]); + ASSERT_EQ(0x3373737373737373, a[1]); + ASSERT_EQ(0, a[2]); + + return true; +} + +bool TEST_add_uint64_base() +{ + std::vector a(2, 0); + std::vector b(2, 0); + std::vector c(2); + + c[0] = 0xFFFFFFFF; + c[1] = 0xFFFFFFFF; + + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFE; + a[1] = 0xFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFFFFFFFFFF; + a[1] = 0xFFFFFFFFFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0xFFFFFFFFFFFFFFFF); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFFFFFFFFFF; + a[1] = 0xFFFFFFFFFFFFFFFF; + b[0] = 0xFFFFFFFFFFFFFFFF; + b[1] = 0xFFFFFFFFFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFE, c[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[1]); + ASSERT_NE(0, heracles::math::add_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFE, a[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, a[1]); + + a[0] = 0xFFFFFFFFFFFFFFFF; + a[1] = 0; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(1, c[1]); + + a[0] = 0xFFFFFFFFFFFFFFFF; + a[1] = 5; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::add_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(6, c[1]); + + return true; +} + +bool TEST_sub_uint64_base() +{ + std::vector a(2, 0); + std::vector b(2, 0); + std::vector c(2); + + c[0] = 0xFFFFFFFF; + c[1] = 0xFFFFFFFF; + + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 0; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0xFFFFFFFF; + a[1] = 0xFFFFFFFF; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFE, c[0]); + ASSERT_EQ(0xFFFFFFFF, c[1]); + + a[0] = 0; + a[1] = 0; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[1]); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, a[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, a[1]); + + a[0] = 0xFFFFFFFFFFFFFFFF; + a[1] = 0xFFFFFFFFFFFFFFFF; + b[0] = 0xFFFFFFFFFFFFFFFF; + b[1] = 0xFFFFFFFFFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0, c[0]); + ASSERT_EQ(0, c[1]); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, a.data())); + ASSERT_EQ(0, a[0]); + ASSERT_EQ(0, a[1]); + + a[0] = 0xFFFFFFFFFFFFFFFE; + a[1] = 0xFFFFFFFFFFFFFFFF; + b[0] = 0xFFFFFFFFFFFFFFFF; + b[1] = 0xFFFFFFFFFFFFFFFF; + std::fill_n(c.data(), c.size(), 0); + ASSERT_NE(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[1]); + + a[0] = 0; + a[1] = 1; + b[0] = 1; + b[1] = 0; + std::fill_n(c.data(), c.size(), 0); + ASSERT_EQ(0, heracles::math::sub_uint_base(a.data(), b.data(), 2, c.data())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFF, c[0]); + ASSERT_EQ(0, c[1]); + + return true; +} + +bool TEST_montgomery_add() +{ + ASSERT_EQ(11661950U, heracles::math::montgomeryAdd(177890559U, 470380160U, 536608769U)); + ASSERT_EQ(330474188U, heracles::math::montgomeryAdd(192697207U, 137776981U, 536608769U)); + ASSERT_EQ(111700460U, heracles::math::montgomeryAdd(72857859U, 38842601U, 536215553U)); + ASSERT_EQ(301757272U, heracles::math::montgomeryAdd(482904845U, 355067980U, 536215553U)); + ASSERT_EQ(149531932U, heracles::math::montgomeryAdd(83952415U, 65579517U, 1070727169U)); + ASSERT_EQ(176142121U, heracles::math::montgomeryAdd(441592427U, 805276863U, 1070727169U)); + + return true; +} + +bool TEST_montgomery_mul() +{ + ASSERT_EQ(514071123U, heracles::math::montgomeryMul(166645782U, 378454820U, 1070727169U)); + ASSERT_EQ(930227960U, heracles::math::montgomeryMul(45847266U, 378454820U, 1070727169U)); + ASSERT_EQ(313946907U, heracles::math::montgomeryMul(257508513U, 63724800U, 378470401U)); + ASSERT_EQ(256679068U, heracles::math::montgomeryMul(94982773U, 100100078U, 378470401U)); + ASSERT_EQ(183766988U, heracles::math::montgomeryMul(104720473U, 242438106U, 381616129U)); + ASSERT_EQ(149148360U, heracles::math::montgomeryMul(158503089U, 242438106U, 381616129U)); + + return true; +} diff --git a/p-isa_tools/data_formats/test/math_unittest/unittest.h b/p-isa_tools/data_formats/test/math_unittest/unittest.h new file mode 100644 index 00000000..9dc083f9 --- /dev/null +++ b/p-isa_tools/data_formats/test/math_unittest/unittest.h @@ -0,0 +1,79 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#ifndef ASSERT_EQ +#define ASSERT_EQ(a, b) \ + do \ + { \ + if ((a) != (b)) /* NOLINT(readability/braces) */ \ + { \ + fprintf(stderr, "Test fail: %s - %s:%d", __FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) +#endif // ASSERT_EQ +#ifndef ASSERT_NE +#define ASSERT_NE(a, b) \ + do \ + { \ + if ((a) == (b)) /* NOLINT(readability/braces) */ \ + { \ + fprintf(stderr, "Test fail: %s - %s:%d", __FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) +#endif // ASSERT_NE + +#ifndef ASSERT_FALSE +#define ASSERT_FALSE(a) \ + do \ + { \ + if (a) /* NOLINT(readability/braces) */ \ + { \ + fprintf(stderr, "Test fail: %s - %s:%d", __FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) +#endif // ASSERT_FALSE + +#ifndef ASSERT_TRUE +#define ASSERT_TRUE(a) \ + do \ + { \ + if (!(a)) /* NOLINT(readability/braces) */ \ + { \ + fprintf(stderr, "Test fail: %s - %s:%d", __FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) +#endif // ASSERT_TRUE + +bool TEST_add_uint32_mod(); +bool TEST_multiply_uint_mod(); +bool TEST_exponentiate_uint32_mod(); +bool TEST_negate_uint_mod(); +bool TEST_try_invert_uint_mod32(); +bool TEST_get_msb_index(); +bool TEST_get_significant_bit_count(); +bool TEST_get_significant_bit_count_uint(); +bool TEST_divide_uint96_inplace(); +bool TEST_left_shift_uint96(); +bool TEST_right_shift_uint96(); +bool TEST_add_uint32_base(); +bool TEST_sub_uint32_base(); +bool TEST_xgcd32(); +bool TEST_reverse_bits(); + +bool TEST_add_uint64_mod(); +bool TEST_exponentiate_uint64_mod(); +bool TEST_divide_uint192_inplace(); +bool TEST_left_shift_uint192(); +bool TEST_right_shift_uint192(); +bool TEST_add_uint64_base(); +bool TEST_sub_uint64_base(); +bool TEST_montgomery_add(); +bool TEST_montgomery_mul(); +#endif // TRACING_DATA_FORMATS_SRC_DATA_FORMATS_TEST_MATH_UNITTEST_UNITTEST_H_ diff --git a/p-isa_tools/data_formats/tracers/openfhe/CMakeLists.txt b/p-isa_tools/data_formats/tracers/openfhe/CMakeLists.txt new file mode 100644 index 00000000..03b8fe24 --- /dev/null +++ b/p-isa_tools/data_formats/tracers/openfhe/CMakeLists.txt @@ -0,0 +1,68 @@ +# Include FetchContent module +include(FetchContent) + +# Add OpenFHE via FetchContent +FetchContent_Declare( + OpenFHE + GIT_REPOSITORY https://github.com/AlexanderViand/openfhe-development.git + GIT_TAG 7e2eafc49b6b4c1a0000b1722a8749797ec277d9 # head of `tracing` branch (2025-08-07) +) + +# Set OpenFHE build options before making it available +set(BUILD_UNITTESTS OFF CACHE BOOL "" FORCE) +set(BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) +set(BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) +set(ENABLE_TRACER ON CACHE BOOL "" FORCE) + +message(STATUS "Fetching OpenFHE, this may take a while...") +FetchContent_MakeAvailable(OpenFHE) +message(STATUS "Finished configuring OpenFHE") + +FetchContent_GetProperties(openfhe) + +# Create executable from tracing_example.cpp +add_executable(tracing_example tracing_example.cpp tracer.h) +# Set C++ standard +target_compile_features(tracing_example PRIVATE cxx_std_17) +# Link with OpenFHE libraries +target_link_libraries(tracing_example PRIVATE HERACLES_DATA_FORMATS::heracles_data_formats OPENFHEcore OPENFHEpke OPENFHEbinfhe) + +target_include_directories(tracing_example PRIVATE + # Third Party Includes + $ + $ + $ + $ + # public headers that sit in the repo + $ + $ + $ + # generated header (configure_file → config_core.h) + $) + + +# Set compiler flags for optimization and debug info +target_compile_options(tracing_example PRIVATE + $<$:-O3> + $<$:-g -O0> +) + + +# define a custom target that runs tracing, then submits the trace to the program mapper, finally sending the pisa and mem file to the functional modeler +add_custom_target( + run_tracing_example + # Run the actual example, which will generate the traces as end-to-end-test/tracing_example.bin and end-to-end-test/tracing_example_data.bin + COMMAND tracing_example + # Run the program mapper on the instruction trace, generating end-to-end-test/tracing_example.bin.csv + COMMAND env VIRTUAL_ENV=${VENV_PATH} PATH=${VENV_PATH}/bin:$ENV{PATH} PYTHONPATH=${VENV_SITE_PACKAGES}:$ENV{PYTHONPATH} $ ${CMAKE_BINARY_DIR}/end-to-end-test/tracing_example.bin ${CMAKE_SOURCE_DIR}/kerngen/kerngen.py + COMMAND $ ${CMAKE_BINARY_DIR}/end-to-end-test/tracing_example_pisa.csv --verbose --hec_dataformats_mode --hec_dataformats_poly_program_location ${CMAKE_BINARY_DIR}/end-to-end-test/tracing_example.bin --hec_dataformats_data ${CMAKE_BINARY_DIR}/end-to-end-test/tracing_example_data.bin + # TODO: Next step: assemble the *.pisa.csv and the *tw.mem using the assembler from https://github.com/IntelLabs/HERACLES-HGCF + # TODO: Then it's a dead end :( as we don't have any tooling that can support non-toy sized workloads + DEPENDS tracing_example program_mapper functional_modeler create-end-to-end-test-dir + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/end-to-end-test +) + +add_custom_target(create-end-to-end-test-dir + COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_BINARY_DIR}/end-to-end-test + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_BINARY_DIR}/end-to-end-test +) diff --git a/p-isa_tools/data_formats/tracers/openfhe/README.md b/p-isa_tools/data_formats/tracers/openfhe/README.md new file mode 100644 index 00000000..4dfdc180 --- /dev/null +++ b/p-isa_tools/data_formats/tracers/openfhe/README.md @@ -0,0 +1,88 @@ +# OpenFHE Tracer + +This directory contains the HERACLES tracer implementation for OpenFHE, which enables extraction of FHE computation traces for use with the p-ISA tools. + +## Overview + +The OpenFHE tracer captures homomorphic encryption operations performed by OpenFHE and generates trace files that can be processed by the p-ISA toolchain. This enables: + +- Extraction of FHE instruction sequences +- Generation of data traces for polynomial operations +- End-to-end compilation from FHE programs to hardware accelerator instructions + +## Files + +- `tracer.h` - Main tracer implementation that hooks into OpenFHE's tracing infrastructure +- `tracing_example.cpp` - Example program demonstrating tracer usage with basic FHE operations +- `CMakeLists.txt` - Build configuration that fetches OpenFHE with tracing support enabled + +## Building + +From the repository root: + +```bash +# Configure the project +mkdir -p build +cmake -B build -S p-isa_tools + +# Build the tracing example +cmake --build build --target tracing_example +``` + +## Running the Example + +To run the complete end-to-end tracing pipeline: + +```bash +cmake --build build --target run_tracing_example +``` + +This will: +1. Execute the tracing example, generating trace files in `build/end-to-end-test/`: + - `tracing_example.bin` - FHE instruction trace + - `tracing_example_data.bin` - Data trace with polynomial values +2. Run the program mapper on the instruction trace to generate `tracing_example.bin.csv` + (this internally uses the kernel generator) +3. Run the functional modeler to process the traces (see note below) + +### Output Files + +After running, you'll find the following in `build/end-to-end-test/`: +- `tracing_example.bin` - Binary FHE instruction trace +- `tracing_example_data.bin` - Binary data trace +- `tracing_example.bin.csv` - Mapped instruction sequence +- `tracing_example_pisa.csv` - p-ISA instructions (if functional modeler succeeds) + +## Known Issues + +> **Note:** The functional modeler may fail with an error about missing `partQHatInvModq`. This is currently expected as recent versions of OpenFHE no longer use this parameter internally, but the p-isa_tools still expect it. + +## Usage in Your Own Code + +To use the tracer in your OpenFHE application: + +```cpp +#include "tracer.h" + +// After creating your CryptoContext +auto cc = GenCryptoContext(parameters); + +// Create and attach the tracer +IF_TRACE(auto tracer = std::make_shared>("output_name", cc)); +IF_TRACE(cc->setTracer(tracer)); + +// Your FHE operations will now be traced +auto result = cc->EvalAdd(cipher1, cipher2); + +// Save trace files +IF_TRACE(tracer->saveBinaryTrace()); +IF_TRACE(tracer->saveJsonTrace()); // for debugging/manual inspection +``` + +## Dependencies + +Note: these are automatically fetched/created by CMake when using `run_tracing_example` + +- OpenFHE +- HERACLES_DATA_FORMATS library (built as part of p-isa_tools) +- Python environment with kerngen for program mapping diff --git a/p-isa_tools/data_formats/tracers/openfhe/tracer.h b/p-isa_tools/data_formats/tracers/openfhe/tracer.h new file mode 100644 index 00000000..8bc553d7 --- /dev/null +++ b/p-isa_tools/data_formats/tracers/openfhe/tracer.h @@ -0,0 +1,772 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#ifndef P_ISA_TOOLS_DATA_FORMATS_TRACERS_OPENFHE_TRACER_H_ +#define P_ISA_TOOLS_DATA_FORMATS_TRACERS_OPENFHE_TRACER_H_ + +// Defines ENABLE_TRACER (via config_core.h) so needs to be outside the #ifdef ENABLE_TRACER_SUPPORT +#include "utils/tracing.h" + +#ifdef ENABLE_TRACER + +#ifdef WITH_OPENMP +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ciphertext-ser.h" +#include "cryptocontext-ser.h" +#include "key/key-ser.h" +#include "plaintext-ser.h" +#include "scheme/bfvrns/bfvrns-ser.h" +#include "scheme/bgvrns/bgvrns-ser.h" +#include "scheme/ckksrns/ckksrns-ser.h" +#include "utils/hashutil.h" + +#include +#include +#include + +namespace lbcrypto { + +// Note: while this follows the standard template conventions of OpenFHE, it's really only designed to work for DCRTPoly... +template +class HeraclesTracer; + +template +class HeraclesFunctionTracer : public FunctionTracer +{ +public: + HeraclesFunctionTracer(const std::string &func, HeraclesTracer *tracer) : + m_tracer(tracer) + { + m_currentInstruction = heracles::fhe_trace::Instruction(); + + // TODO: we should differentiate between high-level ops and low-level ops + // and use eval_op name for higher level ops that created several lower-level ops + // but that also requires adding a bit of scoping logic in HeraclesTracer + m_currentInstruction.set_evalop_name(func); // Store the original function name + + m_currentInstruction.set_op(m_tracer->getHeraclesInstruction(func)); + + auto cc = m_tracer->getCryptoContext(); + if (cc->getSchemeId() != CKKSRNS_SCHEME) + { + // FIXME: set this based on the plaintext algebra being used + m_currentInstruction.set_plaintext_index(0); + } + } + + ~HeraclesFunctionTracer() override + { + // Transfer collected operands and parameters to the instruction + for (const auto &source : m_sources) + { + m_currentInstruction.mutable_args()->add_srcs()->CopyFrom(source); + } + + for (const auto &dest : m_destinations) + { + m_currentInstruction.mutable_args()->add_dests()->CopyFrom(dest); + } + + // Transfer parameters using the stored names + for (size_t i = 0; i < m_parameters.size() && i < m_parameterNames.size(); ++i) + { + (*m_currentInstruction.mutable_args()->mutable_params())[m_parameterNames[i]].CopyFrom(m_parameters[i]); + } + + // Finalize the instruction and add it to the tracer + m_tracer->addInstruction(m_currentInstruction); + } + + // Input registration methods + + /// Register data with in/out flag to avoid duplication of conversion logic + void registerData(std::vector elements, std::string name, bool isOutput = false) + { + if (elements.empty()) + throw std::runtime_error("Cannot register empty data."); + + // Use the semantic name (ct, pt, sk, pk, etc.) instead of always "dcrtpoly" + std::string id = m_tracer->getUniqueObjectId(elements, name); + + // Create OperandObject (name, num_rns, order) + auto operand = heracles::fhe_trace::OperandObject(); + operand.set_symbol_name(id); + operand.set_num_rns(elements[0].GetNumOfElements()); + operand.set_order(elements.size()); + + // Add to appropriate member variable for later processing in destructor + if (isOutput) + { + m_destinations.push_back(operand); + m_tracer->trackOutput(id); + } + else + { + m_sources.push_back(operand); + // Check for orphaned inputs (objects that weren't registered as outputs) + m_tracer->checkInput(id, m_currentInstruction.op()); + } + + // Add to TestVector: Convert DCRTPoly to protobuf format + auto data = heracles::data::Data(); + auto dcrtpoly = data.mutable_dcrtpoly(); + + // Set whether the polynomial is in NTT form + dcrtpoly->set_in_ntt_form(elements[0].GetFormat() == Format::EVALUATION); + + for (const auto &element : elements) + { + auto poly = dcrtpoly->add_polys(); + convertDCRTPolyToProtobuf(poly, element); + } + + m_tracer->storeData(id, data); + } + + // Helper for single elements + void registerData(Element element, std::string name, bool isOutput = false) + { + registerData(std::vector(1, element), name, isOutput); + } + + void registerInput(Ciphertext ciphertext, std::string name, bool isMutable) override + { + registerData(ciphertext->GetElements(), name.empty() ? "ciphertext" : name); + } + + void registerInput(ConstCiphertext ciphertext, std::string name, bool isMutable) override + { + registerData(ciphertext->GetElements(), name.empty() ? "ciphertext" : name); + } + + void registerInput(Plaintext plaintext, std::string name, bool isMutable) override + { + registerData(plaintext->GetElement(), name.empty() ? "plaintext" : name); + } + + void registerInput(ConstPlaintext plaintext, std::string name, bool isMutable) override + { + registerData(plaintext->GetElement(), name.empty() ? "plaintext" : name); + } + + void registerInput(const PublicKey publicKey, std::string name, bool isMutable) override + { + registerData(publicKey->GetPublicElements(), name.empty() ? "publickey" : name, false); + } + + void registerInput(const PrivateKey privateKey, std::string name, bool isMutable) override + { + registerData(privateKey->GetPrivateElement(), name.empty() ? "secretkey" : name, false); + } + + void registerInput(const EvalKey evalKey, std::string name, bool isMutable) override + { + name = name.empty() ? "evalkey" : name; + // EvalKey doesn't have GetElement method, just skip for now + // FIXME: implement proper EvalKey extraction + std::cout << "Warning: EvalKey registration not yet implemented in HeraclesTracer." << std::endl; + (void)evalKey; // Suppress unused parameter warning + } + + void registerInput(const PlaintextEncodings encoding, std::string name, bool isMutable) override + { + std::string encodingStr; + switch (encoding) + { + case PlaintextEncodings::COEF_PACKED_ENCODING: + encodingStr = "COEF_PACKED_ENCODING"; + break; + case PlaintextEncodings::PACKED_ENCODING: + encodingStr = "PACKED_ENCODING"; + break; + case PlaintextEncodings::STRING_ENCODING: + encodingStr = "STRING_ENCODING"; + break; + case PlaintextEncodings::CKKS_PACKED_ENCODING: + encodingStr = "CKKS_PACKED_ENCODING"; + break; + default: + encodingStr = "UNKNOWN_ENCODING"; + break; + } + addParameter(name.empty() ? "encoding" : name, encodingStr, "string"); + } + + void registerInput(const std::vector &values, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "int64_vector" : name, values.size(), "uint64"); + } + + void registerInput(const std::vector &values, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "int32_vector" : name, values.size(), "uint32"); + } + + void registerInput(const std::vector &values, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "uint32_vector" : name, values.size(), "uint32"); + } + + void registerInput(const std::vector &values, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "double_vector" : name, values.size(), "uint64"); + } + + void registerInput(double value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "double" : name, value, "double"); + } + + void registerInput(std::complex value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "complex_real" : name + "_real", value.real(), "double"); + addParameter(name.empty() ? "complex_imag" : name + "_imag", value.imag(), "double"); + } + + void registerInput(const std::vector> &values, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "complex_vector" : name, values.size(), "uint64"); + } + + void registerInput(int64_t value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "int64" : name, value, "int64"); + } + + void registerInput(size_t value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "size_t" : name, value, "uint64"); + } + + void registerInput(bool value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "bool" : name, value ? "true" : "false", "string"); + } + + void registerInput(const std::string &value, std::string name, bool isMutable) override + { + addParameter(name.empty() ? "string" : name, value, "string"); + } + + void registerInput(const std::shared_ptr>> &evalKeyMap, std::string name, + bool isMutable) override + { + size_t mapSize = evalKeyMap ? evalKeyMap->size() : 0; + addParameter(name.empty() ? "eval_key_map_size" : name + "_size", mapSize, "uint64"); + } + + void registerInput(void *ptr, std::string name, bool isMutable) override + { + throw std::runtime_error("HERACLES tracing does not support registering non-typed inputs."); + } + + // Output registration methods + Ciphertext registerOutput(Ciphertext ciphertext, std::string name) override + { + if (ciphertext && ciphertext->GetElements().size() > 0) + { + registerData(ciphertext->GetElements(), name.empty() ? "ciphertext" : name, true); // true = output + } + return ciphertext; + } + + ConstCiphertext registerOutput(ConstCiphertext ciphertext, std::string name) override + { + if (ciphertext && ciphertext->GetElements().size() > 0) + { + registerData(ciphertext->GetElements(), name.empty() ? "ciphertext" : name, true); // true = output + } + return ciphertext; + } + + Plaintext registerOutput(Plaintext plaintext, std::string name) override + { + if (plaintext) + { + // Convert single element to vector for registerData + std::vector elements = { plaintext->GetElement() }; + registerData(elements, name.empty() ? "plaintext" : name, true); // true = output + } + return plaintext; + } + + KeyPair registerOutput(KeyPair keyPair, std::string name) override + { + if (keyPair.publicKey) + registerData(keyPair.publicKey->GetPublicElements(), "publickey", true); + if (keyPair.secretKey) + registerData(keyPair.secretKey->GetPrivateElement(), "secretkey", true); + return keyPair; + } + + EvalKey registerOutput(EvalKey evalKey, std::string name) override + { + if (evalKey) + { + // Convert evaluation key elements to vector for registerData + std::vector elements = evalKey->GetBVector(); // Get the B vector + registerData(elements, name.empty() ? "evalkey" : name, true); // true = output + } + return evalKey; + } + + std::vector> registerOutput(std::vector> evalKeys, std::string name) override + { + // TODO: registerData this, too + return evalKeys; + } + + std::vector> registerOutput(std::vector> ciphertexts, + std::string name) override + { + for (auto &ct : ciphertexts) + { + registerOutput(ct, name); + } + return ciphertexts; + } + + std::shared_ptr>> registerOutput( + std::shared_ptr>> evalKeyMap, std::string name) override + { + // TODO: registerData this, too + return evalKeyMap; + } + + PublicKey registerOutput(PublicKey publicKey, std::string name) override + { + // TODO: registerData this, too + return publicKey; + } + + PrivateKey registerOutput(PrivateKey privateKey, std::string name) override + { + // TODO: registerData this, too + return privateKey; + } + + std::string registerOutput(const std::string &value, std::string name) override + { + // TODO: registerData this, too + return value; + } + + Element registerOutput(Element element, std::string name) override + { + // TODO: registerData this, too + return element; + } + +private: + HeraclesTracer *m_tracer; + heracles::fhe_trace::Instruction m_currentInstruction; + + // We record the args and params in case some ops require reordering them + std::vector m_sources; + std::vector m_destinations; + std::vector m_parameters; + std::vector m_parameterNames; // Parameter names corresponding to m_parameters + + /// Helper to extract SSA ID from objects for HERACLES tracing (delegated to HeraclesTracer) + template + std::string getObjectId(T obj, const std::string &type) + { + return m_tracer->getUniqueObjectId(obj, type); + } + + /// Helper to create HERACLES OperandObject for ciphertexts/plaintexts + void setHERACLESOperandObject(heracles::fhe_trace::OperandObject *opObj, const std::string &objectId, + size_t numRNS = 0, size_t order = 1) + { + opObj->set_symbol_name(objectId); + opObj->set_num_rns(numRNS); + opObj->set_order(order); + } + + /// Helper to add parameter to HERACLES instruction + template + void addParameter(const std::string &name, const T &value, const std::string &type) + { + heracles::fhe_trace::Parameter param; + std::stringstream ss; + ss << value; + param.set_value(ss.str()); + + // Set parameter type based on type string + std::string upperType = type; + std::transform(upperType.begin(), upperType.end(), upperType.begin(), ::toupper); + + if (upperType == "DOUBLE") + { + param.set_type(heracles::fhe_trace::ValueType::DOUBLE); + } + else if (upperType == "FLOAT") + { + param.set_type(heracles::fhe_trace::ValueType::FLOAT); + } + else if (upperType == "INT32") + { + param.set_type(heracles::fhe_trace::ValueType::INT32); + } + else if (upperType == "INT64") + { + param.set_type(heracles::fhe_trace::ValueType::INT64); + } + else if (upperType == "UINT32") + { + param.set_type(heracles::fhe_trace::ValueType::UINT32); + } + else if (upperType == "UINT64") + { + param.set_type(heracles::fhe_trace::ValueType::UINT64); + } + else + { + param.set_type(heracles::fhe_trace::ValueType::STRING); + } + + // Store in member variables with name for later processing + m_parameterNames.push_back(name); + m_parameters.push_back(param); + } + + /// Helper to convert DCRTPoly to HERACLES protobuf format + void convertDCRTPolyToProtobuf(heracles::data::Polynomial *proto_poly, const Element &dcrtpoly) + { + const auto &elems = dcrtpoly.GetAllElements(); + proto_poly->set_in_openfhe_evaluation((dcrtpoly.GetFormat() == Format::EVALUATION)); + + for (size_t l = 0; l < dcrtpoly.GetNumOfElements(); ++l) + { + size_t poly_degree = elems[l].GetLength(); + auto elem_vals = elems[l].GetValues(); + auto rns_poly_pb = proto_poly->add_rns_polys(); + + std::vector v_coeffs(poly_degree); + for (size_t j = 0; j < poly_degree; ++j) + { + v_coeffs[j] = elem_vals[j].ConvertToInt(); + } + + *rns_poly_pb->mutable_coeffs() = { v_coeffs.begin(), v_coeffs.end() }; + rns_poly_pb->set_modulus(elems[l].GetModulus().ConvertToInt()); + } + } +}; + +/// HERACLES Protobuf Tracing implementation +/// Generates protobuf traces compatible with the HERACLES project +template +class HeraclesTracer : public Tracer +{ +public: + HeraclesTracer(const std::string &filename = "openfhe-heracles-trace", const CryptoContext &cc = nullptr, bool warnOnUnregisteredInputs = true) : + m_filename(filename), m_context(cc), m_warnOnUnregisteredInputs(warnOnUnregisteredInputs) + { + if (!cc) + { + throw std::runtime_error("HeraclesTracer requires a valid CryptoContext - cannot be null"); + } + _initializeTrace(); + } + + ~HeraclesTracer() override = default; + + // Override the virtual createFunctionTracer method (required by new API) + std::unique_ptr> createFunctionTracer(std::string func) override + { + // Check if this the func matches a no_emit_prefix + // If yes, return a null tracer that does not do anything + for (const auto &prefix : no_emit_prefixes) + if (func.find(prefix) == 0) + return std::make_unique>(); + + // Otherwise, create a real tracer that will emit instructions + return std::make_unique>(func, this); + } + + CryptoContext getCryptoContext() + { + return m_context; + } + + /// Generate unique object ID using SimpleTracer-style logic + template + std::string getUniqueObjectId(T obj, const std::string &type) + { + // Serialize and hash the object for uniqueness detection + std::stringstream serialStream; + Serial::Serialize(obj, serialStream, SerType::BINARY); + const std::string hash = HashUtil::HashString(serialStream.str()); + + // Check if we already have a unique ID for this hash + auto hashIt = m_uniqueID.find(hash); + if (hashIt != m_uniqueID.end()) + { + // Object already seen - reuse existing ID + return hashIt->second; + } + + // Generate new ID using counter + size_t &counter = m_counters[type]; + std::string id = type + "_" + std::to_string(++counter); + m_uniqueID[hash] = id; + return id; + } + + void addInstruction(const heracles::fhe_trace::Instruction &instruction) + { + std::lock_guard lock(m_mutex); + m_FHETrace->add_instructions()->CopyFrom(instruction); + } + + /// Track an object ID as a known output + void trackOutput(const std::string &objectId) + { + std::lock_guard lock(m_mutex); + m_knownOutputs.insert(objectId); + } + + /// Check if an input object ID was previously registered as an output + /// Prints a warning if the object appears to be "orphaned" (not from any traced output) + void checkInput(const std::string &objectId, const std::string &operationName) + { + std::lock_guard lock(m_mutex); + if (m_warnOnUnregisteredInputs && m_knownOutputs.find(objectId) == m_knownOutputs.end()) + { + std::cout << "WARNING: Object '" << objectId << "' used as input in operation '" << operationName + << "' but was never registered as output of any traced operation." << std::endl; + std::cout << "This is normal if only tracing server-side code (and indicates this is a client input)," + << " but may indicate missing internal tracing logic if tracing client and server side code." << std::endl; + } + } + + /// Store data for test vector + void storeData(const std::string &objectId, const heracles::data::Data &data) + { + std::lock_guard lock(m_mutex); + (*m_TestVector->mutable_sym_data_map())[objectId] = data; + } + + /// Save trace to file in binary format + void saveBinaryTrace() + { + std::lock_guard lock(m_mutex); + + heracles::fhe_trace::store_trace(m_filename + ".bin", *m_FHETrace); + + // Create manifest for the binary files + heracles::data::hdf_manifest manifest; + + // Store context and test vector with manifest + heracles::data::store_hec_context(&manifest, m_filename + "_context.bin", *m_FHEContext); + heracles::data::store_testvector(&manifest, m_filename + "_testvector.bin", *m_TestVector); + + // Store the combined data trace + heracles::data::store_data_trace(m_filename + "_data.bin", *m_FHEContext, *m_TestVector); + + // Generate the manifest file + heracles::data::generate_manifest(m_filename + "_manifest.txt", manifest); + } + + /// Save trace to file in JSON format + void saveJsonTrace() + { + std::lock_guard lock(m_mutex); + heracles::fhe_trace::store_json_trace(m_filename + ".json", *m_FHETrace); + heracles::data::store_hec_context_json(m_filename + "_context.json", *m_FHEContext); + heracles::data::store_testvector_json(m_filename + "_testvector.json", *m_TestVector); + // Note: the combined data trace object is not available in *.json + } + + std::string getHeraclesInstruction(std::string functionName) const + { + // No mutex lock, since we're just reading a const member + // Check the map, if not in there, return the functionName + // Note: this is using prefix matching! + for (const auto &[key, value] : op_name_map) + if (functionName.find(key) == 0) + return value; + return functionName; + } + +private: + /// Guards access to member variables for accesses by FunctionTracer(s) + mutable std::mutex m_mutex; + + // ID management (accessible by HeraclesFunctionTracer for naming logic) + std::unordered_map m_uniqueID; // hash -> human-readable ID + std::unordered_map m_counters; // type -> counter + + // Track known output IDs to detect missing tracing calls + std::unordered_set m_knownOutputs; // object IDs that have been registered as outputs + + std::string m_filename; // Filename basis to use. Will be extended with _data and *.bin/*.json + CryptoContext m_context; // CryptoContext for the current trace + bool m_warnOnUnregisteredInputs; // Whether to warn on unregistered inputs + + // Generated traces (nullptr until tracing is finished) + std::unique_ptr m_FHETrace = nullptr; + std::unique_ptr m_FHEContext = nullptr; + std::unique_ptr m_TestVector = nullptr; + + /// Instructions to skip emission for (but still trace nested instructions) + /// WARNING: the match is on the PREFIX of the instruction name, + /// so LeveledSHERNS::AdjustForMultInPlace will match + /// LeveledSHERNS::AdjustForMultInPlace(ciphertext1, ciphertext2) + /// but also LeveledSHERNS::AdjustForMultInPlace(ciphertext, plaintext) + /// Note: This means that InPlace versions are also matched, e.g., + /// LeveledSHERNS::EvalAdd will also match LeveledSHERNS::EvalAddInPlace! + const std::unordered_set no_emit_prefixes = { + // Ignore all CryptoContext high-level wrappers + "CryptoContext::", + // Automagic Adjustment Wrappers + "LeveledSHEBase::AdjustForMult", + "LeveledSHERNS::AdjustForMult", + "LeveledSHERNS::AdjustForAddOrSub", + "LeveledSHECKKSRNS::AdjustLevelsAndDepth", // also covers "..ToOne" version + // Multiplication Wrappers + "LeveledSHEBase::EvalMult", + "LeveledSHERNS::EvalMult", + "LeveledSHECKKSRNS::EvalMult(", // We do want LeveledSHECKKSRNS::EvalMultCore + "LeveledSHECKKSRNS::EvalMultInPlace(", // so we can't just match on EvalMult! + // Addition/Subtraction Wrappers + "LeveledSHERNS::EvalAdd(", // Again, we want the ::...Core version + "LeveledSHERNS::EvalAddInPlace(", + "LeveledSHERNS::EvalSub(", // Again, we want the ::...Core version + "LeveledSHERNS::EvalSubInPlace(", + }; + + /// Mapping from OpenFHE function name (prefix) to HERACLES instruction name + /// WARNING: this is also prefix match, so it will match the beginning of the function name + const std::unordered_map op_name_map = { + // Addition + { "LeveledSHEBase::EvalAddCore(Ciphertext,Ciphertext)", "add" }, + { "LeveledSHEBase::EvalAddCoreInPlace(Ciphertext,Ciphertext)", "add" }, + { "LeveledSHEBase::EvalAddCore(Ciphertext,Plaintext)", "add" }, + { "LeveledSHEBase::EvalAddCoreInPlace(Ciphertext,Plaintext)", "add" }, + // Subtraction + { "LeveledSHEBase::EvalSubCore(Ciphertext,Ciphertext)", "sub" }, + { "LeveledSHEBase::EvalSubCoreInPlace(Ciphertext,Ciphertext)", "sub" }, + { "LeveledSHEBase::EvalSubCore(Ciphertext,Plaintext)", "sub" }, + { "LeveledSHEBase::EvalSubCoreInPlace(Ciphertext,Plaintext)", "sub" }, + // Multiplication (scheme-specific) + { "LeveledSHECKKSRNS::EvalMultCore(Ciphertext,Ciphertext)", "mul" }, + { "LeveledSHECKKSRNS::EvalMultCoreInPlace(ciphertext, ciphertext)", "mul" }, + { "LeveledSHECKKSRNS::EvalMultCore(Ciphertext,Plaintext)", "mul" }, + { "LeveledSHECKKSRNS::EvalMultCoreInPlace(Ciphertext,Plaintext)", "mul" }, + { "LeveledSHECKKSRNS::EvalMultCore(Ciphertext,double)", "muli" }, + { "LeveledSHECKKSRNS::EvalMultCoreInPlace(Ciphertext,double)", "muli" }, + // Also map the high-level wrappers in case they slip through + { "LeveledSHECKKSRNS::EvalMult(Ciphertext,double)", "muli" }, + { "LeveledSHECKKSRNS::EvalMultInPlace(Ciphertext,double)", "muli" }, + // Modulus Reduction / Rescale + { "LeveledSHECKKSRNS::ModReduceInternal", "rescale" }, + // Rotation + { "LeveledSHEBase::EvalAutomorphism", "rotate" } + + }; + + void _initializeTrace() + { + m_FHETrace = std::make_unique(); + m_TestVector = std::make_unique(); + m_FHEContext = std::make_unique(); + _initializeContext(); + + m_FHETrace->set_scheme(m_FHEContext->scheme()); + m_FHETrace->set_n(m_FHEContext->n()); + m_FHETrace->set_key_rns_num(m_FHEContext->key_rns_num()); + m_FHETrace->set_q_size(m_FHEContext->q_size()); + m_FHETrace->set_dnum(m_FHEContext->digit_size()); + m_FHETrace->set_alpha(m_FHEContext->alpha()); + } + + void + _initializeContext() + { + if (!m_context) + { + throw std::runtime_error("No CryptoContext provided for HERACLES tracing"); + } + + auto cc_rns = std::dynamic_pointer_cast(m_context->GetCryptoParameters()); + if (!cc_rns) + throw std::runtime_error("HERACLES requires RNS parameters."); + auto key_rns = cc_rns->GetParamsQP()->GetParams(); + + auto scheme = m_context->getSchemeId(); + switch (scheme) + { + case SCHEME::CKKSRNS_SCHEME: + { + m_FHEContext->set_scheme(heracles::common::SCHEME_CKKS); + // Add CKKS-specific information + // FIXME: set_has_ckks_info() is private, need to find correct way to set this + // m_FHEContext->set_has_ckks_info(); + auto ckks_info = m_FHEContext->mutable_ckks_info(); + size_t sizeQ = m_context->GetElementParams()->GetParams().size(); + for (size_t i = 0; i < sizeQ; ++i) + { + ckks_info->add_scaling_factor_real(cc_rns->GetScalingFactorReal(i)); + if (i < sizeQ - 1) + ckks_info->add_scaling_factor_real_big(cc_rns->GetScalingFactorRealBig(i)); + } + } + break; + case SCHEME::BGVRNS_SCHEME: + { + m_FHEContext->set_scheme(heracles::common::SCHEME_BGV); + // BGV not fully supported yet + } + break; + case SCHEME::BFVRNS_SCHEME: + { + m_FHEContext->set_scheme(heracles::common::SCHEME_BFV); + // BFV not fully supported yet + } + break; + default: + throw std::runtime_error("Unsupported scheme for HERACLES tracing"); + } + + // TODO: check in old tracing code what these should be set to! + auto poly_degree = m_context->GetRingDimension(); + m_FHEContext->set_n(poly_degree); + m_FHEContext->set_key_rns_num(key_rns.size()); + m_FHEContext->set_alpha(cc_rns->GetNumPerPartQ()); + m_FHEContext->set_digit_size(cc_rns->GetNumPartQ()); + for (const auto &parms : key_rns) + { + auto q_i = parms->GetModulus(); + m_FHEContext->add_q_i(q_i.ConvertToInt()); + + auto psi_i = RootOfUnity(poly_degree * 2, parms->GetModulus()); + m_FHEContext->add_psi(psi_i.ConvertToInt()); + } + m_FHEContext->set_q_size(m_context->GetElementParams()->GetParams().size()); + m_FHEContext->set_alpha(cc_rns->GetNumPerPartQ()); + } +}; + +} // namespace lbcrypto + +#endif // ENABLE_TRACER + +#endif // P_ISA_TOOLS_DATA_FORMATS_TRACERS_OPENFHE_TRACER_H_ diff --git a/p-isa_tools/data_formats/tracers/openfhe/tracing_example.cpp b/p-isa_tools/data_formats/tracers/openfhe/tracing_example.cpp new file mode 100644 index 00000000..215055f5 --- /dev/null +++ b/p-isa_tools/data_formats/tracers/openfhe/tracing_example.cpp @@ -0,0 +1,267 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +//================================================================================== +// BSD 2-Clause License +// +// Copyright (c) 2014-2022, NJIT, Duality Technologies Inc. and other contributors +// +// All rights reserved. +// +// Author TPOC: contact@openfhe.org +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//================================================================================== + +/* + Simple examples for CKKS, based on OpenFHE's "src/pke/examples/simple-real-numbers.cpp" + */ + +#define PROFILE + +#include "openfhe.h" +#include "tracer.h" + +using namespace lbcrypto; + +int main() +{ + // Step 1: Setup CryptoContext + + // A. Specify main parameters + /* A1) Multiplicative depth: + * The CKKS scheme we setup here will work for any computation + * that has a multiplicative depth equal to 'multDepth'. + * This is the maximum possible depth of a given multiplication, + * but not the total number of multiplications supported by the + * scheme. + * + * For example, computation f(x, y) = x^2 + x*y + y^2 + x + y has + * a multiplicative depth of 1, but requires a total of 3 multiplications. + * On the other hand, computation g(x_i) = x1*x2*x3*x4 can be implemented + * either as a computation of multiplicative depth 3 as + * g(x_i) = ((x1*x2)*x3)*x4, or as a computation of multiplicative depth 2 + * as g(x_i) = (x1*x2)*(x3*x4). + * + * For performance reasons, it's generally preferable to perform operations + * in the shorted multiplicative depth possible. + */ + uint32_t multDepth = 1; + + /* A2) Bit-length of scaling factor. + * CKKS works for real numbers, but these numbers are encoded as integers. + * For instance, real number m=0.01 is encoded as m'=round(m*D), where D is + * a scheme parameter called scaling factor. Suppose D=1000, then m' is 10 (an + * integer). Say the result of a computation based on m' is 130, then at + * decryption, the scaling factor is removed so the user is presented with + * the real number result of 0.13. + * + * Parameter 'scaleModSize' determines the bit-length of the scaling + * factor D, but not the scaling factor itself. The latter is implementation + * specific, and it may also vary between ciphertexts in certain versions of + * CKKS (e.g., in FLEXIBLEAUTO). + * + * Choosing 'scaleModSize' depends on the desired accuracy of the + * computation, as well as the remaining parameters like multDepth or security + * standard. This is because the remaining parameters determine how much noise + * will be incurred during the computation (remember CKKS is an approximate + * scheme that incurs small amounts of noise with every operation). The + * scaling factor should be large enough to both accommodate this noise and + * support results that match the desired accuracy. + */ + uint32_t scaleModSize = 50; + + /* A3) Number of plaintext slots used in the ciphertext. + * CKKS packs multiple plaintext values in each ciphertext. + * The maximum number of slots depends on a security parameter called ring + * dimension. In this instance, we don't specify the ring dimension directly, + * but let the library choose it for us, based on the security level we + * choose, the multiplicative depth we want to support, and the scaling factor + * size. + * + * Please use method GetRingDimension() to find out the exact ring dimension + * being used for these parameters. Give ring dimension N, the maximum batch + * size is N/2, because of the way CKKS works. + */ + uint32_t batchSize = 8; + + /* A4) Desired security level based on FHE standards. + * This parameter can take four values. Three of the possible values + * correspond to 128-bit, 192-bit, and 256-bit security, and the fourth value + * corresponds to "NotSet", which means that the user is responsible for + * choosing security parameters. Naturally, "NotSet" should be used only in + * non-production environments, or by experts who understand the security + * implications of their choices. + * + * If a given security level is selected, the library will consult the current + * security parameter tables defined by the FHE standards consortium + * (https://homomorphicencryption.org/introduction/) to automatically + * select the security parameters. Please see "TABLES of RECOMMENDED + * PARAMETERS" in the following reference for more details: + * http://homomorphicencryption.org/wp-content/uploads/2018/11/HomomorphicEncryptionStandardv1.1.pdf + */ + CCParams parameters; + parameters.SetMultiplicativeDepth(multDepth); + parameters.SetScalingModSize(scaleModSize); + parameters.SetBatchSize(batchSize); + + CryptoContext cc = GenCryptoContext(parameters); + + // Enable the features that you wish to use + cc->Enable(PKE); + cc->Enable(KEYSWITCH); + cc->Enable(LEVELEDSHE); + std::cout << "CKKS scheme is using ring dimension " << cc->GetRingDimension() << std::endl + << std::endl; + + // B. Step 2: Key Generation + /* B1) Generate encryption keys. + * These are used for encryption/decryption, as well as in generating + * different kinds of keys. + */ + auto keys = cc->KeyGen(); + + /* B2) Generate the digit size + * In CKKS, whenever someone multiplies two ciphertexts encrypted with key s, + * we get a result with some components that are valid under key s, and + * with an additional component that's valid under key s^2. + * + * In most cases, we want to perform relinearization of the multiplication + * result, i.e., we want to transform the s^2 component of the ciphertext so + * it becomes valid under original key s. To do so, we need to create what we + * call a relinearization key with the following line. + */ + cc->EvalMultKeyGen(keys.secretKey); + + /* B3) Generate the rotation keys + * CKKS supports rotating the contents of a packed ciphertext, but to do so, + * we need to create what we call a rotation key. This is done with the + * following call, which takes as input a vector with indices that correspond + * to the rotation offset we want to support. Negative indices correspond to + * right shift and positive to left shift. Look at the output of this demo for + * an illustration of this. + * + * Keep in mind that rotations work over the batch size or entire ring dimension (if the batch size is not specified). + * This means that, if ring dimension is 8 and batch + * size is not specified, then an input (1,2,3,4,0,0,0,0) rotated by 2 will become + * (3,4,0,0,0,0,1,2) and not (3,4,1,2,0,0,0,0). + * If ring dimension is 8 and batch + * size is set to 4, then the rotation of (1,2,3,4) by 2 will become (3,4,1,2). + * Also, as someone can observe + * in the output of this demo, since CKKS is approximate, zeros are not exact + * - they're just very small numbers. + */ + cc->EvalRotateKeyGen(keys.secretKey, { 1, -2 }); + + // Step 3: Encoding and encryption of inputs + + // Inputs + std::vector x1 = { 0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0 }; + std::vector x2 = { 5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25 }; + + // Encoding as plaintexts + Plaintext ptxt1 = cc->MakeCKKSPackedPlaintext(x1); + Plaintext ptxt2 = cc->MakeCKKSPackedPlaintext(x2); + + std::cout << "Input x1: " << ptxt1 << std::endl; + std::cout << "Input x2: " << ptxt2 << std::endl; + + // Encrypt the encoded vectors + auto c1 = cc->Encrypt(keys.publicKey, ptxt1); + auto c2 = cc->Encrypt(keys.publicKey, ptxt2); + + // Step 4: Evaluation + // Note: Since we do not want to emit instructions for client-side operations like key generation, we only start the trace here + // and we also do not emit warnings for unregisteret inputs, as we're guarnateed to have those (the ptxts and ctxts created above). + IF_TRACE(auto tracer = std::make_shared>("tracing_example", cc, false)); + IF_TRACE(cc->setTracer(tracer)); + + // Homomorphic addition + auto cAdd = cc->EvalAdd(c1, c2); + + // Homomorphic subtraction + auto cSub = cc->EvalSub(c1, c2); + + // Homomorphic scalar multiplication + auto cScalar = cc->EvalMult(c1, 4.0); + + // Homomorphic multiplication + auto cMul = cc->EvalMult(c1, c2); + + // Homomorphic rotations + auto cRot1 = cc->EvalRotate(c1, 1); + auto cRot2 = cc->EvalRotate(c1, -2); + + // Note: since we don't want to emit instructions for client-side operations such as decryption/decoding, we save the trace here + IF_TRACE(tracer->saveBinaryTrace()); + IF_TRACE(tracer->saveJsonTrace()); // Only used for easy debugging/manual inspection of the trace w/o needing to use protoc + + // Step 5: Decryption and output + Plaintext result; + // We set the cout precision to 8 decimal digits for a nicer output. + // If you want to see the error/noise introduced by CKKS, bump it up + // to 15 and it should become visible. + std::cout.precision(8); + + std::cout << std::endl + << "Results of homomorphic computations: " << std::endl; + + cc->Decrypt(keys.secretKey, c1, &result); + result->SetLength(batchSize); + std::cout << "x1 = " << result; + std::cout << "Estimated precision in bits: " << result->GetLogPrecision() << std::endl; + + // Decrypt the result of addition + cc->Decrypt(keys.secretKey, cAdd, &result); + result->SetLength(batchSize); + std::cout << "x1 + x2 = " << result; + std::cout << "Estimated precision in bits: " << result->GetLogPrecision() << std::endl; + + // Decrypt the result of subtraction + cc->Decrypt(keys.secretKey, cSub, &result); + result->SetLength(batchSize); + std::cout << "x1 - x2 = " << result << std::endl; + + // Decrypt the result of scalar multiplication + cc->Decrypt(keys.secretKey, cScalar, &result); + result->SetLength(batchSize); + std::cout << "4 * x1 = " << result << std::endl; + + // Decrypt the result of multiplication + cc->Decrypt(keys.secretKey, cMul, &result); + result->SetLength(batchSize); + std::cout << "x1 * x2 = " << result << std::endl; + + // Decrypt the result of rotations + + cc->Decrypt(keys.secretKey, cRot1, &result); + result->SetLength(batchSize); + std::cout << std::endl + << "In rotations, very small outputs (~10^-10 here) correspond to 0's:" << std::endl; + std::cout << "x1 rotate by 1 = " << result << std::endl; + + cc->Decrypt(keys.secretKey, cRot2, &result); + result->SetLength(batchSize); + std::cout << "x1 rotate by -2 = " << result << std::endl; + + return 0; +} diff --git a/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h b/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h index 1652521c..134c910c 100644 --- a/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h +++ b/p-isa_tools/functional_modeler/data_handlers/json_data_handler.h @@ -297,7 +297,7 @@ std::vector>> JSONDataHandler::getAllim } auto metadata = m_input_json["metadata"]; auto inputs = metadata.find("immediate"); - if (inputs != metadata.end() && !inputs->empty()) + if (!inputs.operator==(metadata.end()) && !inputs->empty()) { for (const auto &input : inputs->items()) { diff --git a/p-isa_tools/kerngen/CMakeLists.txt b/p-isa_tools/kerngen/CMakeLists.txt new file mode 100644 index 00000000..b0f0cb1d --- /dev/null +++ b/p-isa_tools/kerngen/CMakeLists.txt @@ -0,0 +1,37 @@ +message( STATUS "Installing kerngen Python requirements") +# First, ensure that we have python3: +find_package (Python3 COMPONENTS Interpreter) + +# Now create a venv +set(VENV_PATH "${CMAKE_BINARY_DIR}/venv") +if (NOT EXISTS "${VENV_PATH}/bin/python") + execute_process(COMMAND "${Python3_EXECUTABLE}" -m venv "${VENV_PATH}" + RESULT_VARIABLE VENV_CREATION_FAILED) + if (VENV_CREATION_FAILED) + message(FATAL_ERROR "Failed to create virtual environment at ${VENV_PATH}") + endif() +endif() + +# Now find python again, but from the venv. This requires (re)setting a few things: +set(Python3_FIND_VIRTUALENV FIRST) +set(Python3_EXECUTABLE "${VENV_PATH}/bin/python") +unset(Python3_VERSION) +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + +# Make sure we have pip/setup tools +execute_process(COMMAND "${Python3_EXECUTABLE}" -m ensurepip) +execute_process(COMMAND "${Python3_EXECUTABLE}" -m pip install --upgrade pip setuptools) + +# Install the requirements +execute_process(COMMAND "${Python3_EXECUTABLE}" -m pip install -r "${CMAKE_CURRENT_LIST_DIR}/requirements.txt") + +# Get the site-packages directory for PYTHONPATH +execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import site; print(site.getsitepackages()[0])" + OUTPUT_VARIABLE VENV_SITE_PACKAGES + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Make the variables available to parent scope +set(VENV_PATH "${VENV_PATH}" PARENT_SCOPE) +set(VENV_SITE_PACKAGES "${VENV_SITE_PACKAGES}" PARENT_SCOPE) diff --git a/p-isa_tools/program_mapper/p_isa/pisa_test_generator.h b/p-isa_tools/program_mapper/p_isa/pisa_test_generator.h index 8e5e289b..8219a32a 100644 --- a/p-isa_tools/program_mapper/p_isa/pisa_test_generator.h +++ b/p-isa_tools/program_mapper/p_isa/pisa_test_generator.h @@ -3,14 +3,15 @@ #pragma once -#include "program_mapper/p_isa/tests/pisa_instruction_tests.h" -#include "program_mapper/p_isa/tests/pisa_kernel_tests.h" -#include -#include #include -#include #include +#include "common/graph/graph.h" +#include "common/p_isa/p_isa.h" +#include "program_mapper/p_isa/tests/pisa_instruction_tests.h" +#include "program_mapper/p_isa/tests/pisa_kernel_tests.h" +#include "program_mapper/poly_program/polyprogram.h" + using json = nlohmann::json; namespace pisa::testgenerator { diff --git a/p-isa_tools/program_mapper/p_isa/pisakernel.cpp b/p-isa_tools/program_mapper/p_isa/pisakernel.cpp index 7bae3b54..efc20f5e 100644 --- a/p-isa_tools/program_mapper/p_isa/pisakernel.cpp +++ b/p-isa_tools/program_mapper/p_isa/pisakernel.cpp @@ -62,6 +62,24 @@ inline std::string genKernInput(const pisa::poly::PolyOperation &op) << " " << inputs[i].num_of_polynomials << "\n"; } + // For muli operations, add the immediate value as an additional DATA input + bool has_immediate = false; + if (op.Name() == "muli") + { + // Check if we have an operand parameter (immediate value) + try + { + op.getParam("operand"); + // Add immediate value as input1 with 1 polynomial (scalar) + input << "DATA input" << op.numInputOperands() << " 1\n"; + has_immediate = true; + } + catch (...) + { + // No operand parameter, treat as regular operation + } + } + // OP input << std::uppercase << op.Name() << std::nouppercase; for (int i = 0; i < op.numOutputOperands(); ++i) @@ -74,6 +92,11 @@ inline std::string genKernInput(const pisa::poly::PolyOperation &op) input << " " << "input" << i /*inputs[i].register_name*/; } + // For muli, add the immediate as the second input + if (has_immediate) + { + input << " input" << op.numInputOperands(); + } return input.str(); } diff --git a/p-isa_tools/program_mapper/poly_program/operations/core.h b/p-isa_tools/program_mapper/poly_program/operations/core.h index e221565a..18950996 100644 --- a/p-isa_tools/program_mapper/poly_program/operations/core.h +++ b/p-isa_tools/program_mapper/poly_program/operations/core.h @@ -40,6 +40,19 @@ static const PolyOperationDesc Sub("sub", { OP_NAME, FHE_SCHEME, POLYMOD_DEG_LOG * */ static const PolyOperationDesc Mul("mul", { OP_NAME, FHE_SCHEME, POLYMOD_DEG_LOG2, KEY_RNS, OUTPUT_ARGUMENT, INPUT_ARGUMENT, INPUT_ARGUMENT }); +/** \brief PolyOperation multiply immediate (scalar multiplication) PolyOperationDesc + * Op name: muli + * | Param | description | + * | ----- | ----------- | + * | FHE_SCHEME | specifies the FHE_SCHEME of the poly operation | + * | POLYMOD_DEG_LOG2 | Specifies the modulus degree of the input polynomials | + * | KEY_RNS | Specifies number of RNS key values | + * | OUTPUT_ARGUMENT | Destination ciphertext | + * | INPUT_ARGUMENT | Input ciphertext label | + * | PARAM | Immediate/scalar value | + * */ +static const PolyOperationDesc Muli("muli", { OP_NAME, FHE_SCHEME, POLYMOD_DEG_LOG2, KEY_RNS, OUTPUT_ARGUMENT, INPUT_ARGUMENT, PARAM }); + /** \brief PolyOperation Square PolyOperationDesc * Adds two polynomials with specified specified polymodulus degree, RNS terms, and polynomial parts and writes the result to the specified output. * | Param | description | diff --git a/p-isa_tools/program_mapper/poly_program/poly_operation_library.h b/p-isa_tools/program_mapper/poly_program/poly_operation_library.h index e53636f8..4f67991d 100644 --- a/p-isa_tools/program_mapper/poly_program/poly_operation_library.h +++ b/p-isa_tools/program_mapper/poly_program/poly_operation_library.h @@ -13,6 +13,7 @@ static std::map core_operation_library = { { "sub", library::core::Sub }, { "mul", library::core::Mul }, { "mul_plain", library::core::Mul }, + { "muli", library::core::Muli }, // Multiply immediate (scalar multiplication) { "square", library::core::Square }, { "ntt", library::core::Ntt }, { "intt", library::core::Intt }, diff --git a/p-isa_tools/program_mapper/poly_program/polyprogram.cpp b/p-isa_tools/program_mapper/poly_program/polyprogram.cpp index 53ceec57..397059b9 100644 --- a/p-isa_tools/program_mapper/poly_program/polyprogram.cpp +++ b/p-isa_tools/program_mapper/poly_program/polyprogram.cpp @@ -209,6 +209,10 @@ void PolyOperation::setComponents(const heracles::fhe_trace::Instruction &instr_ setGaloisElt(stoi(v.value())); if (k == "factor") setFactor(stoi(v.value())); + if (k == "operand") + // For muli operations, the operand is the immediate scalar value + // Store it as a parameter that will be used when generating the kernel + setParam({ k, { v.value(), ValueType::DOUBLE } }); } } heracles::fhe_trace::Instruction *PolyOperation::getProtobuffFHETraceInstruction() diff --git a/pyproject.toml b/pyproject.toml index 9859c8df..7ee16417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,20 +3,26 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.setuptools] -packages = ["assembler_tools", "p-isa_tools"] +packages = ["assembler_tools", "p-isa_tools", "heracles"] [tool.setuptools.package-dir] "" = "." +"heracles" = "p-isa_tools/data_formats/python/heracles" [project] name = "encrypted-computing-sdk" version = "0.1.0" description = "Encrypted Computing SDK" requires-python = ">=3.10" -dependencies = [] +dependencies = [ + "protobuf==4.23.4", + "regex-spm", + "numpy", +] [project.optional-dependencies] dev = [ + "grpcio-tools==1.56.2", # needed for compiling protos from python directly "pytest>=7.4.0", "pre-commit>=3.5.0", # pre-commit installs its own version of the tools below, @@ -110,7 +116,7 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.pytest.ini_options] -pythonpath = ["p-isa_tools/kerngen", "assembler_tools/hec-assembler-tools"] +pythonpath = ["p-isa_tools/kerngen", "assembler_tools/hec-assembler-tools", "p-isa_tools/data_formats/python"] [tool.mypy] python_version = "3.10" @@ -118,3 +124,6 @@ namespace_packages = true show_error_codes = true pretty = true files = ["p-isa_tools/"] +exclude = [ + "p-isa_tools/data_formats/python/heracles/proto/.*_pb2\\.py$", +]