mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Fix train.py for 1.0
This commit is contained in:
		
							parent
							
								
									271a120d30
								
							
						
					
					
						commit
						080d29e092
					
				| 
						 | 
					@ -14,22 +14,31 @@ class Trainer(object):
 | 
				
			||||||
        self.gold_tuples = gold_tuples
 | 
					        self.gold_tuples = gold_tuples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
 | 
					    def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
 | 
				
			||||||
        def _epoch():
 | 
					        cached_golds = {}
 | 
				
			||||||
            for raw_text, paragraph_tuples in self.gold_tuples:
 | 
					        def _epoch(indices):
 | 
				
			||||||
 | 
					            for i in indices:
 | 
				
			||||||
 | 
					                raw_text, paragraph_tuples = self.gold_tuples[i]
 | 
				
			||||||
                if gold_preproc:
 | 
					                if gold_preproc:
 | 
				
			||||||
                    raw_text = None
 | 
					                    raw_text = None
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    paragraph_tuples = merge_sents(paragraph_tuples)
 | 
					                    paragraph_tuples = merge_sents(paragraph_tuples)
 | 
				
			||||||
                if augment_data is not None:
 | 
					                if augment_data is None:
 | 
				
			||||||
 | 
					                    docs = self.make_docs(raw_text, paragraph_tuples)
 | 
				
			||||||
 | 
					                    if i in cached_golds:
 | 
				
			||||||
 | 
					                        golds = cached_golds[i]
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        golds = self.make_golds(docs, paragraph_tuples)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
                    raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
 | 
					                    raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
 | 
				
			||||||
                    docs = self.make_docs(raw_text, paragraph_tuples)
 | 
					                    docs = self.make_docs(raw_text, paragraph_tuples)
 | 
				
			||||||
                    golds = self.make_golds(docs, paragraph_tuples)
 | 
					                    golds = self.make_golds(docs, paragraph_tuples)
 | 
				
			||||||
                for doc, gold in zip(docs, golds):
 | 
					                for doc, gold in zip(docs, golds):
 | 
				
			||||||
                    yield doc, gold
 | 
					                    yield doc, gold
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        indices = list(range(len(self.gold_tuples)))
 | 
				
			||||||
        for itn in range(nr_epoch):
 | 
					        for itn in range(nr_epoch):
 | 
				
			||||||
            random.shuffle(self.gold_tuples)
 | 
					            random.shuffle(indices)
 | 
				
			||||||
            yield _epoch()
 | 
					            yield _epoch(indices)
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
    def update(self, doc, gold):
 | 
					    def update(self, doc, gold):
 | 
				
			||||||
        for process in self.nlp.pipeline:
 | 
					        for process in self.nlp.pipeline:
 | 
				
			||||||
| 
						 | 
					@ -48,7 +57,7 @@ class Trainer(object):
 | 
				
			||||||
            docs = self.make_docs(raw_text, paragraph_tuples)
 | 
					            docs = self.make_docs(raw_text, paragraph_tuples)
 | 
				
			||||||
            golds = self.make_golds(docs, paragraph_tuples)
 | 
					            golds = self.make_golds(docs, paragraph_tuples)
 | 
				
			||||||
            for doc, gold in zip(docs, golds):
 | 
					            for doc, gold in zip(docs, golds):
 | 
				
			||||||
                for process in self.nlp.pipeline[1:]:
 | 
					                for process in self.nlp.pipeline:
 | 
				
			||||||
                    process(doc)
 | 
					                    process(doc)
 | 
				
			||||||
                scorer.score(doc, gold)
 | 
					                scorer.score(doc, gold)
 | 
				
			||||||
        return scorer
 | 
					        return scorer
 | 
				
			||||||
| 
						 | 
					@ -62,8 +71,8 @@ class Trainer(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def make_golds(self, docs, paragraph_tuples):
 | 
					    def make_golds(self, docs, paragraph_tuples):
 | 
				
			||||||
        if len(docs) == 1:
 | 
					        if len(docs) == 1:
 | 
				
			||||||
            return [GoldParse(docs[0], sent_tuples[0])
 | 
					            return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
 | 
				
			||||||
                    for sent_tuples in paragraph_tuples]
 | 
					                    for sent_tuples in paragraph_tuples]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return [GoldParse(doc, sent_tuples[0])
 | 
					            return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
 | 
				
			||||||
                    for doc, sent_tuples in zip(docs, paragraph_tuples)]
 | 
					                    for doc, sent_tuples in zip(docs, paragraph_tuples)]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user