diff --git a/.clang-format b/.clang-format index 23014f9..6b6f7bb 100644 --- a/.clang-format +++ b/.clang-format @@ -23,7 +23,7 @@ BraceWrapping: AfterCaseLabel: true AfterControlStatement: Always SplitEmptyFunction: false - BeforeLambdaBody: false + BeforeLambdaBody: true BeforeCatch: true BeforeWhile: true BeforeElse: true diff --git a/.gitignore b/.gitignore index b5cc7ab..c72c66c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,10 @@ build/* docs/* +deps/* +.cache .idea /graph_info.json /vcpkg_installed/ /node_modules/ /.docusaurus/ - diff --git a/.vscode/launch.json b/.vscode/launch.json index a87114b..7543e98 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,13 +2,13 @@ "version": "0.2.0", "configurations": [ { - "type": "gdb", + "type": "codelldb", "request": "launch", "name": "Cli-Debug", "preLaunchTask": "cmake-install-debug", - "expressions": "native", - "program": "./build/Debug/cli/BraneScriptCli", + "program": "${workspaceFolder}/build/Debug/cli/BraneScriptCli.exe", "console": "integratedTerminal", + "cwd": "${workspaceFolder}/build/Debug/cli/", "args": [ "exampleScripts/test.bscript" ] @@ -21,6 +21,7 @@ "expressions": "native", "program": "./build/Release/cli/BraneScriptCli", "console": "integratedTerminal", + "cwd": "${workspaceFolder}/build/Release/cli/", "args": [ "exampleScripts/test.bscript" ] diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..2b077ee --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "files.associations": { + "vector": "cpp", + "type_traits": "cpp" + } +} \ No newline at end of file diff --git a/.vscode/tasks-gcc.json b/.vscode/tasks-gcc.json new file mode 100644 index 0000000..a9c28f2 --- /dev/null +++ b/.vscode/tasks-gcc.json @@ -0,0 +1,105 @@ +{ + "version": "2.0.0", + "showOutput": "always", + "tasks": [ + { + "taskName": "cmake-configure-debug", + "label": "cmake-configure-debug", + "type": "shell", + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake -S. -B./build/Debug -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DVCPKG_TARGET_TRIPLET=x64-mingw-static -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", + "presentation": "always" + }, + { + "taskName": "cmake-build-debug", + "label": "cmake-build-debug", + "type": "shell", + "dependsOn": [], + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake --build ./build/Debug", + "presentation": "always" + }, + { + "taskName": "cmake-configure-release", + "label": "cmake-configure-release", + "type": "shell", + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake -S. -B./build/Release -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=\".vscode/toolchains/zcc\" -DCMAKE_CXX_COMPILER=\".vscode/toolchains/z++\" -DVCPKG_OVERLAY_TRIPLETS=.vscode/triplets/ -DVCPKG_TARGET_TRIPLET=zig-x64-windows -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", + "presentation": "always" + }, + { + "taskName": "cmake-build-release", + "label": "cmake-build-release", + "type": "shell", + "dependsOn": [], + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake --build ./build/Release" + }, + { + "taskName": "build-db-debug", + "label": "build-db-debug", + "type": "shell", + "dependsOn": [ + "cmake-configure-debug" + ], + "options": { + "cwd": "${workspaceRoot}" + }, + "windows": { + "command": "mv ./build/Debug/compile_commands.json ./build/ -Force" + }, + "linux": { + "command": "mv ./build/Debug/compile_commands.json ./build/ -f" + } + }, + { + "taskName": "build-db-release", + "label": "build-db-release", + "type": "shell", + "dependsOn": [ + "cmake-configure-release" + ], + "options": { + "cwd": "${workspaceRoot}" + }, + "windows": { + "command": "mv ./build/Release/compile_commands.json ./build/ -Force" + }, + "linux": { + "command": "mv ./build/Release/compile_commands.json ./build/ -f" + } + }, + { + "taskName": "cmake-install-debug", + "label": "cmake-install-debug", + "type": "shell", + "dependsOn": [ + "cmake-build-debug" + ], + "options": { + "cwd": "${workspaceRoot}/build/Debug" + }, + "command": "cmake --install . --component cli --prefix ./cli" + }, + { + "taskName": "cmake-install-cli-release", + "label": "cmake-install-cli-release", + "type": "shell", + "dependsOn": [ + "cmake-build-release" + ], + "options": { + "cwd": "${workspaceRoot}/build/Release" + }, + "command": "cmake --install . --component cli --prefix ./cli" + } + ] +} diff --git a/.vscode/tasks-zig.json b/.vscode/tasks-zig.json new file mode 100644 index 0000000..b3ec786 --- /dev/null +++ b/.vscode/tasks-zig.json @@ -0,0 +1,105 @@ +{ + "version": "2.0.0", + "showOutput": "always", + "tasks": [ + { + "taskName": "cmake-configure-debug", + "label": "cmake-configure-debug", + "type": "shell", + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake -S. -B./build/Debug -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=\"${workspaceRoot}/.vscode/toolchains/zcc.cmd\" -DCMAKE_CXX_COMPILER=\"${workspaceRoot}/.vscode/toolchains/z++.cmd\" -DVCPKG_OVERLAY_TRIPLETS=\".vscode/triplets/\" -DVCPKG_TARGET_TRIPLET=zig-x64-windows -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON -DENABLE_TREESITTER_DEV=ON", + "presentation": "always" + }, + { + "taskName": "cmake-build-debug", + "label": "cmake-build-debug", + "type": "shell", + "dependsOn": [], + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake --build ./build/Debug", + "presentation": "always" + }, + { + "taskName": "cmake-configure-release", + "label": "cmake-configure-release", + "type": "shell", + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake -S. -B./build/Release -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=\".vscode/toolchains/zcc\" -DCMAKE_CXX_COMPILER=\".vscode/toolchains/z++\" -DVCPKG_OVERLAY_TRIPLETS=.vscode/triplets/ -DVCPKG_TARGET_TRIPLET=zig-x64-windows -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", + "presentation": "always" + }, + { + "taskName": "cmake-build-release", + "label": "cmake-build-release", + "type": "shell", + "dependsOn": [], + "options": { + "cwd": "${workspaceRoot}" + }, + "command": "cmake --build ./build/Release" + }, + { + "taskName": "build-db-debug", + "label": "build-db-debug", + "type": "shell", + "dependsOn": [ + "cmake-configure-debug" + ], + "options": { + "cwd": "${workspaceRoot}" + }, + "windows": { + "command": "mv ./build/Debug/compile_commands.json ./build/ -Force" + }, + "linux": { + "command": "mv ./build/Debug/compile_commands.json ./build/ -f" + } + }, + { + "taskName": "build-db-release", + "label": "build-db-release", + "type": "shell", + "dependsOn": [ + "cmake-configure-release" + ], + "options": { + "cwd": "${workspaceRoot}" + }, + "windows": { + "command": "mv ./build/Release/compile_commands.json ./build/ -Force" + }, + "linux": { + "command": "mv ./build/Release/compile_commands.json ./build/ -f" + } + }, + { + "taskName": "cmake-install-debug", + "label": "cmake-install-debug", + "type": "shell", + "dependsOn": [ + "cmake-build-debug" + ], + "options": { + "cwd": "${workspaceRoot}/build/Debug" + }, + "command": "cmake --install . --component cli --prefix ./cli" + }, + { + "taskName": "cmake-install-cli-release", + "label": "cmake-install-cli-release", + "type": "shell", + "dependsOn": [ + "cmake-build-release" + ], + "options": { + "cwd": "${workspaceRoot}/build/Release" + }, + "command": "cmake --install . --component cli --prefix ./cli" + } + ] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 11b9675..60d09d3 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -9,16 +9,14 @@ "options": { "cwd": "${workspaceRoot}" }, - "command": "cmake -S. -B./build/Debug -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DVCPKG_TARGET_TRIPLET=x64-mingw-static -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", + "command": "cmake -S. -B./build/Debug -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", "presentation": "always" }, { "taskName": "cmake-build-debug", "label": "cmake-build-debug", "type": "shell", - "dependsOn": [ - "cmake-configure-debug" - ], + "dependsOn": [], "options": { "cwd": "${workspaceRoot}" }, @@ -32,16 +30,14 @@ "options": { "cwd": "${workspaceRoot}" }, - "command": "cmake -S. -B./build/Release -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DVCPKG_TARGET_TRIPLET=x64-mingw-static -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON -DTREESITTER_BRANESCRIPT_DIR=../TreeSitterBraneScript", + "command": "cmake -S. -B./build/Release -G Ninja -DCMAKE_MAKE_PROGRAM=ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=\"${env:VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake\" -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DBS_BUILD_TESTS=ON", "presentation": "always" }, { "taskName": "cmake-build-release", "label": "cmake-build-release", "type": "shell", - "dependsOn": [ - "cmake-configure-release" - ], + "dependsOn": [], "options": { "cwd": "${workspaceRoot}" }, diff --git a/.vscode/toolchains/toolchain-zig-linux.cmake b/.vscode/toolchains/toolchain-zig-linux.cmake new file mode 100644 index 0000000..141614e --- /dev/null +++ b/.vscode/toolchains/toolchain-zig-linux.cmake @@ -0,0 +1,88 @@ +# Derived from the mingw pipeline +if(NOT _VCPKG_ZIG_TOOLCHAIN) + set(_VCPKG_ZIG_TOOLCHAIN 1) + + list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES + VCPKG_CRT_LINKAGE VCPKG_TARGET_ARCHITECTURE + VCPKG_C_FLAGS VCPKG_CXX_FLAGS + VCPKG_C_FLAGS_DEBUG VCPKG_CXX_FLAGS_DEBUG + VCPKG_C_FLAGS_RELEASE VCPKG_CXX_FLAGS_RELEASE + VCPKG_LINKER_FLAGS VCPKG_LINKER_FLAGS_RELEASE VCPKG_LINKER_FLAGS_DEBUG + ) + + if(VCPKG_TARGET_ARCHITECTURE STREQUAL "x86") + set(CMAKE_SYSTEM_PROCESSOR i686 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "x64") + set(CMAKE_SYSTEM_PROCESSOR x86_64 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "arm") + set(CMAKE_SYSTEM_PROCESSOR armv7 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "arm64") + set(CMAKE_SYSTEM_PROCESSOR aarch64 CACHE STRING "") + endif() + + set(CMAKE_SYSTEM_NAME Linux) + + foreach(lang C CXX) + set(CMAKE_${lang}_COMPILER_TARGET "${CMAKE_SYSTEM_PROCESSOR}-linux-gnu" CACHE STRING "") + endforeach() + + # Select Zig as the compiler + if(WIN32) + find_program(ZCC + NAMES zcc.bat zcc.cmd + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + find_program(ZPP + NAMES z++.bat z++.cmd + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + else() + find_program(ZCC + NAMES zcc.sh + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + find_program(ZPP + NAMES z++.sh + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + endif() + +if(NOT ZPP) + message(FATAL_ERROR "Could not find zig wrapper script in ${CMAKE_CURRENT_LIST_DIR}") +endif() + message("Using ZIG at ${ZPP}") + set(CMAKE_C_COMPILER "${ZCC}" CACHE STRING "") + set(CMAKE_CXX_COMPILER "${ZPP}" CACHE STRING "") + + message("curent c compile flags: ${CMAKE_C_FLAGS_INIT}") + string(APPEND CMAKE_C_FLAGS_INIT "${VCPKG_C_FLAGS}") + string(APPEND CMAKE_CXX_FLAGS_INIT "${VCPKG_CXX_FLAGS} ") + string(APPEND CMAKE_C_FLAGS_DEBUG_INIT " ${VCPKG_C_FLAGS_DEBUG} ") + string(APPEND CMAKE_CXX_FLAGS_DEBUG_INIT " ${VCPKG_CXX_FLAGS_DEBUG} ") + string(APPEND CMAKE_C_FLAGS_RELEASE_INIT " ${VCPKG_C_FLAGS_RELEASE} ") + string(APPEND CMAKE_CXX_FLAGS_RELEASE_INIT " ${VCPKG_CXX_FLAGS_RELEASE} ") + + string(APPEND CMAKE_MODULE_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + + if(VCPKG_CRT_LINKAGE STREQUAL "static") + string(APPEND CMAKE_MODULE_LINKER_FLAGS_INIT "-static ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_INIT "-static ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_INIT "-static ") + endif() + + string(APPEND CMAKE_MODULE_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG} ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG} ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG} ") + string(APPEND CMAKE_MODULE_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") +endif() + + + diff --git a/.vscode/toolchains/toolchain-zig-windows.cmake b/.vscode/toolchains/toolchain-zig-windows.cmake new file mode 100644 index 0000000..ffbba94 --- /dev/null +++ b/.vscode/toolchains/toolchain-zig-windows.cmake @@ -0,0 +1,88 @@ +# Derived from the mingw pipeline +if(NOT _VCPKG_ZIG_TOOLCHAIN) + set(_VCPKG_ZIG_TOOLCHAIN 1) + + list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES + VCPKG_CRT_LINKAGE VCPKG_TARGET_ARCHITECTURE + VCPKG_C_FLAGS VCPKG_CXX_FLAGS + VCPKG_C_FLAGS_DEBUG VCPKG_CXX_FLAGS_DEBUG + VCPKG_C_FLAGS_RELEASE VCPKG_CXX_FLAGS_RELEASE + VCPKG_LINKER_FLAGS VCPKG_LINKER_FLAGS_RELEASE VCPKG_LINKER_FLAGS_DEBUG + ) + + if(VCPKG_TARGET_ARCHITECTURE STREQUAL "x86") + set(CMAKE_SYSTEM_PROCESSOR i686 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "x64") + set(CMAKE_SYSTEM_PROCESSOR x86_64 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "arm") + set(CMAKE_SYSTEM_PROCESSOR armv7 CACHE STRING "") + elseif(VCPKG_TARGET_ARCHITECTURE STREQUAL "arm64") + set(CMAKE_SYSTEM_PROCESSOR aarch64 CACHE STRING "") + endif() + + set(CMAKE_SYSTEM_NAME Windows) + + foreach(lang C CXX) + set(CMAKE_${lang}_COMPILER_TARGET "${CMAKE_SYSTEM_PROCESSOR}-windows-gnu" CACHE STRING "") + endforeach() + + # Select Zig as the compiler + if(WIN32) + find_program(ZCC + NAMES zcc.bat zcc.cmd + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + find_program(ZPP + NAMES z++.bat z++.cmd + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + else() + find_program(ZCC + NAMES zcc.sh + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + find_program(ZPP + NAMES z++.sh + PATHS ${CMAKE_CURRENT_LIST_DIR} + NO_DEFAULT_PATH + ) + endif() + + if(NOT ZPP) + message(FATAL_ERROR "Could not find zig wrapper script in ${CMAKE_CURRENT_LIST_DIR}") + endif() + message("Using ZIG at ${ZPP}") + set(CMAKE_C_COMPILER "${ZCC}" CACHE STRING "") + set(CMAKE_CXX_COMPILER "${ZPP}" CACHE STRING "") + + message("curent c compile flags: ${CMAKE_C_FLAGS_INIT}") + string(APPEND CMAKE_C_FLAGS_INIT "${VCPKG_C_FLAGS}") + string(APPEND CMAKE_CXX_FLAGS_INIT "${VCPKG_CXX_FLAGS}") + string(APPEND CMAKE_C_FLAGS_DEBUG_INIT "${VCPKG_C_FLAGS_DEBUG} -g -gdwarf") + string(APPEND CMAKE_CXX_FLAGS_DEBUG_INIT "${VCPKG_CXX_FLAGS_DEBUG} -g -gdwarf") + string(APPEND CMAKE_C_FLAGS_RELEASE_INIT "${VCPKG_C_FLAGS_RELEASE} ") + string(APPEND CMAKE_CXX_FLAGS_RELEASE_INIT "${VCPKG_CXX_FLAGS_RELEASE} ") + + string(APPEND CMAKE_MODULE_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_INIT " ${VCPKG_LINKER_FLAGS} ") + + if(VCPKG_CRT_LINKAGE STREQUAL "static") + string(APPEND CMAKE_MODULE_LINKER_FLAGS_INIT "-static ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_INIT "-static ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_INIT "-static ") + endif() + + string(APPEND CMAKE_MODULE_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG}") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG}") + string(APPEND CMAKE_EXE_LINKER_FLAGS_DEBUG_INIT " ${VCPKG_LINKER_FLAGS_DEBUG} ") + string(APPEND CMAKE_MODULE_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") + string(APPEND CMAKE_SHARED_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") + string(APPEND CMAKE_EXE_LINKER_FLAGS_RELEASE_INIT " ${VCPKG_LINKER_FLAGS_RELEASE} ") +endif() + + + diff --git a/.vscode/toolchains/z++.cmd b/.vscode/toolchains/z++.cmd new file mode 100644 index 0000000..62809b8 --- /dev/null +++ b/.vscode/toolchains/z++.cmd @@ -0,0 +1 @@ +zig c++ %* diff --git a/.vscode/toolchains/z++.sh b/.vscode/toolchains/z++.sh new file mode 100644 index 0000000..0c89549 --- /dev/null +++ b/.vscode/toolchains/z++.sh @@ -0,0 +1 @@ +zig c++ "$@" diff --git a/.vscode/toolchains/zcc.cmd b/.vscode/toolchains/zcc.cmd new file mode 100644 index 0000000..de6add4 --- /dev/null +++ b/.vscode/toolchains/zcc.cmd @@ -0,0 +1 @@ +zig cc %* diff --git a/.vscode/toolchains/zcc.sh b/.vscode/toolchains/zcc.sh new file mode 100644 index 0000000..b6b89f1 --- /dev/null +++ b/.vscode/toolchains/zcc.sh @@ -0,0 +1 @@ +zig cc "$@" diff --git a/.vscode/triplets/zig-x64-linux.cmake b/.vscode/triplets/zig-x64-linux.cmake new file mode 100644 index 0000000..ff9977f --- /dev/null +++ b/.vscode/triplets/zig-x64-linux.cmake @@ -0,0 +1,6 @@ +set(VCPKG_TARGET_ARCHITECTURE x64) +set(VCPKG_CRT_LINKAGE static) +set(VCPKG_LIBRARY_LINKAGE static) +set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "${CMAKE_CURRENT_LIST_DIR}../toolchains/toolchain-zig-linux.cmake") +set(VCPKG_CMAKE_SYSTEM_NAME "Linux") +set(VCPKG_ENV_PASSTHROUGH PATH) diff --git a/.vscode/triplets/zig-x64-windows.cmake b/.vscode/triplets/zig-x64-windows.cmake new file mode 100644 index 0000000..21ab1af --- /dev/null +++ b/.vscode/triplets/zig-x64-windows.cmake @@ -0,0 +1,6 @@ +set(VCPKG_TARGET_ARCHITECTURE x64) +set(VCPKG_CRT_LINKAGE static) # Or dynamic if you prefer +set(VCPKG_LIBRARY_LINKAGE static) +set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "${CMAKE_CURRENT_LIST_DIR}/../toolchains/toolchain-zig-windows.cmake") +set(VCPKG_CMAKE_SYSTEM_NAME Windows) +set(VCPKG_ENV_PASSTHROUGH PATH) \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ccb9ba..8efe360 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,71 +4,53 @@ include(FetchContent) project(BraneScript VERSION "0.2.0" DESCRIPTION "BraneScript" - HOMEPAGE_URL "https://github.com/wirewhiz/branescript" + HOMEPAGE_URL "https://github.com/BraneReality/branescript" LANGUAGES CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(BS_BUILD_TESTS "Build tests" ON) -option(TreeSitterBS_DIR "Set a manual path to use for the TreeSitterBraneScript repo") - - +option(ENABLE_TREESITTER_DEV "Disable fetch content version requirement so local dev can occur" OFF) # Fetch and run our parser generator -set(TSBS_REPO https://github.com/BraneReality/TreeSitterBraneScript.git) -if(TSBS_DIR) + +set(TSBS_SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps/TreeSitterBraneScript) +if(ENABLE_TREESITTER_DEV AND EXISTS ${TSBS_SOURCE_DIR}/grammar.js) + message("Skipping fetch of TreeSitterBraneScript as ENABLE_TREESITTER_DEV is set") FetchContent_Declare( TreeSitterBraneScript - GIT_REPOSITORY ${TSBS_REPO} - GIT_TAG main - SOURCE_DIR ${TSBS_DIR} + SOURCE_DIR ${TSBS_SOURCE_DIR} ) else() FetchContent_Declare( TreeSitterBraneScript - GIT_REPOSITORY ${TSBS_REPO} + GIT_REPOSITORY https://github.com/BraneReality/TreeSitterBraneScript.git GIT_TAG main + SOURCE_DIR ${TSBS_SOURCE_DIR} ) endif() -FetchContent_MakeAvailable(TreeSitterBraneScript) -set(GRAMMAR_DIR ${treesitterbranescript_SOURCE_DIR}) +set(BUTILS_SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps/BraneUtilities) +FetchContent_Declare( + BraneUtilities + GIT_REPOSITORY https://github.com/BraneReality/BraneUtilities.git + GIT_TAG main + SOURCE_DIR ${BUTILS_SOURCE_DIR} +) + +FetchContent_MakeAvailable(TreeSitterBraneScript) +FetchContent_MakeAvailable(BraneUtilities) -message("Grammar at ${GRAMMAR_DIR}") -set(GRAMMAR_FILE "${GRAMMAR_DIR}/grammar.js") -set(GRAMMAR_BUILD_DIR ${CMAKE_BINARY_DIR}/ts-branescript) -set(GENERATE_PARSER_CMD tree-sitter generate -b ${GRAMMAR_FILE}) -set(PARSER_LIB_NAME parser${CMAKE_SHARED_LIBRARY_SUFFIX}) -set(BUILD_PARSER_CMD tree-sitter build -o ${GRAMMAR_BUILD_DIR}/${PARSER_LIB_NAME}) +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_options("--target=x86_64-windows-gnu" "-gdwarf") +endif() if(MINGW) - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS_DEBUG} -static -static-libgcc -static-libstdc++") - if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g2") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g2") - endif() + set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} -Wa,-mbig-obj -static -static-libgcc -static-libstdc++") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wa,-mbig-obj -static -static-libgcc -static-libstdc++") endif(MINGW) -#[[file(MAKE_DIRECTORY ${GRAMMAR_BUILD_DIR}) -message("Watching grammar at ${GRAMMAR_FILE}") -add_custom_command( - OUTPUT ${GRAMMAR_BUILD_DIR}/ts-build-command.stamp - COMMAND echo ${GENERATE_PARSER_CMD} - COMMAND ${GENERATE_PARSER_CMD} - COMMAND echo ${BUILD_PARSER_CMD} - COMMAND ${BUILD_PARSER_CMD} - - WORKING_DIRECTORY ${GRAMMAR_BUILD_DIR} - DEPENDS ${GRAMMAR_FILE} - COMMENT "Generate ts parser because grammar.js changed" -) -add_custom_target( - tree_sitter_bs_parser - DEPENDS - ${GRAMMAR_BUILD_DIR}/ts-build-command.stamp -)]] - add_subdirectory(src) if(BS_BUILD_TESTS) enable_testing() - add_subdirectory(tests) + #add_subdirectory(tests) endif() diff --git a/README.md b/README.md index 546403d..bd3ae60 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,6 @@ Read more about why we're using a custom scripting system instead of a pre-exist ## Setup -We use git submodules to bring in the TreeSitterBraneScript repo so we can edit and build it locally, so you'll need to periodically run: -```bash -git submodules update --init --recursive --remote -``` -this command will both init submodules, and fetch the latest version if re-run - You will need to install the tree-sitter cli, this can be done with npm or cargo. As tree-sitter also has a dependency on [nodejs](https://nodejs.org/en/download/package-manager), it usually makes sense to use npm. diff --git a/exampleScripts/target-grammar.bscript b/exampleScripts/target-grammar.bscript new file mode 100644 index 0000000..e095390 --- /dev/null +++ b/exampleScripts/target-grammar.bscript @@ -0,0 +1,39 @@ +mod test { + pipe add_parallel(add1: i32, add2: i32) -> (count: f32) // All functions / pipelines must return structs with named values, this is so visual scripting implementations may have labels for output values. Eventually this could be abstracted so that a single type return could be automatically expanded to (value: T) + [ + // Every pipeline stage has the format of [] [] [] { } + { + let value = Shared::new(0i32); + if add1 == 0 && add2 == 0 { + continue(stage: end, args: (value: value)); // Go to next pipeline stage right away + // or + skip_to(stage: end, args: (value: value)); // like a goto, but can only go forward in the pipeline. Useful for async control flow. + } + + let t1 = start(pipe: add, args: (value: value, count: add1)); // create a start command for both add pipelines. These don't get executed right away, as a hard rule no communication happens inside of segments, they're pure functions + let t2 = start(pipe: add, args: (value: value, count: add2)); + + continue(args: (value: value), deps: [t1, t2]); // Go to the next pipeline stage after deps have been completed, if t1 or t2 returned a value we could pass them as args and the next stage would be able to access their return values, deps is only required to say "we don't care about the returns, so don't make them args". + } + end { // "end" is the stage label + break(value: count); + } + ] + + pipe add(value: Shared, count: i32) + [ + (value: Shared) { // Thread safety of data is handled by the runtime, so pipeline stages are only called when all their requested data is available. Shared is basically a shared reference counted pointer with a mutex inside it, where the mutex is automatically shared or unique locked depending on the needs of the stage (Indicated with a MutRef for a mutable reference and Ref for a const reference, and plain Shared for a non-locked handle) + value += 1 + let workers = start_foreach(pipe: (value: MutRef) // start one closure style pipeline for a range of values + [{ + value += 1; + }], (value: value.handle()), 0..count); // 0..N expands to a range iterator, see rust .. operator used in for loops + + break(deps: [workers]); + } + ] +} + +mod test2 { + +} diff --git a/exampleScripts/test.bscript b/exampleScripts/test.bscript index ae34abc..abe07af 100644 --- a/exampleScripts/test.bscript +++ b/exampleScripts/test.bscript @@ -1,27 +1,13 @@ mod test { - pipe Foo - (a: i32, b: i32){ + pipe add_multiple(start: i32, add1: i32, add2: i32) -> (count: i32) [ - let r: i32 = a + b; + (start: i32, add1: i32, add2: i32) { + let value : i32 = start + add1; + continue(args: (value: value, add2: add2), deps: []); + } + (value: i32, add2: i32) { + let value : i32 = value + add2; + continue(args: (count: value)); + } ] - hold r - call (r)Bar(o: f32) - [ - o = o + 2; - ] - }(value: o) - - - pipe Bar - (a: i32){ - [ - (in: a)f32(o: f32); - ] - }(value: o) } - -mod test2 -{} - -mod hello_world{} - diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fe77ac3..159b76b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(ir) add_subdirectory(cli) add_subdirectory(parser) -#add_subdirectory(compiler) +add_subdirectory(compiler) +add_subdirectory(runtime) add_subdirectory(types) + diff --git a/src/cli/CMakeLists.txt b/src/cli/CMakeLists.txt index 4ed031d..8f117cd 100644 --- a/src/cli/CMakeLists.txt +++ b/src/cli/CMakeLists.txt @@ -12,7 +12,7 @@ llvm_map_components_to_libnames(llvm_libs Support Core IRReader ...) target_link_libraries(main PRIVATE ${llvm_libs})]] add_executable(BraneScriptCli main.cpp) -target_link_libraries(BraneScriptCli PRIVATE parser) +target_link_libraries(BraneScriptCli PRIVATE parser compiler llvmJitBackend) install(TARGETS BraneScriptCli RUNTIME_DEPENDENCIES @@ -21,13 +21,19 @@ install(TARGETS BraneScriptCli "api-ms-.*" "ext-ms-.*" # Windows system DLLs "^/lib.*" # Linux system libraries "^/usr/lib.*" - POST_EXCLUDE_REGEXES + POST_EXCLUDE_REGEXES + "libgcc_s_seh-1\\.dll" + "libstdc++-6\\.dll" ".*system32/.*\\.dll" # Windows system directory "^/lib64/.*" # Linux system directories "^/usr/lib64/.*" DESTINATION . COMPONENT cli) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/BraneScriptCli.pdb + DESTINATION . + COMPONENT cli + OPTIONAL) install(FILES ${treesitterbranescript_BINARY_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}tree-sitter-branescript${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION . COMPONENT cli) diff --git a/src/cli/main.cpp b/src/cli/main.cpp index 8bc2194..3f6fca9 100644 --- a/src/cli/main.cpp +++ b/src/cli/main.cpp @@ -8,9 +8,12 @@ #include #include #include "parser/documentParser.h" -#include "parser/tree_sitter_branescript.h" +#include "runtime/backends/llvm/llvmJitBackend.h" +#include "tree-sitter-branescript.h" #include +#include "compiler/compiler.h" +#include "ir/irTextSerializer.h" #include "tree_sitter/api.h" // #include "TSBindings.h" @@ -59,7 +62,7 @@ void print_tree(TSNode node, std::string_view source, int currentDepth = 0) if(field) { bool matches = ts_node_child_by_field_name(node, field, strlen(field)).id == child.id; - printf("field: \"%s\" matches epxected = %s", field, matches ? "true" : "false"); + printf("field: \"%s\" matches expected = %s", field, matches ? "true" : "false"); if(!matches) printf("Was epxecting: %s", ts_node_type(ts_node_child_by_field_name(node, field, strlen(field)))); } @@ -93,7 +96,7 @@ int main(int argc, char* argv[]) f.read(source_code.data(), count); f.close(); - std::cout << "Parsing: \n" << source_code << std::endl; + std::cout << "---Parsing---\n" << source_code << "\n---END---" << std::endl; TSParser* parser = ts_parser_new(); const TSLanguage* braneScriptLang = tree_sitter_branescript(); ts_parser_set_language(parser, braneScriptLang); @@ -111,9 +114,9 @@ int main(int argc, char* argv[]) printf("Parsing DocumentContext...\n"); auto bs_parser = std::make_shared(); - BraneScript::ParsedDocument doc(argv[1], source_code, bs_parser); + auto doc = std::make_shared(argv[1], source_code, bs_parser); - auto parseRes = doc.getDocumentContext(); + auto parseRes = doc->getDocumentContext(); if(!parseRes.messages.empty()) { @@ -129,6 +132,105 @@ int main(int argc, char* argv[]) for(auto& mod : parseRes.document->modules) printf("%s\n", mod.second->identifier->text.c_str()); + BraneScript::Compiler compiler; + Option> compileRes = compiler.compile({doc}); + if(!compileRes) + std::cout << "Failed to compile!" << std::endl; + else + std::cout << "compile completed with messages:" << std::endl; + for(auto& m : compiler.messages()) + { + printf("[line %d, char %d]: %s\n", + m.source.range.value().start_point.row, + m.source.range.value().start_point.column, + m.message.c_str()); + } + + if(compileRes) + { + for(auto& mod : compileRes.value()) + { + std::cout << "Compiled module: " << mod.id << std::endl; + auto writeRes = BraneScript::IRSerializer::irToText(mod); + if(writeRes) + std::cout << "IR: " << writeRes.ok() << std::endl; + else + std::cout << "Failed to write ir: " << writeRes.err() << std::endl; + } + } + else + return 1; + + BraneScript::LLVMJitBackend backend; + backend.stageModule(std::make_shared(compileRes.value()[0])); + backend.processModules(); + + + for(auto& func : backend.functions()) + { + std::cout << "Was able to load function: " << func.first << std::endl; + } + + void* bindingsPage[65535]; + int8_t memPage[65535]; + for(auto& i : memPage) + i = 0; + for(auto& i : bindingsPage) + i = memPage; + + struct Args + { + uint32_t a; + uint32_t b; + uint32_t c; + + std::string print() { return std::format("(a: {}, b: {}, c: {})", a, b, c); } + }; + + auto f0 = backend.functions().at("-f0"); + auto f1 = backend.functions().at("-f1"); + + BraneScript::JitPtr input = {.binding = 0, .index = sizeof(Args)}; + BraneScript::JitPtr output = {.binding = 0, .index = sizeof(Args) * 2}; + std::cout << "JitPtr size: " << sizeof(input) << std::endl; + + auto inputPtr = (Args*)bindingsPage[input.binding] + input.index; + auto outputPtr = (Args*)bindingsPage[output.binding] + output.index; + + *inputPtr = Args{.a = 2, .b = 5, .c = 3}; + + std::cout << "Running -f1 with args " << inputPtr->print() << std::endl; + f0(bindingsPage, input.asInt(), output.asInt()); + std::cout << "-f1 returned: " << outputPtr->print() << std::endl; + + std::cout << "Mem: "; + for(size_t i = 0; i < 128; ++i) + { + std::cout << std::format("{} ", memPage[i]); + if(i % 30 == 0 && i != 0) + std::cout << std::endl; + } + std::cout << std::endl; + + // Move output so we can feed it back through, and clean up output memory + *inputPtr = *outputPtr; + inputPtr->c = 0; + *outputPtr = Args{0, 0, 0}; + + std::cout << "Running -f2 with args " << inputPtr->print() << std::endl; + f1(bindingsPage, input.asInt(), output.asInt()); + std::cout << "-f2 returned: " << outputPtr->print() << std::endl; + + + std::cout << "Mem: "; + for(size_t i = 0; i < 128; ++i) + { + std::cout << std::format("{} ", memPage[i]); + if(i % 30 == 0 && i != 0) + std::cout << std::endl; + } + std::cout << std::endl; + ts_tree_delete(tree); ts_parser_delete(parser); diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index e3031ac..85686c7 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -4,3 +4,10 @@ add_library(compiler STATIC ) target_link_libraries(compiler PUBLIC ir parser) + + +if(MINGW) + if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(compiler PRIVATE -g -Og) + endif() +endif(MINGW) diff --git a/src/compiler/compiler.cpp b/src/compiler/compiler.cpp index 18f253d..2e90052 100644 --- a/src/compiler/compiler.cpp +++ b/src/compiler/compiler.cpp @@ -1,121 +1,754 @@ #include "compiler.h" -#include +#include +#include +#include +#include "enums/matchv.h" +#include "enums/result.h" #include -void Compiler::recordMessage(CompilerMessage message) +namespace BraneScript { - _result.messages.push_back(std::move(message)); -} - + struct CompilerContext : TextContext + { + std::unordered_map> modules; -CompileResult Compiler::compile(const std::vector>& documents) -{ - _result = CompileResult(); - _identifers.clear(); - _modules.clear(); - - _sources.clear(); - for(auto& source: documents) - _sources.emplace(source->source(), source); - - auto globalMod = std::make_shared("global"); - _identifers.insert({"global", globalMod}); - _modules.insert({"global", globalMod}); - - indexSymbolsPass(); - generateModulesPass(); - - for(auto& mod: _modules) - _result.modules.push_back(std::move(*mod.second)); - return std::move(_result); -} - -std::expected parseScopedIdentifier(TSNode idRoot, const ParsedDocument& doc) -{ - ScopedIdentifier id; + Option searchFor(Node identifier, size_t scope) override + { + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + auto mod = modules.find((*id)->text); + if(mod != modules.end()) + return Some(mod->second); + } + return None(); + } + }; - while(!ts_node_is_null(idRoot)) + struct CompilerPass { - TSNode idNameNode = ts_node_child_by_field_name(idRoot, "id", 2); - std::optional> templateArgs; + std::vector messages; + bool recordedError = false; - // TODO actually parse templateArgs + void recordMessage(std::string message, Node ctx, CompilerMessageType type) + { + messages.push_back(CompilerMessage{ + .type = type, + .source = CompilerSource{.uri = ctx->source->uri, .range = Some(ctx->range)}, + .message = std::move(message), + }); + } - id.scopes.emplace_back(doc.nodeToString(idNameNode), templateArgs); - idRoot = ts_node_child_by_field_name(idRoot, "child", 5); - } + void recordLog(std::string message, Node ctx) + { + recordMessage(std::move(message), std::move(ctx), CompilerMessageType::Log); + } - return id; -} + void recordWarning(std::string message, Node ctx) + { + recordMessage(std::move(message), std::move(ctx), CompilerMessageType::Log); + } -void Compiler::indexSymbolsPass() -{ - for(auto& [path, doc]: _sources) + void recordError(std::string message, Node ctx) + { + recordMessage(std::move(message), std::move(ctx), CompilerMessageType::Log); + } + }; + + struct DocumentCombinerPass : CompilerPass { - TSNode root = doc->docRoot(); - - ScopedIdentifier currentScope; - indexSymbolsPass_recursive(root, currentScope, *doc); - } -} + Node compilerContext; -void Compiler::indexSymbolsPass_recursive(TSNode node, ScopedIdentifier& currentScope, const ParsedDocument& doc) -{ - size_t parentScopeLength = currentScope.scopes.size(); + static bool identiferTaken(const std::string& id, const Node& mod) + { + return mod->structs.contains(id) || mod->functions.contains(id) || mod->pipelines.contains(id); + } + + void mergeModule(std::string modName, std::vector> moduleSources) + { + auto mod = std::make_shared(); + mod->parent = Some(compilerContext->weak_from_this()); + + for(auto& def : moduleSources) + { + if(!mod->identifier) + mod->identifier = def->identifier; + + for(auto& s : def->structs) + { + if(identiferTaken(s.first, mod)) + recordError(std::format("May not define {}::{} multiple times", mod->identifier->text, s.first), + s.second); + mod->structs[s.first] = s.second; + s.second->parent = Some(mod->weak_from_this()); + } + + for(auto& f : def->functions) + { + if(identiferTaken(f.first, mod)) + recordError(std::format("May not define {}::{} multiple times", mod->identifier->text, f.first), + f.second); + mod->functions[f.first] = f.second; + f.second->parent = Some(mod->weak_from_this()); + } + + for(auto& p : def->pipelines) + { + if(identiferTaken(p.first, mod)) + recordError(std::format("May not define {}::{} multiple times", mod->identifier->text, p.first), + p.second); + mod->pipelines[p.first] = p.second; + p.second->parent = Some(mod->weak_from_this()); + } + // TODO combine impl statements + } + + compilerContext->modules.insert({mod->identifier->text, mod}); + } + + Option> run(const std::vector>& sources) + { + std::unordered_map>> moduleSources; + compilerContext = std::make_shared(); + + for(auto& document : sources) + { + auto ctx = document->getDocumentContext(); + for(auto& mod : ctx.document->modules) + moduleSources[mod.first].emplace_back(mod.second); + } + + for(auto& sourceList : moduleSources) + mergeModule(sourceList.first, std::move(sourceList.second)); - NodeType nodeType; - if(_lut.tryToNodeType(ts_node_symbol(node), nodeType)) + return Some(std::move(compilerContext)); + } + }; + + struct DocumentToIRPass : CompilerPass { - switch(nodeType) + Option currentModule; + Option currentPipeline; + Option currentFunction; + + CompilerSource sourceFor(Node ctx) + { + return CompilerSource{.uri = ctx->source->uri, .range = Some(ctx->range)}; + } + + Option compileStruct(Node ctx) + { + int structId = currentModule.value()->structs.size(); + currentModule.value()->structs.push_back(IRStruct{.id = Some(ctx->identifier->text), .members = {}}); + auto& s = currentModule.value()->structs[structId]; + + for(auto& m : ctx->members) + { + auto memberType = resolveType(m.second->type.value()); + if(!memberType) + { + messages.push_back(memberType.err()); + return None(); + } + + s.members.push_back(IRStructMember{.id = Some(m.first), .type = memberType.ok()}); + } + return Some(structId); + } + + Result resolveType(Node typeCtx) + { + std::vector members; + for(auto& m : typeCtx->members) + { + IRStructMember newMember; + if(m->label) + newMember.id = Some(m->label.value()->text); + else + newMember.id = None(); + auto typeRes = resolveType(m->type.value()); + if(!typeRes) + return typeRes; + newMember.type = typeRes.ok(); + members.push_back(newMember); + } + return resolveType(members); + } + + Result resolveType(const std::vector& members) { - case NodeType::Pipeline: + auto& structs = currentModule.value()->structs; + for(size_t i = 0; i < structs.size(); ++i) + { + auto& s = structs[i]; + if(s.id.isSome()) + continue; + if(s.members.size() != members.size()) + continue; + bool membersMatch = true; + for(size_t j = 0; j < members.size(); ++j) { - auto pipe = std::make_shared(); + auto& a = s.members[j]; + auto& b = members[j]; + if(a.id != b.id || a.type != b.type) + { + membersMatch = false; + break; + } + } + + if(membersMatch) + return Ok((int32_t)i); + } - TSNode idNode = ts_node_child_by_field_name(node, "id", 2); - Identifier identifier{ - doc.nodeToString(idNode) + IRType newId = (int32_t)structs.size(); + structs.push_back(IRStruct{.id = None(), .members = std::move(members)}); + + return Ok(newId); + } + + Result resolveType(Node typeCtx) + { + for(auto mod : typeCtx->modifiers) + { + switch(mod) + { + case TypeModifiers::MutRef: + case TypeModifiers::ConstRef: + return Ok(IRNativeType::I32); // Since this is a reference type, treat it as an int + default: + assert(false && "Unhandled type modifier! This is a compiler bug!"); + return Err(CompilerMessage{.type = CompilerMessageType::Error, + .source = sourceFor(typeCtx), + .message = "Unhandled type modifier! This is a compiler bug!"}); + } + } + + // See if this is a native type, otherwise it's a struct or typedef + if(typeCtx->baseType->scopes.size() == 1) + { + if(auto id = std::get_if>(&typeCtx->baseType->scopes[0])) + { + static std::unordered_map nativeTypes = { + {"u8", IRNativeType::U8}, + {"i8", IRNativeType::I8}, + {"u16", IRNativeType::U16}, + {"i16", IRNativeType::I16}, + {"u32", IRNativeType::U32}, + {"i32", IRNativeType::I32}, + {"f32", IRNativeType::F32}, + {"u64", IRNativeType::U64}, + {"i64", IRNativeType::I64}, + {"f64", IRNativeType::F64}, + {"I128", IRNativeType::I128}, }; - - pipe->name = identifier.name; + auto nativeType = nativeTypes.find((*id)->text); + if(nativeType != nativeTypes.end()) + return Ok(nativeType->second); + } + } + + // Find identifier + auto typeNode = typeCtx->searchFor(typeCtx->baseType, 0); + if(!typeNode) + { + return Err(CompilerMessage{.type = CompilerMessageType::Error, + .source = sourceFor(typeCtx), + .message = std::format("{} was not found", typeCtx->baseType->longId())}); + } + + return MATCHV(typeNode.value(), + [&](Node& ctx) -> Result + { + std::vector members; + std::string id = ctx->longId(); + auto& structs = currentModule.value()->structs; + for(int i = 0; i < structs.size(); ++i) + { + auto& s = structs[i]; + if(s.id && s.id.value() == id) + return Ok(i); + } + + auto newStruct = compileStruct(ctx); + if(!newStruct) + return Err( + CompilerMessage{.type = CompilerMessageType::Error, + .source = sourceFor(typeCtx), + .message = std::format("{} is not a type", typeCtx->baseType->longId())}); + return Ok(newStruct.value()); + }, + [&](auto& notSupported) -> Result + { + return Err(CompilerMessage{.type = CompilerMessageType::Error, + .source = sourceFor(typeCtx), + .message = std::format("{} is not a type", typeCtx->baseType->longId())}); + }); + } + + std::unordered_map, IRValue> localVarLookup; + + void appendOp(IROperation op) { currentFunction.value()->operations.push_back(op); } + + IRType getValueType(IRValue value) { return currentFunction.value()->localVars[value.id]; } + + IRValue allocValue(IRType type) + { + IRValue id{.id = (uint32_t)currentFunction.value()->localVars.size()}; + currentFunction.value()->localVars.push_back(type); + return id; + } - std::get>(_identifers["global"])->pipelines.insert({identifier.name, pipe}); - currentScope.scopes.push_back(std::move(identifier)); + Option compileAnonStruct(Node ctx) + { + std::vector memberValues; + for(auto& input : ctx->members) + { + // TODO validate argument names + auto argExpr = compileExpression(input->expression); + if(!argExpr) + { + recordError("expected value", ctx); + return None(); + } + memberValues.push_back(argExpr.value()); + } + + std::vector memberDefs; + memberDefs.reserve(ctx->members.size()); + for(size_t i = 0; i < ctx->members.size(); ++i) + { + memberDefs.push_back( + IRStructMember{.id = Some(ctx->members[i]->id->text), .type = getValueType(memberValues[i])}); + } + + Result structType = resolveType(memberDefs); + if(!structType) + { + recordError("could not resolve anon struct type", ctx); + return None(); + } + // Alloc a struct, but the runtime will give us rootPtr as an I32 pointer + IRValue rootPtr = allocValue(structType.ok()); + for(size_t i = 0; i < memberValues.size(); ++i) + { + MemAcc memberAccess{}; + memberAccess.ptr = rootPtr; + memberAccess.dest = allocValue(IRNativeType::I32); + memberAccess.index = (uint8_t)i; + appendOp(memberAccess); + + MovOp mov{}; + mov.dest = memberAccess.dest; + mov.src = memberValues[i]; + appendOp(mov); + } + return Some(rootPtr); + } + + Option compileExpression(ExpressionContextNode expression) + { + return MATCHV(expression, + [&](Node& ctx) -> Option + { + recordError(ctx->message, ctx); + return None(); + }, + [&](Node& ctx) -> Option { return compileScope(ctx); }, + [&](Node& ctx) -> Option + { + // TODO make variable referenceable only after it's defined + auto variable = localVarLookup.find(ctx->definedValue); + if(variable == localVarLookup.end()) + { + recordError("(compiler error) could not find " + ctx->definedValue->label.value()->text, ctx); + return None(); + } + return Some(variable->second); + }, + [&](Node& ctx) -> Option + { + recordError("if expressions not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + recordError("while expressions not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + recordError("for expressions not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + auto destRes = compileExpression(ctx->lValue); + if(!destRes) + return None(); + auto destV = destRes.value(); + + auto srcRes = compileExpression(ctx->rValue); + if(!srcRes) + return None(); + auto srcV = srcRes.value(); + + MovOp mov{}; + mov.dest = destV; + mov.src = srcV; + appendOp(mov); + + return destRes; + }, + [&](Node& ctx) -> Option + { + recordError("Const values not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + auto searchRes = ctx->searchFor(ctx, 0); + if(!searchRes) + { + recordError(std::format("{} not found", ctx->longId()), ctx); + return None(); + } - _identifers.insert({currentScope.longName(), pipe}); + auto valueCast = std::get_if>(&searchRes.value()); + if(!valueCast) + { + recordError(std::format("{} is not a value", + MATCHV(searchRes.value(), [](const auto& ctx) { return ctx->longId(); })), + ctx); + return None(); + } + auto localVarRes = localVarLookup.find(*valueCast); + if(localVarRes == localVarLookup.end()) + { + recordError(std::format("{} is not in scope", valueCast->get()->longId()), ctx); + return None(); + } + // TODO re-evaluate this when we implement struct member access + return Some(localVarRes->second); + }, + [&](Node& ctx) -> Option + { + recordError("member access expressions not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + recordError("reference semantics not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + recordError("reference semantics not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + recordError("unary operators not implemented", ctx); + return None(); + }, + [&](Node& ctx) -> Option + { + auto leftRes = compileExpression(ctx->left); + if(!leftRes) + return None(); + auto leftV = leftRes.value(); + auto leftT = getValueType(leftV); + + auto rightRes = compileExpression(ctx->right); + if(!rightRes) + return None(); + auto rightV = rightRes.value(); + auto rightT = getValueType(rightV); + + switch(ctx->opType) + { + case BinaryOperator::Add: + { + auto op = AddOp{}; + op.left = leftV; + op.right = rightV; + op.out = allocValue(leftT); + appendOp(op); + return Some(op.out); + } + case BinaryOperator::Sub: + { + auto op = SubOp{}; + op.left = leftV; + op.right = rightV; + op.out = allocValue(leftT); + appendOp(op); + return Some(op.out); + } + break; + case BinaryOperator::Mul: + { + auto op = SubOp{}; + op.left = leftV; + op.right = rightV; + op.out = allocValue(leftT); + appendOp(op); + return Some(op.out); + } + break; + case BinaryOperator::Div: + { + auto op = DivOp{}; + op.left = leftV; + op.right = rightV; + op.out = allocValue(leftT); + appendOp(op); + return Some(op.out); + } break; + case BinaryOperator::Mod: + case BinaryOperator::Equal: + case BinaryOperator::NotEqual: + case BinaryOperator::Greater: + case BinaryOperator::GreaterEqual: + case BinaryOperator::Less: + case BinaryOperator::LessEqual: + case BinaryOperator::LogicAnd: + case BinaryOperator::LogicOr: + case BinaryOperator::BitwiseAnd: + case BinaryOperator::BitwiseOr: + case BinaryOperator::BitwiseXOr: + case BinaryOperator::BitshiftLeft: + case BinaryOperator::BitshiftRight: + recordError("operator type not implemented", ctx); + return None(); + } + + return None(); + }, + [&](Node& ctx) -> Option { return compileAnonStruct(ctx); }, + [&](Node& ctx) -> Option + { + auto idCast = std::get_if>(&ctx->callable); + if(!idCast) + { + recordError(std::format("{} is not an identifier", + MATCHV(ctx->callable, [](const auto& ctx) { return ctx->longId(); })), + ctx); + return None(); + } + + if((*idCast)->longId() == "continue") + { + NextStageOp ns{}; + auto inputStruct = compileAnonStruct(ctx->args); + if(!inputStruct) + { + recordError("Invalid call sig", ctx->args); + return None(); + } + ns.input = inputStruct.value(); + appendOp(ns); + return None(); + } + CallOp call{}; + auto searchRes = ctx->searchFor(*idCast, 0); + if(!searchRes) + { + recordError(std::format("{} not found", (*idCast)->longId()), ctx); + return None(); + } + + auto functionCast = std::get_if>(&searchRes.value()); + if(!functionCast) + { + recordError(std::format("{} is not a function", + MATCHV(searchRes.value(), [](const auto& ctx) { return ctx->longId(); })), + ctx); + return None(); + } + + call.function = functionCast->get()->longId(); + auto retType = resolveType(functionCast->get()->callSig->output); + if(!retType) + return None(); + call.output = allocValue(retType.ok()); + + auto inputStruct = compileAnonStruct(ctx->args); + if(!inputStruct) + { + recordError("Invalid call sig", ctx->args); + return None(); } - default: - break; + call.input = inputStruct.value(); + + // TODO account for unit returns + appendOp(call); + return Some(call.output); + }); } - } - for (int i = 0; i < ts_node_named_child_count(node); i++) { - indexSymbolsPass_recursive(ts_node_named_child(node, i), currentScope, doc); - } - currentScope.scopes.resize(parentScopeLength); -} + Option compileScope(Node scope) + { + for(auto& localVar : scope->localVariables) + { + IRValue localId{(uint32_t)currentFunction.value()->localVars.size()}; + auto typeRes = resolveType(localVar->type.value()); + if(!typeRes) + return None(); + currentFunction.value()->localVars.push_back(typeRes.ok()); + localVarLookup.insert({localVar, localId}); + } -void Compiler::generateIRPass() -{ - for(auto& [modId, mod]: _modules) - { - - // GeneratePipelines defined in our sources - for(auto& [pipeId, pipe]: mod->pipelines) + Option lastExpression; + for(auto& expr : scope->expressions) + lastExpression = compileExpression(expr); + + return lastExpression; + } + + Option + compileFunction(Option identifier, Node callSig, Node body) { - if (!pipe->debugInfo.has_value()) - continue; - if (!_sources.contains(pipe->name)) - continue; - + int32_t funcId = (int32_t)currentModule.value()->functions.size(); + currentModule.value()->functions.push_back(IRFunction{}); + auto& func = currentModule.value()->functions[funcId]; + currentFunction = Some(&func); + func.id = identifier ? identifier.value() : std::format("-f{}", funcId); + + auto inType = resolveType(callSig->input); + if(!inType) + { + messages.push_back(inType.err()); + return None(); + } + else + func.input = std::get(inType.ok()); + auto outType = resolveType(callSig->output); + if(!outType) + { + messages.push_back(outType.err()); + return None(); + } + else + func.output = std::get(outType.ok()); + + localVarLookup.clear(); + for(auto& arg : callSig->input->members) + { + IRValue localId{(uint32_t)currentFunction.value()->localVars.size()}; + auto typeRes = resolveType(arg->type.value()); + if(!typeRes) + return None(); + currentFunction.value()->localVars.push_back(typeRes.ok()); + localVarLookup.insert({arg, localId}); + } + + compileScope(body); + + currentFunction = None(); + return Some(funcId); } + void compilePipeline(Node ctx) + { + int32_t pipeId = (int32_t)currentModule.value()->pipelines.size(); + currentModule.value()->pipelines.push_back(IRPipeline{.id = ctx->identifier->text}); + auto& pipe = currentModule.value()->pipelines[pipeId]; + currentPipeline = Some(&pipe); + + + auto inType = resolveType(ctx->callSig->input); + if(!inType) + { + messages.push_back(inType.err()); + return; + } + else + pipe.input = std::get(inType.ok()); + auto outType = resolveType(ctx->callSig->output); + if(!outType) + { + messages.push_back(outType.err()); + return; + } + else + pipe.output = std::get(outType.ok()); + + std::vector> stageFunctions; + for(auto& stage : ctx->stages) + { + if(stage->callSig->output->members.size() > 0) + { + recordError("Pipeline stages cannot return a value", stage->callSig); + return; + } + + auto res = compileFunction(None(), stage->callSig, stage->body); + if(!res) + return; + pipe.stages.push_back(res.value()); + } + + currentPipeline = None(); + } + + Option compileModule(Node ctx) + { + auto mod = IRModule{.id = ctx->identifier->text}; + currentModule = Some(&mod); + + for(auto& structCtx : ctx->structs) + { + // Do explicit struct things + } + + for(auto& funcCtx : ctx->functions) + { + // Do explicit function things + } + + for(auto& pipeCtx : ctx->pipelines) + compilePipeline(pipeCtx.second); + + + auto out = Some(std::move(*currentModule.value())); + currentModule = None(); + return out; + } + + Option> run(Node ctx) + { + std::vector modules; + + for(auto& mod : ctx->modules) + { + auto res = compileModule(mod.second); + if(res) + modules.push_back(std::move(res.value())); + } + + return Some(std::move(modules)); + } + }; + + Option> Compiler::compile(const std::vector>& documents) + { + DocumentCombinerPass combinerPass; + DocumentToIRPass toIRPass; + + auto compilerContext = combinerPass.run(documents); + _messages.insert(_messages.begin(), combinerPass.messages.begin(), combinerPass.messages.end()); + if(!compilerContext) + return None(); + auto modules = toIRPass.run(compilerContext.value()); + _messages.insert(_messages.begin(), toIRPass.messages.begin(), toIRPass.messages.end()); + return modules; } -} + const std::vector& Compiler::messages() const { return _messages; } +} // namespace BraneScript diff --git a/src/compiler/compiler.h b/src/compiler/compiler.h index 8ba1d61..78aaf4d 100644 --- a/src/compiler/compiler.h +++ b/src/compiler/compiler.h @@ -4,70 +4,47 @@ #include #include "../ir/ir.h" #include "../parser/documentParser.h" +#include "enums/result.h" #include -/// List of pipelines and functions provided by the runtime that we are compiling for -struct EnvDefs +namespace BraneScript { - std::unordered_map pipelines; - std::unordered_map functions; -}; - -enum class CompilerMessageType -{ - Critical = 0, - Error = 1, - Warning = 2, - Log = 3, - Verbose = 4, -}; - -struct CompilerFileSource -{ - std::string path; - std::optional range; -}; - -using CompilerSource = std::variant; - -struct CompilerMessage -{ - CompilerMessageType type; - CompilerSource source; - std::string message; -}; - -struct CompileResult -{ - std::vector modules; - std::vector messages; -}; - - -using Identifiable = std::variant, std::shared_ptr, std::shared_ptr, std::shared_ptr>; - -class Compiler -{ - SymbolLookupTable _lut; - std::optional _env; - std::unordered_map> _sources; - std::unordered_map> _modules; - std::unordered_map _identifers; - - - void indexSymbolsPass_recursive(TSNode node, ScopedIdentifier& currentScope, const ParsedDocument& doc); - - void indexSymbolsPass(); - void constructGenericsPass(); - void generateIRPass(); - - void generatePipeline(std::shared_ptr pipe); - - void recordMessage(CompilerMessage message); -public: - - Compiler(SymbolLookupTable lut); - CompileResult compile(const std::vector>& documents); -}; + enum class CompilerMessageType + { + Critical = 0, + Error = 1, + Warning = 2, + Log = 3, + Verbose = 4, + }; + + struct CompilerSource + { + std::string uri; + Option range; + }; + + struct CompilerMessage + { + CompilerMessageType type; + CompilerSource source; + std::string message; + }; + + template + using CompileResult = Result, std::vector>; + + using Identifiable = std::variant, IRNode, IRNode>; + + class Compiler + { + std::vector _messages; + + public: + Option> compile(const std::vector>& documents); + + const std::vector& messages() const; + }; +} // namespace BraneScript #endif diff --git a/src/ir/CMakeLists.txt b/src/ir/CMakeLists.txt index b29387f..3bc0754 100644 --- a/src/ir/CMakeLists.txt +++ b/src/ir/CMakeLists.txt @@ -1,2 +1,3 @@ -add_library(ir STATIC nodes.cpp) +add_library(ir STATIC nodes.cpp ir.cpp irTextSerializer.cpp) +target_include_directories(ir PUBLIC ${braneutilities_SOURCE_DIR}/src) diff --git a/src/ir/ir.h b/src/ir/ir.h index d86ff20..283aab1 100644 --- a/src/ir/ir.h +++ b/src/ir/ir.h @@ -7,6 +7,8 @@ #include #include +#include "enums/option.h" + namespace BraneScript { @@ -18,7 +20,7 @@ namespace BraneScript /// symbols, and negative representing external symbols, and 0 being a null value using IDRef = std::variant; - enum class BSBaseType + enum class IRNativeType { U8, I8, @@ -34,111 +36,32 @@ namespace BraneScript I128 }; - struct BSStructType - { - IDRef structId; - }; - - struct BSRefType; - using BSType = std::variant, IRNode>; - - struct BSRefType - { - BSType contained; - bool valueMutable; - }; + using IRType = std::variant; struct IRValue { - uint32_t id; + uint32_t id = 0xFFFFFFFF; }; - struct MovOp; - struct LoadOp; - struct StoreOp; - - struct AddOp; - struct SubOp; - struct MulOp; - struct DivOp; - struct ModOp; - - struct EqOp; - struct NeOp; - struct GtOp; - struct GeOp; - - struct LogicNotOp; - struct LogicAndOp; - struct LogicOrOp; - - struct BitNotOp; - struct BitAndOp; - struct BitOrOp; - struct BitXorOp; - - struct I32ToF32; - struct U32ToF32; - struct F32ToI32; - struct U32ToF32; - - struct ConstI32; - struct ConstU32; - struct ConstF32; - - struct CallOp; - using Operation = std::variant, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode, - - IRNode, - IRNode, - IRNode>; - - - using AsyncOperation = std::variant< - - >; - struct ConstI32 { + IRValue dest; int32_t value; }; struct ConstU32 { + IRValue dest; uint32_t value; }; struct ConstF32 { + IRValue dest; float value; }; + // Move (and cast if needed) one values data to another struct MovOp { IRValue src; @@ -147,15 +70,20 @@ namespace BraneScript struct LoadOp { - std::variant store = ConstU32{0}; - IRValue src; + IRValue ptr; IRValue dest; }; struct StoreOp { - std::variant store = ConstU32{0}; + IRValue ptr; IRValue src; + }; + + struct MemAcc + { + uint8_t index; + IRValue ptr; IRValue dest; }; @@ -236,66 +164,85 @@ namespace BraneScript { }; - struct I32ToF32 : public UnaryOp + struct CallOp { + IDRef function; + IRValue input; + IRValue output; }; - struct U32ToF32 : public UnaryOp + struct NextStageOp { + IRValue input; }; - struct F32ToI32 : public UnaryOp - { - }; + using IROperation = std::variant; + + struct IRStructMember { - IDRef function; - std::vector inputs; - std::vector outputs; + Option id; + IRType type; }; - struct BSPipelineStage + struct IRStruct { - std::vector localVars; - std::vector operations; - std::vector asyncOps; + Option id; + std::vector members; }; - struct BSStruct + struct IRPipeline { std::string id; - std::vector members; + IDRef input; + IDRef output; + std::vector stages; }; - struct BSPipeline + struct IRFunction { std::string id; - std::vector inputs; - std::vector outputs; - std::optional> stages; + std::vector localVars; + IDRef input; + IDRef output; + std::vector operations; }; - struct BSFunction + struct IRModule { std::string id; - std::vector localVars; - std::vector inputs; - std::vector outputs; - std::vector operations; + std::vector structs; + std::vector functions; + std::vector pipelines; }; - struct BSModule - { - std::string name; - std::vector> structs; - std::vector> functions; - std::vector> pipelines; - }; } // namespace BraneScript diff --git a/src/ir/irTextSerializer.cpp b/src/ir/irTextSerializer.cpp new file mode 100644 index 0000000..59ac2b2 --- /dev/null +++ b/src/ir/irTextSerializer.cpp @@ -0,0 +1,218 @@ +#include "irTextSerializer.h" +#include + +namespace BraneScript::IRSerializer +{ + std::string serializeNativeType(IRNativeType type) + { + switch(type) + { + case IRNativeType::U8: + return "u8"; + case IRNativeType::I8: + return "i8"; + case IRNativeType::U16: + return "u16"; + case IRNativeType::I16: + return "i16"; + case IRNativeType::U32: + return "u32"; + case IRNativeType::I32: + return "i32"; + case IRNativeType::F32: + return "f32"; + case IRNativeType::U64: + return "u64"; + case IRNativeType::I64: + return "i64"; + case IRNativeType::F64: + return "f64"; + case IRNativeType::U128: + return "u128"; + case IRNativeType::I128: + return "i128"; + } + return "\"Parse write error\""; + }; + + std::string serializeIDRef(const IDRef& id) + { + return MATCHV(id, + [](const std::string& idStr) { return std::format("\"{}\"", idStr); }, + [](int32_t idIndex) { return std::format("#{}", idIndex); }); + } + + std::string serializeType(const IRType& type) + { + return MATCHV(type, + [](const IRNativeType& type) { return serializeNativeType(type); }, + [](const IDRef& type) { return serializeIDRef(type); }); + }; + + std::string serializeIRValue(const IRValue& value) { return std::format("${}", value.id); } + + std::string serializeStoreVariant(const std::variant& store) + { + return MATCHV(store, + [](const IRValue& val) { return serializeIRValue(val); }, + [](const ConstU32& val) { return std::format("{}", val.value); }); + } + + std::string serializeOp(const IROperation& op) + { + return MATCHV( + op, + [](const MovOp& mov) + { return std::format("(mov {} {})", serializeIRValue(mov.src), serializeIRValue(mov.dest)); }, + [](const LoadOp& load) + { return std::format("(load {} {})", serializeIRValue(load.ptr), serializeIRValue(load.dest)); }, + [](const StoreOp& store) + { return std::format("(store {} {})", serializeIRValue(store.ptr), serializeIRValue(store.src)); }, + [](const MemAcc& memAcc) + { + return std::format( + "(ma #{} {} {})", memAcc.index, serializeIRValue(memAcc.ptr), serializeIRValue(memAcc.dest)); + }, + [](const ConstI32& c) { return std::format("(const.i32 {} {})", serializeIRValue(c.dest), c.value); }, + [](const ConstU32& c) { return std::format("(const.u32 {} {})", serializeIRValue(c.dest), c.value); }, + [](const ConstF32& c) { return std::format("(const.f32 {} {})", serializeIRValue(c.dest), c.value); }, + [](const AddOp& op) + { + return std::format( + "(add {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const SubOp& op) + { + return std::format( + "(sub {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const MulOp& op) + { + return std::format( + "(mul {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const DivOp& op) + { + return std::format( + "(div {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const ModOp& op) + { + return std::format( + "(mod {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const EqOp& op) + { + return std::format( + "(eq {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const GeOp& op) + { + return std::format( + "(ge {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const LogicNotOp& op) + { return std::format("(logic.not {} {})", serializeIRValue(op.in), serializeIRValue(op.out)); }, + [](const LogicAndOp& op) + { + return std::format("(logic.and {} {} {})", + serializeIRValue(op.left), + serializeIRValue(op.right), + serializeIRValue(op.out)); + }, + [](const LogicOrOp& op) + { + return std::format( + "(logic.or {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const BitNotOp& op) + { + return std::format( + "(bit.not {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const BitAndOp& op) + { + return std::format( + "(bit.and {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const BitOrOp& op) + { + return std::format( + "(bit.or {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const BitXorOp& op) + { + return std::format( + "(bit.xor {} {} {})", serializeIRValue(op.left), serializeIRValue(op.right), serializeIRValue(op.out)); + }, + [](const CallOp& call) + { + return std::format("(call {}{} {})", + serializeIDRef(call.function), + serializeIRValue(call.input), + serializeIRValue(call.output)); + }, + [](const NextStageOp& call) { return std::format("(stage.next {})", serializeIRValue(call.input)); } + + ); + } + + Result irToText(const IRModule& module) + { + std::string structs; + for(size_t i = 0; i < module.structs.size(); ++i) + { + auto& s = module.structs[i]; + std::string id = s.id ? s.id.value() : std::format("-s{}", i); + + std::string members; + for(size_t m = 0; m < s.members.size(); ++m) + { + std::string mid = s.members[m].id ? s.members[m].id.value() : std::format("-{}", i); + + members += std::format("(\"{}\" {})", mid, serializeType(s.members[m].type)); + } + + structs += std::format("\n(struct \"{}\" {})", id, members); + } + + std::string functions; + for(size_t i = 0; i < module.functions.size(); ++i) + { + auto& f = module.functions[i]; + + std::string input = serializeType(f.input); + std::string output = serializeType(f.output); + + std::string vars; + for(auto& var : f.localVars) + { + vars += " " + serializeType(var); + } + + std::string ops; + for(auto& op : f.operations) + { + ops += "\n\t" + serializeOp(op); + } + + functions += std::format("\n(func \"{}\" {} {} (vars{}) (ops{}))", f.id, input, output, vars, ops); + } + + std::string pipelines; + for(size_t i = 0; i < module.pipelines.size(); ++i) + { + auto& p = module.pipelines[i]; + + std::string input = serializeType(p.input); + std::string output = serializeType(p.output); + + std::string stages; + for(auto& stage : p.stages) + stages += " " + serializeIDRef(stage); + + pipelines += std::format("\n(pipe \"{}\" {} {} (stages{}))", p.id, input, output, stages); + } + return Ok(std::format("(module \"{}\" {} {} {})", module.id, structs, functions, pipelines)); + } +} // namespace BraneScript::IRSerializer diff --git a/src/ir/irTextSerializer.h b/src/ir/irTextSerializer.h new file mode 100644 index 0000000..6ba501b --- /dev/null +++ b/src/ir/irTextSerializer.h @@ -0,0 +1,9 @@ +#pragma once +#include "enums/result.h" +#include "ir.h" + +namespace BraneScript::IRSerializer +{ + Result irToText(const IRModule& module); + Result textToIR(std::string_view text); +} // namespace BraneScript::IRSerializer diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt index 0cc2f97..a5f9bb5 100644 --- a/src/parser/CMakeLists.txt +++ b/src/parser/CMakeLists.txt @@ -1,9 +1,19 @@ find_package(unofficial-tree-sitter CONFIG REQUIRED) -#find_package(TreeSitterBraneScript REQUIRED) add_library(parser STATIC documentParser.cpp documentContext.cpp ) -target_link_libraries(parser PUBLIC unofficial::tree-sitter::tree-sitter tree-sitter-branescript types) + +target_link_libraries(parser PUBLIC unofficial::tree-sitter::tree-sitter tree-sitter-branescript types utilities) +target_include_directories(parser PRIVATE ${treesitterbranescript_SOURCE_DIR}/src) +target_include_directories(parser PUBLIC ${treesitterbranescript_SOURCE_DIR}/bindings/c) +target_include_directories(parser PUBLIC ${braneutilities_SOURCE_DIR}/src) + +if(MINGW) + target_compile_options(parser PRIVATE -Wmissing-field-initializers) + if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(parser PRIVATE -g -Og) + endif() +endif(MINGW) diff --git a/src/parser/documentContext.cpp b/src/parser/documentContext.cpp index c840718..2234c8d 100644 --- a/src/parser/documentContext.cpp +++ b/src/parser/documentContext.cpp @@ -1,28 +1,41 @@ - #include "documentContext.h" #include +#include +#include "enums/matchv.h" namespace BraneScript { - std::optional TextContext::getNodeAtChar(TSPoint pos) { return std::nullopt; } + Option TextContext::getNodeAtChar(TSPoint pos) { return None(); } - std::optional TextContext::findIdentifier(std::string_view identifier) + Option TextContext::searchFor(Node identifier) { - return findIdentifier(identifier, 0); + return searchFor(std::move(identifier), 0); } - std::optional TextContext::findIdentifier(std::string_view identifier, uint8_t searchOptions) + Option TextContext::searchFor(Node identifier, size_t scope) { - return std::nullopt; + if(scope > 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(std::move(identifier), scope); + return None(); } std::string TextContext::longId() const { return ""; } - std::string ValueContext::signature() const + std::string scopeSegementId(const ScopeSegment& segment) { - assert(false && "Unimplemented"); - std::string sig; - return sig; + return MATCHV(segment, [&](Node id) { return id->text; }); + } + + std::string ScopedIdentifier::longId() const + { + if(scopes.empty()) + return ""; + std::string id = scopeSegementId(scopes[0]); + for(size_t i = 1; i < scopes.size(); ++i) + id += "::" + scopeSegementId(scopes[i]); + return id; } std::string ValueContext::longId() const @@ -30,7 +43,7 @@ namespace BraneScript std::string id; if(parent) { - if(auto p = parent->lock()) + if(auto p = parent.value().lock()) id += p->longId() + "::"; } if(label) @@ -40,43 +53,177 @@ namespace BraneScript return id; } - std::optional PipelineContext::findIdentifier(std::string_view identifier, uint8_t searchOptions) + std::string CallSigContext::longId() const { - return std::nullopt; + auto lid = std::format("{} -> {}", input->longId(), output->longId()); + if(auto p = getParent()) + lid = std::format("{}::{}", p.value()->longId(), lid); + return lid; } - std::string PipelineContext::longId() const + Option FunctionContext::searchFor(Node identifier, size_t scope) + { + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + for(auto& var : callSig->input->members) + { + if(!var->label) + continue; + if(**id == *var->label.value()) + return Some(var); + } + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); + } + + std::string FunctionContext::longId() const { std::string idText; if(parent) { - if(auto p = parent->lock()) + if(auto p = parent.value().lock()) idText += p->longId() + "::"; } idText += identifier->text; return idText; } - std::optional ModuleContext::getNodeAtChar(TSPoint pos) { return std::nullopt; } + Option ScopeContext::searchFor(Node identifier, size_t scope) + { + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + for(auto& var : localVariables) + { + if(!var->label) + continue; + if(**id == *var->label.value()) + return Some(var); + } + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); + } + + Option PipelineStageContext::searchFor(Node identifier, size_t scope) + { + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + for(auto& var : callSig->input->members) + { + if(!var->label) + continue; + if(**id == *var->label.value()) + return Some(var); + } + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); + } + + Option PipelineContext::searchFor(Node identifier, size_t scope) + { + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + if(**id == *this->identifier) + return Some(std::static_pointer_cast(shared_from_this())); + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); + } + + std::string PipelineContext::longId() const + { + std::string idText; + if(parent) + { + if(auto p = parent.value().lock()) + idText += p->longId() + "::"; + } + idText += identifier->text; + return idText; + } - std::optional ModuleContext::findIdentifier(std::string_view identifier, uint8_t searchOptions) + Option StructContext::searchFor(Node identifier, size_t scope) { - return std::nullopt; + if(auto id = std::get_if>(&identifier->scopes[scope])) + { + if(**id == *this->identifier) + return Some(std::static_pointer_cast(shared_from_this())); + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); } - std::optional DocumentContext::getNodeAtChar(TSPoint pos) { return std::nullopt; } + Option ModuleContext::getNodeAtChar(TSPoint pos) { return None(); } - std::optional DocumentContext::findIdentifier(std::string_view identifier, uint8_t searchOptions) + Option ModuleContext::searchFor(Node identifier, size_t scope) { - return std::nullopt; + // Eventually this will search for generics and traits as well + if(auto res = MATCHV(identifier->scopes[scope], + [&](Node& sid) -> Option + { + if(*sid == *this->identifier) + { + if(identifier->scopes.size() == scope + 1) + return Some(std::static_pointer_cast(shared_from_this())); + // If we match the current scope, but there's more, try to match the rest of the path + scope += 1; + } + + auto s = structs.find(sid->text); + if(s != structs.end()) + return Some(s->second); + + auto f = functions.find(sid->text); + if(f != functions.end()) + return Some(f->second); + + auto p = pipelines.find(sid->text); + if(p != pipelines.end()) + return Some(f->second); + + return None(); + })) + { + return res; + } + + if(scope != 0) + return None(); + if(auto p = getParent()) + return p.value()->searchFor(identifier, scope); + return None(); } + Option DocumentContext::getNodeAtChar(TSPoint pos) { return None(); } + std::string ModuleContext::longId() const { std::string id; if(parent) { - if(auto p = parent->lock()) + if(auto p = parent.value().lock()) id += p->longId() + "::"; } return id + identifier->text; diff --git a/src/parser/documentContext.h b/src/parser/documentContext.h index 7cfa933..bc98dbd 100644 --- a/src/parser/documentContext.h +++ b/src/parser/documentContext.h @@ -3,11 +3,11 @@ #include #include -#include #include #include #include #include "../types/valueType.h" +#include "enums/option.h" #include #include @@ -33,13 +33,13 @@ namespace BraneScript struct BinaryOperatorContext; struct VariableDefinitionContext; struct AssignmentContext; - struct BlockContext; - struct SinkListContext; - struct SourceListContext; + struct AnonStructContext; + struct AnonStructTypeContext; struct CallContext; + struct CallSigContext; + struct ScopeContext; struct PipelineStageContext; - struct AsyncExpressionContext; - struct SinkDefContext; + struct MemberInitContext; struct FunctionContext; struct PipelineContext; struct ModuleContext; @@ -52,53 +52,57 @@ namespace BraneScript Node, Node, Node, - Node, - Node, - Node, + Node, + Node, Node, - Node, + Node, + Node, Node, - Node, + Node, Node, Node, + Node, Node, Node>; - enum IDSearchOptions : uint8_t + struct TextSource { - IDSearchOptions_ChildrenOnly = 1 << 0, // Don't search upwards through the tree - IDSearchOptions_ParentsOnly = 1 << 1, // Don't search downwards through the tree + std::string uri; }; struct TextContext : public std::enable_shared_from_this { TSRange range; - std::optional> parent; + Node source; + Option> parent; virtual ~TextContext() = default; - virtual std::optional getNodeAtChar(TSPoint pos); - virtual std::optional findIdentifier(std::string_view identifier); - virtual std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions); + virtual Option getNodeAtChar(TSPoint pos); + + Option searchFor(Node identifier); + /// Searches up the context tree until we find an identifier that matches the first scope of the identifier, + /// Then attempts to search downwards through matching nodes until the full path matches. + virtual Option searchFor(Node identifier, size_t scope); virtual std::string longId() const; template - std::optional> as() + Option> as() { static_assert(std::is_base_of::value, "T must be a subclass of DocumentContext"); - auto* r = std::dynamic_pointer_cast(shared_from_this()); + auto r = std::dynamic_pointer_cast(shared_from_this()); if(!r) - return std::nullopt; - return r; + return None(); + return Some(r); } template - std::optional> as() const + Option> as() const { static_assert(std::is_base_of::value, "T must be a subclass of DocumentContext"); - auto* r = std::dynamic_pointer_cast(shared_from_this()); + auto r = std::dynamic_pointer_cast(shared_from_this()); if(!r) - return std::nullopt; - return r; + return None(); + return Some(r); } template @@ -108,18 +112,18 @@ namespace BraneScript } template - std::optional> getParent() const + Option> getParent() const { if(!parent) - return std::nullopt; - if(auto p = parent->lock()) + return None(); + if(auto p = parent.value().lock()) { auto pt = p->as(); if(pt) return pt; return p->getParent(); } - return std::nullopt; + return None(); } template @@ -145,8 +149,10 @@ namespace BraneScript { std::string text; operator std::string&(); - bool operator==(const Identifier&) const; - bool operator!=(const Identifier&) const; + + inline bool operator==(const Identifier& o) const { return o.text == text; } + + inline bool operator!=(const Identifier& o) const { return o.text != text; } }; enum class TypeModifiers @@ -164,8 +170,8 @@ namespace BraneScript struct ValueContext : public TextContext { // What data does this value store - std::optional> label; - std::optional> type; + Option> label; + Option> type; // Is stored on the heap or the stack, instead of being a temporary value holder bool isLValue = false; @@ -183,7 +189,6 @@ namespace BraneScript /*ValueContext(TypeContext type, bool isLValue, bool isConst, bool isRef);*/ /*ValueContext(std::string label, TypeContext type, bool isLValue, bool isConst, bool isRef);*/ /**/ - virtual std::string signature() const; std::string longId() const override; }; @@ -214,6 +219,7 @@ namespace BraneScript struct ScopedIdentifier : public TextContext { std::vector scopes; + std::string longId() const override; }; struct ErrorContext @@ -229,17 +235,36 @@ namespace BraneScript struct ExpressionErrorContext; struct ScopeContext; + struct IfContext; + struct WhileContext; + struct ForContext; + struct AssignmentContext; + struct ConstValueContext; + struct LabeledValueReferenceContext; + struct MemberAccessContext; + struct CreateReferenceContext; + struct DereferenceContext; struct UnaryOperatorContext; struct BinaryOperatorContext; + struct AnonStructContext; + struct CallContext; using ExpressionContextNode = std::variant, Node, + Node, + Node, + Node, + Node, + Node, + Node, + Node, + Node, + Node, + Node, Node, - Node>; - - struct AsyncExpressionContext : public TextContext - { - }; + Node, + Node, + Node>; struct ExpressionErrorContext : public ExpressionContext, ErrorContext { @@ -261,14 +286,23 @@ namespace BraneScript std::vector> localVariables; std::vector expressions; - virtual std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions); + Option searchFor(Node identifier, size_t scope) override; + }; + + struct CallSigContext : public TextContext + { + Node input; + Node output; + std::string longId() const override; }; struct PipelineStageContext : public TextContext { - std::vector> localVariables; - std::vector expressions; - std::vector> asyncExpressions; + Option> identifier; + Node callSig; + Node body; + + Option searchFor(Node identifier, size_t scope) override; }; struct IfContext : public ExpressionContext @@ -380,71 +414,51 @@ namespace BraneScript ExpressionContextNode right; }; - struct BlockContext : public ExpressionContext - { - std::vector expressions; - }; - - struct SourceListContext : public TextContext + struct AnonStructTypeContext : public TextContext { - NodeList defs; + NodeList members; }; - struct SinkDefContext : public TextContext + struct MemberInitContext : public TextContext { Node id; ExpressionContextNode expression; }; - struct SinkListContext : public TextContext + struct AnonStructContext : public ExpressionContext { - NodeList values; + NodeList members; }; struct CallContext : public ExpressionContext { - ScopedIdentifier id; - std::vector arguments; - std::vector outputs; - }; - - struct FunctionDescriptionContext : public TextContext - { - Identifier identifier; - Node sources; - Node sinks; - std::string longId() const override; + ExpressionContextNode callable; + Node args; }; struct FunctionContext : public TextContext { - FunctionDescriptionContext description; + Node identifier; + Node callSig; Node body; - std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions) override; - std::string longId() const override; - std::string signature() const; - }; - - struct ImplContext : public TextContext - { - Node type; - LabeledNodeMap methods; + Option searchFor(Node identifier, size_t scope) override; std::string longId() const override; }; struct TraitContext : public TextContext { - Identifier identifier; - NodeList methods; + Node identifier; + LabeledNodeMap methods; std::string longId() const override; }; - struct TraitImplContext : public ImplContext + struct ImplContext : public TextContext { - Identifier trait; + Option> trait; Node type; + LabeledNodeMap methods; std::string longId() const override; }; @@ -452,12 +466,10 @@ namespace BraneScript { Node identifier; // Arguments - Node sources; - Node sinks; - + Node callSig; NodeList stages; - std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions) override; + Option searchFor(Node identifier, size_t scope) override; std::string longId() const override; std::string argSig() const; std::string signature() const; @@ -465,26 +477,26 @@ namespace BraneScript struct StructContext : public TextContext { - Identifier identifier; + Node identifier; - NodeList variables; - NodeList functions; + LabeledNodeMap members; bool packed = false; - std::optional getNodeAtChar(TSPoint pos) override; - std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions) override; + Option getNodeAtChar(TSPoint pos) override; + Option searchFor(Node identifier, size_t scope) override; std::string longId() const override; }; struct ModuleContext : public TextContext { Node identifier; - NodeList structs; - NodeList functions; - NodeList pipelines; + LabeledNodeMap structs; + LabeledNodeMap functions; + LabeledNodeMap pipelines; + + Option getNodeAtChar(TSPoint pos) override; - std::optional getNodeAtChar(TSPoint pos) override; - std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions) override; + Option searchFor(Node identifier, size_t scope) override; std::string longId() const override; }; @@ -493,8 +505,7 @@ namespace BraneScript std::filesystem::path source; LabeledNodeMap modules; - std::optional getNodeAtChar(TSPoint pos) override; - std::optional findIdentifier(std::string_view identifier, uint8_t searchOptions) override; + Option getNodeAtChar(TSPoint pos) override; }; } // namespace BraneScript diff --git a/src/parser/documentParser.cpp b/src/parser/documentParser.cpp index 1da305f..ecde826 100644 --- a/src/parser/documentParser.cpp +++ b/src/parser/documentParser.cpp @@ -2,27 +2,28 @@ // Created by WireWhiz on 10/28/2024. // -#include "documentParser.h" - #include #include -#include #include #include #include #include #include #include + +#include "documentParser.h" +#include "enums/matchv.h" +#include "enums/option.h" #include "parser/documentContext.h" -#include "tree_sitter_branescript.h" +#include "parserEnums.h" +#include "tree-sitter-branescript.h" #include - #define Expect(node, _optional, _errorMessage) \ if(!_optional) \ { \ errorMessage(node, _errorMessage); \ - return std::nullopt; \ + return None(); \ } namespace BraneScript @@ -34,49 +35,24 @@ namespace BraneScript }; template - bool tryCastVariant(std::variant value, std::optional>& result) + bool tryCastVariant(std::variant value, Option>& result) { - return std::visit(overloads{[&](auto&& inner) -> bool { - if constexpr(((std::is_same_v) || ...)) + return std::visit(overloads{[&](auto&& inner) -> bool + { + if constexpr(((std::is_convertible()) || ...)) { - result = inner; + result = Some(inner); + return false; + } + else + { + result = None(); return false; } - result = std::nullopt; - return false; }}, std::move(value)); } - enum class TSNodeType : uint16_t - { - Unknown = 0, - Number, - Identifier, - ScopedIdentifier, - Type, - TemplateArgument, - TemplateArguments, - Add, - Sub, - Mul, - Div, - Assign, - Block, - VariableDefinition, - SinkDef, - SinkList, - SourceDef, - SourceList, - Call, - PipelineStage, - AsyncOperation, - Function, - Pipeline, - Module, - SourceFile - }; - void foreachNodeChild(TSNode node, const std::function& f) { uint32_t childCount = ts_node_child_count(node); @@ -101,156 +77,25 @@ namespace BraneScript return ts_language_symbol_name(tree_sitter_branescript(), ts_node_symbol(node)); } - class TSSymbolLookupTable - { - private: - std::vector symbolToNodeType; - std::vector nodeTypeToSymbol; - - void add_entry(const TSLanguage* lang, std::string_view ident, TSNodeType nodeType) - { - auto symbol = ts_language_symbol_for_name(lang, ident.data(), ident.size(), true); - if(symbol >= symbolToNodeType.size()) - symbolToNodeType.resize(symbol + 1, (TSNodeType)UINT16_MAX); - if((uint16_t)nodeType >= nodeTypeToSymbol.size()) - nodeTypeToSymbol.resize((uint16_t)nodeType + 1, (TSSymbol)UINT16_MAX); - symbolToNodeType[symbol] = nodeType; - nodeTypeToSymbol[(uint16_t)nodeType] = symbol; - } - - public: - TSSymbolLookupTable(const TSLanguage* lang) - { - add_entry(lang, "number", TSNodeType::Number); - add_entry(lang, "module", TSNodeType::Module); - add_entry(lang, "identifier", TSNodeType::Identifier); - add_entry(lang, "scopedIdentifier", TSNodeType::ScopedIdentifier); - add_entry(lang, "type", TSNodeType::Type); - add_entry(lang, "templateArgument", TSNodeType::TemplateArgument); - add_entry(lang, "templateArguments", TSNodeType::TemplateArguments); - add_entry(lang, "add", TSNodeType::Add); - add_entry(lang, "sub", TSNodeType::Sub); - add_entry(lang, "mul", TSNodeType::Mul); - add_entry(lang, "div", TSNodeType::Div); - add_entry(lang, "assign", TSNodeType::Assign); - add_entry(lang, "block", TSNodeType::Block); - add_entry(lang, "variableDefinition", TSNodeType::VariableDefinition); - add_entry(lang, "sinkDef", TSNodeType::SinkDef); - add_entry(lang, "sinkList", TSNodeType::SinkList); - add_entry(lang, "sourceDef", TSNodeType::SourceDef); - add_entry(lang, "sourceList", TSNodeType::SourceList); - add_entry(lang, "call", TSNodeType::Call); - add_entry(lang, "pipelineStage", TSNodeType::PipelineStage); - add_entry(lang, "asyncOperation", TSNodeType::AsyncOperation); - add_entry(lang, "function", TSNodeType::Function); - add_entry(lang, "pipeline", TSNodeType::Pipeline); - add_entry(lang, "sourceFile", TSNodeType::SourceFile); - } - - std::optional tryToNodeType(TSSymbol symbol) const - { - if(symbol > symbolToNodeType.size()) - return std::nullopt; - auto nodeType = symbolToNodeType[symbol]; - return (uint16_t)nodeType == UINT16_MAX ? std::nullopt : std::make_optional(nodeType); - } - - std::optional tryToSymbol(TSNodeType nodeType) const - { - if((uint16_t)nodeType > nodeTypeToSymbol.size()) - return std::nullopt; - auto symbol = nodeTypeToSymbol[(uint16_t)nodeType]; - return symbol == UINT16_MAX ? std::nullopt : std::make_optional(symbol); - } - }; - - enum class TSFieldName : uint16_t - { - Id = 0, - Child, - Type, - Value, - Mut, - Defs, - Left, - Right, - Sources, - Sinks, - Stages, - TemplateArgs, - }; - - class TSFieldLookupTable - { - private: - std::vector _fieldIds; - - void add_entry(const TSLanguage* lang, std::string_view textName, TSFieldName fieldName) - { - auto fID = ts_language_field_id_for_name(lang, textName.data(), textName.size()); - - if((uint16_t)fieldName >= _fieldIds.size()) - _fieldIds.resize((uint16_t)fieldName + 1, (TSFieldId)0); - _fieldIds[(uint16_t)fieldName] = fID; - } - - public: - TSFieldLookupTable(const TSLanguage* lang) - { - add_entry(lang, "id", TSFieldName::Id); - add_entry(lang, "child", TSFieldName::Child); - add_entry(lang, "type", TSFieldName::Type); - add_entry(lang, "value", TSFieldName::Value); - add_entry(lang, "mut", TSFieldName::Mut); - add_entry(lang, "defs", TSFieldName::Defs); - add_entry(lang, "left", TSFieldName::Left); - add_entry(lang, "right", TSFieldName::Right); - add_entry(lang, "sinks", TSFieldName::Sinks); - add_entry(lang, "sources", TSFieldName::Sources); - add_entry(lang, "stages", TSFieldName::Stages); - add_entry(lang, "templateArgs", TSFieldName::TemplateArgs); - } - - TSFieldId get(TSFieldName name) const - { - assert((uint16_t)name < _fieldIds.size() && "Make sure all fields are defined in the constructor"); - return _fieldIds[(uint16_t)name]; - } - }; - - const TSSymbolLookupTable symbolLookupTable = TSSymbolLookupTable(tree_sitter_branescript()); - const TSFieldLookupTable fieldLookupTable = TSFieldLookupTable(tree_sitter_branescript()); - - std::string_view typeToName(TSNodeType type) - { - auto symbol = symbolLookupTable.tryToSymbol(type); - if(!symbol) - return "Unknown"; - return ts_language_symbol_name(tree_sitter_branescript(), symbol.value()); - } - - TSNodeType nodeType(TSNode node) + std::string_view symbolToName(ts_symbol_identifiers symbolId) { - auto nodeType = symbolLookupTable.tryToNodeType(ts_node_symbol(node)); - if(!nodeType) - return TSNodeType::Unknown; - return *nodeType; + return ts_language_symbol_name(tree_sitter_branescript(), symbolId); } - template - bool nodeIsType(TSNode node) + template + bool nodeIsSymbol(TSNode node) { - auto nt = nodeType(node); + auto nt = ts_node_symbol(node); return ((nt == Types) || ...); } - template - void advanceWhileType(TSNode node, const std::function& f) + template + void advanceWhileSymbol(TSNode node, const std::function& f) { bool nodeIsCorrectType = true; if(!ts_node_is_named(node)) node = ts_node_next_named_sibling(node); - while(!ts_node_is_null(node) && nodeIsType(node)) + while(!ts_node_is_null(node) && nodeIsSymbol(node)) { f(node); node = ts_node_next_named_sibling(node); @@ -284,18 +129,19 @@ namespace BraneScript struct ParserAPI { std::filesystem::path path; - std::string_view source; + std::string_view sourceText; + Node source; std::shared_ptr parser; std::vector messages; TSTree* tree; std::list> scopes; - std::optional> currentScope() + Option> currentScope() { if(scopes.empty()) - return std::nullopt; - return scopes.back(); + return None(); + return Some(scopes.back()); } ScopedScope pushScope(Node node) @@ -324,29 +170,29 @@ namespace BraneScript messages.emplace_back(MessageType::Error, nodeRange(ctx), std::move(message)); } - std::optional getField(TSNode node, TSFieldName field) + Option getField(TSNode node, ts_field_identifiers field) { - auto result = ts_node_child_by_field_id(node, fieldLookupTable.get(field)); + auto result = ts_node_child_by_field_id(node, field); if(ts_node_is_null(result)) - return std::nullopt; - return result; + return None(); + return Some(result); } - bool expectNode(TSNode node, TSNodeType type) + bool expectSymbol(TSNode node, ts_symbol_identifiers type) { - auto nodeType = symbolLookupTable.tryToNodeType(ts_node_symbol(node)); + auto nodeType = ts_node_symbol(node); if(nodeType && type == nodeType) return true; std::string message = - std::format("Was expecting \"{}\" but found \"{}\" \n", typeToName(type), nodeText(node)); + std::format("Was expecting \"{}\" but found \"{}\" \n", symbolToName(type), nodeText(node)); errorMessage(node, message); return false; } - bool expectField(TSNode node, const std::vector& types) + bool expectField(TSNode node, const std::vector& types) { - auto nodeType = symbolLookupTable.tryToNodeType(ts_node_symbol(node)); + auto nodeType = ts_node_symbol(node); if(nodeType) { for(auto t : types) @@ -359,7 +205,7 @@ namespace BraneScript std::string message = std::format("Found \"{}\", but was expecting one of the following:\n", nodeText(node)); for(auto t : types) - message += std::format("{}\n", typeToName(t)); + message += std::format("{}\n", symbolToName(t)); errorMessage(node, message); return false; } @@ -368,7 +214,7 @@ namespace BraneScript { auto start = ts_node_start_byte(node); auto end = ts_node_end_byte(node); - return {source.data() + start, end - start}; + return {sourceText.data() + start, end - start}; } template @@ -379,141 +225,144 @@ namespace BraneScript auto new_node = std::make_shared(); new_node->range = nodeRange(context); new_node->parent = currentScope(); + new_node->source = source; return new_node; } - std::optional parse(TSNode node) + Option parse(TSNode node) { - switch(nodeType(node)) + switch((ts_symbol_identifiers)ts_node_symbol(node)) { - case TSNodeType::SourceFile: + case sym_source_file: assert(false && "Cannot call parse on root node!"); - return std::nullopt; - case TSNodeType::Module: + return None(); + case sym_module: return parseModule(node); - case TSNodeType::Pipeline: + case sym_pipeline: return parsePipeline(node); - case TSNodeType::Function: + case sym_function: return parseFunction(node); - case TSNodeType::PipelineStage: + case sym_pipelineStage: return parsePipelineStage(node); - case TSNodeType::AsyncOperation: - return parseAsyncOperation(node); - case TSNodeType::Call: + case sym_callSig: + return parseCallSig(node); + case sym_call: return parseCall(node); - case TSNodeType::SourceList: - return parseSourceList(node); - case TSNodeType::SinkList: - return parseSinkList(node); - case TSNodeType::VariableDefinition: + case sym_anonStruct: + return parseAnonStruct(node); + case sym_variableDefinition: return parseVariableDefinition(node); - case TSNodeType::Block: + case sym_block: return parseBlock(node); - case TSNodeType::Assign: + case sym_assign: return parseAssign(node); - case TSNodeType::Div: + case sym_div: return parseDiv(node); - case TSNodeType::Mul: + case sym_mul: return parseMul(node); - case TSNodeType::Sub: + case sym_sub: return parseSub(node); - case TSNodeType::Add: + case sym_add: return parseAdd(node); - /*case TSNodeType::TemplateArguments: - return parseTemplateArguments(node); - case TSNodeType::TemplateArgument: - return parseTemplateArgument(node);*/ - case TSNodeType::Type: + case sym_type: return parseType(node); - case TSNodeType::ScopedIdentifier: + case sym_scopedIdentifier: return parseScopedIdentifier(node); - case TSNodeType::Identifier: + case sym_identifier: return parseIdentifier(node); - case TSNodeType::Number: + case sym_valueDef: + return parseValueDef(node); + case sym_number: return parseNumber(node); - case TSNodeType::SourceDef: - return parseSourceDef(node); - case TSNodeType::SinkDef: - return parseSinkDef(node); - case TSNodeType::Unknown: + case sym_anonStructType: + return parseAnonStructType(node); default: if(ts_node_has_error(node)) { errorMessage(node, std::format("Unexpected \"{}\"", nodeText(node))); - return std::nullopt; + return None(); } - std::cout << "Unknown TSNode type for: " << nodeText(node) << std::endl; - assert(false && "TSNodeType unhandled!"); - return std::nullopt; + std::cout << std::format("Parser not implemented for symbol {} could not parse \"{}\"", + nodeName(node), + nodeText(node)) + << std::endl; + return None(); } } - std::optional parseExpression(TSNode node) + Option parseExpression(TSNode node) { auto parsed = parse(node); if(!parsed) - return std::nullopt; + return None(); - std::optional out; - tryCastVariant(*parsed, out); + Option out; + tryCastVariant(parsed.value(), out); + if(!out) + errorMessage(node, std::format("Expected expression, but found: {}", nodeText(node))); return out; } - std::optional> parseNumber(TSNode root) + Option> parseNumber(TSNode root) { std::string_view text = nodeText(root); float value; - auto [ptr, res] = std::from_chars(text.data(), text.data() + text.size(), value); - if(res == std::errc()) - { - errorMessage(root, std::format("Error parsing float: {}", std::make_error_code(res).message())); - return std::nullopt; - } + /*auto [ptr, res] = std::from_chars(text.data(), text.data() + text.size(), value);*/ + /*if(res == std::errc())*/ + /*{*/ + /* errorMessage(root, std::format("Error parsing float: {}", std::make_error_code(res).message()));*/ + /* return None();*/ + /*}*/ + value = std::stof((std::string)text); auto constNode = makeNode(root); constNode->value = value; - return constNode; + return Some(constNode); } - std::optional> parseIdentifier(TSNode root) + Option> parseIdentifier(TSNode root) { - if(!expectNode(root, TSNodeType::Identifier)) - return std::nullopt; + if(!expectSymbol(root, sym_identifier)) + return None(); auto ident = makeNode(root); ident->text = nodeText(root); - return ident; + return Some(ident); } - std::optional> parseScopedIdentifier(TSNode root) + Option> parseScopedIdentifier(TSNode root) { - if(!expectNode(root, TSNodeType::ScopedIdentifier)) - return std::nullopt; + if(!expectSymbol(root, sym_scopedIdentifier)) + return None(); auto scopedId = makeNode(root); auto scope = pushScope(scopedId); - foreachNamedNodeChild(root, [&](TSNode ident) { + foreachNamedNodeChild(root, + [&](TSNode ident) + { auto id = parseIdentifier(ident); if(!id) return; - scopedId->scopes.emplace_back(*id); + scopedId->scopes.emplace_back(id.value()); }); if(scopedId->scopes.empty()) { errorMessage(root, "Expected Identifier!"); - return std::nullopt; + return None(); } - return scopedId; + return Some(scopedId); } - std::optional> parseType(TSNode root) + Option> parseType(TSNode root) { - if(!expectNode(root, TSNodeType::Type)) - return std::nullopt; + if(!expectSymbol(root, sym_type)) + return None(); auto type = makeNode(root); auto scope = pushScope(type); bool queuedReference = false; - foreachNodeChild(root, [&](TSNode node) { + foreachNodeChild(root, + [&](TSNode node) + { if(!ts_node_is_named(node)) { auto text = nodeText(node); @@ -553,406 +402,450 @@ namespace BraneScript if(!type->baseType) { errorMessage(root, "Missing identifier"); - return std::nullopt; + return None(); } - return type; + return Some(type); } - // std::optional> parseTemplateArguments(TSNode root); - // std::optional> parseTemplateArgument(TSNode root); - std::optional> parseAdd(TSNode root) + // Option> parseTemplateArguments(TSNode root); + // Option> parseTemplateArgument(TSNode root); + Option> parseAdd(TSNode root) { - if(!expectNode(root, TSNodeType::Add)) - return std::nullopt; + if(!expectSymbol(root, sym_add)) + return None(); auto opr = makeNode(root); auto scope = pushScope(opr); opr->opType = BinaryOperator::Add; - auto tsLeft = getField(root, TSFieldName::Left); + auto tsLeft = getField(root, field_left); Expect(root, tsLeft, "Expected lvalue"); - auto leftNode = parseExpression(*tsLeft); + auto leftNode = parseExpression(tsLeft.value()); if(!leftNode) - return std::nullopt; - opr->left = *leftNode; + return None(); + opr->left = leftNode.value(); - auto tsRight = getField(root, TSFieldName::Left); + auto tsRight = getField(root, field_right); Expect(root, tsRight, "Expected rvalue"); - auto rightNode = parseExpression(*tsRight); + auto rightNode = parseExpression(tsRight.value()); if(!rightNode) - return std::nullopt; - opr->right = *rightNode; + return None(); + opr->right = rightNode.value(); - return opr; + return Some(opr); } - std::optional> parseSub(TSNode root) + Option> parseSub(TSNode root) { - if(!expectNode(root, TSNodeType::Sub)) - return std::nullopt; + if(!expectSymbol(root, sym_sub)) + return None(); auto opr = makeNode(root); auto scope = pushScope(opr); opr->opType = BinaryOperator::Sub; - auto tsLeft = getField(root, TSFieldName::Left); + auto tsLeft = getField(root, field_left); Expect(root, tsLeft, "Expected lvalue"); - auto leftNode = parseExpression(*tsLeft); + auto leftNode = parseExpression(tsLeft.value()); if(!leftNode) - return std::nullopt; - opr->left = *leftNode; + return None(); + opr->left = leftNode.value(); - auto tsRight = getField(root, TSFieldName::Left); + auto tsRight = getField(root, field_left); Expect(root, tsRight, "Expected rvalue"); - auto rightNode = parseExpression(*tsRight); + auto rightNode = parseExpression(tsRight.value()); if(!rightNode) - return std::nullopt; - opr->right = *rightNode; + return None(); + opr->right = rightNode.value(); - return opr; + return Some(opr); } - std::optional> parseMul(TSNode root) + Option> parseMul(TSNode root) { - if(!expectNode(root, TSNodeType::Mul)) - return std::nullopt; + if(!expectSymbol(root, sym_mul)) + return None(); auto opr = makeNode(root); auto scope = pushScope(opr); opr->opType = BinaryOperator::Mul; - auto tsLeft = getField(root, TSFieldName::Left); + auto tsLeft = getField(root, field_left); Expect(root, tsLeft, "Expected lvalue"); - auto leftNode = parseExpression(*tsLeft); + auto leftNode = parseExpression(tsLeft.value()); if(!leftNode) - return std::nullopt; - opr->left = *leftNode; + return None(); + opr->left = leftNode.value(); - auto tsRight = getField(root, TSFieldName::Left); + auto tsRight = getField(root, field_left); Expect(root, tsRight, "Expected rvalue"); - auto rightNode = parseExpression(*tsRight); + auto rightNode = parseExpression(tsRight.value()); if(!rightNode) - return std::nullopt; - opr->right = *rightNode; + return None(); + opr->right = rightNode.value(); - return opr; + return Some(opr); } - std::optional> parseDiv(TSNode root) + Option> parseDiv(TSNode root) { - if(!expectNode(root, TSNodeType::Div)) - return std::nullopt; + if(!expectSymbol(root, sym_div)) + return None(); auto opr = makeNode(root); auto scope = pushScope(opr); opr->opType = BinaryOperator::Div; - auto tsLeft = getField(root, TSFieldName::Left); + auto tsLeft = getField(root, field_left); Expect(root, tsLeft, "Expected lvalue"); - auto leftNode = parseExpression(*tsLeft); + auto leftNode = parseExpression(tsLeft.value()); if(!leftNode) - return std::nullopt; - opr->left = *leftNode; + return None(); + opr->left = leftNode.value(); - auto tsRight = getField(root, TSFieldName::Left); + auto tsRight = getField(root, field_left); Expect(root, tsRight, "Expected rvalue"); - auto rightNode = parseExpression(*tsRight); + auto rightNode = parseExpression(tsRight.value()); if(!rightNode) - return std::nullopt; - opr->right = *rightNode; + return None(); + opr->right = rightNode.value(); - return opr; + return Some(opr); } - std::optional> parseVariableDefinition(TSNode root) + Option> parseVariableDefinition(TSNode root) { - if(!expectNode(root, TSNodeType::VariableDefinition)) - return std::nullopt; + if(!expectSymbol(root, sym_variableDefinition)) + return None(); auto def = makeNode(root); auto scope = pushScope(def); - def->definedValue = makeNode(root); - - auto tsMut = getField(root, TSFieldName::Mut); - def->definedValue->isMut = tsMut.has_value(); - auto tsIdNode = getField(root, TSFieldName::Id); - Expect(root, tsIdNode, "Expected Identifier!"); - auto idNode = parseIdentifier(*tsIdNode); - if(!idNode) - return std::nullopt; - def->definedValue->label = idNode; + auto defField = getField(root, field_def); + Expect(root, defField, "Expected variable definition"); + auto valueDef = parseValueDef(defField.value()); + if(!valueDef) + return None(); + def->definedValue = valueDef.value(); - auto tsTypeNode = getField(root, TSFieldName::Type); - if(tsTypeNode) - def->definedValue->type = parseType(*tsTypeNode); + if(auto parentScope = def->getParent()) + parentScope.value()->localVariables.push_back(def->definedValue); - return def; + return Some(def); } - std::optional> parseAssign(TSNode root) + Option> parseAssign(TSNode root) { - if(!expectNode(root, TSNodeType::Assign)) - return std::nullopt; + if(!expectSymbol(root, sym_assign)) + return None(); auto assign = makeNode(root); auto scope = pushScope(assign); - auto tsLeft = getField(root, TSFieldName::Left); + auto tsLeft = getField(root, field_left); Expect(root, tsLeft, "Expected lvalue"); - auto leftNode = parseExpression(*tsLeft); + auto leftNode = parseExpression(tsLeft.value()); if(!leftNode) - return std::nullopt; - assign->lValue = *leftNode; + return None(); + assign->lValue = leftNode.value(); - auto tsRight = getField(root, TSFieldName::Left); + auto tsRight = getField(root, field_right); Expect(root, tsRight, "Expected rvalue"); - auto rightNode = parseExpression(*tsRight); + auto rightNode = parseExpression(tsRight.value()); if(!rightNode) - return std::nullopt; - assign->rValue = *rightNode; + return None(); + assign->rValue = rightNode.value(); - return assign; + return Some(assign); } - std::optional> parseBlock(TSNode root) + Option> parseBlock(TSNode root) { - if(!expectNode(root, TSNodeType::AsyncOperation)) - return std::nullopt; - auto block = makeNode(root); + if(!expectSymbol(root, sym_block)) + return None(); + auto block = makeNode(root); auto scope = pushScope(block); - foreachNamedNodeChild(root, [&](TSNode node) { - auto expr = parseExpression(node); - if(!expr) - return; - block->expressions.push_back(*expr); - }); + auto expressions = getField(root, field_expressions); + if(expressions) + { + foreachNamedNodeChild(root, + [&](TSNode node) + { + logMessage(node, std::format("parsing block child: {}", nodeText(node))); + auto expr = parseExpression(node); + if(!expr) + { + errorMessage(node, std::format("Was expecting expression but found: {}", nodeText(node))); + return; + } + block->expressions.push_back(expr.value()); + }); + } - return block; + + return Some(block); } - std::optional> parseSinkDef(TSNode root) + Option> parseMemberInit(TSNode root) { - if(!expectNode(root, TSNodeType::SinkDef)) - return std::nullopt; + if(!expectSymbol(root, sym_memberInit)) + return None(); - auto def = makeNode(root); - auto scope = pushScope(def); + auto member = makeNode(root); + auto scope = pushScope(member); - auto tsIdNode = getField(root, TSFieldName::Id); - Expect(root, tsIdNode, "Expected Identifier"); - auto idNode = parseIdentifier(*tsIdNode); - if(!idNode) - return std::nullopt; - def->id = *idNode; + auto idField = getField(root, field_id); + Expect(root, idField, "Expected identifier"); + auto id = parseIdentifier(idField.value()); + if(!id) + return None(); + member->id = id.value(); - auto tsValueNode = getField(root, TSFieldName::Value); - Expect(root, tsValueNode, "Expected Expression"); - auto valueNode = parseExpression(*tsValueNode); - if(!valueNode) - return std::nullopt; - def->expression = *valueNode; - return def; + auto valueField = getField(root, field_value); + Expect(root, valueField, "Expected expression"); + + auto value = parseExpression(valueField.value()); + + + if(!value) + return None(); + member->expression = value.value(); + + return Some(member); } - std::optional> parseSinkList(TSNode root) + Option> parseAnonStruct(TSNode root) { - if(!expectNode(root, TSNodeType::SinkList)) - return std::nullopt; + if(!expectSymbol(root, sym_anonStruct)) + return None(); - auto list = makeNode(root); + auto list = makeNode(root); auto scope = pushScope(list); - foreachNamedNodeChild(root, [&](TSNode tsSinkDefNode) { - auto sinkDefNode = parseSinkDef(tsSinkDefNode); - if(!sinkDefNode) + foreachNamedNodeChild(root, + [&](TSNode tsSinkDefNode) + { + auto memberInit = parseMemberInit(tsSinkDefNode); + if(!memberInit) + return; + list->members.push_back(memberInit.value()); + }); + + return Some(list); + } + + Option> parseAnonStructType(TSNode root) + { + if(!expectSymbol(root, sym_anonStructType)) + return None(); + + auto structDef = makeNode(root); + auto scope = pushScope(structDef); + + foreachNamedNodeChild(root, + [&](TSNode tsSinkDefNode) + { + auto memberDef = parseValueDef(tsSinkDefNode); + if(!memberDef) return; - list->values.push_back(*sinkDefNode); + structDef->members.push_back(memberDef.value()); }); - return list; + return Some(structDef); } - std::optional> parseSourceDef(TSNode root) + Option> parseValueDef(TSNode root) { - if(!expectNode(root, TSNodeType::SourceDef)) - return std::nullopt; + if(!expectSymbol(root, sym_valueDef)) + return None(); auto def = makeNode(root); auto scope = pushScope(def); def->isLValue = true; - auto mutField = getField(root, TSFieldName::Mut); - def->isMut = mutField.has_value(); + auto mutField = getField(root, field_mut); + def->isMut = mutField.isSome(); - auto tsIdNode = getField(root, TSFieldName::Id); + auto tsIdNode = getField(root, field_id); Expect(root, tsIdNode, "Expected identifier"); - def->label = parseIdentifier(*tsIdNode); + def->label = parseIdentifier(tsIdNode.value()); - auto tsTypeNode = getField(root, TSFieldName::Type); + auto tsTypeNode = getField(root, field_type); Expect(root, tsTypeNode, "Expected type"); - auto typeNode = parseType(*tsTypeNode); + auto typeNode = parseType(tsTypeNode.value()); if(!typeNode) - return std::nullopt; - def->type = *typeNode; + return None(); + def->type = Some(typeNode.value()); - return def; + return Some(def); } - std::optional> parseSourceList(TSNode root) + Option> parseCallSig(TSNode root) { - if(!expectNode(root, TSNodeType::SourceList)) - return std::nullopt; + if(!expectSymbol(root, sym_callSig)) + return None(); - auto list = makeNode(root); - auto scope = pushScope(list); + auto callSig = makeNode(root); + auto scope = pushScope(callSig); - foreachNamedNodeChild(root, [&](TSNode tsSourceDefNode) { - auto sourceDefNode = parseSourceDef(tsSourceDefNode); - if(!sourceDefNode) - return; - list->defs.push_back(*sourceDefNode); - }); + auto inputField = getField(root, field_input); + Expect(root, inputField, "Expected call signature input args"); + auto input = parseAnonStructType(inputField.value()); + if(!input) + return None(); + callSig->input = input.value(); + + auto outputType = getField(root, field_output); + if(outputType) + { + auto output = parseAnonStructType(outputType.value()); + if(!output) + return None(); + callSig->output = output.value(); + } + else + callSig->output = makeNode(root); - return list; + return Some(callSig); } - std::optional> parseCall(TSNode root) + Option> parseCall(TSNode root) { - auto range = nodeRange(root); - warningMessage(root, - "Function calls not implemented yet! \n\"" + - std::string(source.substr(range.start_byte, range.end_byte - range.start_byte)) + - "\" will be ignored"); - return std::nullopt; + if(!expectSymbol(root, sym_call)) + return None(); + auto call = makeNode(root); + auto scope = pushScope(call); + + auto funcField = getField(root, field_func); + Expect(root, funcField, "Expected callable expression"); + auto callable = parseExpression(funcField.value()); + if(!callable) + return None(); + call->callable = callable.value(); + + auto argsField = getField(root, field_args); + Expect(root, argsField, "Expected call args"); + auto args = parseAnonStruct(argsField.value()); + if(!args) + return None(); + call->args = args.value(); + + + return Some(call); } - std::optional> parsePipelineStage(TSNode root) + Option> parsePipelineStage(TSNode root) { - if(!expectNode(root, TSNodeType::PipelineStage)) - return std::nullopt; + if(!expectSymbol(root, sym_pipelineStage)) + return None(); auto stage = makeNode(root); auto scope = pushScope(stage); - foreachNamedNodeChild(root, [&](TSNode node) { - auto expr = parseExpression(node); - if(!expr) - return; - stage->expressions.push_back(*expr); - }); - - return stage; - } + auto callSigField = getField(root, field_callSig); + if(callSigField) + { + auto callSig = parseCallSig(callSigField.value()); + if(!callSig) + return None(); + stage->callSig = callSig.value(); + } - std::optional> parseAsyncOperation(TSNode root) - { - if(!expectNode(root, TSNodeType::AsyncOperation)) - return std::nullopt; - auto range = nodeRange(root); - warningMessage(root, - "Async Operations not implemented yet! \n\"" + - std::string(source.substr(range.start_byte, range.end_byte - range.start_byte)) + - "\" will be ignored"); - return std::nullopt; + auto bodyField = getField(root, field_body); + Expect(root, bodyField, "Expected pipeline body"); + auto body = parseBlock(bodyField.value()); + if(!body) + return None(); + stage->body = body.value(); + return Some(stage); } - std::optional> parseFunction(TSNode root) + Option> parseFunction(TSNode root) { warningMessage(root, "Functions not implemented yet!"); - return std::nullopt; + return None(); } - std::optional> parsePipeline(TSNode root) + Option> parsePipeline(TSNode root) { - if(!expectNode(root, TSNodeType::Pipeline)) - return std::nullopt; - auto tsIdNode = getField(root, TSFieldName::Id); - Expect(root, tsIdNode, "Identifier was not found"); - auto idNode = parseIdentifier(*tsIdNode); + if(!expectSymbol(root, sym_pipeline)) + return None(); auto pipe = makeNode(root); - idNode.value()->parent = pipe; - pipe->identifier = *idNode; - - auto tsSourcesNode = getField(root, TSFieldName::Sources); - Expect(root, tsSourcesNode, "Pipeline sources list not found"); - auto sourcesNode = parseSourceList(*tsSourcesNode); - if(!sourcesNode) - return std::nullopt; - pipe->sources = *sourcesNode; - - auto tsSinksNode = getField(root, TSFieldName::Sinks); - Expect(root, tsSinksNode, "Pipeline sinks not found"); - auto sinksNode = parseSinkList(*tsSinksNode); - if(!sinksNode) - return std::nullopt; - pipe->sinks = *sinksNode; - - auto tsStagesNode = getField(root, TSFieldName::Stages); + auto scope = pushScope(pipe); + + auto tsIdNode = getField(root, field_id); + Expect(root, tsIdNode, "Identifier was not found"); + auto idNode = parseIdentifier(tsIdNode.value()); + + idNode.value()->parent = Some>(pipe); + pipe->identifier = idNode.value(); + + auto callSigField = getField(root, field_callSig); + Expect(root, callSigField, "Expected call signature"); + auto callSig = parseCallSig(callSigField.value()); + if(!callSig) + return None(); + pipe->callSig = callSig.value(); + + auto tsStagesNode = getField(root, field_stages); Expect(root, tsStagesNode, "Pipeline must have at least one stage"); - advanceWhileType(*tsStagesNode, - [&](TSNode tsStageNode) { - auto stageNode = parse(tsStageNode); + advanceWhileSymbol(tsStagesNode.value(), + [&](TSNode tsStageNode) + { + auto stageNode = parsePipelineStage(tsStageNode); if(!stageNode) return; - std::visit(overloads{[&](Node stage) { pipe->stages.push_back(stage); }, - [&](Node asyncOp) { - pipe->stages.back()->asyncExpressions.push_back(asyncOp); - }, - [&](auto& none) { - errorMessage(tsStageNode, - std::string("Grammer was correct, but context created was wrong: ") + - typeid(none).name()); - }}, - *stageNode); + pipe->stages.push_back(stageNode.value()); }); - return pipe; + return Some(pipe); } - std::optional> parseModule(TSNode root) + Option> parseModule(TSNode root) { - if(!expectNode(root, TSNodeType::Module)) - return std::nullopt; + if(!expectSymbol(root, sym_module)) + return None(); auto mod = makeNode(root); auto scope = pushScope(mod); - auto identifier = getField(root, TSFieldName::Id); + auto identifier = getField(root, field_id); if(!identifier) - return std::nullopt; + return None(); - auto idNode = parseIdentifier(*identifier); + auto idNode = parseIdentifier(identifier.value()); if(!idNode) - return std::nullopt; - idNode.value()->parent = mod; - mod->identifier = *idNode; + return None(); + idNode.value()->parent = Some>(mod); + mod->identifier = idNode.value(); - auto firstDef = getField(root, TSFieldName::Defs); + auto firstDef = getField(root, field_defs); if(!firstDef) - return mod; - for(auto currentDef = *firstDef; !ts_node_is_null(currentDef) && ts_node_is_named(currentDef); + return Some(mod); + for(auto currentDef = firstDef.value(); !ts_node_is_null(currentDef) && ts_node_is_named(currentDef); currentDef = ts_node_next_named_sibling(currentDef)) { - std::optional def = parse(currentDef); + Option def = parse(currentDef); if(!def) continue; - std::visit( - overloads{[&](Node& pipeline) { mod->pipelines.push_back(std::move(pipeline)); }, - [&](Node& function) { mod->functions.push_back(std::move(function)); }, - [&](auto& none) { + MATCHV(def.value(), + [&](Node& pipeline) + { mod->pipelines.insert({pipeline->identifier->text, std::move(pipeline)}); }, + [&](Node& function) + { mod->functions.insert({function->identifier->text, std::move(function)}); }, + [&](auto& none) + { errorMessage(currentDef, std::string("Grammer was correct, but context created was wrong: ") + typeid(none).name()); - }}, - *def); + }); } - return mod; + return Some(mod); } ParserResult parseDocument() { - tree = ts_parser_parse_string(parser->parser(), nullptr, source.data(), source.size()); + tree = ts_parser_parse_string(parser->parser(), nullptr, sourceText.data(), sourceText.size()); auto doc = std::make_shared(); auto scope = pushScope(doc); @@ -961,8 +854,11 @@ namespace BraneScript TSNode root = ts_tree_root_node(tree); - foreachNodeChild(root, [&](TSNode node) { - auto newMod = parseModule(node); + foreachNodeChild(root, + [&](TSNode node) + { + Option> newMod; + newMod = parseModule(node); if(newMod) doc->modules.insert({newMod.value()->identifier->text, newMod.value()}); }); @@ -1009,8 +905,14 @@ namespace BraneScript { if(_cachedResult) return _cachedResult.value(); - ParserAPI ctx{_path, _source, _parser}; - _cachedResult = std::make_optional(ctx.parseDocument()); + ParserAPI ctx{_path, + _source, + std::make_shared(TextSource{std::format("file://{}", _path.string())}), + _parser, + {}, + {}, + {}}; + _cachedResult = Some(ctx.parseDocument()); return _cachedResult.value(); } diff --git a/src/parser/documentParser.h b/src/parser/documentParser.h index c87bcfd..c66b8b5 100644 --- a/src/parser/documentParser.h +++ b/src/parser/documentParser.h @@ -5,10 +5,9 @@ #ifndef BRANESCRIPT_TSBINDINGS_H #define BRANESCRIPT_TSBINDINGS_H -#include #include #include -#include "parser/documentContext.h" +#include "documentContext.h" #include namespace BraneScript @@ -54,7 +53,7 @@ namespace BraneScript std::string _source; std::shared_ptr _parser; - std::optional> _cachedResult; + Option> _cachedResult; public: ParsedDocument(std::filesystem::path path, std::string source, std::shared_ptr parser); diff --git a/src/parser/tree_sitter_BraneScript.h b/src/parser/tree_sitter_BraneScript.h deleted file mode 100644 index 9d74f45..0000000 --- a/src/parser/tree_sitter_BraneScript.h +++ /dev/null @@ -1,19 +0,0 @@ -// -// Created by WireWhiz on 10/28/2024. -// - -#ifndef BRANESCRIPT_TREE_SITTER_BRANESCRIPT_H -#define BRANESCRIPT_TREE_SITTER_BRANESCRIPT_H - -#include -#include -#include -#include - -extern "C"{ - const TSLanguage *tree_sitter_branescript(void); -} - - - -#endif // BRANESCRIPT_TREE_SITTER_BRANESCRIPT_H diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt new file mode 100644 index 0000000..5552669 --- /dev/null +++ b/src/runtime/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_library(runtime STATIC runtime.cpp) + +add_subdirectory(backends) + diff --git a/src/runtime/backends/CMakeLists.txt b/src/runtime/backends/CMakeLists.txt new file mode 100644 index 0000000..8eca84a --- /dev/null +++ b/src/runtime/backends/CMakeLists.txt @@ -0,0 +1,2 @@ + +add_subdirectory(llvm) diff --git a/src/runtime/backends/jitBackend.h b/src/runtime/backends/jitBackend.h new file mode 100644 index 0000000..c2342fb --- /dev/null +++ b/src/runtime/backends/jitBackend.h @@ -0,0 +1,65 @@ +#pragma once + +#include "enums/result.h" +#include "ir/ir.h" + +namespace BraneScript +{ + struct JitPtr + { + uint16_t binding; + uint16_t index; + + int32_t asInt() { return *((int32_t*)this); } + + static JitPtr fromInt(int32_t data) { return *((JitPtr*)&data); } + }; + + using JitFuncHandle = void(__cdecl*)(void*, int32_t, int32_t); + + struct JitStructType; + using JitType = std::variant>; + + struct JitStructMember + { + Option label; + JitType type; + uint16_t offset; + }; + + struct JitStructType + { + std::vector members; + }; + + struct JitFunction + { + std::shared_ptr input; + std::shared_ptr output; + JitFuncHandle handle; + }; + + struct JitPipeline + { + std::shared_ptr input; + std::shared_ptr output; + std::vector> stages; + }; + + struct JitModule + { + std::unordered_map> structTypes; + std::unordered_map> functions; + std::unordered_map> pipelines; + }; + + class JitBackend + { + public: + virtual void stageModule(std::shared_ptr module) = 0; + /// Consume all staged modules + virtual Result processModules() = 0; + virtual Option> getPipeline(std::string_view moduleName, + std::string_view pipelineName) = 0; + }; +} // namespace BraneScript diff --git a/src/runtime/backends/llvm/CMakeLists.txt b/src/runtime/backends/llvm/CMakeLists.txt new file mode 100644 index 0000000..92963ad --- /dev/null +++ b/src/runtime/backends/llvm/CMakeLists.txt @@ -0,0 +1,17 @@ + + + +find_package(LLVM REQUIRED CONFIG) + +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + + +add_library(llvmJitBackend STATIC llvmJitBackend.cpp) +target_include_directories(llvmJitBackend PRIVATE ${LLVM_INCLUDE_DIRS}) +target_compile_definitions(llvmJitBackend PRIVATE ${LLVM_DEFINITIONS}) + +llvm_map_components_to_libnames(LLVM_LIBS support core irreader executionengine asmprinter nativecodegen orcjit transformutils target analysis x86asmparser) +message(STATUS "Linking llvm libs ${LLVM_LIBS}") +target_link_libraries(llvmJitBackend PRIVATE ${LLVM_LIBS}) +target_include_directories(llvmJitBackend PUBLIC ${braneutilities_SOURCE_DIR}/src) diff --git a/src/runtime/backends/llvm/llvmJitBackend.cpp b/src/runtime/backends/llvm/llvmJitBackend.cpp new file mode 100644 index 0000000..eb67f87 --- /dev/null +++ b/src/runtime/backends/llvm/llvmJitBackend.cpp @@ -0,0 +1,757 @@ +#include "llvmJitBackend.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +#include +#include +#include + +namespace BraneScript +{ + struct LLVMModuleBuilderCtx + { + llvm::LLVMContext* llvmCtx; + // IR builder + std::shared_ptr> builder; + std::unique_ptr llvmMod; + + + // Debug info builer + std::unique_ptr dBuilder; + llvm::DICompileUnit* diCompileUnit = nullptr; + llvm::DIFile* diFile = nullptr; + + // walker state + llvm::Function* currentFunc = nullptr; + llvm::BasicBlock* currentBlock = nullptr; + ; + + llvm::DISubprogram* diFunction = nullptr; + + const IRModule* currentMod; + + struct StructTypeCtx + { + llvm::StructType* llvmType; + const IRStruct* bsType; + }; + + struct FunctionCtx + { + llvm::Function* llvmFunc; + const IRFunction* bsFunc; + }; + + struct NativeTypeCtx + { + llvm::Type* llvmType = nullptr; + std::variant type; + }; + + using TypeContext = std::variant; + + struct ValueCtx + { + llvm::Value* value = nullptr; + TypeContext type; + }; + + std::vector structs; + std::vector functions; + std::vector values; + + // Cache values + llvm::FunctionType* functionType = nullptr; + llvm::Value* memTable = nullptr; + llvm::Type* branePtrType = nullptr; + ValueCtx retPtr; + + Result getNativeType(IRNativeType type) const + { + llvm::Type* llvmType = nullptr; + switch(type) + { + case IRNativeType::U8: + case IRNativeType::I8: + llvmType = llvm::Type::getInt8Ty(*llvmCtx); + break; + case IRNativeType::U16: + case IRNativeType::I16: + llvmType = llvm::Type::getInt16Ty(*llvmCtx); + break; + case IRNativeType::U32: + case IRNativeType::I32: + llvmType = llvm::Type::getInt32Ty(*llvmCtx); + break; + case IRNativeType::F32: + llvmType = llvm::Type::getFloatTy(*llvmCtx); + break; + case IRNativeType::U64: + case IRNativeType::I64: + llvmType = llvm::Type::getInt64Ty(*llvmCtx); + break; + case IRNativeType::F64: + llvmType = llvm::Type::getDoubleTy(*llvmCtx); + break; + case IRNativeType::U128: + case IRNativeType::I128: + llvmType = llvm::Type::getInt128Ty(*llvmCtx); + break; + } + if(!llvmType) + return Err("Invalid IRNativeType value"); + return Ok(NativeTypeCtx{.llvmType = llvmType, .type = type}); + }; + + Result getStructType(const IDRef& id) + { + return MATCHV(id, + [&](const std::string& idStr) -> Result + { + for(auto& s : structs) + { + if(s.bsType->id.isNone()) + continue; + if(s.bsType->id.value() == idStr) + return Ok(s); + } + return Err(std::format("No struct defined with id {}", idStr)); + }, + [&](int32_t idIndex) -> Result + { + if(idIndex >= structs.size()) + return Err(std::format("Struct index out of range")); + return Ok(structs[idIndex]); + }); + } + + static llvm::Type* getLLVMType(const TypeContext& ctx) + { + return MATCHV(ctx, [](auto& ctx) -> llvm::Type* { return ctx.llvmType; }); + } + + Result getType(const IRType& type) + { + return MATCHV(type, + [&](const IRNativeType& type) -> Result + { + auto res = getNativeType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }, + [&](const IDRef& type) -> Result + { + auto res = getStructType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }); + }; + + Result getValue(const IRValue& value) + { + if(value.id >= values.size()) + return Err(std::format("Value index {} out of range", value.id)); + return Ok(values[value.id]); + } + + // Evalutate a int32 brane script pointer against the current memory table to get the full pointer + size_t evalCount = 0; + + llvm::Value* evaluatePtr(llvm::Type* type, llvm::Value* intPtr) + { + auto* memIndex = builder->CreateLShr(intPtr, builder->getInt32(16), "memIndex"); + auto* bindingIndex = builder->CreateAnd(intPtr, builder->getInt32(0x0000FFFF), "bindIndex"); + + auto bitType = llvm::Type::getInt1Ty(*llvmCtx); + auto* basePtr = builder->CreateLoad(llvm::PointerType::get(bitType, 0), + builder->CreateGEP(memTable->getType(), memTable, {bindingIndex})); + auto* offsetPtr = builder->CreateGEP(bitType, basePtr, memIndex); + + // Debug + builder->CreateStore(builder->CreateAdd(memIndex, builder->getInt32(100)), + builder->CreateGEP(bitType, basePtr, builder->getInt32((evalCount++) * 4 + 32))); + + return builder->CreateBitOrPointerCast(offsetPtr, llvm::PointerType::get(type, 0)); + } + + Result getStructMember(ValueCtx structValue, size_t member) + { + auto* structTypeCtx = std::get_if(&structValue.type); + if(!structTypeCtx) + return Err("Tried to get member of type that was not a struct"); + + if(member > structTypeCtx->bsType->members.size()) + return Err(std::format("Tried to get member {} but struct only has {} members", + member, + structTypeCtx->bsType->members.size())); + + + auto memberValue = builder->CreateStructGEP(structTypeCtx->llvmType, structValue.value, member); + auto memberTypeRes = getType(structTypeCtx->bsType->members[member].type); + if(!memberTypeRes) + return Err("Could not resove member type: " + memberTypeRes.err()); + + return Ok(ValueCtx{ + .value = memberValue, + .type = memberTypeRes.ok(), + }); + } + + Result<> buildMov(const MovOp& mov) + { + auto srcRes = getValue(mov.src); + CHECK_RESULT(srcRes); + auto destRes = getValue(mov.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcPtr.type), srcPtr.value, "movArg"); + builder->CreateStore(srcValue, destPtr.value, false); + return Ok<>(); + } + + Result<> buildLoad(const LoadOp& load) + { + auto branePtrRes = getValue(load.ptr); + CHECK_RESULT(branePtrRes); + auto destRes = getValue(load.dest); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto destValue = destRes.ok(); + + + auto loadedPtr = evaluatePtr(getLLVMType(destValue.type), branePtr.value); + builder->CreateStore(loadedPtr, destValue.value); + return Ok<>(); + } + + Result<> buildStore(const StoreOp& store) + { + + auto branePtrRes = getValue(store.ptr); + CHECK_RESULT(branePtrRes); + auto srcRes = getValue(store.src); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto srcValuePtr = srcRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcValuePtr.type), srcValuePtr.value); + + + auto loadedPtr = evaluatePtr(getLLVMType(srcValuePtr.type), branePtr.value); + builder->CreateStore(srcValue, loadedPtr); + return Ok<>(); + } + + Result<> buildMemAcc(const MemAcc& memAcc) + { + auto srcRes = getValue(memAcc.ptr); + CHECK_RESULT(srcRes); + auto destRes = getValue(memAcc.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto memberValue = getStructMember(srcPtr, memAcc.index); + CHECK_RESULT(memberValue); + values[memAcc.dest.id] = memberValue.ok(); + + return Ok<>(); + } + + Result<> buildAdd(const AddOp& op) + { + auto leftRes = getValue(op.left); + CHECK_RESULT(leftRes); + auto rightRes = getValue(op.right); + CHECK_RESULT(rightRes); + auto destRes = getValue(op.out); + CHECK_RESULT(destRes); + + auto leftPtr = leftRes.ok(); + auto rightPtr = rightRes.ok(); + auto destPtr = destRes.ok(); + + auto leftValue = builder->CreateLoad(getLLVMType(leftPtr.type), leftPtr.value, "loadAddArg"); + auto rightValue = builder->CreateLoad(getLLVMType(rightPtr.type), rightPtr.value, "loadAddArg"); + + auto res = builder->CreateAdd(leftValue, rightValue); + builder->CreateStore(res, destPtr.value, "storeAddRes"); + + return Ok<>(); + } + + Result<> buildNextStage(const NextStageOp& call) + { + auto inputRes = getValue(call.input); + CHECK_RESULT(inputRes); + auto inputValue = inputRes.ok(); + auto inputStructType = std::get_if(&inputValue.type); + if(!inputStructType) + return Err("Tried to return non struct type"); + + auto memberCount = inputStructType->bsType->members.size(); + retPtr.type = inputValue.type; // "Temporary workaround" + for(size_t i = 0; i < memberCount; ++i) + { + auto inputMemberRes = getStructMember(inputValue, i); + auto outputMemberRes = getStructMember(retPtr, i); + CHECK_RESULT(inputMemberRes); + CHECK_RESULT(outputMemberRes); + auto inputMemberValue = inputMemberRes.ok(); + auto input = + builder->CreateLoad(getLLVMType(inputMemberValue.type), inputMemberValue.value, "loadStackInput"); + builder->CreateStore(input, outputMemberRes.ok().value, false); + } + + builder->CreateRetVoid(); + return Ok<>(); + } + + Result<> buildOp(const IROperation& op) + { + return MATCHV(op, + [&](const MovOp& mov) { return buildMov(mov); }, + [&](const LoadOp& load) -> Result<> { return buildLoad(load); }, + [&](const StoreOp& store) -> Result<> { return buildStore(store); }, + [&](const MemAcc& memAcc) -> Result<> { return buildMemAcc(memAcc); }, + [&](const ConstI32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstU32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstF32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const AddOp& op) -> Result<> { return buildAdd(op); }, + [&](const SubOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const MulOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const DivOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ModOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const EqOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const GeOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicNotOp& op) -> Result<> + { // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitNotOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitXorOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const CallOp& call) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const NextStageOp& call) -> Result<> { return buildNextStage(call); } + + ); + } + + Result> buildModule(const IRModule& module, + const llvm::DataLayout& layout, + std::string_view triplet, + llvm::LLVMContext* ctxPtr) + { + llvmCtx = ctxPtr; + builder = std::make_shared>(*llvmCtx); + + currentMod = &module; + llvmMod = std::make_unique(module.id, *llvmCtx); + + llvmMod->setDataLayout(layout); // *_layout); + llvmMod->setTargetTriple(triplet); //_session->getExecutorProcessControl().getTargetTriple().str()); + + + branePtrType = llvm::ArrayType::get(llvm::Type::getInt16Ty(*llvmCtx), 2); + + structs.reserve(module.structs.size()); + for(size_t i = 0; i < module.structs.size(); ++i) + { + auto& s = module.structs[i]; + + std::string sid = s.id ? s.id.value() : std::format("-s{}", i); + + std::vector memberTypes; + for(size_t m = 0; m < s.members.size(); ++m) + { + std::string mid = s.members[m].id ? s.members[m].id.value() : std::format("-{}", m); + auto llvmType = getType(s.members[m].type); + if(!llvmType) + return Err(llvmType.err()); + + memberTypes.push_back(getLLVMType(llvmType.ok())); + } + + llvm::StructType* st; + if(memberTypes.empty()) + st = llvm::StructType::create(*llvmCtx, sid); + else + st = llvm::StructType::create(memberTypes, sid, false); + + structs.push_back(StructTypeCtx{.llvmType = st, .bsType = &s}); + } + + auto memTableType = llvm::ArrayType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*llvmCtx), 0), 65535); + llvm::Type* funcArgs[3] = {llvm::PointerType::get(memTableType, + 0), // Pointer to array of addresses + llvm::Type::getInt32Ty(*llvmCtx), // Brane pointer to args struct + llvm::Type::getInt32Ty(*llvmCtx)}; // Brane pointer to output struct + functionType = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvmCtx), funcArgs, false); + + + functions.reserve(module.functions.size()); + for(size_t i = 0; i < module.functions.size(); ++i) + { + auto& f = module.functions[i]; + + auto inputRes = getStructType(f.input); + if(!inputRes) + return Err(std::format("Couldn't compile input type for {}: {}", f.id, inputRes.err())); + auto inputType = inputRes.ok(); + + auto outputRes = getStructType(f.output); + if(!outputRes) + return Err(std::format("Couldn't compile output type for {}: {}", f.id, inputRes.err())); + auto outputType = outputRes.ok(); + + auto linkage = llvm::Function::ExternalLinkage; + currentFunc = llvm::Function::Create(functionType, linkage, f.id, llvmMod.get()); + currentFunc->setCallingConv(llvm::CallingConv::C); + currentBlock = llvm::BasicBlock::Create(*llvmCtx, "entry", currentFunc); + builder->SetInsertPoint(currentBlock); + + memTable = currentFunc->getArg(0); + auto inputStruct = + ValueCtx{.value = evaluatePtr(inputType.llvmType, currentFunc->getArg(1)), .type = inputType}; + + // Really need a better return mechanism + retPtr = + ValueCtx{.value = evaluatePtr(outputType.llvmType, currentFunc->getArg(2)), .type = outputType}; + + auto inputCount = inputType.bsType->members.size(); + for(size_t i = 0; i < f.localVars.size(); ++i) + { + auto& var = f.localVars[i]; + if(i < inputCount) + { + auto argRes = getStructMember(inputStruct, i); + if(!argRes) + return Err(argRes.err()); + auto arg = argRes.ok(); + arg.value->setName(std::format("local{}", i)); + values.push_back(arg); + std::cout << i << " Re-route function arg member" << std::endl; + } + else + { + auto typeRes = getType(f.localVars[i]); + if(!typeRes) + return Err(typeRes.err()); + MATCHV(typeRes.ok(), + [&](NativeTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType, 0, std::format("local{}", i)); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new native value" << std::endl; + }, + [&](StructTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType, 0, std::format("local{}", i)); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new struct value" << std::endl; + }); + } + } + + for(auto& op : f.operations) + { + auto res = buildOp(op); + if(!res) + return Err(res.err()); + } + values.clear(); + std::string funcError; + llvm::raw_string_ostream funcErrorStream(funcError); + if(llvm::verifyFunction(*currentFunc, &funcErrorStream)) + return Err(funcError); + } + + std::string modError; + llvm::raw_string_ostream modErrorStream(modError); + if(llvm::verifyModule(*llvmMod, &modErrorStream)) + return Err(modError); + + std::string moduleContent; + llvm::raw_string_ostream moduleStream(moduleContent); + llvmMod->print(moduleStream, nullptr); + std::cout << "Staging module: \n" << moduleContent << std::endl; + + return Ok(std::move(llvmMod)); + } + + LLVMModuleBuilderCtx() {} + }; // namespace BraneScript + + bool llvmInitialized = false; + + LLVMJitBackend::LLVMJitBackend() + { + if(!llvmInitialized) + { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvmInitialized = true; + } + + llvm::ObjectCache* OC = nullptr; // Implement later? + + + auto builderRes = + llvm::orc::LLJITBuilder() + .setCompileFunctionCreator( + [&](llvm::orc::JITTargetMachineBuilder JTMB) + -> llvm::Expected> + { + auto TM = JTMB.createTargetMachine(); + if(!TM) + return TM.takeError(); + return std::make_unique(std::move(*TM), OC); + }) + .setObjectLinkingLayerCreator([&](llvm::orc::ExecutionSession& ES, + const llvm::Triple& TT) -> std::unique_ptr + { + // Except for the GDBListener registration, the rest of + // the code is taken from LLJIT.cpp. + auto GetMemMgr = []() { return std::make_unique(); }; + auto ObjLinkingLayer = std::make_unique(ES, std::move(GetMemMgr)); + if(TT.isOSBinFormatCOFF()) + { + ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true); + } + ObjLinkingLayer->registerJITEventListener(*llvm::JITEventListener::createGDBRegistrationListener()); + return ObjLinkingLayer; + }).create(); + if(!builderRes) + throw std::runtime_error("Failed to create llvm jit"); + _llJit = std::move(*builderRes); + } + + LLVMJitBackend::~LLVMJitBackend() {}; + + void LLVMJitBackend::stageModule(std::shared_ptr module) + { + auto llvmCtx = std::make_unique(); + LLVMModuleBuilderCtx ctx; + auto newMod = ctx.buildModule(*module, _llJit->getDataLayout(), _llJit->getTargetTriple().str(), llvmCtx.get()); + if(!newMod) + { + std::cout << "Failed to stage module! " << newMod.err() << std::endl; + return; + } + _stagedModules.push_back(StageContext{ + .ir = module, + .llvm = std::make_unique(std::move(newMod.ok()), std::move(llvmCtx)), + }); + } + + Result LLVMJitBackend::processModules() + { + for(auto& mod : _stagedModules) + { + + auto libRes = _llJit->createJITDylib(mod.ir->id); + if(!libRes) + { + std::string resMessage; + llvm::raw_string_ostream resStream(resMessage); + resStream << libRes.takeError(); + std::cout << "Failed to create jit dylib: " << resMessage << std::endl; + continue; + } + auto& lib = *libRes; + + auto addRes = _llJit->addIRModule(lib, std::move(*mod.llvm)); + if(addRes) + { + std::string resMessage; + llvm::raw_string_ostream resStream(resMessage); + resStream << addRes; + std::cout << "Failed to add jit module: " << resMessage << std::endl; + continue; + } + + for(auto& func : mod.ir->functions) + { + auto functionSym = _llJit->lookup(lib, func.id); + if(!functionSym) + throw std::runtime_error("Unable to find exported function: " + toString(functionSym.takeError())); + _functions.insert({func.id, functionSym->toPtr()}); + } + + + /*std::unordered_map> newStructs; + // Extract all structs before populating members, so we can resolve out of order dependencies. + + for(auto& s : irModule.structs) + { + auto sConstructor = module->getFunction(s.constructorSig); + auto sDestructor = module->getFunction(s.destructorSig); + auto sCopyConstructor = module->getFunction(s.copyConstructorSig); + auto sMoveConstructor = module->getFunction(s.moveConstructorSig); + if(!sConstructor || !sDestructor || !sCopyConstructor || !sMoveConstructor) + throw std::runtime_error("Module load failed, missing constructors for: " + s.name); + auto sDef = new StructDef(s.name, sConstructor, sCopyConstructor, sMoveConstructor, sDestructor); + for(auto& tag : s.tags) + sDef->tags.insert(tag); + newStructs.emplace(s.name, sDef); + } + + for(auto& ns : newStructs) + { + if(_types.contains(ns.first)) + throw std::runtime_error("Module load failed, cannot load a type twice: " + ns.first); + _types.insert({ns.first, ns.second.get()}); + } + + size_t structIndex = 0; + for(auto& ns : newStructs) + { + auto structLayout = exportedStructLayouts[structIndex]; + auto& srcStruct = irModule.structs[structIndex]; + ns.second->size = structLayout->getSizeInBytes(); + for(size_t i = 0; i < srcStruct.members.size(); i++) + { + auto& src = srcStruct.members[i]; + auto memberOffset = structLayout->getElementOffset(i); + auto memberType = _types.find(src.type); + if(memberType == _types.end()) + throw std::runtime_error("Module load failed, unable to find type metadata: " + src.type); + ns.second->memberVars.push_back( + StructVar{src.name, VarType{memberType->second, src.isRef}, memberOffset}); + } + + printf("Adding struct %s\n", ns.first.c_str()); + printf("Size: %zu\n", ns.second->size); + printf("Members:\n"); + for(auto& m : ns.second->memberVars) + printf(" %s %s, offset: %zu\n", m.type.def->name.c_str(), m.name.c_str(), m.offset); + + module->structDefinitions.insert({ns.first, std::move(ns.second)}); + ++structIndex; + }*/ + + /*_modules.insert(irModule.id, module); + for(auto& dep : irModule.links) + _modules.addDependency(irModule.id, dep);*/ + } + return Ok<>(); + } + + Option> LLVMJitBackend::getPipeline(std::string_view moduleName, + std::string_view pipelineName) + { + return None(); + } + + std::unordered_map& LLVMJitBackend::functions() { return _functions; } +} // namespace BraneScript diff --git a/src/runtime/backends/llvm/llvmJitBackend.cpp~ b/src/runtime/backends/llvm/llvmJitBackend.cpp~ new file mode 100644 index 0000000..6c35d7c --- /dev/null +++ b/src/runtime/backends/llvm/llvmJitBackend.cpp~ @@ -0,0 +1,831 @@ +#include "llvmJitBackend.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +#include +#include +#include + +namespace BraneScript +{ + struct LLVMModuleBuilderCtx + { + llvm::LLVMContext* llvmCtx; + // IR builder + std::shared_ptr> builder; + std::unique_ptr llvmMod; + + + // Debug info builer + std::unique_ptr dBuilder; + llvm::DICompileUnit* diCompileUnit = nullptr; + llvm::DIFile* diFile = nullptr; + + // walker state + llvm::Function* currentFunc = nullptr; + llvm::BasicBlock* currentBlock = nullptr; + ; + + llvm::DISubprogram* diFunction = nullptr; + + const IRModule* currentMod; + + struct StructTypeCtx + { + llvm::StructType* llvmType; + const IRStruct* bsType; + }; + + struct FunctionCtx + { + llvm::Function* llvmFunc; + const IRFunction* bsFunc; + }; + + struct NativeTypeCtx + { + llvm::Type* llvmType = nullptr; + std::variant type; + }; + + using TypeContext = std::variant; + + struct ValueCtx + { + llvm::Value* value = nullptr; + TypeContext type; + }; + + std::vector structs; + std::vector functions; + std::vector values; + + // Cache values + llvm::FunctionType* functionType = nullptr; + llvm::Value* memTable = nullptr; + llvm::Type* branePtrType = nullptr; + ValueCtx retPtr; + + Result getNativeType(IRNativeType type) const + { + llvm::Type* llvmType = nullptr; + switch(type) + { + case IRNativeType::U8: + case IRNativeType::I8: + llvmType = llvm::Type::getInt8Ty(*llvmCtx); + case IRNativeType::U16: + case IRNativeType::I16: + llvmType = llvm::Type::getInt16Ty(*llvmCtx); + case IRNativeType::U32: + case IRNativeType::I32: + llvmType = llvm::Type::getInt32Ty(*llvmCtx); + case IRNativeType::F32: + llvmType = llvm::Type::getFloatTy(*llvmCtx); + case IRNativeType::U64: + case IRNativeType::I64: + llvmType = llvm::Type::getInt64Ty(*llvmCtx); + case IRNativeType::F64: + llvmType = llvm::Type::getDoubleTy(*llvmCtx); + case IRNativeType::U128: + case IRNativeType::I128: + llvmType = llvm::Type::getInt128Ty(*llvmCtx); + } + if(!llvmType) + return Err("Invalid IRNativeType value"); + return Ok(NativeTypeCtx{.llvmType = llvmType, .type = type}); + }; + + Result getStructType(const IDRef& id) + { + return MATCHV(id, + [&](const std::string& idStr) -> Result + { + for(auto& s : structs) + { + if(s.bsType->id.isNone()) + continue; + if(s.bsType->id.value() == idStr) + return Ok(s); + } + return Err(std::format("No struct defined with id {}", idStr)); + }, + [&](int32_t idIndex) -> Result + { + if(idIndex >= structs.size()) + return Err(std::format("Struct index out of range")); + return Ok(structs[idIndex]); + }); + } + + static llvm::Type* getLLVMType(const TypeContext& ctx) + { + return MATCHV(ctx, [](auto& ctx) -> llvm::Type* { return ctx.llvmType; }); + } + + Result getType(const IRType& type) + { + return MATCHV(type, + [&](const IRNativeType& type) -> Result + { + auto res = getNativeType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }, + [&](const IDRef& type) -> Result + { + auto res = getStructType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }); + }; + + Result getValue(const IRValue& value) + { + if(value.id >= values.size()) + return Err(std::format("Value index {} out of range", value.id)); + return Ok(values[value.id]); + } + + // Evalutate a int32 brane script pointer against the current memory table to get the full pointer + llvm::Value* evaluatePtr(llvm::Type* type, llvm::Value* intPtr) + { + auto* shiftedHigh = builder->CreateLShr(intPtr, builder->getInt32(16), ""); + auto* maskedLow = builder->CreateAnd(intPtr, builder->getInt32(0xFFFF), ""); + + auto* memIndex = builder->CreateTrunc(shiftedHigh, builder->getInt16Ty(), ""); + auto* bindingIndex = builder->CreateTrunc(maskedLow, builder->getInt16Ty(), "l"); + + auto* basePtr = builder->CreateGEP(memTable->getType(), memTable, {bindingIndex}); + auto* offsetPtr = builder->CreatePtrAdd(basePtr, memIndex); + return builder->CreateBitOrPointerCast(offsetPtr, llvm::PointerType::get(type, 0)); + } + + Result getStructMember(ValueCtx structValue, size_t member) + { + auto* structTypeCtx = std::get_if(&structValue.type); + if(!structTypeCtx) + return Err("Tried to get member of type that was not a struct"); + + if(member > structTypeCtx->bsType->members.size()) + return Err(std::format("Tried to get member {} but struct only has {} members", + member, + structTypeCtx->bsType->members.size())); + + + auto memberValue = builder->CreateStructGEP(structTypeCtx->llvmType, structValue.value, member); + auto memberTypeRes = getType(structTypeCtx->bsType->members[member].type); + if(!memberTypeRes) + return Err("Could not resove member type: " + memberTypeRes.err()); + + return Ok(ValueCtx{ + .value = memberValue, + .type = memberTypeRes.ok(), + }); + } + + Result<> buildMov(const MovOp& mov) + { + auto srcRes = getValue(mov.src); + CHECK_RESULT(srcRes); + auto destRes = getValue(mov.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcPtr.type), srcPtr.value); + builder->CreateStore(srcValue, destPtr.value); + return Ok<>(); + } + + Result<> buildLoad(const LoadOp& load) + { + auto branePtrRes = getValue(load.ptr); + CHECK_RESULT(branePtrRes); + auto destRes = getValue(load.dest); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto destValue = destRes.ok(); + + + auto loadedPtr = evaluatePtr(getLLVMType(destValue.type), branePtr.value); + builder->CreateStore(loadedPtr, destValue.value); + return Ok<>(); + } + + Result<> buildStore(const StoreOp& store) + { + + auto branePtrRes = getValue(store.ptr); + CHECK_RESULT(branePtrRes); + auto srcRes = getValue(store.src); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto srcValuePtr = srcRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcValuePtr.type), srcValuePtr.value); + + + auto loadedPtr = evaluatePtr(getLLVMType(srcValuePtr.type), branePtr.value); + builder->CreateStore(srcValue, loadedPtr); + return Ok<>(); + } + + Result<> buildMemAcc(const MemAcc& memAcc) + { + auto srcRes = getValue(memAcc.ptr); + CHECK_RESULT(srcRes); + auto destRes = getValue(memAcc.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto memberValue = getStructMember(srcPtr, memAcc.index); + CHECK_RESULT(memberValue); + values[memAcc.dest.id] = memberValue.ok(); + + return Ok<>(); + } + + Result<> buildAdd(const AddOp& op) + { + auto leftRes = getValue(op.left); + CHECK_RESULT(leftRes); + auto rightRes = getValue(op.right); + CHECK_RESULT(leftRes); + auto destRes = getValue(op.out); + CHECK_RESULT(destRes); + + auto leftPtr = leftRes.ok(); + auto rightPtr = leftRes.ok(); + auto destPtr = destRes.ok(); + + auto leftValue = builder->CreateLoad(getLLVMType(leftPtr.type), leftPtr.value); + auto rightValue = builder->CreateLoad(getLLVMType(rightPtr.type), rightPtr.value); + + auto res = builder->CreateAdd(leftValue, rightValue); + builder->CreateStore(res, destPtr.value); + + return Ok<>(); + } + + Result<> buildNextStage(const NextStageOp& call) + { + auto inputRes = getValue(call.input); + CHECK_RESULT(inputRes); + auto inputValue = inputRes.ok(); + auto inputStructType = std::get_if(&inputValue.type); + if(!inputStructType) + return Err("Tried to return non struct type"); + + auto memberCount = inputStructType->bsType->members.size(); + retPtr.type = inputValue.type; // "Temporary workaround" + for(size_t i = 0; i < memberCount; ++i) + { + auto inputMemberRes = getStructMember(inputValue, i); + + auto outputMemberRes = getStructMember(retPtr, i); + CHECK_RESULT(inputMemberRes); + CHECK_RESULT(outputMemberRes); + auto inputMemberValue = inputMemberRes.ok(); + builder->CreateStore(builder->CreateLoad(getLLVMType(inputMemberValue.type), inputMemberValue.value), + outputMemberRes.ok().value); + } + builder->CreateRetVoid(); + return Ok<>(); + } + + Result<> buildOp(const IROperation& op) + { + return MATCHV(op, + [&](const MovOp& mov) { return buildMov(mov); }, + [&](const LoadOp& load) -> Result<> { return buildLoad(load); }, + [&](const StoreOp& store) -> Result<> { return buildStore(store); }, + [&](const MemAcc& memAcc) -> Result<> { return buildMemAcc(memAcc); }, + [&](const ConstI32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstU32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstF32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const AddOp& op) -> Result<> { return buildAdd(op); }, + [&](const SubOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const MulOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const DivOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ModOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const EqOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const GeOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicNotOp& op) -> Result<> + { // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitNotOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitXorOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const CallOp& call) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const NextStageOp& call) -> Result<> { return buildNextStage(call); } + + ); + } + + Result> buildModule(const IRModule& module, + llvm::DataLayout& layout, + std::string_view triplet, + llvm::LLVMContext* ctxPtr) + { + llvmCtx = ctxPtr; + builder = std::make_shared>(*llvmCtx); + + currentMod = &module; + llvmMod = std::make_unique(module.id, *llvmCtx); + + llvmMod->setDataLayout(layout); // *_layout); + llvmMod->setTargetTriple(triplet); //_session->getExecutorProcessControl().getTargetTriple().str()); + + + branePtrType = llvm::ArrayType::get(llvm::Type::getInt16Ty(*llvmCtx), 2); + + structs.reserve(module.structs.size()); + for(size_t i = 0; i < module.structs.size(); ++i) + { + auto& s = module.structs[i]; + + std::string sid = s.id ? s.id.value() : std::format("-{}", i); + + std::vector memberTypes; + for(size_t m = 0; m < s.members.size(); ++m) + { + std::string mid = s.members[m].id ? s.members[m].id.value() : std::format("-{}", m); + auto llvmType = getType(s.members[m].type); + if(!llvmType) + return Err(llvmType.err()); + + memberTypes.push_back(getLLVMType(llvmType.ok())); + } + + llvm::StructType* st; + if(memberTypes.empty()) + st = llvm::StructType::create(*llvmCtx, sid); + else + st = llvm::StructType::create(memberTypes, sid, false); + + structs.push_back(StructTypeCtx{.llvmType = st, .bsType = &s}); + } + + auto memTableType = llvm::ArrayType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*llvmCtx), 0), 65535); + llvm::Type* funcArgs[3] = {llvm::PointerType::get(memTableType, + 0), // Pointer to array of addresses + llvm::Type::getInt32Ty(*llvmCtx), // Brane pointer to args struct + llvm::Type::getInt32Ty(*llvmCtx)}; // Brane pointer to output struct + functionType = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvmCtx), funcArgs, false); + + + functions.reserve(module.functions.size()); + for(size_t i = 0; i < module.functions.size(); ++i) + { + auto& f = module.functions[i]; + + auto inputRes = getStructType(f.input); + if(!inputRes) + return Err(std::format("Couldn't compile input type for {}: {}", f.id, inputRes.err())); + auto inputType = inputRes.ok(); + + auto outputRes = getStructType(f.output); + if(!outputRes) + return Err(std::format("Couldn't compile output type for {}: {}", f.id, inputRes.err())); + auto outputType = outputRes.ok(); + + auto linkage = llvm::Function::ExternalLinkage; + currentFunc = llvm::Function::Create(functionType, linkage, f.id, llvmMod.get()); + currentBlock = llvm::BasicBlock::Create(*llvmCtx, "entry", currentFunc); + builder->SetInsertPoint(currentBlock); + + memTable = currentFunc->getArg(0); + auto inputStruct = + ValueCtx{.value = evaluatePtr(inputType.llvmType, currentFunc->getArg(1)), .type = inputType}; + + // Really need a better return mechanism + retPtr = + ValueCtx{.value = evaluatePtr(outputType.llvmType, currentFunc->getArg(2)), .type = outputType}; + + auto inputCount = inputType.bsType->members.size(); + for(size_t i = 0; i < f.localVars.size(); ++i) + { + auto& var = f.localVars[i]; + if(i < inputCount) + { + auto argRes = getStructMember(inputStruct, i); + if(!argRes) + return Err(argRes.err()); + values.push_back(argRes.ok()); + std::cout << i << " Re-route function arg member" << std::endl; + } + else + { + auto typeRes = getType(f.localVars[i]); + if(!typeRes) + return Err(typeRes.err()); + MATCHV(typeRes.ok(), + [&](NativeTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new native value" << std::endl; + }, + [&](StructTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new struct value" << std::endl; + }); + } + } + + for(auto& op : f.operations) + { + auto res = buildOp(op); + if(!res) + return Err(res.err()); + } + + values.clear(); + std::string funcError; + llvm::raw_string_ostream funcErrorStream(funcError); + if(llvm::verifyFunction(*currentFunc, &funcErrorStream)) + return Err(funcError); + } + + std::string modError; + llvm::raw_string_ostream modErrorStream(modError); + if(llvm::verifyModule(*llvmMod, &modErrorStream)) + return Err(modError); + + std::string moduleContent; + llvm::raw_string_ostream moduleStream(moduleContent); + llvmMod->print(moduleStream, nullptr); + std::cout << "Staging module: \n" << moduleContent << std::endl; + + return Ok(std::move(llvmMod)); + } + + LLVMModuleBuilderCtx() {} + }; // namespace BraneScript + + bool llvmInitialized = false; + + LLVMJitBackend::LLVMJitBackend() + { + if(!llvmInitialized) + { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvmInitialized = true; + } + + _orcCtx = std::make_unique(std::make_unique()); + + auto EPC = llvm::orc::SelfExecutorProcessControl::Create(); + if(!EPC) + throw std::runtime_error("Could not create llvm epc"); + _session = std::make_unique(std::move(*EPC)); + + llvm::orc::JITTargetMachineBuilder jtmb(_session->getExecutorProcessControl().getTargetTriple()); + jtmb.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive); + + auto dl = jtmb.getDefaultDataLayoutForTarget(); + if(!dl) + throw std::runtime_error("Unable to get target data layout: " + toString(dl.takeError())); + _layout = std::make_unique(*dl); + + _mangler = std::make_unique(*_session, *_layout); + + _linkingLayer = std::make_unique( + *_session, []() { return std::make_unique(); }); + + // Debug stuff + _linkingLayer->registerJITEventListener(*llvm::JITEventListener::createGDBRegistrationListener()); + _linkingLayer->setProcessAllSections(true); + + _compileLayer = std::make_unique( + *_session, *_linkingLayer, std::make_unique(jtmb)); + std::shared_ptr tm; + if(auto etm = jtmb.createTargetMachine()) + tm = std::move(*etm); + _transformLayer = + std::make_unique(*_session, + *_compileLayer, + [this, tm](llvm::orc::ThreadSafeModule m, const auto& r) + { + // Add some optimizations. + std::string moduleErr; + llvm::raw_string_ostream modErrStr(moduleErr); + if(llvm::verifyModule(*m.getModuleUnlocked(), &modErrStr)) + { + std::string module; + llvm::raw_string_ostream modStr(module); + m.getModuleUnlocked()->print(modStr, nullptr); + fprintf(stderr, "%s", module.c_str()); + std::cout << "Module verification failed: " + modErrStr.str() << std::endl; + // TODO report failure async + } + + // Only return here if debug + return std::move(m); + + llvm::LoopAnalysisManager lam; + llvm::FunctionAnalysisManager fam; + llvm::CGSCCAnalysisManager cgam; + llvm::ModuleAnalysisManager mam; + + llvm::PassBuilder PB(tm.get()); + + // Register all the basic analyses with the managers. + + PB.registerModuleAnalyses(mam); + PB.registerCGSCCAnalyses(cgam); + PB.registerFunctionAnalyses(fam); + PB.registerLoopAnalyses(lam); + PB.crossRegisterProxies(lam, fam, cgam, mam); + + // Create the pass manager. + auto fpm = PB.buildFunctionSimplificationPipeline(llvm::OptimizationLevel::O2, + llvm::ThinOrFullLTOPhase::FullLTOPostLink); + + + // Run the optimizations over all functions in the module being added to + // the JIT. + for(auto& f : *m.getModuleUnlocked()) + { + if(f.empty()) + continue; + fpm.run(f, fam); + } + + return std::move(m); + }); + } + + LLVMJitBackend::~LLVMJitBackend() + { + llvm::Error res = _session->endSession(); + if(res) + { + std::string err; + llvm::raw_string_ostream errorStream(err); + errorStream << res; + std::cout << "Runtime session had an error when ending: " << err << std::endl; + } + }; + + void LLVMJitBackend::stageModule(std::shared_ptr module) + { + LLVMModuleBuilderCtx ctx; + auto newMod = ctx.buildModule( + *module, *_layout, _session->getExecutorProcessControl().getTargetTriple().str(), _orcCtx->getContext()); + if(!newMod) + { + std::cout << "Failed to stage module! " << newMod.err() << std::endl; + return; + } + _stagedModules.push_back(StageContext{ + .ir = module, + .llvm = std::move(newMod.ok()), + }); + } + + Result LLVMJitBackend::processModules() + { + for(auto& mod : _stagedModules) + { + + auto ctx = _orcCtx->getContext(); + auto& lib = _session->createBareJITDylib(mod.ir->id); + /*for(auto& importedModule : irModule.links) + { + if(!_modules.contains(importedModule)) + throw std::runtime_error(irModule.id + " Could not link to required module: " + importedModule); + lib.addToLinkOrder(*_modules.at(importedModule)->lib); + }*/ + + llvm::orc::SymbolMap symbols; + /*for(auto& f : irModule.functions) + { + auto nativeFunc = _nativeSymbols.find(f.name); + if(nativeFunc != _nativeSymbols.end()) + symbols[(*_mangler)(f.name)] = nativeFunc->second; + } + if(!symbols.empty()) + { + if(auto err = lib.define(llvm::orc::absoluteSymbols(symbols))) + throw std::runtime_error("Unable to define native symbols: " + toString(std::move(err))); + }*/ + + // auto module = new JitModule(lib); + // module->rt = lib.getDefaultResourceTracker(); + + /*printf("Loading module %s with structs:\n", irModule.id.c_str()); + for(auto& s : (*deserializedModule)->getIdentifiedStructTypes()) + printf(" %s\n", s->getName().str().c_str());*/ + + /*std::vector exportedStructLayouts; + for(auto& s : irModule.structs) + { + auto structType = llvm::StructType::getTypeByName(*ctx, s.name); + if(!structType) + throw std::runtime_error("Module load failed, unable to find struct metadata: " + s.name); + const llvm::StructLayout* structLayout = _layout->getStructLayout(structType); + exportedStructLayouts.push_back(structLayout); + }*/ + + if(llvm::Error res = _transformLayer->add(lib.getDefaultResourceTracker(), + llvm::orc::ThreadSafeModule(std::move(mod.llvm), *_orcCtx))) + throw std::runtime_error("Unable to add script: " + toString(std::move(res))); + + /*for(auto& glob : irModule.globals)*/ + /*{*/ + /* llvm::Expected globSymbol = _session->lookup({&lib}, + * (*_mangler)(glob.name));*/ + /* if(!globSymbol)*/ + /* throw std::runtime_error("Unable to find exported global: " + toString(globSymbol.takeError()));*/ + /* module->globalNames.insert({glob.name, module->globalVars.size()});*/ + /* module->globalVars.push_back((void*)globSymbol->getAddress());*/ + /*}*/ + + for(auto& func : mod.ir->functions) + { + auto mangledId = (*_mangler)(func.id); + auto functionSym = _session->lookup({&lib}, mangledId); + if(!functionSym) + throw std::runtime_error("Unable to find exported function: " + toString(functionSym.takeError())); + _functions.insert({func.id, functionSym->getAddress().toPtr()}); + } + + /*std::unordered_map> newStructs; + // Extract all structs before populating members, so we can resolve out of order dependencies. + + for(auto& s : irModule.structs) + { + auto sConstructor = module->getFunction(s.constructorSig); + auto sDestructor = module->getFunction(s.destructorSig); + auto sCopyConstructor = module->getFunction(s.copyConstructorSig); + auto sMoveConstructor = module->getFunction(s.moveConstructorSig); + if(!sConstructor || !sDestructor || !sCopyConstructor || !sMoveConstructor) + throw std::runtime_error("Module load failed, missing constructors for: " + s.name); + auto sDef = new StructDef(s.name, sConstructor, sCopyConstructor, sMoveConstructor, sDestructor); + for(auto& tag : s.tags) + sDef->tags.insert(tag); + newStructs.emplace(s.name, sDef); + } + + for(auto& ns : newStructs) + { + if(_types.contains(ns.first)) + throw std::runtime_error("Module load failed, cannot load a type twice: " + ns.first); + _types.insert({ns.first, ns.second.get()}); + } + + size_t structIndex = 0; + for(auto& ns : newStructs) + { + auto structLayout = exportedStructLayouts[structIndex]; + auto& srcStruct = irModule.structs[structIndex]; + ns.second->size = structLayout->getSizeInBytes(); + for(size_t i = 0; i < srcStruct.members.size(); i++) + { + auto& src = srcStruct.members[i]; + auto memberOffset = structLayout->getElementOffset(i); + auto memberType = _types.find(src.type); + if(memberType == _types.end()) + throw std::runtime_error("Module load failed, unable to find type metadata: " + src.type); + ns.second->memberVars.push_back( + StructVar{src.name, VarType{memberType->second, src.isRef}, memberOffset}); + } + + printf("Adding struct %s\n", ns.first.c_str()); + printf("Size: %zu\n", ns.second->size); + printf("Members:\n"); + for(auto& m : ns.second->memberVars) + printf(" %s %s, offset: %zu\n", m.type.def->name.c_str(), m.name.c_str(), m.offset); + + module->structDefinitions.insert({ns.first, std::move(ns.second)}); + ++structIndex; + }*/ + + /*_modules.insert(irModule.id, module); + for(auto& dep : irModule.links) + _modules.addDependency(irModule.id, dep);*/ + } + return Ok<>(); + } + + Option> LLVMJitBackend::getPipeline(std::string_view moduleName, + std::string_view pipelineName) + { + return None(); + } + + std::unordered_map& LLVMJitBackend::functions() { return _functions; } +} // namespace BraneScript diff --git a/src/runtime/backends/llvm/llvmJitBackend.cpq~ b/src/runtime/backends/llvm/llvmJitBackend.cpq~ new file mode 100644 index 0000000..6c35d7c --- /dev/null +++ b/src/runtime/backends/llvm/llvmJitBackend.cpq~ @@ -0,0 +1,831 @@ +#include "llvmJitBackend.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +#include +#include +#include + +namespace BraneScript +{ + struct LLVMModuleBuilderCtx + { + llvm::LLVMContext* llvmCtx; + // IR builder + std::shared_ptr> builder; + std::unique_ptr llvmMod; + + + // Debug info builer + std::unique_ptr dBuilder; + llvm::DICompileUnit* diCompileUnit = nullptr; + llvm::DIFile* diFile = nullptr; + + // walker state + llvm::Function* currentFunc = nullptr; + llvm::BasicBlock* currentBlock = nullptr; + ; + + llvm::DISubprogram* diFunction = nullptr; + + const IRModule* currentMod; + + struct StructTypeCtx + { + llvm::StructType* llvmType; + const IRStruct* bsType; + }; + + struct FunctionCtx + { + llvm::Function* llvmFunc; + const IRFunction* bsFunc; + }; + + struct NativeTypeCtx + { + llvm::Type* llvmType = nullptr; + std::variant type; + }; + + using TypeContext = std::variant; + + struct ValueCtx + { + llvm::Value* value = nullptr; + TypeContext type; + }; + + std::vector structs; + std::vector functions; + std::vector values; + + // Cache values + llvm::FunctionType* functionType = nullptr; + llvm::Value* memTable = nullptr; + llvm::Type* branePtrType = nullptr; + ValueCtx retPtr; + + Result getNativeType(IRNativeType type) const + { + llvm::Type* llvmType = nullptr; + switch(type) + { + case IRNativeType::U8: + case IRNativeType::I8: + llvmType = llvm::Type::getInt8Ty(*llvmCtx); + case IRNativeType::U16: + case IRNativeType::I16: + llvmType = llvm::Type::getInt16Ty(*llvmCtx); + case IRNativeType::U32: + case IRNativeType::I32: + llvmType = llvm::Type::getInt32Ty(*llvmCtx); + case IRNativeType::F32: + llvmType = llvm::Type::getFloatTy(*llvmCtx); + case IRNativeType::U64: + case IRNativeType::I64: + llvmType = llvm::Type::getInt64Ty(*llvmCtx); + case IRNativeType::F64: + llvmType = llvm::Type::getDoubleTy(*llvmCtx); + case IRNativeType::U128: + case IRNativeType::I128: + llvmType = llvm::Type::getInt128Ty(*llvmCtx); + } + if(!llvmType) + return Err("Invalid IRNativeType value"); + return Ok(NativeTypeCtx{.llvmType = llvmType, .type = type}); + }; + + Result getStructType(const IDRef& id) + { + return MATCHV(id, + [&](const std::string& idStr) -> Result + { + for(auto& s : structs) + { + if(s.bsType->id.isNone()) + continue; + if(s.bsType->id.value() == idStr) + return Ok(s); + } + return Err(std::format("No struct defined with id {}", idStr)); + }, + [&](int32_t idIndex) -> Result + { + if(idIndex >= structs.size()) + return Err(std::format("Struct index out of range")); + return Ok(structs[idIndex]); + }); + } + + static llvm::Type* getLLVMType(const TypeContext& ctx) + { + return MATCHV(ctx, [](auto& ctx) -> llvm::Type* { return ctx.llvmType; }); + } + + Result getType(const IRType& type) + { + return MATCHV(type, + [&](const IRNativeType& type) -> Result + { + auto res = getNativeType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }, + [&](const IDRef& type) -> Result + { + auto res = getStructType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }); + }; + + Result getValue(const IRValue& value) + { + if(value.id >= values.size()) + return Err(std::format("Value index {} out of range", value.id)); + return Ok(values[value.id]); + } + + // Evalutate a int32 brane script pointer against the current memory table to get the full pointer + llvm::Value* evaluatePtr(llvm::Type* type, llvm::Value* intPtr) + { + auto* shiftedHigh = builder->CreateLShr(intPtr, builder->getInt32(16), ""); + auto* maskedLow = builder->CreateAnd(intPtr, builder->getInt32(0xFFFF), ""); + + auto* memIndex = builder->CreateTrunc(shiftedHigh, builder->getInt16Ty(), ""); + auto* bindingIndex = builder->CreateTrunc(maskedLow, builder->getInt16Ty(), "l"); + + auto* basePtr = builder->CreateGEP(memTable->getType(), memTable, {bindingIndex}); + auto* offsetPtr = builder->CreatePtrAdd(basePtr, memIndex); + return builder->CreateBitOrPointerCast(offsetPtr, llvm::PointerType::get(type, 0)); + } + + Result getStructMember(ValueCtx structValue, size_t member) + { + auto* structTypeCtx = std::get_if(&structValue.type); + if(!structTypeCtx) + return Err("Tried to get member of type that was not a struct"); + + if(member > structTypeCtx->bsType->members.size()) + return Err(std::format("Tried to get member {} but struct only has {} members", + member, + structTypeCtx->bsType->members.size())); + + + auto memberValue = builder->CreateStructGEP(structTypeCtx->llvmType, structValue.value, member); + auto memberTypeRes = getType(structTypeCtx->bsType->members[member].type); + if(!memberTypeRes) + return Err("Could not resove member type: " + memberTypeRes.err()); + + return Ok(ValueCtx{ + .value = memberValue, + .type = memberTypeRes.ok(), + }); + } + + Result<> buildMov(const MovOp& mov) + { + auto srcRes = getValue(mov.src); + CHECK_RESULT(srcRes); + auto destRes = getValue(mov.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcPtr.type), srcPtr.value); + builder->CreateStore(srcValue, destPtr.value); + return Ok<>(); + } + + Result<> buildLoad(const LoadOp& load) + { + auto branePtrRes = getValue(load.ptr); + CHECK_RESULT(branePtrRes); + auto destRes = getValue(load.dest); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto destValue = destRes.ok(); + + + auto loadedPtr = evaluatePtr(getLLVMType(destValue.type), branePtr.value); + builder->CreateStore(loadedPtr, destValue.value); + return Ok<>(); + } + + Result<> buildStore(const StoreOp& store) + { + + auto branePtrRes = getValue(store.ptr); + CHECK_RESULT(branePtrRes); + auto srcRes = getValue(store.src); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto srcValuePtr = srcRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcValuePtr.type), srcValuePtr.value); + + + auto loadedPtr = evaluatePtr(getLLVMType(srcValuePtr.type), branePtr.value); + builder->CreateStore(srcValue, loadedPtr); + return Ok<>(); + } + + Result<> buildMemAcc(const MemAcc& memAcc) + { + auto srcRes = getValue(memAcc.ptr); + CHECK_RESULT(srcRes); + auto destRes = getValue(memAcc.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto memberValue = getStructMember(srcPtr, memAcc.index); + CHECK_RESULT(memberValue); + values[memAcc.dest.id] = memberValue.ok(); + + return Ok<>(); + } + + Result<> buildAdd(const AddOp& op) + { + auto leftRes = getValue(op.left); + CHECK_RESULT(leftRes); + auto rightRes = getValue(op.right); + CHECK_RESULT(leftRes); + auto destRes = getValue(op.out); + CHECK_RESULT(destRes); + + auto leftPtr = leftRes.ok(); + auto rightPtr = leftRes.ok(); + auto destPtr = destRes.ok(); + + auto leftValue = builder->CreateLoad(getLLVMType(leftPtr.type), leftPtr.value); + auto rightValue = builder->CreateLoad(getLLVMType(rightPtr.type), rightPtr.value); + + auto res = builder->CreateAdd(leftValue, rightValue); + builder->CreateStore(res, destPtr.value); + + return Ok<>(); + } + + Result<> buildNextStage(const NextStageOp& call) + { + auto inputRes = getValue(call.input); + CHECK_RESULT(inputRes); + auto inputValue = inputRes.ok(); + auto inputStructType = std::get_if(&inputValue.type); + if(!inputStructType) + return Err("Tried to return non struct type"); + + auto memberCount = inputStructType->bsType->members.size(); + retPtr.type = inputValue.type; // "Temporary workaround" + for(size_t i = 0; i < memberCount; ++i) + { + auto inputMemberRes = getStructMember(inputValue, i); + + auto outputMemberRes = getStructMember(retPtr, i); + CHECK_RESULT(inputMemberRes); + CHECK_RESULT(outputMemberRes); + auto inputMemberValue = inputMemberRes.ok(); + builder->CreateStore(builder->CreateLoad(getLLVMType(inputMemberValue.type), inputMemberValue.value), + outputMemberRes.ok().value); + } + builder->CreateRetVoid(); + return Ok<>(); + } + + Result<> buildOp(const IROperation& op) + { + return MATCHV(op, + [&](const MovOp& mov) { return buildMov(mov); }, + [&](const LoadOp& load) -> Result<> { return buildLoad(load); }, + [&](const StoreOp& store) -> Result<> { return buildStore(store); }, + [&](const MemAcc& memAcc) -> Result<> { return buildMemAcc(memAcc); }, + [&](const ConstI32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstU32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstF32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const AddOp& op) -> Result<> { return buildAdd(op); }, + [&](const SubOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const MulOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const DivOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ModOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const EqOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const GeOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicNotOp& op) -> Result<> + { // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitNotOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitXorOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const CallOp& call) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const NextStageOp& call) -> Result<> { return buildNextStage(call); } + + ); + } + + Result> buildModule(const IRModule& module, + llvm::DataLayout& layout, + std::string_view triplet, + llvm::LLVMContext* ctxPtr) + { + llvmCtx = ctxPtr; + builder = std::make_shared>(*llvmCtx); + + currentMod = &module; + llvmMod = std::make_unique(module.id, *llvmCtx); + + llvmMod->setDataLayout(layout); // *_layout); + llvmMod->setTargetTriple(triplet); //_session->getExecutorProcessControl().getTargetTriple().str()); + + + branePtrType = llvm::ArrayType::get(llvm::Type::getInt16Ty(*llvmCtx), 2); + + structs.reserve(module.structs.size()); + for(size_t i = 0; i < module.structs.size(); ++i) + { + auto& s = module.structs[i]; + + std::string sid = s.id ? s.id.value() : std::format("-{}", i); + + std::vector memberTypes; + for(size_t m = 0; m < s.members.size(); ++m) + { + std::string mid = s.members[m].id ? s.members[m].id.value() : std::format("-{}", m); + auto llvmType = getType(s.members[m].type); + if(!llvmType) + return Err(llvmType.err()); + + memberTypes.push_back(getLLVMType(llvmType.ok())); + } + + llvm::StructType* st; + if(memberTypes.empty()) + st = llvm::StructType::create(*llvmCtx, sid); + else + st = llvm::StructType::create(memberTypes, sid, false); + + structs.push_back(StructTypeCtx{.llvmType = st, .bsType = &s}); + } + + auto memTableType = llvm::ArrayType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*llvmCtx), 0), 65535); + llvm::Type* funcArgs[3] = {llvm::PointerType::get(memTableType, + 0), // Pointer to array of addresses + llvm::Type::getInt32Ty(*llvmCtx), // Brane pointer to args struct + llvm::Type::getInt32Ty(*llvmCtx)}; // Brane pointer to output struct + functionType = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvmCtx), funcArgs, false); + + + functions.reserve(module.functions.size()); + for(size_t i = 0; i < module.functions.size(); ++i) + { + auto& f = module.functions[i]; + + auto inputRes = getStructType(f.input); + if(!inputRes) + return Err(std::format("Couldn't compile input type for {}: {}", f.id, inputRes.err())); + auto inputType = inputRes.ok(); + + auto outputRes = getStructType(f.output); + if(!outputRes) + return Err(std::format("Couldn't compile output type for {}: {}", f.id, inputRes.err())); + auto outputType = outputRes.ok(); + + auto linkage = llvm::Function::ExternalLinkage; + currentFunc = llvm::Function::Create(functionType, linkage, f.id, llvmMod.get()); + currentBlock = llvm::BasicBlock::Create(*llvmCtx, "entry", currentFunc); + builder->SetInsertPoint(currentBlock); + + memTable = currentFunc->getArg(0); + auto inputStruct = + ValueCtx{.value = evaluatePtr(inputType.llvmType, currentFunc->getArg(1)), .type = inputType}; + + // Really need a better return mechanism + retPtr = + ValueCtx{.value = evaluatePtr(outputType.llvmType, currentFunc->getArg(2)), .type = outputType}; + + auto inputCount = inputType.bsType->members.size(); + for(size_t i = 0; i < f.localVars.size(); ++i) + { + auto& var = f.localVars[i]; + if(i < inputCount) + { + auto argRes = getStructMember(inputStruct, i); + if(!argRes) + return Err(argRes.err()); + values.push_back(argRes.ok()); + std::cout << i << " Re-route function arg member" << std::endl; + } + else + { + auto typeRes = getType(f.localVars[i]); + if(!typeRes) + return Err(typeRes.err()); + MATCHV(typeRes.ok(), + [&](NativeTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new native value" << std::endl; + }, + [&](StructTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new struct value" << std::endl; + }); + } + } + + for(auto& op : f.operations) + { + auto res = buildOp(op); + if(!res) + return Err(res.err()); + } + + values.clear(); + std::string funcError; + llvm::raw_string_ostream funcErrorStream(funcError); + if(llvm::verifyFunction(*currentFunc, &funcErrorStream)) + return Err(funcError); + } + + std::string modError; + llvm::raw_string_ostream modErrorStream(modError); + if(llvm::verifyModule(*llvmMod, &modErrorStream)) + return Err(modError); + + std::string moduleContent; + llvm::raw_string_ostream moduleStream(moduleContent); + llvmMod->print(moduleStream, nullptr); + std::cout << "Staging module: \n" << moduleContent << std::endl; + + return Ok(std::move(llvmMod)); + } + + LLVMModuleBuilderCtx() {} + }; // namespace BraneScript + + bool llvmInitialized = false; + + LLVMJitBackend::LLVMJitBackend() + { + if(!llvmInitialized) + { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvmInitialized = true; + } + + _orcCtx = std::make_unique(std::make_unique()); + + auto EPC = llvm::orc::SelfExecutorProcessControl::Create(); + if(!EPC) + throw std::runtime_error("Could not create llvm epc"); + _session = std::make_unique(std::move(*EPC)); + + llvm::orc::JITTargetMachineBuilder jtmb(_session->getExecutorProcessControl().getTargetTriple()); + jtmb.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive); + + auto dl = jtmb.getDefaultDataLayoutForTarget(); + if(!dl) + throw std::runtime_error("Unable to get target data layout: " + toString(dl.takeError())); + _layout = std::make_unique(*dl); + + _mangler = std::make_unique(*_session, *_layout); + + _linkingLayer = std::make_unique( + *_session, []() { return std::make_unique(); }); + + // Debug stuff + _linkingLayer->registerJITEventListener(*llvm::JITEventListener::createGDBRegistrationListener()); + _linkingLayer->setProcessAllSections(true); + + _compileLayer = std::make_unique( + *_session, *_linkingLayer, std::make_unique(jtmb)); + std::shared_ptr tm; + if(auto etm = jtmb.createTargetMachine()) + tm = std::move(*etm); + _transformLayer = + std::make_unique(*_session, + *_compileLayer, + [this, tm](llvm::orc::ThreadSafeModule m, const auto& r) + { + // Add some optimizations. + std::string moduleErr; + llvm::raw_string_ostream modErrStr(moduleErr); + if(llvm::verifyModule(*m.getModuleUnlocked(), &modErrStr)) + { + std::string module; + llvm::raw_string_ostream modStr(module); + m.getModuleUnlocked()->print(modStr, nullptr); + fprintf(stderr, "%s", module.c_str()); + std::cout << "Module verification failed: " + modErrStr.str() << std::endl; + // TODO report failure async + } + + // Only return here if debug + return std::move(m); + + llvm::LoopAnalysisManager lam; + llvm::FunctionAnalysisManager fam; + llvm::CGSCCAnalysisManager cgam; + llvm::ModuleAnalysisManager mam; + + llvm::PassBuilder PB(tm.get()); + + // Register all the basic analyses with the managers. + + PB.registerModuleAnalyses(mam); + PB.registerCGSCCAnalyses(cgam); + PB.registerFunctionAnalyses(fam); + PB.registerLoopAnalyses(lam); + PB.crossRegisterProxies(lam, fam, cgam, mam); + + // Create the pass manager. + auto fpm = PB.buildFunctionSimplificationPipeline(llvm::OptimizationLevel::O2, + llvm::ThinOrFullLTOPhase::FullLTOPostLink); + + + // Run the optimizations over all functions in the module being added to + // the JIT. + for(auto& f : *m.getModuleUnlocked()) + { + if(f.empty()) + continue; + fpm.run(f, fam); + } + + return std::move(m); + }); + } + + LLVMJitBackend::~LLVMJitBackend() + { + llvm::Error res = _session->endSession(); + if(res) + { + std::string err; + llvm::raw_string_ostream errorStream(err); + errorStream << res; + std::cout << "Runtime session had an error when ending: " << err << std::endl; + } + }; + + void LLVMJitBackend::stageModule(std::shared_ptr module) + { + LLVMModuleBuilderCtx ctx; + auto newMod = ctx.buildModule( + *module, *_layout, _session->getExecutorProcessControl().getTargetTriple().str(), _orcCtx->getContext()); + if(!newMod) + { + std::cout << "Failed to stage module! " << newMod.err() << std::endl; + return; + } + _stagedModules.push_back(StageContext{ + .ir = module, + .llvm = std::move(newMod.ok()), + }); + } + + Result LLVMJitBackend::processModules() + { + for(auto& mod : _stagedModules) + { + + auto ctx = _orcCtx->getContext(); + auto& lib = _session->createBareJITDylib(mod.ir->id); + /*for(auto& importedModule : irModule.links) + { + if(!_modules.contains(importedModule)) + throw std::runtime_error(irModule.id + " Could not link to required module: " + importedModule); + lib.addToLinkOrder(*_modules.at(importedModule)->lib); + }*/ + + llvm::orc::SymbolMap symbols; + /*for(auto& f : irModule.functions) + { + auto nativeFunc = _nativeSymbols.find(f.name); + if(nativeFunc != _nativeSymbols.end()) + symbols[(*_mangler)(f.name)] = nativeFunc->second; + } + if(!symbols.empty()) + { + if(auto err = lib.define(llvm::orc::absoluteSymbols(symbols))) + throw std::runtime_error("Unable to define native symbols: " + toString(std::move(err))); + }*/ + + // auto module = new JitModule(lib); + // module->rt = lib.getDefaultResourceTracker(); + + /*printf("Loading module %s with structs:\n", irModule.id.c_str()); + for(auto& s : (*deserializedModule)->getIdentifiedStructTypes()) + printf(" %s\n", s->getName().str().c_str());*/ + + /*std::vector exportedStructLayouts; + for(auto& s : irModule.structs) + { + auto structType = llvm::StructType::getTypeByName(*ctx, s.name); + if(!structType) + throw std::runtime_error("Module load failed, unable to find struct metadata: " + s.name); + const llvm::StructLayout* structLayout = _layout->getStructLayout(structType); + exportedStructLayouts.push_back(structLayout); + }*/ + + if(llvm::Error res = _transformLayer->add(lib.getDefaultResourceTracker(), + llvm::orc::ThreadSafeModule(std::move(mod.llvm), *_orcCtx))) + throw std::runtime_error("Unable to add script: " + toString(std::move(res))); + + /*for(auto& glob : irModule.globals)*/ + /*{*/ + /* llvm::Expected globSymbol = _session->lookup({&lib}, + * (*_mangler)(glob.name));*/ + /* if(!globSymbol)*/ + /* throw std::runtime_error("Unable to find exported global: " + toString(globSymbol.takeError()));*/ + /* module->globalNames.insert({glob.name, module->globalVars.size()});*/ + /* module->globalVars.push_back((void*)globSymbol->getAddress());*/ + /*}*/ + + for(auto& func : mod.ir->functions) + { + auto mangledId = (*_mangler)(func.id); + auto functionSym = _session->lookup({&lib}, mangledId); + if(!functionSym) + throw std::runtime_error("Unable to find exported function: " + toString(functionSym.takeError())); + _functions.insert({func.id, functionSym->getAddress().toPtr()}); + } + + /*std::unordered_map> newStructs; + // Extract all structs before populating members, so we can resolve out of order dependencies. + + for(auto& s : irModule.structs) + { + auto sConstructor = module->getFunction(s.constructorSig); + auto sDestructor = module->getFunction(s.destructorSig); + auto sCopyConstructor = module->getFunction(s.copyConstructorSig); + auto sMoveConstructor = module->getFunction(s.moveConstructorSig); + if(!sConstructor || !sDestructor || !sCopyConstructor || !sMoveConstructor) + throw std::runtime_error("Module load failed, missing constructors for: " + s.name); + auto sDef = new StructDef(s.name, sConstructor, sCopyConstructor, sMoveConstructor, sDestructor); + for(auto& tag : s.tags) + sDef->tags.insert(tag); + newStructs.emplace(s.name, sDef); + } + + for(auto& ns : newStructs) + { + if(_types.contains(ns.first)) + throw std::runtime_error("Module load failed, cannot load a type twice: " + ns.first); + _types.insert({ns.first, ns.second.get()}); + } + + size_t structIndex = 0; + for(auto& ns : newStructs) + { + auto structLayout = exportedStructLayouts[structIndex]; + auto& srcStruct = irModule.structs[structIndex]; + ns.second->size = structLayout->getSizeInBytes(); + for(size_t i = 0; i < srcStruct.members.size(); i++) + { + auto& src = srcStruct.members[i]; + auto memberOffset = structLayout->getElementOffset(i); + auto memberType = _types.find(src.type); + if(memberType == _types.end()) + throw std::runtime_error("Module load failed, unable to find type metadata: " + src.type); + ns.second->memberVars.push_back( + StructVar{src.name, VarType{memberType->second, src.isRef}, memberOffset}); + } + + printf("Adding struct %s\n", ns.first.c_str()); + printf("Size: %zu\n", ns.second->size); + printf("Members:\n"); + for(auto& m : ns.second->memberVars) + printf(" %s %s, offset: %zu\n", m.type.def->name.c_str(), m.name.c_str(), m.offset); + + module->structDefinitions.insert({ns.first, std::move(ns.second)}); + ++structIndex; + }*/ + + /*_modules.insert(irModule.id, module); + for(auto& dep : irModule.links) + _modules.addDependency(irModule.id, dep);*/ + } + return Ok<>(); + } + + Option> LLVMJitBackend::getPipeline(std::string_view moduleName, + std::string_view pipelineName) + { + return None(); + } + + std::unordered_map& LLVMJitBackend::functions() { return _functions; } +} // namespace BraneScript diff --git a/src/runtime/backends/llvm/llvmJitBackend.cpz~ b/src/runtime/backends/llvm/llvmJitBackend.cpz~ new file mode 100644 index 0000000..6ad8b2e --- /dev/null +++ b/src/runtime/backends/llvm/llvmJitBackend.cpz~ @@ -0,0 +1,787 @@ +#include "llvmJitBackend.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +#include +#include +#include + +namespace BraneScript +{ + struct LLVMModuleBuilderCtx + { + llvm::LLVMContext* llvmCtx; + // IR builder + std::shared_ptr> builder; + std::unique_ptr llvmMod; + + + // Debug info builer + std::unique_ptr dBuilder; + llvm::DICompileUnit* diCompileUnit = nullptr; + llvm::DIFile* diFile = nullptr; + + // walker state + llvm::Function* currentFunc = nullptr; + llvm::BasicBlock* currentBlock = nullptr; + ; + + llvm::DISubprogram* diFunction = nullptr; + + const IRModule* currentMod; + + struct StructTypeCtx + { + llvm::StructType* llvmType; + const IRStruct* bsType; + }; + + struct FunctionCtx + { + llvm::Function* llvmFunc; + const IRFunction* bsFunc; + }; + + struct NativeTypeCtx + { + llvm::Type* llvmType = nullptr; + std::variant type; + }; + + using TypeContext = std::variant; + + struct ValueCtx + { + llvm::Value* value = nullptr; + TypeContext type; + }; + + std::vector structs; + std::vector functions; + std::vector values; + + // Cache values + llvm::FunctionType* functionType = nullptr; + llvm::Value* memTable = nullptr; + llvm::Type* branePtrType = nullptr; + ValueCtx retPtr; + + Result getNativeType(IRNativeType type) const + { + llvm::Type* llvmType = nullptr; + switch(type) + { + case IRNativeType::U8: + case IRNativeType::I8: + llvmType = llvm::Type::getInt8Ty(*llvmCtx); + case IRNativeType::U16: + case IRNativeType::I16: + llvmType = llvm::Type::getInt16Ty(*llvmCtx); + case IRNativeType::U32: + case IRNativeType::I32: + llvmType = llvm::Type::getInt32Ty(*llvmCtx); + case IRNativeType::F32: + llvmType = llvm::Type::getFloatTy(*llvmCtx); + case IRNativeType::U64: + case IRNativeType::I64: + llvmType = llvm::Type::getInt64Ty(*llvmCtx); + case IRNativeType::F64: + llvmType = llvm::Type::getDoubleTy(*llvmCtx); + case IRNativeType::U128: + case IRNativeType::I128: + llvmType = llvm::Type::getInt128Ty(*llvmCtx); + } + if(!llvmType) + return Err("Invalid IRNativeType value"); + return Ok(NativeTypeCtx{.llvmType = llvmType, .type = type}); + }; + + Result getStructType(const IDRef& id) + { + return MATCHV(id, + [&](const std::string& idStr) -> Result + { + for(auto& s : structs) + { + if(s.bsType->id.isNone()) + continue; + if(s.bsType->id.value() == idStr) + return Ok(s); + } + return Err(std::format("No struct defined with id {}", idStr)); + }, + [&](int32_t idIndex) -> Result + { + if(idIndex >= structs.size()) + return Err(std::format("Struct index out of range")); + return Ok(structs[idIndex]); + }); + } + + static llvm::Type* getLLVMType(const TypeContext& ctx) + { + return MATCHV(ctx, [](auto& ctx) -> llvm::Type* { return ctx.llvmType; }); + } + + Result getType(const IRType& type) + { + return MATCHV(type, + [&](const IRNativeType& type) -> Result + { + auto res = getNativeType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }, + [&](const IDRef& type) -> Result + { + auto res = getStructType(type); + if(!res) + return Err(res.err()); + return Ok(res.ok()); + }); + }; + + Result getValue(const IRValue& value) + { + if(value.id >= values.size()) + return Err(std::format("Value index {} out of range", value.id)); + return Ok(values[value.id]); + } + + // Evalutate a int32 brane script pointer against the current memory table to get the full pointer + llvm::Value* evaluatePtr(llvm::Type* type, llvm::Value* intPtr) + { + auto* shiftedHigh = builder->CreateLShr(intPtr, builder->getInt32(16), ""); + auto* maskedLow = builder->CreateAnd(intPtr, builder->getInt32(0xFFFF), ""); + + auto* memIndex = builder->CreateTrunc(shiftedHigh, builder->getInt16Ty(), ""); + auto* bindingIndex = builder->CreateTrunc(maskedLow, builder->getInt16Ty(), "l"); + + auto* basePtr = builder->CreateGEP(memTable->getType(), memTable, {bindingIndex}); + auto* offsetPtr = builder->CreatePtrAdd(basePtr, memIndex); + return builder->CreateBitOrPointerCast(offsetPtr, llvm::PointerType::get(type, 0)); + } + + Result getStructMember(ValueCtx structValue, size_t member) + { + auto* structTypeCtx = std::get_if(&structValue.type); + if(!structTypeCtx) + return Err("Tried to get member of type that was not a struct"); + + if(member > structTypeCtx->bsType->members.size()) + return Err(std::format("Tried to get member {} but struct only has {} members", + member, + structTypeCtx->bsType->members.size())); + + + auto memberValue = builder->CreateStructGEP(structTypeCtx->llvmType, structValue.value, member); + auto memberTypeRes = getType(structTypeCtx->bsType->members[member].type); + if(!memberTypeRes) + return Err("Could not resove member type: " + memberTypeRes.err()); + + return Ok(ValueCtx{ + .value = memberValue, + .type = memberTypeRes.ok(), + }); + } + + Result<> buildMov(const MovOp& mov) + { + auto srcRes = getValue(mov.src); + CHECK_RESULT(srcRes); + auto destRes = getValue(mov.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcPtr.type), srcPtr.value); + builder->CreateStore(srcValue, destPtr.value); + return Ok<>(); + } + + Result<> buildLoad(const LoadOp& load) + { + auto branePtrRes = getValue(load.ptr); + CHECK_RESULT(branePtrRes); + auto destRes = getValue(load.dest); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto destValue = destRes.ok(); + + + auto loadedPtr = evaluatePtr(getLLVMType(destValue.type), branePtr.value); + builder->CreateStore(loadedPtr, destValue.value); + return Ok<>(); + } + + Result<> buildStore(const StoreOp& store) + { + + auto branePtrRes = getValue(store.ptr); + CHECK_RESULT(branePtrRes); + auto srcRes = getValue(store.src); + + auto branePtr = branePtrRes.ok(); + auto branePtrValue = builder->CreateLoad(getLLVMType(branePtr.type), branePtr.value); + auto srcValuePtr = srcRes.ok(); + + auto srcValue = builder->CreateLoad(getLLVMType(srcValuePtr.type), srcValuePtr.value); + + + auto loadedPtr = evaluatePtr(getLLVMType(srcValuePtr.type), branePtr.value); + builder->CreateStore(srcValue, loadedPtr); + return Ok<>(); + } + + Result<> buildMemAcc(const MemAcc& memAcc) + { + auto srcRes = getValue(memAcc.ptr); + CHECK_RESULT(srcRes); + auto destRes = getValue(memAcc.dest); + CHECK_RESULT(destRes); + + auto srcPtr = srcRes.ok(); + auto destPtr = destRes.ok(); + + auto memberValue = getStructMember(srcPtr, memAcc.index); + CHECK_RESULT(memberValue); + values[memAcc.dest.id] = memberValue.ok(); + + return Ok<>(); + } + + Result<> buildAdd(const AddOp& op) + { + auto leftRes = getValue(op.left); + CHECK_RESULT(leftRes); + auto rightRes = getValue(op.right); + CHECK_RESULT(leftRes); + auto destRes = getValue(op.out); + CHECK_RESULT(destRes); + + auto leftPtr = leftRes.ok(); + auto rightPtr = leftRes.ok(); + auto destPtr = destRes.ok(); + + auto leftValue = builder->CreateLoad(getLLVMType(leftPtr.type), leftPtr.value); + auto rightValue = builder->CreateLoad(getLLVMType(rightPtr.type), rightPtr.value); + + auto res = builder->CreateAdd(leftValue, rightValue); + builder->CreateStore(res, destPtr.value); + + return Ok<>(); + } + + Result<> buildNextStage(const NextStageOp& call) + { + auto inputRes = getValue(call.input); + CHECK_RESULT(inputRes); + auto inputValue = inputRes.ok(); + auto inputStructType = std::get_if(&inputValue.type); + if(!inputStructType) + return Err("Tried to return non struct type"); + + auto memberCount = inputStructType->bsType->members.size(); + retPtr.type = inputValue.type; // "Temporary workaround" + for(size_t i = 0; i < memberCount; ++i) + { + auto inputMemberRes = getStructMember(inputValue, i); + + auto outputMemberRes = getStructMember(retPtr, i); + CHECK_RESULT(inputMemberRes); + CHECK_RESULT(outputMemberRes); + auto inputMemberValue = inputMemberRes.ok(); + builder->CreateStore(builder->CreateLoad(getLLVMType(inputMemberValue.type), inputMemberValue.value), + outputMemberRes.ok().value); + } + builder->CreateRetVoid(); + return Ok<>(); + } + + Result<> buildOp(const IROperation& op) + { + return MATCHV(op, + [&](const MovOp& mov) { return buildMov(mov); }, + [&](const LoadOp& load) -> Result<> { return buildLoad(load); }, + [&](const StoreOp& store) -> Result<> { return buildStore(store); }, + [&](const MemAcc& memAcc) -> Result<> { return buildMemAcc(memAcc); }, + [&](const ConstI32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstU32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ConstF32& c) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const AddOp& op) -> Result<> { return buildAdd(op); }, + [&](const SubOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const MulOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const DivOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const ModOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const EqOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const GeOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicNotOp& op) -> Result<> + { // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const LogicOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitNotOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitAndOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitOrOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const BitXorOp& op) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const CallOp& call) -> Result<> + { + // TODO: Implement + return Err("Operand not implmeneted"); + }, + [&](const NextStageOp& call) -> Result<> { return buildNextStage(call); } + + ); + } + + Result> buildModule(const IRModule& module, + llvm::DataLayout& layout, + std::string_view triplet, + llvm::LLVMContext* ctxPtr) + { + llvmCtx = ctxPtr; + builder = std::make_shared>(*llvmCtx); + + currentMod = &module; + llvmMod = std::make_unique(module.id, *llvmCtx); + + llvmMod->setDataLayout(layout); // *_layout); + llvmMod->setTargetTriple(triplet); //_session->getExecutorProcessControl().getTargetTriple().str()); + + + branePtrType = llvm::ArrayType::get(llvm::Type::getInt16Ty(*llvmCtx), 2); + + structs.reserve(module.structs.size()); + for(size_t i = 0; i < module.structs.size(); ++i) + { + auto& s = module.structs[i]; + + std::string sid = s.id ? s.id.value() : std::format("-s{}", i); + + std::vector memberTypes; + for(size_t m = 0; m < s.members.size(); ++m) + { + std::string mid = s.members[m].id ? s.members[m].id.value() : std::format("-{}", m); + auto llvmType = getType(s.members[m].type); + if(!llvmType) + return Err(llvmType.err()); + + memberTypes.push_back(getLLVMType(llvmType.ok())); + } + + llvm::StructType* st; + if(memberTypes.empty()) + st = llvm::StructType::create(*llvmCtx, sid); + else + st = llvm::StructType::create(memberTypes, sid, false); + + structs.push_back(StructTypeCtx{.llvmType = st, .bsType = &s}); + } + + auto memTableType = llvm::ArrayType::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*llvmCtx), 0), 65535); + llvm::Type* funcArgs[3] = {llvm::PointerType::get(memTableType, + 0), // Pointer to array of addresses + llvm::Type::getInt32Ty(*llvmCtx), // Brane pointer to args struct + llvm::Type::getInt32Ty(*llvmCtx)}; // Brane pointer to output struct + functionType = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvmCtx), funcArgs, false); + + + functions.reserve(module.functions.size()); + for(size_t i = 0; i < module.functions.size(); ++i) + { + auto& f = module.functions[i]; + + auto inputRes = getStructType(f.input); + if(!inputRes) + return Err(std::format("Couldn't compile input type for {}: {}", f.id, inputRes.err())); + auto inputType = inputRes.ok(); + + auto outputRes = getStructType(f.output); + if(!outputRes) + return Err(std::format("Couldn't compile output type for {}: {}", f.id, inputRes.err())); + auto outputType = outputRes.ok(); + + auto linkage = llvm::Function::ExternalLinkage; + currentFunc = llvm::Function::Create(functionType, linkage, f.id, llvmMod.get()); + currentFunc->setCallingConv(llvm::CallingConv::C); + currentBlock = llvm::BasicBlock::Create(*llvmCtx, "entry", currentFunc); + builder->SetInsertPoint(currentBlock); + + memTable = currentFunc->getArg(0); + auto inputStruct = + ValueCtx{.value = evaluatePtr(inputType.llvmType, currentFunc->getArg(1)), .type = inputType}; + + // Really need a better return mechanism + retPtr = + ValueCtx{.value = evaluatePtr(outputType.llvmType, currentFunc->getArg(2)), .type = outputType}; + + auto inputCount = inputType.bsType->members.size(); + for(size_t i = 0; i < f.localVars.size(); ++i) + { + auto& var = f.localVars[i]; + if(i < inputCount) + { + auto argRes = getStructMember(inputStruct, i); + if(!argRes) + return Err(argRes.err()); + values.push_back(argRes.ok()); + std::cout << i << " Re-route function arg member" << std::endl; + } + else + { + auto typeRes = getType(f.localVars[i]); + if(!typeRes) + return Err(typeRes.err()); + MATCHV(typeRes.ok(), + [&](NativeTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new native value" << std::endl; + }, + [&](StructTypeCtx valueType) + { + auto* value = builder->CreateAlloca(valueType.llvmType); + values.push_back(ValueCtx{.value = value, .type = valueType}); + std::cout << i << " Allocated new struct value" << std::endl; + }); + } + } + + for(auto& op : f.operations) + { + auto res = buildOp(op); + if(!res) + return Err(res.err()); + } + + values.clear(); + std::string funcError; + llvm::raw_string_ostream funcErrorStream(funcError); + if(llvm::verifyFunction(*currentFunc, &funcErrorStream)) + return Err(funcError); + } + + std::string modError; + llvm::raw_string_ostream modErrorStream(modError); + if(llvm::verifyModule(*llvmMod, &modErrorStream)) + return Err(modError); + + std::string moduleContent; + llvm::raw_string_ostream moduleStream(moduleContent); + llvmMod->print(moduleStream, nullptr); + std::cout << "Staging module: \n" << moduleContent << std::endl; + + return Ok(std::move(llvmMod)); + } + + LLVMModuleBuilderCtx() {} + }; // namespace BraneScript + + bool llvmInitialized = false; + + LLVMJitBackend::LLVMJitBackend() + { + if(!llvmInitialized) + { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvmInitialized = true; + } + + _orcCtx = std::make_unique(std::make_unique()); + + auto EPC = llvm::orc::SelfExecutorProcessControl::Create(); + if(!EPC) + throw std::runtime_error("Could not create llvm epc"); + _session = std::make_unique(std::move(*EPC)); + + llvm::orc::JITTargetMachineBuilder jtmb(_session->getExecutorProcessControl().getTargetTriple()); + jtmb.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive); + + auto dl = jtmb.getDefaultDataLayoutForTarget(); + if(!dl) + throw std::runtime_error("Unable to get target data layout: " + toString(dl.takeError())); + _layout = std::make_unique(*dl); + + _mangler = std::make_unique(*_session, *_layout); + + _linkingLayer = std::make_unique( + *_session, []() { return std::make_unique(); }); + + // Debug stuff + _linkingLayer->registerJITEventListener(*llvm::JITEventListener::createGDBRegistrationListener()); + _linkingLayer->setProcessAllSections(true); + + _compileLayer = std::make_unique( + *_session, *_linkingLayer, std::make_unique(jtmb)); + std::shared_ptr tm; + if(auto etm = jtmb.createTargetMachine()) + tm = std::move(*etm); + _transformLayer = + std::make_unique(*_session, + *_compileLayer, + [this, tm](llvm::orc::ThreadSafeModule m, const auto& r) + { + // Add some optimizations. + std::string moduleErr; + llvm::raw_string_ostream modErrStr(moduleErr); + if(llvm::verifyModule(*m.getModuleUnlocked(), &modErrStr)) + { + std::string module; + llvm::raw_string_ostream modStr(module); + m.getModuleUnlocked()->print(modStr, nullptr); + fprintf(stderr, "%s", module.c_str()); + std::cout << "Module verification failed: " + modErrStr.str() << std::endl; + // TODO report failure async + } + + // Only return here if debug + std::cout << "Loaded module into runtime" << std::endl; + return std::move(m); + + llvm::LoopAnalysisManager lam; + llvm::FunctionAnalysisManager fam; + llvm::CGSCCAnalysisManager cgam; + llvm::ModuleAnalysisManager mam; + + llvm::PassBuilder PB(tm.get()); + + // Register all the basic analyses with the managers. + + PB.registerModuleAnalyses(mam); + PB.registerCGSCCAnalyses(cgam); + PB.registerFunctionAnalyses(fam); + PB.registerLoopAnalyses(lam); + PB.crossRegisterProxies(lam, fam, cgam, mam); + + // Create the pass manager. + auto fpm = PB.buildFunctionSimplificationPipeline(llvm::OptimizationLevel::O2, + llvm::ThinOrFullLTOPhase::FullLTOPostLink); + + + // Run the optimizations over all functions in the module being added to + // the JIT. + for(auto& f : *m.getModuleUnlocked()) + { + if(f.empty()) + continue; + fpm.run(f, fam); + } + + return std::move(m); + }); + } + + LLVMJitBackend::~LLVMJitBackend() + { + llvm::Error res = _session->endSession(); + if(res) + { + std::string err; + llvm::raw_string_ostream errorStream(err); + errorStream << res; + std::cout << "Runtime session had an error when ending: " << err << std::endl; + } + }; + + void LLVMJitBackend::stageModule(std::shared_ptr module) + { + LLVMModuleBuilderCtx ctx; + auto newMod = ctx.buildModule( + *module, *_layout, _session->getExecutorProcessControl().getTargetTriple().str(), _orcCtx->getContext()); + if(!newMod) + { + std::cout << "Failed to stage module! " << newMod.err() << std::endl; + return; + } + _stagedModules.push_back(StageContext{ + .ir = module, + .llvm = std::move(newMod.ok()), + }); + } + + Result LLVMJitBackend::processModules() + { + for(auto& mod : _stagedModules) + { + + auto ctx = _orcCtx->getContext(); + auto& lib = _session->createBareJITDylib(mod.ir->id); + + if(llvm::Error res = _transformLayer->add(lib, llvm::orc::ThreadSafeModule(std::move(mod.llvm), *_orcCtx))) + throw std::runtime_error("Unable to add script: " + toString(std::move(res))); + + for(auto& func : mod.ir->functions) + { + auto mangledId = (*_mangler)(func.id); + auto functionSym = _session->lookup({&lib}, mangledId); + if(!functionSym) + throw std::runtime_error("Unable to find exported function: " + toString(functionSym.takeError())); + _functions.insert({func.id, functionSym->getAddress().toPtr()}); + } + + + /*std::unordered_map> newStructs; + // Extract all structs before populating members, so we can resolve out of order dependencies. + + for(auto& s : irModule.structs) + { + auto sConstructor = module->getFunction(s.constructorSig); + auto sDestructor = module->getFunction(s.destructorSig); + auto sCopyConstructor = module->getFunction(s.copyConstructorSig); + auto sMoveConstructor = module->getFunction(s.moveConstructorSig); + if(!sConstructor || !sDestructor || !sCopyConstructor || !sMoveConstructor) + throw std::runtime_error("Module load failed, missing constructors for: " + s.name); + auto sDef = new StructDef(s.name, sConstructor, sCopyConstructor, sMoveConstructor, sDestructor); + for(auto& tag : s.tags) + sDef->tags.insert(tag); + newStructs.emplace(s.name, sDef); + } + + for(auto& ns : newStructs) + { + if(_types.contains(ns.first)) + throw std::runtime_error("Module load failed, cannot load a type twice: " + ns.first); + _types.insert({ns.first, ns.second.get()}); + } + + size_t structIndex = 0; + for(auto& ns : newStructs) + { + auto structLayout = exportedStructLayouts[structIndex]; + auto& srcStruct = irModule.structs[structIndex]; + ns.second->size = structLayout->getSizeInBytes(); + for(size_t i = 0; i < srcStruct.members.size(); i++) + { + auto& src = srcStruct.members[i]; + auto memberOffset = structLayout->getElementOffset(i); + auto memberType = _types.find(src.type); + if(memberType == _types.end()) + throw std::runtime_error("Module load failed, unable to find type metadata: " + src.type); + ns.second->memberVars.push_back( + StructVar{src.name, VarType{memberType->second, src.isRef}, memberOffset}); + } + + printf("Adding struct %s\n", ns.first.c_str()); + printf("Size: %zu\n", ns.second->size); + printf("Members:\n"); + for(auto& m : ns.second->memberVars) + printf(" %s %s, offset: %zu\n", m.type.def->name.c_str(), m.name.c_str(), m.offset); + + module->structDefinitions.insert({ns.first, std::move(ns.second)}); + ++structIndex; + }*/ + + /*_modules.insert(irModule.id, module); + for(auto& dep : irModule.links) + _modules.addDependency(irModule.id, dep);*/ + } + return Ok<>(); + } + + Option> LLVMJitBackend::getPipeline(std::string_view moduleName, + std::string_view pipelineName) + { + return None(); + } + + std::unordered_map& LLVMJitBackend::functions() { return _functions; } +} // namespace BraneScript diff --git a/src/runtime/backends/llvm/llvmJitBackend.h b/src/runtime/backends/llvm/llvmJitBackend.h new file mode 100644 index 0000000..1fec6e1 --- /dev/null +++ b/src/runtime/backends/llvm/llvmJitBackend.h @@ -0,0 +1,55 @@ +#pragma once + +#include "../jitBackend.h" + +namespace llvm +{ + class DataLayout; + class Type; + class JITEvaluatedSymbol; + class Module; + class LLVMContext; + + namespace orc + { + class LLJIT; + class ExecutionSession; + class IRCompileLayer; + class MangleAndInterner; + class RTDyldObjectLinkingLayer; + class IRTransformLayer; + class JITDylib; + class ThreadSafeModule; + } // namespace orc +} // namespace llvm + +namespace BraneScript +{ + class LLVMJitBackend : public JitBackend + { + struct StageContext + { + std::shared_ptr ir; + std::unique_ptr llvm; + }; + + std::vector _stagedModules; + + std::unique_ptr _llJit; + + std::unordered_map _nativeSymbols; + + std::unordered_map _functions; + + public: + LLVMJitBackend(); + ~LLVMJitBackend(); + void stageModule(std::shared_ptr module) override; + /// Consume all staged modules + Result processModules() override; + Option> getPipeline(std::string_view moduleName, + std::string_view pipelineName) override; + + std::unordered_map& functions(); + }; +} // namespace BraneScript diff --git a/src/runtime/runtime.cpp b/src/runtime/runtime.cpp new file mode 100644 index 0000000..e69de29 diff --git a/src/runtime/runtime.h b/src/runtime/runtime.h new file mode 100644 index 0000000..2f587fc --- /dev/null +++ b/src/runtime/runtime.h @@ -0,0 +1,12 @@ +#pragma once +#include "backends/jitBackend.h" + +namespace BraneScript +{ + class Runtime + { + std::shared_ptr _jitBackend; + + public: + }; +} // namespace BraneScript diff --git a/vcpkg.json b/vcpkg.json index 209aa74..623d743 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,5 +1,16 @@ { "dependencies": [ + { + "name": "llvm", + "default-features": false, + "features": [ + "enable-assertions", + "enable-eh", + "enable-ffi", + "utils", + "default-targets" + ] + }, { "name": "tree-sitter" },