Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions common/src/main/scala/org/mockito/JavaReflectionUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.mockito

import org.mockito.invocation.InvocationOnMock
import ru.vyarus.java.generics.resolver.GenericsResolver

import java.lang.reflect.Field
import scala.util.control.NonFatal

/**
* Utility methods for Java reflection operations, particularly for Mockito mocks.
*/
object JavaReflectionUtils {

def resolveWithJavaGenerics(invocation: InvocationOnMock): Option[Class[_]] =
try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(invocation.method.getDeclaringClass).method(invocation.method).resolveReturnClass())
catch {
case _: Throwable => None
}

def setFinalStatic(field: Field, newValue: AnyRef): Unit =
try {
// Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe)
val unsafeClass: Class[_] =
try
Class.forName("sun.misc.Unsafe")
catch {
case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe")
}

val unsafeField = unsafeClass.getDeclaredField("theUnsafe")
unsafeField.setAccessible(true)
val unsafe = unsafeField.get(null)

// Get methods via reflection to handle both Unsafe implementations
val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field])
val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field])
val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object])

// Make the field accessible
field.setAccessible(true)

// Get base and offset for the field
val base: Object = staticFieldBaseMethod.invoke(unsafe, field)
val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long]

// Set the field value directly
putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue)
} catch {
case NonFatal(e) =>
throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e)
}

}
6 changes: 2 additions & 4 deletions common/src/main/scala/org/mockito/MockitoAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.mockito

import org.mockito.Answers.CALLS_REAL_METHODS
import org.mockito.ReflectionUtils.InvocationOnMockOps
import org.mockito.internal.configuration.plugins.Plugins.getMockMaker
import org.mockito.internal.creation.MockSettingsImpl
import org.mockito.internal.exceptions.Reporter.notAMockPassedToVerifyNoMoreInteractions
Expand Down Expand Up @@ -453,7 +452,6 @@ private[mockito] trait DoSomething {
}

private[mockito] trait MockitoEnhancer extends MockCreator {
implicit val invocationOps: InvocationOnMock => InvocationOnMockOps = InvocationOps

/**
* Delegates to <code>Mockito.mock(type: Class[T])</code> It provides a nicer API as you can, for instance, do <code>mock[MyClass]</code> instead of
Expand Down Expand Up @@ -630,9 +628,9 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
(settings: MockCreationSettings[O], pt: Prettifier) => ThreadAwareMockHandler(settings, realImpl)(pt)
)

ReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
JavaReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
try block
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
finally JavaReflectionUtils.setFinalStatic(moduleField, realImpl)
}
}
}
Expand Down
115 changes: 29 additions & 86 deletions common/src/main/scala/org/mockito/ReflectionUtils.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package org.mockito

import java.lang.reflect.{ Field, Method, Modifier }

import org.mockito.internal.ValueClassWrapper
import org.mockito.JavaReflectionUtils.resolveWithJavaGenerics
import org.mockito.invocation.InvocationOnMock
import org.scalactic.TripleEquals._
import ru.vyarus.java.generics.resolver.GenericsResolver

import java.lang.reflect.Method
import scala.reflect.ClassTag
import scala.reflect.internal.Symbols
import scala.util.{ Try => uTry }
import scala.util.control.NonFatal

object ReflectionUtils {
import scala.reflect.runtime.{ universe => ru }
Expand All @@ -23,58 +20,37 @@ object ReflectionUtils {
def methodToJava(sym: Symbols#MethodSymbol): Method
}]

def listToTuple(l: List[Object]): Any =
l match {
case Nil => Nil
case h :: Nil => h
case _ => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*)
}

implicit class InvocationOnMockOps(val invocation: InvocationOnMock) extends AnyVal {
def mock[M]: M = invocation.getMock.asInstanceOf[M]
def method: Method = invocation.getMethod
def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index))
def args: List[Any] = invocation.getArguments.toList
def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R]
def argsAsTuple: Any = listToTuple(args.map(_.asInstanceOf[Object]))

def returnType: Class[_] = {
val javaReturnType = method.getReturnType
private[mockito] def returnType(invocation: InvocationOnMock): Class[_] = {
val javaReturnType = invocation.method.getReturnType

if (javaReturnType == classOf[Object])
resolveWithScalaGenerics
.orElse(resolveWithJavaGenerics)
.getOrElse(javaReturnType)
else javaReturnType
}
if (javaReturnType == classOf[Object])
resolveWithScalaGenerics(invocation)
.orElse(resolveWithJavaGenerics(invocation))
.getOrElse(javaReturnType)
else javaReturnType
}

def returnsValueClass: Boolean = findTypeSymbol.exists(_.returnType.typeSymbol.isDerivedValueClass)
private[mockito] def returnsValueClass(invocation: InvocationOnMock): Boolean =
findTypeSymbol(invocation).exists(_.returnType.typeSymbol.isDerivedValueClass)

private def resolveWithScalaGenerics: Option[Class[_]] =
uTry {
findTypeSymbol
.filter(_.returnType.typeSymbol.isClass)
.map(_.asMethod.returnType.typeSymbol.asClass)
.map(mirror.runtimeClass)
}.toOption.flatten

private def findTypeSymbol =
uTry {
mirror
.classSymbol(method.getDeclaringClass)
.info
.decls
.collectFirst {
case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === method => symbol
}
}.toOption.flatten
private def resolveWithScalaGenerics(invocation: InvocationOnMock): Option[Class[_]] =
uTry {
findTypeSymbol(invocation)
.filter(_.returnType.typeSymbol.isClass)
.map(_.asMethod.returnType.typeSymbol.asClass)
.map(mirror.runtimeClass)
}.toOption.flatten

private def resolveWithJavaGenerics: Option[Class[_]] =
try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(method.getDeclaringClass).method(method).resolveReturnClass())
catch {
case _: Throwable => None
}
}
private def findTypeSymbol(invocation: InvocationOnMock) =
uTry {
mirror
.classSymbol(invocation.method.getDeclaringClass)
.info
.decls
.collectFirst {
case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === invocation.method => symbol
}
}.toOption.flatten

private def isNonConstructorMethod(d: ru.Symbol): Boolean = d.isMethod && !d.isConstructor

Expand Down Expand Up @@ -113,37 +89,4 @@ object ReflectionUtils {
.getOrElse(Seq.empty)
}

def setFinalStatic(field: Field, newValue: AnyRef): Unit =
try {
// Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe)
val unsafeClass: Class[_] =
try
Class.forName("sun.misc.Unsafe")
catch {
case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe")
}

val unsafeField = unsafeClass.getDeclaredField("theUnsafe")
unsafeField.setAccessible(true)
val unsafe = unsafeField.get(null)

// Get methods via reflection to handle both Unsafe implementations
val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field])
val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field])
val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object])

// Make the field accessible
field.setAccessible(true)

// Get base and offset for the field
val base: Object = staticFieldBaseMethod.invoke(unsafe, field)
val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long]

// Set the field value directly
putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue)
} catch {
case NonFatal(e) =>
throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e)
}

}
16 changes: 14 additions & 2 deletions common/src/main/scala/org/mockito/mockito.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org

import java.lang.reflect.Method

import org.mockito.ReflectionUtils.InvocationOnMockOps
import org.mockito.internal.{ ValueClassExtractor, ValueClassWrapper }
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.ScalaAnswer
Expand All @@ -21,7 +20,20 @@ package object mockito {

def clazz[T](implicit classTag: ClassTag[T]): Class[T] = classTag.runtimeClass.asInstanceOf[Class[T]]

implicit val InvocationOps: InvocationOnMock => InvocationOnMockOps = new InvocationOnMockOps(_)
implicit class InvocationOnMockOps(val invocation: InvocationOnMock) {
def mock[M]: M = invocation.getMock.asInstanceOf[M]
def method: Method = invocation.getMethod
def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index))
def args: List[Any] = invocation.getArguments.toList
def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R]
def argsAsTuple: Any = args.map(_.asInstanceOf[Object]) match {
case Nil => Nil
case h :: Nil => h
case l => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*)
}
def returnType: Class[_] = ReflectionUtils.returnType(invocation)
def returnsValueClass: Boolean = ReflectionUtils.returnsValueClass(invocation)
}

def invocationToAnswer[T: ValueClassExtractor](f: InvocationOnMock => T): ScalaAnswer[T] =
ScalaAnswer.lift(f.andThen(ValueClassExtractor[T].extractAs[T]))
Expand Down