mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	prepare for aligned heads-spans training
This commit is contained in:
		
							parent
							
								
									63a41ba50a
								
							
						
					
					
						commit
						a1d0219903
					
				| 
						 | 
					@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
            losses = {}
 | 
					            losses = {}
 | 
				
			||||||
        losses.setdefault(self.name, 0.0)
 | 
					        losses.setdefault(self.name, 0.0)
 | 
				
			||||||
        validate_examples(examples, "SpanPredictor.update")
 | 
					        validate_examples(examples, "SpanPredictor.update")
 | 
				
			||||||
        if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
 | 
					        if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
 | 
				
			||||||
            # Handle cases where there are no tokens in any docs.
 | 
					            # Handle cases where there are no tokens in any docs.
 | 
				
			||||||
            return losses
 | 
					            return losses
 | 
				
			||||||
        set_dropout_rate(self.model, drop)
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        total_loss = 0
 | 
					        total_loss = 0
 | 
				
			||||||
        for eg in examples:
 | 
					        for eg in examples:
 | 
				
			||||||
            # replicates the EntityLinker's behaviour and
 | 
					            # For update we use the gold coref_head_clusters
 | 
				
			||||||
            # copies annotations over https://bit.ly/3iweDcW
 | 
					            # in the reference.
 | 
				
			||||||
            # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
 | 
					            span_scores, backprop = self.model.begin_update([eg.reference])
 | 
				
			||||||
            doc = eg.predicted
 | 
					 | 
				
			||||||
            old_spans = eg.predicted.spans
 | 
					 | 
				
			||||||
            for key, sg in eg.reference.spans.items():
 | 
					 | 
				
			||||||
                if key.startswith(self.input_prefix):
 | 
					 | 
				
			||||||
                    doc.spans[key] = eg.get_aligned_spans_y2x(sg)
 | 
					 | 
				
			||||||
            span_scores, backprop = self.model.begin_update([doc])
 | 
					 | 
				
			||||||
            loss, d_scores = self.get_loss([eg], span_scores)
 | 
					            loss, d_scores = self.get_loss([eg], span_scores)
 | 
				
			||||||
            total_loss += loss
 | 
					            total_loss += loss
 | 
				
			||||||
            # TODO check shape here
 | 
					            # TODO check shape here
 | 
				
			||||||
            backprop(d_scores)
 | 
					            backprop(d_scores)
 | 
				
			||||||
            # Restore example
 | 
					 | 
				
			||||||
            for key, sg in old_spans.items():
 | 
					 | 
				
			||||||
                eg.predicted.spans[key] = sg
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if sgd is not None:
 | 
					        if sgd is not None:
 | 
				
			||||||
            self.finish_update(sgd)
 | 
					            self.finish_update(sgd)
 | 
				
			||||||
| 
						 | 
					@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
        # span_scores is a Floats3d. What are the axes? mention x token x start/end
 | 
					        # span_scores is a Floats3d. What are the axes? mention x token x start/end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for eg in examples:
 | 
					        for eg in examples:
 | 
				
			||||||
 | 
					 | 
				
			||||||
            # get gold data
 | 
					 | 
				
			||||||
            gold = doc2clusters(eg.predicted, self.input_prefix)
 | 
					 | 
				
			||||||
            # flatten the gold data
 | 
					 | 
				
			||||||
            starts = []
 | 
					            starts = []
 | 
				
			||||||
            ends = []
 | 
					            ends = []
 | 
				
			||||||
            for cluster in gold:
 | 
					            for key, sg in eg.reference.spans.items():
 | 
				
			||||||
                for mention in cluster:
 | 
					                if key.startswith(self.output_prefix):
 | 
				
			||||||
                    starts.append(mention[0])
 | 
					                    for mention in sg:
 | 
				
			||||||
                    # XXX I think this was missing here
 | 
					                        starts.append(mention.start)
 | 
				
			||||||
                    ends.append(mention[1] - 1)
 | 
					                        ends.append(mention.end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            starts = self.model.ops.xp.asarray(starts)
 | 
					            starts = self.model.ops.xp.asarray(starts)
 | 
				
			||||||
            ends = self.model.ops.xp.asarray(ends)
 | 
					            ends = self.model.ops.xp.asarray(ends)
 | 
				
			||||||
            start_scores = span_scores[:, :, 0]
 | 
					            start_scores = span_scores[:, :, 0]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user