diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 3bb6248af..7f8c817d2 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -30,6 +30,7 @@ random.seed(0) numpy.random.seed(0) def minibatch_by_words(items, size=5000): + random.shuffle(items) if isinstance(size, int): size_ = itertools.repeat(size) else: @@ -39,10 +40,17 @@ def minibatch_by_words(items, size=5000): batch_size = next(size_) batch = [] while batch_size >= 0: - doc, gold = next(items) + try: + doc, gold = next(items) + except StopIteration: + yield batch + return batch_size -= len(doc) batch.append((doc, gold)) - yield batch + if batch: + yield batch + else: + break ################ # Data reading # @@ -146,10 +154,6 @@ def _make_gold(nlp, text, sent_annots): doc = nlp.make_doc(text) flat.pop('spaces') gold = GoldParse(doc, **flat) - #for annot in gold.orig_annot: - # print(annot) - #for i in range(len(doc)): - # print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) return doc, gold ############################# @@ -168,12 +172,6 @@ def golds_to_gold_tuples(docs, golds): return tuples -def refresh_docs(docs): - vocab = docs[0].vocab - return [Doc(vocab, words=[t.text for t in doc], - spaces=[t.whitespace_ for t in doc]) - for doc in docs] - ############## # Evaluation # ############## @@ -357,25 +355,24 @@ class TreebankPaths(object): ) def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10): paths = TreebankPaths(ud_dir, corpus) + print("Train and evaluate", corpus, "using lang", paths.lang) nlp = load_nlp(paths.lang, config) docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), - limit=limit) + max_doc_length=config.max_doc_length, limit=limit) optimizer = initialize_pipeline(nlp, docs, golds, config) - n_train_words = sum(len(doc) for doc in docs) - print("Begin training (%d words)" % n_train_words) for i in range(config.nr_epoch): - docs = refresh_docs(docs) + docs = [nlp.make_doc(doc.text) for doc in docs] batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) losses = {} - for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size): - if not batch: - continue - batch_docs, batch_gold = zip(*batch) - - nlp.update(batch_docs, batch_gold, sgd=optimizer, - drop=config.dropout, losses=losses) + n_train_words = sum(len(doc) for doc in docs) + with tqdm.tqdm(total=n_train_words, leave=False) as pbar: + for batch in batches: + batch_docs, batch_gold = zip(*batch) + pbar.update(sum(len(doc) for doc in batch_docs)) + nlp.update(batch_docs, batch_gold, sgd=optimizer, + drop=config.dropout, losses=losses) with nlp.use_params(optimizer.averages): dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu)