diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 2baf922db..e047f7133 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -1,9 +1,13 @@ #pragma once #include +#include #include +#include // Required for std::less +#include #include #include +#include #include #include #include @@ -24,331 +28,194 @@ namespace rnexecutorch::jsi_conversion { using namespace facebook; -// Conversion from jsi to C++ types -------------------------------------------- - -template T getValue(const jsi::Value &val, jsi::Runtime &runtime); +// ================================================================================================= +// HELPERS (Internal) +// ================================================================================================= +namespace detail { template - requires meta::IsNumeric -inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) { - return static_cast(val.asNumber()); -} - -template <> -inline bool getValue(const jsi::Value &val, jsi::Runtime &runtime) { - return val.asBool(); -} - -template <> -inline std::string getValue(const jsi::Value &val, - jsi::Runtime &runtime) { - return val.getString(runtime).utf8(runtime); -} - -template <> -inline std::u32string getValue(const jsi::Value &val, - jsi::Runtime &runtime) { - std::string utf8 = getValue(val, runtime); - std::wstring_convert, char32_t> conv; - - return conv.from_bytes(utf8); -} - -template <> -inline std::shared_ptr -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return std::make_shared( - val.asObject(runtime).asFunction(runtime)); -} - -template <> -inline JSTensorViewIn getValue(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::pair getTypedArrayData(const jsi::Value &val, + jsi::Runtime &runtime) { jsi::Object obj = val.asObject(runtime); - JSTensorViewIn tensorView; - int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber(); - tensorView.scalarType = static_cast(scalarTypeInt); - - jsi::Value shapeValue = obj.getProperty(runtime, "sizes"); - jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime); - size_t numShapeDims = shapeArray.size(runtime); - tensorView.sizes.reserve(numShapeDims); - - for (size_t i = 0; i < numShapeDims; ++i) { - int32_t dim = - getValue(shapeArray.getValueAtIndex(runtime, i), runtime); - tensorView.sizes.push_back(dim); + if (obj.isArrayBuffer(runtime)) { + jsi::ArrayBuffer buffer = obj.getArrayBuffer(runtime); + return {reinterpret_cast(buffer.data(runtime)), + buffer.size(runtime) / sizeof(T)}; } - // On JS side, TensorPtr objects hold a 'data' property which should be either - // an ArrayBuffer or TypedArray - jsi::Value dataValue = obj.getProperty(runtime, "dataPtr"); - jsi::Object dataObj = dataValue.asObject(runtime); - - // Check if it's an ArrayBuffer or TypedArray - if (dataObj.isArrayBuffer(runtime)) { - jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime); - tensorView.dataPtr = arrayBuffer.data(runtime); - } else { - // Handle typed arrays (Float32Array, Int32Array, etc.) - const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") && - dataObj.hasProperty(runtime, "byteOffset") && - dataObj.hasProperty(runtime, "byteLength") && - dataObj.hasProperty(runtime, "length"); - if (!isValidTypedArray) { - throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray"); - } - jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer"); - if (!bufferValue.isObject() || - !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { - throw jsi::JSError(runtime, - "TypedArray buffer property must be an ArrayBuffer"); - } + bool isValidTypedArray = obj.hasProperty(runtime, "buffer") && + obj.hasProperty(runtime, "byteOffset") && + obj.hasProperty(runtime, "byteLength") && + obj.hasProperty(runtime, "length"); - jsi::ArrayBuffer arrayBuffer = - bufferValue.asObject(runtime).getArrayBuffer(runtime); - size_t byteOffset = - getValue(dataObj.getProperty(runtime, "byteOffset"), runtime); - - tensorView.dataPtr = - static_cast(arrayBuffer.data(runtime)) + byteOffset; - } - return tensorView; -} - -// C++ set from JS array. Set with heterogenerous look-up (adding std::less<> -// enables querying with std::string_view). -template <> -inline std::set> -getValue>>(const jsi::Value &val, - jsi::Runtime &runtime) { - - jsi::Array array = val.asObject(runtime).asArray(runtime); - size_t length = array.size(runtime); - std::set> result; - - for (size_t i = 0; i < length; ++i) { - jsi::Value element = array.getValueAtIndex(runtime, i); - result.insert(getValue(element, runtime)); - } - return result; -} - -// Helper function to convert typed arrays to std::span -template -inline std::span getTypedArrayAsSpan(const jsi::Value &val, - jsi::Runtime &runtime) { - jsi::Object obj = val.asObject(runtime); - - const bool isValidTypedArray = obj.hasProperty(runtime, "buffer") && - obj.hasProperty(runtime, "byteOffset") && - obj.hasProperty(runtime, "byteLength") && - obj.hasProperty(runtime, "length"); if (!isValidTypedArray) { - throw jsi::JSError(runtime, "Value must be a TypedArray"); + throw jsi::JSError(runtime, "Value must be an ArrayBuffer or TypedArray"); } - // Get the underlying ArrayBuffer jsi::Value bufferValue = obj.getProperty(runtime, "buffer"); if (!bufferValue.isObject() || !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { - throw jsi::JSError(runtime, - "TypedArray buffer property must be an ArrayBuffer"); + throw jsi::JSError(runtime, "TypedArray buffer must be an ArrayBuffer"); } jsi::ArrayBuffer arrayBuffer = bufferValue.asObject(runtime).getArrayBuffer(runtime); size_t byteOffset = - getValue(obj.getProperty(runtime, "byteOffset"), runtime); - size_t length = getValue(obj.getProperty(runtime, "length"), runtime); + static_cast(obj.getProperty(runtime, "byteOffset").asNumber()); + size_t length = + static_cast(obj.getProperty(runtime, "length").asNumber()); - T *dataPtr = reinterpret_cast( - static_cast(arrayBuffer.data(runtime)) + byteOffset); - - return {dataPtr, length}; + uint8_t *rawData = arrayBuffer.data(runtime) + byteOffset; + return {reinterpret_cast(rawData), length}; } -template -inline std::vector getArrayAsVector(const jsi::Value &val, - jsi::Runtime &runtime) { - jsi::Array array = val.asObject(runtime).asArray(runtime); - const size_t length = array.size(runtime); - std::vector result; - result.reserve(length); - - for (size_t i = 0; i < length; ++i) { - const jsi::Value element = array.getValueAtIndex(runtime, i); - result.push_back(getValue(element, runtime)); - } - return result; -} - -// Template specializations for std::vector types -template <> -inline std::vector -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); -} - -template <> -inline std::vector -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); -} +} // namespace detail -template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); -} +// ================================================================================================= +// JS -> C++ (JsiGetter Struct) +// We use a struct to allow partial specialization for vectors/spans +// ================================================================================================= -template <> -inline std::vector getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); -} +// Forward Declaration +template struct JsiGetter; -template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); +// Public API Wrapper +template T getValue(const jsi::Value &val, jsi::Runtime &runtime) { + return JsiGetter::get(val, runtime); } -template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getArrayAsVector(val, runtime); -} - -// Template specializations for std::span types -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} - -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} +template struct JsiGetter { + static T get(const jsi::Value &val, jsi::Runtime &runtime) { + if constexpr (meta::IsNumeric) { + return static_cast(val.asNumber()); + } else { + // Fallback for unsupported types + throw jsi::JSError(runtime, "Unsupported type conversion"); + } + } +}; -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} +template <> struct JsiGetter { + static bool get(const jsi::Value &val, jsi::Runtime &runtime) { + return val.asBool(); + } +}; -template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} +template <> struct JsiGetter { + static std::string get(const jsi::Value &val, jsi::Runtime &runtime) { + return val.getString(runtime).utf8(runtime); + } +}; + +template <> struct JsiGetter { + static std::u32string get(const jsi::Value &val, jsi::Runtime &runtime) { + std::string utf8 = getValue(val, runtime); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + std::wstring_convert, char32_t> conv; + return conv.from_bytes(utf8); +#pragma clang diagnostic pop + } +}; -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} +template <> struct JsiGetter> { + static std::shared_ptr get(const jsi::Value &val, + jsi::Runtime &runtime) { + return std::make_shared( + val.asObject(runtime).asFunction(runtime)); + } +}; -template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} +template <> struct JsiGetter { + static JSTensorViewIn get(const jsi::Value &val, jsi::Runtime &runtime) { + jsi::Object obj = val.asObject(runtime); + JSTensorViewIn tensorView; -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} + tensorView.scalarType = static_cast( + static_cast(obj.getProperty(runtime, "scalarType").asNumber())); -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} + jsi::Array shapeArray = + obj.getProperty(runtime, "sizes").asObject(runtime).asArray(runtime); + size_t numDims = shapeArray.size(runtime); + tensorView.sizes.reserve(numDims); -template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} + for (size_t i = 0; i < numDims; ++i) { + tensorView.sizes.push_back(static_cast( + shapeArray.getValueAtIndex(runtime, i).asNumber())); + } -template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { - return getTypedArrayAsSpan(val, runtime); -} + // On JS side, TensorPtr objects hold a 'data' property which should be + // either an ArrayBuffer or TypedArray + auto [ptr, _] = detail::getTypedArrayData( + obj.getProperty(runtime, "dataPtr"), runtime); + tensorView.dataPtr = ptr; -// Conversion from C++ types to jsi -------------------------------------------- + return tensorView; + } +}; -// Implementation functions might return any type, but in a promise we can only -// return jsi::Value or jsi::Object. For each type being returned -// we add a function here. +// C++ set from JS array. Set with heterogenerous look-up (adding std::less<> +// enables querying with std::string_view). +template <> struct JsiGetter>> { + static std::set> get(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + size_t length = array.size(runtime); + std::set> result; + + for (size_t i = 0; i < length; ++i) { + // Explicitly get string to avoid ambiguity + result.insert( + getValue(array.getValueAtIndex(runtime, i), runtime)); + } + return result; + } +}; -inline jsi::Value getJsiValue(std::shared_ptr valuePtr, - jsi::Runtime &runtime) { - return std::move(*valuePtr); -} +template struct JsiGetter> { + static std::vector get(const jsi::Value &val, jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + size_t length = array.size(runtime); + std::vector result; + result.reserve(length); -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); + for (size_t i = 0; i < length; ++i) { + result.push_back(getValue(array.getValueAtIndex(runtime, i), runtime)); + } + return result; } - return {runtime, array}; -} +}; -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); +template struct JsiGetter> { + static std::span get(const jsi::Value &val, jsi::Runtime &runtime) { + auto [ptr, len] = detail::getTypedArrayData(val, runtime); + return std::span{ptr, len}; } - return {runtime, array}; -} +}; -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(vec[i])); - } - return {runtime, array}; -} +// ================================================================================================= +// C++ -> JS (getJsiValue) +// ================================================================================================= -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, - jsi::String::createFromUtf8(runtime, vec[i])); - } - return {runtime, array}; +inline jsi::Value getJsiValue(int val, jsi::Runtime & /*runtime*/) { + return {val}; } - -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { - jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, jsi::Value(vec[i])); - } - return {runtime, array}; +inline jsi::Value getJsiValue(bool val, jsi::Runtime & /*runtime*/) { + return {val}; +} +inline jsi::Value getJsiValue(double val, jsi::Runtime & /*runtime*/) { + return {val}; +} +inline jsi::Value getJsiValue(float val, jsi::Runtime & /*runtime*/) { + return {static_cast(val)}; } -// Conditional as on android, size_t and uint64_t reduce to the same type, -// introducing ambiguity template && !std::is_same_v>> inline jsi::Value getJsiValue(T val, jsi::Runtime &runtime) { - return jsi::Value(static_cast(val)); + return {static_cast(val)}; } inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { @@ -356,12 +223,13 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { - return {runtime, val}; +inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) { + return jsi::String::createFromUtf8(runtime, str); } -inline jsi::Value getJsiValue(bool val, jsi::Runtime &runtime) { - return jsi::Value(val); +inline jsi::Value getJsiValue(std::shared_ptr valuePtr, + jsi::Runtime &runtime) { + return std::move(*valuePtr); } inline jsi::Value getJsiValue(const std::shared_ptr &buf, @@ -370,13 +238,19 @@ inline jsi::Value getJsiValue(const std::shared_ptr &buf, return {runtime, arrayBuffer}; } -inline jsi::Value -getJsiValue(const std::vector> &vec, - jsi::Runtime &runtime) { +template +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); - for (size_t i = 0; i < vec.size(); i++) { - jsi::ArrayBuffer arrayBuffer(runtime, vec[i]); - array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer)); + for (size_t i = 0; i < vec.size(); ++i) { + if constexpr (std::is_same_v && + !std::is_same_v) { + // Conditional as on android, size_t and uint64_t reduce to the same type, + // introducing ambiguity + array.setValueAtIndex(runtime, i, static_cast(vec[i])); + } else { + array.setValueAtIndex(runtime, i, getJsiValue(vec[i], runtime)); + } } return {runtime, array}; } @@ -386,30 +260,21 @@ inline jsi::Value getJsiValue(const std::vector &vec, jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { jsi::Object tensorObj(runtime); - tensorObj.setProperty(runtime, "sizes", getJsiValue(vec[i].sizes, runtime)); - tensorObj.setProperty(runtime, "scalarType", - jsi::Value(static_cast(vec[i].scalarType))); - - jsi::ArrayBuffer arrayBuffer(runtime, vec[i].dataPtr); - tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); - + static_cast(vec[i].scalarType)); + tensorObj.setProperty(runtime, "dataPtr", + jsi::ArrayBuffer(runtime, vec[i].dataPtr)); array.setValueAtIndex(runtime, i, tensorObj); } return {runtime, array}; } -inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) { - return jsi::String::createFromUtf8(runtime, str); -} - inline jsi::Value getJsiValue(const std::unordered_map &map, jsi::Runtime &runtime) { jsi::Object mapObj{runtime}; for (auto &[k, v] : map) { - // The string_view keys must be null-terminated! mapObj.setProperty(runtime, k.data(), v); } return mapObj; @@ -419,21 +284,23 @@ inline jsi::Value getJsiValue( const std::vector &detections, jsi::Runtime &runtime) { jsi::Array array(runtime, detections.size()); - for (std::size_t i = 0; i < detections.size(); ++i) { + for (size_t i = 0; i < detections.size(); ++i) { + const auto &d = detections[i]; jsi::Object detection(runtime); jsi::Object bbox(runtime); - bbox.setProperty(runtime, "x1", detections[i].x1); - bbox.setProperty(runtime, "y1", detections[i].y1); - bbox.setProperty(runtime, "x2", detections[i].x2); - bbox.setProperty(runtime, "y2", detections[i].y2); + + bbox.setProperty(runtime, "x1", d.x1); + bbox.setProperty(runtime, "y1", d.y1); + bbox.setProperty(runtime, "x2", d.x2); + bbox.setProperty(runtime, "y2", d.y2); detection.setProperty(runtime, "bbox", bbox); detection.setProperty( runtime, "label", jsi::String::createFromAscii( - runtime, models::object_detection::constants::kCocoLablesMap.at( - detections[i].label))); - detection.setProperty(runtime, "score", detections[i].score); + runtime, + models::object_detection::constants::kCocoLablesMap.at(d.label))); + detection.setProperty(runtime, "score", d.score); array.setValueAtIndex(runtime, i, detection); } return array; @@ -444,43 +311,39 @@ getJsiValue(const std::vector &detections, jsi::Runtime &runtime) { auto jsiDetections = jsi::Array(runtime, detections.size()); for (size_t i = 0; i < detections.size(); ++i) { - const auto &detection = detections[i]; + const auto &d = detections[i]; + auto jsiDetection = jsi::Object(runtime); + auto jsiBbox = jsi::Array(runtime, 4); - auto jsiDetectionObject = jsi::Object(runtime); - - auto jsiBboxArray = jsi::Array(runtime, 4); -#pragma unroll for (size_t j = 0; j < 4u; ++j) { - auto jsiPointObject = jsi::Object(runtime); - jsiPointObject.setProperty(runtime, "x", detection.bbox[j].x); - jsiPointObject.setProperty(runtime, "y", detection.bbox[j].y); - jsiBboxArray.setValueAtIndex(runtime, j, jsiPointObject); + auto point = jsi::Object(runtime); + point.setProperty(runtime, "x", d.bbox[j].x); + point.setProperty(runtime, "y", d.bbox[j].y); + jsiBbox.setValueAtIndex(runtime, j, point); } - jsiDetectionObject.setProperty(runtime, "bbox", jsiBboxArray); - jsiDetectionObject.setProperty( - runtime, "text", jsi::String::createFromUtf8(runtime, detection.text)); - jsiDetectionObject.setProperty(runtime, "score", detection.score); - - jsiDetections.setValueAtIndex(runtime, i, jsiDetectionObject); + jsiDetection.setProperty(runtime, "bbox", jsiBbox); + jsiDetection.setProperty(runtime, "text", + jsi::String::createFromUtf8(runtime, d.text)); + jsiDetection.setProperty(runtime, "score", d.score); + jsiDetections.setValueAtIndex(runtime, i, jsiDetection); } - return jsiDetections; } inline jsi::Value getJsiValue(const std::vector - &speechSegments, + &segments, jsi::Runtime &runtime) { - auto jsiSegments = jsi::Array(runtime, speechSegments.size()); - for (size_t i = 0; i < speechSegments.size(); i++) { - const auto &[start, end] = speechSegments[i]; - auto jsiSegmentObject = jsi::Object(runtime); - jsiSegmentObject.setProperty(runtime, "start", static_cast(start)); - jsiSegmentObject.setProperty(runtime, "end", static_cast(end)); - jsiSegments.setValueAtIndex(runtime, i, jsiSegmentObject); + auto jsiSegments = jsi::Array(runtime, segments.size()); + for (size_t i = 0; i < segments.size(); i++) { + auto segObj = jsi::Object(runtime); + segObj.setProperty(runtime, "start", + static_cast(segments[i].start)); + segObj.setProperty(runtime, "end", static_cast(segments[i].end)); + jsiSegments.setValueAtIndex(runtime, i, segObj); } return jsiSegments; } -} // namespace rnexecutorch::jsi_conversion +} // namespace rnexecutorch::jsi_conversion \ No newline at end of file