mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
fix train_textcat script
This commit is contained in:
parent
7dd53d0964
commit
674c39bff9
|
@ -20,6 +20,7 @@ import spacy
|
|||
from spacy import util
|
||||
from spacy.util import minibatch, compounding
|
||||
from spacy.gold import Example
|
||||
from thinc.api import Config
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -42,8 +43,9 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
|
|||
output_dir.mkdir()
|
||||
|
||||
print(f"Loading nlp model from {config_path}")
|
||||
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
nlp_config = Config().from_disk(config_path)
|
||||
print(f"config: {nlp_config}")
|
||||
nlp, _ = util.load_model_from_config(nlp_config)
|
||||
|
||||
# ensure the nlp object was defined with a textcat component
|
||||
if "textcat" not in nlp.pipe_names:
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["textcat"]
|
||||
|
||||
[nlp.pipeline.textcat]
|
||||
[components]
|
||||
|
||||
[components.textcat]
|
||||
factory = "textcat"
|
||||
|
||||
[nlp.pipeline.textcat.model]
|
||||
@architectures = "spacy.TextCatCNN.v1"
|
||||
exclusive_classes = false
|
||||
|
||||
[nlp.pipeline.textcat.model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 96
|
||||
depth = 4
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = true
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
|
|
Loading…
Reference in New Issue
Block a user