From ce1e4eace2a5adb8dfa6166fdd3d3c06886c9b16 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 23 Feb 2019 11:55:16 +0100 Subject: [PATCH] Default to former TextCategorizer model * Keep TextCategorizer default model same as v2.0 * Add option 'architecture' that allows "simple_cnn" to switch to simpler model. * Add option exclusive_classes, defaulting to False. If set to True, the model treats classes as mutually exclusive, i.e. only one class can be true per instance. --- spacy/pipeline/pipes.pyx | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 100b5abfd..4e052ef16 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -24,7 +24,8 @@ from ..vocab cimport Vocab from ..syntax import nonproj from ..attrs import POS, ID from ..parts_of_speech import X -from .._ml import Tok2Vec, build_tagger_model, build_simple_cnn_text_classifier +from .._ml import Tok2Vec, build_tagger_model +from .._ml import build_text_classifier, build_simple_cnn_text_classifier from .._ml import link_vectors_to_models, zero_init, flatten from .._ml import masked_language_model, create_default_optimizer from ..errors import Errors, TempErrors @@ -862,8 +863,11 @@ class TextCategorizer(Pipe): token_vector_width = cfg["token_vector_width"] else: token_vector_width = util.env_opt("token_vector_width", 96) - tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) - return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) + if cfg.get('architecture') == 'simple_cnn': + tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) + return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) + else: + return build_text_classifier(nr_class, **cfg) @property def tok2vec(self):