Skip to content

Commit 8f7455d

Browse files
committed
Add unified diff support for luatest.assert_equals
This patch adds unified diff output for `t.assert_equals()` failures, using a vendored Lua implementation of google/diff-match-patch (`luatest/vendor/diff_match_patch.lua` taken from [^1]). Closes #412 [^1]: https://github.com/google/diff-match-patch/blob/master/lua/diff_match_patch.lua
1 parent a0930d4 commit 8f7455d

File tree

6 files changed

+2745
-12
lines changed

6 files changed

+2745
-12
lines changed

.luacheckrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
include_files = {"**/*.lua", "*.rockspec", "*.luacheckrc"}
2-
exclude_files = {"build.luarocks/", "lua_modules/", "tmp/", ".luarocks/", ".rocks/"}
2+
exclude_files = {"build.luarocks/", "lua_modules/", "tmp/", ".luarocks/", ".rocks/", "luatest/vendor/"}
33

44
max_line_length = 120

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Added support for unified diff output in `t.assert_equals()` failure messages
6+
when expected and actual values are YAML-serializable (gh-412).
57
- Fixed a bug when the JUnit reporter generated invalid XML for parameterized
68
tests with string arguments (gh-407).
79
- Group and suite hooks must now be registered using the call-style

luatest/assertions.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
local math = require('math')
77

88
local comparator = require('luatest.comparator')
9+
local diff = require('luatest.diff')
910
local mismatch_formatter = require('luatest.mismatch_formatter')
1011
local pp = require('luatest.pp')
1112
local log = require('luatest.log')
@@ -83,6 +84,12 @@ local function error_msg_equality(actual, expected, deep_analysis)
8384
if success then
8485
result = table.concat({result, mismatchResult}, '\n')
8586
end
87+
88+
local diff_result = diff.build_unified_diff(expected, actual)
89+
if diff_result then
90+
result = table.concat({result, 'diff:', diff_result}, '\n')
91+
end
92+
8693
return result
8794
end
8895
return string.format("expected: %s, actual: %s",

luatest/diff.lua

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
local yaml = require('yaml')
2+
local uri = require('uri')
3+
4+
-- diff_match_patch expects bit32
5+
if not rawget(_G, 'bit32') then
6+
_G.bit32 = require('bit')
7+
end
8+
9+
local diff_match_patch = require('luatest.vendor.diff_match_patch')
10+
11+
diff_match_patch.settings({
12+
Diff_Timeout = 0,
13+
Patch_Margin = 1e9,
14+
})
15+
16+
local M = {}
17+
18+
-- Maximum number of distinct line IDs that can be encoded as single-byte chars.
19+
local MAX_LINE_ID = 0x100
20+
21+
local function encode_line_id(id)
22+
if id >= MAX_LINE_ID then
23+
return nil
24+
end
25+
26+
return string.char(id)
27+
end
28+
29+
local function decode_line_id(encoded)
30+
return encoded:byte(1)
31+
end
32+
33+
-- Recursively normalize a value into something that:
34+
-- * is safe and stable for YAML encoding;
35+
-- * produces meaningful diffs for values that provide informative tostring();
36+
-- * does NOT produce noisy diffs for opaque userdata/cdata (newproxy, ffi types, etc).
37+
local function normalize_for_yaml(value)
38+
local t = type(value)
39+
40+
if t == 'table' then
41+
local res = {}
42+
for k, v in pairs(value) do
43+
local nk = normalize_for_yaml(k)
44+
if nk == nil then
45+
-- YAML keys must be representable; fallback to tostring.
46+
nk = tostring(k)
47+
end
48+
res[nk] = normalize_for_yaml(v)
49+
end
50+
return res
51+
end
52+
53+
if t == 'cdata' or t == 'userdata' then
54+
local ok, s = pcall(tostring, value)
55+
if ok and type(s) == 'string' then
56+
return s
57+
end
58+
59+
return '<unknown cdata/userdata>'
60+
end
61+
62+
if t == 'function' or t == 'thread' then
63+
return '<' .. t .. '>'
64+
end
65+
66+
-- other primitive types.
67+
return value
68+
end
69+
70+
-- Encode a Lua value as YAML after normalizing it to a diff-friendly form.
71+
local function encode_yaml(value)
72+
local ok, encoded = pcall(yaml.encode, normalize_for_yaml(value))
73+
if ok then
74+
return encoded
75+
end
76+
end
77+
78+
-- Convert a supported Lua value into a textual form suitable for diffing.
79+
--
80+
-- * Tables are serialized to YAML with recursive normalization.
81+
-- * Strings are used as-is.
82+
-- * Numbers / booleans are converted via tostring().
83+
-- * Top-level opaque userdata/cdata disable diffing when tostring() fails (return nil).
84+
local function as_yaml(value)
85+
local t = type(value)
86+
87+
if t == 'cdata' or t == 'userdata' then
88+
local ok, s = pcall(tostring, value)
89+
if ok and type(s) == 'string' then
90+
return s
91+
end
92+
93+
return nil
94+
end
95+
96+
if t == 'string' then
97+
return value
98+
end
99+
100+
local encoded = encode_yaml(value)
101+
if encoded ~= nil then
102+
return encoded
103+
end
104+
105+
local ok, s = pcall(tostring, value)
106+
if ok and type(s) == 'string' then
107+
return s
108+
end
109+
end
110+
111+
-- Map two multiline texts to compact "char sequences" and shared line table.
112+
-- Returns nil if the number of unique lines exceeds MAX_LINE_ID.
113+
local function lines_to_chars(text1, text2)
114+
local line_array = {}
115+
local line_hash = {}
116+
117+
local function add_line(line)
118+
local id = line_hash[line]
119+
if id == nil then
120+
id = #line_array + 1
121+
local encoded = encode_line_id(id)
122+
if encoded == nil then
123+
return nil
124+
end
125+
line_array[id] = line
126+
line_hash[line] = id
127+
end
128+
129+
return encode_line_id(id)
130+
end
131+
132+
local function munge(text)
133+
local tokens = {}
134+
local start = 1
135+
136+
while true do
137+
local newline_pos = text:find('\n', start, true)
138+
if newline_pos == nil then
139+
local tail = text:sub(start)
140+
if tail ~= '' then
141+
local token = add_line(tail)
142+
if token == nil then
143+
return nil
144+
end
145+
table.insert(tokens, token)
146+
end
147+
break
148+
end
149+
150+
local token = add_line(text:sub(start, newline_pos))
151+
if token == nil then
152+
return nil
153+
end
154+
table.insert(tokens, token)
155+
start = newline_pos + 1
156+
end
157+
158+
return table.concat(tokens)
159+
end
160+
161+
local chars1 = munge(text1)
162+
if chars1 == nil then
163+
return nil
164+
end
165+
166+
local chars2 = munge(text2)
167+
if chars2 == nil then
168+
return nil
169+
end
170+
171+
return chars1, chars2, line_array
172+
end
173+
174+
-- Expand a "char sequence" produced by lines_to_chars back into full text.
175+
local function chars_to_lines(text, line_array)
176+
local out = {}
177+
178+
for i = 1, #text do
179+
local id = decode_line_id(text:sub(i, i))
180+
local line = line_array[id]
181+
if line == nil then
182+
return nil
183+
end
184+
table.insert(out, line)
185+
end
186+
187+
return table.concat(out)
188+
end
189+
190+
-- Compute line-based diff using diff_match_patch, falling back to nil on failure.
191+
local function diff_by_lines(text1, text2)
192+
local chars1, chars2, line_array = lines_to_chars(text1, text2)
193+
if chars1 == nil then
194+
return nil
195+
end
196+
197+
local diffs = diff_match_patch.diff_main(chars1, chars2, false)
198+
diff_match_patch.diff_cleanupSemantic(diffs)
199+
200+
for i, diff in ipairs(diffs) do
201+
local text = chars_to_lines(diff[2], line_array)
202+
if text == nil then
203+
return nil
204+
end
205+
diffs[i][2] = text
206+
end
207+
208+
return diffs
209+
end
210+
211+
-- Normalize patch text from diff_match_patch: unescape it, drop junk lines,
212+
-- and ensure it is valid, readable unified diff.
213+
local function prettify_patch(patch_text)
214+
-- patch_toText() escapes non-ascii symbols using URL escaping. Convert it
215+
-- back to preserve the original values in unified diff output.
216+
patch_text = uri.unescape(patch_text)
217+
218+
local out = {}
219+
220+
for line in (patch_text .. '\n'):gmatch('(.-)\n') do
221+
if line ~= '' and line ~= ' ' then
222+
local first = line:sub(1, 1)
223+
224+
if first ~= '@' and first ~= '+'
225+
and first ~= '-' and first ~= ' ' then
226+
line = ' ' .. line
227+
end
228+
229+
table.insert(out, line)
230+
end
231+
end
232+
233+
return table.concat(out, '\n')
234+
end
235+
236+
--- Build unified diff for expected and actual values serialized to YAML.
237+
-- Tries line-based diff first, falls back to char-based.
238+
-- Returns nil when values can't be serialized or there is no diff.
239+
function M.build_unified_diff(expected, actual)
240+
local expected_text = as_yaml(expected)
241+
local actual_text = as_yaml(actual)
242+
243+
if expected_text == nil or actual_text == nil then
244+
return nil
245+
end
246+
247+
local diffs = diff_by_lines(expected_text, actual_text)
248+
local used_line_diff = true
249+
250+
if diffs == nil then
251+
diffs = diff_match_patch.diff_main(expected_text, actual_text)
252+
used_line_diff = false
253+
end
254+
255+
if not used_line_diff then
256+
diff_match_patch.diff_cleanupSemantic(diffs)
257+
end
258+
259+
local patches = diff_match_patch.patch_make(expected_text,
260+
actual_text, diffs)
261+
local patch_text = diff_match_patch.patch_toText(patches)
262+
263+
if patch_text == '' then
264+
return nil
265+
end
266+
267+
return prettify_patch(patch_text)
268+
end
269+
270+
return M

0 commit comments

Comments
 (0)