Skip to content

Commit d8c2af9

Browse files
committed
Add unified diff support for luatest.assert_equals
This patch adds unified diff output for `t.assert_equals()` failures, using a vendored Lua implementation of google/diff-match-patch (`luatest/vendor/diff_match_patch.lua` taken from [^1]). Closes #412 [^1]: https://github.com/google/diff-match-patch/blob/master/lua/diff_match_patch.lua
1 parent a0930d4 commit d8c2af9

File tree

6 files changed

+2782
-13
lines changed

6 files changed

+2782
-13
lines changed

.luacheckrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
include_files = {"**/*.lua", "*.rockspec", "*.luacheckrc"}
2-
exclude_files = {"build.luarocks/", "lua_modules/", "tmp/", ".luarocks/", ".rocks/"}
2+
exclude_files = {"build.luarocks/", "lua_modules/", "tmp/", ".luarocks/", ".rocks/", "luatest/vendor/"}
33

44
max_line_length = 120

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Added support for unified diff output in `t.assert_equals()` failure messages
6+
when expected and actual values are YAML-serializable (gh-412).
57
- Fixed a bug when the JUnit reporter generated invalid XML for parameterized
68
tests with string arguments (gh-407).
79
- Group and suite hooks must now be registered using the call-style

luatest/assertions.lua

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
local math = require('math')
77

88
local comparator = require('luatest.comparator')
9+
local diff = require('luatest.diff')
910
local mismatch_formatter = require('luatest.mismatch_formatter')
1011
local pp = require('luatest.pp')
1112
local log = require('luatest.log')
@@ -83,6 +84,12 @@ local function error_msg_equality(actual, expected, deep_analysis)
8384
if success then
8485
result = table.concat({result, mismatchResult}, '\n')
8586
end
87+
88+
local diff_result = diff.build_unified_diff(expected, actual)
89+
if diff_result then
90+
result = table.concat({result, 'diff:', diff_result}, '\n')
91+
end
92+
8693
return result
8794
end
8895
return string.format("expected: %s, actual: %s",
@@ -470,7 +477,19 @@ end
470477
function M.assert_covers(actual, expected, message)
471478
if not table_covers(actual, expected) then
472479
local str_actual, str_expected = prettystr_pairs(actual, expected)
473-
failure(string.format('expected %s to cover %s', str_actual, str_expected), message, 2)
480+
local sliced_actual = table_slice(actual, expected)
481+
482+
local parts = {
483+
string.format('expected %s to cover %s', str_actual, str_expected),
484+
}
485+
486+
local diff_result = diff.build_unified_diff(expected, sliced_actual)
487+
if diff_result then
488+
table.insert(parts, 'diff:')
489+
table.insert(parts, diff_result)
490+
end
491+
492+
failure(table.concat(parts, '\n'), message, 2)
474493
end
475494
end
476495

luatest/diff.lua

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
local yaml = require('yaml')
2+
local uri = require('uri')
3+
4+
-- diff_match_patch expects bit32
5+
if not rawget(_G, 'bit32') then
6+
_G.bit32 = require('bit')
7+
end
8+
9+
local diff_match_patch = require('luatest.vendor.diff_match_patch')
10+
11+
diff_match_patch.settings({
12+
Diff_Timeout = 0,
13+
Patch_Margin = 1e9,
14+
})
15+
16+
local M = {}
17+
18+
-- Maximum number of distinct line IDs that can be encoded as single-byte chars.
19+
local MAX_LINE_ID = 0x100
20+
21+
local function encode_line_id(id)
22+
if id >= MAX_LINE_ID then
23+
return nil
24+
end
25+
26+
return string.char(id)
27+
end
28+
29+
local function decode_line_id(encoded)
30+
return encoded:byte(1)
31+
end
32+
33+
-- Recursively normalize a value into something that:
34+
-- * is safe and stable for YAML encoding;
35+
-- * produces meaningful diffs for values that provide informative tostring();
36+
-- * does NOT produce noisy diffs for opaque userdata/cdata (newproxy, ffi types, etc).
37+
local function normalize_for_yaml(value)
38+
local t = type(value)
39+
40+
if t == 'table' then
41+
local entries = {}
42+
for k, v in pairs(value) do
43+
local nk = normalize_for_yaml(k)
44+
if nk == nil then
45+
-- YAML keys must be representable; fallback to tostring.
46+
nk = tostring(k)
47+
end
48+
table.insert(entries, {key = nk, value = v})
49+
end
50+
table.sort(entries, function(a, b)
51+
if type(a.key) == 'number' and type(b.key) == 'number' then
52+
return a.key < b.key
53+
end
54+
return tostring(a.key) < tostring(b.key)
55+
end)
56+
57+
local res = {}
58+
for _, entry in ipairs(entries) do
59+
res[entry.key] = normalize_for_yaml(entry.value)
60+
end
61+
return res
62+
end
63+
64+
if t == 'cdata' or t == 'userdata' then
65+
local ok, s = pcall(tostring, value)
66+
if ok and type(s) == 'string' then
67+
return s
68+
end
69+
70+
return '<unknown cdata/userdata>'
71+
end
72+
73+
if t == 'function' or t == 'thread' then
74+
return '<' .. t .. '>'
75+
end
76+
77+
-- other primitive types.
78+
return value
79+
end
80+
81+
-- Encode a Lua value as YAML after normalizing it to a diff-friendly form.
82+
local function encode_yaml(value)
83+
local ok, encoded = pcall(yaml.encode, normalize_for_yaml(value))
84+
if ok then
85+
return encoded
86+
end
87+
end
88+
89+
-- Convert a supported Lua value into a textual form suitable for diffing.
90+
--
91+
-- * Tables are serialized to YAML with recursive normalization.
92+
-- * Strings are used as-is.
93+
-- * Numbers / booleans are converted via tostring().
94+
-- * Top-level opaque userdata/cdata disable diffing when tostring() fails (return nil).
95+
local function as_yaml(value)
96+
local t = type(value)
97+
98+
if t == 'cdata' or t == 'userdata' then
99+
local ok, s = pcall(tostring, value)
100+
if ok and type(s) == 'string' then
101+
return s
102+
end
103+
104+
return nil
105+
end
106+
107+
if t == 'string' then
108+
return value
109+
end
110+
111+
local encoded = encode_yaml(value)
112+
if encoded ~= nil then
113+
return encoded
114+
end
115+
116+
local ok, s = pcall(tostring, value)
117+
if ok and type(s) == 'string' then
118+
return s
119+
end
120+
end
121+
122+
-- Map two multiline texts to compact "char sequences" and shared line table.
123+
-- Returns nil if the number of unique lines exceeds MAX_LINE_ID.
124+
local function lines_to_chars(text1, text2)
125+
local line_array = {}
126+
local line_hash = {}
127+
128+
local function add_line(line)
129+
local id = line_hash[line]
130+
if id == nil then
131+
id = #line_array + 1
132+
local encoded = encode_line_id(id)
133+
if encoded == nil then
134+
return nil
135+
end
136+
line_array[id] = line
137+
line_hash[line] = id
138+
end
139+
140+
return encode_line_id(id)
141+
end
142+
143+
local function munge(text)
144+
local tokens = {}
145+
local start = 1
146+
147+
while true do
148+
local newline_pos = text:find('\n', start, true)
149+
if newline_pos == nil then
150+
local tail = text:sub(start)
151+
if tail ~= '' then
152+
local token = add_line(tail)
153+
if token == nil then
154+
return nil
155+
end
156+
table.insert(tokens, token)
157+
end
158+
break
159+
end
160+
161+
local token = add_line(text:sub(start, newline_pos))
162+
if token == nil then
163+
return nil
164+
end
165+
table.insert(tokens, token)
166+
start = newline_pos + 1
167+
end
168+
169+
return table.concat(tokens)
170+
end
171+
172+
local chars1 = munge(text1)
173+
if chars1 == nil then
174+
return nil
175+
end
176+
177+
local chars2 = munge(text2)
178+
if chars2 == nil then
179+
return nil
180+
end
181+
182+
return chars1, chars2, line_array
183+
end
184+
185+
-- Expand a "char sequence" produced by lines_to_chars back into full text.
186+
local function chars_to_lines(text, line_array)
187+
local out = {}
188+
189+
for i = 1, #text do
190+
local id = decode_line_id(text:sub(i, i))
191+
local line = line_array[id]
192+
if line == nil then
193+
return nil
194+
end
195+
table.insert(out, line)
196+
end
197+
198+
return table.concat(out)
199+
end
200+
201+
-- Compute line-based diff using diff_match_patch, falling back to nil on failure.
202+
local function diff_by_lines(text1, text2)
203+
local chars1, chars2, line_array = lines_to_chars(text1, text2)
204+
if chars1 == nil then
205+
return nil
206+
end
207+
208+
local diffs = diff_match_patch.diff_main(chars1, chars2, false)
209+
diff_match_patch.diff_cleanupSemantic(diffs)
210+
211+
for i, diff in ipairs(diffs) do
212+
local text = chars_to_lines(diff[2], line_array)
213+
if text == nil then
214+
return nil
215+
end
216+
diffs[i][2] = text
217+
end
218+
219+
return diffs
220+
end
221+
222+
-- Normalize patch text from diff_match_patch: unescape it, drop junk lines,
223+
-- and ensure it is valid, readable unified diff.
224+
local function prettify_patch(patch_text)
225+
-- patch_toText() escapes non-ascii symbols using URL escaping. Convert it
226+
-- back to preserve the original values in unified diff output.
227+
patch_text = uri.unescape(patch_text)
228+
229+
local out = {}
230+
local last_sign = nil
231+
232+
for line in (patch_text .. '\n'):gmatch('(.-)\n') do
233+
if line ~= '' and line ~= ' ' then
234+
local first = line:sub(1, 1)
235+
236+
if first == '+' or first == '-' then
237+
last_sign = first
238+
elseif first == '@' or first == ' ' then
239+
last_sign = nil
240+
elseif last_sign ~= nil then
241+
line = last_sign .. line
242+
else
243+
line = ' ' .. line
244+
end
245+
246+
table.insert(out, line)
247+
end
248+
end
249+
250+
return table.concat(out, '\n')
251+
end
252+
253+
--- Build unified diff for expected and actual values serialized to YAML.
254+
-- Tries line-based diff first, falls back to char-based.
255+
-- Returns nil when values can't be serialized or there is no diff.
256+
function M.build_unified_diff(expected, actual)
257+
local expected_text = as_yaml(expected)
258+
local actual_text = as_yaml(actual)
259+
260+
if expected_text == nil or actual_text == nil then
261+
return nil
262+
end
263+
264+
local diffs = diff_by_lines(expected_text, actual_text)
265+
266+
if diffs == nil then
267+
return nil
268+
end
269+
270+
local patches = diff_match_patch.patch_make(expected_text,
271+
actual_text, diffs)
272+
local patch_text = diff_match_patch.patch_toText(patches)
273+
274+
if patch_text == '' then
275+
return nil
276+
end
277+
278+
return prettify_patch(patch_text)
279+
end
280+
281+
return M

0 commit comments

Comments
 (0)