mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 19:08:06 +03:00
Improve argument passing in textcat
This commit is contained in:
parent
eb2a3c5971
commit
565ef8c4d8
|
@ -866,8 +866,8 @@ class TextCategorizer(Pipe):
|
||||||
name = 'textcat'
|
name = 'textcat'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, nr_class=1, width=64, **cfg):
|
def Model(cls, **cfg):
|
||||||
return build_text_classifier(nr_class, width, **cfg)
|
return build_text_classifier(**cfg)
|
||||||
|
|
||||||
def __init__(self, vocab, model=True, **cfg):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
|
@ -948,8 +948,9 @@ class TextCategorizer(Pipe):
|
||||||
token_vector_width = 64
|
token_vector_width = 64
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||||
self.model = self.Model(len(self.labels), token_vector_width,
|
self.cfg['nr_class'] = len(self.labels)
|
||||||
**self.cfg)
|
self.cfg['width'] = token_vector_width
|
||||||
|
self.model = self.Model(**self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user