mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Predict tags with encoder
This commit is contained in:
parent
56073a11ef
commit
94e86ae00a
|
@ -5,6 +5,7 @@ from thinc.api import chain, layerize, with_getitem
|
|||
from thinc.neural import Model, Softmax
|
||||
import numpy
|
||||
|
||||
from .tokens.doc cimport Doc
|
||||
from .syntax.parser cimport Parser
|
||||
#from .syntax.beam_parser cimport BeamParser
|
||||
from .syntax.ner cimport BiluoPushDown
|
||||
|
@ -30,24 +31,42 @@ class TokenVectorEncoder(object):
|
|||
|
||||
def __call__(self, doc):
|
||||
doc.tensor = self.model([doc])[0]
|
||||
self.predict_tags([doc])
|
||||
|
||||
def begin_update(self, docs, drop=0.):
|
||||
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
|
||||
for i, doc in enumerate(docs):
|
||||
doc.tensor = tensors[i]
|
||||
self.predict_tags(docs)
|
||||
return tensors, bp_tensors
|
||||
|
||||
def predict_tags(self, docs, drop=0.):
|
||||
cdef Doc doc
|
||||
scores, _ = self.tagger.begin_update(docs, drop=drop)
|
||||
idx = 0
|
||||
for i, doc in enumerate(docs):
|
||||
tag_ids = scores[idx:idx+len(doc)].argmax(axis=1)
|
||||
for j, tag_id in enumerate(tag_ids):
|
||||
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||
idx += 1
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None):
|
||||
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
|
||||
losses = scores.copy()
|
||||
idx = 0
|
||||
for i, gold in enumerate(golds):
|
||||
if hasattr(self.tagger.ops.xp, 'scatter_add'):
|
||||
ids = numpy.zeros((len(gold),), dtype='i')
|
||||
start = idx
|
||||
for j, tag in enumerate(gold.tags):
|
||||
ids[j] = docs[0].vocab.morphology.tag_names.index(tag)
|
||||
idx += 1
|
||||
self.tagger.ops.xp.scatter_add(losses[start:idx], ids, -1.0)
|
||||
else:
|
||||
for j, tag in enumerate(gold.tags):
|
||||
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
|
||||
losses[idx, tag_id] -= 1.
|
||||
idx += 1
|
||||
finish_update(losses, sgd)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user