Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 92 additions & 18 deletions lua/opencode/events.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,94 @@ local M = {}
---
---@field permissions? opencode.events.permissions.Opts

local heartbeat_timer = vim.uv.new_timer()
---How often `opencode` sends heartbeat events.
local OPENCODE_HEARTBEAT_INTERVAL_MS = 30000
---@type number?
local subscription_job_id = nil

---@class opencode.events.State
---@field heartbeat_timer uv_timer_t
---@field subscription_job_id? number
---@field connected_server? opencode.server.Server

---@type table<integer, opencode.events.State>
local tab_states = {}

---@param tab? integer
---@return opencode.events.State, integer
local function get_state(tab)
tab = tab or vim.api.nvim_get_current_tabpage()
if not tab_states[tab] then
tab_states[tab] = {
heartbeat_timer = vim.uv.new_timer(),
subscription_job_id = nil,
connected_server = nil,
}
end
return tab_states[tab], tab
end

local function refresh_compat_connected_server()
local state = tab_states[vim.api.nvim_get_current_tabpage()]
M.connected_server = state and state.connected_server or nil
end

local function disconnect_state(state)
if state.subscription_job_id then
vim.fn.jobstop(state.subscription_job_id)
end
if state.heartbeat_timer then
state.heartbeat_timer:stop()
end

state.subscription_job_id = nil
state.connected_server = nil
end

local function prune_invalid_tab_states()
for tab, state in pairs(tab_states) do
if not vim.api.nvim_tabpage_is_valid(tab) then
disconnect_state(state)
if state.heartbeat_timer and not state.heartbeat_timer:is_closing() then
state.heartbeat_timer:close()
end
tab_states[tab] = nil
end
end
end

---The currently-connected `opencode` server, if any.
---Executes autocmds for received SSEs with type `OpencodeEvent:<event.type>`, passing the event and server port as data.
---Cleared when the server disposes itself, the connection errors, the heartbeat disappears, or we connect to a new server.
---@type opencode.server.Server?
M.connected_server = nil

function M.get_connected_server(tab)
prune_invalid_tab_states()
local state = tab_states[tab or vim.api.nvim_get_current_tabpage()]
return state and state.connected_server or nil
end

---@param server opencode.server.Server
function M.connect(server)
M.disconnect()
---@param tab? integer
function M.connect(server, tab)
local state
state, tab = get_state(tab)
M.disconnect(tab)

require("opencode.promise")
.resolve(server)
:next(function(_server) ---@param _server opencode.server.Server
subscription_job_id = _server:sse_subscribe(function(response) ---@param response opencode.server.Event
M.connected_server = _server
state.subscription_job_id = _server:sse_subscribe(function(response) ---@param response opencode.server.Event
state.connected_server = _server
refresh_compat_connected_server()

if heartbeat_timer then
heartbeat_timer:start(OPENCODE_HEARTBEAT_INTERVAL_MS + 5000, 0, vim.schedule_wrap(M.disconnect))
if state.heartbeat_timer then
state.heartbeat_timer:start(
OPENCODE_HEARTBEAT_INTERVAL_MS + 5000,
0,
vim.schedule_wrap(function()
M.disconnect(tab)
end)
)
end

if require("opencode.config").opts.events.enabled then
Expand All @@ -44,16 +108,17 @@ function M.connect(server)
event = response,
-- Can't pass metatable through here, so listeners need to reconstruct the server object if they want to use its methods
port = _server.port,
tab = tab,
},
})
end
end, function()
-- This is also called when the connection is closed normally by `vim.fn.jobstop`.
-- i.e. when disconnecting before connecting to a new server.
-- In that case, don't re-execute disconnect - it'd disconnect from the new server.
if M.connected_server == _server then
if state.connected_server == _server then
-- Server disappeared ungracefully, e.g. process killed, network error, etc.
M.disconnect()
M.disconnect(tab)
end
end)
end)
Expand All @@ -62,15 +127,24 @@ function M.connect(server)
end)
end

function M.disconnect()
if subscription_job_id then
vim.fn.jobstop(subscription_job_id)
end
if heartbeat_timer then
heartbeat_timer:stop()
---@param tab? integer
function M.disconnect(tab)
prune_invalid_tab_states()
local state = tab_states[tab or vim.api.nvim_get_current_tabpage()]
if not state then
refresh_compat_connected_server()
return
end

M.connected_server = nil
disconnect_state(state)
refresh_compat_connected_server()
end

vim.api.nvim_create_autocmd("TabEnter", {
callback = function()
prune_invalid_tab_states()
refresh_compat_connected_server()
end,
})

return M
26 changes: 16 additions & 10 deletions lua/opencode/server/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,14 @@ end
---2. The configured port in `require("opencode.config").opts.port`.
---3. All servers, prioritizing one sharing CWD with Neovim, and prompting the user to select if multiple are found.
---@return Promise<opencode.server.Server>
local function find()
---@param tab? integer
local function find(tab)
local Promise = require("opencode.promise")
local port_opt = require("opencode.config").opts.server.port
local connected_server = require("opencode.events").connected_server
local events = require("opencode.events")
local connected_server = events.get_connected_server(tab)
local tabnr = tab and vim.api.nvim_tabpage_get_number(tab) or nil
local nvim_cwd = tabnr and vim.fn.getcwd(-1, tabnr) or vim.fn.getcwd()

return connected_server and Promise.resolve(connected_server)
or type(port_opt) == "number" and Server.new(port_opt)
Expand All @@ -364,7 +368,6 @@ local function find()
return Server.new(port)
end)
or Server.get_all():next(function(servers) ---@param servers opencode.server.Server[]
local nvim_cwd = vim.fn.getcwd()
local servers_in_cwd = vim.tbl_filter(function(server)
-- Overlaps in either direction, with no non-empty mismatch
return server.cwd:find(nvim_cwd, 0, true) == 1 or nvim_cwd:find(server.cwd, 0, true) == 1
Expand All @@ -375,14 +378,15 @@ local function find()
return servers_in_cwd[1]
else
-- Can't guess which one the user wants based on CWD - select from *all*
return require("opencode.ui.select_server").select_server(servers)
return require("opencode.ui.select_server").select_server(servers, { cwd = nvim_cwd })
end
end)
end

---Poll for an `opencode` server, rejecting if not found within five seconds.
---@return Promise<opencode.server.Server>
local function poll()
---@param tab? integer
local function poll(tab)
local Promise = require("opencode.promise")
local poll_timer, timer_err, timer_errname = vim.uv.new_timer()
if not poll_timer then
Expand All @@ -395,7 +399,7 @@ local function poll()
1000,
1000,
vim.schedule_wrap(function()
find()
find(tab)
:next(function(server)
resolve(server)
end)
Expand All @@ -418,9 +422,11 @@ end
---@return Promise<opencode.server.Server>
function Server.get()
local Promise = require("opencode.promise")
local connected_server = require("opencode.events").connected_server
local events = require("opencode.events")
local tab = vim.api.nvim_get_current_tabpage()
local connected_server = events.get_connected_server(tab)

return find()
return find(tab)
:catch(function(err)
if not err then
-- Do nothing when server selection was cancelled
Expand All @@ -434,11 +440,11 @@ function Server.get()
return Promise.reject(err)
end

return poll()
return poll(tab)
end)
:next(function(server) ---@param server opencode.server.Server
if not connected_server or connected_server.port ~= server.port then
require("opencode.events").connect(server)
events.connect(server, tab)
end
return server
end)
Expand Down
2 changes: 1 addition & 1 deletion lua/opencode/status.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end

---@return string
function M.statusline()
local connected_server = require("opencode.events").connected_server
local connected_server = require("opencode.events").get_connected_server()
local port = connected_server and connected_server.port
return M.statusline_icon() .. (port and (" :" .. tostring(port)) or "")
end
Expand Down
Loading
Loading