mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 00:20:35 +03:00
adjust train.py to train both english and german models
This commit is contained in:
parent
3448cb40a4
commit
690c5acabf
|
@ -14,6 +14,7 @@ import re
|
||||||
|
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
|
from spacy.de import German
|
||||||
|
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
from spacy.gold import read_json_file
|
from spacy.gold import read_json_file
|
||||||
|
@ -25,6 +26,7 @@ from spacy.syntax.arc_eager import ArcEager
|
||||||
from spacy.syntax.ner import BiluoPushDown
|
from spacy.syntax.ner import BiluoPushDown
|
||||||
from spacy.tagger import Tagger
|
from spacy.tagger import Tagger
|
||||||
from spacy.syntax.parser import Parser
|
from spacy.syntax.parser import Parser
|
||||||
|
from spacy.syntax.nonproj import PseudoProjectivity
|
||||||
|
|
||||||
|
|
||||||
def _corrupt(c, noise_level):
|
def _corrupt(c, noise_level):
|
||||||
|
@ -82,7 +84,7 @@ def _merge_sents(sents):
|
||||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||||
beam_width=1, verbose=False,
|
beam_width=1, verbose=False,
|
||||||
use_orig_arc_eager=False):
|
use_orig_arc_eager=False, pseudoprojective=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
ner_model_dir = path.join(model_dir, 'ner')
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
|
@ -96,9 +98,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
os.mkdir(ner_model_dir)
|
os.mkdir(ner_model_dir)
|
||||||
os.mkdir(pos_model_dir)
|
os.mkdir(pos_model_dir)
|
||||||
|
|
||||||
|
if pseudoprojective:
|
||||||
|
# preprocess training data here before ArcEager.get_labels() is called
|
||||||
|
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=ArcEager.get_labels(gold_tuples),
|
labels=ArcEager.get_labels(gold_tuples),
|
||||||
beam_width=beam_width)
|
beam_width=beam_width,projectivize=pseudoprojective)
|
||||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||||
labels=BiluoPushDown.get_labels(gold_tuples),
|
labels=BiluoPushDown.get_labels(gold_tuples),
|
||||||
beam_width=0)
|
beam_width=0)
|
||||||
|
@ -107,6 +113,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
gold_tuples = gold_tuples[:n_sents]
|
gold_tuples = gold_tuples[:n_sents]
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||||
|
if nlp.lang == 'de':
|
||||||
|
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
||||||
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||||
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||||
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
||||||
|
@ -131,12 +139,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
raw_text = add_noise(raw_text, corruption_level)
|
raw_text = add_noise(raw_text, corruption_level)
|
||||||
tokens = nlp.tokenizer(raw_text)
|
tokens = nlp.tokenizer(raw_text)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
if not gold.is_projective:
|
if not gold.is_projective:
|
||||||
raise Exception(
|
raise Exception("Non-projective sentence in training: %s" % annot_tuples)
|
||||||
"Non-projective sentence in training, after we should "
|
|
||||||
"have enforced projectivity: %s" % annot_tuples
|
|
||||||
)
|
|
||||||
loss += nlp.parser.train(tokens, gold)
|
loss += nlp.parser.train(tokens, gold)
|
||||||
nlp.entity.train(tokens, gold)
|
nlp.entity.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
|
@ -152,6 +157,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
||||||
beam_width=None, cand_preproc=None):
|
beam_width=None, cand_preproc=None):
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir)
|
||||||
|
if nlp.lang == 'de':
|
||||||
|
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
||||||
if beam_width is not None:
|
if beam_width is not None:
|
||||||
nlp.parser.cfg.beam_width = beam_width
|
nlp.parser.cfg.beam_width = beam_width
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
|
@ -200,6 +207,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
language=("The language to train", "positional", None, str, ['en','de']),
|
||||||
train_loc=("Location of training file or directory"),
|
train_loc=("Location of training file or directory"),
|
||||||
dev_loc=("Location of development file or directory"),
|
dev_loc=("Location of development file or directory"),
|
||||||
model_dir=("Location of output model directory",),
|
model_dir=("Location of output model directory",),
|
||||||
|
@ -211,19 +219,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
n_iter=("Number of training iterations", "option", "i", int),
|
n_iter=("Number of training iterations", "option", "i", int),
|
||||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||||
debug=("Debug mode", "flag", "d", bool),
|
debug=("Debug mode", "flag", "d", bool),
|
||||||
|
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||||
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False):
|
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
|
||||||
|
lang = {'en':English, 'de':German}.get(language)
|
||||||
|
|
||||||
if not eval_only:
|
if not eval_only:
|
||||||
gold_train = list(read_json_file(train_loc))
|
gold_train = list(read_json_file(train_loc))
|
||||||
train(English, gold_train, model_dir,
|
train(lang, gold_train, model_dir,
|
||||||
feat_set='basic' if not debug else 'debug',
|
feat_set='basic' if not debug else 'debug',
|
||||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||||
corruption_level=corruption_level, n_iter=n_iter,
|
corruption_level=corruption_level, n_iter=n_iter,
|
||||||
verbose=verbose)
|
verbose=verbose,pseudoprojective=pseudoprojective)
|
||||||
if out_loc:
|
if out_loc:
|
||||||
write_parses(English, dev_loc, model_dir, out_loc)
|
write_parses(lang, dev_loc, model_dir, out_loc)
|
||||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
scorer = evaluate(lang, list(read_json_file(dev_loc)),
|
||||||
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
model_dir, gold_preproc=gold_preproc, verbose=verbose)
|
||||||
print('TOK', scorer.token_acc)
|
print('TOK', scorer.token_acc)
|
||||||
print('POS', scorer.tags_acc)
|
print('POS', scorer.tags_acc)
|
||||||
|
|
|
@ -6,4 +6,4 @@ from ..language import Language
|
||||||
|
|
||||||
|
|
||||||
class German(Language):
|
class German(Language):
|
||||||
pass
|
lang = 'de'
|
||||||
|
|
|
@ -244,14 +244,8 @@ cdef class GoldParse:
|
||||||
raise Exception("Cycle found: %s" % cycle)
|
raise Exception("Cycle found: %s" % cycle)
|
||||||
|
|
||||||
if make_projective:
|
if make_projective:
|
||||||
# projectivity here means non-proj arcs are being disconnected
|
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
|
||||||
np_arcs = []
|
self.heads = proj_heads
|
||||||
for word in range(self.length):
|
|
||||||
if nonproj.is_nonproj_arc(word,self.heads):
|
|
||||||
np_arcs.append(word)
|
|
||||||
for np_arc in np_arcs:
|
|
||||||
self.heads[np_arc] = None
|
|
||||||
self.labels[np_arc] = ''
|
|
||||||
|
|
||||||
self.brackets = {}
|
self.brackets = {}
|
||||||
for (gold_start, gold_end, label_str) in brackets:
|
for (gold_start, gold_end, label_str) in brackets:
|
||||||
|
|
|
@ -94,9 +94,10 @@ cdef class Parser:
|
||||||
moves = transition_system(strings, cfg.labels)
|
moves = transition_system(strings, cfg.labels)
|
||||||
templates = get_templates(cfg.features)
|
templates = get_templates(cfg.features)
|
||||||
model = ParserModel(templates)
|
model = ParserModel(templates)
|
||||||
|
project = cfg.projectivize if hasattr(cfg,'projectivize') else False
|
||||||
if path.exists(path.join(model_dir, 'model')):
|
if path.exists(path.join(model_dir, 'model')):
|
||||||
model.load(path.join(model_dir, 'model'))
|
model.load(path.join(model_dir, 'model'))
|
||||||
return cls(strings, moves, model, cfg.projectivize)
|
return cls(strings, moves, model, project)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, pkg_or_str_or_file, vocab):
|
def load(cls, pkg_or_str_or_file, vocab):
|
||||||
|
|
|
@ -143,7 +143,7 @@ cdef class Tagger:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def blank(cls, vocab, templates):
|
def blank(cls, vocab, templates):
|
||||||
model = TaggerModel(N_CONTEXT_FIELDS, templates)
|
model = TaggerModel(templates)
|
||||||
return cls(vocab, model)
|
return cls(vocab, model)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.tokens.doc import Doc
|
|
||||||
from spacy.vocab import Vocab
|
|
||||||
from spacy.tokenizer import Tokenizer
|
|
||||||
from spacy.attrs import DEP, HEAD
|
from spacy.attrs import DEP, HEAD
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
@ -56,12 +53,28 @@ def test_is_nonproj_tree():
|
||||||
assert(is_nonproj_tree(partial_tree) == False)
|
assert(is_nonproj_tree(partial_tree) == False)
|
||||||
assert(is_nonproj_tree(multirooted_tree) == True)
|
assert(is_nonproj_tree(multirooted_tree) == True)
|
||||||
|
|
||||||
def test_pseudoprojectivity():
|
|
||||||
|
def deprojectivize(proj_heads, deco_labels, EN):
|
||||||
|
slen = len(proj_heads)
|
||||||
|
sent = EN.tokenizer.tokens_from_list(['whatever'] * slen)
|
||||||
|
rel_proj_heads = [ head-i for i,head in enumerate(proj_heads) ]
|
||||||
|
labelids = [ EN.vocab.strings[label] for label in deco_labels ]
|
||||||
|
parse = numpy.asarray(zip(rel_proj_heads,labelids), dtype=numpy.int32)
|
||||||
|
sent.from_array([HEAD,DEP],parse)
|
||||||
|
PseudoProjectivity.deprojectivize(sent)
|
||||||
|
parse = sent.to_array([HEAD,DEP])
|
||||||
|
deproj_heads = [ i+head for i,head in enumerate(parse[:,0]) ]
|
||||||
|
undeco_labels = [ EN.vocab.strings[int(labelid)] for labelid in parse[:,1] ]
|
||||||
|
return deproj_heads, undeco_labels
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.models
|
||||||
|
def test_pseudoprojectivity(EN):
|
||||||
tree = [1,2,2]
|
tree = [1,2,2]
|
||||||
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
||||||
labels = ['NK','SB','ROOT','NK','OA','OC','SB','RC','--']
|
labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct']
|
||||||
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
|
nonproj_tree2 = [9,1,3,1,5,6,9,8,6,1,6,12,13,10,1]
|
||||||
labels2 = ['MO','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--']
|
labels2 = ['advmod','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct']
|
||||||
|
|
||||||
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
|
assert(PseudoProjectivity.decompose('X||Y') == ('X','Y'))
|
||||||
assert(PseudoProjectivity.decompose('X') == ('X',''))
|
assert(PseudoProjectivity.decompose('X') == ('X',''))
|
||||||
|
@ -80,29 +93,32 @@ def test_pseudoprojectivity():
|
||||||
|
|
||||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
|
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree,labels)
|
||||||
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
|
assert(proj_heads == [1,2,2,4,5,2,7,5,2])
|
||||||
assert(deco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC||OA','--'])
|
assert(deco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl||dobj','punct'])
|
||||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||||
# assert(deproj_heads == nonproj_tree)
|
assert(deproj_heads == nonproj_tree)
|
||||||
# assert(undeco_labels == labels)
|
assert(undeco_labels == labels)
|
||||||
|
|
||||||
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
|
proj_heads, deco_labels = PseudoProjectivity.projectivize(nonproj_tree2,labels2)
|
||||||
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
|
assert(proj_heads == [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1])
|
||||||
assert(deco_labels == ['MO||OC','ROOT','NK','SB','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
assert(deco_labels == ['advmod||aux','root','det','nsubj','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct'])
|
||||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize(proj_heads,deco_labels)
|
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||||
# assert(deproj_heads == nonproj_tree2)
|
assert(deproj_heads == nonproj_tree2)
|
||||||
# assert(undeco_labels == labels2)
|
assert(undeco_labels == labels2)
|
||||||
|
|
||||||
# if decoration is wrong such that there is no head with the desired label
|
# if decoration is wrong such that there is no head with the desired label
|
||||||
# the structure is kept and the label is undecorated
|
# the structure is kept and the label is undecorated
|
||||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,2,2,4,5,2,7,5,2],['NK','SB','ROOT','NK','OA','OC','SB','RC||DA','--'])
|
proj_heads = [1,2,2,4,5,2,7,5,2]
|
||||||
# assert(deproj_heads == [1,2,2,4,5,2,7,5,2])
|
deco_labels = ['det','nsubj','root','det','dobj','aux','nsubj','acl||iobj','punct']
|
||||||
# assert(undeco_labels == ['NK','SB','ROOT','NK','OA','OC','SB','RC','--'])
|
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||||
|
assert(deproj_heads == proj_heads)
|
||||||
|
assert(undeco_labels == ['det','nsubj','root','det','dobj','aux','nsubj','acl','punct'])
|
||||||
|
|
||||||
# if there are two potential new heads, the first one is chosen even if it's wrong
|
# if there are two potential new heads, the first one is chosen even if it's wrong
|
||||||
# deproj_heads, undeco_labels = PseudoProjectivity.deprojectivize([1,1,3,1,5,6,9,8,6,1,9,12,13,10,1], \
|
proj_heads = [1,1,3,1,5,6,9,8,6,1,9,12,13,10,1]
|
||||||
# ['MO||OC','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR||OA','MO','NK','NK','--'])
|
deco_labels = ['advmod||aux','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod||dobj','advmod','det','amod','punct']
|
||||||
# assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
deproj_heads, undeco_labels = deprojectivize(proj_heads,deco_labels,EN)
|
||||||
# assert(undeco_labels == ['MO','ROOT','NK','OC','MO','NK','OA','NK','AG','OC','MNR','MO','NK','NK','--'])
|
assert(deproj_heads == [3,1,3,1,5,6,9,8,6,1,6,12,13,10,1])
|
||||||
|
assert(undeco_labels == ['advmod','root','det','aux','advmod','det','dobj','det','nmod','aux','nmod','advmod','det','amod','punct'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user