diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 7eb356100..6018827a4 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -16,6 +16,11 @@ from spacy.gold import GoldParse, minibatch from spacy.util import compounding from spacy.pipeline import TextCategorizer +# TODO: Remove this once we're not supporting models trained with thinc <6.9.0 +import thinc.neural._classes.layernorm +thinc.neural._classes.layernorm.set_compat_six_eight(False) + + def train_textcat(tokenizer, textcat, train_texts, train_cats, dev_texts, dev_cats, @@ -28,12 +33,13 @@ def train_textcat(tokenizer, textcat, train_docs = [tokenizer(text) for text in train_texts] train_gold = [GoldParse(doc, cats=cats) for doc, cats in zip(train_docs, train_cats)] - train_data = zip(train_docs, train_gold) + train_data = list(zip(train_docs, train_gold)) batch_sizes = compounding(4., 128., 1.001) for i in range(n_iter): losses = {} - train_data = tqdm.tqdm(train_data, leave=False) # Progress bar - for batch in minibatch(train_data, size=batch_sizes): + # Progress bar and minibatching + batches = minibatch(tqdm.tqdm(train_data, leave=False), size=batch_sizes) + for batch in batches: docs, golds = zip(*batch) textcat.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses) @@ -65,12 +71,13 @@ def evaluate(tokenizer, textcat, texts, cats): return {'textcat_p': precis, 'textcat_r': recall, 'textcat_f': fscore} -def load_data(): +def load_data(limit=0): # Partition off part of the train data --- avoid running experiments # against test. train_data, _ = thinc.extra.datasets.imdb() random.shuffle(train_data) + train_data = train_data[-limit:] texts, labels = zip(*train_data) cats = [(['POSITIVE'] if y else []) for y in labels] @@ -90,7 +97,7 @@ def main(model_loc=None): textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE']) print("Load IMDB data") - (train_texts, train_cats), (dev_texts, dev_cats) = load_data() + (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=1000) print("Itn.\tLoss\tP\tR\tF") progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}'