Improve output on trainer

This commit is contained in:
Matthew Honnibal 2017-03-11 11:12:48 -06:00
parent b438dfd3f3
commit 1224c4d3c6

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import unicode_literals
import random
import tqdm
from .gold import GoldParse
from .scorer import Scorer
from .gold import merge_sents
@ -12,11 +13,12 @@ class Trainer(object):
def __init__(self, nlp, gold_tuples):
self.nlp = nlp
self.gold_tuples = gold_tuples
self.nr_epoch = 0
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
cached_golds = {}
def _epoch(indices):
for i in indices:
for i in tqdm.tqdm(indices):
raw_text, paragraph_tuples = self.gold_tuples[i]
if gold_preproc:
raw_text = None
@ -39,11 +41,12 @@ class Trainer(object):
for itn in range(nr_epoch):
random.shuffle(indices)
yield _epoch(indices)
self.nr_epoch += 1
def update(self, doc, gold):
for process in self.nlp.pipeline:
if hasattr(process, 'update'):
process.update(doc, gold)
loss = process.update(doc, gold, itn=self.nr_epoch)
process(doc)
return doc