diff --git a/src/Jitex/Framework/NETFramework.cs b/src/Jitex/Framework/NETFramework.cs deleted file mode 100644 index 4a4887d..0000000 --- a/src/Jitex/Framework/NETFramework.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace Jitex.Framework -{ - internal sealed class NETFramework : RuntimeFramework - { - [DllImport("clrjit.dll", CallingConvention = CallingConvention.StdCall, SetLastError = true, EntryPoint = "getJit", BestFitMapping = true)] - private static extern IntPtr GetJit(); - - public NETFramework() : base(false) - { - } - - protected override IntPtr GetJitAddress() - { - return GetJit(); - } - } -} diff --git a/src/Jitex/Framework/Offsets/CEEInfoOffset.cs b/src/Jitex/Framework/Offsets/CEEInfoOffset.cs index 386e3c0..260b4f5 100644 --- a/src/Jitex/Framework/Offsets/CEEInfoOffset.cs +++ b/src/Jitex/Framework/Offsets/CEEInfoOffset.cs @@ -5,8 +5,8 @@ namespace Jitex.Framework.Offsets internal static class CEEInfoOffset { public static int ResolveToken { get; private set; } - public static int ConstructStringLiteral { get; private set; } + public static int GetEHInfo { get; private set; } static CEEInfoOffset() { @@ -18,10 +18,13 @@ private static void ReadOffset(bool isCore, Version version) { if (isCore && version >= new Version(8, 0, 0)) { + GetEHInfo = 0xB; ResolveToken = 0x1C; ConstructStringLiteral = 0x91; - }else if (isCore && version >= new Version(7, 0, 0)) + } + else if (isCore && version >= new Version(7, 0, 0)) { + GetEHInfo = 0xB; ResolveToken = 0x1D; ConstructStringLiteral = 0x95; } diff --git a/src/Jitex/Framework/RuntimeFramework.cs b/src/Jitex/Framework/RuntimeFramework.cs index c3d0c8c..edf8fe4 100644 --- a/src/Jitex/Framework/RuntimeFramework.cs +++ b/src/Jitex/Framework/RuntimeFramework.cs @@ -2,6 +2,7 @@ using System.Diagnostics; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using Jitex.JIT.CorInfo; @@ -70,7 +71,7 @@ protected RuntimeFramework(bool isCore) IsCore = isCore; Jit = GetJitAddress(); ICorJitCompileVTable = Marshal.ReadIntPtr(Jit); - IntPtr compileMethodPtr = Marshal.ReadIntPtr(ICorJitCompileVTable); + var compileMethodPtr = Marshal.ReadIntPtr(ICorJitCompileVTable); CompileMethod = Marshal.GetDelegateForFunctionPointer(compileMethodPtr); IdentifyFrameworkVersion(); } @@ -84,11 +85,9 @@ private static RuntimeFramework GetFramework() if (_framework != null) return _framework; - string frameworkRunning = RuntimeInformation.FrameworkDescription; + var frameworkRunning = RuntimeInformation.FrameworkDescription; - if (frameworkRunning.StartsWith(".NET Framework")) - _framework = new NETFramework(); - else if (frameworkRunning.StartsWith(".NET")) + if (frameworkRunning.StartsWith(".NET")) _framework = new NETCore(); else throw new NotSupportedException($"Framework {frameworkRunning} is not supported!"); @@ -113,21 +112,21 @@ public void ReadICorJitInfoVTable(IntPtr iCorJitInfo) private void IdentifyFrameworkVersion() { - Assembly assembly = typeof(System.Runtime.GCSettings).GetTypeInfo().Assembly; + var assembly = typeof(System.Runtime.GCSettings).GetTypeInfo().Assembly; string[] assemblyPath = assembly.CodeBase.Split(new[] { '/', '\\' }, StringSplitOptions.RemoveEmptyEntries); - string frameworkName = IsCore ? "Microsoft.NETCore.App" : "Framework64"; + var frameworkName = IsCore ? "Microsoft.NETCore.App" : "Framework64"; - int frameworkIndex = Array.IndexOf(assemblyPath, frameworkName); + var frameworkIndex = Array.IndexOf(assemblyPath, frameworkName); if (frameworkIndex > 0 && frameworkIndex < assemblyPath.Length - 2) { - string version = assemblyPath[frameworkIndex + 1]; + var version = assemblyPath[frameworkIndex + 1]; if (!IsCore) version = version[1..]; - int[] versionsNumbers = version.Split('.').Select(int.Parse).ToArray(); + var versionsNumbers = version.Split('.').Select(int.Parse).ToArray(); FrameworkVersion = new Version(versionsNumbers[0], versionsNumbers[1], versionsNumbers[2]); } else if (AppContext.TargetFrameworkName.StartsWith(".NETCoreApp")) diff --git a/src/Jitex/Hook/HookManager.cs b/src/Jitex/Hook/HookManager.cs deleted file mode 100644 index 5be8e2e..0000000 --- a/src/Jitex/Hook/HookManager.cs +++ /dev/null @@ -1,57 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using Jitex.Utils; -using Mono.Unix.Native; - -namespace Jitex.Hook -{ - internal sealed class HookManager - { - private readonly IList _hooks = new List(); - - /// - /// Inject a delegate in VTable. - /// - /// Pointer to address method. - /// Delegate to be inject. - public void InjectHook(IntPtr pointerAddress, Delegate delToInject) - { - IntPtr originalAddress = Marshal.ReadIntPtr(pointerAddress); - IntPtr hookAddress = Marshal.GetFunctionPointerForDelegate(delToInject); - VTableHook hook = new VTableHook(delToInject, originalAddress, pointerAddress); - WritePointer(pointerAddress, hookAddress); - _hooks.Add(hook); - } - - /// - /// Remove hook from VTable. - /// - /// Delegate to remove. - /// - public bool RemoveHook(Delegate del) - { - VTableHook? hookFound = _hooks.FirstOrDefault(h => h.Delegate.Method.Equals(del.Method)); - - return hookFound != null && RemoveHook(hookFound); - } - - private bool RemoveHook(VTableHook hook) - { - WritePointer(hook.Address, hook.OriginalAddress); - _hooks.Remove(hook); - return true; - } - - /// - /// Write pointer on address. - /// - /// Address to write pointer. - /// Pointer to write. - private static void WritePointer(IntPtr address, IntPtr pointer) - { - MemoryHelper.UnprotectWrite(address, pointer); - } - } -} \ No newline at end of file diff --git a/src/Jitex/Hook/VTableHook.cs b/src/Jitex/Hook/VTableHook.cs deleted file mode 100644 index 55bfbb7..0000000 --- a/src/Jitex/Hook/VTableHook.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; - -namespace Jitex.Hook -{ - internal sealed class VTableHook - { - public Delegate Delegate { get; } - - /// - /// Original address - /// - public IntPtr OriginalAddress { get; } - - /// - /// New address. - /// - public IntPtr Address { get; } - - public VTableHook(Delegate @delegate, IntPtr originalAddress, IntPtr address) - { - Delegate = @delegate; - OriginalAddress = originalAddress; - Address = address; - } - } -} \ No newline at end of file diff --git a/src/Jitex/Intercept/CallContext.cs b/src/Jitex/Intercept/CallContext.cs index a00d9b5..21509e9 100644 --- a/src/Jitex/Intercept/CallContext.cs +++ b/src/Jitex/Intercept/CallContext.cs @@ -45,6 +45,8 @@ public class CallContext /// public int ParametersCount => _parameters.Length; + public Exception? Exception { get; set; } + /// /// If context is waiting for end of call /// @@ -52,7 +54,7 @@ public class CallContext /// Is used to hold original call after call ContinueAsync. /// internal bool IsWaitingForEnd { get; private set; } - + /// /// Create a new context from call (Should not be called directly). /// It's for 32 Bits. @@ -365,6 +367,17 @@ public void SetReturnValue(T value, bool validateType = true) ProceedCall = false; } + public void SetException(Exception exception) + { + + } + + public void ThrowExceptionIfNecessary() + { + if(Exception != null) + throw Exception; + } + internal void ContinueWithCode() { if (_autoResetEvent == null) diff --git a/src/Jitex/Intercept/InterceptManager.cs b/src/Jitex/Intercept/InterceptManager.cs index e904df9..118cb94 100644 --- a/src/Jitex/Intercept/InterceptManager.cs +++ b/src/Jitex/Intercept/InterceptManager.cs @@ -1,12 +1,7 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; -using System.Reflection; using System.Threading.Tasks; -using Jitex.Exceptions; -using Jitex.JIT.Context; -using Jitex.Utils.Comparer; namespace Jitex.Intercept { diff --git a/src/Jitex/Intercept/InterceptorBuilder.cs b/src/Jitex/Intercept/InterceptorBuilder.cs index 07fcbc3..ceeb4db 100644 --- a/src/Jitex/Intercept/InterceptorBuilder.cs +++ b/src/Jitex/Intercept/InterceptorBuilder.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Jitex.PE; using System.Reflection; @@ -52,6 +53,12 @@ internal class InterceptorBuilder : IDisposable private static readonly MethodInfo GetTypeFromHandle = GetTypeFromHandle = typeof(Type).GetMethod(nameof(Type.GetTypeFromHandle))!; + private static readonly MethodInfo SetException = typeof(CallContext) + .GetMethod(nameof(CallContext.SetException), BindingFlags.Public | BindingFlags.Instance)!; + + private static readonly MethodInfo GetThrowExceptionIfNecessary = typeof(CallContext) + .GetMethod(nameof(CallContext.ThrowExceptionIfNecessary), BindingFlags.Public | BindingFlags.Instance)!; + private readonly MethodBase _method; private readonly MethodBody _body; @@ -106,13 +113,11 @@ public MethodBody InjectInterceptor(bool reuseReferences) _image = _imageReader.LoadImage(reuseReferences); } - IList localVariables = _body.LocalVariables; + var localVariables = _body.LocalVariables; - localVariables.Add(new LocalVariableInfo(typeof(CallContext))); - byte callContextVariableIndex = (byte)(localVariables.Count - 1); + var callContextVariableIndex = CreateVariable(localVariables); + var callManagerVariableIndex = CreateVariable(localVariables); - localVariables.Add(new LocalVariableInfo(typeof(CallManager))); - byte callManagerVariableIndex = (byte)(localVariables.Count - 1); int callContextCtorMetadataToken; if (OSHelper.IsX86) @@ -171,17 +176,22 @@ public MethodBody InjectInterceptor(bool reuseReferences) instructions.Add(Ldloc_S, callContextVariableIndex); instructions.Add(Callvirt, getProceedCallMetadataToken); - Instruction gotoInstruction = instructions.Add(Brfalse, 0); //if(context.ProceedCall) + //if(context.ProceedCall) + Instruction gotoReleaseTask = instructions.Add(Brfalse, 0); + var rr = _body.ReadIL(); instructions.AddRange(_body.ReadIL()); instructions.RemoveLast(); //Remove Ret instruction. if (returnType != typeof(void)) instructions.Add(Stloc_S, returnVariableIndex); - Instruction endpointGoto = instructions.Add(Ldloc_S, callManagerVariableIndex); + // WriteExceptionHandler(localVariables, instructions, callContextVariableIndex); + + //callManager.ReleaseTask(); + var endpointGoto = instructions.Add(Ldloc_S, callManagerVariableIndex); instructions.Add(Callvirt, releaseTaskMetadataToken); - gotoInstruction.Value = (endpointGoto.Offset - gotoInstruction.Offset - gotoInstruction.Size); + gotoReleaseTask.Value = (endpointGoto.Offset - gotoReleaseTask.Offset - gotoReleaseTask.Size); WriteGetReturnValue(instructions, callContextVariableIndex, callManagerVariableIndex, returnType); @@ -194,9 +204,54 @@ public MethodBody InjectInterceptor(bool reuseReferences) LocalVariables = localVariables }; + var ll = body.ReadIL(); return body; } + /// + /// Write exception handler on body. + /// + /// + /// Inject the follow code: + /// ---- + /// try { + /// #code... + /// } catch (Exception ex) { + /// context.SetException(ex); + /// } + /// + /// context.ThrowExceptionIfNecessary(); + /// ---- + /// + /// + /// + /// + private void WriteExceptionHandler(IList localVariables, Instructions instructions, + byte callContextVariableIndex) + { + //try{ + // + //} ... + var exceptionVariableIndex = CreateVariable(localVariables); + var leaveTryInstruction = instructions.Add(Leave, 0); + instructions.Add(Stloc_S, exceptionVariableIndex); + + //catch (Exception ex){ + // context.SetException(ex); + //} + instructions.Add(Ldloc_S, callContextVariableIndex); + instructions.Add(Ldloc_S, exceptionVariableIndex); + instructions.Add(Callvirt, SetException); + var leaveCatchInstruction = instructions.Add(Leave, 0); + + //context.ThrowExceptionIfNecessary(); + var endpointLeave = instructions.Add(Ldloc_S, callContextVariableIndex); + instructions.Add(Callvirt, GetThrowExceptionIfNecessary); + + leaveTryInstruction.Value = (endpointLeave.Offset - leaveTryInstruction.Offset - leaveTryInstruction.Size); + leaveCatchInstruction.Value = (endpointLeave.Offset - leaveCatchInstruction.Offset - leaveCatchInstruction.Size); + } + /// /// Write instructions to load parameters from method. /// @@ -275,8 +330,7 @@ private byte WriteInstanceParameter(Instructions instructions) _image!.AddOrGetMemberRef(PointerBox, out int pointerBoxMetadataToken); - variables.Add(new LocalVariableInfo(returnType)); - byte returnVariableIndex = (byte)(variables.Count - 1); + var returnVariableIndex = CreateVariable(variables, returnType); instructions.Add(Ldloca_S, returnVariableIndex); @@ -394,6 +448,14 @@ private void WriteTypesOnArray(Instructions instructions, IReadOnlyCollection(IList variables) => CreateVariable(variables, typeof(T)); + + private byte CreateVariable(IList variables, Type type) + { + variables.Add(new LocalVariableInfo(type)); + return (byte)(variables.Count - 1); + } + private void ValidateImageLoaded() { if (_image == null) diff --git a/src/Jitex/Internal/InternalModule.cs b/src/Jitex/Internal/InternalModule.cs index b239cb1..4e591af 100644 --- a/src/Jitex/Internal/InternalModule.cs +++ b/src/Jitex/Internal/InternalModule.cs @@ -2,7 +2,8 @@ using System.Collections.Concurrent; using System.Reflection; using System.Reflection.Emit; -using Jitex.JIT.Context; +using Jitex.JIT.Hooks.CompileMethod; +using Jitex.JIT.Hooks.Token; using Jitex.Utils; using MethodInfo = System.Reflection.MethodInfo; diff --git a/src/Jitex/JIT/Context/ContextBase.cs b/src/Jitex/JIT/Context/ContextBase.cs deleted file mode 100644 index 51f7455..0000000 --- a/src/Jitex/JIT/Context/ContextBase.cs +++ /dev/null @@ -1,46 +0,0 @@ -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Reflection; - -namespace Jitex.JIT.Context -{ - public abstract class ContextBase - { - private MethodBase? _source; - - /// - /// If context has source method from call. - /// - public bool HasSource { get; private set; } - - /// - /// Method source from call - /// - public MethodBase? Source - { - get - { - if (!HasSource) - { - StackTrace trace = new(1, false); - - IEnumerable methods = trace.GetFrames() - .Select(frame => frame.GetMethod()) - .Where(method => method.DeclaringType.Assembly != typeof(ManagedJit).Assembly); - - _source = methods.FirstOrDefault(); - HasSource = _source != null; - } - - return _source; - } - } - - protected ContextBase(MethodBase? source, bool hasSource) - { - _source = source; - HasSource = hasSource; - } - } -} \ No newline at end of file diff --git a/src/Jitex/JIT/CorInfo/CEEInfo.cs b/src/Jitex/JIT/CorInfo/CEEInfo.cs index 62c266a..a0b3735 100644 --- a/src/Jitex/JIT/CorInfo/CEEInfo.cs +++ b/src/Jitex/JIT/CorInfo/CEEInfo.cs @@ -14,9 +14,10 @@ internal static class CEEInfo private static readonly ConstructStringLiteralDelegate _constructStringLiteral; private static readonly ResolveTokenDelegate _resolveToken; + private static readonly GetEhInfoDelegate _getEHInfo; public static IntPtr ResolveTokenIndex { get; } - + public static IntPtr GetEHInfoIndex { get; set; } public static IntPtr ConstructStringLiteralIndex { get; } [UnmanagedFunctionPointer(CallingConvention.ThisCall)] @@ -26,28 +27,39 @@ internal static class CEEInfo public delegate InfoAccessType ConstructStringLiteralDelegate(IntPtr thisHandle, IntPtr hModule, int metadataToken, IntPtr ptrString); + [UnmanagedFunctionPointer(CallingConvention.ThisCall)] + public delegate void GetEhInfoDelegate(IntPtr thisHandle, IntPtr ftn, uint ehNumber, out IntPtr clause); + static CEEInfo() { if (CEEInfoVTable == IntPtr.Zero) throw new VTableNotLoaded(nameof(CEEInfo)); + GetEHInfoIndex = CEEInfoVTable + IntPtr.Size * CEEInfoOffset.GetEHInfo; ResolveTokenIndex = CEEInfoVTable + IntPtr.Size * CEEInfoOffset.ResolveToken; ConstructStringLiteralIndex = CEEInfoVTable + IntPtr.Size * CEEInfoOffset.ConstructStringLiteral; - IntPtr resolveTokenPtr = Marshal.ReadIntPtr(ResolveTokenIndex); - IntPtr constructStringLiteralPtr = Marshal.ReadIntPtr(ConstructStringLiteralIndex); - + var resolveTokenPtr = Marshal.ReadIntPtr(ResolveTokenIndex); + var getEhInfoPtr = Marshal.ReadIntPtr(GetEHInfoIndex); + var constructStringLiteralPtr = Marshal.ReadIntPtr(ConstructStringLiteralIndex); + _resolveToken = Marshal.GetDelegateForFunctionPointer(resolveTokenPtr); + _getEHInfo = Marshal.GetDelegateForFunctionPointer(getEhInfoPtr); _constructStringLiteral = Marshal.GetDelegateForFunctionPointer(constructStringLiteralPtr); //PrepareMethod in .NET Core 2.0, will raise StackOverFlowException ResolveToken(default, default); - System.Reflection.MethodInfo constructString = typeof(CEEInfo).GetMethod(nameof(ConstructStringLiteral))!; + var constructString = typeof(CEEInfo).GetMethod(nameof(ConstructStringLiteral))!; RuntimeHelpers.PrepareMethod(constructString.MethodHandle); } + public static void GetEHInfo(IntPtr thisHandle, IntPtr ftn, uint ehNumber, out IntPtr clause) + { + _getEHInfo(thisHandle, ftn, ehNumber, out clause); + } + public static void ResolveToken(IntPtr thisHandle, IntPtr pResolvedToken) { if (thisHandle == IntPtr.Zero) diff --git a/src/Jitex/JIT/CorInfo/ConstructString.cs b/src/Jitex/JIT/CorInfo/ConstructString.cs deleted file mode 100644 index c37c267..0000000 --- a/src/Jitex/JIT/CorInfo/ConstructString.cs +++ /dev/null @@ -1,16 +0,0 @@ -using System; - -namespace Jitex.JIT.CorInfo -{ - internal class ConstructString - { - public IntPtr HandleModule { get; } - public int MetadataToken { get; } - - public ConstructString(IntPtr handleModule, int metadataToken) - { - HandleModule = handleModule; - MetadataToken = metadataToken; - } - } -} \ No newline at end of file diff --git a/src/Jitex/JIT/CorInfo/CorInfoEhClause.cs b/src/Jitex/JIT/CorInfo/CorInfoEhClause.cs new file mode 100644 index 0000000..f61e2b1 --- /dev/null +++ b/src/Jitex/JIT/CorInfo/CorInfoEhClause.cs @@ -0,0 +1,13 @@ +using System.Runtime.InteropServices; + +namespace Jitex.JIT.CorInfo; + +[StructLayout(LayoutKind.Sequential)] +public struct CorInfoEhClause +{ + public uint Flags; + public uint TryOffset; + public uint TryLength; + public uint HandlerOffset; + public uint HandlerLength; +} \ No newline at end of file diff --git a/src/Jitex/JIT/CorInfo/MethodInfo.cs b/src/Jitex/JIT/CorInfo/MethodInfo.cs index 64c23e2..9381164 100644 --- a/src/Jitex/JIT/CorInfo/MethodInfo.cs +++ b/src/Jitex/JIT/CorInfo/MethodInfo.cs @@ -20,7 +20,7 @@ public class MethodInfo : CorType private IntPtr ILCodeAddr => HInstance + MethodInfoOffset.ILCode; private IntPtr ILCodeSizeAddr => HInstance + MethodInfoOffset.ILCodeSize; private IntPtr MaxStackAddr => HInstance + MethodInfoOffset.MaxStack; - private IntPtr EHCountAddr => HInstance + MethodInfoOffset.EHCount; + public IntPtr EHCountAddr => HInstance + MethodInfoOffset.EHCount; /// /// Signature from locals variables. diff --git a/src/Jitex/JIT/Handlers/MethodCompiled.cs b/src/Jitex/JIT/Handlers/MethodCompiled.cs index bbee503..df632f7 100644 --- a/src/Jitex/JIT/Handlers/MethodCompiled.cs +++ b/src/Jitex/JIT/Handlers/MethodCompiled.cs @@ -1,8 +1,8 @@ -using Jitex.JIT.Context; -using Jitex.JIT.CorInfo; +using Jitex.JIT.CorInfo; using Jitex.Runtime; using System; using System.Reflection; +using Jitex.JIT.Hooks.CompileMethod; namespace Jitex.JIT.Handlers { diff --git a/src/Jitex/JIT/Hooks/CompileMethod/CompileMethodHook.cs b/src/Jitex/JIT/Hooks/CompileMethod/CompileMethodHook.cs new file mode 100644 index 0000000..bf50a0e --- /dev/null +++ b/src/Jitex/JIT/Hooks/CompileMethod/CompileMethodHook.cs @@ -0,0 +1,270 @@ +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Jitex.Framework; +using Jitex.JIT.CorInfo; +using Jitex.JIT.Handlers; +using Jitex.JIT.Hooks.ExceptionInfo; +using Jitex.JIT.Hooks.String; +using Jitex.JIT.Hooks.Token; +using Jitex.Runtime; +using Jitex.Utils; +using Jitex.Utils.Extension; +using Jitex.Utils.NativeAPI.Windows; +using Mono.Unix.Native; +using MethodInfo = Jitex.JIT.CorInfo.MethodInfo; + +namespace Jitex.JIT.Hooks.CompileMethod; + +/// +/// Method resolver handler. +/// +/// Context of method. +public delegate void MethodResolverHandler(MethodContext context); + +/// +/// Handler to event after compiled method. +/// +/// Method compiled. +public delegate void MethodCompiledHandler(MethodCompiled methodCompiled); + +internal class CompileMethodHook : HookBase +{ + private readonly RuntimeFramework _framework; + + private static RuntimeFramework.CompileMethodDelegate DelegateHook; + + [ThreadStatic] + private static ThreadTls? _tls; + + private static CompileMethodHook? Instance { get; set; } + private static readonly ConcurrentDictionary HandleSource = new(); + private event MethodCompiledHandler? OnMethodCompiled; + + internal void AddOnMethodCompiledEvent(MethodCompiledHandler handler) => OnMethodCompiled += handler; + internal void RemoveOnMethodCompiledEvent(MethodCompiledHandler handler) => OnMethodCompiled -= handler; + + + private CompileMethodHook() + { + _framework = RuntimeFramework.Framework; + } + + public static CompileMethodHook GetInstance() + { + Instance ??= new CompileMethodHook(); + return Instance; + } + + /// + /// Lock to prevent unload in compile time. + /// + private static readonly object JitLock = new object(); + + /// + /// Wrap delegate to compileMethod from ICorJitCompiler. + /// + /// this parameter (pointer to CILJIT). + /// (IN) - Pointer to ICorJitInfo. + /// (IN) - Pointer to CORINFO_METHOD_INFO. + /// (IN) - Pointer to CorJitFlag. + /// (OUT) - Pointer to NativeEntry. + /// (OUT) - Size of NativeEntry. + [MethodImpl(MethodImplOptions.NoInlining)] + private CorJitResult Hook(IntPtr thisPtr, IntPtr comp, IntPtr info, uint flags, IntPtr nativeEntry, + out int nativeSizeOfCode) + { + //Don`t move this line to below of "if"! This line is to prevent stack overflow. + _tls ??= new ThreadTls(); + + if (thisPtr == default) + { + nativeEntry = IntPtr.Zero; + nativeSizeOfCode = 0; + return 0; + } + + _tls.EnterCount++; + + try + { + MethodContext? methodContext = null; + var sigAddress = IntPtr.Zero; + var ilAddress = IntPtr.Zero; + + //Don't put anything inside "if" to be compiled! Otherwise, will raise a StackOverflow + if (_tls.EnterCount > 1) + return _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); + + var methodInfo = new MethodInfo(info); + var methodFound = MethodHelper.GetMethodFromHandle(methodInfo.MethodHandle); + + if (methodFound == null) + return _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); + + if (DynamicHelpers.IsDynamicScope(methodInfo.Scope)) + { + methodFound = DynamicHelpers.GetOwner(methodFound); + } + + if (GetInvocationList().Any()) + { + lock (JitLock) + { + if (_framework.CEEInfoVTable == IntPtr.Zero) + { + _framework.ReadICorJitInfoVTable(comp); + + TokenHook.GetInstance().InjectHook(CEEInfo.ResolveTokenIndex); + //StringHook.GetInstance().InjectHook(CEEInfo.ConstructStringLiteralIndex); + // ExceptionInfoHook.GetInstance().InjectHook(CEEInfo.GetEHInfoIndex); + } + } + + //Try retrieve source from call. + //--- + //Before method to be compiled, he should be "resolved" (resolveToken). + //Inside resolveToken, we can get source (which requested compilation) and destiny handle method (which be compiled). + //In theory, every method to be compiled, should pass inside resolveToken, but has some unknown cases which they will be not "resolved". + //Also, this is an inaccurate way to get source, because in some cases, can return a false source. + var hasSource = HandleSource.TryGetValue(methodInfo.MethodHandle, out var source); + + methodContext = new MethodContext(methodFound, source, hasSource); + + foreach (var handler in GetInvocationList()) + { + handler(methodContext); + + if (methodContext.IsResolved) + break; + } + + //Set instance TLS for ResolveToken. + //I know, it`s weird, but without this line, we got an SEHException. + TokenHook.GetInstance().SetNewInstanceTls(); + + if (methodContext.Mode == MethodContext.ResolveMode.IL) + { + var methodBody = methodContext.Body; + // ExceptionInfoHook.Handle = methodInfo.MethodHandle; + + if (methodBody.HasLocalVariable) + { + byte[] signatureVariables = methodBody.GetSignatureVariables(); + sigAddress = MarshalHelper.CreateArrayCopy(signatureVariables); + + methodInfo.Locals.Signature = sigAddress + 1; + methodInfo.Locals.Args = sigAddress + 3; + methodInfo.Locals.NumArgs = (ushort)methodBody.LocalVariables.Count; + } + + methodInfo.MaxStack = methodBody.MaxStackSize; + methodInfo.EHCount = methodContext.Body.EHCount; + methodInfo.ILCode = MarshalHelper.CreateArrayCopy(methodBody.IL); + methodInfo.ILCodeSize = (uint)methodBody.IL.Length; + } + } + + var result = _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); + + if (result != CorJitResult.CORJIT_OK) + return result; + + var realNativeEntry = MemoryHelper.Read(nativeEntry); + + MethodCompiled methodCompiled = new(methodFound, methodContext, methodInfo, result, realNativeEntry, + nativeSizeOfCode); + + RuntimeMethodCache.AddMethod(methodCompiled); + // OnMethodCompiled?.Invoke(methodCompiled); + + if (ilAddress != IntPtr.Zero) + Marshal.FreeHGlobal(ilAddress); + + if (sigAddress != IntPtr.Zero) + Marshal.FreeHGlobal(sigAddress); + + if (methodContext is not { IsResolved: true }) + return result; + + if (methodContext.Mode == MethodContext.ResolveMode.Native) + { + WriteNative(methodContext.NativeCode!, ref nativeSizeOfCode, nativeEntry); + } + else if (methodContext.Mode == MethodContext.ResolveMode.Entry) + { + var entryContext = methodContext.EntryContext!; + + WriteEntry(entryContext, ref nativeSizeOfCode, nativeEntry); + + methodCompiled.NativeCode.Address = nativeEntry; + methodCompiled.NativeCode.Size = nativeSizeOfCode; + } + + return result; + } + catch (Exception ex) + { + nativeSizeOfCode = default; + throw new Exception("Failed compile method.", ex); + } + finally + { + _tls.EnterCount--; + } + } + + private static void WriteEntry(NativeCode nativeCode, ref int nativeSize, IntPtr nativeEntry) + { + MemoryHelper.Write(nativeEntry, nativeCode.Address); + + if (nativeCode.Size > 0) + nativeSize = nativeCode.Size; + } + + private static void WriteNative(byte[] nativeCode, ref int nativeSize, IntPtr nativeEntry) + { + var size = nativeCode.Length; + var address = Marshal.AllocHGlobal(size); + + unsafe + { + var ptr = Unsafe.AsPointer(ref nativeCode[0]); + Unsafe.CopyBlock(address.ToPointer(), ptr, (uint)size); + } + + MemoryHelper.Write(nativeEntry, address); + nativeSize = size; + + if (OSHelper.IsWindows) + { + Kernel32.VirtualProtect(address, size, Kernel32.MemoryProtection.EXECUTE_READ_WRITE); + } + else + { + var (alignedAddress, alignedSize) = MemoryHelper.GetAlignedAddress(address, size); + + if (OSHelper.IsHardenedRuntime) + Syscall.mprotect(alignedAddress, alignedSize, MmapProts.PROT_READ | MmapProts.PROT_EXEC); + else + Syscall.mprotect(alignedAddress, alignedSize, + MmapProts.PROT_READ | MmapProts.PROT_WRITE | MmapProts.PROT_EXEC); + } + } + + internal static void RegisterSource(IntPtr methodHandle, MethodBase? source) + { + HandleSource.AddOrUpdate(methodHandle, source, (_, _) => source); + } + + public override void PrepareHook() + { + DelegateHook = Hook; + HookAddress = Marshal.GetFunctionPointerForDelegate(DelegateHook); + RuntimeHelperExtension.PrepareDelegate(DelegateHook, IntPtr.Zero, IntPtr.Zero, IntPtr.Zero, (uint)0, + IntPtr.Zero, 0); + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Context/MethodContext.cs b/src/Jitex/JIT/Hooks/CompileMethod/MethodContext.cs similarity index 96% rename from src/Jitex/JIT/Context/MethodContext.cs rename to src/Jitex/JIT/Hooks/CompileMethod/MethodContext.cs index 8501670..40f2cb7 100644 --- a/src/Jitex/JIT/Context/MethodContext.cs +++ b/src/Jitex/JIT/Hooks/CompileMethod/MethodContext.cs @@ -2,28 +2,30 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using System.Threading; using Jitex.Intercept; -using Jitex.PE; using Jitex.Runtime; using Jitex.Utils; using MethodBody = Jitex.Builder.Method.MethodBody; -namespace Jitex.JIT.Context +namespace Jitex.JIT.Hooks.CompileMethod { /// /// Context for method resolution. /// - public class MethodContext : ContextBase + public class MethodContext : Contextbase { private MethodBody? _body; /// /// Resolution mode. /// - [Flags] public enum ResolveMode { + /// + /// None + /// + None = 0, + /// /// MSIL (pre-compile) /// diff --git a/src/Jitex/JIT/Hooks/Contextbase.cs b/src/Jitex/JIT/Hooks/Contextbase.cs new file mode 100644 index 0000000..31e78b6 --- /dev/null +++ b/src/Jitex/JIT/Hooks/Contextbase.cs @@ -0,0 +1,48 @@ +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using Jitex.JIT.Hooks.Token; + +namespace Jitex.JIT.Hooks; + +public abstract class Contextbase +{ + public bool IsResolved { get; set; } + + private MethodBase? _source; + + /// + /// If context has source method from call. + /// + public bool HasSource { get; private set; } + + /// + /// Method source from call + /// + public MethodBase? Source + { + get + { + if (!HasSource) + { + StackTrace trace = new(1, false); + + var methods = trace.GetFrames() + .Select(frame => frame.GetMethod()) + .Where(method => method.DeclaringType.Assembly != typeof(Token.TokenHook).Assembly); + + _source = methods.FirstOrDefault(); + HasSource = _source != null; + } + + return _source; + } + } + + protected Contextbase(MethodBase? source, bool hasSource) + { + _source = source; + HasSource = hasSource; + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionContext.cs b/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionContext.cs new file mode 100644 index 0000000..535ff90 --- /dev/null +++ b/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionContext.cs @@ -0,0 +1,10 @@ +using System.Reflection; + +namespace Jitex.JIT.Hooks.ExceptionInfo; + +public class ExceptionContext : Contextbase +{ + public ExceptionContext(MethodBase? source, bool hasSource) : base(source, hasSource) + { + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionInfoHook.cs b/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionInfoHook.cs new file mode 100644 index 0000000..cea44ca --- /dev/null +++ b/src/Jitex/JIT/Hooks/ExceptionInfo/ExceptionInfoHook.cs @@ -0,0 +1,38 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Jitex.JIT.CorInfo; + +namespace Jitex.JIT.Hooks.ExceptionInfo; + +internal class ExceptionInfoHook : HookBase +{ + public static ExceptionInfoHook? Instance; + public static IntPtr Handle { get; set; } + + public ExceptionInfoHook() + { + } + + public static ExceptionInfoHook GetInstance() + { + Instance ??= new ExceptionInfoHook(); + return Instance; + } + + private static void Hook(IntPtr thisHandle, IntPtr ftn, uint ehNumber, out IntPtr clause) + { + CEEInfo.GetEHInfo(thisHandle, ftn, ehNumber, out clause); + + if (ftn == Handle) + { + var ehInfo = Marshal.PtrToStructure(clause); + Debugger.Break(); + } + } + + public override void PrepareHook() + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/HookBase.cs b/src/Jitex/JIT/Hooks/HookBase.cs new file mode 100644 index 0000000..1accb85 --- /dev/null +++ b/src/Jitex/JIT/Hooks/HookBase.cs @@ -0,0 +1,80 @@ +using System; +using System.Diagnostics; +using System.Linq; +using Jitex.Framework; +using Jitex.JIT.Hooks.Token; +using Jitex.Utils; + +namespace Jitex.JIT.Hooks; + +internal abstract class HookBase : IDisposable +{ + private IntPtr _indexAddress; + private IntPtr _originalAddress; + protected IntPtr HookAddress { get; set; } + + public bool IsEnabled { get; private set; } + protected Delegate? Handlers { get; set; } + + public void InjectHook(IntPtr indexAddress) + { + _indexAddress = indexAddress; + + if (_originalAddress == default) + _originalAddress = MemoryHelper.Read(_indexAddress); + + MemoryHelper.UnprotectWrite(_indexAddress, HookAddress); + IsEnabled = true; + } + + public void RemoveHook() + { + if (_indexAddress == default || _originalAddress == default) + return; + + MemoryHelper.UnprotectWrite(_indexAddress, _originalAddress); + IsEnabled = false; + } + + public void AddHandler(THandler handler) where THandler : Delegate + { + Handlers = (THandler)Delegate.Combine(Handlers, handler); + } + + public void RemoverHandler(THandler handler) where THandler : Delegate + { + Handlers = Delegate.Remove(Handlers, handler) as THandler; + } + + internal bool HasHandler(THandler handler) where THandler : Delegate + { + return GetInvocationList() + .Any(del => del.Method == handler.Method); + } + + protected Delegate[] GetInvocationList () + { + if (Handlers == null) + return []; + + return Handlers.GetInvocationList(); + } + + + protected THandler[] GetInvocationList()where THandler : Delegate + { + if (Handlers == null) + return []; + + return Handlers.GetInvocationList().Cast().ToArray(); + } + + public abstract void PrepareHook(); + + public virtual void Dispose() + { + RemoveHook(); + Handlers = null; + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/String/StringContext.cs b/src/Jitex/JIT/Hooks/String/StringContext.cs new file mode 100644 index 0000000..29a07ee --- /dev/null +++ b/src/Jitex/JIT/Hooks/String/StringContext.cs @@ -0,0 +1,23 @@ +using System.Reflection; + +namespace Jitex.JIT.Hooks.String; + +public class StringContext : Contextbase +{ + public Module Module { get; private set; } + public int MetadataToken { get; private set; } + public string Content { get; private set; } + + public StringContext(Module module, int metadataToken, string content) : base(null, false) + { + Module = module; + MetadataToken = metadataToken; + Content = content; + } + + public void ResolveString(string newContent) + { + IsResolved = true; + Content = newContent; + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/String/StringHook.cs b/src/Jitex/JIT/Hooks/String/StringHook.cs new file mode 100644 index 0000000..deff7b6 --- /dev/null +++ b/src/Jitex/JIT/Hooks/String/StringHook.cs @@ -0,0 +1,108 @@ +using System; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using Jitex.JIT.CorInfo; +using Jitex.Utils; +using Jitex.Utils.Extension; + +namespace Jitex.JIT.Hooks.String; + +/// +/// String resolver handler. +/// +/// Context form string. +public delegate void StringResolverHandler(StringContext context); + +internal class StringHook : HookBase +{ + private static CEEInfo.ConstructStringLiteralDelegate DelegateHook; + + [ThreadStatic] + private static ThreadTls? Tls; + + private static StringHook? Instance { get; set; } + + public static StringHook GetInstance() + { + Instance ??= new StringHook(); + return Instance; + } + + private InfoAccessType Hook(IntPtr thisHandle, IntPtr hModule, int metadataToken, + IntPtr ppValue) + { + if (thisHandle == IntPtr.Zero) + return default; + + Tls ??= new ThreadTls(); + + Tls.EnterCount++; + + try + { + if (Tls.EnterCount != 1) + return CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); + + var resolvers = GetInvocationList(); + + if (!resolvers.Any()) + return CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); + + var module = ModuleHelper.GetModuleByAddress(hModule); + var content = module!.ResolveString(metadataToken); + var context = new StringContext(module, metadataToken, content); + + foreach (StringResolverHandler resolver in resolvers) + { + resolver(context); + + if (!context.IsResolved) continue; + + if (string.IsNullOrEmpty(context.Content)) + throw new ArgumentNullException("String content can't be null or empty."); + + var result = CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); + WriteString(ppValue, context.Content!); + return result; + } + + return CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); + } + finally + { + Tls.EnterCount--; + } + } + + /// + /// Write string on OBJECTHANDLE. + /// + /// Pointer to OBJECTHANDLE. + /// Content to write. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void WriteString(IntPtr ppValue, string content) + { + var pEntry = Marshal.ReadIntPtr(ppValue); + + var objectHandle = Marshal.ReadIntPtr(pEntry); + var hashMapPtr = Marshal.ReadIntPtr(objectHandle); + var newContent = Encoding.Unicode.GetBytes(content); + + objectHandle = Marshal.AllocHGlobal(IntPtr.Size + sizeof(int) + newContent.Length); + + Marshal.WriteIntPtr(objectHandle, hashMapPtr); + Marshal.WriteInt32(objectHandle + IntPtr.Size, newContent.Length / 2); + Marshal.Copy(newContent, 0, objectHandle + IntPtr.Size + sizeof(int), newContent.Length); + + Marshal.WriteIntPtr(pEntry, objectHandle); + } + + public override void PrepareHook() + { + DelegateHook = Hook; + HookAddress = Marshal.GetFunctionPointerForDelegate(DelegateHook); + RuntimeHelperExtension.PrepareDelegate(DelegateHook, IntPtr.Zero, IntPtr.Zero, 0, IntPtr.Zero); + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/CompileTls.cs b/src/Jitex/JIT/Hooks/ThreadTls.cs similarity index 69% rename from src/Jitex/JIT/CompileTls.cs rename to src/Jitex/JIT/Hooks/ThreadTls.cs index 21c3bff..84ff496 100644 --- a/src/Jitex/JIT/CompileTls.cs +++ b/src/Jitex/JIT/Hooks/ThreadTls.cs @@ -1,6 +1,6 @@ namespace Jitex.JIT { - internal class CompileTls + internal class ThreadTls { public int EnterCount; } diff --git a/src/Jitex/JIT/Context/TokenContext.cs b/src/Jitex/JIT/Hooks/Token/TokenContext.cs similarity index 82% rename from src/Jitex/JIT/Context/TokenContext.cs rename to src/Jitex/JIT/Hooks/Token/TokenContext.cs index fe6cb4f..23ee786 100644 --- a/src/Jitex/JIT/Context/TokenContext.cs +++ b/src/Jitex/JIT/Hooks/Token/TokenContext.cs @@ -1,15 +1,13 @@ using System; using System.Reflection; using Jitex.JIT.CorInfo; -using Jitex.Utils; -using MethodInfo = System.Reflection.MethodInfo; -namespace Jitex.JIT.Context +namespace Jitex.JIT.Hooks.Token { /// /// Context for token resolution. /// - public class TokenContext : ContextBase + public class TokenContext : Contextbase { private readonly ResolvedToken? _resolvedToken; private TokenKind _tokenType; @@ -165,39 +163,18 @@ public Module? Module /// public bool IsResolved { get; private set; } - /// - /// Content from string (only to string). - /// - public string? Content { get; private set; } - /// /// Constructor for token type. (non-string) /// /// Original token. /// Source method from compile tree ("requester"). /// Has source from call. - internal TokenContext(ref ResolvedToken resolvedToken, MethodBase? source, bool hasSource) : base(source, hasSource) + internal TokenContext(ref ResolvedToken resolvedToken, MethodBase? source, bool hasSource) : base(source, + hasSource) { _resolvedToken = resolvedToken; } - /// - /// Constructor for string type. - /// - /// Original string. - /// Source method who requested token. - /// /// Has source from call. - internal TokenContext(ConstructString constructString, MethodBase? source, bool hasSource) : base(source, hasSource) - { - _module = ModuleHelper.GetModuleByAddress(constructString.HandleModule); - - TokenType = TokenKind.String; - MetadataToken = constructString.MetadataToken; - - if (Module != null) - Content = Module.ResolveString(MetadataToken); - } - /// /// Resolve token from module. /// @@ -247,15 +224,5 @@ public void ResolveConstructor(ConstructorInfo constructor) { ResolveMethod(constructor); } - - /// - /// Resolve string by content string. - /// - /// Content to replace. - public void ResolveString(string content) - { - IsResolved = true; - Content = content; - } } } \ No newline at end of file diff --git a/src/Jitex/JIT/Hooks/Token/TokenHook.cs b/src/Jitex/JIT/Hooks/Token/TokenHook.cs new file mode 100644 index 0000000..4c58cce --- /dev/null +++ b/src/Jitex/JIT/Hooks/Token/TokenHook.cs @@ -0,0 +1,113 @@ +using System; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using Jitex.Framework.Offsets; +using Jitex.JIT.CorInfo; +using Jitex.Utils; +using Jitex.Utils.Extension; + +namespace Jitex.JIT.Hooks.Token; + +/// +/// Token resolver handler. +/// +/// Context of token. +public delegate void TokenResolverHandler(TokenContext context); + +internal class TokenHook : HookBase +{ + private static CEEInfo.ResolveTokenDelegate DelegateHook; + + [ThreadStatic] + private static ThreadTls? _tls; + + private static TokenHook? Instance { get; set; } + + public static TokenHook GetInstance() + { + Instance ??= new TokenHook(); + return Instance; + } + + private void Hook(IntPtr thisHandle, IntPtr pResolvedToken) + { + _tls ??= new ThreadTls(); + _tls.EnterCount++; + + if (thisHandle == IntPtr.Zero) + { + CompileMethod.CompileMethodHook.RegisterSource(IntPtr.Zero, null); + return; + } + + var token = 0; + + try + { + if (_tls.EnterCount > 1) + { + CEEInfo.ResolveToken(thisHandle, pResolvedToken); + return; + } + + var resolvers = GetInvocationList(); + + if (!resolvers.Any()) + { + CEEInfo.ResolveToken(thisHandle, pResolvedToken); + return; + } + + var resolvedToken = new ResolvedToken(pResolvedToken); + token = resolvedToken.Token; //Just to show on exception. + + MethodBase? source = null; + + if (!OSHelper.IsX86) + { + var sourceAddress = + Marshal.ReadIntPtr(thisHandle, IntPtr.Size * ResolvedTokenOffset.SourceOffset); + if (sourceAddress != default) + source = MethodHelper.GetMethodFromHandle(sourceAddress); + } + + var hasSource = source != null; + + var context = new TokenContext(ref resolvedToken, source, hasSource); + + foreach (var resolver in resolvers) + { + resolver(context); + } + + CEEInfo.ResolveToken(thisHandle, pResolvedToken); + + if (resolvedToken.HMethod != IntPtr.Zero) + { + CompileMethod.CompileMethodHook.RegisterSource(resolvedToken.HMethod, source); + } + } + catch (Exception ex) + { + throw new Exception($"Failed to resolve token: 0x{token:X}.", ex); + } + finally + { + _tls.EnterCount--; + } + } + + public override void PrepareHook() + { + DelegateHook = Hook; + HookAddress = Marshal.GetFunctionPointerForDelegate(DelegateHook); + RuntimeHelperExtension.PrepareDelegate(DelegateHook, IntPtr.Zero, IntPtr.Zero); + } + + + public void SetNewInstanceTls() + { + _tls = new ThreadTls(); + } +} \ No newline at end of file diff --git a/src/Jitex/JIT/ManagedJit.cs b/src/Jitex/JIT/ManagedJit.cs deleted file mode 100644 index b4433b9..0000000 --- a/src/Jitex/JIT/ManagedJit.cs +++ /dev/null @@ -1,614 +0,0 @@ -using Jitex.Hook; -using Jitex.JIT.CorInfo; -using Jitex.Utils; -using System; -using System.Collections.Concurrent; -using System.Linq; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text; -using Jitex.Framework; -using Jitex.Framework.Offsets; -using Jitex.JIT.Context; -using Jitex.Runtime; -using Microsoft.Extensions.Logging; -using MethodBody = Jitex.Builder.Method.MethodBody; -using MethodInfo = Jitex.JIT.CorInfo.MethodInfo; -using static Jitex.JIT.JitexHandler; -using static Jitex.Utils.JitexLogger; -using Jitex.JIT.Handlers; -using Jitex.Utils.Extension; -using Jitex.Utils.NativeAPI.Windows; -using Mono.Unix.Native; -using static Jitex.Utils.MemoryHelper; - -namespace Jitex.JIT -{ - /// - /// Handlers to expose hooks. - /// - public static class JitexHandler - { - /// - /// Method resolver handler. - /// - /// Context of method. - public delegate void MethodResolverHandler(MethodContext context); - - /// - /// Token resolver handler. - /// - /// Context of token. - public delegate void TokenResolverHandler(TokenContext context); - - /// - /// Handler to event after compiled method. - /// - /// Method compiled. - public delegate void MethodCompiledHandler(MethodCompiled methodCompiled); - } - - /// - /// Hook instance from JIT. - /// - internal sealed class ManagedJit : IDisposable - { - private readonly ConcurrentDictionary _handleSource = new(); - - /// - /// Lock to prevent multiple instance. - /// - private static readonly object InstanceLock = new object(); - - /// - /// Lock to prevent unload in compile time. - /// - private static readonly object JitLock = new object(); - - /// - /// Current instance of JIT. - /// - private static ManagedJit? _instance; - - [ThreadStatic] - private static CompileTls? _compileTls; - - [ThreadStatic] - private static TokenTls? _tokenTls; - - private readonly HookManager _hookManager = new HookManager(); - - /// - /// Running framework. - /// - private readonly RuntimeFramework _framework; - - /// - /// Custom compíle method. - /// - private RuntimeFramework.CompileMethodDelegate _compileMethod; - - /// - /// Custom resolve token. - /// - private CEEInfo.ResolveTokenDelegate _resolveToken; - - /// - /// Custom construct string literal. - /// - private CEEInfo.ConstructStringLiteralDelegate _constructStringLiteral; - - private bool _isDisposed; - - private event MethodCompiledHandler? OnMethodCompiled; - - private MethodResolverHandler? _methodResolvers; - - private TokenResolverHandler? _tokenResolvers; - - public bool IsEnabled { get; private set; } - - /// - /// Prepare custom JIT. - /// - private ManagedJit() - { - ModuleHelper.Initialize(); - _framework = RuntimeFramework.Framework; - - _compileMethod = CompileMethod; - _resolveToken = ResolveToken; - _constructStringLiteral = ConstructStringLiteral; - - PrepareHook(); - - _hookManager.InjectHook(_framework.ICorJitCompileVTable, _compileMethod); - IsEnabled = true; - } - - private void PrepareHook() - { - Log?.LogTrace("Preparing delegate for CompileMethod"); - RuntimeHelperExtension.PrepareDelegate(_compileMethod, IntPtr.Zero, IntPtr.Zero, IntPtr.Zero, (uint)0, - IntPtr.Zero, 0); - - Log?.LogTrace("Preparing delegate for ResolveToken"); - RuntimeHelperExtension.PrepareDelegate(_resolveToken, IntPtr.Zero, IntPtr.Zero); - - Log?.LogTrace("Preparing delegate for ConstructStringLiteral"); - RuntimeHelperExtension.PrepareDelegate(_constructStringLiteral, IntPtr.Zero, IntPtr.Zero, 0, IntPtr.Zero); - } - - /// - /// Get singleton instance from ManagedJit. - /// - /// - internal static ManagedJit GetInstance() - { - lock (InstanceLock) - { - return _instance ??= new ManagedJit(); - } - } - - internal void AddMethodResolver(MethodResolverHandler methodResolver) => _methodResolvers += methodResolver; - - internal void AddTokenResolver(TokenResolverHandler tokenResolver) => _tokenResolvers += tokenResolver; - - internal void RemoveMethodResolver(MethodResolverHandler methodResolver) => _methodResolvers -= methodResolver; - - internal void RemoveTokenResolver(TokenResolverHandler tokenResolver) => _tokenResolvers -= tokenResolver; - - internal void AddOnMethodCompiledEvent(MethodCompiledHandler handler) => OnMethodCompiled += handler; - internal void RemoveOnMethodCompiledEvent(MethodCompiledHandler handler) => OnMethodCompiled -= handler; - - internal bool HasMethodResolver(MethodResolverHandler methodResolver) => _methodResolvers != null && - _methodResolvers.GetInvocationList().Any(del => del.Method == methodResolver.Method); - - internal bool HasTokenResolver(TokenResolverHandler tokenResolver) => _tokenResolvers != null && - _tokenResolvers.GetInvocationList() - .Any(del => del.Method == - tokenResolver.Method); - - - /// - /// Enable Jitex hooks - /// - internal void Enable() - { - lock (JitLock) - { - if (IsEnabled) - return; - - _hookManager.InjectHook(_framework.ICorJitCompileVTable, _compileMethod); - - if (_framework.CEEInfoVTable != IntPtr.Zero) - { - _hookManager.InjectHook(CEEInfo.ResolveTokenIndex, _resolveToken); - _hookManager.InjectHook(CEEInfo.ConstructStringLiteralIndex, _constructStringLiteral); - } - } - - IsEnabled = true; - } - - /// - /// Disable Jitex hooks - /// - internal void Disable() - { - lock (JitLock) - { - if (!IsEnabled) - return; - - _hookManager.RemoveHook(_compileMethod); - - if (_framework.CEEInfoVTable != IntPtr.Zero) - { - _hookManager.RemoveHook(_resolveToken); - _hookManager.RemoveHook(_constructStringLiteral); - } - } - - IsEnabled = false; - } - - /// - /// Wrap delegate to compileMethod from ICorJitCompiler. - /// - /// this parameter (pointer to CILJIT). - /// (IN) - Pointer to ICorJitInfo. - /// (IN) - Pointer to CORINFO_METHOD_INFO. - /// (IN) - Pointer to CorJitFlag. - /// (OUT) - Pointer to NativeEntry. - /// (OUT) - Size of NativeEntry. - [MethodImpl(MethodImplOptions.NoInlining)] - private CorJitResult CompileMethod(IntPtr thisPtr, IntPtr comp, IntPtr info, uint flags, IntPtr nativeEntry, - out int nativeSizeOfCode) - { - using IDisposable compileMethodScope = Log?.BeginScope("CompileMethod")!; - - _compileTls ??= new CompileTls(); - - if (thisPtr == default) - { - nativeEntry = IntPtr.Zero; - nativeSizeOfCode = 0; - return 0; - } - - _compileTls.EnterCount++; - - try - { - MethodContext? methodContext = null; - IntPtr sigAddress = IntPtr.Zero; - IntPtr ilAddress = IntPtr.Zero; - - //Dont put anything inside "if" to be compiled! Otherwise, will raise a StackOverflow - if (_compileTls.EnterCount > 1) - return _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); - - MethodInfo methodInfo = new MethodInfo(info); - MethodBase? methodFound = MethodHelper.GetMethodFromHandle(methodInfo.MethodHandle); - - if (methodFound == null) - { - Log?.LogTrace( - $"Method for handle: {methodInfo.MethodHandle} not found. Calling original CompileMethod..."); - return _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); - } - - if (DynamicHelpers.IsDynamicScope(methodInfo.Scope)) - { - Log?.LogDebug("Is a dynamic scope, getting owner..."); - methodFound = DynamicHelpers.GetOwner(methodFound); - } - - using IDisposable methodScope = Log?.BeginScope(methodFound.ToString())!; - Log?.LogInformation($"Method to be compiled: {methodFound}"); - - Delegate[] resolvers = _methodResolvers == null - ? Array.Empty() - : _methodResolvers.GetInvocationList(); - - if (resolvers.Any()) - { - lock (JitLock) - { - if (_framework.CEEInfoVTable == IntPtr.Zero) - { - Log?.LogTrace("Reading CEEInfoVTable..."); - _framework.ReadICorJitInfoVTable(comp); - - Log?.LogTrace("Injecting hook for ResolveToken"); - _hookManager.InjectHook(CEEInfo.ResolveTokenIndex, _resolveToken); - - Log?.LogTrace("Injecting hook for ConstructStringLiteralIndex"); - _hookManager.InjectHook(CEEInfo.ConstructStringLiteralIndex, _constructStringLiteral); - } - } - - //Try retrieve source from call. - //--- - //Before method to be compiled, he should be "resolved" (resolveToken). - //Inside resolveToken, we can get source (which requested compilation) and destiny handle method (which be compiled). - //In theory, every method to be compiled, should pass inside resolveToken, but has some unknown cases which they will be not "resolved". - //Also, this is an inaccurate way to get source, because in some cases, can return a false source. - bool hasSource = _handleSource.TryGetValue(methodInfo.MethodHandle, out MethodBase? source); - - methodContext = new MethodContext(methodFound, source, hasSource); - - foreach (MethodResolverHandler resolver in resolvers) - { - try - { - Log?.LogInformation( - $"Calling resolver [{resolver.Method.DeclaringType?.FullName}.{resolver.Method.Name}]"); - - resolver(methodContext); - } - catch (Exception ex) - { - Log?.LogError(ex, - $"Failed to execute resolver [{resolver.Method.DeclaringType?.FullName}.{resolver.Method.Name}]."); - } - - if (methodContext.IsResolved) - { - Log?.LogInformation( - $"Method resolved by [{resolver.Method.DeclaringType?.FullName}.{resolver.Method.Name}]"); - break; - } - } - - Log?.LogDebug( - $"Is method resolved: {methodContext.IsResolved}. ResolveMode: {methodContext.Mode.ToString()}"); - - _tokenTls = new TokenTls(); - - if (methodContext.IsResolved && methodContext.Mode.HasFlag(MethodContext.ResolveMode.IL)) - { - MethodBody methodBody = methodContext.Body; - - if (methodBody.HasLocalVariable) - { - byte[] signatureVariables = methodBody.GetSignatureVariables(); - sigAddress = MarshalHelper.CreateArrayCopy(signatureVariables); - - methodInfo.Locals.Signature = sigAddress + 1; - methodInfo.Locals.Args = sigAddress + 3; - methodInfo.Locals.NumArgs = (ushort)methodBody.LocalVariables.Count; - } - - methodInfo.MaxStack = methodBody.MaxStackSize; - methodInfo.EHCount = methodContext.Body.EHCount; - methodInfo.ILCode = MarshalHelper.CreateArrayCopy(methodBody.IL); - methodInfo.ILCodeSize = (uint)methodBody.IL.Length; - } - } - - var result = _framework.CompileMethod(thisPtr, comp, info, flags, nativeEntry, out nativeSizeOfCode); - - if (result != CorJitResult.CORJIT_OK) - { - Log?.LogCritical($"Result from original compileMethod: {result}"); - return result; - } - - var realNativeEntry = Read(nativeEntry); - - MethodCompiled methodCompiled = new(methodFound, methodContext, methodInfo, result, realNativeEntry, - nativeSizeOfCode); - - RuntimeMethodCache.AddMethod(methodCompiled); - OnMethodCompiled?.Invoke(methodCompiled); - - if (ilAddress != IntPtr.Zero) - Marshal.FreeHGlobal(ilAddress); - - if (sigAddress != IntPtr.Zero) - Marshal.FreeHGlobal(sigAddress); - - if (methodContext is not { IsResolved: true }) - return result; - - if (methodContext.Mode == MethodContext.ResolveMode.Native) - { - Log?.LogDebug("Overwriting generated native code..."); - - WriteNative(methodContext.NativeCode!, ref nativeSizeOfCode, nativeEntry); - - Log?.LogDebug("Native code overwrited."); - } - else if (methodContext.Mode == MethodContext.ResolveMode.Entry) - { - Log?.LogDebug($"Overwriting original EntryPoint..."); - - var entryContext = methodContext.EntryContext!; - - WriteEntry(entryContext, ref nativeSizeOfCode, nativeEntry); - - methodCompiled.NativeCode.Address = nativeEntry; - methodCompiled.NativeCode.Size = nativeSizeOfCode; - - Log?.LogDebug("EntryPoint overwrited."); - } - - return result; - } - catch (Exception ex) - { - Log?.LogCritical(ex, "Failed to compile method."); - nativeSizeOfCode = default; - throw new Exception("Failed compile method.", ex); - } - - finally - { - _compileTls.EnterCount--; - } - } - - private static void WriteEntry(NativeCode nativeCode, ref int nativeSize, IntPtr nativeEntry) - { - Write(nativeEntry, nativeCode.Address); - - if (nativeCode.Size > 0) - nativeSize = nativeCode.Size; - } - - private static void WriteNative(byte[] nativeCode, ref int nativeSize, IntPtr nativeEntry) - { - var size = nativeCode.Length; - var address = Marshal.AllocHGlobal(size); - - unsafe - { - var ptr = Unsafe.AsPointer(ref nativeCode[0]); - Unsafe.CopyBlock(address.ToPointer(), ptr, (uint)size); - } - - Write(nativeEntry, address); - nativeSize = size; - - if (OSHelper.IsWindows) - { - Kernel32.VirtualProtect(address, size, Kernel32.MemoryProtection.EXECUTE_READ_WRITE); - } - else - { - var (alignedAddress, alignedSize) = GetAlignedAddress(address, size); - - if (OSHelper.IsHardenedRuntime) - Syscall.mprotect(alignedAddress, alignedSize, MmapProts.PROT_READ | MmapProts.PROT_EXEC); - else - Syscall.mprotect(alignedAddress, alignedSize, - MmapProts.PROT_READ | MmapProts.PROT_WRITE | MmapProts.PROT_EXEC); - } - } - - private void ResolveToken(IntPtr thisHandle, IntPtr pResolvedToken) - { - _tokenTls ??= new TokenTls(); - _tokenTls.EnterCount++; - - if (thisHandle == IntPtr.Zero) - { - _handleSource.AddOrUpdate(IntPtr.Zero, MethodBase.GetCurrentMethod(), (ptr, b) => null); - return; - } - - int token = 0; - - try - { - if (_tokenTls.EnterCount > 1 || _tokenResolvers == null) - { - CEEInfo.ResolveToken(thisHandle, pResolvedToken); - return; - } - - Delegate[] resolvers = _tokenResolvers.GetInvocationList(); - - if (!resolvers.Any()) - { - CEEInfo.ResolveToken(thisHandle, pResolvedToken); - return; - } - - ResolvedToken resolvedToken = new ResolvedToken(pResolvedToken); - token = resolvedToken.Token; //Just to show on exception. - - MethodBase? source = null; - - if (!OSHelper.IsX86) - { - IntPtr sourceAddress = - Marshal.ReadIntPtr(thisHandle, IntPtr.Size * ResolvedTokenOffset.SourceOffset); - if (sourceAddress != default) - source = MethodHelper.GetMethodFromHandle(sourceAddress); - } - - bool hasSource = source != null; - - TokenContext context = new TokenContext(ref resolvedToken, source, hasSource); - - foreach (TokenResolverHandler resolver in resolvers) - { - resolver(context); - } - - CEEInfo.ResolveToken(thisHandle, pResolvedToken); - - if (resolvedToken.HMethod != IntPtr.Zero) - { - if (!_handleSource.TryGetValue(resolvedToken.HMethod, out MethodBase? _)) - { - _handleSource[resolvedToken.HMethod] = source; - } - } - } - catch (Exception ex) - { - throw new Exception($"Failed to resolve token: 0x{token:X}.", ex); - } - finally - { - _tokenTls.EnterCount--; - } - } - - private InfoAccessType ConstructStringLiteral(IntPtr thisHandle, IntPtr hModule, int metadataToken, - IntPtr ppValue) - { - if (thisHandle == IntPtr.Zero) - return default; - - _tokenTls ??= new TokenTls(); - - _tokenTls.EnterCount++; - - try - { - if (_tokenTls.EnterCount == 1 && _tokenResolvers != null) - { - Delegate[] resolvers = _tokenResolvers.GetInvocationList(); - - if (!resolvers.Any()) - return CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); - - ConstructString constructString = new ConstructString(hModule, metadataToken); - TokenContext context = new TokenContext(constructString, null, false); - - foreach (TokenResolverHandler resolver in resolvers) - { - resolver(context); - - if (context.IsResolved) - { - if (string.IsNullOrEmpty(context.Content)) - throw new ArgumentNullException("String content can't be null or empty."); - - var result = CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); - WriteString(ppValue, context.Content!); - return result; - } - } - } - - return CEEInfo.ConstructStringLiteral(thisHandle, hModule, metadataToken, ppValue); - } - finally - { - _tokenTls.EnterCount--; - } - } - - /// - /// Write string on OBJECTHANDLE. - /// - /// Pointer to OBJECTHANDLE. - /// Content to write. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void WriteString(IntPtr ppValue, string content) - { - IntPtr pEntry = Marshal.ReadIntPtr(ppValue); - - IntPtr objectHandle = Marshal.ReadIntPtr(pEntry); - IntPtr hashMapPtr = Marshal.ReadIntPtr(objectHandle); - byte[] newContent = Encoding.Unicode.GetBytes(content); - - objectHandle = Marshal.AllocHGlobal(IntPtr.Size + sizeof(int) + newContent.Length); - - Marshal.WriteIntPtr(objectHandle, hashMapPtr); - Marshal.WriteInt32(objectHandle + IntPtr.Size, newContent.Length / 2); - Marshal.Copy(newContent, 0, objectHandle + IntPtr.Size + sizeof(int), newContent.Length); - - Marshal.WriteIntPtr(pEntry, objectHandle); - } - - public void Dispose() - { - lock (JitLock) - { - if (_isDisposed) - return; - - Disable(); - - _methodResolvers = null; - _tokenResolvers = null; - - _instance = null; - _isDisposed = true; - IsEnabled = false; - } - - GC.SuppressFinalize(this); - } - } -} \ No newline at end of file diff --git a/src/Jitex/JIT/TokenTls.cs b/src/Jitex/JIT/TokenTls.cs deleted file mode 100644 index 476dede..0000000 --- a/src/Jitex/JIT/TokenTls.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Jitex.JIT -{ - internal class TokenTls : CompileTls - { - } -} diff --git a/src/Jitex/Jitex.csproj b/src/Jitex/Jitex.csproj index 8aea1dc..4ebbdc8 100644 --- a/src/Jitex/Jitex.csproj +++ b/src/Jitex/Jitex.csproj @@ -30,7 +30,7 @@ - + @@ -41,9 +41,9 @@ - - - + + + diff --git a/src/Jitex/JitexManager.cs b/src/Jitex/JitexManager.cs index d2982f0..f02896c 100644 --- a/src/Jitex/JitexManager.cs +++ b/src/Jitex/JitexManager.cs @@ -1,9 +1,12 @@ using System; using System.Collections.Generic; +using Jitex.Framework; using Jitex.JIT; using Jitex.Utils.Comparer; using Jitex.Intercept; -using static Jitex.JIT.JitexHandler; +using Jitex.JIT.Hooks.CompileMethod; +using Jitex.JIT.Hooks.String; +using Jitex.JIT.Hooks.Token; namespace Jitex { @@ -12,17 +15,11 @@ namespace Jitex /// public static class JitexManager { - private static readonly object LockModules = new object(); - private static readonly object MethodResolverLock = new object(); - private static readonly object TokenResolverLock = new object(); - private static readonly object CallInterceptorLock = new object(); - private static readonly object OnMethodCompiledLock = new object(); - - private static ManagedJit? _jit; private static InterceptManager? _interceptManager; - - private static ManagedJit Jit => _jit ??= ManagedJit.GetInstance(); private static InterceptManager InterceptManager => _interceptManager ??= InterceptManager.GetInstance(); + private static CompileMethodHook CompileMethodHook => CompileMethodHook.GetInstance(); + private static TokenHook TokenHook => TokenHook.GetInstance(); + private static StringHook StringHook => StringHook.GetInstance(); /// /// All modules load on Jitex. @@ -30,6 +27,13 @@ public static class JitexManager private static IDictionary ModulesLoaded { get; } = new Dictionary(TypeEqualityComparer.Instance); + static JitexManager() + { + CompileMethodHook.PrepareHook(); + TokenHook.PrepareHook(); + StringHook.PrepareHook(); + } + /// /// Event to raise when method was compiled. /// @@ -67,20 +71,39 @@ public static event InterceptHandler.InterceptorHandler Interceptor remove => RemoveInterceptor(value); } + /// + /// String resolver + /// + public static event StringResolverHandler StringResolver + { + add => AddStringResolver(value); + remove => RemoveStringResolver(value); + } + /// /// Returns if Jitex is enabled. /// - public static bool IsEnabled => _jit is { IsEnabled: true }; + public static bool IsEnabled => CompileMethodHook.IsEnabled; /// /// Enable Jitex /// - public static void EnableJitex() => Jit.Enable(); + public static void EnableJitex() + { + CompileMethodHook.InjectHook(RuntimeFramework.Framework.ICorJitCompileVTable); + // TokenHook.InjectHook(RuntimeFramework.Framework.CEEInfoVTable); + // StringHook.InjectHook(RuntimeFramework.Framework.CEEInfoVTable); + } /// /// Disable Jitex /// - public static void DisableJitex() => Jit.Disable(); + public static void DisableJitex() + { + CompileMethodHook.RemoveHook(); + // TokenHook.RemoveHook(); + // StringHook.RemoveHook(); + } /// /// Load module on Jitex. @@ -88,16 +111,13 @@ public static event InterceptHandler.InterceptorHandler Interceptor /// Module to load. public static void LoadModule(Type typeModule) { - lock (LockModules) + if (!ModuleIsLoaded(typeModule)) { - if (!ModuleIsLoaded(typeModule)) - { - JitexModule module = (JitexModule)Activator.CreateInstance(typeModule); + JitexModule module = (JitexModule)Activator.CreateInstance(typeModule); - module.LoadResolvers(); + module.LoadResolvers(); - ModulesLoaded.Add(typeModule, module); - } + ModulesLoaded.Add(typeModule, module); } } @@ -110,16 +130,13 @@ public static void LoadModule(Type typeModule, object? instance) { if (instance == null) throw new ArgumentNullException(nameof(instance)); - lock (LockModules) + if (!ModuleIsLoaded(typeModule)) { - if (!ModuleIsLoaded(typeModule)) - { - JitexModule module = (JitexModule)instance; + JitexModule module = (JitexModule)instance; - module.LoadResolvers(); + module.LoadResolvers(); - ModulesLoaded.Add(typeModule, module); - } + ModulesLoaded.Add(typeModule, module); } } @@ -156,13 +173,10 @@ public static void RemoveModule() where TModule : JitexModule /// Module to remove. public static void RemoveModule(Type typeModule) { - lock (LockModules) + if (ModulesLoaded.TryGetValue(typeModule, out JitexModule module)) { - if (ModulesLoaded.TryGetValue(typeModule, out JitexModule module)) - { - ModulesLoaded.Remove(typeModule); - module.Dispose(); - } + ModulesLoaded.Remove(typeModule); + module.Dispose(); } } @@ -191,13 +205,10 @@ public static bool ModuleIsLoaded(Type typeModule) /// Interceptor to call. public static void AddInterceptor(InterceptHandler.InterceptorHandler interceptorCallAsync) { - lock (CallInterceptorLock) - { - InterceptManager.AddInterceptorCall(interceptorCallAsync); + InterceptManager.AddInterceptorCall(interceptorCallAsync); - if (!IsEnabled) - EnableJitex(); - } + if (!IsEnabled) + EnableJitex(); } /// @@ -206,8 +217,7 @@ public static void AddInterceptor(InterceptHandler.InterceptorHandler intercepto /// Interceptor to remove. public static void RemoveInterceptor(InterceptHandler.InterceptorHandler interceptorCall) { - lock (CallInterceptorLock) - InterceptManager.RemoveInterceptorCall(interceptorCall); + InterceptManager.RemoveInterceptorCall(interceptorCall); } /// @@ -217,8 +227,7 @@ public static void RemoveInterceptor(InterceptHandler.InterceptorHandler interce /// Returns true if loaded, otherwise returns false. public static bool HasInterceptor(InterceptHandler.InterceptorHandler interceptorCall) { - lock (CallInterceptorLock) - return InterceptManager.HasInteceptorCall(interceptorCall); + return InterceptManager.HasInteceptorCall(interceptorCall); } /// @@ -227,48 +236,40 @@ public static bool HasInterceptor(InterceptHandler.InterceptorHandler intercepto /// Method resolver to add. public static void AddMethodResolver(MethodResolverHandler methodResolver) { - lock (MethodResolverLock) - { - Jit.AddMethodResolver(methodResolver); + CompileMethodHook.AddHandler(methodResolver); - if (!IsEnabled) - EnableJitex(); - } + if (!IsEnabled) + EnableJitex(); } /// /// Add a token resolver. /// /// Token resolver to add. - public static void AddTokenResolver(JitexHandler.TokenResolverHandler tokenResolver) + public static void AddTokenResolver(TokenResolverHandler tokenResolver) { - lock (TokenResolverLock) - { - Jit.AddTokenResolver(tokenResolver); + TokenHook.AddHandler(tokenResolver); - if (!IsEnabled) - EnableJitex(); - } + if (!IsEnabled) + EnableJitex(); } /// /// Remove a method resolver. /// /// Method resolver to remove. - public static void RemoveMethodResolver(JitexHandler.MethodResolverHandler methodResolver) + public static void RemoveMethodResolver(MethodResolverHandler methodResolver) { - lock (MethodResolverLock) - Jit.RemoveMethodResolver(methodResolver); + CompileMethodHook.RemoverHandler(methodResolver); } /// /// Remove a token resolver. /// /// Token resolver to remove. - public static void RemoveTokenResolver(JitexHandler.TokenResolverHandler tokenResolver) + public static void RemoveTokenResolver(TokenResolverHandler tokenResolver) { - lock (TokenResolverLock) - Jit.RemoveTokenResolver(tokenResolver); + TokenHook.RemoverHandler(tokenResolver); } /// @@ -277,8 +278,7 @@ public static void RemoveTokenResolver(JitexHandler.TokenResolverHandler tokenRe /// public static void AddOnMethodCompiled(MethodCompiledHandler onMethodCompiled) { - lock (OnMethodCompiledLock) - Jit.AddOnMethodCompiledEvent(onMethodCompiled); + CompileMethodHook.AddOnMethodCompiledEvent(onMethodCompiled); } /// @@ -287,8 +287,25 @@ public static void AddOnMethodCompiled(MethodCompiledHandler onMethodCompiled) /// public static void RemoveOnMethodCompiled(MethodCompiledHandler onMethodCompiled) { - lock (OnMethodCompiledLock) - Jit.RemoveOnMethodCompiledEvent(onMethodCompiled); + CompileMethodHook.RemoveOnMethodCompiledEvent(onMethodCompiled); + } + + /// + /// Add string resolver. + /// + /// + public static void AddStringResolver(StringResolverHandler stringResolver) + { + StringHook.AddHandler(stringResolver); + } + + /// + /// Remove string resolver. + /// + /// + public static void RemoveStringResolver(StringResolverHandler stringResolver) + { + StringHook.RemoverHandler(stringResolver); } /// @@ -296,16 +313,20 @@ public static void RemoveOnMethodCompiled(MethodCompiledHandler onMethodCompiled /// /// Method resolver. /// True to already loaded. False to not loaded. - public static bool HasMethodResolver(JitexHandler.MethodResolverHandler methodResolver) => - Jit.HasMethodResolver(methodResolver); + public static bool HasMethodResolver(MethodResolverHandler methodResolver) + { + return CompileMethodHook.HasHandler(methodResolver); + } /// /// Returns If a token resolver is already loaded. /// /// Token resolver. /// True to already loaded. False to not loaded. - public static bool HasTokenResolver(JitexHandler.TokenResolverHandler tokenResolver) => - Jit.HasTokenResolver(tokenResolver); + public static bool HasTokenResolver(TokenResolverHandler tokenResolver) + { + return TokenHook.HasHandler(tokenResolver); + } /// /// Unload Jitex and modules from application. @@ -325,13 +346,9 @@ public static bool TryGetModule(out TModule? instance) public static void Remove() { - if (_jit != null) - { - ModulesLoaded.Clear(); - - _jit.Dispose(); - _jit = null; - } + CompileMethodHook.Dispose(); + TokenHook.Dispose(); + StringHook.Dispose(); } } } \ No newline at end of file diff --git a/src/Jitex/JitexModule.cs b/src/Jitex/JitexModule.cs index 069ec57..a8bfb6e 100644 --- a/src/Jitex/JitexModule.cs +++ b/src/Jitex/JitexModule.cs @@ -1,5 +1,6 @@ using System; -using Jitex.JIT.Context; +using Jitex.JIT.Hooks.CompileMethod; +using Jitex.JIT.Hooks.Token; namespace Jitex { diff --git a/src/Jitex/PE/NativeReader.cs b/src/Jitex/PE/NativeReader.cs index 78d6a0c..b41c5ba 100644 --- a/src/Jitex/PE/NativeReader.cs +++ b/src/Jitex/PE/NativeReader.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Concurrent; +using System.Diagnostics; using System.Reflection; using System.Runtime.CompilerServices; using dnlib.DotNet; +using dnlib.IO; using dnlib.PE; using dnlib.W32Resources; using Jitex.Framework; @@ -12,14 +14,14 @@ namespace Jitex.PE { - internal class NativeReader + public class NativeReader { + private static readonly FieldInfo _fieldData; private static readonly bool FrameworkSupportR2R; - private static readonly ConcurrentDictionary Images = new(); private readonly bool _hasRtr; - private readonly IntPtr _base; + private IntPtr _base; private uint _size; private int _entryIndexSize; private int _nElements; @@ -29,13 +31,14 @@ internal class NativeReader static NativeReader() { FrameworkSupportR2R = RuntimeFramework.Framework >= new Version(3, 0); + _fieldData = typeof(DataReaderFactory).Assembly.GetType("dnlib.IO.MemoryMappedDataReaderFactory", true) + .GetField("data", BindingFlags.Instance | BindingFlags.NonPublic); } public NativeReader(Module module) { if (!Images.TryGetValue(module, out ImageInfo image)) { - _base = ModuleHelper.GetModuleHandle(module); image = LoadImage(module); Images.TryAdd(module, image); _hasRtr = image.NumberOfElements > 0; @@ -44,7 +47,7 @@ public NativeReader(Module module) { _base = image!.BaseAddress; _size = image.Size; - _nElements = (int) image.NumberOfElements; + _nElements = (int)image.NumberOfElements; _entryIndexSize = image.EntryIndexSize; _baseOffset = image.BaseOffset; _hasRtr = image.NumberOfElements > 0; @@ -53,8 +56,9 @@ public NativeReader(Module module) private ImageInfo LoadImage(Module module) { - ModuleContext moduleContext = ModuleDef.CreateModuleContext(); - using ModuleDefMD moduleDef = ModuleDefMD.Load(module, moduleContext); + var moduleContext = ModuleDef.CreateModuleContext(); + using var moduleDef = ModuleDefMD.Load(module, moduleContext); + _base = (IntPtr)_fieldData.GetValue(moduleDef.Metadata.PEImage.DataReaderFactory); bool hasR2R = moduleDef.Metadata.ImageCor20Header.HasNativeHeader && FrameworkSupportR2R; @@ -62,7 +66,8 @@ private ImageInfo LoadImage(Module module) { _size = moduleDef.Metadata.PEImage.DataReaderFactory.Length; - IntPtr startHeaderAddress = _base + (int) moduleDef.Metadata.ImageCor20Header.ManagedNativeHeader.VirtualAddress; + IntPtr startHeaderAddress = + _base + (int)moduleDef.Metadata.ImageCor20Header.ManagedNativeHeader.VirtualAddress; uint virtualAddress = GetEntryPointSection(startHeaderAddress); if (virtualAddress == 0) @@ -72,12 +77,12 @@ private ImageInfo LoadImage(Module module) unsafe { - _baseOffset = DecodeUnsigned((int) virtualAddress, &val); + _baseOffset = DecodeUnsigned((int)virtualAddress, &val); } - _nElements = (int) (val >> 2); - _entryIndexSize = (byte) (val & 3); - return new ImageInfo(module, _base, _size, _baseOffset, (uint) _nElements, (byte) _entryIndexSize); + _nElements = (int)(val >> 2); + _entryIndexSize = (byte)(val & 3); + return new ImageInfo(module, _base, _size, _baseOffset, (uint)_nElements, (byte)_entryIndexSize); } return new ImageInfo(module); @@ -91,7 +96,8 @@ private static unsafe uint GetEntryPointSection(IntPtr startHeader) return 0; IntPtr startSection = startHeader + sizeof(READYTORUN_HEADER); - ReadOnlySpan sections = new(startSection.ToPointer(), (int) header.CoreHeader.NumberOfSections); + ReadOnlySpan sections = new(startSection.ToPointer(), + (int)header.CoreHeader.NumberOfSections); foreach (READYTORUN_SECTION section in sections) { @@ -107,7 +113,7 @@ private unsafe int DecodeUnsigned(int offset, uint* pValue) if (offset >= _size) throw new BadImageFormatException(); - uint val = *(byte*) (_base + offset); + uint val = *(byte*)(_base + offset); if ((val & 1) == 0) { *pValue = (val >> 1); @@ -118,7 +124,7 @@ private unsafe int DecodeUnsigned(int offset, uint* pValue) if (offset + 1 >= _size) throw new BadImageFormatException(); *pValue = ((val >> 2) | - ((uint) *(byte*) (_base + offset + 1) << 6)); + ((uint)*(byte*)(_base + offset + 1) << 6)); offset += 2; } else if ((val & 4) == 0) @@ -126,8 +132,8 @@ private unsafe int DecodeUnsigned(int offset, uint* pValue) if (offset + 2 >= _size) throw new BadImageFormatException(); *pValue = (val >> 3) | - ((uint) *(byte*) (_base + offset + 1) << 5) | - ((uint) *(byte*) (_base + offset + 2) << 13); + ((uint)*(byte*)(_base + offset + 1) << 5) | + ((uint)*(byte*)(_base + offset + 2) << 13); offset += 3; } else if ((val & 8) == 0) @@ -135,9 +141,9 @@ private unsafe int DecodeUnsigned(int offset, uint* pValue) if (offset + 3 >= _size) throw new BadImageFormatException(); *pValue = (val >> 4) | - ((uint) (byte*) (_base + offset + 1) << 4) | - ((uint) (byte*) (_base + offset + 2) << 12) | - ((uint) (byte*) (_base + offset + 3) << 20); + ((uint)(byte*)(_base + offset + 1) << 4) | + ((uint)(byte*)(_base + offset + 2) << 12) | + ((uint)(byte*)(_base + offset + 3) << 20); offset += 4; } else if ((val & 16) == 0) @@ -165,12 +171,12 @@ public bool IsReadyToRun(MethodBase method) uint offset = _entryIndexSize switch { - 0 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (index / BlockSize)), - 1 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (2 * (index / BlockSize))), - _ => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (4 * (index / BlockSize))) + 0 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(index / BlockSize)), + 1 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(2 * (index / BlockSize))), + _ => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(4 * (index / BlockSize))) }; - offset += (uint) _baseOffset; + offset += (uint)_baseOffset; for (uint bit = BlockSize >> 1; bit > 0; bit >>= 1) { @@ -179,7 +185,7 @@ public bool IsReadyToRun(MethodBase method) unsafe { - offset2 = (uint) DecodeUnsigned((int) offset, &val); + offset2 = (uint)DecodeUnsigned((int)offset, &val); } if ((index & bit) != 0) @@ -224,13 +230,13 @@ public bool DisableReadyToRun(MethodBase method) uint offset = _entryIndexSize switch { - 0 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (index / BlockSize)), - 1 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (2 * (index / BlockSize))), - _ => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int) (4 * (index / BlockSize))) + 0 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(index / BlockSize)), + 1 => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(2 * (index / BlockSize))), + _ => MemoryHelper.ReadUnaligned(_base, _baseOffset + (int)(4 * (index / BlockSize))) }; - offset += (uint) _baseOffset; - MemoryHelper.UnprotectWrite(_base, (int) offset, 0x00); + offset += (uint)_baseOffset; + MemoryHelper.UnprotectWrite(_base, (int)offset, 0x00); return true; } diff --git a/src/Jitex/Utils/ModuleHelper.cs b/src/Jitex/Utils/ModuleHelper.cs index 2d117be..2039f89 100644 --- a/src/Jitex/Utils/ModuleHelper.cs +++ b/src/Jitex/Utils/ModuleHelper.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using dnlib.DotNet; namespace Jitex.Utils { @@ -73,106 +74,5 @@ private static void LoadMapScopeToHandle() AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainOnAssemblyLoad; } } - - public static IntPtr GetModuleHandle(Module module) - { - IntPtr address; - - if (OSHelper.IsWindows) - address = GetModuleWindows(module); - else if (OSHelper.IsLinux) - address = GetModuleLinux(module.FullyQualifiedName); - else - address = GetModuleOSX(module.FullyQualifiedName); - - if (address == default) - throw new BadImageFormatException($"Base address for module {module.FullyQualifiedName} not found!"); - - return address; - } - - private static IntPtr GetModuleLinux(string modulePath) - { - lock (LockSelfMapsLinux) - { - using FileStream fs = File.OpenRead("/proc/self/maps"); - using StreamReader sr = new(fs); - - do - { - string line = sr.ReadLine()!; - if (!line.EndsWith(modulePath)) - continue; - - int separator = line.IndexOf("-", StringComparison.Ordinal); - - //TODO: Implement Span in future... - IntPtr address = new(long.Parse(line[..separator], NumberStyles.HexNumber)); - - if (!IsValidModuleHandle(address)) - continue; - - return address; - } while (!sr.EndOfStream); - } - - return default; - } - - private static IntPtr GetModuleWindows(Module module) - { - if (!OSHelper.IsWindows) - throw new InvalidOperationException(); - - return (IntPtr)GetHInstance.Invoke(null, new object[] { module }); - } - - public static IntPtr GetModuleOSX(string modulePath) - { - //TODO: Get modules from mach_vm_region and proc_regionfilename. - Process proc = new() - { - StartInfo = new ProcessStartInfo - { - FileName = "vmmap", - Arguments = Process.GetCurrentProcess().Id.ToString(), - UseShellExecute = false, - RedirectStandardOutput = true, - CreateNoWindow = true - } - }; - - proc.Start(); - while (!proc.StandardOutput.EndOfStream) - { - string? line = proc.StandardOutput.ReadLine(); - - if (string.IsNullOrEmpty(line)) - break; - - if (!line.EndsWith(Path.GetFileName(modulePath))) - continue; - - int middleAddress = line.IndexOf("-"); - int startRangeAddress = line.LastIndexOf(' ', middleAddress) + 1; - IntPtr address = new(long.Parse(line[startRangeAddress..middleAddress], NumberStyles.HexNumber)); - - if (!IsValidModuleHandle(address)) - continue; - - return address; - } - - return default; - } - - private static bool IsValidModuleHandle(IntPtr address) - { - byte b1 = MemoryHelper.Read(address); - byte b2 = MemoryHelper.Read(address, 1); - - //Validate if address start with MZ - return b1 == 0x4D && b2 == 0x5A; - } } } \ No newline at end of file