mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Add low_data mode in textcat
This commit is contained in:
parent
ead78c7b9b
commit
a3b69bcb3d
30
spacy/_ml.py
30
spacy/_ml.py
|
@ -510,9 +510,23 @@ def foreach(layer, drop_factor=1.0):
|
|||
|
||||
|
||||
def build_text_classifier(nr_class, width=64, **cfg):
|
||||
nr_vector = cfg.get('nr_vector', 200)
|
||||
nr_vector = cfg.get('nr_vector', 5000)
|
||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
||||
'**': clone}):
|
||||
if cfg.get('low_data'):
|
||||
model = (
|
||||
SpacyVectors
|
||||
>> flatten_add_lengths
|
||||
>> with_getitem(0, LN(Affine(width, 300)))
|
||||
>> ParametricAttention(width)
|
||||
>> Pooling(sum_pool)
|
||||
>> Residual(ReLu(width, width)) ** 2
|
||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||
>> logistic
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
lower = HashEmbed(width, nr_vector, column=1)
|
||||
prefix = HashEmbed(width//2, nr_vector, column=2)
|
||||
suffix = HashEmbed(width//2, nr_vector, column=3)
|
||||
|
@ -523,7 +537,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
>> with_flatten(
|
||||
uniqued(
|
||||
(lower | prefix | suffix | shape)
|
||||
>> LN(Maxout(width, 64+32+32+32)),
|
||||
>> LN(Maxout(width, width+(width//2)*3)),
|
||||
column=0
|
||||
)
|
||||
)
|
||||
|
@ -537,14 +551,16 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
cnn_model = (
|
||||
# TODO Make concatenate support lists
|
||||
concatenate_lists(trained_vectors, static_vectors)
|
||||
>> flatten_add_lengths
|
||||
>> with_getitem(0,
|
||||
SELU(width, width*2)
|
||||
>> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2
|
||||
>> with_flatten(
|
||||
LN(Maxout(width, width*2))
|
||||
>> Residual(
|
||||
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
|
||||
) ** 2, pad=2
|
||||
)
|
||||
>> flatten_add_lengths
|
||||
>> ParametricAttention(width)
|
||||
>> Pooling(sum_pool)
|
||||
>> SELU(width, width) ** 2
|
||||
>> Residual(zero_init(Maxout(width, width)))
|
||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||
)
|
||||
|
||||
|
|
|
@ -638,12 +638,13 @@ class TextCategorizer(BaseThincComponent):
|
|||
return mean_square_error, d_scores
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
if pipeline:
|
||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
else:
|
||||
token_vector_width = 64
|
||||
if self.model is True:
|
||||
self.model = self.Model(len(self.labels), token_vector_width)
|
||||
self.model = self.Model(len(self.labels), token_vector_width,
|
||||
**self.cfg)
|
||||
|
||||
|
||||
cdef class EntityRecognizer(LinearParser):
|
||||
|
|
Loading…
Reference in New Issue
Block a user