diff --git a/src/Lua/LuaTable.cs b/src/Lua/LuaTable.cs index 6bc5b665..cee72838 100644 --- a/src/Lua/LuaTable.cs +++ b/src/Lua/LuaTable.cs @@ -4,7 +4,12 @@ namespace Lua; -public sealed class LuaTable : IEnumerable> +public interface ILuaEnumerable +{ + bool TryGetNext(LuaValue key, out KeyValuePair pair); +} + +public sealed class LuaTable : IEnumerable>, ILuaEnumerable { public LuaTable() : this(8, 8) { diff --git a/src/Lua/Standard/BasicLibrary.cs b/src/Lua/Standard/BasicLibrary.cs index bdf35463..d4cac830 100644 --- a/src/Lua/Standard/BasicLibrary.cs +++ b/src/Lua/Standard/BasicLibrary.cs @@ -136,10 +136,20 @@ public ValueTask GetMetatable(LuaFunctionExecutionContext context, Cancella public async ValueTask IPairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { - var arg0 = context.GetArgument(0); + var arg0 = context.GetArgument(0); + + LuaTable metatable = default; + if (arg0.TryRead(out LuaTable table)) + { + metatable = table.Metatable; + } + else if (arg0.TryRead(out ILuaUserData userdata)) + { + metatable = userdata.Metatable; + } // If table has a metamethod __ipairs, calls it with table as argument and returns the first three results from the call. - if (arg0.Metatable != null && arg0.Metatable.TryGetValue(Metamethods.IPairs, out var metamethod)) + if (metatable != null && metatable.TryGetValue(Metamethods.IPairs, out var metamethod)) { var stack = context.Thread.Stack; var top = stack.Count; @@ -218,10 +228,16 @@ public ValueTask Load(LuaFunctionExecutionContext context, CancellationToke public ValueTask Next(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { - var arg0 = context.GetArgument(0); + var arg0 = context.GetArgument(0); var arg1 = context.HasArgument(1) ? context.Arguments[1] : LuaValue.Nil; - if (arg0.TryGetNext(arg1, out var kv)) + ILuaEnumerable enumerable = default; + if (arg0.TryRead(out LuaTable table)) + enumerable = table; + else if (arg0.TryRead(out ILuaUserData userdata) && userdata is ILuaEnumerable) + enumerable = userdata as ILuaEnumerable; + + if (enumerable != null && enumerable.TryGetNext(arg1, out var kv)) { return new(context.Return(kv.Key, kv.Value)); } @@ -233,10 +249,20 @@ public ValueTask Next(LuaFunctionExecutionContext context, CancellationToke public async ValueTask Pairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { - var arg0 = context.GetArgument(0); + var arg0 = context.GetArgument(0); + + LuaTable metatable = default; + if (arg0.TryRead(out LuaTable table)) + { + metatable = table.Metatable; + } + else if (arg0.TryRead(out ILuaUserData userdata)) + { + metatable = userdata.Metatable; + } // If table has a metamethod __pairs, calls it with table as argument and returns the first three results from the call. - if (arg0.Metatable != null && arg0.Metatable.TryGetValue(Metamethods.Pairs, out var metamethod)) + if (metatable != null && metatable.TryGetValue(Metamethods.Pairs, out var metamethod)) { var stack = context.Thread.Stack; var top = stack.Count; diff --git a/tests/Lua.Tests/UserDataPairs/UserDataPairsTests.cs b/tests/Lua.Tests/UserDataPairs/UserDataPairsTests.cs new file mode 100644 index 00000000..a5c345c9 --- /dev/null +++ b/tests/Lua.Tests/UserDataPairs/UserDataPairsTests.cs @@ -0,0 +1,117 @@ +// Copyright (C) 2021-2025 Steffen Itterheim +// Refer to included LICENSE file for terms and conditions. + +using Lua.Runtime; +using Lua.Standard; +using Lua.Tests.Helpers; + +namespace Lua.Tests.UserDataPairs; + +public class UserDataPairsTests +{ + [TestCase("userdatapairs.lua")] + public async Task Test_UserDataPairs(string file) + { + var state = LuaState.Create(); + state.Platform.StandardIO = new TestStandardIO(); + state.OpenStandardLibraries(); + state.Environment["LuaList"] = new LuaValue(new LuaList()); + + var path = FileHelper.GetAbsolutePath(file); + Directory.SetCurrentDirectory(Path.GetDirectoryName(path)!); + try + { + await state.DoFileAsync(Path.GetFileName(file)); + } + catch (LuaRuntimeException e) + { + var luaTraceback = e.LuaTraceback; + if (luaTraceback == null) + { + throw; + } + + var line = luaTraceback.FirstLine; + throw new($"{path}:{line} \n{e.InnerException}\n {e}"); + } + } +} + + +public sealed class LuaList : ILuaUserData, ILuaEnumerable +{ + static readonly LuaFunction __len = new(Metamethods.Len, (context, _) => + { + return new ValueTask(context.Return(3)); + }); + static readonly LuaFunction __pairs = new(Metamethods.Pairs, (context, _) => + { + var arg0 = context.GetArgument(0); + return new ValueTask(context.Return(LuaListIterator, arg0, LuaValue.Nil)); + }); + + static readonly LuaFunction LuaListIterator = new LuaFunction("listnext", (context, token) => + { + var list = context.GetArgument>(0); + var key = context.HasArgument(1) ? context.Arguments[1] : LuaValue.Nil; + + var index = -1; + if (key.Type is LuaValueType.Nil) + { + index = 0; + } + else if (key.TryRead(out int number) && number > 0 && number < list.ManagedArray.Length) + { + index = number; + } + + if (index != -1) + { + return new(context.Return(++index, list.ManagedArray[index - 1])); + } + + return new(context.Return(LuaValue.Nil)); + }); + + static LuaTable s_Metatable; + public LuaTable Metatable { get => s_Metatable; set => throw new NotImplementedException(); } + + public int[] ManagedArray { get; } + //public Dictionary ManagedDict { get; } + public LuaList() + { + ManagedArray = new [] { 1,2,3,4,5 }; + //ManagedDict = new Dictionary {{"TRUE", true}, {"FALSE", false}}; + + s_Metatable = new LuaTable(); + s_Metatable[Metamethods.Len] = __len; + s_Metatable[Metamethods.Pairs] = __pairs; + s_Metatable[Metamethods.IPairs] = __pairs; + } + + public bool TryGetNext(LuaValue key, out KeyValuePair pair) + { + var index = -1; + if (key.Type is LuaValueType.Nil) + { + index = 0; + } + else if (key.TryRead(out int integer) && integer > 0 && integer <= ManagedArray.Length) + { + index = integer; + } + + if (index != -1) + { + var span = ManagedArray.AsSpan(index); + for (var i = 0; i < span.Length; i++) + { + pair = new(index + i + 1, span[i]); + return true; + } + } + + pair = default; + return false; + } +} \ No newline at end of file diff --git a/tests/Lua.Tests/UserDataPairs/userdatapairs.lua b/tests/Lua.Tests/UserDataPairs/userdatapairs.lua new file mode 100644 index 00000000..7009ee37 --- /dev/null +++ b/tests/Lua.Tests/UserDataPairs/userdatapairs.lua @@ -0,0 +1,51 @@ +local iterations = 0 + +print("LuaList pairs:") +for k, v in pairs(LuaList) do + print("List[" .. tostring(k) .. "] = " .. tostring(v)) + assert(k == v) + iterations = iterations + 1 +end +assert(iterations == 5) + + +iterations = 0 +print("LuaList ipairs:") +for i, v in ipairs(LuaList) do + print("List[" .. tostring(i) .. "] = " .. tostring(v)) + assert(i == v) + iterations = iterations + 1 +end +assert(iterations == 5) + + +iterations = 0 +print("LuaList next:") +local i, v = next(LuaList, nil) +while i do + print("List[" .. tostring(i) .. "] = " .. tostring(v)) + assert(i == v) + iterations = iterations + 1 + + i, v = next(LuaList, i) +end +assert(iterations == 5) + + +local t = +{ + 1, 2, 3, 4, 5, + ["some key"] = "some value", + ["another key"] = "another value", + [-1000] = -1001, + [_G] = print, +} + +print("LuaTable pairs:") +for k, v in pairs(t) do + print(k, v) +end +print("LuaTable ipairs:") +for i, v in ipairs(t) do + print(i, v) +end