mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Add ud_train and ud_evaluate CLI commands
This commit is contained in:
parent
4b72c38556
commit
3478ea76d1
570
spacy/cli/conll17_ud_eval.py
Normal file
570
spacy/cli/conll17_ud_eval.py
Normal file
|
@ -0,0 +1,570 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# CoNLL 2017 UD Parsing evaluation script.
|
||||
#
|
||||
# Compatible with Python 2.7 and 3.2+, can be used either as a module
|
||||
# or a standalone executable.
|
||||
#
|
||||
# Copyright 2017 Institute of Formal and Applied Linguistics (UFAL),
|
||||
# Faculty of Mathematics and Physics, Charles University, Czech Republic.
|
||||
#
|
||||
# Changelog:
|
||||
# - [02 Jan 2017] Version 0.9: Initial release
|
||||
# - [25 Jan 2017] Version 0.9.1: Fix bug in LCS alignment computation
|
||||
# - [10 Mar 2017] Version 1.0: Add documentation and test
|
||||
# Compare HEADs correctly using aligned words
|
||||
# Allow evaluation with errorneous spaces in forms
|
||||
# Compare forms in LCS case insensitively
|
||||
# Detect cycles and multiple root nodes
|
||||
# Compute AlignedAccuracy
|
||||
|
||||
# Command line usage
|
||||
# ------------------
|
||||
# conll17_ud_eval.py [-v] [-w weights_file] gold_conllu_file system_conllu_file
|
||||
#
|
||||
# - if no -v is given, only the CoNLL17 UD Shared Task evaluation LAS metrics
|
||||
# is printed
|
||||
# - if -v is given, several metrics are printed (as precision, recall, F1 score,
|
||||
# and in case the metric is computed on aligned words also accuracy on these):
|
||||
# - Tokens: how well do the gold tokens match system tokens
|
||||
# - Sentences: how well do the gold sentences match system sentences
|
||||
# - Words: how well can the gold words be aligned to system words
|
||||
# - UPOS: using aligned words, how well does UPOS match
|
||||
# - XPOS: using aligned words, how well does XPOS match
|
||||
# - Feats: using aligned words, how well does FEATS match
|
||||
# - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match
|
||||
# - Lemmas: using aligned words, how well does LEMMA match
|
||||
# - UAS: using aligned words, how well does HEAD match
|
||||
# - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match
|
||||
# - if weights_file is given (with lines containing deprel-weight pairs),
|
||||
# one more metric is shown:
|
||||
# - WeightedLAS: as LAS, but each deprel (ignoring subtypes) has different weight
|
||||
|
||||
# API usage
|
||||
# ---------
|
||||
# - load_conllu(file)
|
||||
# - loads CoNLL-U file from given file object to an internal representation
|
||||
# - the file object should return str on both Python 2 and Python 3
|
||||
# - raises UDError exception if the given file cannot be loaded
|
||||
# - evaluate(gold_ud, system_ud)
|
||||
# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
|
||||
# - raises UDError if the concatenated tokens of gold and system file do not match
|
||||
# - returns a dictionary with the metrics described above, each metrics having
|
||||
# three fields: precision, recall and f1
|
||||
|
||||
# Description of token matching
|
||||
# -----------------------------
|
||||
# In order to match tokens of gold file and system file, we consider the text
|
||||
# resulting from concatenation of gold tokens and text resulting from
|
||||
# concatenation of system tokens. These texts should match -- if they do not,
|
||||
# the evaluation fails.
|
||||
#
|
||||
# If the texts do match, every token is represented as a range in this original
|
||||
# text, and tokens are equal only if their range is the same.
|
||||
|
||||
# Description of word matching
|
||||
# ----------------------------
|
||||
# When matching words of gold file and system file, we first match the tokens.
|
||||
# The words which are also tokens are matched as tokens, but words in multi-word
|
||||
# tokens have to be handled differently.
|
||||
#
|
||||
# To handle multi-word tokens, we start by finding "multi-word spans".
|
||||
# Multi-word span is a span in the original text such that
|
||||
# - it contains at least one multi-word token
|
||||
# - all multi-word tokens in the span (considering both gold and system ones)
|
||||
# are completely inside the span (i.e., they do not "stick out")
|
||||
# - the multi-word span is as small as possible
|
||||
#
|
||||
# For every multi-word span, we align the gold and system words completely
|
||||
# inside this span using LCS on their FORMs. The words not intersecting
|
||||
# (even partially) any multi-word span are then aligned as tokens.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
# CoNLL-U column names
|
||||
ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
|
||||
|
||||
# UD Error is used when raising exceptions in this module
|
||||
class UDError(Exception):
|
||||
pass
|
||||
|
||||
# Load given CoNLL-U file into internal representation
|
||||
def load_conllu(file):
|
||||
# Internal representation classes
|
||||
class UDRepresentation:
|
||||
def __init__(self):
|
||||
# Characters of all the tokens in the whole file.
|
||||
# Whitespace between tokens is not included.
|
||||
self.characters = []
|
||||
# List of UDSpan instances with start&end indices into `characters`.
|
||||
self.tokens = []
|
||||
# List of UDWord instances.
|
||||
self.words = []
|
||||
# List of UDSpan instances with start&end indices into `characters`.
|
||||
self.sentences = []
|
||||
class UDSpan:
|
||||
def __init__(self, start, end, characters):
|
||||
self.start = start
|
||||
# Note that self.end marks the first position **after the end** of span,
|
||||
# so we can use characters[start:end] or range(start, end).
|
||||
self.end = end
|
||||
self.characters = characters
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join(self.characters[self.start:self.end])
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def __repr__(self):
|
||||
return self.text
|
||||
class UDWord:
|
||||
def __init__(self, span, columns, is_multiword):
|
||||
# Span of this word (or MWT, see below) within ud_representation.characters.
|
||||
self.span = span
|
||||
# 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
|
||||
self.columns = columns
|
||||
# is_multiword==True means that this word is part of a multi-word token.
|
||||
# In that case, self.span marks the span of the whole multi-word token.
|
||||
self.is_multiword = is_multiword
|
||||
# Reference to the UDWord instance representing the HEAD (or None if root).
|
||||
self.parent = None
|
||||
# Let's ignore language-specific deprel subtypes.
|
||||
self.columns[DEPREL] = columns[DEPREL].split(':')[0]
|
||||
|
||||
ud = UDRepresentation()
|
||||
|
||||
# Load the CoNLL-U file
|
||||
index, sentence_start = 0, None
|
||||
linenum = 0
|
||||
while True:
|
||||
line = file.readline()
|
||||
linenum += 1
|
||||
if not line:
|
||||
break
|
||||
line = line.rstrip("\r\n")
|
||||
|
||||
# Handle sentence start boundaries
|
||||
if sentence_start is None:
|
||||
# Skip comments
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
# Start a new sentence
|
||||
ud.sentences.append(UDSpan(index, 0, ud.characters))
|
||||
sentence_start = len(ud.words)
|
||||
if not line:
|
||||
# Add parent UDWord links and check there are no cycles
|
||||
def process_word(word):
|
||||
if word.parent == "remapping":
|
||||
raise UDError("There is a cycle in a sentence")
|
||||
if word.parent is None:
|
||||
head = int(word.columns[HEAD])
|
||||
if head > len(ud.words) - sentence_start:
|
||||
raise UDError("HEAD '{}' points outside of the sentence".format(word.columns[HEAD]))
|
||||
if head:
|
||||
parent = ud.words[sentence_start + head - 1]
|
||||
word.parent = "remapping"
|
||||
process_word(parent)
|
||||
word.parent = parent
|
||||
|
||||
for word in ud.words[sentence_start:]:
|
||||
process_word(word)
|
||||
|
||||
# Check there is a single root node
|
||||
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
|
||||
raise UDError("There are multiple roots in a sentence")
|
||||
|
||||
# End the sentence
|
||||
ud.sentences[-1].end = index
|
||||
sentence_start = None
|
||||
continue
|
||||
|
||||
# Read next token/word
|
||||
columns = line.split("\t")
|
||||
if len(columns) != 10:
|
||||
raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, line))
|
||||
|
||||
# Skip empty nodes
|
||||
if "." in columns[ID]:
|
||||
continue
|
||||
|
||||
# Delete spaces from FORM so gold.characters == system.characters
|
||||
# even if one of them tokenizes the space.
|
||||
columns[FORM] = columns[FORM].replace(" ", "")
|
||||
if not columns[FORM]:
|
||||
raise UDError("There is an empty FORM in the CoNLL-U file -- line %d" % linenum)
|
||||
|
||||
# Save token
|
||||
ud.characters.extend(columns[FORM])
|
||||
ud.tokens.append(UDSpan(index, index + len(columns[FORM]), ud.characters))
|
||||
index += len(columns[FORM])
|
||||
|
||||
# Handle multi-word tokens to save word(s)
|
||||
if "-" in columns[ID]:
|
||||
try:
|
||||
start, end = map(int, columns[ID].split("-"))
|
||||
except:
|
||||
raise UDError("Cannot parse multi-word token ID '{}'".format(columns[ID]))
|
||||
|
||||
for _ in range(start, end + 1):
|
||||
word_line = file.readline().rstrip("\r\n")
|
||||
word_columns = word_line.split("\t")
|
||||
if len(word_columns) != 10:
|
||||
print(columns)
|
||||
raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, word_line))
|
||||
ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
|
||||
# Basic tokens/words
|
||||
else:
|
||||
try:
|
||||
word_id = int(columns[ID])
|
||||
except:
|
||||
raise UDError("Cannot parse word ID '{}'".format(columns[ID]))
|
||||
if word_id != len(ud.words) - sentence_start + 1:
|
||||
raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(columns[ID], columns[FORM], len(ud.words) - sentence_start + 1))
|
||||
|
||||
try:
|
||||
head_id = int(columns[HEAD])
|
||||
except:
|
||||
raise UDError("Cannot parse HEAD '{}'".format(columns[HEAD]))
|
||||
if head_id < 0:
|
||||
raise UDError("HEAD cannot be negative")
|
||||
|
||||
ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
|
||||
|
||||
if sentence_start is not None:
|
||||
raise UDError("The CoNLL-U file does not end with empty line")
|
||||
|
||||
return ud
|
||||
|
||||
# Evaluate the gold and system treebanks (loaded using load_conllu).
|
||||
def evaluate(gold_ud, system_ud, deprel_weights=None):
|
||||
class Score:
|
||||
def __init__(self, gold_total, system_total, correct, aligned_total=None):
|
||||
self.precision = correct / system_total if system_total else 0.0
|
||||
self.recall = correct / gold_total if gold_total else 0.0
|
||||
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
|
||||
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
|
||||
class AlignmentWord:
|
||||
def __init__(self, gold_word, system_word):
|
||||
self.gold_word = gold_word
|
||||
self.system_word = system_word
|
||||
self.gold_parent = None
|
||||
self.system_parent_gold_aligned = None
|
||||
class Alignment:
|
||||
def __init__(self, gold_words, system_words):
|
||||
self.gold_words = gold_words
|
||||
self.system_words = system_words
|
||||
self.matched_words = []
|
||||
self.matched_words_map = {}
|
||||
def append_aligned_words(self, gold_word, system_word):
|
||||
self.matched_words.append(AlignmentWord(gold_word, system_word))
|
||||
self.matched_words_map[system_word] = gold_word
|
||||
def fill_parents(self):
|
||||
# We represent root parents in both gold and system data by '0'.
|
||||
# For gold data, we represent non-root parent by corresponding gold word.
|
||||
# For system data, we represent non-root parent by either gold word aligned
|
||||
# to parent system nodes, or by None if no gold words is aligned to the parent.
|
||||
for words in self.matched_words:
|
||||
words.gold_parent = words.gold_word.parent if words.gold_word.parent is not None else 0
|
||||
words.system_parent_gold_aligned = self.matched_words_map.get(words.system_word.parent, None) \
|
||||
if words.system_word.parent is not None else 0
|
||||
|
||||
def lower(text):
|
||||
if sys.version_info < (3, 0) and isinstance(text, str):
|
||||
return text.decode("utf-8").lower()
|
||||
return text.lower()
|
||||
|
||||
def spans_score(gold_spans, system_spans):
|
||||
correct, gi, si = 0, 0, 0
|
||||
while gi < len(gold_spans) and si < len(system_spans):
|
||||
if system_spans[si].start < gold_spans[gi].start:
|
||||
si += 1
|
||||
elif gold_spans[gi].start < system_spans[si].start:
|
||||
gi += 1
|
||||
else:
|
||||
correct += gold_spans[gi].end == system_spans[si].end
|
||||
si += 1
|
||||
gi += 1
|
||||
|
||||
return Score(len(gold_spans), len(system_spans), correct)
|
||||
|
||||
def alignment_score(alignment, key_fn, weight_fn=lambda w: 1):
|
||||
gold, system, aligned, correct = 0, 0, 0, 0
|
||||
|
||||
for word in alignment.gold_words:
|
||||
gold += weight_fn(word)
|
||||
|
||||
for word in alignment.system_words:
|
||||
system += weight_fn(word)
|
||||
|
||||
for words in alignment.matched_words:
|
||||
aligned += weight_fn(words.gold_word)
|
||||
|
||||
if key_fn is None:
|
||||
# Return score for whole aligned words
|
||||
return Score(gold, system, aligned)
|
||||
|
||||
for words in alignment.matched_words:
|
||||
if key_fn(words.gold_word, words.gold_parent) == key_fn(words.system_word, words.system_parent_gold_aligned):
|
||||
correct += weight_fn(words.gold_word)
|
||||
|
||||
return Score(gold, system, correct, aligned)
|
||||
|
||||
def beyond_end(words, i, multiword_span_end):
|
||||
if i >= len(words):
|
||||
return True
|
||||
if words[i].is_multiword:
|
||||
return words[i].span.start >= multiword_span_end
|
||||
return words[i].span.end > multiword_span_end
|
||||
|
||||
def extend_end(word, multiword_span_end):
|
||||
if word.is_multiword and word.span.end > multiword_span_end:
|
||||
return word.span.end
|
||||
return multiword_span_end
|
||||
|
||||
def find_multiword_span(gold_words, system_words, gi, si):
|
||||
# We know gold_words[gi].is_multiword or system_words[si].is_multiword.
|
||||
# Find the start of the multiword span (gs, ss), so the multiword span is minimal.
|
||||
# Initialize multiword_span_end characters index.
|
||||
if gold_words[gi].is_multiword:
|
||||
multiword_span_end = gold_words[gi].span.end
|
||||
if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start:
|
||||
si += 1
|
||||
else: # if system_words[si].is_multiword
|
||||
multiword_span_end = system_words[si].span.end
|
||||
if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start:
|
||||
gi += 1
|
||||
gs, ss = gi, si
|
||||
|
||||
# Find the end of the multiword span
|
||||
# (so both gi and si are pointing to the word following the multiword span end).
|
||||
while not beyond_end(gold_words, gi, multiword_span_end) or \
|
||||
not beyond_end(system_words, si, multiword_span_end):
|
||||
if gi < len(gold_words) and (si >= len(system_words) or
|
||||
gold_words[gi].span.start <= system_words[si].span.start):
|
||||
multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
|
||||
gi += 1
|
||||
else:
|
||||
multiword_span_end = extend_end(system_words[si], multiword_span_end)
|
||||
si += 1
|
||||
return gs, ss, gi, si
|
||||
|
||||
def compute_lcs(gold_words, system_words, gi, si, gs, ss):
|
||||
lcs = [[0] * (si - ss) for i in range(gi - gs)]
|
||||
for g in reversed(range(gi - gs)):
|
||||
for s in reversed(range(si - ss)):
|
||||
if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]):
|
||||
lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
|
||||
lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
|
||||
lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
|
||||
return lcs
|
||||
|
||||
def align_words(gold_words, system_words):
|
||||
alignment = Alignment(gold_words, system_words)
|
||||
|
||||
gi, si = 0, 0
|
||||
while gi < len(gold_words) and si < len(system_words):
|
||||
if gold_words[gi].is_multiword or system_words[si].is_multiword:
|
||||
# A: Multi-word tokens => align via LCS within the whole "multiword span".
|
||||
gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)
|
||||
|
||||
if si > ss and gi > gs:
|
||||
lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)
|
||||
|
||||
# Store aligned words
|
||||
s, g = 0, 0
|
||||
while g < gi - gs and s < si - ss:
|
||||
if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]):
|
||||
alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
|
||||
g += 1
|
||||
s += 1
|
||||
elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0):
|
||||
g += 1
|
||||
else:
|
||||
s += 1
|
||||
else:
|
||||
# B: No multi-word token => align according to spans.
|
||||
if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end):
|
||||
alignment.append_aligned_words(gold_words[gi], system_words[si])
|
||||
gi += 1
|
||||
si += 1
|
||||
elif gold_words[gi].span.start <= system_words[si].span.start:
|
||||
gi += 1
|
||||
else:
|
||||
si += 1
|
||||
|
||||
alignment.fill_parents()
|
||||
|
||||
return alignment
|
||||
|
||||
# Check that underlying character sequences do match
|
||||
if gold_ud.characters != system_ud.characters:
|
||||
index = 0
|
||||
while gold_ud.characters[index] == system_ud.characters[index]:
|
||||
index += 1
|
||||
|
||||
raise UDError(
|
||||
"The concatenation of tokens in gold file and in system file differ!\n" +
|
||||
"First 20 differing characters in gold file: '{}' and system file: '{}'".format(
|
||||
"".join(gold_ud.characters[index:index + 20]),
|
||||
"".join(system_ud.characters[index:index + 20])
|
||||
)
|
||||
)
|
||||
|
||||
# Align words
|
||||
alignment = align_words(gold_ud.words, system_ud.words)
|
||||
|
||||
# Compute the F1-scores
|
||||
result = {
|
||||
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
|
||||
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
|
||||
"Words": alignment_score(alignment, None),
|
||||
"UPOS": alignment_score(alignment, lambda w, parent: w.columns[UPOS]),
|
||||
"XPOS": alignment_score(alignment, lambda w, parent: w.columns[XPOS]),
|
||||
"Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]),
|
||||
"AllTags": alignment_score(alignment, lambda w, parent: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
|
||||
"Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]),
|
||||
"UAS": alignment_score(alignment, lambda w, parent: parent),
|
||||
"LAS": alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL])),
|
||||
}
|
||||
|
||||
# Add WeightedLAS if weights are given
|
||||
if deprel_weights is not None:
|
||||
def weighted_las(word):
|
||||
return deprel_weights.get(word.columns[DEPREL], 1.0)
|
||||
result["WeightedLAS"] = alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL]), weighted_las)
|
||||
|
||||
return result
|
||||
|
||||
def load_deprel_weights(weights_file):
|
||||
if weights_file is None:
|
||||
return None
|
||||
|
||||
deprel_weights = {}
|
||||
for line in weights_file:
|
||||
# Ignore comments and empty lines
|
||||
if line.startswith("#") or not line.strip():
|
||||
continue
|
||||
|
||||
columns = line.rstrip("\r\n").split()
|
||||
if len(columns) != 2:
|
||||
raise ValueError("Expected two columns in the UD Relations weights file on line '{}'".format(line))
|
||||
|
||||
deprel_weights[columns[0]] = float(columns[1])
|
||||
|
||||
return deprel_weights
|
||||
|
||||
def load_conllu_file(path):
|
||||
_file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {}))
|
||||
return load_conllu(_file)
|
||||
|
||||
def evaluate_wrapper(args):
|
||||
# Load CoNLL-U files
|
||||
gold_ud = load_conllu_file(args.gold_file)
|
||||
system_ud = load_conllu_file(args.system_file)
|
||||
|
||||
# Load weights if requested
|
||||
deprel_weights = load_deprel_weights(args.weights)
|
||||
|
||||
return evaluate(gold_ud, system_ud, deprel_weights)
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("gold_file", type=str,
|
||||
help="Name of the CoNLL-U file with the gold data.")
|
||||
parser.add_argument("system_file", type=str,
|
||||
help="Name of the CoNLL-U file with the predicted data.")
|
||||
parser.add_argument("--weights", "-w", type=argparse.FileType("r"), default=None,
|
||||
metavar="deprel_weights_file",
|
||||
help="Compute WeightedLAS using given weights for Universal Dependency Relations.")
|
||||
parser.add_argument("--verbose", "-v", default=0, action="count",
|
||||
help="Print all metrics.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use verbose if weights are supplied
|
||||
if args.weights is not None and not args.verbose:
|
||||
args.verbose = 1
|
||||
|
||||
# Evaluate
|
||||
evaluation = evaluate_wrapper(args)
|
||||
|
||||
# Print the evaluation
|
||||
if not args.verbose:
|
||||
print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1))
|
||||
else:
|
||||
metrics = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "Feats", "AllTags", "Lemmas", "UAS", "LAS"]
|
||||
if args.weights is not None:
|
||||
metrics.append("WeightedLAS")
|
||||
|
||||
print("Metrics | Precision | Recall | F1 Score | AligndAcc")
|
||||
print("-----------+-----------+-----------+-----------+-----------")
|
||||
for metric in metrics:
|
||||
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
|
||||
metric,
|
||||
100 * evaluation[metric].precision,
|
||||
100 * evaluation[metric].recall,
|
||||
100 * evaluation[metric].f1,
|
||||
"{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else ""
|
||||
))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# Tests, which can be executed with `python -m unittest conll17_ud_eval`.
|
||||
class TestAlignment(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _load_words(words):
|
||||
"""Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors."""
|
||||
lines, num_words = [], 0
|
||||
for w in words:
|
||||
parts = w.split(" ")
|
||||
if len(parts) == 1:
|
||||
num_words += 1
|
||||
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
|
||||
else:
|
||||
lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
|
||||
for part in parts[1:]:
|
||||
num_words += 1
|
||||
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
|
||||
return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
|
||||
|
||||
def _test_exception(self, gold, system):
|
||||
self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
|
||||
|
||||
def _test_ok(self, gold, system, correct):
|
||||
metrics = evaluate(self._load_words(gold), self._load_words(system))
|
||||
gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
|
||||
system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
|
||||
self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
|
||||
(correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
|
||||
|
||||
def test_exception(self):
|
||||
self._test_exception(["a"], ["b"])
|
||||
|
||||
def test_equal(self):
|
||||
self._test_ok(["a"], ["a"], 1)
|
||||
self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
|
||||
|
||||
def test_equal_with_multiword(self):
|
||||
self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
|
||||
self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
|
||||
self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
|
||||
self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
|
||||
|
||||
def test_alignment(self):
|
||||
self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
|
||||
self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
|
||||
self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
|
||||
self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
|
||||
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
|
||||
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
|
||||
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
|
394
spacy/cli/ud_train.py
Normal file
394
spacy/cli/ud_train.py
Normal file
|
@ -0,0 +1,394 @@
|
|||
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
||||
.conllu format for development data, allowing the official scorer to be used.
|
||||
'''
|
||||
from __future__ import unicode_literals
|
||||
import plac
|
||||
import tqdm
|
||||
import attr
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
|
||||
import spacy
|
||||
import spacy.util
|
||||
from ..tokens import Token, Doc
|
||||
from ..gold import GoldParse
|
||||
from ..syntax.nonproj import projectivize
|
||||
from ..matcher import Matcher
|
||||
from collections import defaultdict, Counter
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import numpy.random
|
||||
import cytoolz
|
||||
|
||||
from . import conll17_ud_eval
|
||||
|
||||
from .. import lang
|
||||
from .. import lang
|
||||
from ..lang import zh
|
||||
from ..lang import ja
|
||||
|
||||
lang.zh.Chinese.Defaults.use_jieba = False
|
||||
lang.ja.Japanese.Defaults.use_janome = False
|
||||
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
def minibatch_by_words(items, size=5000):
|
||||
random.shuffle(items)
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = []
|
||||
while batch_size >= 0:
|
||||
try:
|
||||
doc, gold = next(items)
|
||||
except StopIteration:
|
||||
if batch:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= len(doc)
|
||||
batch.append((doc, gold))
|
||||
if batch:
|
||||
yield batch
|
||||
else:
|
||||
break
|
||||
|
||||
################
|
||||
# Data reading #
|
||||
################
|
||||
|
||||
space_re = re.compile('\s+')
|
||||
def split_text(text):
|
||||
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
|
||||
|
||||
|
||||
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||
max_doc_length=None, limit=None):
|
||||
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
||||
include Doc objects created using nlp.make_doc and then aligned against
|
||||
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
||||
created from the gold-standard segments. At least one must be True.'''
|
||||
if not raw_text and not oracle_segments:
|
||||
raise ValueError("At least one of raw_text or oracle_segments must be True")
|
||||
paragraphs = split_text(text_file.read())
|
||||
conllu = read_conllu(conllu_file)
|
||||
# sd is spacy doc; cd is conllu doc
|
||||
# cs is conllu sent, ct is conllu token
|
||||
docs = []
|
||||
golds = []
|
||||
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
||||
sent_annots = []
|
||||
for cs in cd:
|
||||
sent = defaultdict(list)
|
||||
for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
|
||||
if '.' in id_:
|
||||
continue
|
||||
if '-' in id_:
|
||||
continue
|
||||
id_ = int(id_)-1
|
||||
head = int(head)-1 if head != '0' else id_
|
||||
sent['words'].append(word)
|
||||
sent['tags'].append(tag)
|
||||
sent['heads'].append(head)
|
||||
sent['deps'].append('ROOT' if dep == 'root' else dep)
|
||||
sent['spaces'].append(space_after == '_')
|
||||
sent['entities'] = ['-'] * len(sent['words'])
|
||||
sent['heads'], sent['deps'] = projectivize(sent['heads'],
|
||||
sent['deps'])
|
||||
if oracle_segments:
|
||||
docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
|
||||
golds.append(GoldParse(docs[-1], **sent))
|
||||
|
||||
sent_annots.append(sent)
|
||||
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
|
||||
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||
sent_annots = []
|
||||
docs.append(doc)
|
||||
golds.append(gold)
|
||||
if limit and len(docs) >= limit:
|
||||
return docs, golds
|
||||
|
||||
if raw_text and sent_annots:
|
||||
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||
docs.append(doc)
|
||||
golds.append(gold)
|
||||
if limit and len(docs) >= limit:
|
||||
return docs, golds
|
||||
return docs, golds
|
||||
|
||||
|
||||
def read_conllu(file_):
|
||||
docs = []
|
||||
sent = []
|
||||
doc = []
|
||||
for line in file_:
|
||||
if line.startswith('# newdoc'):
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
doc = []
|
||||
elif line.startswith('#'):
|
||||
continue
|
||||
elif not line.strip():
|
||||
if sent:
|
||||
doc.append(sent)
|
||||
sent = []
|
||||
else:
|
||||
sent.append(list(line.strip().split('\t')))
|
||||
if len(sent[-1]) != 10:
|
||||
print(repr(line))
|
||||
raise ValueError
|
||||
if sent:
|
||||
doc.append(sent)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def _make_gold(nlp, text, sent_annots):
|
||||
# Flatten the conll annotations, and adjust the head indices
|
||||
flat = defaultdict(list)
|
||||
for sent in sent_annots:
|
||||
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||
flat[field].extend(sent[field])
|
||||
# Construct text if necessary
|
||||
assert len(flat['words']) == len(flat['spaces'])
|
||||
if text is None:
|
||||
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||
doc = nlp.make_doc(text)
|
||||
flat.pop('spaces')
|
||||
gold = GoldParse(doc, **flat)
|
||||
return doc, gold
|
||||
|
||||
#############################
|
||||
# Data transforms for spaCy #
|
||||
#############################
|
||||
|
||||
def golds_to_gold_tuples(docs, golds):
|
||||
'''Get out the annoying 'tuples' format used by begin_training, given the
|
||||
GoldParse objects.'''
|
||||
tuples = []
|
||||
for doc, gold in zip(docs, golds):
|
||||
text = doc.text
|
||||
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
||||
sents = [((ids, words, tags, heads, labels, iob), [])]
|
||||
tuples.append((text, sents))
|
||||
return tuples
|
||||
|
||||
|
||||
##############
|
||||
# Evaluation #
|
||||
##############
|
||||
|
||||
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||
with text_loc.open('r', encoding='utf8') as text_file:
|
||||
texts = split_text(text_file.read())
|
||||
docs = list(nlp.pipe(texts))
|
||||
with sys_loc.open('w', encoding='utf8') as out_file:
|
||||
write_conllu(docs, out_file)
|
||||
with gold_loc.open('r', encoding='utf8') as gold_file:
|
||||
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||
with sys_loc.open('r', encoding='utf8') as sys_file:
|
||||
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
||||
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
||||
return scores
|
||||
|
||||
|
||||
def write_conllu(docs, file_):
|
||||
merger = Matcher(docs[0].vocab)
|
||||
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
||||
for i, doc in enumerate(docs):
|
||||
matches = merger(doc)
|
||||
spans = [doc[start:end+1] for _, start, end in matches]
|
||||
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||
for start_char, end_char in offsets:
|
||||
doc.merge(start_char, end_char)
|
||||
file_.write("# newdoc id = {i}\n".format(i=i))
|
||||
for j, sent in enumerate(doc.sents):
|
||||
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
||||
file_.write("# text = {text}\n".format(text=sent.text))
|
||||
for k, token in enumerate(sent):
|
||||
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||
file_.write('\n')
|
||||
|
||||
|
||||
def print_progress(itn, losses, ud_scores):
|
||||
fields = {
|
||||
'dep_loss': losses.get('parser', 0.0),
|
||||
'tag_loss': losses.get('tagger', 0.0),
|
||||
'words': ud_scores['Words'].f1 * 100,
|
||||
'sents': ud_scores['Sentences'].f1 * 100,
|
||||
'tags': ud_scores['XPOS'].f1 * 100,
|
||||
'uas': ud_scores['UAS'].f1 * 100,
|
||||
'las': ud_scores['LAS'].f1 * 100,
|
||||
}
|
||||
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
|
||||
if itn == 0:
|
||||
print('\t'.join(header))
|
||||
tpl = '\t'.join((
|
||||
'{:d}',
|
||||
'{dep_loss:.1f}',
|
||||
'{las:.1f}',
|
||||
'{uas:.1f}',
|
||||
'{tags:.1f}',
|
||||
'{sents:.1f}',
|
||||
'{words:.1f}',
|
||||
))
|
||||
print(tpl.format(itn, **fields))
|
||||
|
||||
#def get_sent_conllu(sent, sent_id):
|
||||
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||
|
||||
def get_token_conllu(token, i):
|
||||
if token._.begins_fused:
|
||||
n = 1
|
||||
while token.nbor(n)._.inside_fused:
|
||||
n += 1
|
||||
id_ = '%d-%d' % (i, i+n)
|
||||
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
|
||||
else:
|
||||
lines = []
|
||||
if token.head.i == token.i:
|
||||
head = 0
|
||||
else:
|
||||
head = i + (token.head.i - token.i) + 1
|
||||
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
|
||||
str(head), token.dep_.lower(), '_', '_']
|
||||
lines.append('\t'.join(fields))
|
||||
return '\n'.join(lines)
|
||||
|
||||
Token.set_extension('get_conllu_lines', method=get_token_conllu)
|
||||
Token.set_extension('begins_fused', default=False)
|
||||
Token.set_extension('inside_fused', default=False)
|
||||
|
||||
|
||||
##################
|
||||
# Initialization #
|
||||
##################
|
||||
|
||||
|
||||
def load_nlp(corpus, config):
|
||||
lang = corpus.split('_')[0]
|
||||
nlp = spacy.blank(lang)
|
||||
if config.vectors:
|
||||
nlp.vocab.from_disk(config.vectors / 'vocab')
|
||||
return nlp
|
||||
|
||||
def initialize_pipeline(nlp, docs, golds, config):
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
if config.multitask_tag:
|
||||
nlp.parser.add_multitask_objective('tag')
|
||||
if config.multitask_sent:
|
||||
nlp.parser.add_multitask_objective('sent_start')
|
||||
nlp.parser.moves.add_action(2, 'subtok')
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
if tag is not None:
|
||||
nlp.tagger.add_label(tag)
|
||||
# Replace labels that didn't make the frequency cutoff
|
||||
actions = set(nlp.parser.labels)
|
||||
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
||||
for gold in golds:
|
||||
for i, label in enumerate(gold.labels):
|
||||
if label is not None and label not in label_set:
|
||||
gold.labels[i] = label.split('||')[0]
|
||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
||||
|
||||
|
||||
########################
|
||||
# Command line helpers #
|
||||
########################
|
||||
|
||||
@attr.s
|
||||
class Config(object):
|
||||
vectors = attr.ib(default=None)
|
||||
max_doc_length = attr.ib(default=10)
|
||||
multitask_tag = attr.ib(default=True)
|
||||
multitask_sent = attr.ib(default=True)
|
||||
nr_epoch = attr.ib(default=30)
|
||||
batch_size = attr.ib(default=1000)
|
||||
dropout = attr.ib(default=0.2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, loc):
|
||||
with Path(loc).open('r', encoding='utf8') as file_:
|
||||
cfg = json.load(file_)
|
||||
return cls(**cfg)
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self, path, section):
|
||||
self.path = path
|
||||
self.section = section
|
||||
self.conllu = None
|
||||
self.text = None
|
||||
for file_path in self.path.iterdir():
|
||||
name = file_path.parts[-1]
|
||||
if section in name and name.endswith('conllu'):
|
||||
self.conllu = file_path
|
||||
elif section in name and name.endswith('txt'):
|
||||
self.text = file_path
|
||||
if self.conllu is None:
|
||||
msg = "Could not find .txt file in {path} for {section}"
|
||||
raise IOError(msg.format(section=section, path=path))
|
||||
if self.text is None:
|
||||
msg = "Could not find .txt file in {path} for {section}"
|
||||
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
|
||||
|
||||
|
||||
class TreebankPaths(object):
|
||||
def __init__(self, ud_path, treebank, **cfg):
|
||||
self.train = Dataset(ud_path / treebank, 'train')
|
||||
self.dev = Dataset(ud_path / treebank, 'dev')
|
||||
self.lang = self.train.lang
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||
"positional", None, str),
|
||||
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||
limit=("Size limit", "option", "n", int)
|
||||
)
|
||||
def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
if not (parses_dir / corpus).exists():
|
||||
(parses_dir / corpus).mkdir()
|
||||
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||
nlp = load_nlp(paths.lang, config)
|
||||
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
max_doc_length=config.max_doc_length, limit=limit)
|
||||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||
|
||||
for i in range(config.nr_epoch):
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||
losses = {}
|
||||
n_train_words = sum(len(doc) for doc in docs)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
for batch in batches:
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=config.dropout, losses=losses)
|
||||
|
||||
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||
print_progress(i, losses, scores)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
Loading…
Reference in New Issue
Block a user