diff --git a/native/src/jni/hook_bridge.cpp b/native/src/jni/hook_bridge.cpp index 38cc45faf..b4d841cbd 100644 --- a/native/src/jni/hook_bridge.cpp +++ b/native/src/jni/hook_bridge.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -273,171 +274,251 @@ VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, allocateObject, jclass cls) { } /** - * @brief A high-performance, low-level implementation of Method.invoke for super.method() calls. + * Core JNI backend for non-virtual method invocation and special object initialization. * - * This function manually unboxes arguments from a jobject array into a jvalue C-style array, - * calls the appropriate JNI `CallNonvirtual...MethodA` function, - * and then boxes the return value back into a jobject. - * This avoids the overhead of Java reflection. - * - * @warning This is a very sensitive function. - * The `shorty` descriptor must perfectly match the method's actual signature. + * Implementation details: + * 1. Dispatches using JNI CallNonvirtualMethodA. + * 2. Employs stack allocation (alloca) for JNI argument mapping. + * 3. Safely mirrors standard Java reflection (NPEs on null primitives/receivers). + * 4. Prevents JNI Type Confusion and memory leaks by caching primitive wrappers globally, + * while leveraging java.lang.Number for fast implicit widening/narrowing. + * 5. Accurately catches and wraps target method exceptions into InvocationTargetException. */ VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, invokeSpecialMethod, jobject method, jcharArray shorty, jclass cls, jobject thiz, jobjectArray args) { - // --- Cache all necessary MethodIDs for boxing/unboxing primitive wrappers - // --- This is a major performance optimization, done only once. - static auto *const get_int = - env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I"); - static auto *const get_double = - env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D"); - static auto *const get_long = - env->GetMethodID(env->FindClass("java/lang/Long"), "longValue", "()J"); - static auto *const get_float = - env->GetMethodID(env->FindClass("java/lang/Float"), "floatValue", "()F"); - static auto *const get_short = - env->GetMethodID(env->FindClass("java/lang/Short"), "shortValue", "()S"); - static auto *const get_byte = - env->GetMethodID(env->FindClass("java/lang/Byte"), "byteValue", "()B"); - static auto *const get_char = - env->GetMethodID(env->FindClass("java/lang/Character"), "charValue", "()C"); - static auto *const get_boolean = - env->GetMethodID(env->FindClass("java/lang/Boolean"), "booleanValue", "()Z"); - static auto *const set_int = env->GetStaticMethodID(env->FindClass("java/lang/Integer"), - "valueOf", "(I)Ljava/lang/Integer;"); - static auto *const set_double = env->GetStaticMethodID(env->FindClass("java/lang/Double"), - "valueOf", "(D)Ljava/lang/Double;"); + // --- JNI Global Reference Caching --- + // Cached once per process lifecycle to maintain extreme performance and prevent JNI aborts. + static jclass cls_Number = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Number")); + static jclass cls_Boolean = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Boolean")); + static jclass cls_Character = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Character")); + + // Globally cache primitive wrapper classes for safe return value boxing + static jclass cls_Integer = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Integer")); + static jclass cls_Double = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Double")); + static jclass cls_Long = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Long")); + static jclass cls_Float = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Float")); + static jclass cls_Short = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Short")); + static jclass cls_Byte = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Byte")); + + static jclass cls_ITE = + (jclass)env->NewGlobalRef(env->FindClass("java/lang/reflect/InvocationTargetException")); + + static auto *const ctor_ite = env->GetMethodID(cls_ITE, "", "(Ljava/lang/Throwable;)V"); + + static auto *const get_int = env->GetMethodID(cls_Number, "intValue", "()I"); + static auto *const get_double = env->GetMethodID(cls_Number, "doubleValue", "()D"); + static auto *const get_long = env->GetMethodID(cls_Number, "longValue", "()J"); + static auto *const get_float = env->GetMethodID(cls_Number, "floatValue", "()F"); + static auto *const get_short = env->GetMethodID(cls_Number, "shortValue", "()S"); + static auto *const get_byte = env->GetMethodID(cls_Number, "byteValue", "()B"); + + static auto *const get_char = env->GetMethodID(cls_Character, "charValue", "()C"); + static auto *const get_boolean = env->GetMethodID(cls_Boolean, "booleanValue", "()Z"); + + static auto *const set_int = + env->GetStaticMethodID(cls_Integer, "valueOf", "(I)Ljava/lang/Integer;"); + static auto *const set_double = + env->GetStaticMethodID(cls_Double, "valueOf", "(D)Ljava/lang/Double;"); static auto *const set_long = - env->GetStaticMethodID(env->FindClass("java/lang/Long"), "valueOf", "(J)Ljava/lang/Long;"); - static auto *const set_float = env->GetStaticMethodID(env->FindClass("java/lang/Float"), - "valueOf", "(F)Ljava/lang/Float;"); - static auto *const set_short = env->GetStaticMethodID(env->FindClass("java/lang/Short"), - "valueOf", "(S)Ljava/lang/Short;"); + env->GetStaticMethodID(cls_Long, "valueOf", "(J)Ljava/lang/Long;"); + static auto *const set_float = + env->GetStaticMethodID(cls_Float, "valueOf", "(F)Ljava/lang/Float;"); + static auto *const set_short = + env->GetStaticMethodID(cls_Short, "valueOf", "(S)Ljava/lang/Short;"); static auto *const set_byte = - env->GetStaticMethodID(env->FindClass("java/lang/Byte"), "valueOf", "(B)Ljava/lang/Byte;"); - static auto *const set_char = env->GetStaticMethodID(env->FindClass("java/lang/Character"), - "valueOf", "(C)Ljava/lang/Character;"); - static auto *const set_boolean = env->GetStaticMethodID(env->FindClass("java/lang/Boolean"), - "valueOf", "(Z)Ljava/lang/Boolean;"); + env->GetStaticMethodID(cls_Byte, "valueOf", "(B)Ljava/lang/Byte;"); + static auto *const set_char = + env->GetStaticMethodID(cls_Character, "valueOf", "(C)Ljava/lang/Character;"); + static auto *const set_boolean = + env->GetStaticMethodID(cls_Boolean, "valueOf", "(Z)Ljava/lang/Boolean;"); auto target = env->FromReflectedMethod(method); - auto param_len = env->GetArrayLength(shorty) - 1; // First char is return type. + auto param_len = env->GetArrayLength(shorty) - 1; - // --- Argument Validation --- - if (env->GetArrayLength(args) != param_len) { + // --- Argument & Receiver Validation --- + auto args_len = args != nullptr ? env->GetArrayLength(args) : 0; + if (args_len != param_len) { env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "args.length does not match parameter count"); return nullptr; } + if (thiz == nullptr) { - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), - "`this` cannot be null for a non-virtual call"); + env->ThrowNew(env->FindClass("java/lang/NullPointerException"), "null receiver"); return nullptr; } - // --- Unbox Arguments --- - std::vector a(param_len); + // Allocate jvalue array on the stack + jvalue *a = param_len > 0 ? static_cast(alloca(param_len * sizeof(jvalue))) : nullptr; + auto *const shorty_char = env->GetCharArrayElements(shorty, nullptr); + if (shorty_char == nullptr) { + return nullptr; // JVM already threw OutOfMemoryError + } + + // RAII/Helper for clean JNI array exits + auto abort_and_return = [&]() { + env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); + return nullptr; + }; + + // --- Safe Unboxing --- for (jint i = 0; i != param_len; ++i) { jobject element = env->GetObjectArrayElement(args, i); - if (env->ExceptionCheck()) { - env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); - return nullptr; - } + if (env->ExceptionCheck()) return abort_and_return(); - // The shorty string at index i+1 describes the type of the i-th parameter. - switch (shorty_char[i + 1]) { - case 'I': - a[i].i = env->CallIntMethod(element, get_int); - break; - case 'D': - a[i].d = env->CallDoubleMethod(element, get_double); - break; - case 'J': - a[i].j = env->CallLongMethod(element, get_long); - break; - case 'F': - a[i].f = env->CallFloatMethod(element, get_float); - break; - case 'S': - a[i].s = env->CallShortMethod(element, get_short); - break; - case 'B': - a[i].b = env->CallByteMethod(element, get_byte); - break; - case 'C': - a[i].c = env->CallCharMethod(element, get_char); - break; - case 'Z': - a[i].z = env->CallBooleanMethod(element, get_boolean); - break; - default: // Assumes 'L' or '[' for object types - a[i].l = element; - // Set element to null so we don't delete the local ref twice. - // The reference is stored in the jvalue array and is still valid. - element = nullptr; - break; + char type = shorty_char[i + 1]; + + if (element == nullptr) { + if (type != 'L' && type != '[') { + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), + "null primitive argument"); + return abort_and_return(); + } + a[i].l = nullptr; + } else { + if (type == 'Z') { + if (!env->IsInstanceOf(element, cls_Boolean)) { + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), + "Expected Boolean"); + return abort_and_return(); + } + a[i].z = env->CallBooleanMethod(element, get_boolean); + } else if (type == 'C') { + if (!env->IsInstanceOf(element, cls_Character)) { + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), + "Expected Character"); + return abort_and_return(); + } + a[i].c = env->CallCharMethod(element, get_char); + } else if (type != 'L' && type != '[') { + bool is_number = env->IsInstanceOf(element, cls_Number) == JNI_TRUE; + bool is_character = + !is_number && (env->IsInstanceOf(element, cls_Character) == JNI_TRUE); + + if (!is_number && !is_character) { + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), + "Expected Number or Character"); + return abort_and_return(); + } + + // If a Character is passed to a numeric parameter, extract its value for widening + jchar c_val = 0; + if (is_character) { + c_val = env->CallCharMethod(element, get_char); + if (env->ExceptionCheck()) return abort_and_return(); + } + + switch (type) { + case 'I': + a[i].i = env->CallIntMethod(element, get_int); + break; + case 'D': + a[i].d = env->CallDoubleMethod(element, get_double); + break; + case 'J': + a[i].j = env->CallLongMethod(element, get_long); + break; + case 'F': + a[i].f = env->CallFloatMethod(element, get_float); + break; + case 'S': + a[i].s = env->CallShortMethod(element, get_short); + break; + case 'B': + a[i].b = env->CallByteMethod(element, get_byte); + break; + } + } else { + a[i].l = element; + element = + nullptr; // Transferred ownership to jvalue array; will be freed on return + } } - // Clean up the local reference for the wrapper object if it was created. if (element) env->DeleteLocalRef(element); + if (env->ExceptionCheck()) return abort_and_return(); + } - // Check for exceptions during the unboxing call (e.g., - // NullPointerException). - if (env->ExceptionCheck()) { - env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); - return nullptr; + // --- Non-virtual Invocation --- + jvalue ret_val; + switch (shorty_char[0]) { + case 'I': + ret_val.i = env->CallNonvirtualIntMethodA(thiz, cls, target, a); + break; + case 'D': + ret_val.d = env->CallNonvirtualDoubleMethodA(thiz, cls, target, a); + break; + case 'J': + ret_val.j = env->CallNonvirtualLongMethodA(thiz, cls, target, a); + break; + case 'F': + ret_val.f = env->CallNonvirtualFloatMethodA(thiz, cls, target, a); + break; + case 'S': + ret_val.s = env->CallNonvirtualShortMethodA(thiz, cls, target, a); + break; + case 'B': + ret_val.b = env->CallNonvirtualByteMethodA(thiz, cls, target, a); + break; + case 'C': + ret_val.c = env->CallNonvirtualCharMethodA(thiz, cls, target, a); + break; + case 'Z': + ret_val.z = env->CallNonvirtualBooleanMethodA(thiz, cls, target, a); + break; + case 'L': + ret_val.l = env->CallNonvirtualObjectMethodA(thiz, cls, target, a); + break; + default: + env->CallNonvirtualVoidMethodA(thiz, cls, target, a); + break; + } + + // --- Exception Wrapping --- + jthrowable target_exception = env->ExceptionOccurred(); + if (target_exception) { + env->ExceptionClear(); + jobject ite = env->NewObject(cls_ITE, ctor_ite, target_exception); + // Ensure NewObject didn't fail due to OOM before throwing + if (ite) { + env->Throw(static_cast(ite)); } + return abort_and_return(); } - // --- Call Non-virtual Method and Box Return Value --- + // --- Box Return Value --- jobject value = nullptr; - // The shorty string at index 0 describes the return type. switch (shorty_char[0]) { case 'I': - value = - env->CallStaticObjectMethod(jclass{nullptr}, - set_int, // Use Integer.valueOf() to box - env->CallNonvirtualIntMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Integer, set_int, ret_val.i); break; case 'D': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_double, - env->CallNonvirtualDoubleMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Double, set_double, ret_val.d); break; case 'J': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_long, env->CallNonvirtualLongMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Long, set_long, ret_val.j); break; case 'F': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_float, - env->CallNonvirtualFloatMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Float, set_float, ret_val.f); break; case 'S': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_short, - env->CallNonvirtualShortMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Short, set_short, ret_val.s); break; case 'B': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_byte, env->CallNonvirtualByteMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Byte, set_byte, ret_val.b); break; case 'C': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_char, env->CallNonvirtualCharMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Character, set_char, ret_val.c); break; case 'Z': - value = env->CallStaticObjectMethod( - jclass{nullptr}, set_boolean, - env->CallNonvirtualBooleanMethodA(thiz, cls, target, a.data())); + value = env->CallStaticObjectMethod(cls_Boolean, set_boolean, ret_val.z); break; - case 'L': // Return type is an object, no boxing needed. - value = env->CallNonvirtualObjectMethodA(thiz, cls, target, a.data()); + case 'L': + value = ret_val.l; break; - default: // Assumes 'V' for void return type. case 'V': - env->CallNonvirtualVoidMethodA(thiz, cls, target, a.data()); + value = nullptr; break; }