adjust train.py to train both english and german models

This commit is contained in:
Wolfgang Seeker 2016-03-03 15:21:00 +01:00
parent 3448cb40a4
commit 690c5acabf
6 changed files with 67 additions and 45 deletions

View File

@ -14,6 +14,7 @@ import re
import spacy.util
from spacy.en import English
from spacy.de import German
from spacy.syntax.util import Config
from spacy.gold import read_json_file
@ -25,6 +26,7 @@ from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.ner import BiluoPushDown
from spacy.tagger import Tagger
from spacy.syntax.parser import Parser
from spacy.syntax.nonproj import PseudoProjectivity
def _corrupt(c, noise_level):
@ -82,7 +84,7 @@ def _merge_sents(sents):
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
beam_width=1, verbose=False,
use_orig_arc_eager=False):
use_orig_arc_eager=False, pseudoprojective=False):
dep_model_dir = path.join(model_dir, 'deps')
ner_model_dir = path.join(model_dir, 'ner')
pos_model_dir = path.join(model_dir, 'pos')
@ -96,9 +98,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
os.mkdir(ner_model_dir)
os.mkdir(pos_model_dir)
if pseudoprojective:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=ArcEager.get_labels(gold_tuples),
beam_width=beam_width)
beam_width=beam_width,projectivize=pseudoprojective)
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=BiluoPushDown.get_labels(gold_tuples),
beam_width=0)
@ -107,6 +113,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
@ -131,12 +139,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
raw_text = add_noise(raw_text, corruption_level)
tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True)
gold = GoldParse(tokens, annot_tuples)
if not gold.is_projective:
raise Exception(
"Non-projective sentence in training, after we should "
"have enforced projectivity: %s" % annot_tuples
)
raise Exception("Non-projective sentence in training: %s" % annot_tuples)
loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags)
@ -152,6 +157,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
beam_width=None, cand_preproc=None):
nlp = Language(data_dir=model_dir)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width
scorer = Scorer()
@ -200,6 +207,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
@plac.annotations(
language=("The language to train", "positional", None, str, ['en','de']),
train_loc=("Location of training file or directory"),
dev_loc=("Location of development file or directory"),
model_dir=("Location of output model directory",),
@ -211,19 +219,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
n_iter=("Number of training iterations", "option", "i", int),
verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool),
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
)
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False):
def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
lang = {'en':English, 'de':German}.get(language)
if not eval_only:
gold_train = list(read_json_file(train_loc))
train(English, gold_train, model_dir,
train(lang, gold_train, model_dir,
feat_set='basic' if not debug else 'debug',
gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter,
verbose=verbose)
verbose=verbose,pseudoprojective=pseudoprojective)
if out_loc:
write_parses(English, dev_loc, model_dir, out_loc)
scorer = evaluate(English, list(read_json_file(dev_loc)),
write_parses(lang, dev_loc, model_dir, out_loc)
scorer = evaluate(lang, list(read_json_file(dev_loc)),
model_dir, gold_preproc=gold_preproc, verbose=verbose)
print('TOK', scorer.token_acc)
print('POS', scorer.tags_acc)

View File

@ -6,4 +6,4 @@ from ..language import Language
class German(Language):
pass
lang = 'de'

View File

@ -244,14 +244,8 @@ cdef class GoldParse:
raise Exception("Cycle found: %s" % cycle)
if make_projective:
# projectivity here means non-proj arcs are being disconnected
np_arcs = []
for word in range(self.length):
if nonproj.is_nonproj_arc(word,self.heads):
np_arcs.append(word)
for np_arc in np_arcs:
self.heads[np_arc] = None
self.labels[np_arc] = ''
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
self.heads = proj_heads
self.brackets = {}
for (gold_start, gold_end, label_str) in brackets:

View File

@ -94,9 +94,10 @@ cdef class Parser:
moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features)
model = ParserModel(templates)
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model, cfg.projectivize)
return cls(strings, moves, model, project)
@classmethod
def load(cls, pkg_or_str_or_file, vocab):

View File

@ -143,7 +143,7 @@ cdef class Tagger:
@classmethod
def blank(cls, vocab, templates):
model = TaggerModel(N_CONTEXT_FIELDS, templates)
model = TaggerModel(templates)
return cls(vocab, model)
@classmethod

View File

@ -1,9 +1,6 @@
from __future__ import unicode_literals
import pytest
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.tokenizer import Tokenizer
from spacy.attrs import DEP, HEAD
import numpy
@ -56,12 +53,28 @@ def test_is_nonproj_tree():
assert(is_nonproj_tree(partial_tree) == False)
assert(is_nonproj_tree(multirooted_tree) == True)
def test_pseudoprojectivity():
def deprojectivize(proj_heads, deco_labels, EN):
slen = len(proj_heads)
sent = EN.tokenizer.tokens_from_list(['whatever'] * slen)
rel_proj_heads = [ head-i for i,head in enumerate(proj_heads) ]
labelids = [ EN.vocab.strings[label] for label in deco_labels ]
parse = numpy.asarray(zip(rel_proj_heads,labelids), dtype=numpy.int32)
sent.from_array([HEAD,DEP],parse)
PseudoProjectivity.deprojectivize(sent)
parse = sent.to_array([HEAD,DEP])
deproj_heads = [ i+head for i,head in enumerate(parse[:,0]) ]
undeco_labels = [ EN.vocab.strings[int(labelid)] for labelid in parse[:,1] ]
return deproj_heads, undeco_labels
@pytest.mark.models
def test_pseudoprojectivity(EN):
tree = [1,2,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2]
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct']
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
labels2 = ['advmod','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct']
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
assert(PseudoProjectivity.decompose('X') == ('X',''))
@ -80,29 +93,32 @@ def test_pseudoprojectivity():
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
# assert(deproj_heads == nonproj_tree)
# assert(undeco_labels == labels)
assert(deco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl||dobj','punct'])
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == nonproj_tree)
assert(undeco_labels == labels)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
# assert(deproj_heads == nonproj_tree2)
# assert(undeco_labels == labels2)
assert(deco_labels == ['advmod||aux','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct'])
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == nonproj_tree2)
assert(undeco_labels == labels2)
# if decoration is wrong such that there is no head with the desired label
# the structure is kept and the label is undecorated
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
proj_heads = [1,2,2,4,5,2,7,5,2]
deco_labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl||iobj','punct']
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == proj_heads)
assert(undeco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct'])
# if there are two potential new heads, the first one is chosen even if it's wrong
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
proj_heads = [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]
deco_labels = ['advmod||aux','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct']
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
assert(undeco_labels == ['advmod','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct'])