Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ poetry.lock
# virtual environments
venv
.venv
uv.lock

# cache folders
.pytest_cache
.benchmarks
/docs/site/
/site/

139 changes: 129 additions & 10 deletions src/jiwer/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
#

"""
Utility method to visualize the alignment between one or more reference and hypothesis
pairs.
Utility method to visualize the alignment and errors between one or more reference
and hypothesis pairs.
"""

from collections import defaultdict
from typing import List, Union, Optional

from jiwer.process import CharacterOutput, WordOutput, AlignmentChunk

__all__ = ["visualize_alignment"]
__all__ = ["visualize_alignment", "collect_error_counts", "visualize_error_counts"]


def visualize_alignment(
Expand Down Expand Up @@ -65,16 +65,19 @@ def visualize_alignment(
```
will produce this visualization:
```txt
sentence 1
=== SENTENCE 1 ===

REF: # short one here
HYP: shoe order one *
I S D

sentence 2
=== sentence 2 ===

REF: quite a bit of # # longer sentence #
HYP: quite * bit of an even longest sentence here
D I I S I

=== SUMMARY ===
number of sentences: 2
substitutions=2 deletions=2 insertions=4 hits=5

Expand All @@ -87,12 +90,13 @@ def visualize_alignment(
When `show_measures=False`, only the alignment will be printed:

```txt
sentence 1
=== SENTENCE 1 ===

REF: # short one here
HYP: shoe order one *
I S D

sentence 2
=== SENTENCE 2 ===
REF: quite a bit of # # longer sentence #
HYP: quite * bit of an even longest sentence here
D I I S I
Expand All @@ -101,7 +105,7 @@ def visualize_alignment(
When setting `line_width=80`, the following output will be split into multiple lines:

```txt
sentence 1
=== SENTENCE 1 ===
REF: This is a very long sentence that is *** much longer than the previous one
HYP: This is a very loong sentence that is not much longer than the previous one
S I
Expand All @@ -122,13 +126,14 @@ def visualize_alignment(
):
continue

final_str += f"sentence {idx+1}\n"
final_str += f"=== SENTENCE {idx+1} ===\n\n"
final_str += _construct_comparison_string(
gt, hp, chunks, include_space_seperator=not is_cer, line_width=line_width
)
final_str += "\n"

if show_measures:
final_str += "=== SUMMARY ===\n"
final_str += f"number of sentences: {len(alignment)}\n"
final_str += f"substitutions={output.substitutions} "
final_str += f"deletions={output.deletions} "
Expand Down Expand Up @@ -213,3 +218,117 @@ def _construct_comparison_string(
return agg_str + f"{ref_str[:-1]}\n{hyp_str[:-1]}\n{op_str[:-1]}\n"
else:
return agg_str + f"{ref_str}\n{hyp_str}\n{op_str}\n"


def collect_error_counts(output: Union[WordOutput, CharacterOutput]):
"""
Retrieve three dictionaries, which count the frequency of how often
each word or character was substituted, inserted, or deleted.
The substitution dictionary has, as keys, a 2-tuple (from, to).
The other two dictionaries have the inserted/deleted words or characters as keys.

Args:
output: The processed output of reference and hypothesis pair(s).

Returns:
A three-tuple of dictionaries, in the order substitutions, insertions, deletions.
"""
substitutions = defaultdict(lambda: 0)
insertions = defaultdict(lambda: 0)
deletions = defaultdict(lambda: 0)

for idx, sentence_chunks in enumerate(output.alignments):
ref = output.references[idx]
hyp = output.hypotheses[idx]
sep = " " if isinstance(output, WordOutput) else ""

for chunk in sentence_chunks:
if chunk.type == "insert":
inserted = sep.join(hyp[chunk.hyp_start_idx : chunk.hyp_end_idx])
insertions[inserted] += 1
if chunk.type == "delete":
deleted = sep.join(ref[chunk.ref_start_idx : chunk.ref_end_idx])
deletions[deleted] += 1
if chunk.type == "substitute":
replaced = sep.join(ref[chunk.ref_start_idx : chunk.ref_end_idx])
by = sep.join(hyp[chunk.hyp_start_idx : chunk.hyp_end_idx])
substitutions[(replaced, by)] += 1

return substitutions, insertions, deletions


def visualize_error_counts(
output: Union[WordOutput, CharacterOutput],
show_substitutions: bool = True,
show_insertions: bool = True,
show_deletions: bool = True,
top_k: Optional[int] = None,
):
"""
Visualize which words (or characters), and how often, were substituted, inserted, or deleted.

Args:
output:
show_substitutions: If true, visualize substitution errors.
show_insertions: If true, visualize insertion errors.
show_deletions: If true, visualize deletion errors.
top_k: If set, only visualize the k most frequent errors.

Returns: A string which visualizes the words/characters and their frequencies.

"""
s, i, d = collect_error_counts(output)

def build_list(errors: dict):
if len(errors) == 0:
return "none"

keys = [k for k in errors.keys()]
keys = sorted(keys, reverse=True, key=lambda k: errors[k])

if top_k is not None:
keys = keys[:top_k]

# we get the maximum length of all words to nicely pad output
ln = max(len(k) if isinstance(k, str) else max(len(e) for e in k) for k in keys)

# here we construct the string
build = ""

for count, (k, v) in enumerate(
sorted(errors.items(), key=lambda tpl: tpl[1], reverse=True)
):
if top_k is not None and count >= top_k:
break

if isinstance(k, tuple):
build += f"{k[0]: <{ln}} --> {k[1]:<{ln}} = {v}x\n"
else:
build += f"{k:<{ln}} = {v}x\n"

return build

output = ""

if show_substitutions:
if output != "":
output += "\n"
output += "=== SUBSTITUTIONS ===\n"
output += build_list(s)

if show_insertions:
if output != "":
output += "\n"
output += "=== INSERTIONS ===\n"
output += build_list(i)

if show_deletions:
if output != "":
output += "\n"
output += "=== DELETIONS ===\n"
output += build_list(d)

if output[-1:] == "\n":
output = output[:-1]

return output
Loading