mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix train.py for 1.0
This commit is contained in:
		
							parent
							
								
									271a120d30
								
							
						
					
					
						commit
						080d29e092
					
				| 
						 | 
				
			
			@ -14,22 +14,31 @@ class Trainer(object):
 | 
			
		|||
        self.gold_tuples = gold_tuples
 | 
			
		||||
 | 
			
		||||
    def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
 | 
			
		||||
        def _epoch():
 | 
			
		||||
            for raw_text, paragraph_tuples in self.gold_tuples:
 | 
			
		||||
        cached_golds = {}
 | 
			
		||||
        def _epoch(indices):
 | 
			
		||||
            for i in indices:
 | 
			
		||||
                raw_text, paragraph_tuples = self.gold_tuples[i]
 | 
			
		||||
                if gold_preproc:
 | 
			
		||||
                    raw_text = None
 | 
			
		||||
                else:
 | 
			
		||||
                    paragraph_tuples = merge_sents(paragraph_tuples)
 | 
			
		||||
                if augment_data is not None:
 | 
			
		||||
                if augment_data is None:
 | 
			
		||||
                    docs = self.make_docs(raw_text, paragraph_tuples)
 | 
			
		||||
                    if i in cached_golds:
 | 
			
		||||
                        golds = cached_golds[i]
 | 
			
		||||
                    else:
 | 
			
		||||
                        golds = self.make_golds(docs, paragraph_tuples)
 | 
			
		||||
                else:
 | 
			
		||||
                    raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
 | 
			
		||||
                    docs = self.make_docs(raw_text, paragraph_tuples)
 | 
			
		||||
                    golds = self.make_golds(docs, paragraph_tuples)
 | 
			
		||||
                for doc, gold in zip(docs, golds):
 | 
			
		||||
                    yield doc, gold
 | 
			
		||||
 | 
			
		||||
        indices = list(range(len(self.gold_tuples)))
 | 
			
		||||
        for itn in range(nr_epoch):
 | 
			
		||||
            random.shuffle(self.gold_tuples)
 | 
			
		||||
            yield _epoch()
 | 
			
		||||
            random.shuffle(indices)
 | 
			
		||||
            yield _epoch(indices)
 | 
			
		||||
 
 | 
			
		||||
    def update(self, doc, gold):
 | 
			
		||||
        for process in self.nlp.pipeline:
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +57,7 @@ class Trainer(object):
 | 
			
		|||
            docs = self.make_docs(raw_text, paragraph_tuples)
 | 
			
		||||
            golds = self.make_golds(docs, paragraph_tuples)
 | 
			
		||||
            for doc, gold in zip(docs, golds):
 | 
			
		||||
                for process in self.nlp.pipeline[1:]:
 | 
			
		||||
                for process in self.nlp.pipeline:
 | 
			
		||||
                    process(doc)
 | 
			
		||||
                scorer.score(doc, gold)
 | 
			
		||||
        return scorer
 | 
			
		||||
| 
						 | 
				
			
			@ -62,8 +71,8 @@ class Trainer(object):
 | 
			
		|||
 | 
			
		||||
    def make_golds(self, docs, paragraph_tuples):
 | 
			
		||||
        if len(docs) == 1:
 | 
			
		||||
            return [GoldParse(docs[0], sent_tuples[0])
 | 
			
		||||
            return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
 | 
			
		||||
                    for sent_tuples in paragraph_tuples]
 | 
			
		||||
        else:
 | 
			
		||||
            return [GoldParse(doc, sent_tuples[0])
 | 
			
		||||
            return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
 | 
			
		||||
                    for doc, sent_tuples in zip(docs, paragraph_tuples)]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user