mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	add backprop callback to spanpredictor
This commit is contained in:
		
							parent
							
								
									3ba913109d
								
							
						
					
					
						commit
						2a1ad4c5d2
					
				| 
						 | 
					@ -104,13 +104,13 @@ def convert_coref_scorer_inputs(
 | 
				
			||||||
    # just use the first
 | 
					    # just use the first
 | 
				
			||||||
    # TODO real batching
 | 
					    # TODO real batching
 | 
				
			||||||
    X = X[0]
 | 
					    X = X[0]
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    word_features = xp2torch(X, requires_grad=is_train)
 | 
					    word_features = xp2torch(X, requires_grad=is_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
					    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
				
			||||||
        # convert to xp and wrap in list
 | 
					        # convert to xp and wrap in list
 | 
				
			||||||
        gradients = torch2xp(args.args[0])
 | 
					        gradients = torch2xp(args.args[0])
 | 
				
			||||||
        return [gradients]
 | 
					        return [gradients]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
 | 
					    return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -141,16 +141,22 @@ def convert_span_predictor_inputs(
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    tok2vec, (sent_ids, head_ids) = X
 | 
					    tok2vec, (sent_ids, head_ids) = X
 | 
				
			||||||
    # Normally we shoudl use the input is_train, but for these two it's not relevant
 | 
					    # Normally we shoudl use the input is_train, but for these two it's not relevant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
				
			||||||
 | 
					        # convert to xp and wrap in list
 | 
				
			||||||
 | 
					        gradients = torch2xp(args.args[1])
 | 
				
			||||||
 | 
					        return [[gradients], None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    word_features = xp2torch(tok2vec[0], requires_grad=is_train)
 | 
				
			||||||
    sent_ids = xp2torch(sent_ids[0], requires_grad=False)
 | 
					    sent_ids = xp2torch(sent_ids[0], requires_grad=False)
 | 
				
			||||||
    if not head_ids[0].size:
 | 
					    if not head_ids[0].size:
 | 
				
			||||||
        head_ids = torch.empty(size=(0,))
 | 
					        head_ids = torch.empty(size=(0,))
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        head_ids = xp2torch(head_ids[0], requires_grad=False)
 | 
					        head_ids = xp2torch(head_ids[0], requires_grad=False)
 | 
				
			||||||
    word_features = xp2torch(tok2vec[0], requires_grad=is_train)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
 | 
					    argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
 | 
				
			||||||
    # TODO actually support backprop
 | 
					    # TODO actually support backprop
 | 
				
			||||||
    return argskwargs, lambda dX: [[]]
 | 
					    return argskwargs, backprop
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO This probably belongs in the component, not the model.
 | 
					# TODO This probably belongs in the component, not the model.
 | 
				
			||||||
| 
						 | 
					@ -247,7 +253,6 @@ def head_data_forward(model, docs, is_train):
 | 
				
			||||||
                heads.append(span[0].i)
 | 
					                heads.append(span[0].i)
 | 
				
			||||||
        heads = model.ops.asarray2i(heads)
 | 
					        heads = model.ops.asarray2i(heads)
 | 
				
			||||||
        head_ids.append(heads)
 | 
					        head_ids.append(heads)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # each of these is a list with one entry per doc
 | 
					    # each of these is a list with one entry per doc
 | 
				
			||||||
    # backprop is just a placeholder
 | 
					    # backprop is just a placeholder
 | 
				
			||||||
    # TODO it would probably be better to have a list of tuples than two lists of arrays
 | 
					    # TODO it would probably be better to have a list of tuples than two lists of arrays
 | 
				
			||||||
| 
						 | 
					@ -584,7 +589,6 @@ class SpanPredictor(torch.nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
 | 
					        scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
 | 
				
			||||||
        scores[rows, cols] = res[padding_mask]
 | 
					        scores[rows, cols] = res[padding_mask]
 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Make sure that start <= head <= end during inference
 | 
					        # Make sure that start <= head <= end during inference
 | 
				
			||||||
        if not self.training:
 | 
					        if not self.training:
 | 
				
			||||||
            valid_starts = torch.log((relative_positions >= 0).to(torch.float))
 | 
					            valid_starts = torch.log((relative_positions >= 0).to(torch.float))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user