diff --git a/README.md b/README.md index b77adc6..55ed574 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # async.nvim Small async library for Neovim plugins +[API documentation](async.md) + ## Example Take the current function that uses a callback style function to run a system process. @@ -19,8 +21,8 @@ end If we want to emulate something like: -```lua -echo foo && echo bar && echo baz +```bash +echo 'foo' && echo 'bar' && echo 'baz' ``` Would need to be implemented as: @@ -60,7 +62,7 @@ local run_job_a = a.wrap(run_job, 3) Now we need to create a top level function to initialize the async context. To do this we can use `void` or `sync`. -Note: the main difference between `void` and `sync` is that `sync` functions can be called with a callback (like the `run_job` in a non-async context, however the user must provide the number of agurments. +Note: the main difference between `void` and `sync` is that `sync` functions can be called with a callback, like the original `run_job` in a non-async context, however the user must provide the number of arguments. For this example we will use `void`: @@ -84,3 +86,45 @@ main() We can now call `run_job_a` in linear imperative fashion without needing to define callbacks. The arguments provided to the callback in the original function are simply returned by the async version. + +## The `async_t` handle + +This library supports cancelling async functions that are currently running. This is done via the `async_t` handle interface. +The handle must provide the methods `cancel()` and `is_cancelled()`, and the purpose of these is to allow the cancelled async function to run any cleanup and free any resources it has created. + +### Example use with `vim.loop.spawn`: + +Typically applications to `vim.loop.spawn` make use of `stdio` pipes for communicating. This involves creating `uv_pipe_t` objects. +If a job is cancelled then these objects must be closed. + +```lua +local function run_job = async.wrap(function(cmd, args, callback) + local stdout = vim.loop.new_pipe(false) + + local raw_handle + raw_handle = vim.loop.spawn(cmd, { args = args, stdio = { nil, stdout }}, + function(code) + stdout:close() + raw_handle:close() + callback(code) + end + ) + + local handle = {} + + handle.is_cancelled = function(_) + return raw_handle.is_closing() + end + + handle.cancel = function(_, cb) + raw_handle:close(function() + stdout:close(cb) + end) + end + + return handle +end) +``` + +So even if `run_job` is called in a deep function stack, calling `cancel()` on any parent async function will allow the job to be cancelled safely. + diff --git a/async.md b/async.md index d1864cf..e568637 100644 --- a/async.md +++ b/async.md @@ -5,7 +5,43 @@ Small async library for Neovim plugins ## Functions -### `sync(func, argc)` +### `running()` + +Returns whether the current execution context is async. + + +#### Returns + + `boolean?` + +--- +### `run(func, callback, ...)` + +Run a function in an async context. + +#### Parameters: + +* `func` (`function`): +* `callback` (`function`): +* `...` (`any`): Arguments for func + +#### Returns + + `async_t`: Handle + +--- +### `wait(argc, func, ...)` + +Wait on a callback style function + +#### Parameters: + +* `argc` (`integer?`): The number of arguments of func. +* `func` (`function`): callback style function to execute +* `...` (`any`): Arguments for func + +--- +### `create(func, argc, strict)` Use this to create a function which executes in an async context but called from a non-async context. Inherently this cannot return anything @@ -15,9 +51,14 @@ Use this to create a function which executes in an async context but * `func` (`function`): * `argc` (`number`): The number of arguments of func. Defaults to 0 +* `strict` (`boolean`): Error when called in non-async context + +#### Returns + + `function(...):async_t` --- -### `void(func)` +### `void(func, strict)` Create a function which executes in an async context but called from a non-async context. @@ -25,9 +66,10 @@ Create a function which executes in an async context but #### Parameters: * `func` (`function`): +* `strict` (`boolean`): Error when called in non-async context --- -### `wrap(func, argc, protected)` +### `wrap(func, argc, protected, strict)` Creates an async function with a callback style function. @@ -36,18 +78,23 @@ Creates an async function with a callback style function. * `func` (`function`): A callback style function to be converted. The last argument must be the callback. * `argc` (`integer`): The number of arguments of func. Must be included. * `protected` (`boolean`): call the function in protected mode (like pcall) +* `strict` (`boolean`): Error when called in non-async context + +#### Returns + + `function`: Returns an async function --- -### `join(n, interrupt_check, thunks)` +### `join(thunks, n, interrupt_check)` Run a collection of async functions (`thunks`) concurrently and return when all have finished. #### Parameters: +* `thunks` (`function[]`): * `n` (`integer`): Max number of thunks to run concurrently * `interrupt_check` (`function`): Function to abort thunks between calls -* `thunks` (`function[]`): --- ### `curry(fn, ...)` diff --git a/ldoc.ltp b/ldoc.ltp index 49d5f73..68f1e15 100644 --- a/ldoc.ltp +++ b/ldoc.ltp @@ -43,6 +43,25 @@ $(lev3) $(subnames): > end > end -- for > end -- if params +> if item.retgroups then + +$(lev3) Returns + +> for _, group in ldoc.ipairs(item.retgroups) do +> for r in group:iter() do +> local type, ctypes = item:return_type(r) +> if type ~= '' then +> if r.text ~= '' then + `$(type)`: $(r.text) +> else + `$(type)` +> end +> else + $(r.text) +> end +> end +> end +> end --- > end diff --git a/lua/async.lua b/lua/async.lua index 01c089f..c786e0d 100644 --- a/lua/async.lua +++ b/lua/async.lua @@ -1,9 +1,13 @@ --- Small async library for Neovim plugins --- @module async +-- Store all the async threads in a weak table so we don't prevent them from +-- being garbage collected +local handles = setmetatable({}, { __mode = 'k' }) + local M = {} --- Coroutine.running() was changed between Lua 5.1 and 5.2: +-- Note: coroutine.running() was changed between Lua 5.1 and 5.2: -- - 5.1: Returns the running coroutine, or nil when called by the main thread. -- - 5.2: Returns the running coroutine plus a boolean, true when the running -- coroutine is the main one. @@ -11,46 +15,138 @@ local M = {} -- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT -- -- We need to handle both. -local main_co_or_nil = coroutine.running() -local function execute(func, callback, ...) +--- Returns whether the current execution context is async. +--- +--- @treturn boolean? +function M.running() + local current = coroutine.running() + if current and handles[current] then + return true + end +end + +local function is_Async_T(handle) + if handle + and type(handle) == 'table' + and vim.is_callable(handle.cancel) + and vim.is_callable(handle.is_cancelled) then + return true + end +end + +local Async_T = {} + +-- Analogous to uv.close +function Async_T:cancel(cb) + -- Cancel anything running on the event loop + if self._current and not self._current:is_cancelled() then + self._current:cancel(cb) + end +end + +function Async_T.new(co) + local handle = setmetatable({}, { __index = Async_T }) + handles[co] = handle + return handle +end + +-- Analogous to uv.is_closing +function Async_T:is_cancelled() + return self._current and self._current:is_cancelled() +end + +--- Run a function in an async context. +--- @tparam function func +--- @tparam function callback +--- @tparam any ... Arguments for func +--- @treturn async_t Handle +function M.run(func, callback, ...) + vim.validate { + func = { func, 'function' }, + callback = { callback, 'function', true } + } + local co = coroutine.create(func) + local handle = Async_T.new(co) local function step(...) local ret = {coroutine.resume(co, ...)} - local stat, nargs, protected, err_or_fn = unpack(ret) + local ok = ret[1] - if not stat then - error(string.format("The coroutine failed with this message: %s\n%s", - err_or_fn, debug.traceback(co))) + if not ok then + local err = ret[2] + error(string.format("The coroutine failed with this message:\n%s\n%s", + err, debug.traceback(co))) end if coroutine.status(co) == 'dead' then if callback then - callback(unpack(ret, 4)) + callback(unpack(ret, 4, table.maxn(ret))) end return end - assert(type(err_or_fn) == 'function', "type error :: expected func") + local nargs, fn = ret[2], ret[3] + local args = {select(4, unpack(ret))} - local args = {select(5, unpack(ret))} + assert(type(fn) == 'function', "type error :: expected func") - if protected then - args[nargs] = function(...) - step(true, ...) - end - local ok, err = pcall(err_or_fn, unpack(args, 1, nargs)) - if not ok then - step(false, err) - end - else - args[nargs] = step - err_or_fn(unpack(args, 1, nargs)) + args[nargs] = step + + local r = fn(unpack(args, 1, nargs)) + if is_Async_T(r) then + handle._current = r end end step(...) + return handle +end + +local function wait(argc, func, ...) + vim.validate { + argc = { argc, 'number' }, + func = { func, 'function' }, + } + + -- Always run the wrapped functions in xpcall and re-raise the error in the + -- coroutine. This makes pcall work as normal. + local function pfunc(...) + local args = { ... } + local cb = args[argc] + args[argc] = function(...) + cb(true, ...) + end + xpcall(func, function(err) + cb(false, err, debug.traceback()) + end, unpack(args, 1, argc)) + end + + local ret = {coroutine.yield(argc, pfunc, ...)} + + local ok = ret[1] + if not ok then + local _, err, traceback = unpack(ret) + error(string.format("Wrapped function failed: %s\n%s", err, traceback)) + end + + return unpack(ret, 2, table.maxn(ret)) +end + +--- Wait on a callback style function +--- +--- @tparam integer? argc The number of arguments of func. +--- @tparam function func callback style function to execute +--- @tparam any ... Arguments for func +function M.wait(...) + if type(...) == 'number' then + return wait(...) + else + -- Assume argc is equal to the number of passed arguments (- 1 for function + -- that is first argument, + 1 for callback that hasn't been passed). + return wait(select('#', ...), ...) + end end --- Use this to create a function which executes in an async context but @@ -58,50 +154,70 @@ end --- since it is non-blocking --- @tparam function func --- @tparam number argc The number of arguments of func. Defaults to 0 -function M.sync(func, argc) +--- @tparam boolean strict Error when called in non-async context +--- @treturn function(...):async_t +function M.create(func, argc, strict) + vim.validate { + func = { func, 'function' }, + argc = { argc, 'number', true } + } argc = argc or 0 return function(...) - if coroutine.running() ~= main_co_or_nil then + if M.running() then + if strict then + error('This function must run in a non-async context') + end return func(...) end - local callback = select(argc+1, ...) - execute(func, callback, unpack({...}, 1, argc)) + local callback = select(argc + 1, ...) + return M.run(func, callback, unpack({...}, 1, argc)) end end --- Create a function which executes in an async context but --- called from a non-async context. --- @tparam function func -function M.void(func) +--- @tparam boolean strict Error when called in non-async context +function M.void(func, strict) + vim.validate { func = { func, 'function' } } return function(...) - if coroutine.running() ~= main_co_or_nil then + if M.running() then + if strict then + error('This function must run in a non-async context') + end return func(...) end - execute(func, nil, ...) + return M.run(func, nil, ...) end end --- Creates an async function with a callback style function. +--- --- @tparam function func A callback style function to be converted. The last argument must be the callback. --- @tparam integer argc The number of arguments of func. Must be included. ---- @tparam boolean protected call the function in protected mode (like pcall) ---- @return function Returns an async function -function M.wrap(func, argc, protected) - assert(argc) +--- @tparam boolean strict Error when called in non-async context +--- @treturn function Returns an async function +function M.wrap(func, argc, strict) + vim.validate { + argc = { argc, 'number' }, + } return function(...) - if coroutine.running() == main_co_or_nil then + if not M.running() then + if strict then + error('This function must run in an async context') + end return func(...) end - return coroutine.yield(argc, protected, func, ...) + return M.wait(argc, func, ...) end end --- Run a collection of async functions (`thunks`) concurrently and return when --- all have finished. +--- @tparam function[] thunks --- @tparam integer n Max number of thunks to run concurrently --- @tparam function interrupt_check Function to abort thunks between calls ---- @tparam function[] thunks -function M.join(n, interrupt_check, thunks) +function M.join(thunks, n, interrupt_check) local function run(finish) if #thunks == 0 then return finish() @@ -113,7 +229,7 @@ function M.join(n, interrupt_check, thunks) local ret = {} local function cb(...) - ret[#ret+1] = {...} + ret[#ret + 1] = {...} to_go = to_go - 1 if to_go == 0 then finish(ret) @@ -130,7 +246,10 @@ function M.join(n, interrupt_check, thunks) end end - return coroutine.yield(1, false, run) + if not M.running() then + return run + end + return M.wait(1, false, run) end --- Partially applying arguments to an async function @@ -142,7 +261,7 @@ function M.curry(fn, ...) return function(...) local other = {...} for i = 1, select('#', ...) do - args[nargs+i] = other[i] + args[nargs + i] = other[i] end fn(unpack(args)) end