From 54a539a113d9a57406a970c09f5906df5a0dc97c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 23 Jul 2017 00:34:12 +0200 Subject: [PATCH] Finish text classifier example --- examples/training/train_textcat.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index fa6e4f6ad..033cc50a9 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -80,14 +80,14 @@ def load_data(): return (train_texts, train_cats), (dev_texts, dev_cats) -def main(): +def main(model_loc=None): nlp = spacy.lang.en.English() tokenizer = nlp.tokenizer textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE']) print("Load IMDB data") (train_texts, train_cats), (dev_texts, dev_cats) = load_data() - + print("Itn.\tLoss\tP\tR\tF") progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}' @@ -95,6 +95,15 @@ def main(): train_texts, train_cats, dev_texts, dev_cats, n_iter=20)): print(progress.format(i=i, loss=loss, **scores)) + # How to save, load and use + nlp.pipeline.append(textcat) + if model_loc is not None: + nlp.to_disk(model_loc) + + nlp = spacy.load(model_loc) + doc = nlp(u'This movie sucked!') + print(doc.cats) + if __name__ == '__main__':