From c388833ca63f065ae9ffa8fc2fc8defafb8ab3cb Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 25 Feb 2018 10:38:06 +0100 Subject: [PATCH] Minibatch by number of tokens, support other vectors, refactor CoNLL printing --- examples/training/conllu.py | 61 +++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 673033280..148475bbc 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -8,21 +8,38 @@ import re import sys import spacy import spacy.util -from spacy.tokens import Doc +from spacy.tokens import Token, Doc from spacy.gold import GoldParse, minibatch from spacy.syntax.nonproj import projectivize from collections import defaultdict, Counter from timeit import default_timer as timer from spacy.matcher import Matcher +import itertools import random import numpy.random +import cytoolz from spacy._align import align random.seed(0) numpy.random.seed(0) +def minibatch_by_words(items, size=5000): + if isinstance(size, int): + size_ = itertools.repeat(size) + else: + size_ = size + items = iter(items) + while True: + batch_size = next(size_) + batch = [] + while batch_size >= 0: + doc, gold = next(items) + batch_size -= len(doc) + batch.append((doc, gold)) + yield batch + def get_token_acc(docs, golds): '''Quick function to evaluate tokenization accuracy.''' @@ -214,31 +231,51 @@ def print_conllu(docs, file_): offsets = [(span.start_char, span.end_char) for span in spans] for start_char, end_char in offsets: doc.merge(start_char, end_char) - #print([t.text for t in doc]) file_.write("# newdoc id = {i}\n".format(i=i)) for j, sent in enumerate(doc.sents): file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) file_.write("# text = {text}\n".format(text=sent.text)) - for k, t in enumerate(sent): - if t.head.i == t.i: - head = 0 - else: - head = k + (t.head.i - t.i) + 1 - fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', - str(head), t.dep_.lower(), '_', '_'] - file_.write('\t'.join(fields) + '\n') + for k, token in enumerate(sent): + file_.write(token._.get_conllu_lines(k) + '\n') file_.write('\n') +#def get_sent_conllu(sent, sent_id): +# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] + +def get_token_conllu(token, i): + if token._.begins_fused: + n = 1 + while token.nbor(n)._.inside_fused: + n += 1 + id_ = '%d-%d' % (k, k+n) + lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_'] + else: + lines = [] + if token.head.i == token.i: + head = 0 + else: + head = i + (token.head.i - token.i) + 1 + fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_', + str(head), token.dep_.lower(), '_', '_'] + lines.append('\t'.join(fields)) + return '\n'.join(lines) + +Token.set_extension('get_conllu_lines', method=get_token_conllu) +Token.set_extension('begins_fused', default=False) +Token.set_extension('inside_fused', default=False) + def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, output_loc): - nlp = spacy.blank(lang) if lang == 'en': + nlp = spacy.blank(lang) vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0') nlp.vocab.vectors = vec_nlp.vocab.vectors for lex in vec_nlp.vocab: _ = nlp.vocab[lex.orth_] vec_nlp = None + else: + nlp = spacy.load(lang) with open(conllu_train_loc) as conllu_file: with open(text_train_loc) as text_file: docs, golds = read_data(nlp, conllu_file, text_file, @@ -272,7 +309,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, spacy.util.env_opt('batch_compound', 1.001)) for i in range(30): docs = refresh_docs(docs) - batches = minibatch(list(zip(docs, golds)), size=batch_sizes) + batches = minibatch_by_words(list(zip(docs, golds)), size=1000) with tqdm.tqdm(total=n_train_words, leave=False) as pbar: losses = {} for batch in batches: