Skip to content

Commit 78968af

Browse files
author
Yibing Liu
authored
Merge pull request #47 from pkuyym/fix-46
Expose edit distance for error_rate.py
2 parents fe1501c + 0f9b3eb commit 78968af

File tree

1 file changed

+65
-23
lines changed

1 file changed

+65
-23
lines changed

utils/error_rate.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,62 @@ def _levenshtein_distance(ref, hyp):
5656
return distance[m % 2][n]
5757

5858

59+
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
60+
"""Compute the levenshtein distance between reference sequence and
61+
hypothesis sequence in word-level.
62+
63+
:param reference: The reference sentence.
64+
:type reference: basestring
65+
:param hypothesis: The hypothesis sentence.
66+
:type hypothesis: basestring
67+
:param ignore_case: Whether case-sensitive or not.
68+
:type ignore_case: bool
69+
:param delimiter: Delimiter of input sentences.
70+
:type delimiter: char
71+
:return: Levenshtein distance and word number of reference sentence.
72+
:rtype: list
73+
"""
74+
if ignore_case == True:
75+
reference = reference.lower()
76+
hypothesis = hypothesis.lower()
77+
78+
ref_words = filter(None, reference.split(delimiter))
79+
hyp_words = filter(None, hypothesis.split(delimiter))
80+
81+
edit_distance = _levenshtein_distance(ref_words, hyp_words)
82+
return float(edit_distance), len(ref_words)
83+
84+
85+
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
86+
"""Compute the levenshtein distance between reference sequence and
87+
hypothesis sequence in char-level.
88+
89+
:param reference: The reference sentence.
90+
:type reference: basestring
91+
:param hypothesis: The hypothesis sentence.
92+
:type hypothesis: basestring
93+
:param ignore_case: Whether case-sensitive or not.
94+
:type ignore_case: bool
95+
:param remove_space: Whether remove internal space characters
96+
:type remove_space: bool
97+
:return: Levenshtein distance and length of reference sentence.
98+
:rtype: list
99+
"""
100+
if ignore_case == True:
101+
reference = reference.lower()
102+
hypothesis = hypothesis.lower()
103+
104+
join_char = ' '
105+
if remove_space == True:
106+
join_char = ''
107+
108+
reference = join_char.join(filter(None, reference.split(' ')))
109+
hypothesis = join_char.join(filter(None, hypothesis.split(' ')))
110+
111+
edit_distance = _levenshtein_distance(reference, hypothesis)
112+
return float(edit_distance), len(reference)
113+
114+
59115
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
60116
"""Calculate word error rate (WER). WER compares reference text and
61117
hypothesis text in word-level. WER is defined as:
@@ -85,20 +141,15 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
85141
:type delimiter: char
86142
:return: Word error rate.
87143
:rtype: float
88-
:raises ValueError: If the reference length is zero.
144+
:raises ValueError: If word number of reference is zero.
89145
"""
90-
if ignore_case == True:
91-
reference = reference.lower()
92-
hypothesis = hypothesis.lower()
146+
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
147+
delimiter)
93148

94-
ref_words = filter(None, reference.split(delimiter))
95-
hyp_words = filter(None, hypothesis.split(delimiter))
96-
97-
if len(ref_words) == 0:
149+
if ref_len == 0:
98150
raise ValueError("Reference's word number should be greater than 0.")
99151

100-
edit_distance = _levenshtein_distance(ref_words, hyp_words)
101-
wer = float(edit_distance) / len(ref_words)
152+
wer = float(edit_distance) / ref_len
102153
return wer
103154

104155

@@ -135,20 +186,11 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
135186
:rtype: float
136187
:raises ValueError: If the reference length is zero.
137188
"""
138-
if ignore_case == True:
139-
reference = reference.lower()
140-
hypothesis = hypothesis.lower()
189+
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
190+
remove_space)
141191

142-
join_char = ' '
143-
if remove_space == True:
144-
join_char = ''
145-
146-
reference = join_char.join(filter(None, reference.split(' ')))
147-
hypothesis = join_char.join(filter(None, hypothesis.split(' ')))
148-
149-
if len(reference) == 0:
192+
if ref_len == 0:
150193
raise ValueError("Length of reference should be greater than 0.")
151194

152-
edit_distance = _levenshtein_distance(reference, hypothesis)
153-
cer = float(edit_distance) / len(reference)
195+
cer = float(edit_distance) / ref_len
154196
return cer

0 commit comments

Comments
 (0)