mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	Fix multitask objectives
This commit is contained in:
		
							parent
							
								
									d1246c95fb
								
							
						
					
					
						commit
						8f06903e09
					
				|  | @ -681,13 +681,19 @@ class MultitaskObjective(Tagger): | ||||||
|         return tokvecs, scores |         return tokvecs, scores | ||||||
| 
 | 
 | ||||||
|     def get_loss(self, docs, golds, scores): |     def get_loss(self, docs, golds, scores): | ||||||
|  |         assert len(docs) == len(golds) | ||||||
|         cdef int idx = 0 |         cdef int idx = 0 | ||||||
|         correct = numpy.zeros((scores.shape[0],), dtype='i') |         correct = numpy.zeros((scores.shape[0],), dtype='i') | ||||||
|         guesses = scores.argmax(axis=1) |         guesses = scores.argmax(axis=1) | ||||||
|         for gold in golds: |         for i, gold in enumerate(golds): | ||||||
|             for i in range(len(gold.labels)): |             for j in range(len(docs[i])): | ||||||
|                 label = self.make_label(i, gold.words, gold.tags, gold.heads, |                 # Handes alignment for tokenization differences | ||||||
|                                         gold.labels, gold.ents) |                 gold_idx = gold.cand_to_gold[j] | ||||||
|  |                 if gold_idx is None: | ||||||
|  |                     idx += 1 | ||||||
|  |                     continue | ||||||
|  |                 label = self.make_label(gold_idx, gold.words, gold.tags, | ||||||
|  |                                         gold.heads, gold.labels, gold.ents) | ||||||
|                 if label is None or label not in self.labels: |                 if label is None or label not in self.labels: | ||||||
|                     correct[idx] = guesses[idx] |                     correct[idx] = guesses[idx] | ||||||
|                 else: |                 else: | ||||||
|  |  | ||||||
|  | @ -542,6 +542,7 @@ cdef class Parser: | ||||||
|     def update(self, docs, golds, drop=0., sgd=None, losses=None): |     def update(self, docs, golds, drop=0., sgd=None, losses=None): | ||||||
|         if not any(self.moves.has_gold(gold) for gold in golds): |         if not any(self.moves.has_gold(gold) for gold in golds): | ||||||
|             return None |             return None | ||||||
|  |         assert len(docs) == len(golds) | ||||||
|         if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0: |         if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0: | ||||||
|             return self.update_beam(docs, golds, |             return self.update_beam(docs, golds, | ||||||
|                     self.cfg['beam_width'], self.cfg['beam_density'], |                     self.cfg['beam_width'], self.cfg['beam_density'], | ||||||
|  | @ -551,6 +552,8 @@ cdef class Parser: | ||||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): |         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||||
|             docs = [docs] |             docs = [docs] | ||||||
|             golds = [golds] |             golds = [golds] | ||||||
|  |         for multitask in self._multitasks: | ||||||
|  |             multitask.update(docs, golds, drop=drop, sgd=sgd) | ||||||
|         cuda_stream = util.get_cuda_stream() |         cuda_stream = util.get_cuda_stream() | ||||||
|         states, golds, max_steps = self._init_gold_batch(docs, golds) |         states, golds, max_steps = self._init_gold_batch(docs, golds) | ||||||
|         (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, |         (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, | ||||||
|  | @ -605,8 +608,6 @@ cdef class Parser: | ||||||
|                 break |                 break | ||||||
|         self._make_updates(d_tokvecs, |         self._make_updates(d_tokvecs, | ||||||
|             bp_tokvecs, backprops, sgd, cuda_stream) |             bp_tokvecs, backprops, sgd, cuda_stream) | ||||||
|         for multitask in self._multitasks: |  | ||||||
|             multitask.update(docs, golds, drop=drop, sgd=sgd) |  | ||||||
|      |      | ||||||
|     def update_beam(self, docs, golds, width=None, density=None, |     def update_beam(self, docs, golds, width=None, density=None, | ||||||
|             drop=0., sgd=None, losses=None): |             drop=0., sgd=None, losses=None): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user