diff --git a/src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs b/src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs index 4772b3ef..21379e88 100644 --- a/src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs +++ b/src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs @@ -177,12 +177,16 @@ static bool ValidateMembers(TypeMetadata typeMetadata, Compilation compilation, } PARAMETERS: - foreach (var typeSymbol in method.Symbol.Parameters - .Select(x => x.Type)) + foreach (var para in method.Symbol.Parameters) { + var typeSymbol = para.Type; if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) continue; if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) continue; + if (typeSymbol is IArrayTypeSymbol arr && para.IsParams) + { + typeSymbol = arr.ElementType; + } var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue); if (!conversion.Exists && (typeSymbol is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol))) { @@ -360,12 +364,23 @@ static void EmitMethodFunction(string functionName, string chunkName, TypeMetada index++; } + var hasParams = methodMetadata.Symbol.Parameters.Any(x => x.IsParams); + var nonParamsCount = 0; foreach (var parameter in methodMetadata.Symbol.Parameters) { var isParameterLuaValue = SymbolEqualityComparer.Default.Equals(parameter.Type, references.LuaValue); - if (parameter.HasExplicitDefaultValue) + if (parameter.IsParams) + { + builder.AppendLine($"var parameters = new LuaValue[context.ArgumentCount - {nonParamsCount}];"); + builder.AppendLine($"for (var i = 0; i < context.ArgumentCount - {nonParamsCount}; i++)"); + builder.AppendLine("{"); + builder.AppendLine($" parameters[i] = context.GetArgument(i + {nonParamsCount});"); + builder.AppendLine("}"); + } + else if (parameter.HasExplicitDefaultValue) { + nonParamsCount++; var syntax = (ParameterSyntax)parameter.DeclaringSyntaxReferences[0].GetSyntax(); if (isParameterLuaValue) @@ -379,6 +394,7 @@ static void EmitMethodFunction(string functionName, string chunkName, TypeMetada } else { + nonParamsCount++; if (isParameterLuaValue) { builder.AppendLine($"var arg{index} = context.GetArgument({index});"); @@ -404,13 +420,31 @@ static void EmitMethodFunction(string functionName, string chunkName, TypeMetada if (methodMetadata.IsStatic) { builder.Append($"{typeMetadata.FullTypeName}.{methodMetadata.Symbol.Name}(", false); - builder.Append(string.Join(",", Enumerable.Range(0, index).Select(x => $"arg{x}")), false); + builder.Append(string.Join(",", Enumerable.Range(0, hasParams ? index - 1 : index).Select(x => $"arg{x}")), false); + if (hasParams) + { + if (index > 1) + { + builder.Append(",", false); + } + builder.Append("parameters", false); + // params must be the last parameter in C# + } builder.AppendLine(");", false); } else { builder.Append($"userData.{methodMetadata.Symbol.Name}("); - builder.Append(string.Join(",", Enumerable.Range(1, index - 1).Select(x => $"arg{x}")), false); + builder.Append(string.Join(",", Enumerable.Range(1, hasParams ? index - 2 : index - 1).Select(x => $"arg{x}")), false); + if (hasParams) + { + if (index > 1) + { + builder.Append(",", false); + } + builder.Append("parameters"); + // params must be the last parameter in C# + } builder.AppendLine(");", false); } diff --git a/src/Lua.SourceGenerator/MethodMetadata.cs b/src/Lua.SourceGenerator/MethodMetadata.cs index f3adba7c..50b394d7 100644 --- a/src/Lua.SourceGenerator/MethodMetadata.cs +++ b/src/Lua.SourceGenerator/MethodMetadata.cs @@ -19,7 +19,10 @@ public MethodMetadata(IMethodSymbol symbol, SymbolReferences references) IsStatic = symbol.IsStatic; var returnType = symbol.ReturnType; - var fullName = (returnType.ContainingNamespace.IsGlobalNamespace ? "" : (returnType.ContainingNamespace + ".")) + returnType.Name; + var isArray = returnType is IArrayTypeSymbol arrayType; + var fullName = isArray ? + "System.Array" + : (returnType.ContainingNamespace.IsGlobalNamespace ? "" : (returnType.ContainingNamespace + ".")) + returnType.Name; IsAsync = fullName is "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask" or "Cysharp.Threading.Tasks.UniTask" diff --git a/tests/Lua.Tests/LuaObjectTests.cs b/tests/Lua.Tests/LuaObjectTests.cs index c531d2ee..d3fdcec1 100644 --- a/tests/Lua.Tests/LuaObjectTests.cs +++ b/tests/Lua.Tests/LuaObjectTests.cs @@ -17,6 +17,19 @@ public static void MethodVoid() Console.WriteLine("HEY!"); } + [LuaMember] + public static LuaTable ParamsMethod(params LuaValue[] arguments) + { + var table = new LuaTable(arguments.Length, arguments.Length); + for (int i = 0; i < arguments.Length; i++) + { + // lua starts at 1 + table[i + 1] = arguments[i]; + } + + return table; + } + [LuaMember] public static async Task MethodAsync() { @@ -88,6 +101,21 @@ public async Task Test_MethodVoid() Assert.That(results, Has.Length.EqualTo(0)); } + [Test] + public async Task Test_ParamsMethod() + { + var userData = new TestUserData(); + + var state = LuaState.Create(); + state.Environment["test"] = userData; + var results = await state.DoStringAsync("return test.ParamsMethod('abc', 'def')"); + + Assert.That(results, Has.Length.EqualTo(1)); + var table = results[0].Read(); + Assert.That(table[1].Read(), Is.EqualTo("abc")); + Assert.That(table[2].Read(), Is.EqualTo("def")); + } + [Test] public async Task Test_MethodAsync() {