2018-11-30 22:16:14 +03:00
|
|
|
# flake8: noqa
|
|
|
|
"""Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
2018-04-29 16:50:25 +03:00
|
|
|
.conllu format for development data, allowing the official scorer to be used.
|
2018-11-30 22:16:14 +03:00
|
|
|
"""
|
2018-04-29 16:50:25 +03:00
|
|
|
from __future__ import unicode_literals
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
import plac
|
|
|
|
import tqdm
|
|
|
|
from pathlib import Path
|
|
|
|
import re
|
|
|
|
import sys
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
import srsly
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
import spacy
|
|
|
|
import spacy.util
|
2018-11-30 22:16:14 +03:00
|
|
|
from ...tokens import Token, Doc
|
|
|
|
from ...gold import GoldParse
|
|
|
|
from ...util import compounding, minibatch_by_words
|
|
|
|
from ...syntax.nonproj import projectivize
|
|
|
|
from ...matcher import Matcher
|
|
|
|
|
|
|
|
# from ...morphology import Fused_begin, Fused_inside
|
|
|
|
from ... import displacy
|
2018-04-29 16:50:25 +03:00
|
|
|
from collections import defaultdict, Counter
|
|
|
|
from timeit import default_timer as timer
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-05-15 23:17:29 +03:00
|
|
|
Fused_begin = None
|
|
|
|
Fused_inside = None
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
import itertools
|
|
|
|
import random
|
|
|
|
import numpy.random
|
|
|
|
|
|
|
|
from . import conll17_ud_eval
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
from ... import lang
|
|
|
|
from ...lang import zh
|
|
|
|
from ...lang import ja
|
|
|
|
from ...lang import ru
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
# Data reading #
|
|
|
|
################
|
|
|
|
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
space_re = re.compile(r"\s+")
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
def split_text(text):
|
2018-11-30 22:16:14 +03:00
|
|
|
return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]
|
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
##############
|
|
|
|
# Evaluation #
|
|
|
|
##############
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
def read_conllu(file_):
|
|
|
|
docs = []
|
|
|
|
sent = []
|
|
|
|
doc = []
|
|
|
|
for line in file_:
|
2018-11-30 22:16:14 +03:00
|
|
|
if line.startswith("# newdoc"):
|
2018-04-29 16:50:25 +03:00
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
doc = []
|
2018-11-30 22:16:14 +03:00
|
|
|
elif line.startswith("#"):
|
2018-04-29 16:50:25 +03:00
|
|
|
continue
|
|
|
|
elif not line.strip():
|
|
|
|
if sent:
|
|
|
|
doc.append(sent)
|
|
|
|
sent = []
|
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
sent.append(list(line.strip().split("\t")))
|
2018-04-29 16:50:25 +03:00
|
|
|
if len(sent[-1]) != 10:
|
|
|
|
print(repr(line))
|
|
|
|
raise ValueError
|
|
|
|
if sent:
|
|
|
|
doc.append(sent)
|
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
2018-11-30 22:16:14 +03:00
|
|
|
if text_loc.parts[-1].endswith(".conllu"):
|
2018-04-29 16:50:25 +03:00
|
|
|
docs = []
|
|
|
|
with text_loc.open() as file_:
|
|
|
|
for conllu_doc in read_conllu(file_):
|
|
|
|
for conllu_sent in conllu_doc:
|
|
|
|
words = [line[1] for line in conllu_sent]
|
|
|
|
docs.append(Doc(nlp.vocab, words=words))
|
|
|
|
for name, component in nlp.pipeline:
|
|
|
|
docs = list(component.pipe(docs))
|
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
with text_loc.open("r", encoding="utf8") as text_file:
|
2018-04-29 16:50:25 +03:00
|
|
|
texts = split_text(text_file.read())
|
|
|
|
docs = list(nlp.pipe(texts))
|
2018-11-30 22:16:14 +03:00
|
|
|
with sys_loc.open("w", encoding="utf8") as out_file:
|
2018-04-29 16:50:25 +03:00
|
|
|
write_conllu(docs, out_file)
|
2018-11-30 22:16:14 +03:00
|
|
|
with gold_loc.open("r", encoding="utf8") as gold_file:
|
2018-04-29 16:50:25 +03:00
|
|
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
2018-11-30 22:16:14 +03:00
|
|
|
with sys_loc.open("r", encoding="utf8") as sys_file:
|
2018-04-29 16:50:25 +03:00
|
|
|
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
|
|
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
|
|
|
return docs, scores
|
|
|
|
|
|
|
|
|
|
|
|
def write_conllu(docs, file_):
|
|
|
|
merger = Matcher(docs[0].vocab)
|
2018-11-30 22:16:14 +03:00
|
|
|
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
|
2018-04-29 16:50:25 +03:00
|
|
|
for i, doc in enumerate(docs):
|
|
|
|
matches = merger(doc)
|
2018-11-30 22:16:14 +03:00
|
|
|
spans = [doc[start : end + 1] for _, start, end in matches]
|
2019-02-15 12:29:44 +03:00
|
|
|
with doc.retokenize() as retokenizer:
|
|
|
|
for span in spans:
|
|
|
|
retokenizer.merge(span)
|
|
|
|
# TODO: This shouldn't be necessary? Should be handled in merge
|
2018-04-29 16:50:25 +03:00
|
|
|
for word in doc:
|
|
|
|
if word.i == word.head.i:
|
2018-11-30 22:16:14 +03:00
|
|
|
word.dep_ = "ROOT"
|
2018-04-29 16:50:25 +03:00
|
|
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
|
|
|
for j, sent in enumerate(doc.sents):
|
|
|
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
|
|
|
file_.write("# text = {text}\n".format(text=sent.text))
|
|
|
|
for k, token in enumerate(sent):
|
2018-11-30 22:16:14 +03:00
|
|
|
file_.write(_get_token_conllu(token, k, len(sent)) + "\n")
|
|
|
|
file_.write("\n")
|
2018-04-29 16:50:25 +03:00
|
|
|
for word in sent:
|
2018-11-30 22:16:14 +03:00
|
|
|
if word.head.i == word.i and word.dep_ == "ROOT":
|
2018-04-29 16:50:25 +03:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
print("Rootless sentence!")
|
|
|
|
print(sent)
|
|
|
|
print(i)
|
|
|
|
for w in sent:
|
|
|
|
print(w.i, w.text, w.head.text, w.head.i, w.dep_)
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
def _get_token_conllu(token, k, sent_len):
|
2018-11-30 22:16:14 +03:00
|
|
|
if token.check_morph(Fused_begin) and (k + 1 < sent_len):
|
2018-04-29 16:50:25 +03:00
|
|
|
n = 1
|
|
|
|
text = [token.text]
|
|
|
|
while token.nbor(n).check_morph(Fused_inside):
|
|
|
|
text.append(token.nbor(n).text)
|
|
|
|
n += 1
|
2018-11-30 22:16:14 +03:00
|
|
|
id_ = "%d-%d" % (k + 1, (k + n))
|
|
|
|
fields = [id_, "".join(text)] + ["_"] * 8
|
|
|
|
lines = ["\t".join(fields)]
|
2018-04-29 16:50:25 +03:00
|
|
|
else:
|
|
|
|
lines = []
|
|
|
|
if token.head.i == token.i:
|
|
|
|
head = 0
|
|
|
|
else:
|
|
|
|
head = k + (token.head.i - token.i) + 1
|
2018-11-30 22:16:14 +03:00
|
|
|
fields = [
|
|
|
|
str(k + 1),
|
|
|
|
token.text,
|
|
|
|
token.lemma_,
|
|
|
|
token.pos_,
|
|
|
|
token.tag_,
|
|
|
|
"_",
|
|
|
|
str(head),
|
|
|
|
token.dep_.lower(),
|
|
|
|
"_",
|
|
|
|
"_",
|
|
|
|
]
|
|
|
|
if token.check_morph(Fused_begin) and (k + 1 < sent_len):
|
2018-04-29 16:50:25 +03:00
|
|
|
if k == 0:
|
|
|
|
fields[1] = token.norm_[0].upper() + token.norm_[1:]
|
|
|
|
else:
|
|
|
|
fields[1] = token.norm_
|
|
|
|
elif token.check_morph(Fused_inside):
|
|
|
|
fields[1] = token.norm_
|
|
|
|
elif token._.split_start is not None:
|
|
|
|
split_start = token._.split_start
|
|
|
|
split_end = token._.split_end
|
|
|
|
split_len = (split_end.i - split_start.i) + 1
|
|
|
|
n_in_split = token.i - split_start.i
|
2018-11-30 22:16:14 +03:00
|
|
|
subtokens = guess_fused_orths(split_start.text, [""] * split_len)
|
2018-04-29 16:50:25 +03:00
|
|
|
fields[1] = subtokens[n_in_split]
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
lines.append("\t".join(fields))
|
|
|
|
return "\n".join(lines)
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
|
|
|
|
def guess_fused_orths(word, ud_forms):
|
2018-11-30 22:16:14 +03:00
|
|
|
"""The UD data 'fused tokens' don't necessarily expand to keys that match
|
2018-04-29 16:50:25 +03:00
|
|
|
the form. We need orths that exact match the string. Here we make a best
|
2018-11-30 22:16:14 +03:00
|
|
|
effort to divide up the word."""
|
|
|
|
if word == "".join(ud_forms):
|
2018-04-29 16:50:25 +03:00
|
|
|
# Happy case: we get a perfect split, with each letter accounted for.
|
|
|
|
return ud_forms
|
|
|
|
elif len(word) == sum(len(subtoken) for subtoken in ud_forms):
|
|
|
|
# Unideal, but at least lengths match.
|
|
|
|
output = []
|
|
|
|
remain = word
|
|
|
|
for subtoken in ud_forms:
|
|
|
|
assert len(subtoken) >= 1
|
2018-11-30 22:16:14 +03:00
|
|
|
output.append(remain[: len(subtoken)])
|
|
|
|
remain = remain[len(subtoken) :]
|
2018-04-29 16:50:25 +03:00
|
|
|
assert len(remain) == 0, (word, ud_forms, remain)
|
|
|
|
return output
|
|
|
|
else:
|
|
|
|
# Let's say word is 6 long, and there are three subtokens. The orths
|
|
|
|
# *must* equal the original string. Arbitrarily, split [4, 1, 1]
|
2018-11-30 22:16:14 +03:00
|
|
|
first = word[: len(word) - (len(ud_forms) - 1)]
|
2018-04-29 16:50:25 +03:00
|
|
|
output = [first]
|
2018-11-30 22:16:14 +03:00
|
|
|
remain = word[len(first) :]
|
2018-04-29 16:50:25 +03:00
|
|
|
for i in range(1, len(ud_forms)):
|
|
|
|
assert remain
|
|
|
|
output.append(remain[:1])
|
|
|
|
remain = remain[1:]
|
|
|
|
assert len(remain) == 0, (word, output, remain)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
def print_results(name, ud_scores):
|
|
|
|
fields = {}
|
|
|
|
if ud_scores is not None:
|
2018-11-30 22:16:14 +03:00
|
|
|
fields.update(
|
|
|
|
{
|
|
|
|
"words": ud_scores["Words"].f1 * 100,
|
|
|
|
"sents": ud_scores["Sentences"].f1 * 100,
|
|
|
|
"tags": ud_scores["XPOS"].f1 * 100,
|
|
|
|
"uas": ud_scores["UAS"].f1 * 100,
|
|
|
|
"las": ud_scores["LAS"].f1 * 100,
|
|
|
|
}
|
|
|
|
)
|
2018-04-29 16:50:25 +03:00
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
fields.update({"words": 0.0, "sents": 0.0, "tags": 0.0, "uas": 0.0, "las": 0.0})
|
|
|
|
tpl = "\t".join(
|
|
|
|
(name, "{las:.1f}", "{uas:.1f}", "{tags:.1f}", "{sents:.1f}", "{words:.1f}")
|
|
|
|
)
|
2018-04-29 16:50:25 +03:00
|
|
|
print(tpl.format(**fields))
|
|
|
|
return fields
|
|
|
|
|
|
|
|
|
|
|
|
def get_token_split_start(token):
|
2018-11-30 22:16:14 +03:00
|
|
|
if token.text == "":
|
2018-04-29 16:50:25 +03:00
|
|
|
assert token.i != 0
|
|
|
|
i = -1
|
2018-11-30 22:16:14 +03:00
|
|
|
while token.nbor(i).text == "":
|
2018-04-29 16:50:25 +03:00
|
|
|
i -= 1
|
|
|
|
return token.nbor(i)
|
2018-11-30 22:16:14 +03:00
|
|
|
elif (token.i + 1) < len(token.doc) and token.nbor(1).text == "":
|
2018-04-29 16:50:25 +03:00
|
|
|
return token
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_token_split_end(token):
|
2018-11-30 22:16:14 +03:00
|
|
|
if (token.i + 1) == len(token.doc):
|
|
|
|
return token if token.text == "" else None
|
|
|
|
elif token.text != "" and token.nbor(1).text != "":
|
2018-04-29 16:50:25 +03:00
|
|
|
return None
|
|
|
|
i = 1
|
2018-11-30 22:16:14 +03:00
|
|
|
while (token.i + i) < len(token.doc) and token.nbor(i).text == "":
|
2018-04-29 16:50:25 +03:00
|
|
|
i += 1
|
2018-11-30 22:16:14 +03:00
|
|
|
return token.nbor(i - 1)
|
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
##################
|
|
|
|
# Initialization #
|
|
|
|
##################
|
|
|
|
|
|
|
|
|
|
|
|
def load_nlp(experiments_dir, corpus):
|
2018-11-30 22:16:14 +03:00
|
|
|
nlp = spacy.load(experiments_dir / corpus / "best-model")
|
2018-04-29 16:50:25 +03:00
|
|
|
return nlp
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-04-29 16:50:25 +03:00
|
|
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
2018-11-30 22:16:14 +03:00
|
|
|
nlp.add_pipe(nlp.create_pipe("parser"))
|
2018-04-29 16:50:25 +03:00
|
|
|
return nlp
|
|
|
|
|
|
|
|
|
|
|
|
@plac.annotations(
|
2018-11-30 22:16:14 +03:00
|
|
|
test_data_dir=(
|
|
|
|
"Path to Universal Dependencies test data",
|
|
|
|
"positional",
|
|
|
|
None,
|
|
|
|
Path,
|
|
|
|
),
|
2018-04-29 16:50:25 +03:00
|
|
|
experiment_dir=("Parent directory with output model", "positional", None, Path),
|
2018-11-30 22:16:14 +03:00
|
|
|
corpus=(
|
|
|
|
"UD corpus to evaluate, e.g. UD_English, UD_Spanish, etc",
|
|
|
|
"positional",
|
|
|
|
None,
|
|
|
|
str,
|
|
|
|
),
|
2018-04-29 16:50:25 +03:00
|
|
|
)
|
|
|
|
def main(test_data_dir, experiment_dir, corpus):
|
2018-11-30 22:16:14 +03:00
|
|
|
Token.set_extension("split_start", getter=get_token_split_start)
|
|
|
|
Token.set_extension("split_end", getter=get_token_split_end)
|
|
|
|
Token.set_extension("begins_fused", default=False)
|
|
|
|
Token.set_extension("inside_fused", default=False)
|
2018-04-29 16:50:25 +03:00
|
|
|
lang.zh.Chinese.Defaults.use_jieba = False
|
|
|
|
lang.ja.Japanese.Defaults.use_janome = False
|
|
|
|
lang.ru.Russian.Defaults.use_pymorphy2 = False
|
|
|
|
|
|
|
|
nlp = load_nlp(experiment_dir, corpus)
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
treebank_code = nlp.meta["treebank"]
|
|
|
|
for section in ("test", "dev"):
|
|
|
|
if section == "dev":
|
|
|
|
section_dir = "conll17-ud-development-2017-03-19"
|
2018-04-29 16:50:25 +03:00
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
section_dir = "conll17-ud-test-2017-05-09"
|
|
|
|
text_path = test_data_dir / "input" / section_dir / (treebank_code + ".txt")
|
|
|
|
udpipe_path = (
|
|
|
|
test_data_dir / "input" / section_dir / (treebank_code + "-udpipe.conllu")
|
|
|
|
)
|
|
|
|
gold_path = test_data_dir / "gold" / section_dir / (treebank_code + ".conllu")
|
|
|
|
|
|
|
|
header = [section, "LAS", "UAS", "TAG", "SENT", "WORD"]
|
|
|
|
print("\t".join(header))
|
|
|
|
inputs = {"gold": gold_path, "udp": udpipe_path, "raw": text_path}
|
|
|
|
for input_type in ("udp", "raw"):
|
2018-04-29 16:50:25 +03:00
|
|
|
input_path = inputs[input_type]
|
2018-11-30 22:16:14 +03:00
|
|
|
output_path = (
|
|
|
|
experiment_dir / corpus / "{section}.conllu".format(section=section)
|
|
|
|
)
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
parsed_docs, test_scores = evaluate(nlp, input_path, gold_path, output_path)
|
|
|
|
|
|
|
|
accuracy = print_results(input_type, test_scores)
|
2018-11-30 22:16:14 +03:00
|
|
|
acc_path = (
|
|
|
|
experiment_dir
|
|
|
|
/ corpus
|
|
|
|
/ "{section}-accuracy.json".format(section=section)
|
|
|
|
)
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
srsly.write_json(acc_path, accuracy)
|
2018-04-29 16:50:25 +03:00
|
|
|
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
if __name__ == "__main__":
|
2018-04-29 16:50:25 +03:00
|
|
|
plac.call(main)
|