You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
7.8 KiB
Python
241 lines
7.8 KiB
Python
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
|
"""
|
|
|
|
import re
|
|
import math
|
|
import collections
|
|
from functools import lru_cache
|
|
|
|
|
|
def _get_ngrams(segment, max_order):
|
|
"""Extracts all n-grams upto a given maximum order from an input segment.
|
|
|
|
Args:
|
|
segment: text segment from which n-grams will be extracted.
|
|
max_order: maximum length in tokens of the n-grams returned by this
|
|
methods.
|
|
|
|
Returns:
|
|
The Counter containing all n-grams upto max_order in segment
|
|
with a count of how many times each n-gram occurred.
|
|
"""
|
|
ngram_counts = collections.Counter()
|
|
for order in range(1, max_order + 1):
|
|
for i in range(0, len(segment) - order + 1):
|
|
ngram = tuple(segment[i : i + order])
|
|
ngram_counts[ngram] += 1
|
|
return ngram_counts
|
|
|
|
|
|
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
|
|
"""Computes BLEU score of translated segments against one or more references.
|
|
|
|
Args:
|
|
reference_corpus: list of lists of references for each translation. Each
|
|
reference should be tokenized into a list of tokens.
|
|
translation_corpus: list of translations to score. Each translation
|
|
should be tokenized into a list of tokens.
|
|
max_order: Maximum n-gram order to use when computing BLEU score.
|
|
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
|
|
|
Returns:
|
|
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
|
precisions and brevity penalty.
|
|
"""
|
|
matches_by_order = [0] * max_order
|
|
possible_matches_by_order = [0] * max_order
|
|
reference_length = 0
|
|
translation_length = 0
|
|
for references, translation in zip(reference_corpus, translation_corpus):
|
|
reference_length += min(len(r) for r in references)
|
|
translation_length += len(translation)
|
|
|
|
merged_ref_ngram_counts = collections.Counter()
|
|
for reference in references:
|
|
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
|
translation_ngram_counts = _get_ngrams(translation, max_order)
|
|
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
|
for ngram in overlap:
|
|
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
|
for order in range(1, max_order + 1):
|
|
possible_matches = len(translation) - order + 1
|
|
if possible_matches > 0:
|
|
possible_matches_by_order[order - 1] += possible_matches
|
|
|
|
precisions = [0] * max_order
|
|
for i in range(0, max_order):
|
|
if smooth:
|
|
precisions[i] = (matches_by_order[i] + 1.0) / (
|
|
possible_matches_by_order[i] + 1.0
|
|
)
|
|
else:
|
|
if possible_matches_by_order[i] > 0:
|
|
precisions[i] = (
|
|
float(matches_by_order[i]) / possible_matches_by_order[i]
|
|
)
|
|
else:
|
|
precisions[i] = 0.0
|
|
|
|
if min(precisions) > 0:
|
|
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
|
|
geo_mean = math.exp(p_log_sum)
|
|
else:
|
|
geo_mean = 0
|
|
|
|
ratio = float(translation_length) / reference_length
|
|
|
|
if ratio > 1.0:
|
|
bp = 1.0
|
|
else:
|
|
bp = math.exp(1 - 1.0 / ratio)
|
|
|
|
bleu = geo_mean * bp
|
|
|
|
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
|
|
|
|
|
class BaseTokenizer:
|
|
"""A base dummy tokenizer to derive from."""
|
|
|
|
def signature(self):
|
|
"""
|
|
Returns a signature for the tokenizer.
|
|
:return: signature string
|
|
"""
|
|
return "none"
|
|
|
|
def __call__(self, line):
|
|
"""
|
|
Tokenizes an input line with the tokenizer.
|
|
:param line: a segment to tokenize
|
|
:return: the tokenized line
|
|
"""
|
|
return line
|
|
|
|
|
|
class TokenizerRegexp(BaseTokenizer):
|
|
def signature(self):
|
|
return "re"
|
|
|
|
def __init__(self):
|
|
self._re = [
|
|
# language-dependent part (assuming Western languages)
|
|
(re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
|
|
# tokenize period and comma unless preceded by a digit
|
|
(re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
|
|
# tokenize period and comma unless followed by a digit
|
|
(re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
|
|
# tokenize dash when preceded by a digit
|
|
(re.compile(r"([0-9])(-)"), r"\1 \2 "),
|
|
# one space only between words
|
|
# NOTE: Doing this in Python (below) is faster
|
|
# (re.compile(r'\s+'), r' '),
|
|
]
|
|
|
|
@lru_cache(maxsize=2**16)
|
|
def __call__(self, line):
|
|
"""Common post-processing tokenizer for `13a` and `zh` tokenizers.
|
|
:param line: a segment to tokenize
|
|
:return: the tokenized line
|
|
"""
|
|
for _re, repl in self._re:
|
|
line = _re.sub(repl, line)
|
|
|
|
# no leading or trailing spaces, single space within words
|
|
# return ' '.join(line.split())
|
|
# This line is changed with regards to the original tokenizer (seen above) to return individual words
|
|
return line.split()
|
|
|
|
|
|
class Tokenizer13a(BaseTokenizer):
|
|
def signature(self):
|
|
return "13a"
|
|
|
|
def __init__(self):
|
|
self._post_tokenizer = TokenizerRegexp()
|
|
|
|
@lru_cache(maxsize=2**16)
|
|
def __call__(self, line):
|
|
"""Tokenizes an input line using a relatively minimal tokenization
|
|
that is however equivalent to mteval-v13a, used by WMT.
|
|
|
|
:param line: a segment to tokenize
|
|
:return: the tokenized line
|
|
"""
|
|
|
|
# language-independent part:
|
|
line = line.replace("<skipped>", "")
|
|
line = line.replace("-\n", "")
|
|
line = line.replace("\n", " ")
|
|
|
|
if "&" in line:
|
|
line = line.replace(""", '"')
|
|
line = line.replace("&", "&")
|
|
line = line.replace("<", "<")
|
|
line = line.replace(">", ">")
|
|
|
|
return self._post_tokenizer(f" {line} ")
|
|
|
|
|
|
def compute_blue_score(
|
|
predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False
|
|
):
|
|
# if only one reference is provided make sure we still use list of lists
|
|
if isinstance(references[0], str):
|
|
references = [[ref] for ref in references]
|
|
|
|
references = [[tokenizer(r) for r in ref] for ref in references]
|
|
predictions = [tokenizer(p) for p in predictions]
|
|
score = compute_bleu(
|
|
reference_corpus=references,
|
|
translation_corpus=predictions,
|
|
max_order=max_order,
|
|
smooth=smooth,
|
|
)
|
|
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
|
return bleu
|
|
|
|
|
|
def cal_distance(word1, word2):
|
|
m = len(word1)
|
|
n = len(word2)
|
|
if m * n == 0:
|
|
return m + n
|
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
for i in range(m + 1):
|
|
dp[i][0] = i
|
|
for j in range(n + 1):
|
|
dp[0][j] = j
|
|
for i in range(1, m + 1):
|
|
for j in range(1, n + 1):
|
|
a = dp[i - 1][j] + 1
|
|
b = dp[i][j - 1] + 1
|
|
c = dp[i - 1][j - 1]
|
|
if word1[i - 1] != word2[j - 1]:
|
|
c += 1
|
|
dp[i][j] = min(a, b, c)
|
|
return dp[m][n]
|
|
|
|
|
|
def compute_edit_distance(prediction, label):
|
|
prediction = prediction.strip().split(" ")
|
|
label = label.strip().split(" ")
|
|
distance = cal_distance(prediction, label)
|
|
return distance
|