spaCy/spacy/pipeline.pyx

145 lines
4.9 KiB
Cython
Raw Normal View History

# coding: utf8
from __future__ import unicode_literals
from thinc.api import chain, layerize, with_getitem
from thinc.neural import Model, Softmax
2017-05-07 19:04:24 +03:00
import numpy
2017-05-08 15:53:45 +03:00
from .tokens.doc cimport Doc
from .syntax.parser cimport Parser
#from .syntax.beam_parser cimport BeamParser
from .syntax.ner cimport BiluoPushDown
from .syntax.arc_eager cimport ArcEager
from .tagger import Tagger
2017-05-07 04:57:26 +03:00
from ._ml import build_tok2vec, flatten
# TODO: The disorganization here is pretty embarrassing. At least it's only
# internals.
from .syntax.parser import get_templates as get_feature_templates
from .attrs import DEP, ENT_TYPE
class TokenVectorEncoder(object):
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
def __init__(self, vocab, **cfg):
self.vocab = vocab
2017-05-08 15:46:50 +03:00
self.model = build_tok2vec(vocab.lang, **cfg)
self.tagger = chain(
self.model,
2017-05-07 04:57:26 +03:00
flatten,
Softmax(self.vocab.morphology.n_tags, 64))
def __call__(self, doc):
doc.tensor = self.model([doc])[0]
2017-05-08 15:53:45 +03:00
self.predict_tags([doc])
def begin_update(self, docs, drop=0.):
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
for i, doc in enumerate(docs):
doc.tensor = tensors[i]
2017-05-08 15:53:45 +03:00
self.predict_tags(docs)
return tensors, bp_tensors
2017-05-08 15:53:45 +03:00
def predict_tags(self, docs, drop=0.):
cdef Doc doc
scores, _ = self.tagger.begin_update(docs, drop=drop)
idx = 0
for i, doc in enumerate(docs):
tag_ids = scores[idx:idx+len(doc)].argmax(axis=1)
for j, tag_id in enumerate(tag_ids):
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
idx += 1
def update(self, docs, golds, drop=0., sgd=None):
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
losses = scores.copy()
idx = 0
for i, gold in enumerate(golds):
2017-05-08 15:53:45 +03:00
if hasattr(self.tagger.ops.xp, 'scatter_add'):
ids = numpy.zeros((len(gold),), dtype='i')
start = idx
for j, tag in enumerate(gold.tags):
ids[j] = docs[0].vocab.morphology.tag_names.index(tag)
idx += 1
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0)
else:
for j, tag in enumerate(gold.tags):
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
losses[idx, tag_id] -= 1.
idx += 1
finish_update(losses, sgd)
cdef class EntityRecognizer(Parser):
"""
Annotate named entities on Doc objects.
"""
2017-03-15 17:27:41 +03:00
TransitionSystem = BiluoPushDown
feature_templates = get_feature_templates('ner')
2017-03-15 17:27:41 +03:00
def add_label(self, label):
Parser.add_label(self, label)
2017-03-15 17:27:41 +03:00
if isinstance(label, basestring):
label = self.vocab.strings[label]
# Set label into serializer. Super hacky :(
2017-03-15 17:27:41 +03:00
for attr, freqs in self.vocab.serializer_freqs:
if attr == ENT_TYPE and label not in freqs:
freqs.append([label, 1])
self.vocab._serializer = None
#
#cdef class BeamEntityRecognizer(BeamParser):
# """
# Annotate named entities on Doc objects.
# """
# TransitionSystem = BiluoPushDown
#
# feature_templates = get_feature_templates('ner')
#
# def add_label(self, label):
# Parser.add_label(self, label)
# if isinstance(label, basestring):
# label = self.vocab.strings[label]
# # Set label into serializer. Super hacky :(
# for attr, freqs in self.vocab.serializer_freqs:
# if attr == ENT_TYPE and label not in freqs:
# freqs.append([label, 1])
# self.vocab._serializer = None
#
2017-03-15 17:27:41 +03:00
2017-03-11 20:11:30 +03:00
cdef class DependencyParser(Parser):
TransitionSystem = ArcEager
feature_templates = get_feature_templates('basic')
def add_label(self, label):
Parser.add_label(self, label)
if isinstance(label, basestring):
label = self.vocab.strings[label]
for attr, freqs in self.vocab.serializer_freqs:
if attr == DEP and label not in freqs:
freqs.append([label, 1])
# Super hacky :(
self.vocab._serializer = None
#
#cdef class BeamDependencyParser(BeamParser):
# TransitionSystem = ArcEager
#
# feature_templates = get_feature_templates('basic')
#
# def add_label(self, label):
# Parser.add_label(self, label)
# if isinstance(label, basestring):
# label = self.vocab.strings[label]
# for attr, freqs in self.vocab.serializer_freqs:
# if attr == DEP and label not in freqs:
# freqs.append([label, 1])
# # Super hacky :(
# self.vocab._serializer = None
#
#__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer]
__all__ = [Tagger, DependencyParser, EntityRecognizer]