fix train_textcat script

This commit is contained in:
svlandeg 2020-07-27 16:48:21 +02:00
parent 7dd53d0964
commit 674c39bff9
2 changed files with 13 additions and 16 deletions

View File

@ -20,6 +20,7 @@ import spacy
from spacy import util from spacy import util
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
from spacy.gold import Example from spacy.gold import Example
from thinc.api import Config
@plac.annotations( @plac.annotations(
@ -42,8 +43,9 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
output_dir.mkdir() output_dir.mkdir()
print(f"Loading nlp model from {config_path}") print(f"Loading nlp model from {config_path}")
nlp_config = util.load_config(config_path, create_objects=False)["nlp"] nlp_config = Config().from_disk(config_path)
nlp = util.load_model_from_config(nlp_config) print(f"config: {nlp_config}")
nlp, _ = util.load_model_from_config(nlp_config)
# ensure the nlp object was defined with a textcat component # ensure the nlp object was defined with a textcat component
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:

View File

@ -1,19 +1,14 @@
[nlp] [nlp]
lang = "en" lang = "en"
pipeline = ["textcat"]
[nlp.pipeline.textcat] [components]
[components.textcat]
factory = "textcat" factory = "textcat"
[nlp.pipeline.textcat.model] [components.textcat.model]
@architectures = "spacy.TextCatCNN.v1" @architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false exclusive_classes = true
ngram_size = 1
[nlp.pipeline.textcat.model.tok2vec] no_output_layer = false
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true