Support 'bow' architecture for TextCategorizer

This allows efficient ngram bag-of-words models, which are better when
the classifier needs to run quickly, especially when the texts are long.
Pass architecture="bow" to use it. The extra arguments ngram_size and
attr are also available, e.g. ngram_size=2 means unigram and bigram
features will be extracted.
This commit is contained in:
Matthew Honnibal 2019-03-23 16:08:27 +01:00
parent 2ddd552d52
commit b7c75049f7

View File

@ -25,6 +25,7 @@ from ..attrs import POS, ID
from ..parts_of_speech import X
from .._ml import Tok2Vec, build_tagger_model
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
from .._ml import build_bow_text_classifier
from .._ml import link_vectors_to_models, zero_init, flatten
from .._ml import masked_language_model, create_default_optimizer
from ..errors import Errors, TempErrors
@ -876,6 +877,8 @@ class TextCategorizer(Pipe):
if cfg.get("architecture") == "simple_cnn":
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
elif cfg.get("architecture") == "bow":
return build_bow_text_classifier(nr_class, **cfg)
else:
return build_text_classifier(nr_class, **cfg)