diff --git a/spacy/train.py b/spacy/train.py index a86569100..175c99cf2 100644 --- a/spacy/train.py +++ b/spacy/train.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import unicode_literals import random +import tqdm from .gold import GoldParse from .scorer import Scorer from .gold import merge_sents @@ -12,11 +13,12 @@ class Trainer(object): def __init__(self, nlp, gold_tuples): self.nlp = nlp self.gold_tuples = gold_tuples + self.nr_epoch = 0 def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): cached_golds = {} def _epoch(indices): - for i in indices: + for i in tqdm.tqdm(indices): raw_text, paragraph_tuples = self.gold_tuples[i] if gold_preproc: raw_text = None @@ -39,11 +41,12 @@ class Trainer(object): for itn in range(nr_epoch): random.shuffle(indices) yield _epoch(indices) - + self.nr_epoch += 1 + def update(self, doc, gold): for process in self.nlp.pipeline: if hasattr(process, 'update'): - process.update(doc, gold) + loss = process.update(doc, gold, itn=self.nr_epoch) process(doc) return doc