Improve argument passing in textcat

This commit is contained in:
Matthew Honnibal 2018-03-16 12:30:51 +01:00
parent eb2a3c5971
commit 565ef8c4d8

View File

@ -866,8 +866,8 @@ class TextCategorizer(Pipe):
name = 'textcat' name = 'textcat'
@classmethod @classmethod
def Model(cls, nr_class=1, width=64, **cfg): def Model(cls, **cfg):
return build_text_classifier(nr_class, width, **cfg) return build_text_classifier(**cfg)
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
@ -948,8 +948,9 @@ class TextCategorizer(Pipe):
token_vector_width = 64 token_vector_width = 64
if self.model is True: if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length self.cfg['pretrained_dims'] = self.vocab.vectors_length
self.model = self.Model(len(self.labels), token_vector_width, self.cfg['nr_class'] = len(self.labels)
**self.cfg) self.cfg['width'] = token_vector_width
self.model = self.Model(**self.cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()