diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index c96faf7b9..4b3080ce5 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -24,9 +24,10 @@ import io -def read_conllx(loc): +def read_conllx(loc, n=0): with io.open(loc, 'r', encoding='utf8') as file_: text = file_.read() + i = 0 for sent in text.strip().split('\n\n'): lines = sent.strip().split('\n') if lines: @@ -47,6 +48,9 @@ def read_conllx(loc): raise tuples = [list(t) for t in zip(*tokens)] yield (None, [[tuples, []]]) + i += 1 + if n >= 1 and i >= n: + break def score_model(vocab, tagger, parser, gold_docs, verbose=False): @@ -73,7 +77,7 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None): actions = ArcEager.get_actions(gold_parses=train_sents) features = get_templates('basic') - + model_dir = pathlib.Path(model_dir) if not (model_dir / 'deps').exists(): (model_dir / 'deps').mkdir() @@ -95,25 +99,26 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None): for tag in tags: assert tag in tag_map, repr(tag) tagger = Tagger(vocab, tag_map=tag_map) - parser = DependencyParser(vocab, actions=actions, features=features) - + parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) + for itn in range(15): + loss = 0. for _, doc_sents in train_sents: for (ids, words, tags, heads, deps, ner), _ in doc_sents: doc = Doc(vocab, words=words) gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) tagger(doc) - parser.update(doc, gold) + loss += parser.update(doc, gold, itn=itn) doc = Doc(vocab, words=words) tagger.update(doc, gold) random.shuffle(train_sents) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) - print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc)) + print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc)) nlp = Language(vocab=vocab, tagger=tagger, parser=parser) nlp.end_training(model_dir) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) - + if __name__ == '__main__': plac.call(main)