Add ud_train and ud_evaluate CLI commands

This commit is contained in:
Matthew Honnibal 2018-03-10 23:41:55 +01:00
parent 4b72c38556
commit 3478ea76d1
2 changed files with 964 additions and 0 deletions

View File

@ -0,0 +1,570 @@
#!/usr/bin/env python
# CoNLL 2017 UD Parsing evaluation script.
#
# Compatible with Python 2.7 and 3.2+, can be used either as a module
# or a standalone executable.
#
# Copyright 2017 Institute of Formal and Applied Linguistics (UFAL),
# Faculty of Mathematics and Physics, Charles University, Czech Republic.
#
# Changelog:
# - [02 Jan 2017] Version 0.9: Initial release
# - [25 Jan 2017] Version 0.9.1: Fix bug in LCS alignment computation
# - [10 Mar 2017] Version 1.0: Add documentation and test
# Compare HEADs correctly using aligned words
# Allow evaluation with errorneous spaces in forms
# Compare forms in LCS case insensitively
# Detect cycles and multiple root nodes
# Compute AlignedAccuracy
# Command line usage
# ------------------
# conll17_ud_eval.py [-v] [-w weights_file] gold_conllu_file system_conllu_file
#
# - if no -v is given, only the CoNLL17 UD Shared Task evaluation LAS metrics
# is printed
# - if -v is given, several metrics are printed (as precision, recall, F1 score,
# and in case the metric is computed on aligned words also accuracy on these):
# - Tokens: how well do the gold tokens match system tokens
# - Sentences: how well do the gold sentences match system sentences
# - Words: how well can the gold words be aligned to system words
# - UPOS: using aligned words, how well does UPOS match
# - XPOS: using aligned words, how well does XPOS match
# - Feats: using aligned words, how well does FEATS match
# - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match
# - Lemmas: using aligned words, how well does LEMMA match
# - UAS: using aligned words, how well does HEAD match
# - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match
# - if weights_file is given (with lines containing deprel-weight pairs),
# one more metric is shown:
# - WeightedLAS: as LAS, but each deprel (ignoring subtypes) has different weight
# API usage
# ---------
# - load_conllu(file)
# - loads CoNLL-U file from given file object to an internal representation
# - the file object should return str on both Python 2 and Python 3
# - raises UDError exception if the given file cannot be loaded
# - evaluate(gold_ud, system_ud)
# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
# - raises UDError if the concatenated tokens of gold and system file do not match
# - returns a dictionary with the metrics described above, each metrics having
# three fields: precision, recall and f1
# Description of token matching
# -----------------------------
# In order to match tokens of gold file and system file, we consider the text
# resulting from concatenation of gold tokens and text resulting from
# concatenation of system tokens. These texts should match -- if they do not,
# the evaluation fails.
#
# If the texts do match, every token is represented as a range in this original
# text, and tokens are equal only if their range is the same.
# Description of word matching
# ----------------------------
# When matching words of gold file and system file, we first match the tokens.
# The words which are also tokens are matched as tokens, but words in multi-word
# tokens have to be handled differently.
#
# To handle multi-word tokens, we start by finding "multi-word spans".
# Multi-word span is a span in the original text such that
# - it contains at least one multi-word token
# - all multi-word tokens in the span (considering both gold and system ones)
# are completely inside the span (i.e., they do not "stick out")
# - the multi-word span is as small as possible
#
# For every multi-word span, we align the gold and system words completely
# inside this span using LCS on their FORMs. The words not intersecting
# (even partially) any multi-word span are then aligned as tokens.
from __future__ import division
from __future__ import print_function
import argparse
import io
import sys
import unittest
# CoNLL-U column names
ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
# UD Error is used when raising exceptions in this module
class UDError(Exception):
pass
# Load given CoNLL-U file into internal representation
def load_conllu(file):
# Internal representation classes
class UDRepresentation:
def __init__(self):
# Characters of all the tokens in the whole file.
# Whitespace between tokens is not included.
self.characters = []
# List of UDSpan instances with start&end indices into `characters`.
self.tokens = []
# List of UDWord instances.
self.words = []
# List of UDSpan instances with start&end indices into `characters`.
self.sentences = []
class UDSpan:
def __init__(self, start, end, characters):
self.start = start
# Note that self.end marks the first position **after the end** of span,
# so we can use characters[start:end] or range(start, end).
self.end = end
self.characters = characters
@property
def text(self):
return ''.join(self.characters[self.start:self.end])
def __str__(self):
return self.text
def __repr__(self):
return self.text
class UDWord:
def __init__(self, span, columns, is_multiword):
# Span of this word (or MWT, see below) within ud_representation.characters.
self.span = span
# 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
self.columns = columns
# is_multiword==True means that this word is part of a multi-word token.
# In that case, self.span marks the span of the whole multi-word token.
self.is_multiword = is_multiword
# Reference to the UDWord instance representing the HEAD (or None if root).
self.parent = None
# Let's ignore language-specific deprel subtypes.
self.columns[DEPREL] = columns[DEPREL].split(':')[0]
ud = UDRepresentation()
# Load the CoNLL-U file
index, sentence_start = 0, None
linenum = 0
while True:
line = file.readline()
linenum += 1
if not line:
break
line = line.rstrip("\r\n")
# Handle sentence start boundaries
if sentence_start is None:
# Skip comments
if line.startswith("#"):
continue
# Start a new sentence
ud.sentences.append(UDSpan(index, 0, ud.characters))
sentence_start = len(ud.words)
if not line:
# Add parent UDWord links and check there are no cycles
def process_word(word):
if word.parent == "remapping":
raise UDError("There is a cycle in a sentence")
if word.parent is None:
head = int(word.columns[HEAD])
if head > len(ud.words) - sentence_start:
raise UDError("HEAD '{}' points outside of the sentence".format(word.columns[HEAD]))
if head:
parent = ud.words[sentence_start + head - 1]
word.parent = "remapping"
process_word(parent)
word.parent = parent
for word in ud.words[sentence_start:]:
process_word(word)
# Check there is a single root node
if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
raise UDError("There are multiple roots in a sentence")
# End the sentence
ud.sentences[-1].end = index
sentence_start = None
continue
# Read next token/word
columns = line.split("\t")
if len(columns) != 10:
raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, line))
# Skip empty nodes
if "." in columns[ID]:
continue
# Delete spaces from FORM so gold.characters == system.characters
# even if one of them tokenizes the space.
columns[FORM] = columns[FORM].replace(" ", "")
if not columns[FORM]:
raise UDError("There is an empty FORM in the CoNLL-U file -- line %d" % linenum)
# Save token
ud.characters.extend(columns[FORM])
ud.tokens.append(UDSpan(index, index + len(columns[FORM]), ud.characters))
index += len(columns[FORM])
# Handle multi-word tokens to save word(s)
if "-" in columns[ID]:
try:
start, end = map(int, columns[ID].split("-"))
except:
raise UDError("Cannot parse multi-word token ID '{}'".format(columns[ID]))
for _ in range(start, end + 1):
word_line = file.readline().rstrip("\r\n")
word_columns = word_line.split("\t")
if len(word_columns) != 10:
print(columns)
raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, word_line))
ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
# Basic tokens/words
else:
try:
word_id = int(columns[ID])
except:
raise UDError("Cannot parse word ID '{}'".format(columns[ID]))
if word_id != len(ud.words) - sentence_start + 1:
raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(columns[ID], columns[FORM], len(ud.words) - sentence_start + 1))
try:
head_id = int(columns[HEAD])
except:
raise UDError("Cannot parse HEAD '{}'".format(columns[HEAD]))
if head_id < 0:
raise UDError("HEAD cannot be negative")
ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
if sentence_start is not None:
raise UDError("The CoNLL-U file does not end with empty line")
return ud
# Evaluate the gold and system treebanks (loaded using load_conllu).
def evaluate(gold_ud, system_ud, deprel_weights=None):
class Score:
def __init__(self, gold_total, system_total, correct, aligned_total=None):
self.precision = correct / system_total if system_total else 0.0
self.recall = correct / gold_total if gold_total else 0.0
self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
class AlignmentWord:
def __init__(self, gold_word, system_word):
self.gold_word = gold_word
self.system_word = system_word
self.gold_parent = None
self.system_parent_gold_aligned = None
class Alignment:
def __init__(self, gold_words, system_words):
self.gold_words = gold_words
self.system_words = system_words
self.matched_words = []
self.matched_words_map = {}
def append_aligned_words(self, gold_word, system_word):
self.matched_words.append(AlignmentWord(gold_word, system_word))
self.matched_words_map[system_word] = gold_word
def fill_parents(self):
# We represent root parents in both gold and system data by '0'.
# For gold data, we represent non-root parent by corresponding gold word.
# For system data, we represent non-root parent by either gold word aligned
# to parent system nodes, or by None if no gold words is aligned to the parent.
for words in self.matched_words:
words.gold_parent = words.gold_word.parent if words.gold_word.parent is not None else 0
words.system_parent_gold_aligned = self.matched_words_map.get(words.system_word.parent, None) \
if words.system_word.parent is not None else 0
def lower(text):
if sys.version_info < (3, 0) and isinstance(text, str):
return text.decode("utf-8").lower()
return text.lower()
def spans_score(gold_spans, system_spans):
correct, gi, si = 0, 0, 0
while gi < len(gold_spans) and si < len(system_spans):
if system_spans[si].start < gold_spans[gi].start:
si += 1
elif gold_spans[gi].start < system_spans[si].start:
gi += 1
else:
correct += gold_spans[gi].end == system_spans[si].end
si += 1
gi += 1
return Score(len(gold_spans), len(system_spans), correct)
def alignment_score(alignment, key_fn, weight_fn=lambda w: 1):
gold, system, aligned, correct = 0, 0, 0, 0
for word in alignment.gold_words:
gold += weight_fn(word)
for word in alignment.system_words:
system += weight_fn(word)
for words in alignment.matched_words:
aligned += weight_fn(words.gold_word)
if key_fn is None:
# Return score for whole aligned words
return Score(gold, system, aligned)
for words in alignment.matched_words:
if key_fn(words.gold_word, words.gold_parent) == key_fn(words.system_word, words.system_parent_gold_aligned):
correct += weight_fn(words.gold_word)
return Score(gold, system, correct, aligned)
def beyond_end(words, i, multiword_span_end):
if i >= len(words):
return True
if words[i].is_multiword:
return words[i].span.start >= multiword_span_end
return words[i].span.end > multiword_span_end
def extend_end(word, multiword_span_end):
if word.is_multiword and word.span.end > multiword_span_end:
return word.span.end
return multiword_span_end
def find_multiword_span(gold_words, system_words, gi, si):
# We know gold_words[gi].is_multiword or system_words[si].is_multiword.
# Find the start of the multiword span (gs, ss), so the multiword span is minimal.
# Initialize multiword_span_end characters index.
if gold_words[gi].is_multiword:
multiword_span_end = gold_words[gi].span.end
if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start:
si += 1
else: # if system_words[si].is_multiword
multiword_span_end = system_words[si].span.end
if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start:
gi += 1
gs, ss = gi, si
# Find the end of the multiword span
# (so both gi and si are pointing to the word following the multiword span end).
while not beyond_end(gold_words, gi, multiword_span_end) or \
not beyond_end(system_words, si, multiword_span_end):
if gi < len(gold_words) and (si >= len(system_words) or
gold_words[gi].span.start <= system_words[si].span.start):
multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
gi += 1
else:
multiword_span_end = extend_end(system_words[si], multiword_span_end)
si += 1
return gs, ss, gi, si
def compute_lcs(gold_words, system_words, gi, si, gs, ss):
lcs = [[0] * (si - ss) for i in range(gi - gs)]
for g in reversed(range(gi - gs)):
for s in reversed(range(si - ss)):
if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]):
lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
return lcs
def align_words(gold_words, system_words):
alignment = Alignment(gold_words, system_words)
gi, si = 0, 0
while gi < len(gold_words) and si < len(system_words):
if gold_words[gi].is_multiword or system_words[si].is_multiword:
# A: Multi-word tokens => align via LCS within the whole "multiword span".
gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)
if si > ss and gi > gs:
lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)
# Store aligned words
s, g = 0, 0
while g < gi - gs and s < si - ss:
if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]):
alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
g += 1
s += 1
elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0):
g += 1
else:
s += 1
else:
# B: No multi-word token => align according to spans.
if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end):
alignment.append_aligned_words(gold_words[gi], system_words[si])
gi += 1
si += 1
elif gold_words[gi].span.start <= system_words[si].span.start:
gi += 1
else:
si += 1
alignment.fill_parents()
return alignment
# Check that underlying character sequences do match
if gold_ud.characters != system_ud.characters:
index = 0
while gold_ud.characters[index] == system_ud.characters[index]:
index += 1
raise UDError(
"The concatenation of tokens in gold file and in system file differ!\n" +
"First 20 differing characters in gold file: '{}' and system file: '{}'".format(
"".join(gold_ud.characters[index:index + 20]),
"".join(system_ud.characters[index:index + 20])
)
)
# Align words
alignment = align_words(gold_ud.words, system_ud.words)
# Compute the F1-scores
result = {
"Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
"Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
"Words": alignment_score(alignment, None),
"UPOS": alignment_score(alignment, lambda w, parent: w.columns[UPOS]),
"XPOS": alignment_score(alignment, lambda w, parent: w.columns[XPOS]),
"Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]),
"AllTags": alignment_score(alignment, lambda w, parent: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
"Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]),
"UAS": alignment_score(alignment, lambda w, parent: parent),
"LAS": alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL])),
}
# Add WeightedLAS if weights are given
if deprel_weights is not None:
def weighted_las(word):
return deprel_weights.get(word.columns[DEPREL], 1.0)
result["WeightedLAS"] = alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL]), weighted_las)
return result
def load_deprel_weights(weights_file):
if weights_file is None:
return None
deprel_weights = {}
for line in weights_file:
# Ignore comments and empty lines
if line.startswith("#") or not line.strip():
continue
columns = line.rstrip("\r\n").split()
if len(columns) != 2:
raise ValueError("Expected two columns in the UD Relations weights file on line '{}'".format(line))
deprel_weights[columns[0]] = float(columns[1])
return deprel_weights
def load_conllu_file(path):
_file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {}))
return load_conllu(_file)
def evaluate_wrapper(args):
# Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file)
# Load weights if requested
deprel_weights = load_deprel_weights(args.weights)
return evaluate(gold_ud, system_ud, deprel_weights)
def main():
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("gold_file", type=str,
help="Name of the CoNLL-U file with the gold data.")
parser.add_argument("system_file", type=str,
help="Name of the CoNLL-U file with the predicted data.")
parser.add_argument("--weights", "-w", type=argparse.FileType("r"), default=None,
metavar="deprel_weights_file",
help="Compute WeightedLAS using given weights for Universal Dependency Relations.")
parser.add_argument("--verbose", "-v", default=0, action="count",
help="Print all metrics.")
args = parser.parse_args()
# Use verbose if weights are supplied
if args.weights is not None and not args.verbose:
args.verbose = 1
# Evaluate
evaluation = evaluate_wrapper(args)
# Print the evaluation
if not args.verbose:
print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1))
else:
metrics = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "Feats", "AllTags", "Lemmas", "UAS", "LAS"]
if args.weights is not None:
metrics.append("WeightedLAS")
print("Metrics | Precision | Recall | F1 Score | AligndAcc")
print("-----------+-----------+-----------+-----------+-----------")
for metric in metrics:
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric,
100 * evaluation[metric].precision,
100 * evaluation[metric].recall,
100 * evaluation[metric].f1,
"{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else ""
))
if __name__ == "__main__":
main()
# Tests, which can be executed with `python -m unittest conll17_ud_eval`.
class TestAlignment(unittest.TestCase):
@staticmethod
def _load_words(words):
"""Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors."""
lines, num_words = [], 0
for w in words:
parts = w.split(" ")
if len(parts) == 1:
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
else:
lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
for part in parts[1:]:
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
def _test_exception(self, gold, system):
self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
def _test_ok(self, gold, system, correct):
metrics = evaluate(self._load_words(gold), self._load_words(system))
gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
(correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
def test_exception(self):
self._test_exception(["a"], ["b"])
def test_equal(self):
self._test_ok(["a"], ["a"], 1)
self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
def test_equal_with_multiword(self):
self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
def test_alignment(self):
self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)

394
spacy/cli/ud_train.py Normal file
View File

@ -0,0 +1,394 @@
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
.conllu format for development data, allowing the official scorer to be used.
'''
from __future__ import unicode_literals
import plac
import tqdm
import attr
from pathlib import Path
import re
import sys
import json
import spacy
import spacy.util
from ..tokens import Token, Doc
from ..gold import GoldParse
from ..syntax.nonproj import projectivize
from ..matcher import Matcher
from collections import defaultdict, Counter
from timeit import default_timer as timer
import itertools
import random
import numpy.random
import cytoolz
from . import conll17_ud_eval
from .. import lang
from .. import lang
from ..lang import zh
from ..lang import ja
lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False
random.seed(0)
numpy.random.seed(0)
def minibatch_by_words(items, size=5000):
random.shuffle(items)
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
except StopIteration:
if batch:
yield batch
return
batch_size -= len(doc)
batch.append((doc, gold))
if batch:
yield batch
else:
break
################
# Data reading #
################
space_re = re.compile('\s+')
def split_text(text):
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
max_doc_length=None, limit=None):
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True.'''
if not raw_text and not oracle_segments:
raise ValueError("At least one of raw_text or oracle_segments must be True")
paragraphs = split_text(text_file.read())
conllu = read_conllu(conllu_file)
# sd is spacy doc; cd is conllu doc
# cs is conllu sent, ct is conllu token
docs = []
golds = []
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
sent_annots = []
for cs in cd:
sent = defaultdict(list)
for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
if '.' in id_:
continue
if '-' in id_:
continue
id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_
sent['words'].append(word)
sent['tags'].append(tag)
sent['heads'].append(head)
sent['deps'].append('ROOT' if dep == 'root' else dep)
sent['spaces'].append(space_after == '_')
sent['entities'] = ['-'] * len(sent['words'])
sent['heads'], sent['deps'] = projectivize(sent['heads'],
sent['deps'])
if oracle_segments:
docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
golds.append(GoldParse(docs[-1], **sent))
sent_annots.append(sent)
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
doc, gold = _make_gold(nlp, None, sent_annots)
sent_annots = []
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
if raw_text and sent_annots:
doc, gold = _make_gold(nlp, None, sent_annots)
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return docs, golds
def read_conllu(file_):
docs = []
sent = []
doc = []
for line in file_:
if line.startswith('# newdoc'):
if doc:
docs.append(doc)
doc = []
elif line.startswith('#'):
continue
elif not line.strip():
if sent:
doc.append(sent)
sent = []
else:
sent.append(list(line.strip().split('\t')))
if len(sent[-1]) != 10:
print(repr(line))
raise ValueError
if sent:
doc.append(sent)
if doc:
docs.append(doc)
return docs
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
return doc, gold
#############################
# Data transforms for spaCy #
#############################
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
##############
# Evaluation #
##############
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
with text_loc.open('r', encoding='utf8') as text_file:
texts = split_text(text_file.read())
docs = list(nlp.pipe(texts))
with sys_loc.open('w', encoding='utf8') as out_file:
write_conllu(docs, out_file)
with gold_loc.open('r', encoding='utf8') as gold_file:
gold_ud = conll17_ud_eval.load_conllu(gold_file)
with sys_loc.open('r', encoding='utf8') as sys_file:
sys_ud = conll17_ud_eval.load_conllu(sys_file)
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
return scores
def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
for i, doc in enumerate(docs):
matches = merger(doc)
spans = [doc[start:end+1] for _, start, end in matches]
offsets = [(span.start_char, span.end_char) for span in spans]
for start_char, end_char in offsets:
doc.merge(start_char, end_char)
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
file_.write("# text = {text}\n".format(text=sent.text))
for k, token in enumerate(sent):
file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n')
def print_progress(itn, losses, ud_scores):
fields = {
'dep_loss': losses.get('parser', 0.0),
'tag_loss': losses.get('tagger', 0.0),
'words': ud_scores['Words'].f1 * 100,
'sents': ud_scores['Sentences'].f1 * 100,
'tags': ud_scores['XPOS'].f1 * 100,
'uas': ud_scores['UAS'].f1 * 100,
'las': ud_scores['LAS'].f1 * 100,
}
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
if itn == 0:
print('\t'.join(header))
tpl = '\t'.join((
'{:d}',
'{dep_loss:.1f}',
'{las:.1f}',
'{uas:.1f}',
'{tags:.1f}',
'{sents:.1f}',
'{words:.1f}',
))
print(tpl.format(itn, **fields))
#def get_sent_conllu(sent, sent_id):
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
def get_token_conllu(token, i):
if token._.begins_fused:
n = 1
while token.nbor(n)._.inside_fused:
n += 1
id_ = '%d-%d' % (i, i+n)
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
else:
lines = []
if token.head.i == token.i:
head = 0
else:
head = i + (token.head.i - token.i) + 1
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
str(head), token.dep_.lower(), '_', '_']
lines.append('\t'.join(fields))
return '\n'.join(lines)
Token.set_extension('get_conllu_lines', method=get_token_conllu)
Token.set_extension('begins_fused', default=False)
Token.set_extension('inside_fused', default=False)
##################
# Initialization #
##################
def load_nlp(corpus, config):
lang = corpus.split('_')[0]
nlp = spacy.blank(lang)
if config.vectors:
nlp.vocab.from_disk(config.vectors / 'vocab')
return nlp
def initialize_pipeline(nlp, docs, golds, config):
nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start')
nlp.parser.moves.add_action(2, 'subtok')
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
# Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels)
label_set = set([act.split('-')[1] for act in actions if '-' in act])
for gold in golds:
for i, label in enumerate(gold.labels):
if label is not None and label not in label_set:
gold.labels[i] = label.split('||')[0]
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
########################
# Command line helpers #
########################
@attr.s
class Config(object):
vectors = attr.ib(default=None)
max_doc_length = attr.ib(default=10)
multitask_tag = attr.ib(default=True)
multitask_sent = attr.ib(default=True)
nr_epoch = attr.ib(default=30)
batch_size = attr.ib(default=1000)
dropout = attr.ib(default=0.2)
@classmethod
def load(cls, loc):
with Path(loc).open('r', encoding='utf8') as file_:
cfg = json.load(file_)
return cls(**cfg)
class Dataset(object):
def __init__(self, path, section):
self.path = path
self.section = section
self.conllu = None
self.text = None
for file_path in self.path.iterdir():
name = file_path.parts[-1]
if section in name and name.endswith('conllu'):
self.conllu = file_path
elif section in name and name.endswith('txt'):
self.text = file_path
if self.conllu is None:
msg = "Could not find .txt file in {path} for {section}"
raise IOError(msg.format(section=section, path=path))
if self.text is None:
msg = "Could not find .txt file in {path} for {section}"
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
class TreebankPaths(object):
def __init__(self, ud_path, treebank, **cfg):
self.train = Dataset(ud_path / treebank, 'train')
self.dev = Dataset(ud_path / treebank, 'dev')
self.lang = self.train.lang
@plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
limit=("Size limit", "option", "n", int)
)
def main(ud_dir, parses_dir, config, corpus, limit=0):
paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir()
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
losses = {}
n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches:
batch_docs, batch_gold = zip(*batch)
pbar.update(sum(len(doc) for doc in batch_docs))
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses)
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages):
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
print_progress(i, losses, scores)
if __name__ == '__main__':
plac.call(main)