1
1
using System ;
2
2
using System . Reflection ;
3
3
using System . Threading . Tasks ;
4
+ using WampSharp . Core . Utilities ;
4
5
5
6
namespace WampSharp . Rpc . Client
6
7
{
7
8
internal static class TaskExtensions
8
9
{
9
- private static readonly MethodInfo mCastTask = GetCastTaskMethod ( ) ;
10
+ private static readonly MethodInfo mCastTaskToGenericTask = GetCastTaskToGenericTaskMethod ( ) ;
11
+ private static readonly MethodInfo mCastToNonGenericTask = GetCastGenericTaskToNonGenericMethod ( ) ;
10
12
11
- private static MethodInfo GetCastTaskMethod ( )
13
+ private static MethodInfo GetCastGenericTaskToNonGenericMethod ( )
14
+ {
15
+ return typeof ( TaskExtensions ) . GetMethod ( "InnerCastTask" ,
16
+ BindingFlags . Static | BindingFlags . NonPublic ) ;
17
+ }
18
+
19
+ private static MethodInfo GetCastTaskToGenericTaskMethod ( )
12
20
{
13
21
return typeof ( TaskExtensions ) . GetMethod ( "InternalCastTask" ,
14
22
BindingFlags . Static | BindingFlags . NonPublic ) ;
15
23
}
16
24
17
25
public static Task Cast ( this Task < object > task , Type taskType )
18
26
{
19
- return ( Task ) mCastTask . MakeGenericMethod ( taskType ) . Invoke ( null , new object [ ] { task } ) ;
27
+ return ( Task ) mCastTaskToGenericTask . MakeGenericMethod ( taskType ) . Invoke ( null , new object [ ] { task } ) ;
20
28
}
21
29
22
30
private static Task < T > InternalCastTask < T > ( Task < object > task )
@@ -39,7 +47,12 @@ public static Task<object> CastTask(this Task task)
39
47
}
40
48
else
41
49
{
42
- result = InnerCastTask ( ( dynamic ) task ) ;
50
+ Type underlyingType = UnwrapReturnType ( task . GetType ( ) ) ;
51
+
52
+ MethodInfo method =
53
+ mCastToNonGenericTask . MakeGenericMethod ( underlyingType ) ;
54
+
55
+ result = ( Task < object > ) method . Invoke ( null , new object [ ] { task } ) ;
43
56
}
44
57
45
58
return result ;
@@ -71,5 +84,33 @@ private static TResult ContinueWithSafeCallback<TTask, TResult>(TTask task, Func
71
84
72
85
return result ;
73
86
}
87
+
88
+ /// <summary>
89
+ /// Unwraps the return type of a given method.
90
+ /// </summary>
91
+ /// <param name="returnType">The given return type.</param>
92
+ /// <returns>The unwrapped return type.</returns>
93
+ /// <example>
94
+ /// void, Task -> object
95
+ /// Task{string} -> string
96
+ /// int -> int
97
+ /// </example>
98
+ public static Type UnwrapReturnType ( Type returnType )
99
+ {
100
+ if ( returnType == typeof ( void ) || returnType == typeof ( Task ) )
101
+ {
102
+ return typeof ( object ) ;
103
+ }
104
+
105
+ Type taskType =
106
+ returnType . GetClosedGenericTypeImplementation ( typeof ( Task < > ) ) ;
107
+
108
+ if ( taskType != null )
109
+ {
110
+ return returnType . GetGenericArguments ( ) [ 0 ] ;
111
+ }
112
+
113
+ return returnType ;
114
+ }
74
115
}
75
116
}
0 commit comments