Improve train_ud script

This commit is contained in:
Matthew Honnibal 2017-01-09 09:53:46 -06:00
parent 363f09e68c
commit 4ff92184f1

View File

@ -18,6 +18,7 @@ from spacy.pipeline import DependencyParser
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.language_data.tag_map import TAG_MAP as DEFAULT_TAG_MAP
import spacy.attrs import spacy.attrs
import io import io
@ -61,9 +62,12 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False):
return scorer return scorer
def main(train_loc, dev_loc, model_dir, tag_map_loc): def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
if tag_map_loc:
with open(tag_map_loc) as file_: with open(tag_map_loc) as file_:
tag_map = json.loads(file_.read()) tag_map = json.loads(file_.read())
else:
tag_map = DEFAULT_TAG_MAP
train_sents = list(read_conllx(train_loc)) train_sents = list(read_conllx(train_loc))
train_sents = PseudoProjectivity.preprocess_training_data(train_sents) train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
@ -73,9 +77,10 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
model_dir = pathlib.Path(model_dir) model_dir = pathlib.Path(model_dir)
if not (model_dir / 'deps').exists(): if not (model_dir / 'deps').exists():
(model_dir / 'deps').mkdir() (model_dir / 'deps').mkdir()
with (model_dir / 'deps' / 'config.json').open('w') as file_: with (model_dir / 'deps' / 'config.json').open('wb') as file_:
json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_) file_.write(
json.dumps(
{'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8'))
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map) vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
# Populate vocab # Populate vocab
for _, doc_sents in train_sents: for _, doc_sents in train_sents:
@ -86,6 +91,7 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
_ = vocab[dep] _ = vocab[dep]
for tag in tags: for tag in tags:
_ = vocab[tag] _ = vocab[tag]
if tag_map:
for tag in tags: for tag in tags:
assert tag in tag_map, repr(tag) assert tag in tag_map, repr(tag)
tagger = Tagger(vocab, tag_map=tag_map) tagger = Tagger(vocab, tag_map=tag_map)