From 080d29e092d6032d65541d78da53fb6122d1f71d Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Nov 2016 08:55:33 -0600 Subject: [PATCH] Fix train.py for 1.0 --- spacy/train.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/spacy/train.py b/spacy/train.py index 097218310..a86569100 100644 --- a/spacy/train.py +++ b/spacy/train.py @@ -14,22 +14,31 @@ class Trainer(object): self.gold_tuples = gold_tuples def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): - def _epoch(): - for raw_text, paragraph_tuples in self.gold_tuples: + cached_golds = {} + def _epoch(indices): + for i in indices: + raw_text, paragraph_tuples = self.gold_tuples[i] if gold_preproc: raw_text = None else: paragraph_tuples = merge_sents(paragraph_tuples) - if augment_data is not None: + if augment_data is None: + docs = self.make_docs(raw_text, paragraph_tuples) + if i in cached_golds: + golds = cached_golds[i] + else: + golds = self.make_golds(docs, paragraph_tuples) + else: raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples) - docs = self.make_docs(raw_text, paragraph_tuples) - golds = self.make_golds(docs, paragraph_tuples) + docs = self.make_docs(raw_text, paragraph_tuples) + golds = self.make_golds(docs, paragraph_tuples) for doc, gold in zip(docs, golds): yield doc, gold + indices = list(range(len(self.gold_tuples))) for itn in range(nr_epoch): - random.shuffle(self.gold_tuples) - yield _epoch() + random.shuffle(indices) + yield _epoch(indices) def update(self, doc, gold): for process in self.nlp.pipeline: @@ -48,7 +57,7 @@ class Trainer(object): docs = self.make_docs(raw_text, paragraph_tuples) golds = self.make_golds(docs, paragraph_tuples) for doc, gold in zip(docs, golds): - for process in self.nlp.pipeline[1:]: + for process in self.nlp.pipeline: process(doc) scorer.score(doc, gold) return scorer @@ -62,8 +71,8 @@ class Trainer(object): def make_golds(self, docs, paragraph_tuples): if len(docs) == 1: - return [GoldParse(docs[0], sent_tuples[0]) + return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0]) for sent_tuples in paragraph_tuples] else: - return [GoldParse(doc, sent_tuples[0]) + return [GoldParse.from_annot_tuples(doc, sent_tuples[0]) for doc, sent_tuples in zip(docs, paragraph_tuples)]