mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Batch UD evaluation script (#3174)
* running UD eval * printing timing of tokenizer: tokens per second * timing of default English model * structured output and parameterization to compare different runs * additional flag to allow evaluation without parsing info * printing verbose log of errors for manual inspection * printing over- and undersegmented cases (and combo's) * add under and oversegmented numbers to Score and structured output * print high-freq over/under segmented words and word shapes * printing examples as part of the structured output * print the results to file * batch run of different models and treebanks per language * cleaning up code * commandline script to process all languages in spaCy & UD * heuristic to remove blinded corpora and option to run one single best per language * pathlib instead of os for file paths
This commit is contained in:
parent
5a4737df09
commit
66016ac289
|
@ -51,7 +51,8 @@
|
||||||
# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
|
# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
|
||||||
# - raises UDError if the concatenated tokens of gold and system file do not match
|
# - raises UDError if the concatenated tokens of gold and system file do not match
|
||||||
# - returns a dictionary with the metrics described above, each metrics having
|
# - returns a dictionary with the metrics described above, each metrics having
|
||||||
# three fields: precision, recall and f1
|
# four fields: precision, recall, f1 and aligned_accuracy (when using aligned
|
||||||
|
# words, otherwise this is None)
|
||||||
|
|
||||||
# Description of token matching
|
# Description of token matching
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
@ -97,7 +98,7 @@ class UDError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Load given CoNLL-U file into internal representation
|
# Load given CoNLL-U file into internal representation
|
||||||
def load_conllu(file):
|
def load_conllu(file, check_parse=True):
|
||||||
# Internal representation classes
|
# Internal representation classes
|
||||||
class UDRepresentation:
|
class UDRepresentation:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -181,8 +182,9 @@ def load_conllu(file):
|
||||||
process_word(word)
|
process_word(word)
|
||||||
|
|
||||||
# Check there is a single root node
|
# Check there is a single root node
|
||||||
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
|
if check_parse:
|
||||||
raise UDError("There are multiple roots in a sentence")
|
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
|
||||||
|
raise UDError("There are multiple roots in a sentence")
|
||||||
|
|
||||||
# End the sentence
|
# End the sentence
|
||||||
ud.sentences[-1].end = index
|
ud.sentences[-1].end = index
|
||||||
|
@ -198,7 +200,7 @@ def load_conllu(file):
|
||||||
if "." in columns[ID]:
|
if "." in columns[ID]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Delete spaces from FORM so gold.characters == system.characters
|
# Delete spaces from FORM so gold.characters == system.characters
|
||||||
# even if one of them tokenizes the space.
|
# even if one of them tokenizes the space.
|
||||||
columns[FORM] = columns[FORM].replace(" ", "")
|
columns[FORM] = columns[FORM].replace(" ", "")
|
||||||
if not columns[FORM]:
|
if not columns[FORM]:
|
||||||
|
@ -247,13 +249,17 @@ def load_conllu(file):
|
||||||
return ud
|
return ud
|
||||||
|
|
||||||
# Evaluate the gold and system treebanks (loaded using load_conllu).
|
# Evaluate the gold and system treebanks (loaded using load_conllu).
|
||||||
def evaluate(gold_ud, system_ud, deprel_weights=None):
|
def evaluate(gold_ud, system_ud, deprel_weights=None, check_parse=True):
|
||||||
class Score:
|
class Score:
|
||||||
def __init__(self, gold_total, system_total, correct, aligned_total=None):
|
def __init__(self, gold_total, system_total, correct, aligned_total=None, undersegmented=None, oversegmented=None):
|
||||||
self.precision = correct / system_total if system_total else 0.0
|
self.precision = correct / system_total if system_total else 0.0
|
||||||
self.recall = correct / gold_total if gold_total else 0.0
|
self.recall = correct / gold_total if gold_total else 0.0
|
||||||
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
|
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
|
||||||
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
|
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
|
||||||
|
self.undersegmented = undersegmented
|
||||||
|
self.oversegmented = oversegmented
|
||||||
|
self.under_perc = len(undersegmented) / gold_total if gold_total and undersegmented else 0.0
|
||||||
|
self.over_perc = len(oversegmented) / gold_total if gold_total and oversegmented else 0.0
|
||||||
class AlignmentWord:
|
class AlignmentWord:
|
||||||
def __init__(self, gold_word, system_word):
|
def __init__(self, gold_word, system_word):
|
||||||
self.gold_word = gold_word
|
self.gold_word = gold_word
|
||||||
|
@ -286,17 +292,43 @@ def evaluate(gold_ud, system_ud, deprel_weights=None):
|
||||||
|
|
||||||
def spans_score(gold_spans, system_spans):
|
def spans_score(gold_spans, system_spans):
|
||||||
correct, gi, si = 0, 0, 0
|
correct, gi, si = 0, 0, 0
|
||||||
|
undersegmented = list()
|
||||||
|
oversegmented = list()
|
||||||
|
combo = 0
|
||||||
|
previous_end_si_earlier = False
|
||||||
|
previous_end_gi_earlier = False
|
||||||
while gi < len(gold_spans) and si < len(system_spans):
|
while gi < len(gold_spans) and si < len(system_spans):
|
||||||
|
previous_si = system_spans[si-1] if si > 0 else None
|
||||||
|
previous_gi = gold_spans[gi-1] if gi > 0 else None
|
||||||
if system_spans[si].start < gold_spans[gi].start:
|
if system_spans[si].start < gold_spans[gi].start:
|
||||||
|
# avoid counting the same mistake twice
|
||||||
|
if not previous_end_si_earlier:
|
||||||
|
combo += 1
|
||||||
|
oversegmented.append(str(previous_gi).strip())
|
||||||
si += 1
|
si += 1
|
||||||
elif gold_spans[gi].start < system_spans[si].start:
|
elif gold_spans[gi].start < system_spans[si].start:
|
||||||
|
# avoid counting the same mistake twice
|
||||||
|
if not previous_end_gi_earlier:
|
||||||
|
combo += 1
|
||||||
|
undersegmented.append(str(previous_si).strip())
|
||||||
gi += 1
|
gi += 1
|
||||||
else:
|
else:
|
||||||
correct += gold_spans[gi].end == system_spans[si].end
|
correct += gold_spans[gi].end == system_spans[si].end
|
||||||
|
if gold_spans[gi].end < system_spans[si].end:
|
||||||
|
undersegmented.append(str(system_spans[si]).strip())
|
||||||
|
previous_end_gi_earlier = True
|
||||||
|
previous_end_si_earlier = False
|
||||||
|
elif gold_spans[gi].end > system_spans[si].end:
|
||||||
|
oversegmented.append(str(gold_spans[gi]).strip())
|
||||||
|
previous_end_si_earlier = True
|
||||||
|
previous_end_gi_earlier = False
|
||||||
|
else:
|
||||||
|
previous_end_gi_earlier = False
|
||||||
|
previous_end_si_earlier = False
|
||||||
si += 1
|
si += 1
|
||||||
gi += 1
|
gi += 1
|
||||||
|
|
||||||
return Score(len(gold_spans), len(system_spans), correct)
|
return Score(len(gold_spans), len(system_spans), correct, None, undersegmented, oversegmented)
|
||||||
|
|
||||||
def alignment_score(alignment, key_fn, weight_fn=lambda w: 1):
|
def alignment_score(alignment, key_fn, weight_fn=lambda w: 1):
|
||||||
gold, system, aligned, correct = 0, 0, 0, 0
|
gold, system, aligned, correct = 0, 0, 0, 0
|
||||||
|
@ -425,18 +457,28 @@ def evaluate(gold_ud, system_ud, deprel_weights=None):
|
||||||
alignment = align_words(gold_ud.words, system_ud.words)
|
alignment = align_words(gold_ud.words, system_ud.words)
|
||||||
|
|
||||||
# Compute the F1-scores
|
# Compute the F1-scores
|
||||||
result = {
|
if check_parse:
|
||||||
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
|
result = {
|
||||||
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
|
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
|
||||||
"Words": alignment_score(alignment, None),
|
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
|
||||||
"UPOS": alignment_score(alignment, lambda w, parent: w.columns[UPOS]),
|
"Words": alignment_score(alignment, None),
|
||||||
"XPOS": alignment_score(alignment, lambda w, parent: w.columns[XPOS]),
|
"UPOS": alignment_score(alignment, lambda w, parent: w.columns[UPOS]),
|
||||||
"Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]),
|
"XPOS": alignment_score(alignment, lambda w, parent: w.columns[XPOS]),
|
||||||
"AllTags": alignment_score(alignment, lambda w, parent: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
|
"Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]),
|
||||||
"Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]),
|
"AllTags": alignment_score(alignment, lambda w, parent: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
|
||||||
"UAS": alignment_score(alignment, lambda w, parent: parent),
|
"Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]),
|
||||||
"LAS": alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL])),
|
"UAS": alignment_score(alignment, lambda w, parent: parent),
|
||||||
}
|
"LAS": alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL])),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
result = {
|
||||||
|
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
|
||||||
|
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
|
||||||
|
"Words": alignment_score(alignment, None),
|
||||||
|
"Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]),
|
||||||
|
"Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Add WeightedLAS if weights are given
|
# Add WeightedLAS if weights are given
|
||||||
if deprel_weights is not None:
|
if deprel_weights is not None:
|
||||||
|
|
287
spacy/cli/ud/run_eval.py
Normal file
287
spacy/cli/ud/run_eval.py
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
import spacy
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import plac
|
||||||
|
import operator
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from spacy.cli.ud import conll17_ud_eval
|
||||||
|
from spacy.cli.ud.ud_train import write_conllu
|
||||||
|
from spacy.lang.lex_attrs import word_shape
|
||||||
|
from spacy.util import get_lang_class
|
||||||
|
|
||||||
|
# All languages in spaCy - in UD format (note that Norwegian is 'no' instead of 'nb')
|
||||||
|
ALL_LANGUAGES = "ar, ca, da, de, el, en, es, fa, fi, fr, ga, he, hi, hr, hu, id, " \
|
||||||
|
"it, ja, no, nl, pl, pt, ro, ru, sv, tr, ur, vi, zh"
|
||||||
|
|
||||||
|
# Non-parsing tasks that will be evaluated (works for default models)
|
||||||
|
EVAL_NO_PARSE = ['Tokens', 'Words', 'Lemmas', 'Sentences', 'Feats']
|
||||||
|
|
||||||
|
# Tasks that will be evaluated if check_parse=True (does not work for default models)
|
||||||
|
EVAL_PARSE = ['Tokens', 'Words', 'Lemmas', 'Sentences', 'Feats', 'UPOS', 'XPOS', 'AllTags', 'UAS', 'LAS']
|
||||||
|
|
||||||
|
# Minimum frequency an error should have to be printed
|
||||||
|
PRINT_FREQ = 20
|
||||||
|
|
||||||
|
# Maximum number of errors printed per category
|
||||||
|
PRINT_TOTAL = 10
|
||||||
|
|
||||||
|
space_re = re.compile("\s+")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(modelname, add_sentencizer=False):
|
||||||
|
""" Load a specific spaCy model """
|
||||||
|
loading_start = time.time()
|
||||||
|
nlp = spacy.load(modelname)
|
||||||
|
if add_sentencizer:
|
||||||
|
nlp.add_pipe(nlp.create_pipe('sentencizer'))
|
||||||
|
loading_end = time.time()
|
||||||
|
loading_time = loading_end - loading_start
|
||||||
|
if add_sentencizer:
|
||||||
|
return nlp, loading_time, modelname + '_sentencizer'
|
||||||
|
return nlp, loading_time, modelname
|
||||||
|
|
||||||
|
|
||||||
|
def load_default_model_sentencizer(lang):
|
||||||
|
""" Load a generic spaCy model and add the sentencizer for sentence tokenization"""
|
||||||
|
loading_start = time.time()
|
||||||
|
lang_class = get_lang_class(lang)
|
||||||
|
nlp = lang_class()
|
||||||
|
nlp.add_pipe(nlp.create_pipe('sentencizer'))
|
||||||
|
loading_end = time.time()
|
||||||
|
loading_time = loading_end - loading_start
|
||||||
|
return nlp, loading_time, lang + "_default_" + 'sentencizer'
|
||||||
|
|
||||||
|
|
||||||
|
def split_text(text):
|
||||||
|
return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]
|
||||||
|
|
||||||
|
|
||||||
|
def get_freq_tuples(my_list, print_total_threshold):
|
||||||
|
""" Turn a list of errors into frequency-sorted tuples thresholded by a certain total number """
|
||||||
|
d = {}
|
||||||
|
for token in my_list:
|
||||||
|
d.setdefault(token, 0)
|
||||||
|
d[token] += 1
|
||||||
|
return sorted(d.items(), key=operator.itemgetter(1), reverse=True)[:print_total_threshold]
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_blinded_text(stats_xml):
|
||||||
|
""" Heuristic to determine whether the treebank has blinded texts or not """
|
||||||
|
tree = ET.parse(stats_xml)
|
||||||
|
root = tree.getroot()
|
||||||
|
total_tokens = int(root.find('size/total/tokens').text)
|
||||||
|
unique_lemmas = int(root.find('lemmas').get('unique'))
|
||||||
|
|
||||||
|
# assume the corpus is largely blinded when there are less than 1% unique tokens
|
||||||
|
return (unique_lemmas / total_tokens) < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_all_treebanks(ud_dir, languages, corpus, best_per_language):
|
||||||
|
"""" Fetch the txt files for all treebanks for a given set of languages """
|
||||||
|
all_treebanks = dict()
|
||||||
|
treebank_size = dict()
|
||||||
|
for l in languages:
|
||||||
|
all_treebanks[l] = []
|
||||||
|
treebank_size[l] = 0
|
||||||
|
|
||||||
|
for treebank_dir in ud_dir.iterdir():
|
||||||
|
if treebank_dir.is_dir():
|
||||||
|
for txt_path in treebank_dir.iterdir():
|
||||||
|
if txt_path.name.endswith('-ud-' + corpus + '.txt'):
|
||||||
|
file_lang = txt_path.name.split('_')[0]
|
||||||
|
if file_lang in languages:
|
||||||
|
gold_path = treebank_dir / txt_path.name.replace('.txt', '.conllu')
|
||||||
|
stats_xml = treebank_dir / "stats.xml"
|
||||||
|
# ignore treebanks where the texts are not publicly available
|
||||||
|
if not _contains_blinded_text(stats_xml):
|
||||||
|
if not best_per_language:
|
||||||
|
all_treebanks[file_lang].append(txt_path)
|
||||||
|
# check the tokens in the gold annotation to keep only the biggest treebank per language
|
||||||
|
else:
|
||||||
|
with gold_path.open(mode='r', encoding='utf-8') as gold_file:
|
||||||
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||||
|
gold_tokens = len(gold_ud.tokens)
|
||||||
|
if treebank_size[file_lang] < gold_tokens:
|
||||||
|
all_treebanks[file_lang] = [txt_path]
|
||||||
|
treebank_size[file_lang] = gold_tokens
|
||||||
|
|
||||||
|
return all_treebanks
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_eval(nlp, loading_time, print_name, text_path, gold_ud, tmp_output_path, out_file, print_header,
|
||||||
|
check_parse, print_freq_tasks):
|
||||||
|
"""" Run an evaluation of a model nlp on a certain specified treebank """
|
||||||
|
with text_path.open(mode='r', encoding='utf-8') as f:
|
||||||
|
flat_text = f.read()
|
||||||
|
|
||||||
|
# STEP 1: tokenize text
|
||||||
|
tokenization_start = time.time()
|
||||||
|
texts = split_text(flat_text)
|
||||||
|
docs = list(nlp.pipe(texts))
|
||||||
|
tokenization_end = time.time()
|
||||||
|
tokenization_time = tokenization_end - tokenization_start
|
||||||
|
|
||||||
|
# STEP 2: record stats and timings
|
||||||
|
tokens_per_s = int(len(gold_ud.tokens) / tokenization_time)
|
||||||
|
|
||||||
|
print_header_1 = ['date', 'text_path', 'gold_tokens', 'model', 'loading_time', 'tokenization_time', 'tokens_per_s']
|
||||||
|
print_string_1 = [str(datetime.date.today()), text_path.name, len(gold_ud.tokens),
|
||||||
|
print_name, "%.2f" % loading_time, "%.2f" % tokenization_time, tokens_per_s]
|
||||||
|
|
||||||
|
# STEP 3: evaluate predicted tokens and features
|
||||||
|
with tmp_output_path.open(mode="w", encoding="utf8") as tmp_out_file:
|
||||||
|
write_conllu(docs, tmp_out_file)
|
||||||
|
with tmp_output_path.open(mode="r", encoding="utf8") as sys_file:
|
||||||
|
sys_ud = conll17_ud_eval.load_conllu(sys_file, check_parse=check_parse)
|
||||||
|
tmp_output_path.unlink()
|
||||||
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud, check_parse=check_parse)
|
||||||
|
|
||||||
|
# STEP 4: format the scoring results
|
||||||
|
eval_headers = EVAL_PARSE
|
||||||
|
if not check_parse:
|
||||||
|
eval_headers = EVAL_NO_PARSE
|
||||||
|
|
||||||
|
for score_name in eval_headers:
|
||||||
|
score = scores[score_name]
|
||||||
|
print_string_1.extend(["%.2f" % score.precision,
|
||||||
|
"%.2f" % score.recall,
|
||||||
|
"%.2f" % score.f1])
|
||||||
|
print_string_1.append("-" if score.aligned_accuracy is None else "%.2f" % score.aligned_accuracy)
|
||||||
|
print_string_1.append("-" if score.undersegmented is None else "%.4f" % score.under_perc)
|
||||||
|
print_string_1.append("-" if score.oversegmented is None else "%.4f" % score.over_perc)
|
||||||
|
|
||||||
|
print_header_1.extend([score_name + '_p', score_name + '_r', score_name + '_F', score_name + '_acc',
|
||||||
|
score_name + '_under', score_name + '_over'])
|
||||||
|
|
||||||
|
if score_name in print_freq_tasks:
|
||||||
|
print_header_1.extend([score_name + '_word_under_ex', score_name + '_shape_under_ex',
|
||||||
|
score_name + '_word_over_ex', score_name + '_shape_over_ex'])
|
||||||
|
|
||||||
|
d_under_words = get_freq_tuples(score.undersegmented, PRINT_TOTAL)
|
||||||
|
d_under_shapes = get_freq_tuples([word_shape(x) for x in score.undersegmented], PRINT_TOTAL)
|
||||||
|
d_over_words = get_freq_tuples(score.oversegmented, PRINT_TOTAL)
|
||||||
|
d_over_shapes = get_freq_tuples([word_shape(x) for x in score.oversegmented], PRINT_TOTAL)
|
||||||
|
|
||||||
|
# saving to CSV with ; seperator so blinding ; in the example output
|
||||||
|
print_string_1.append(
|
||||||
|
str({k: v for k, v in d_under_words if v > PRINT_FREQ}).replace(";", "*SEMICOLON*"))
|
||||||
|
print_string_1.append(
|
||||||
|
str({k: v for k, v in d_under_shapes if v > PRINT_FREQ}).replace(";", "*SEMICOLON*"))
|
||||||
|
print_string_1.append(
|
||||||
|
str({k: v for k, v in d_over_words if v > PRINT_FREQ}).replace(";", "*SEMICOLON*"))
|
||||||
|
print_string_1.append(
|
||||||
|
str({k: v for k, v in d_over_shapes if v > PRINT_FREQ}).replace(";", "*SEMICOLON*"))
|
||||||
|
|
||||||
|
# STEP 5: print the formatted results to CSV
|
||||||
|
if print_header:
|
||||||
|
out_file.write(';'.join(map(str, print_header_1)) + '\n')
|
||||||
|
out_file.write(';'.join(map(str, print_string_1)) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_evals(models, treebanks, out_file, check_parse, print_freq_tasks):
|
||||||
|
"""" Run an evaluation for each language with its specified models and treebanks """
|
||||||
|
print_header = True
|
||||||
|
|
||||||
|
for tb_lang, treebank_list in treebanks.items():
|
||||||
|
print()
|
||||||
|
print("Language", tb_lang)
|
||||||
|
for text_path in treebank_list:
|
||||||
|
print(" Evaluating on", text_path)
|
||||||
|
|
||||||
|
gold_path = text_path.parent / (text_path.stem + '.conllu')
|
||||||
|
print(" Gold data from ", gold_path)
|
||||||
|
|
||||||
|
# nested try blocks to ensure the code can continue with the next iteration after a failure
|
||||||
|
try:
|
||||||
|
with gold_path.open(mode='r', encoding='utf-8') as gold_file:
|
||||||
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||||
|
|
||||||
|
for nlp, nlp_loading_time, nlp_name in models[tb_lang]:
|
||||||
|
try:
|
||||||
|
print(" Benchmarking", nlp_name)
|
||||||
|
tmp_output_path = text_path.parent / str('tmp_' + nlp_name + '.conllu')
|
||||||
|
run_single_eval(nlp, nlp_loading_time, nlp_name, text_path, gold_ud, tmp_output_path, out_file,
|
||||||
|
print_header, check_parse, print_freq_tasks)
|
||||||
|
print_header = False
|
||||||
|
except Exception as e:
|
||||||
|
print(" Ran into trouble: ", str(e))
|
||||||
|
except Exception as e:
|
||||||
|
print(" Ran into trouble: ", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
out_path=("Path to output CSV file", "positional", None, Path),
|
||||||
|
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||||
|
check_parse=("Set flag to evaluate parsing performance", "flag", "p", bool),
|
||||||
|
langs=("Enumeration of languages to evaluate (default: all)", "option", "l", str),
|
||||||
|
exclude_trained_models=("Set flag to exclude trained models", "flag", "t", bool),
|
||||||
|
exclude_multi=("Set flag to exclude the multi-language model as default baseline", "flag", "m", bool),
|
||||||
|
hide_freq=("Set flag to avoid printing out more detailed high-freq tokenization errors", "flag", "f", bool),
|
||||||
|
corpus=("Whether to run on train, dev or test", "option", "c", str),
|
||||||
|
best_per_language=("Set flag to only keep the largest treebank for each language", "flag", "b", bool)
|
||||||
|
)
|
||||||
|
def main(out_path, ud_dir, check_parse=False, langs=ALL_LANGUAGES, exclude_trained_models=False, exclude_multi=False,
|
||||||
|
hide_freq=False, corpus='train', best_per_language=False):
|
||||||
|
""""
|
||||||
|
Assemble all treebanks and models to run evaluations with.
|
||||||
|
When setting check_parse to True, the default models will not be evaluated as they don't have parsing functionality
|
||||||
|
"""
|
||||||
|
languages = [lang.strip() for lang in langs.split(",")]
|
||||||
|
|
||||||
|
print_freq_tasks = []
|
||||||
|
if not hide_freq:
|
||||||
|
print_freq_tasks = ['Tokens']
|
||||||
|
|
||||||
|
# fetching all relevant treebank from the directory
|
||||||
|
treebanks = fetch_all_treebanks(ud_dir, languages, corpus, best_per_language)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Loading all relevant models for", languages)
|
||||||
|
models = dict()
|
||||||
|
|
||||||
|
# multi-lang model
|
||||||
|
multi = None
|
||||||
|
if not exclude_multi and not check_parse:
|
||||||
|
multi = load_model('xx_ent_wiki_sm', add_sentencizer=True)
|
||||||
|
|
||||||
|
# initialize all models with the multi-lang model
|
||||||
|
for lang in languages:
|
||||||
|
models[lang] = [multi] if multi else []
|
||||||
|
# add default models if we don't want to evaluate parsing info
|
||||||
|
if not check_parse:
|
||||||
|
# Norwegian is 'nb' in spaCy but 'no' in the UD corpora
|
||||||
|
if lang == 'no':
|
||||||
|
models['no'].append(load_default_model_sentencizer('nb'))
|
||||||
|
else:
|
||||||
|
models[lang].append(load_default_model_sentencizer(lang))
|
||||||
|
|
||||||
|
# language-specific trained models
|
||||||
|
if not exclude_trained_models:
|
||||||
|
if 'de' in models:
|
||||||
|
models['de'].append(load_model('de_core_news_sm'))
|
||||||
|
if 'es' in models:
|
||||||
|
models['es'].append(load_model('es_core_news_sm'))
|
||||||
|
models['es'].append(load_model('es_core_news_md'))
|
||||||
|
if 'pt' in models:
|
||||||
|
models['pt'].append(load_model('pt_core_news_sm'))
|
||||||
|
if 'it' in models:
|
||||||
|
models['it'].append(load_model('it_core_news_sm'))
|
||||||
|
if 'nl' in models:
|
||||||
|
models['nl'].append(load_model('nl_core_news_sm'))
|
||||||
|
if 'en' in models:
|
||||||
|
models['en'].append(load_model('en_core_web_sm'))
|
||||||
|
models['en'].append(load_model('en_core_web_md'))
|
||||||
|
models['en'].append(load_model('en_core_web_lg'))
|
||||||
|
if 'fr' in models:
|
||||||
|
models['fr'].append(load_model('fr_core_news_sm'))
|
||||||
|
models['fr'].append(load_model('fr_core_news_md'))
|
||||||
|
|
||||||
|
with out_path.open(mode='w', encoding='utf-8') as out_file:
|
||||||
|
run_all_evals(models, treebanks, out_file, check_parse, print_freq_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
plac.call(main)
|
Loading…
Reference in New Issue
Block a user