diff --git a/tests/unit/client_connect_spec.lua b/tests/unit/client_connect_spec.lua new file mode 100644 index 0000000..430dd71 --- /dev/null +++ b/tests/unit/client_connect_spec.lua @@ -0,0 +1,62 @@ +local client = require("swank.client") +local transport_mod = require("swank.transport") + +describe("client.connect behaviour", function() + it("handles transport connect error and notifies user", function() + -- Ensure config provides defaults + local swank = require("swank") + swank.config = swank.config or {} + swank.config.server = { host = "127.0.0.1", port = 4005 } + + -- Mock transport.Transport.new to return a transport whose connect returns error + local orig_transport = transport_mod.Transport + transport_mod.Transport = { + new = function(on_message, on_disconnect) + return { + connect = function(self, host, port, cb) + -- support both colon and dot call styles; call cb with error + cb("econnrefused") + end, + disconnect = function() end, + } + end, + } + + local notified = false + local orig_notify = vim.notify + vim.notify = function(msg, _lvl) if msg:find("connection failed") then notified = true end end + + -- Call client.connect and assert it handles failure + client.connect(nil, nil) + + -- restore + vim.notify = orig_notify + transport_mod.Transport = orig_transport + end) + + it("successful connect sets connected state", function() + local swank = require("swank") + swank.config = swank.config or {} + swank.config.server = { host = "127.0.0.1", port = 4005 } + + local orig_transport = transport_mod.Transport + transport_mod.Transport = { + new = function(on_message, on_disconnect) + return { + connect = function(self, host, port, cb) + cb(nil) + end, + send = function() end, + disconnect = function() end, + } + end, + } + + client._test_reset() + client.connect(nil, nil) + assert.is_true(client.is_connected()) + + transport_mod.Transport = orig_transport + client._test_reset() + end) +end) diff --git a/tests/unit/transport_connect_spec.lua b/tests/unit/transport_connect_spec.lua new file mode 100644 index 0000000..b572995 --- /dev/null +++ b/tests/unit/transport_connect_spec.lua @@ -0,0 +1,65 @@ +local transport_mod = require("swank.transport") + +describe("transport connect behavior", function() + it("connect success: invokes on_message via read_start callback", function() + local received = {} + local disconnected = false + + -- Mock uv.new_tcp to return a controllable handle + local orig_uv = vim.uv + local handle = {} + function handle:connect(host, port, cb) + -- simulate async connect success + cb(nil) + end + function handle:read_start(cb) + -- store the read callback; we'll simulate the incoming frame after connect + self._read_cb = cb + end + function handle:close() end + vim.uv = { new_tcp = function() return handle end } + + local t = transport_mod.Transport.new(function(msg) table.insert(received, msg) end, + function() disconnected = true end) + + local connected_err = nil + t:connect("127.0.0.1", 4005, function(err) connected_err = err end) + assert.is_nil(connected_err) + + -- simulate an incoming framed message (bypass scheduling) + local body = "(hello)" + local frame = string.format("%06x", #body) .. body + t:_feed(frame) + + assert.equals(1, #received) + assert.equals("(hello)", received[1]) + + vim.uv = orig_uv + end) + + it("connect error: closes handle and returns error", function() + local notified = false + local orig_notify = vim.notify + vim.notify = function(msg, _level) if msg:find("connection failed") then notified = true end end + + local orig_uv = vim.uv + local closed = false + local handle = {} + function handle:connect(host, port, cb) + cb("econnrefused") + end + function handle:close() closed = true end + function handle:read_start(cb) end + vim.uv = { new_tcp = function() return handle end } + + local t = transport_mod.Transport.new(function() end, function() end) + local got_err = nil + t:connect("127.0.0.1", 4005, function(err) got_err = err end) + + -- on error, transport should not have a live handle + assert.is_nil(t.handle) + + vim.notify = orig_notify + vim.uv = orig_uv + end) +end)