From 350785d8ce87ad7973f1386a44afcbe62336467c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 23 Mar 2019 16:10:44 +0100 Subject: [PATCH] Fix size limits in train_textcat example --- examples/training/train_textcat.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 6745ddba6..7cd492f75 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -43,7 +43,11 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None # nlp.create_pipe works for built-ins that are registered with spaCy if "textcat" not in nlp.pipe_names: textcat = nlp.create_pipe( - "textcat", config={"architecture": "simple_cnn", "exclusive_classes": True} + "textcat", + config={ + "exclusive_classes": True, + "architecture": "simple_cnn", + } ) nlp.add_pipe(textcat, last=True) # otherwise, get it, so we can add labels to it @@ -56,7 +60,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None # load the IMDB dataset print("Loading IMDB data...") - (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts) + (train_texts, train_cats), (dev_texts, dev_cats) = load_data() + train_texts = train_texts[:n_texts] + train_cats = train_cats[:n_texts] print( "Using {} examples ({} training, {} evaluation)".format( n_texts, len(train_texts), len(dev_texts)