Add experimental NeuralLabeller

This commit is contained in:
Matthew Honnibal 2017-05-21 17:52:30 -05:00
parent 9b1b0742fd
commit 8d1e64be69
2 changed files with 45 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
from .pipeline import NeuralDependencyParser, EntityRecognizer from .pipeline import NeuralDependencyParser, EntityRecognizer
from .pipeline import TokenVectorEncoder, NeuralTagger, NeuralEntityRecognizer from .pipeline import TokenVectorEncoder, NeuralTagger, NeuralEntityRecognizer
from .pipeline import NeuralLabeller
from .compat import json_dumps from .compat import json_dumps
from .attrs import IS_STOP from .attrs import IS_STOP
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
@ -230,7 +231,7 @@ class Language(object):
for doc, gold in docs_golds: for doc, gold in docs_golds:
yield doc, gold yield doc, gold
def begin_training(self, gold_tuples, **cfg): def begin_training(self, get_gold_tuples, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and """Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager. optimizer. Used as a contextmanager.
@ -244,6 +245,7 @@ class Language(object):
>>> for docs, golds in epoch: >>> for docs, golds in epoch:
>>> state = nlp.update(docs, golds, sgd=optimizer) >>> state = nlp.update(docs, golds, sgd=optimizer)
""" """
self.pipeline.append(NeuralLabeller(self.vocab))
# Populate vocab # Populate vocab
for _, annots_brackets in get_gold_tuples(): for _, annots_brackets in get_gold_tuples():
for annots, _ in annots_brackets: for annots, _ in annots_brackets:

View File

@ -31,6 +31,7 @@ from .syntax.stateclass cimport StateClass
from .gold cimport GoldParse from .gold cimport GoldParse
from .morphology cimport Morphology from .morphology cimport Morphology
from .vocab cimport Vocab from .vocab cimport Vocab
from .syntax.nonproj import PseudoProjectivity
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
@ -148,6 +149,7 @@ class TokenVectorEncoder(object):
if self.model is True: if self.model is True:
self.model = self.Model() self.model = self.Model()
def use_params(self, params): def use_params(self, params):
"""Replace weights of models in the pipeline with those provided in the """Replace weights of models in the pipeline with those provided in the
params dictionary. params dictionary.
@ -252,6 +254,46 @@ class NeuralTagger(object):
with self.model.use_params(params): with self.model.use_params(params):
yield yield
class NeuralLabeller(NeuralTagger):
name = 'nn_labeller'
def __init__(self, vocab, model=True):
self.vocab = vocab
self.model = model
self.labels = {}
def set_annotations(self, docs, dep_ids):
pass
def begin_training(self, gold_tuples, pipeline=None):
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for dep in deps:
if dep not in self.labels:
self.labels[dep] = len(self.labels)
token_vector_width = pipeline[0].model.nO
self.model = with_flatten(
Softmax(len(self.labels), token_vector_width))
def get_loss(self, docs, golds, scores):
scores = self.model.ops.flatten(scores)
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype='i')
guesses = scores.argmax(axis=1)
for gold in golds:
for tag in gold.labels:
if tag is None:
correct[idx] = guesses[idx]
else:
correct[idx] = self.labels[tag]
idx += 1
correct = self.model.ops.xp.array(correct, dtype='i')
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
loss = (d_scores**2).sum()
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
cdef class EntityRecognizer(LinearParser): cdef class EntityRecognizer(LinearParser):
"""Annotate named entities on Doc objects.""" """Annotate named entities on Doc objects."""