From 4ff92184f145a776a29653f2adf29d2bcd0083fe Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 9 Jan 2017 09:53:46 -0600 Subject: [PATCH] Improve train_ud script --- bin/parser/train_ud.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index 565eab37f..c96faf7b9 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -18,6 +18,7 @@ from spacy.pipeline import DependencyParser from spacy.syntax.parser import get_templates from spacy.syntax.arc_eager import ArcEager from spacy.scorer import Scorer +from spacy.language_data.tag_map import TAG_MAP as DEFAULT_TAG_MAP import spacy.attrs import io @@ -61,9 +62,12 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False): return scorer -def main(train_loc, dev_loc, model_dir, tag_map_loc): - with open(tag_map_loc) as file_: - tag_map = json.loads(file_.read()) +def main(train_loc, dev_loc, model_dir, tag_map_loc=None): + if tag_map_loc: + with open(tag_map_loc) as file_: + tag_map = json.loads(file_.read()) + else: + tag_map = DEFAULT_TAG_MAP train_sents = list(read_conllx(train_loc)) train_sents = PseudoProjectivity.preprocess_training_data(train_sents) @@ -73,9 +77,10 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc): model_dir = pathlib.Path(model_dir) if not (model_dir / 'deps').exists(): (model_dir / 'deps').mkdir() - with (model_dir / 'deps' / 'config.json').open('w') as file_: - json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_) - + with (model_dir / 'deps' / 'config.json').open('wb') as file_: + file_.write( + json.dumps( + {'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8')) vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map) # Populate vocab for _, doc_sents in train_sents: @@ -86,8 +91,9 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc): _ = vocab[dep] for tag in tags: _ = vocab[tag] - for tag in tags: - assert tag in tag_map, repr(tag) + if tag_map: + for tag in tags: + assert tag in tag_map, repr(tag) tagger = Tagger(vocab, tag_map=tag_map) parser = DependencyParser(vocab, actions=actions, features=features)