Update train_ner_standalone example

This commit is contained in:
Matthew Honnibal 2017-09-15 10:36:46 +02:00
parent d84607f6bb
commit 027a5d8b75

View File

@ -13,24 +13,27 @@ Input data:
https://www.lt.informatik.tu-darmstadt.de/fileadmin/user_upload/Group_LangTech/data/GermEval2014_complete_data.zip
Developed for: spaCy 1.7.1
Last tested for: spaCy 1.7.1
Last tested for: spaCy 2.0.0a13
'''
from __future__ import unicode_literals, print_function
import plac
from pathlib import Path
import random
import json
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
import tqdm
import spacy.orth as orth_funcs
from spacy.vocab import Vocab
from spacy.pipeline import BeamEntityRecognizer
from spacy.pipeline import EntityRecognizer
from spacy.pipeline import TokenVectorEncoder, NeuralEntityRecognizer
from spacy.tokenizer import Tokenizer
from spacy.tokens import Doc
from spacy.attrs import *
from spacy.gold import GoldParse
from spacy.gold import _iob_to_biluo as iob_to_biluo
from spacy.gold import iob_to_biluo
from spacy.gold import minibatch
from spacy.scorer import Scorer
import spacy.util
try:
unicode
@ -38,95 +41,40 @@ except NameError:
unicode = str
spacy.util.set_env_log(True)
def init_vocab():
return Vocab(
lex_attr_getters={
LOWER: lambda string: string.lower(),
SHAPE: orth_funcs.word_shape,
NORM: lambda string: string.lower(),
PREFIX: lambda string: string[0],
SUFFIX: lambda string: string[-3:],
CLUSTER: lambda string: 0,
IS_ALPHA: orth_funcs.is_alpha,
IS_ASCII: orth_funcs.is_ascii,
IS_DIGIT: lambda string: string.isdigit(),
IS_LOWER: orth_funcs.is_lower,
IS_PUNCT: orth_funcs.is_punct,
IS_SPACE: lambda string: string.isspace(),
IS_TITLE: orth_funcs.is_title,
IS_UPPER: orth_funcs.is_upper,
IS_STOP: lambda string: False,
IS_OOV: lambda string: True
})
def save_vocab(vocab, path):
path = Path(path)
if not path.exists():
path.mkdir()
elif not path.is_dir():
raise IOError("Can't save vocab to %s\nNot a directory" % path)
with (path / 'strings.json').open('w') as file_:
vocab.strings.dump(file_)
vocab.dump((path / 'lexemes.bin').as_posix())
def load_vocab(path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load vocab from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load vocab from %s\nNot a directory" % path)
return Vocab.load(path)
def init_ner_model(vocab, features=None):
if features is None:
features = tuple(EntityRecognizer.feature_templates)
return EntityRecognizer(vocab, features=features)
def save_ner_model(model, path):
path = Path(path)
if not path.exists():
path.mkdir()
if not path.is_dir():
raise IOError("Can't save model to %s\nNot a directory" % path)
model.model.dump((path / 'model').as_posix())
with (path / 'config.json').open('w') as file_:
data = json.dumps(model.cfg)
if not isinstance(data, unicode):
data = data.decode('utf8')
file_.write(data)
def load_ner_model(vocab, path):
return EntityRecognizer.load(path, vocab)
class Pipeline(object):
@classmethod
def load(cls, path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load pipeline from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load pipeline from %s\nNot a directory" % path)
vocab = load_vocab(path)
tokenizer = Tokenizer(vocab, {}, None, None, None)
ner_model = load_ner_model(vocab, path / 'ner')
return cls(vocab, tokenizer, ner_model)
def __init__(self, vocab=None, tokenizer=None, entity=None):
def __init__(self, vocab=None, tokenizer=None, tensorizer=None, entity=None):
if vocab is None:
vocab = init_vocab()
if tokenizer is None:
tokenizer = Tokenizer(vocab, {}, None, None, None)
if tensorizer is None:
tensorizer = TokenVectorEncoder(vocab)
if entity is None:
entity = init_ner_model(self.vocab)
entity = NeuralEntityRecognizer(vocab)
self.vocab = vocab
self.tokenizer = tokenizer
self.tensorizer = tensorizer
self.entity = entity
self.pipeline = [self.entity]
self.pipeline = [tensorizer, self.entity]
def begin_training(self):
for model in self.pipeline:
model.begin_training([])
optimizer = Adam(NumpyOps(), 0.001)
return optimizer
def __call__(self, input_):
doc = self.make_doc(input_)
@ -147,14 +95,18 @@ class Pipeline(object):
gold = GoldParse(doc, entities=annotations)
return gold
def update(self, input_, annot):
doc = self.make_doc(input_)
gold = self.make_gold(input_, annot)
for ner in gold.ner:
if ner not in (None, '-', 'O'):
action, label = ner.split('-', 1)
self.entity.add_label(label)
return self.entity.update(doc, gold)
def update(self, inputs, annots, sgd, losses=None, drop=0.):
if losses is None:
losses = {}
docs = [self.make_doc(input_) for input_ in inputs]
golds = [self.make_gold(input_, annot) for input_, annot in
zip(inputs, annots)]
tensors, bp_tensors = self.tensorizer.update(docs, golds, drop=drop)
d_tensors = self.entity.update((docs, tensors), golds, drop=drop,
sgd=sgd, losses=losses)
bp_tensors(d_tensors, sgd=sgd)
return losses
def evaluate(self, examples):
scorer = Scorer()
@ -164,34 +116,38 @@ class Pipeline(object):
scorer.score(doc, gold)
return scorer.scores
def average_weights(self):
self.entity.model.end_training()
def save(self, path):
def to_disk(self, path):
path = Path(path)
if not path.exists():
path.mkdir()
elif not path.is_dir():
raise IOError("Can't save pipeline to %s\nNot a directory" % path)
save_vocab(self.vocab, path / 'vocab')
save_ner_model(self.entity, path / 'ner')
self.vocab.to_disk(path / 'vocab')
self.tensorizer.to_disk(path / 'tensorizer')
self.entity.to_disk(path / 'ner')
def from_disk(self, path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load pipeline from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load pipeline from %s\nNot a directory" % path)
self.vocab = self.vocab.from_disk(path / 'vocab')
self.tensorizer = self.tensorizer.from_disk(path / 'tensorizer')
self.entity = self.entity.from_disk(path / 'ner')
def train(nlp, train_examples, dev_examples, ctx, nr_epoch=5):
next_epoch = train_examples
def train(nlp, train_examples, dev_examples, nr_epoch=5):
sgd = nlp.begin_training()
print("Iter", "Loss", "P", "R", "F")
for i in range(nr_epoch):
this_epoch = next_epoch
next_epoch = []
loss = 0
for input_, annot in this_epoch:
loss += nlp.update(input_, annot)
if (i+1) < nr_epoch:
next_epoch.append((input_, annot))
random.shuffle(next_epoch)
random.shuffle(train_examples)
losses = {}
for batch in minibatch(tqdm.tqdm(train_examples, leave=False), size=8):
inputs, annots = zip(*batch)
nlp.update(list(inputs), list(annots), sgd, losses=losses)
scores = nlp.evaluate(dev_examples)
report_scores(i, loss, scores)
nlp.average_weights()
report_scores(i, losses['ner'], scores)
scores = nlp.evaluate(dev_examples)
report_scores(channels, i+1, loss, scores)
@ -208,7 +164,8 @@ def read_examples(path):
with path.open() as file_:
sents = file_.read().strip().split('\n\n')
for sent in sents:
if not sent.strip():
sent = sent.strip()
if not sent:
continue
tokens = sent.split('\n')
while tokens and tokens[0].startswith('#'):
@ -217,28 +174,39 @@ def read_examples(path):
iob = []
for token in tokens:
if token.strip():
pieces = token.split()
pieces = token.split('\t')
words.append(pieces[1])
iob.append(pieces[2])
yield words, iob_to_biluo(iob)
def get_labels(examples):
labels = set()
for words, tags in examples:
for tag in tags:
if '-' in tag:
labels.add(tag.split('-')[1])
return sorted(labels)
@plac.annotations(
model_dir=("Path to save the model", "positional", None, Path),
train_loc=("Path to your training data", "positional", None, Path),
dev_loc=("Path to your development data", "positional", None, Path),
)
def main(model_dir=Path('/home/matt/repos/spaCy/spacy/data/de-1.0.0'),
train_loc=None, dev_loc=None, nr_epoch=30):
train_examples = read_examples(train_loc)
def main(model_dir, train_loc, dev_loc, nr_epoch=30):
print(model_dir, train_loc, dev_loc)
train_examples = list(read_examples(train_loc))
dev_examples = read_examples(dev_loc)
nlp = Pipeline.load(model_dir)
nlp = Pipeline()
for label in get_labels(train_examples):
nlp.entity.add_label(label)
print("Add label", label)
train(nlp, train_examples, list(dev_examples), ctx, nr_epoch)
train(nlp, train_examples, list(dev_examples), nr_epoch)
nlp.save(model_dir)
nlp.to_disk(model_dir)
if __name__ == '__main__':
main()
plac.call(main)