mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			339 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			339 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# flake8: noqa
 | 
						|
"""Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
 | 
						|
.conllu format for development data, allowing the official scorer to be used.
 | 
						|
"""
 | 
						|
from __future__ import unicode_literals
 | 
						|
 | 
						|
import plac
 | 
						|
import tqdm
 | 
						|
from pathlib import Path
 | 
						|
import re
 | 
						|
import sys
 | 
						|
import srsly
 | 
						|
 | 
						|
import spacy
 | 
						|
import spacy.util
 | 
						|
from ...tokens import Token, Doc
 | 
						|
from ...gold import GoldParse
 | 
						|
from ...util import compounding, minibatch_by_words
 | 
						|
from ...syntax.nonproj import projectivize
 | 
						|
from ...matcher import Matcher
 | 
						|
 | 
						|
# from ...morphology import Fused_begin, Fused_inside
 | 
						|
from ... import displacy
 | 
						|
from collections import defaultdict, Counter
 | 
						|
from timeit import default_timer as timer
 | 
						|
 | 
						|
Fused_begin = None
 | 
						|
Fused_inside = None
 | 
						|
 | 
						|
import itertools
 | 
						|
import random
 | 
						|
import numpy.random
 | 
						|
 | 
						|
from . import conll17_ud_eval
 | 
						|
 | 
						|
from ... import lang
 | 
						|
from ...lang import zh
 | 
						|
from ...lang import ja
 | 
						|
from ...lang import ru
 | 
						|
 | 
						|
 | 
						|
################
 | 
						|
# Data reading #
 | 
						|
################
 | 
						|
 | 
						|
space_re = re.compile(r"\s+")
 | 
						|
 | 
						|
 | 
						|
def split_text(text):
 | 
						|
    return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]
 | 
						|
 | 
						|
 | 
						|
##############
 | 
						|
# Evaluation #
 | 
						|
##############
 | 
						|
 | 
						|
 | 
						|
def read_conllu(file_):
 | 
						|
    docs = []
 | 
						|
    sent = []
 | 
						|
    doc = []
 | 
						|
    for line in file_:
 | 
						|
        if line.startswith("# newdoc"):
 | 
						|
            if doc:
 | 
						|
                docs.append(doc)
 | 
						|
            doc = []
 | 
						|
        elif line.startswith("#"):
 | 
						|
            continue
 | 
						|
        elif not line.strip():
 | 
						|
            if sent:
 | 
						|
                doc.append(sent)
 | 
						|
            sent = []
 | 
						|
        else:
 | 
						|
            sent.append(list(line.strip().split("\t")))
 | 
						|
            if len(sent[-1]) != 10:
 | 
						|
                print(repr(line))
 | 
						|
                raise ValueError
 | 
						|
    if sent:
 | 
						|
        doc.append(sent)
 | 
						|
    if doc:
 | 
						|
        docs.append(doc)
 | 
						|
    return docs
 | 
						|
 | 
						|
 | 
						|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
 | 
						|
    if text_loc.parts[-1].endswith(".conllu"):
 | 
						|
        docs = []
 | 
						|
        with text_loc.open() as file_:
 | 
						|
            for conllu_doc in read_conllu(file_):
 | 
						|
                for conllu_sent in conllu_doc:
 | 
						|
                    words = [line[1] for line in conllu_sent]
 | 
						|
                    docs.append(Doc(nlp.vocab, words=words))
 | 
						|
        for name, component in nlp.pipeline:
 | 
						|
            docs = list(component.pipe(docs))
 | 
						|
    else:
 | 
						|
        with text_loc.open("r", encoding="utf8") as text_file:
 | 
						|
            texts = split_text(text_file.read())
 | 
						|
            docs = list(nlp.pipe(texts))
 | 
						|
    with sys_loc.open("w", encoding="utf8") as out_file:
 | 
						|
        write_conllu(docs, out_file)
 | 
						|
    with gold_loc.open("r", encoding="utf8") as gold_file:
 | 
						|
        gold_ud = conll17_ud_eval.load_conllu(gold_file)
 | 
						|
        with sys_loc.open("r", encoding="utf8") as sys_file:
 | 
						|
            sys_ud = conll17_ud_eval.load_conllu(sys_file)
 | 
						|
        scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
 | 
						|
    return docs, scores
 | 
						|
 | 
						|
 | 
						|
def write_conllu(docs, file_):
 | 
						|
    merger = Matcher(docs[0].vocab)
 | 
						|
    merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
 | 
						|
    for i, doc in enumerate(docs):
 | 
						|
        matches = merger(doc)
 | 
						|
        spans = [doc[start : end + 1] for _, start, end in matches]
 | 
						|
        offsets = [(span.start_char, span.end_char) for span in spans]
 | 
						|
        for start_char, end_char in offsets:
 | 
						|
            doc.merge(start_char, end_char)
 | 
						|
        # TODO: This shuldn't be necessary? Should be handled in merge
 | 
						|
        for word in doc:
 | 
						|
            if word.i == word.head.i:
 | 
						|
                word.dep_ = "ROOT"
 | 
						|
        file_.write("# newdoc id = {i}\n".format(i=i))
 | 
						|
        for j, sent in enumerate(doc.sents):
 | 
						|
            file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
 | 
						|
            file_.write("# text = {text}\n".format(text=sent.text))
 | 
						|
            for k, token in enumerate(sent):
 | 
						|
                file_.write(_get_token_conllu(token, k, len(sent)) + "\n")
 | 
						|
            file_.write("\n")
 | 
						|
            for word in sent:
 | 
						|
                if word.head.i == word.i and word.dep_ == "ROOT":
 | 
						|
                    break
 | 
						|
            else:
 | 
						|
                print("Rootless sentence!")
 | 
						|
                print(sent)
 | 
						|
                print(i)
 | 
						|
                for w in sent:
 | 
						|
                    print(w.i, w.text, w.head.text, w.head.i, w.dep_)
 | 
						|
                raise ValueError
 | 
						|
 | 
						|
 | 
						|
def _get_token_conllu(token, k, sent_len):
 | 
						|
    if token.check_morph(Fused_begin) and (k + 1 < sent_len):
 | 
						|
        n = 1
 | 
						|
        text = [token.text]
 | 
						|
        while token.nbor(n).check_morph(Fused_inside):
 | 
						|
            text.append(token.nbor(n).text)
 | 
						|
            n += 1
 | 
						|
        id_ = "%d-%d" % (k + 1, (k + n))
 | 
						|
        fields = [id_, "".join(text)] + ["_"] * 8
 | 
						|
        lines = ["\t".join(fields)]
 | 
						|
    else:
 | 
						|
        lines = []
 | 
						|
    if token.head.i == token.i:
 | 
						|
        head = 0
 | 
						|
    else:
 | 
						|
        head = k + (token.head.i - token.i) + 1
 | 
						|
    fields = [
 | 
						|
        str(k + 1),
 | 
						|
        token.text,
 | 
						|
        token.lemma_,
 | 
						|
        token.pos_,
 | 
						|
        token.tag_,
 | 
						|
        "_",
 | 
						|
        str(head),
 | 
						|
        token.dep_.lower(),
 | 
						|
        "_",
 | 
						|
        "_",
 | 
						|
    ]
 | 
						|
    if token.check_morph(Fused_begin) and (k + 1 < sent_len):
 | 
						|
        if k == 0:
 | 
						|
            fields[1] = token.norm_[0].upper() + token.norm_[1:]
 | 
						|
        else:
 | 
						|
            fields[1] = token.norm_
 | 
						|
    elif token.check_morph(Fused_inside):
 | 
						|
        fields[1] = token.norm_
 | 
						|
    elif token._.split_start is not None:
 | 
						|
        split_start = token._.split_start
 | 
						|
        split_end = token._.split_end
 | 
						|
        split_len = (split_end.i - split_start.i) + 1
 | 
						|
        n_in_split = token.i - split_start.i
 | 
						|
        subtokens = guess_fused_orths(split_start.text, [""] * split_len)
 | 
						|
        fields[1] = subtokens[n_in_split]
 | 
						|
 | 
						|
    lines.append("\t".join(fields))
 | 
						|
    return "\n".join(lines)
 | 
						|
 | 
						|
 | 
						|
def guess_fused_orths(word, ud_forms):
 | 
						|
    """The UD data 'fused tokens' don't necessarily expand to keys that match
 | 
						|
    the form. We need orths that exact match the string. Here we make a best
 | 
						|
    effort to divide up the word."""
 | 
						|
    if word == "".join(ud_forms):
 | 
						|
        # Happy case: we get a perfect split, with each letter accounted for.
 | 
						|
        return ud_forms
 | 
						|
    elif len(word) == sum(len(subtoken) for subtoken in ud_forms):
 | 
						|
        # Unideal, but at least lengths match.
 | 
						|
        output = []
 | 
						|
        remain = word
 | 
						|
        for subtoken in ud_forms:
 | 
						|
            assert len(subtoken) >= 1
 | 
						|
            output.append(remain[: len(subtoken)])
 | 
						|
            remain = remain[len(subtoken) :]
 | 
						|
        assert len(remain) == 0, (word, ud_forms, remain)
 | 
						|
        return output
 | 
						|
    else:
 | 
						|
        # Let's say word is 6 long, and there are three subtokens. The orths
 | 
						|
        # *must* equal the original string. Arbitrarily, split [4, 1, 1]
 | 
						|
        first = word[: len(word) - (len(ud_forms) - 1)]
 | 
						|
        output = [first]
 | 
						|
        remain = word[len(first) :]
 | 
						|
        for i in range(1, len(ud_forms)):
 | 
						|
            assert remain
 | 
						|
            output.append(remain[:1])
 | 
						|
            remain = remain[1:]
 | 
						|
        assert len(remain) == 0, (word, output, remain)
 | 
						|
        return output
 | 
						|
 | 
						|
 | 
						|
def print_results(name, ud_scores):
 | 
						|
    fields = {}
 | 
						|
    if ud_scores is not None:
 | 
						|
        fields.update(
 | 
						|
            {
 | 
						|
                "words": ud_scores["Words"].f1 * 100,
 | 
						|
                "sents": ud_scores["Sentences"].f1 * 100,
 | 
						|
                "tags": ud_scores["XPOS"].f1 * 100,
 | 
						|
                "uas": ud_scores["UAS"].f1 * 100,
 | 
						|
                "las": ud_scores["LAS"].f1 * 100,
 | 
						|
            }
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        fields.update({"words": 0.0, "sents": 0.0, "tags": 0.0, "uas": 0.0, "las": 0.0})
 | 
						|
    tpl = "\t".join(
 | 
						|
        (name, "{las:.1f}", "{uas:.1f}", "{tags:.1f}", "{sents:.1f}", "{words:.1f}")
 | 
						|
    )
 | 
						|
    print(tpl.format(**fields))
 | 
						|
    return fields
 | 
						|
 | 
						|
 | 
						|
def get_token_split_start(token):
 | 
						|
    if token.text == "":
 | 
						|
        assert token.i != 0
 | 
						|
        i = -1
 | 
						|
        while token.nbor(i).text == "":
 | 
						|
            i -= 1
 | 
						|
        return token.nbor(i)
 | 
						|
    elif (token.i + 1) < len(token.doc) and token.nbor(1).text == "":
 | 
						|
        return token
 | 
						|
    else:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def get_token_split_end(token):
 | 
						|
    if (token.i + 1) == len(token.doc):
 | 
						|
        return token if token.text == "" else None
 | 
						|
    elif token.text != "" and token.nbor(1).text != "":
 | 
						|
        return None
 | 
						|
    i = 1
 | 
						|
    while (token.i + i) < len(token.doc) and token.nbor(i).text == "":
 | 
						|
        i += 1
 | 
						|
    return token.nbor(i - 1)
 | 
						|
 | 
						|
 | 
						|
##################
 | 
						|
# Initialization #
 | 
						|
##################
 | 
						|
 | 
						|
 | 
						|
def load_nlp(experiments_dir, corpus):
 | 
						|
    nlp = spacy.load(experiments_dir / corpus / "best-model")
 | 
						|
    return nlp
 | 
						|
 | 
						|
 | 
						|
def initialize_pipeline(nlp, docs, golds, config, device):
 | 
						|
    nlp.add_pipe(nlp.create_pipe("parser"))
 | 
						|
    return nlp
 | 
						|
 | 
						|
 | 
						|
@plac.annotations(
 | 
						|
    test_data_dir=(
 | 
						|
        "Path to Universal Dependencies test data",
 | 
						|
        "positional",
 | 
						|
        None,
 | 
						|
        Path,
 | 
						|
    ),
 | 
						|
    experiment_dir=("Parent directory with output model", "positional", None, Path),
 | 
						|
    corpus=(
 | 
						|
        "UD corpus to evaluate, e.g. UD_English, UD_Spanish, etc",
 | 
						|
        "positional",
 | 
						|
        None,
 | 
						|
        str,
 | 
						|
    ),
 | 
						|
)
 | 
						|
def main(test_data_dir, experiment_dir, corpus):
 | 
						|
    Token.set_extension("split_start", getter=get_token_split_start)
 | 
						|
    Token.set_extension("split_end", getter=get_token_split_end)
 | 
						|
    Token.set_extension("begins_fused", default=False)
 | 
						|
    Token.set_extension("inside_fused", default=False)
 | 
						|
    lang.zh.Chinese.Defaults.use_jieba = False
 | 
						|
    lang.ja.Japanese.Defaults.use_janome = False
 | 
						|
    lang.ru.Russian.Defaults.use_pymorphy2 = False
 | 
						|
 | 
						|
    nlp = load_nlp(experiment_dir, corpus)
 | 
						|
 | 
						|
    treebank_code = nlp.meta["treebank"]
 | 
						|
    for section in ("test", "dev"):
 | 
						|
        if section == "dev":
 | 
						|
            section_dir = "conll17-ud-development-2017-03-19"
 | 
						|
        else:
 | 
						|
            section_dir = "conll17-ud-test-2017-05-09"
 | 
						|
        text_path = test_data_dir / "input" / section_dir / (treebank_code + ".txt")
 | 
						|
        udpipe_path = (
 | 
						|
            test_data_dir / "input" / section_dir / (treebank_code + "-udpipe.conllu")
 | 
						|
        )
 | 
						|
        gold_path = test_data_dir / "gold" / section_dir / (treebank_code + ".conllu")
 | 
						|
 | 
						|
        header = [section, "LAS", "UAS", "TAG", "SENT", "WORD"]
 | 
						|
        print("\t".join(header))
 | 
						|
        inputs = {"gold": gold_path, "udp": udpipe_path, "raw": text_path}
 | 
						|
        for input_type in ("udp", "raw"):
 | 
						|
            input_path = inputs[input_type]
 | 
						|
            output_path = (
 | 
						|
                experiment_dir / corpus / "{section}.conllu".format(section=section)
 | 
						|
            )
 | 
						|
 | 
						|
            parsed_docs, test_scores = evaluate(nlp, input_path, gold_path, output_path)
 | 
						|
 | 
						|
            accuracy = print_results(input_type, test_scores)
 | 
						|
            acc_path = (
 | 
						|
                experiment_dir
 | 
						|
                / corpus
 | 
						|
                / "{section}-accuracy.json".format(section=section)
 | 
						|
            )
 | 
						|
            srsly.write_json(acc_path, accuracy)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    plac.call(main)
 |