Update train method

This commit is contained in:
Matthew Honnibal 2016-10-13 03:24:53 +02:00
parent 645d99523a
commit 9b55d97a8f

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals
import random
from .gold import GoldParse
from .scorer import Scorer
from .gold import merge_sents
class Trainer(object):
@ -12,9 +13,13 @@ class Trainer(object):
self.nlp = nlp
self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None):
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
def _epoch():
for raw_text, paragraph_tuples in self.gold_tuples:
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
if augment_data is not None:
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples)
@ -33,9 +38,13 @@ class Trainer(object):
process(doc)
return doc
def evaluate(self, dev_sents):
def evaluate(self, dev_sents, gold_preproc=False):
scorer = Scorer()
for raw_text, paragraph_tuples in dev_sents:
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):