Skip to content

Commit ca5de57

Browse files
committed
use Levenshtein.normalized_distance instead of distance
1 parent 071e6a8 commit ca5de57

File tree

3 files changed

+9
-21
lines changed

3 files changed

+9
-21
lines changed

src/dinglehopper/character_error_rate.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,7 @@ def character_error_rate_n(
2020
:return: character error rate and length of the reference
2121
"""
2222

23-
d = distance(reference, compared)
24-
n = len(reference)
25-
26-
if d == 0:
27-
return 0, n
28-
if n == 0:
29-
return float("inf"), n
30-
return d / n, n
23+
return distance(reference, compared), len(reference)
3124

3225
# XXX Should we really count newlines here?
3326

src/dinglehopper/edit_distance.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99

1010

1111
@multimethod
12-
def distance(seq1: List[str], seq2: List[str]) -> int:
12+
def distance(seq1: List[str], seq2: List[str]) -> float:
1313
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
1414
1515
This assumes that the grapheme clusters are already normalized.
1616
1717
Use distance(str, str) instead if you need to compare two Unicode strings.
1818
"""
19-
return Levenshtein.distance(seq1, seq2)
19+
return Levenshtein.normalized_distance(seq1, seq2)
2020

2121

2222
@distance.register
23-
def _(s1: str, s2: str) -> int:
23+
def _(s1: str, s2: str) -> float:
2424
"""Compute the Levenshtein edit distance between two Unicode strings
2525
2626
Note that this is different from levenshtein() as this function knows about Unicode
@@ -29,12 +29,12 @@ def _(s1: str, s2: str) -> int:
2929
"""
3030
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", s1)))
3131
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", s2)))
32-
return Levenshtein.distance(seq1, seq2)
32+
return Levenshtein.normalized_distance(seq1, seq2)
3333

3434

3535
@distance.register
36-
def _(s1: ExtractedText, s2: ExtractedText) -> int:
37-
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
36+
def _(s1: ExtractedText, s2: ExtractedText) -> float:
37+
return Levenshtein.normalized_distance(s1.grapheme_clusters, s2.grapheme_clusters)
3838

3939

4040
def editops(word1, word2):

src/dinglehopper/word_error_rate.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,10 @@ def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
9696
reference_seq = list(reference)
9797
compared_seq = list(compared)
9898

99-
d = Levenshtein.distance(reference_seq, compared_seq)
99+
d = Levenshtein.normalized_distance(reference_seq, compared_seq)
100100
n = len(reference_seq)
101101

102-
if d == 0:
103-
return 0, n
104-
if n == 0:
105-
return float("inf"), n
106-
return d / n, n
107-
102+
return d, n
108103

109104
def word_error_rate(reference: T, compared: T) -> float:
110105
wer: float

0 commit comments

Comments
 (0)