mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 13:40:34 +03:00
adjust train.py to train both english and german models
This commit is contained in:
parent
3448cb40a4
commit
690c5acabf
|
@ -14,6 +14,7 @@ import re
|
|||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
from spacy.de import German
|
||||
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.gold import read_json_file
|
||||
|
@ -25,6 +26,7 @@ from spacy.syntax.arc_eager import ArcEager
|
|||
from spacy.syntax.ner import BiluoPushDown
|
||||
from spacy.tagger import Tagger
|
||||
from spacy.syntax.parser import Parser
|
||||
from spacy.syntax.nonproj import PseudoProjectivity
|
||||
|
||||
|
||||
def _corrupt(c, noise_level):
|
||||
|
@ -82,7 +84,7 @@ def _merge_sents(sents):
|
|||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||
beam_width=1, verbose=False,
|
||||
use_orig_arc_eager=False):
|
||||
use_orig_arc_eager=False, pseudoprojective=False):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
ner_model_dir = path.join(model_dir, 'ner')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
|
@ -96,9 +98,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
os.mkdir(ner_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
|
||||
if pseudoprojective:
|
||||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=ArcEager.get_labels(gold_tuples),
|
||||
beam_width=beam_width)
|
||||
beam_width=beam_width,projectivize=pseudoprojective)
|
||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||
labels=BiluoPushDown.get_labels(gold_tuples),
|
||||
beam_width=0)
|
||||
|
@ -107,6 +113,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
gold_tuples = gold_tuples[:n_sents]
|
||||
|
||||
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||
if nlp.lang == 'de':
|
||||
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
||||
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
||||
|
@ -131,12 +139,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
raw_text = add_noise(raw_text, corruption_level)
|
||||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
if not gold.is_projective:
|
||||
raise Exception(
|
||||
"Non-projective sentence in training, after we should "
|
||||
"have enforced projectivity: %s" % annot_tuples
|
||||
)
|
||||
raise Exception("Non-projective sentence in training: %s" % annot_tuples)
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
|
@ -152,6 +157,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
||||
beam_width=None, cand_preproc=None):
|
||||
nlp = Language(data_dir=model_dir)
|
||||
if nlp.lang == 'de':
|
||||
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
||||
if beam_width is not None:
|
||||
nlp.parser.cfg.beam_width = beam_width
|
||||
scorer = Scorer()
|
||||
|
@ -200,6 +207,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
|||
|
||||
|
||||
@plac.annotations(
|
||||
language=("The language to train", "positional", None, str, ['en','de']),
|
||||
train_loc=("Location of training file or directory"),
|
||||
dev_loc=("Location of development file or directory"),
|
||||
model_dir=("Location of output model directory",),
|
||||
|
@ -211,19 +219,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
|||
n_iter=("Number of training iterations", "option", "i", int),
|
||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||
debug=("Debug mode", "flag", "d", bool),
|
||||
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False):
|
||||
def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
|
||||
lang = {'en':English, 'de':German}.get(language)
|
||||
|
||||
if not eval_only:
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
train(English, gold_train, model_dir,
|
||||
train(lang, gold_train, model_dir,
|
||||
feat_set='basic' if not debug else 'debug',
|
||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||
corruption_level=corruption_level, n_iter=n_iter,
|
||||
verbose=verbose)
|
||||
verbose=verbose,pseudoprojective=pseudoprojective)
|
||||
if out_loc:
|
||||
write_parses(English, dev_loc, model_dir, out_loc)
|
||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
||||
write_parses(lang, dev_loc, model_dir, out_loc)
|
||||
scorer = evaluate(lang, list(read_json_file(dev_loc)),
|
||||
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
||||
print('TOK', scorer.token_acc)
|
||||
print('POS', scorer.tags_acc)
|
||||
|
|
|
@ -6,4 +6,4 @@ from ..language import Language
|
|||
|
||||
|
||||
class German(Language):
|
||||
pass
|
||||
lang = 'de'
|
||||
|
|
|
@ -244,14 +244,8 @@ cdef class GoldParse:
|
|||
raise Exception("Cycle found: %s" % cycle)
|
||||
|
||||
if make_projective:
|
||||
# projectivity here means non-proj arcs are being disconnected
|
||||
np_arcs = []
|
||||
for word in range(self.length):
|
||||
if nonproj.is_nonproj_arc(word,self.heads):
|
||||
np_arcs.append(word)
|
||||
for np_arc in np_arcs:
|
||||
self.heads[np_arc] = None
|
||||
self.labels[np_arc] = ''
|
||||
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
|
||||
self.heads = proj_heads
|
||||
|
||||
self.brackets = {}
|
||||
for (gold_start, gold_end, label_str) in brackets:
|
||||
|
|
|
@ -94,9 +94,10 @@ cdef class Parser:
|
|||
moves = transition_system(strings, cfg.labels)
|
||||
templates = get_templates(cfg.features)
|
||||
model = ParserModel(templates)
|
||||
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
model.load(path.join(model_dir, 'model'))
|
||||
return cls(strings, moves, model, cfg.projectivize)
|
||||
return cls(strings, moves, model, project)
|
||||
|
||||
@classmethod
|
||||
def load(cls, pkg_or_str_or_file, vocab):
|
||||
|
|
|
@ -143,7 +143,7 @@ cdef class Tagger:
|
|||
|
||||
@classmethod
|
||||
def blank(cls, vocab, templates):
|
||||
model = TaggerModel(N_CONTEXT_FIELDS, templates)
|
||||
model = TaggerModel(templates)
|
||||
return cls(vocab, model)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
import pytest
|
||||
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.attrs import DEP, HEAD
|
||||
import numpy
|
||||
|
||||
|
@ -56,12 +53,28 @@ def test_is_nonproj_tree():
|
|||
assert(is_nonproj_tree(partial_tree) == False)
|
||||
assert(is_nonproj_tree(multirooted_tree) == True)
|
||||
|
||||
def test_pseudoprojectivity():
|
||||
|
||||
def deprojectivize(proj_heads, deco_labels, EN):
|
||||
slen = len(proj_heads)
|
||||
sent = EN.tokenizer.tokens_from_list(['whatever'] * slen)
|
||||
rel_proj_heads = [ head-i for i,head in enumerate(proj_heads) ]
|
||||
labelids = [ EN.vocab.strings[label] for label in deco_labels ]
|
||||
parse = numpy.asarray(zip(rel_proj_heads,labelids), dtype=numpy.int32)
|
||||
sent.from_array([HEAD,DEP],parse)
|
||||
PseudoProjectivity.deprojectivize(sent)
|
||||
parse = sent.to_array([HEAD,DEP])
|
||||
deproj_heads = [ i+head for i,head in enumerate(parse[:,0]) ]
|
||||
undeco_labels = [ EN.vocab.strings[int(labelid)] for labelid in parse[:,1] ]
|
||||
return deproj_heads, undeco_labels
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_pseudoprojectivity(EN):
|
||||
tree = [1,2,2]
|
||||
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
||||
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
|
||||
labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct']
|
||||
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
|
||||
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
|
||||
labels2 = ['advmod','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct']
|
||||
|
||||
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
|
||||
assert(PseudoProjectivity.decompose('X') == ('X',''))
|
||||
|
@ -80,29 +93,32 @@ def test_pseudoprojectivity():
|
|||
|
||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
|
||||
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
|
||||
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
||||
# assert(deproj_heads == nonproj_tree)
|
||||
# assert(undeco_labels == labels)
|
||||
assert(deco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl||dobj','punct'])
|
||||
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||
assert(deproj_heads == nonproj_tree)
|
||||
assert(undeco_labels == labels)
|
||||
|
||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
|
||||
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
|
||||
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
||||
# assert(deproj_heads == nonproj_tree2)
|
||||
# assert(undeco_labels == labels2)
|
||||
assert(deco_labels == ['advmod||aux','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct'])
|
||||
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||
assert(deproj_heads == nonproj_tree2)
|
||||
assert(undeco_labels == labels2)
|
||||
|
||||
# if decoration is wrong such that there is no head with the desired label
|
||||
# the structure is kept and the label is undecorated
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
|
||||
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
|
||||
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
|
||||
proj_heads = [1,2,2,4,5,2,7,5,2]
|
||||
deco_labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl||iobj','punct']
|
||||
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||
assert(deproj_heads == proj_heads)
|
||||
assert(undeco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct'])
|
||||
|
||||
# if there are two potential new heads, the first one is chosen even if it's wrong
|
||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
|
||||
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
||||
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
||||
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
|
||||
proj_heads = [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]
|
||||
deco_labels = ['advmod||aux','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct']
|
||||
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
||||
assert(undeco_labels == ['advmod','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct'])
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user