From 37235cd2f4db4d7ccd1feb14c6a9c68edc0e8d4a Mon Sep 17 00:00:00 2001 From: Eric Lam Date: Sun, 29 Jun 2025 22:22:35 +0800 Subject: [PATCH] Fix WER/CER condition and add empty input test --- tfkit/test/utility/test_utility_eval_metric.py | 18 ++++++++++++++++++ tfkit/utility/eval_metric.py | 4 ++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tfkit/test/utility/test_utility_eval_metric.py b/tfkit/test/utility/test_utility_eval_metric.py index e1a229c..7d51f20 100644 --- a/tfkit/test/utility/test_utility_eval_metric.py +++ b/tfkit/test/utility/test_utility_eval_metric.py @@ -151,6 +151,24 @@ def test_tokenize_text(self): eval = tfkit.utility.eval_metric.EvalMetric(tokenizer, normalize_text=True) self.assertEqual(eval.tokenize_text("How's this work"), "how ' s this work") + def test_empty_er(self): + class DummyTokenizer: + special_tokens_map = {'sep_token': '[SEP]'} + + def encode(self, text, add_special_tokens=False): + return text.split() + + def decode(self, tokens, **kwargs): + return ' '.join(tokens) + + tokenizer = DummyTokenizer() + eval = tfkit.utility.eval_metric.EvalMetric(tokenizer) + eval.add_record("", "", "", task='default') + results = list(eval.cal_score('er')) + self.assertEqual(len(results), 1) + self.assertEqual(results[0][1]['WER'], 0) + self.assertEqual(results[0][1]['CER'], 0) + @pytest.mark.skip() def testNLGWithPAD(self): tokenizer = BertTokenizer.from_pretrained('voidful/albert_chinese_tiny') diff --git a/tfkit/utility/eval_metric.py b/tfkit/utility/eval_metric.py index 628b539..b45db67 100644 --- a/tfkit/utility/eval_metric.py +++ b/tfkit/utility/eval_metric.py @@ -195,8 +195,8 @@ def cal_score(self, metric): targets.append(target) data_score.append([predict, target, {'wer': wer, 'cer': cer}]) - wer = 100 * _wer(targets, predicts) if len(target) > 0 else 100 - cer = 100 * _cer(targets, predicts) if len(target) > 0 else 100 + wer = 100 * _wer(targets, predicts) if len(targets) > 0 else 100 + cer = 100 * _cer(targets, predicts) if len(targets) > 0 else 100 result = {"WER": wer, "CER": cer} data_score = sorted(data_score, key=lambda i: i[2]['wer'], reverse=False) if "nlg" in metric: