mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	addressing suggestions by @polm
This commit is contained in:
		
							parent
							
								
									e4b4b67ef6
								
							
						
					
					
						commit
						06d680b269
					
				|  | @ -511,20 +511,24 @@ class SpanPredictor(TrainablePipe): | |||
|         set_dropout_rate(self.model, drop) | ||||
| 
 | ||||
|         total_loss = 0 | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         for doc, eg in zip(docs, examples): | ||||
|         old_spans = [eg.predicted.spans for eg in examples] | ||||
|         for eg in examples: | ||||
|             # replicates the EntityLinker's behaviour and | ||||
|             # copies annotations over https://bit.ly/3iweDcW | ||||
|             # takes 'coref_head_clusters' from the reference. | ||||
|             # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 | ||||
|             doc = eg.predicted | ||||
|             for key, sg in eg.reference.spans.items(): | ||||
|                 if key.startswith(self.input_prefix): | ||||
|                     aligned_spans = eg.get_aligned_spans_x2y(sg) | ||||
|                     doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans] | ||||
|                     doc.spans[key] = eg.get_aligned_spans_y2x(sg) | ||||
|             span_scores, backprop = self.model.begin_update([doc]) | ||||
|             loss, d_scores = self.get_loss([eg], span_scores) | ||||
|             total_loss += loss | ||||
|             # TODO check shape here | ||||
|             backprop(d_scores) | ||||
|         # Restore examples | ||||
|         for spans, eg in zip(old_spans, examples): | ||||
|             for key, sg in spans.items(): | ||||
|                 eg.predicted.spans[key] = sg | ||||
| 
 | ||||
|         if sgd is not None: | ||||
|             self.finish_update(sgd) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user