Skip to content

Commit d3a678e

Browse files
author
Elad Zelingher
committed
Handling tasks that return internal types
1 parent 05c17b5 commit d3a678e

File tree

3 files changed

+89
-19
lines changed

3 files changed

+89
-19
lines changed

src/WampSharp.Tests/Api/RpcServerTests.cs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Moq;
1+
using System.Threading.Tasks;
2+
using Moq;
23
using NUnit.Framework;
34
using WampSharp.Rpc;
45
using WampSharp.Tests.TestHelpers;
@@ -15,6 +16,12 @@ public interface ICalculator
1516
int Square(int x);
1617
}
1718

19+
public interface INumberProcessor
20+
{
21+
[WampRpcMethod("test/square")]
22+
Task ProcessNumber(int x);
23+
}
24+
1825
[Test]
1926
public void RequestContextIsSet()
2027
{
@@ -44,5 +51,40 @@ public void RequestContextIsSet()
4451

4552
Assert.That(context.SessionId, Is.EqualTo(channel.GetMonitor().SessionId));
4653
}
54+
55+
#if NET45
56+
57+
[Test]
58+
public void AsyncAwaitTaskWork()
59+
{
60+
WampPlayground playground = new WampPlayground();
61+
62+
IWampHost host = playground.Host;
63+
64+
WampRequestContext context = null;
65+
66+
Mock<INumberProcessor> mock = new Mock<INumberProcessor>();
67+
68+
mock.Setup(x => x.ProcessNumber(It.IsAny<int>()))
69+
.Returns(async (int x) =>
70+
{
71+
});
72+
73+
host.HostService(mock.Object);
74+
75+
host.Open();
76+
77+
IWampChannel<MockRaw> channel = playground.CreateNewChannel();
78+
79+
channel.Open();
80+
81+
INumberProcessor proxy = channel.GetRpcProxy<INumberProcessor>();
82+
83+
Task task = proxy.ProcessNumber(4);
84+
85+
mock.Verify(x => x.ProcessNumber(4));
86+
}
87+
88+
#endif
4789
}
4890
}

src/WampSharp/Rpc/Client/TaskExtensions.cs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
using System;
22
using System.Reflection;
33
using System.Threading.Tasks;
4+
using WampSharp.Core.Utilities;
45

56
namespace WampSharp.Rpc.Client
67
{
78
internal static class TaskExtensions
89
{
9-
private static readonly MethodInfo mCastTask = GetCastTaskMethod();
10+
private static readonly MethodInfo mCastTaskToGenericTask = GetCastTaskToGenericTaskMethod();
11+
private static readonly MethodInfo mCastToNonGenericTask = GetCastGenericTaskToNonGenericMethod();
1012

11-
private static MethodInfo GetCastTaskMethod()
13+
private static MethodInfo GetCastGenericTaskToNonGenericMethod()
14+
{
15+
return typeof(TaskExtensions).GetMethod("InnerCastTask",
16+
BindingFlags.Static | BindingFlags.NonPublic);
17+
}
18+
19+
private static MethodInfo GetCastTaskToGenericTaskMethod()
1220
{
1321
return typeof(TaskExtensions).GetMethod("InternalCastTask",
1422
BindingFlags.Static | BindingFlags.NonPublic);
1523
}
1624

1725
public static Task Cast(this Task<object> task, Type taskType)
1826
{
19-
return (Task)mCastTask.MakeGenericMethod(taskType).Invoke(null, new object[] { task });
27+
return (Task)mCastTaskToGenericTask.MakeGenericMethod(taskType).Invoke(null, new object[] { task });
2028
}
2129

2230
private static Task<T> InternalCastTask<T>(Task<object> task)
@@ -39,7 +47,12 @@ public static Task<object> CastTask(this Task task)
3947
}
4048
else
4149
{
42-
result = InnerCastTask((dynamic)task);
50+
Type underlyingType = UnwrapReturnType(task.GetType());
51+
52+
MethodInfo method =
53+
mCastToNonGenericTask.MakeGenericMethod(underlyingType);
54+
55+
result = (Task<object>) method.Invoke(null, new object[] {task});
4356
}
4457

4558
return result;
@@ -71,5 +84,33 @@ private static TResult ContinueWithSafeCallback<TTask, TResult>(TTask task, Func
7184

7285
return result;
7386
}
87+
88+
/// <summary>
89+
/// Unwraps the return type of a given method.
90+
/// </summary>
91+
/// <param name="returnType">The given return type.</param>
92+
/// <returns>The unwrapped return type.</returns>
93+
/// <example>
94+
/// void, Task -> object
95+
/// Task{string} -> string
96+
/// int -> int
97+
/// </example>
98+
public static Type UnwrapReturnType(Type returnType)
99+
{
100+
if (returnType == typeof(void) || returnType == typeof(Task))
101+
{
102+
return typeof(object);
103+
}
104+
105+
Type taskType =
106+
returnType.GetClosedGenericTypeImplementation(typeof(Task<>));
107+
108+
if (taskType != null)
109+
{
110+
return returnType.GetGenericArguments()[0];
111+
}
112+
113+
return returnType;
114+
}
74115
}
75116
}

src/WampSharp/Rpc/Client/WampRpcSerializer.cs

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,7 @@ public WampRpcCall Serialize(MethodInfo method, object[] arguments)
3838

3939
private Type ExtractReturnType(Type returnType)
4040
{
41-
if (returnType == typeof (void) || returnType == typeof(Task))
42-
{
43-
return typeof (object);
44-
}
45-
46-
Type taskType =
47-
returnType.GetClosedGenericTypeImplementation(typeof (Task<>));
48-
49-
if (taskType != null)
50-
{
51-
return returnType.GetGenericArguments()[0];
52-
}
53-
54-
return returnType;
41+
return TaskExtensions.UnwrapReturnType(returnType);
5542
}
5643
}
5744
}

0 commit comments

Comments
 (0)