adjust train.py to train both english and german models

This commit is contained in:
Wolfgang Seeker 2016-03-03 15:21:00 +01:00
parent 3448cb40a4
commit 690c5acabf
6 changed files with 67 additions and 45 deletions

View File

@ -14,6 +14,7 @@ import re
import spacy.util import spacy.util
from spacy.en import English from spacy.en import English
from spacy.de import German
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.gold import read_json_file from spacy.gold import read_json_file
@ -25,6 +26,7 @@ from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.ner import BiluoPushDown from spacy.syntax.ner import BiluoPushDown
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.syntax.parser import Parser from spacy.syntax.parser import Parser
from spacy.syntax.nonproj import PseudoProjectivity
def _corrupt(c, noise_level): def _corrupt(c, noise_level):
@ -82,7 +84,7 @@ def _merge_sents(sents):
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0, seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
beam_width=1, verbose=False, beam_width=1, verbose=False,
use_orig_arc_eager=False): use_orig_arc_eager=False, pseudoprojective=False):
dep_model_dir = path.join(model_dir, 'deps') dep_model_dir = path.join(model_dir, 'deps')
ner_model_dir = path.join(model_dir, 'ner') ner_model_dir = path.join(model_dir, 'ner')
pos_model_dir = path.join(model_dir, 'pos') pos_model_dir = path.join(model_dir, 'pos')
@ -96,9 +98,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
os.mkdir(ner_model_dir) os.mkdir(ner_model_dir)
os.mkdir(pos_model_dir) os.mkdir(pos_model_dir)
if pseudoprojective:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=ArcEager.get_labels(gold_tuples), labels=ArcEager.get_labels(gold_tuples),
beam_width=beam_width) beam_width=beam_width,projectivize=pseudoprojective)
Config.write(ner_model_dir, 'config', features='ner', seed=seed, Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=BiluoPushDown.get_labels(gold_tuples), labels=BiluoPushDown.get_labels(gold_tuples),
beam_width=0) beam_width=0)
@ -107,6 +113,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
gold_tuples = gold_tuples[:n_sents] gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False) nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates()) nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager) nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown) nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
@ -131,12 +139,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
raw_text = add_noise(raw_text, corruption_level) raw_text = add_noise(raw_text, corruption_level)
tokens = nlp.tokenizer(raw_text) tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens) nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True) gold = GoldParse(tokens, annot_tuples)
if not gold.is_projective: if not gold.is_projective:
raise Exception( raise Exception("Non-projective sentence in training: %s" % annot_tuples)
"Non-projective sentence in training, after we should "
"have enforced projectivity: %s" % annot_tuples
)
loss += nlp.parser.train(tokens, gold) loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold) nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags) nlp.tagger.train(tokens, gold.tags)
@ -152,6 +157,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False, def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
beam_width=None, cand_preproc=None): beam_width=None, cand_preproc=None):
nlp = Language(data_dir=model_dir) nlp = Language(data_dir=model_dir)
if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
if beam_width is not None: if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width nlp.parser.cfg.beam_width = beam_width
scorer = Scorer() scorer = Scorer()
@ -200,6 +207,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
@plac.annotations( @plac.annotations(
language=("The language to train", "positional", None, str, ['en','de']),
train_loc=("Location of training file or directory"), train_loc=("Location of training file or directory"),
dev_loc=("Location of development file or directory"), dev_loc=("Location of development file or directory"),
model_dir=("Location of output model directory",), model_dir=("Location of output model directory",),
@ -211,19 +219,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
n_iter=("Number of training iterations", "option", "i", int), n_iter=("Number of training iterations", "option", "i", int),
verbose=("Verbose error reporting", "flag", "v", bool), verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool), debug=("Debug mode", "flag", "d", bool),
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
) )
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False, def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False): debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
lang = {'en':English, 'de':German}.get(language)
if not eval_only: if not eval_only:
gold_train = list(read_json_file(train_loc)) gold_train = list(read_json_file(train_loc))
train(English, gold_train, model_dir, train(lang, gold_train, model_dir,
feat_set='basic' if not debug else 'debug', feat_set='basic' if not debug else 'debug',
gold_preproc=gold_preproc, n_sents=n_sents, gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter, corruption_level=corruption_level, n_iter=n_iter,
verbose=verbose) verbose=verbose,pseudoprojective=pseudoprojective)
if out_loc: if out_loc:
write_parses(English, dev_loc, model_dir, out_loc) write_parses(lang, dev_loc, model_dir, out_loc)
scorer = evaluate(English, list(read_json_file(dev_loc)), scorer = evaluate(lang, list(read_json_file(dev_loc)),
model_dir, gold_preproc=gold_preproc, verbose=verbose) model_dir, gold_preproc=gold_preproc, verbose=verbose)
print('TOK', scorer.token_acc) print('TOK', scorer.token_acc)
print('POS', scorer.tags_acc) print('POS', scorer.tags_acc)

View File

@ -6,4 +6,4 @@ from ..language import Language
class German(Language): class German(Language):
pass lang = 'de'

View File

@ -244,14 +244,8 @@ cdef class GoldParse:
raise Exception("Cycle found: %s" % cycle) raise Exception("Cycle found: %s" % cycle)
if make_projective: if make_projective:
# projectivity here means non-proj arcs are being disconnected proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
np_arcs = [] self.heads = proj_heads
for word in range(self.length):
if nonproj.is_nonproj_arc(word,self.heads):
np_arcs.append(word)
for np_arc in np_arcs:
self.heads[np_arc] = None
self.labels[np_arc] = ''
self.brackets = {} self.brackets = {}
for (gold_start, gold_end, label_str) in brackets: for (gold_start, gold_end, label_str) in brackets:

View File

@ -94,9 +94,10 @@ cdef class Parser:
moves = transition_system(strings, cfg.labels) moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features) templates = get_templates(cfg.features)
model = ParserModel(templates) model = ParserModel(templates)
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model, cfg.projectivize) return cls(strings, moves, model, project)
@classmethod @classmethod
def load(cls, pkg_or_str_or_file, vocab): def load(cls, pkg_or_str_or_file, vocab):

View File

@ -143,7 +143,7 @@ cdef class Tagger:
@classmethod @classmethod
def blank(cls, vocab, templates): def blank(cls, vocab, templates):
model = TaggerModel(N_CONTEXT_FIELDS, templates) model = TaggerModel(templates)
return cls(vocab, model) return cls(vocab, model)
@classmethod @classmethod

View File

@ -1,9 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.tokenizer import Tokenizer
from spacy.attrs import DEP, HEAD from spacy.attrs import DEP, HEAD
import numpy import numpy
@ -56,12 +53,28 @@ def test_is_nonproj_tree():
assert(is_nonproj_tree(partial_tree) == False) assert(is_nonproj_tree(partial_tree) == False)
assert(is_nonproj_tree(multirooted_tree) == True) assert(is_nonproj_tree(multirooted_tree) == True)
def test_pseudoprojectivity():
def deprojectivize(proj_heads, deco_labels, EN):
slen = len(proj_heads)
sent = EN.tokenizer.tokens_from_list(['whatever'] * slen)
rel_proj_heads = [ head-i for i,head in enumerate(proj_heads) ]
labelids = [ EN.vocab.strings[label] for label in deco_labels ]
parse = numpy.asarray(zip(rel_proj_heads,labelids), dtype=numpy.int32)
sent.from_array([HEAD,DEP],parse)
PseudoProjectivity.deprojectivize(sent)
parse = sent.to_array([HEAD,DEP])
deproj_heads = [ i+head for i,head in enumerate(parse[:,0]) ]
undeco_labels = [ EN.vocab.strings[int(labelid)] for labelid in parse[:,1] ]
return deproj_heads, undeco_labels
@pytest.mark.models
def test_pseudoprojectivity(EN):
tree = [1,2,2] tree = [1,2,2]
nonproj_tree = [1,2,2,4,5,2,7,4,2] nonproj_tree = [1,2,2,4,5,2,7,4,2]
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'] labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct']
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1] nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'] labels2 = ['advmod','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct']
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y')) assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
assert(PseudoProjectivity.decompose('X') == ('X','')) assert(PseudoProjectivity.decompose('X') == ('X',''))
@ -80,29 +93,32 @@ def test_pseudoprojectivity():
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels) proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
assert(proj_heads == [1,2,2,4,5,2,7,5,2]) assert(proj_heads == [1,2,2,4,5,2,7,5,2])
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--']) assert(deco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl||dobj','punct'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels) deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
# assert(deproj_heads == nonproj_tree) assert(deproj_heads == nonproj_tree)
# assert(undeco_labels == labels) assert(undeco_labels == labels)
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2) proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]) assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--']) assert(deco_labels == ['advmod||aux','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct'])
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels) deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
# assert(deproj_heads == nonproj_tree2) assert(deproj_heads == nonproj_tree2)
# assert(undeco_labels == labels2) assert(undeco_labels == labels2)
# if decoration is wrong such that there is no head with the desired label # if decoration is wrong such that there is no head with the desired label
# the structure is kept and the label is undecorated # the structure is kept and the label is undecorated
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--']) proj_heads = [1,2,2,4,5,2,7,5,2]
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2]) deco_labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl||iobj','punct']
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']) deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
assert(deproj_heads == proj_heads)
assert(undeco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct'])
# if there are two potential new heads, the first one is chosen even if it's wrong # if there are two potential new heads, the first one is chosen even if it's wrong
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \ proj_heads = [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--']) deco_labels = ['advmod||aux','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct']
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1]) deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']) assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
assert(undeco_labels == ['advmod','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct'])