Use unicode literals in train_ud

This commit is contained in:
Matthew Honnibal 2016-11-25 17:45:45 -06:00
parent bc0a202c9c
commit 22189e60db

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import plac
import json
from os import path
@ -5,6 +6,7 @@ import shutil
import os
import random
import io
import pathlib
from spacy.tokens import Doc
from spacy.syntax.nonproj import PseudoProjectivity
@ -17,15 +19,12 @@ from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer
import spacy.attrs
import io
try:
from codecs import open
except ImportError:
pass
def read_conllx(loc):
with open(loc, 'r', 'utf8') as file_:
with io.open(loc, 'r', encoding='utf8') as file_:
text = file_.read()
for sent in text.strip().split('\n\n'):
lines = sent.strip().split('\n')
@ -56,6 +55,7 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False):
doc = Doc(vocab, words=words)
tagger(doc)
parser(doc)
PseudoProjectivity.deprojectivize(doc)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
scorer.score(doc, gold, verbose=verbose)
return scorer
@ -66,8 +66,13 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
tag_map = json.loads(file_.read())
train_sents = list(read_conllx(train_loc))
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
actions = ArcEager.get_actions(gold_parses=train_sents)
features = get_templates('basic')
model_dir = pathlib.Path(model_dir)
with (model_dir / 'deps' / 'config.json').open('wb') as file_:
json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_)
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
# Populate vocab
@ -75,9 +80,12 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
for word in words:
_ = vocab[word]
for dep in deps:
_ = vocab[dep]
for tag in tags:
_ = vocab[tag]
for tag in tags:
assert tag in tag_map, repr(tag)
print(tags)
tagger = Tagger(vocab, tag_map=tag_map)
parser = DependencyParser(vocab, actions=actions, features=features)