diff --git a/__tests__/utils.ts b/__tests__/utils.ts index b757201..b567440 100644 --- a/__tests__/utils.ts +++ b/__tests__/utils.ts @@ -74,7 +74,7 @@ export function createMessage(message: Partial> & { [tagNa Target: env.Process.Id, Owner: env.Process.Owner, From: env.Process.Owner, - ["Block-Height"]: "1", + ["Block-Height"]: 1, Timestamp: defaultTimestamp, Module: "examplemodule", Cron: false, diff --git a/controller/controller.lua b/controller/controller.lua index d680c4e..b16c273 100644 --- a/controller/controller.lua +++ b/controller/controller.lua @@ -224,18 +224,12 @@ Handlers.add( -- TODO: timeout here? (what if this doesn't return in time, the liquidation remains in a pending state) -- liquidate the loan - ao.send({ + local loanLiquidationRes = ao.send({ Target = msg.From, Action = "Transfer", Quantity = msg.Tags.Quantity, Recipient = Tokens[msg.From] - }) - - -- get result of liquidation - local loanLiquidationRes = Handlers.receive({ - From = Tokens[msg.From], - ["X-Reference"] = tostring(ao.reference) - }) + }).receive(Tokens[msg.From]) -- check loan liquidation result if loanLiquidationRes.Tags.Error then @@ -408,19 +402,13 @@ function tokens.spawnProtocolLogo(collateralLogo) -- message that spawns the logo -- we're sending this to ourselves - ao.send({ + ---@type Message + local spawnedImage = ao.send({ Target = ao.id, Action = "Spawn-Logo", ["Content-Type"] = "image/svg+xml", Data = logoPart1 .. "/" .. collateralLogo .. logoPart2 - }) - - -- now receive the message we are sending ourselves - ---@type Message - local spawnedImage = Handlers.receive({ - From = ao.id, - Reference = tostring(ao.reference) - }) + }).receive(ao.id) return spawnedImage.Id end diff --git a/package.json b/package.json index a169ae5..8387a77 100644 --- a/package.json +++ b/package.json @@ -12,6 +12,6 @@ "typescript": "^5.5.4" }, "dependencies": { - "@permaweb/ao-loader": "^0.0.42" + "@permaweb/ao-loader": "^0.0.43" } } diff --git a/src/borrow/pool.lua b/src/borrow/pool.lua index f2628f5..b6af4cb 100644 --- a/src/borrow/pool.lua +++ b/src/borrow/pool.lua @@ -37,14 +37,16 @@ function mod.setup(msg) ---@type string[] Friends = Friends or json.decode(ao.env.Process.Tags.Friends or "[]") - -- global current timestamp for the oracle + -- global current timestamp and block for the oracle Timestamp = msg.Timestamp + Block = msg["Block-Height"] end --- This syncs the global timestamp using the current message +-- This syncs the global timestamp anc block using the current message ---@type HandlerFunction function mod.syncTimestamp(msg) Timestamp = msg.Timestamp + Block = msg["Block-Height"] end return mod diff --git a/src/liquidations/oracle.lua b/src/liquidations/oracle.lua index df6969f..05e2515 100644 --- a/src/liquidations/oracle.lua +++ b/src/liquidations/oracle.lua @@ -15,7 +15,7 @@ function mod.setup() -- oracle process id Oracle = Oracle or ao.env.Process.Tags.Oracle - -- oracle delay tolerance in miliseconds + -- oracle delay tolerance in milliseconds ---@type number MaxOracleDelay = MaxOracleDelay or tonumber(ao.env.Process.Tags["Oracle-Delay-Tolerance"]) or 0 @@ -64,7 +64,7 @@ function mod.getPrice(...) Target = Oracle, Action = "v2.Request-Latest-Data", Tickers = json.encode(pricesToSync) - }).receive().Data + }).receive(nil, Block + 1).Data -- check if there was any data returned assert(rawData ~= nil and rawData ~= "", "No data returned from the oracle") diff --git a/src/process.lua b/src/process.lua index 6e9151d..3c48bdb 100644 --- a/src/process.lua +++ b/src/process.lua @@ -69,7 +69,17 @@ local function setup_handlers() Handlers.add( "borrow-loan-interest-sync-dynamic", Handlers.utils.continue(Handlers.utils.hasMatchingTagOf("Action", { - "Borrow", "Repay", "Borrow-Balance", "Borrow-Capacity", "Position", "Global-Position", "Positions", "Redeem", "Transfer", "Liquidate-Borrow" + "Borrow", + "Repay", + "Borrow-Balance", + "Borrow-Capacity", + "Position", + "Global-Position", + "Positions", + "Redeem", + "Transfer", + "Liquidate-Borrow", + "Mint" })), interest.syncInterests ) @@ -115,18 +125,17 @@ local function setup_handlers() config.setLiquidationThreshold ) - Handlers.add( - "liquidate-borrow", - { + Handlers.advanced({ + name = "liquidate-borrow", + pattern = { From = CollateralID, Action = "Credit-Notice", Sender = ao.env.Process.Owner, ["X-Action"] = "Liquidate-Borrow" }, - liquidate.liquidateBorrow, - nil, - liquidate.refund - ) + handle = liquidate.liquidateBorrow, + errorHandler = liquidate.refund + }) Handlers.add( "liquidate-position", { From = ao.env.Process.Owner, Action = "Liquidate-Position" }, @@ -148,13 +157,16 @@ local function setup_handlers() Handlers.utils.hasMatchingTag("Action", "Borrow"), borrow ) - Handlers.add( - "borrow-repay", - { From = CollateralID, Action = "Credit-Notice", ["X-Action"] = "Repay" }, - repay.handler, - nil, - repay.error - ) + Handlers.advanced({ + name = "borrow-repay", + pattern = { + From = CollateralID, + Action = "Credit-Notice", + ["X-Action"] = "Repay" + }, + handle = repay.handler, + errorHandler = repay.error + }) Handlers.add( "borrow-position-balance", Handlers.utils.hasMatchingTag("Action", "Borrow-Balance"), @@ -181,13 +193,16 @@ local function setup_handlers() position.allPositions ) - Handlers.add( - "supply-mint", - { From = CollateralID, Action = "Credit-Notice", ["X-Action"] = "Mint" }, - mint.handler, - nil, - mint.error - ) + Handlers.advanced({ + name = "supply-mint", + pattern = { + From = CollateralID, + Action = "Credit-Notice", + ["X-Action"] = "Mint" + }, + handle = mint.handler, + errorHandler = mint.error + }) Handlers.add( "supply-price", Handlers.utils.hasMatchingTag("Action", "Get-Price"), diff --git a/src/utils/ao.lua b/src/utils/ao.lua index 6a7e6a8..01301c2 100644 --- a/src/utils/ao.lua +++ b/src/utils/ao.lua @@ -197,11 +197,16 @@ function ao.send(msg) resolver) end - message.receive = function(...) - local from = message.Target - if select("#", ...) == 1 then from = select(1, ...) end - return - Handlers.receive({From = from, ["X-Reference"] = referenceString}) + message.receive = function(from, timeout) + if from == nil then from = message.Target end + + local result, expired = Handlers.receive({ + From = from, + ["X-Reference"] = referenceString + }, timeout) + assert(not expired, "Response expired") + + return result end return message @@ -267,13 +272,15 @@ function ao.spawn(module, msg) }, callback) end - spawn.receive = function() - return Handlers.receive({ + spawn.receive = function(timeout) + local result, expired = Handlers.receive({ Action = "Spawned", From = ao.id, ["Reference"] = spawnRef - }) + }, timeout) + assert(not expired, "Response expired") + return result end return spawn diff --git a/src/utils/handlers.lua b/src/utils/handlers.lua index fda7dc2..134cb02 100644 --- a/src/utils/handlers.lua +++ b/src/utils/handlers.lua @@ -1,6 +1,26 @@ -- Copyright (c) 2024 Forward Research -- Code from the aos codebase: https://github.com/permaweb/aos +--- The Handlers library provides a flexible way to manage and execute a series of handlers based on patterns. Each handler consists of a pattern function, a handle function, and a name. This library is suitable for scenarios where different actions need to be taken based on varying input criteria. Returns the handlers table. +-- @module handlers + +--- The handlers table +-- @table handlers +-- @field _version The version number of the handlers module +-- @field list The list of handlers +-- @field coroutines The coroutines of the handlers +-- @field onceNonce The nonce for the once handlers +-- @field utils The handlers-utils module +-- @field generateResolver The generateResolver function +-- @field receive The receive function +-- @field once The once function +-- @field add The add function +-- @field append The append function +-- @field prepend The prepend function +-- @field setActive The handler activation function +-- @field advanced The advanced handler function +-- @field remove The remove function +-- @field evaluate The evaluate function local handlers = { _version = "0.0.5" } local coroutine = require('coroutine') local utils = require('.utils.utils') @@ -17,7 +37,12 @@ else end handlers.onceNonce = 0 - +--- Given an array, a property name, and a value, returns the index of the object in the array that has the property with the value. +-- @lfunction findIndexByProp +-- @tparam {table[]} array The array to search through +-- @tparam {string} prop The property name to check +-- @tparam {any} value The value to check for in the property +-- @treturn {number | nil} The index of the object in the array that has the property with the value, or nil if no such object is found local function findIndexByProp(array, prop, value) for index, object in ipairs(array) do if object[prop] == value then @@ -27,18 +52,10 @@ local function findIndexByProp(array, prop, value) return nil end -local function assertAddArgs(name, pattern, handle, maxRuns) - assert( - type(name) == 'string' and - (type(pattern) == 'function' or type(pattern) == 'table' or type(pattern) == 'string'), - 'Invalid arguments given. Expected: \n' .. - '\tname : string, ' .. - '\tpattern : Action : string | MsgMatch : table,\n' .. - '\t\tfunction(msg: Message) : {-1 = break, 0 = skip, 1 = continue},\n' .. - '\thandle(msg : Message) : void) | Resolver,\n' .. - '\tMaxRuns? : number | "inf" | nil') -end - +--- Given a resolver specification, returns a resolver function. +-- @function generateResolver +-- @tparam {table | function} resolveSpec The resolver specification +-- @treturn {function} A resolver function function handlers.generateResolver(resolveSpec) return function(msg) -- If the resolver is a single function, call it. @@ -46,30 +63,57 @@ function handlers.generateResolver(resolveSpec) if type(resolveSpec) == "function" then return resolveSpec(msg) else - for matchSpec, func in pairs(resolveSpec) do - if utils.matchesSpec(msg, matchSpec) then - return func(msg) - end + for matchSpec, func in pairs(resolveSpec) do + if utils.matchesSpec(msg, matchSpec) then + return func(msg) end + end end end end --- Returns the next message that matches the pattern +--- Given a pattern, returns the next message that matches the pattern. -- This function uses Lua's coroutines under-the-hood to add a handler, pause, -- and then resume the current coroutine. This allows us to effectively block -- processing of one message until another is received that matches the pattern. -function handlers.receive(pattern) +-- @function receive +-- @tparam {table | function} pattern The pattern to check for in the message +-- @tparam {table | nil} timeout Timeout after which the handler will error +function handlers.receive(pattern, timeout) local self = coroutine.running() - handlers.once(pattern, function (msg) + local function resume(msg, expired) -- If the result of the resumed coroutine is an error then we should bubble it up to the process - local _, success, errmsg = coroutine.resume(self, msg) + local _, success, errmsg = coroutine.resume(self, msg, expired) assert(success, errmsg) - end) + end + + handlers.advanced({ + name = "_once_" .. tostring(handlers.onceNonce), + position = "prepend", + pattern = pattern, + maxRuns = 1, + timeout = timeout, + handle = function (msg) + resume(msg, false) + end, + onRemove = function (reason) + if reason ~= "timeout" then return end + resume({}, true) + end + }) + handlers.onceNonce = handlers.onceNonce + 1 + return coroutine.yield(pattern) end +--- Given a name, a pattern, and a handle, adds a handler to the list. +-- If name is not provided, "_once_" prefix plus onceNonce will be used as the name. +-- Adds handler with maxRuns of 1 such that it will only be called once then removed from the list. +-- @function once +-- @tparam {string} name The name of the handler +-- @tparam {table | function | string} pattern The pattern to check for in the message +-- @tparam {function} handle The function to call if the pattern matches function handlers.once(...) local name, pattern, handle if select("#", ...) == 3 then @@ -82,174 +126,314 @@ function handlers.once(...) pattern = select(1, ...) handle = select(2, ...) end - handlers.add(name, pattern, handle, 1) + handlers.prepend(name, pattern, handle, 1) end +--- Given a name, a pattern, and a handle, adds a handler to the list. +-- @function add +-- @tparam {string} name The name of the handler +-- @tparam {table | function | string} pattern The pattern to check for in the message +-- @tparam {function} handle The function to call if the pattern matches +-- @tparam {number | string | nil} maxRuns The maximum number of times the handler should run, or nil if there is no limit function handlers.add(...) + -- select arguments based on the amount of arguments provided local args = select("#", ...) local name = select(1, ...) local pattern = select(1, ...) local handle = select(2, ...) - local maxRuns, errorHandler + local maxRuns if args >= 3 then pattern = select(2, ...) handle = select(3, ...) end if args >= 4 then maxRuns = select(4, ...) end - if args == 5 then errorHandler = select(5, ...) end - - assertAddArgs(name, pattern, handle, maxRuns) - - handle = handlers.generateResolver(handle) - - -- update existing handler by name - local idx = findIndexByProp(handlers.list, "name", name) - if idx ~= nil and idx > 0 then - -- found update - handlers.list[idx].pattern = pattern - handlers.list[idx].handle = handle - handlers.list[idx].maxRuns = maxRuns - handlers.list[idx].errorHandler = errorHandler - else - -- not found then add - table.insert(handlers.list, { pattern = pattern, handle = handle, name = name, maxRuns = maxRuns, errorHandler = errorHandler }) - end - return #handlers.list -end - -function handlers.append(...) - local args = select("#", ...) - local name = select(1, ...) - local pattern = select(1, ...) - local handle = select(2, ...) - - local maxRuns, errorHandler - - if args >= 3 then - pattern = select(2, ...) - handle = select(3, ...) - end - if args >= 4 then maxRuns = select(4, ...) end - if args == 5 then errorHandler = select(5, ...) end - - assertAddArgs(name, pattern, handle, maxRuns) - - handle = handlers.generateResolver(handle) - -- update existing handler by name - local idx = findIndexByProp(handlers.list, "name", name) - if idx ~= nil and idx > 0 then - -- found update - handlers.list[idx].pattern = pattern - handlers.list[idx].handle = handle - handlers.list[idx].maxRuns = maxRuns - handlers.list[idx].errorHandler = errorHandler - else - table.insert(handlers.list, { pattern = pattern, handle = handle, name = name, maxRuns = maxRuns, errorHandler = errorHandler }) - end - - + -- configure handler + return handlers.advanced({ + name = name, + pattern = pattern, + handle = handle, + maxRuns = maxRuns + }) end +--- Appends a new handler to the end of the handlers list. +-- @function append +-- @tparam {string} name The name of the handler +-- @tparam {table | function | string} pattern The pattern to check for in the message +-- @tparam {function} handle The function to call if the pattern matches +-- @tparam {number | string | nil} maxRuns The maximum number of times the handler should run, or nil if there is no limit +handlers.append = handlers.add + +--- Prepends a new handler to the beginning of the handlers list. +-- @function prepend +-- @tparam {string} name The name of the handler +-- @tparam {table | function | string} pattern The pattern to check for in the message +-- @tparam {function} handle The function to call if the pattern matches +-- @tparam {number | string | nil} maxRuns The maximum number of times the handler should run, or nil if there is no limit function handlers.prepend(...) + -- select arguments based on the amount of arguments provided local args = select("#", ...) local name = select(1, ...) local pattern = select(1, ...) local handle = select(2, ...) - local maxRuns, errorHandler + local maxRuns if args >= 3 then pattern = select(2, ...) handle = select(3, ...) end if args >= 4 then maxRuns = select(4, ...) end - if args == 5 then errorHandler = select(5, ...) end - - assertAddArgs(name, pattern, handle, maxRuns) - handle = handlers.generateResolver(handle) - - -- update existing handler by name - local idx = findIndexByProp(handlers.list, "name", name) - if idx ~= nil and idx > 0 then - -- found update - handlers.list[idx].pattern = pattern - handlers.list[idx].handle = handle - handlers.list[idx].maxRuns = maxRuns - handlers.list[idx].errorHandler = errorHandler - else - table.insert(handlers.list, 1, { pattern = pattern, handle = handle, name = name, maxRuns = maxRuns, errorHandler = errorHandler }) - end - - + -- configure handler + return handlers.advanced({ + name = name, + pattern = pattern, + handle = handle, + maxRuns = maxRuns, + position = 'prepend' + }) end +--- Returns an object that allows adding a new handler before a specified handler. +-- @function before +-- @tparam {string} handleName The name of the handler before which the new handler will be added +-- @treturn {table} An object with an `add` method to insert the new handler function handlers.before(handleName) assert(type(handleName) == 'string', 'Handler name MUST be a string') - local idx = findIndexByProp(handlers.list, "name", handleName) return { - add = function (name, pattern, handle, maxRuns, errorHandler) - assertAddArgs(name, pattern, handle, maxRuns) - - handle = handlers.generateResolver(handle) - - if idx then - table.insert(handlers.list, idx, { pattern = pattern, handle = handle, name = name, maxRuns = maxRuns, errorHandler = errorHandler }) - end - + add = function (name, pattern, handle, maxRuns) + -- configure handler + return handlers.advanced({ + name = name, + pattern = pattern, + handle = handle, + maxRuns = maxRuns, + position = { + type = 'before', + target = handleName + } + }) end } end +--- Returns an object that allows adding a new handler after a specified handler. +-- @function after +-- @tparam {string} handleName The name of the handler after which the new handler will be added +-- @treturn {table} An object with an `add` method to insert the new handler function handlers.after(handleName) assert(type(handleName) == 'string', 'Handler name MUST be a string') - local idx = findIndexByProp(handlers.list, "name", handleName) + return { - add = function (name, pattern, handle, maxRuns, errorHandler) - assertAddArgs(name, pattern, handle, maxRuns) - - handle = handlers.generateResolver(handle) - - if idx then - table.insert(handlers.list, idx + 1, { pattern = pattern, handle = handle, name = name, maxRuns = maxRuns, errorHandler = errorHandler }) - end - + add = function (name, pattern, handle, maxRuns) + -- configure handler + return handlers.advanced({ + name = name, + pattern = pattern, + handle = handle, + maxRuns = maxRuns, + position = { + type = 'after', + target = handleName + } + }) end } end +--- Allows activating/deactivating a handler +-- @function setActive +-- @tparam {string} name The target handler's name +-- @tparam {boolean} status The handlers active status +function handlers.setActive(name, status) + assert(type(status) == 'boolean', 'Invalid status: must be a boolean') + + -- find handler + local idx = findIndexByProp(handlers.list, 'name', name) + + -- not found + if idx == nil or idx <= 0 then return end + + -- reverse provided status + handlers.list[idx].inactive = not status +end + +--- Allows creating and adding a handler with advanced options using a simple configuration table +-- @function advanced +-- @tparam {table} config The new handler's configuration +function handlers.advanced(config) + -- validate handler config + assert(type(config.name) == 'string', 'Invalid handler name: must be a string') + assert( + type(config.pattern) == 'function' or type(config.pattern) == 'table' or type(config.pattern) == 'string', + 'Invalid pattern: must be a function, a table or a string' + ) + + if config.position ~= nil then + assert( + type(config.position) == 'table' or config.position == 'append' or config.position == 'prepend', + 'Invalid position: must be a table or "append"/"prepend"' + ) + + if type(config.position) == 'table' then + assert( + config.position.type == 'append' or config.position.type == 'prepend' or config.position.type == 'before' or config.position.type == 'after', + 'Invalid position.type: must be one of ("append", "prepend", "before", "after")' + ) + assert( + config.position.target == nil or type(config.position.target) == 'string', + 'Invalid position.target: must be a string (handler name)' + ) + end + end + + assert( + type(config.handle) == 'function' or type(config.handle) == 'table', + 'Invalid handle: must be a function or a table of resolvers' + ) + assert( + config.runType == nil or config.runType == 'continue' or config.runType == 'break' or config.runType == 1 or config.runType == -1, + 'Invalid runType: must be "continue"/1 or "break"/-1' + ) + assert( + config.maxRuns == nil or type(config.maxRuns) == 'number', + "Invalid maxRuns: must be an integer" + ) + assert( + config.errorHandler == nil or type(config.errorHandler) == 'function', + "Invalid error handler: must be a function" + ) + assert( + config.onRemove == nil or type(config.onRemove) == 'function', + "Invalid onRemove: must be a function" + ) + assert( + config.inactive == nil or type(config.inactive) == 'boolean', + 'Invalid inactive: must be a boolean' + ) + + if config.timeout then + assert( + type(config.timeout) == 'table' or type(config.timeout) == 'number', + 'Invalid timeout: must be a table or a number' + ) + + if type(config.timeout) == 'table' then + assert( + config.timeout.type == 'milliseconds' or config.timeout.type == 'blocks', + 'Invalid timeout.type: must be of ("milliseconds" or "blocks")' + ) + assert( + type(config.timeout.value) == 'number', + 'Invalid timeout.value: must be an integer' + ) + end + end + + -- generate resolver for the handler + config.handle = handlers.generateResolver(config.handle) + + -- handle timeout when it is a number (blocks) + if type(config.timeout) == 'number' then + config.timeout = { + type = 'blocks', + value = config.timeout + } + end + + -- if the handler already exists, find it and update + local idx = findIndexByProp(handlers.list, 'name', config.name) + + if idx ~= nil and idx > 0 then + -- found a handler to update + handlers[idx] = config + else + -- a handler with this name doesn't exist yet, so we add it + -- + -- calculate the position the handler should be added at + -- (by default it's the end of the list) + idx = #handlers.list + 1 + if config.position and config.position ~= 'append' then + if config.position == 'prepend' or config.position.type == 'prepend' then + idx = 1 + elseif type(config.position) == 'table' and config.position.type ~= 'append' then + idx = findIndexByProp(handlers.list, 'name', config.position.target) + + if config.position.type == 'after' and idx and idx > 0 then + idx = idx + 1 + end + + if not idx or idx <= 0 then + return #handlers.list + end + end + end + + -- add handler + table.insert(handlers.list, idx, config) + end + + return #handlers.list +end + +--- Removes a handler from the handlers list by name. +-- @function remove +-- @tparam {string} name The name of the handler to be removed function handlers.remove(name) assert(type(name) == 'string', 'name MUST be string') if #handlers.list == 1 and handlers.list[1].name == name then + if handlers.list[1].onRemove ~= nil then + handlers.list[1].onRemove("user-remove") + end handlers.list = {} - end local idx = findIndexByProp(handlers.list, "name", name) if idx ~= nil and idx > 0 then + if handlers.list[idx].onRemove ~= nil then + handlers.list[idx].onRemove("user-remove") + end table.remove(handlers.list, idx) end - end ---- return 0 to not call handler, -1 to break after handler is called, 1 to continue +--- Evaluates each handler against a given message and environment. Handlers are called in the order they appear in the handlers list. +-- Return 0 to not call handler, -1 to break after handler is called, 1 to continue +-- @function evaluate +-- @tparam {table} msg The message to be processed by the handlers. +-- @tparam {table} env The environment in which the handlers are executed. +-- @treturn The response from the handler(s). Returns a default message if no handler matches. function handlers.evaluate(msg, env) local handled = false assert(type(msg) == 'table', 'msg is not valid') assert(type(env) == 'table', 'env is not valid') - + for _, o in ipairs(handlers.list) do - if o.name ~= "_default" then + if o.name ~= "_default" and not o.inactive then local match = utils.matchesSpec(msg, o.pattern) if not (type(match) == 'number' or type(match) == 'string' or type(match) == 'boolean') then error("Pattern result is not valid, it MUST be string, number, or boolean") end - + + -- ensure the handler hasn't timed out yet + if o.timeout ~= nil then + -- remove handler if it timed out + if (o.timeout.type == 'milliseconds' and o.timeout.value < msg.Timestamp) or (o.timeout.type == 'blocks' and o.timeout.value < msg["Block-Height"]) then + if o.onRemove ~= nil then + o.onRemove("timeout") + o.onRemove = nil + end + handlers.remove(o.name) + match = 0 + end + end + -- handle boolean returns if type(match) == "boolean" and match == true then match = -1 @@ -269,6 +453,16 @@ function handlers.evaluate(msg, env) end if match ~= 0 then + -- the pattern matched, now we overwrite it with the + -- handler's "runType" configuration, if there's any + if o.runType ~= nil then + if type(o.runType) == 'number' then + match = o.runType + else + match = o.runType == 'continue' and 1 or -1 + end + end + if match < 0 then handled = true end @@ -276,12 +470,28 @@ function handlers.evaluate(msg, env) local status, err = pcall(o.handle, msg, env) if not status then if not o.errorHandler then error(err) - else pcall(o.errorHandler, msg, env, err) end + else + -- allow error handler to override the default + -- handler behavior (break/continue) + local errorHandlerRes = o.errorHandler(msg, env, err) + + if errorHandlerRes ~= nil then + if type(errorHandlerRes) ~= "number" then + match = errorHandlerRes == "break" and -1 or 1 + elseif errorHandlerRes == 1 or errorHandlerRes == -1 then + match = errorHandlerRes + end + end + end end -- remove handler if maxRuns is reached. maxRuns can be either a number or "inf" if o.maxRuns ~= nil and o.maxRuns ~= "inf" then o.maxRuns = o.maxRuns - 1 if o.maxRuns == 0 then + if o.onRemove ~= nil then + o.onRemove("expired") + o.onRemove = nil + end handlers.remove(o.name) end end @@ -296,4 +506,4 @@ function handlers.evaluate(msg, env) assert(handled, "The request could not be handled") end -return handlers \ No newline at end of file +return handlers diff --git a/src/utils/scheduler.lua b/src/utils/scheduler.lua index 8635807..0965523 100644 --- a/src/utils/scheduler.lua +++ b/src/utils/scheduler.lua @@ -17,32 +17,57 @@ function mod.schedule(...) -- if there are no messages to be sent, we don't do anything if #messages == 0 then return {} end - ---@type HandlerFunction - local function responseHandler(msg) - table.insert(responses, msg) - - -- continue execution when all responses are back - if #responses == #messages then - -- if the result of the resumed coroutine is an error, then we should bubble it up to the process - local _, success, errmsg = coroutine.resume(thread, responses) - - assert(success, errmsg) - end - end - -- send messages for _, msg in ipairs(messages) do ao.send(msg) -- wait for response - Handlers.once( - { From = msg.Target, ["X-Reference"] = tostring(ao.reference) }, - responseHandler - ) + Handlers.advanced({ + name = "_once_" .. tostring(Handlers.onceNonce), + position = "prepend", + pattern = { + From = msg.Target, + ["X-Reference"] = tostring(ao.reference) + }, + maxRuns = 1, + -- TODO: is this an optimal timeout? + timeout = Block + 1, + handle = function (_msg) + table.insert(responses, _msg) + + -- continue execution only when all responses are back + if #responses == #messages then + -- if the result of the resumed coroutine is an error, then we should bubble it up to the process + local _, success, errmsg = coroutine.resume(thread, responses, false) + + assert(success, errmsg) + end + end, + onRemove = function (reason) + -- do not continue if the handler wasn't removed because of a timeout + -- or if the coroutine has already been resumed + if reason ~= "timeout" or coroutine.status(thread) ~= "suspended" then return end + + -- resume execution on timeout, because a timeout + -- invalidates all results + local _, success, errmsg = coroutine.resume(thread, {}, true) + + assert(success, errmsg) + end + }) + Handlers.onceNonce = Handlers.onceNonce + 1 end -- yield execution, till all responses are back - return coroutine.yield({ From = messages[#messages], ["X-Reference"] = tostring(ao.reference) }) + local result, expired = coroutine.yield({ + From = messages[#messages], + ["X-Reference"] = tostring(ao.reference) + }) + + -- check if expired + assert(not expired, "A scheduled response has expired") + + return result end return mod diff --git a/yarn.lock b/yarn.lock index 467af30..fe78934 100644 --- a/yarn.lock +++ b/yarn.lock @@ -510,10 +510,10 @@ "@jridgewell/resolve-uri" "^3.1.0" "@jridgewell/sourcemap-codec" "^1.4.14" -"@permaweb/ao-loader@^0.0.42": - version "0.0.42" - resolved "https://registry.yarnpkg.com/@permaweb/ao-loader/-/ao-loader-0.0.42.tgz#e9baa81134be2ce84affaa4c7b95b0406493d0e2" - integrity sha512-xQiixn7jcb2x7+TtOPO8nARPlgVbFTTHQzSYgARW3rDGfBWw5mGb2Ax3bspdsiuKKaLxF+dxl6A8v2DprlLWjQ== +"@permaweb/ao-loader@^0.0.43": + version "0.0.43" + resolved "https://registry.yarnpkg.com/@permaweb/ao-loader/-/ao-loader-0.0.43.tgz#bf980be06ec397ad475a023550b5526bf53e2a9b" + integrity sha512-xPYzyKSCqtL0U8oUcCrW+uPpm7IcMncM5IPVGCGKljxA3IQA/HI8S5XA6tcZUaDRCl8VSVsJzqOgkdzy1JGi5w== dependencies: "@permaweb/wasm-metering" "^0.2.2"