mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Add add_label methods to Tagger and TextCategorizer
This commit is contained in:
parent
5ab4e96144
commit
e7a9174877
|
@ -11,9 +11,9 @@ import ujson
|
|||
import msgpack
|
||||
|
||||
from thinc.api import chain
|
||||
from thinc.v2v import Softmax
|
||||
from thinc.v2v import Affine, Softmax
|
||||
from thinc.t2v import Pooling, max_pool, mean_pool
|
||||
from thinc.neural.util import to_categorical
|
||||
from thinc.neural.util import to_categorical, copy_array
|
||||
from thinc.neural._classes.difference import Siamese, CauchySimilarity
|
||||
|
||||
from .tokens.doc cimport Doc
|
||||
|
@ -130,6 +130,15 @@ class Pipe(object):
|
|||
documents and their predicted scores."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_label(self, label):
|
||||
"""Add an output label, to be predicted by the model.
|
||||
|
||||
It's possible to extend pre-trained models with new labels,
|
||||
but care should be taken to avoid the "catastrophic forgetting"
|
||||
problem.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
"""Initialize the pipe for training, using data exampes if available.
|
||||
If no model has been initialized yet, the model is added."""
|
||||
|
@ -325,6 +334,14 @@ class Tagger(Pipe):
|
|||
self.cfg.setdefault('pretrained_dims',
|
||||
self.vocab.vectors.data.shape[1])
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
return self.cfg.setdefault('tag_names', [])
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value):
|
||||
self.cfg['tag_names'] = value
|
||||
|
||||
def __call__(self, doc):
|
||||
tags = self.predict([doc])
|
||||
self.set_annotations([doc], tags)
|
||||
|
@ -352,6 +369,7 @@ class Tagger(Pipe):
|
|||
cdef Doc doc
|
||||
cdef int idx = 0
|
||||
cdef Vocab vocab = self.vocab
|
||||
tags = list(self.labels)
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, 'get'):
|
||||
|
@ -359,7 +377,7 @@ class Tagger(Pipe):
|
|||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber preset POS tags
|
||||
if doc.c[j].tag == 0 and doc.c[j].pos == 0:
|
||||
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||
vocab.morphology.assign_tag(&doc.c[j], tags[tag_id])
|
||||
idx += 1
|
||||
doc.is_tagged = True
|
||||
|
||||
|
@ -420,6 +438,17 @@ class Tagger(Pipe):
|
|||
def Model(cls, n_tags, **cfg):
|
||||
return build_tagger_model(n_tags, **cfg)
|
||||
|
||||
def add_label(self, label):
|
||||
if label in self.labels:
|
||||
return 0
|
||||
smaller = self.model[-1]._layers[-1]
|
||||
larger = Softmax(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model[-1]._layers[-1] = larger
|
||||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
yield
|
||||
|
@ -675,7 +704,7 @@ class TextCategorizer(Pipe):
|
|||
|
||||
@property
|
||||
def labels(self):
|
||||
return self.cfg.get('labels', ['LABEL'])
|
||||
return self.cfg.setdefault('labels', ['LABEL'])
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value):
|
||||
|
@ -727,6 +756,17 @@ class TextCategorizer(Pipe):
|
|||
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
|
||||
return mean_square_error, d_scores
|
||||
|
||||
def add_label(self, label):
|
||||
if label in self.labels:
|
||||
return 0
|
||||
smaller = self.model[-1]._layers[-1]
|
||||
larger = Affine(len(self.labels)+1, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model[-1]._layers[-1] = larger
|
||||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
|
|
Loading…
Reference in New Issue
Block a user