Merge pull request #901 from raphael0202/train_ud

Split CONLLX file using tabs and not default split separators
This commit is contained in:
Matthew Honnibal 2017-03-21 23:39:45 +01:00 committed by GitHub
commit f4010053a6

View File

@ -1,18 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import plac import plac
import json import json
from os import path
import shutil
import os
import random import random
import io
import pathlib import pathlib
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.syntax.nonproj import PseudoProjectivity from spacy.syntax.nonproj import PseudoProjectivity
from spacy.language import Language from spacy.language import Language
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.pipeline import DependencyParser, BeamDependencyParser from spacy.pipeline import DependencyParser, BeamDependencyParser
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
@ -23,7 +18,6 @@ import spacy.attrs
import io import io
def read_conllx(loc, n=0): def read_conllx(loc, n=0):
with io.open(loc, 'r', encoding='utf8') as file_: with io.open(loc, 'r', encoding='utf8') as file_:
text = file_.read() text = file_.read()
@ -35,7 +29,8 @@ def read_conllx(loc, n=0):
lines.pop(0) lines.pop(0)
tokens = [] tokens = []
for line in lines: for line in lines:
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split() id_, word, lemma, pos, tag, morph, head, dep, _1, \
_2 = line.split('\t')
if '-' in id_ or '.' in id_: if '-' in id_ or '.' in id_:
continue continue
try: try:
@ -134,7 +129,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
random.shuffle(train_sents) random.shuffle(train_sents)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc)) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
nlp = Language(vocab=vocab, tagger=tagger, parser=parser) nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser)
nlp.end_training(model_dir) nlp.end_training(model_dir)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))