diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index f02aa2a73..4efc063d7 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals import plac import json from os import path @@ -5,6 +6,7 @@ import shutil import os import random import io +import pathlib from spacy.tokens import Doc from spacy.syntax.nonproj import PseudoProjectivity @@ -17,15 +19,12 @@ from spacy.syntax.parser import get_templates from spacy.syntax.arc_eager import ArcEager from spacy.scorer import Scorer import spacy.attrs +import io -try: - from codecs import open -except ImportError: - pass def read_conllx(loc): - with open(loc, 'r', 'utf8') as file_: + with io.open(loc, 'r', encoding='utf8') as file_: text = file_.read() for sent in text.strip().split('\n\n'): lines = sent.strip().split('\n') @@ -56,6 +55,7 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False): doc = Doc(vocab, words=words) tagger(doc) parser(doc) + PseudoProjectivity.deprojectivize(doc) gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) scorer.score(doc, gold, verbose=verbose) return scorer @@ -66,8 +66,13 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc): tag_map = json.loads(file_.read()) train_sents = list(read_conllx(train_loc)) train_sents = PseudoProjectivity.preprocess_training_data(train_sents) + actions = ArcEager.get_actions(gold_parses=train_sents) features = get_templates('basic') + + model_dir = pathlib.Path(model_dir) + with (model_dir / 'deps' / 'config.json').open('wb') as file_: + json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_) vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map) # Populate vocab @@ -75,9 +80,12 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc): for (ids, words, tags, heads, deps, ner), _ in doc_sents: for word in words: _ = vocab[word] + for dep in deps: + _ = vocab[dep] + for tag in tags: + _ = vocab[tag] for tag in tags: assert tag in tag_map, repr(tag) - print(tags) tagger = Tagger(vocab, tag_map=tag_map) parser = DependencyParser(vocab, actions=actions, features=features)