mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Fix parser for GPU
This commit is contained in:
		
							parent
							
								
									260707a4c3
								
							
						
					
					
						commit
						7431e9c87f
					
				| 
						 | 
					@ -19,12 +19,10 @@ cdef struct WeightsC:
 | 
				
			||||||
    const float* feat_bias
 | 
					    const float* feat_bias
 | 
				
			||||||
    const float* hidden_bias
 | 
					    const float* hidden_bias
 | 
				
			||||||
    const float* hidden_weights
 | 
					    const float* hidden_weights
 | 
				
			||||||
    const float* vectors
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef struct ActivationsC:
 | 
					cdef struct ActivationsC:
 | 
				
			||||||
    int* token_ids
 | 
					    int* token_ids
 | 
				
			||||||
    float* vectors
 | 
					 | 
				
			||||||
    float* unmaxed
 | 
					    float* unmaxed
 | 
				
			||||||
    float* scores
 | 
					    float* scores
 | 
				
			||||||
    float* hiddens
 | 
					    float* hiddens
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,8 +50,6 @@ cdef WeightsC get_c_weights(model) except *:
 | 
				
			||||||
    cdef np.ndarray vec2scores_b = model.vec2scores.b
 | 
					    cdef np.ndarray vec2scores_b = model.vec2scores.b
 | 
				
			||||||
    output.hidden_weights = <const float*>vec2scores_W.data
 | 
					    output.hidden_weights = <const float*>vec2scores_W.data
 | 
				
			||||||
    output.hidden_bias = <const float*>vec2scores_b.data
 | 
					    output.hidden_bias = <const float*>vec2scores_b.data
 | 
				
			||||||
    cdef np.ndarray tokvecs = model.tokvecs
 | 
					 | 
				
			||||||
    output.vectors = <float*>tokvecs.data
 | 
					 | 
				
			||||||
    return output
 | 
					    return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,7 +70,6 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    if A._max_size == 0:
 | 
					    if A._max_size == 0:
 | 
				
			||||||
        A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
 | 
					        A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
 | 
				
			||||||
        A.vectors = <float*>calloc(n.states * n.embed_width, sizeof(A.vectors[0]))
 | 
					 | 
				
			||||||
        A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
 | 
					        A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
 | 
				
			||||||
        A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
 | 
					        A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
 | 
				
			||||||
        A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
 | 
					        A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
 | 
				
			||||||
| 
						 | 
					@ -81,8 +78,6 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        A.token_ids = <int*>realloc(A.token_ids,
 | 
					        A.token_ids = <int*>realloc(A.token_ids,
 | 
				
			||||||
            n.states * n.feats * sizeof(A.token_ids[0]))
 | 
					            n.states * n.feats * sizeof(A.token_ids[0]))
 | 
				
			||||||
        A.vectors = <float*>realloc(A.vectors,
 | 
					 | 
				
			||||||
            n.states * n.embed_width * sizeof(A.vectors[0]))
 | 
					 | 
				
			||||||
        A.scores = <float*>realloc(A.scores,
 | 
					        A.scores = <float*>realloc(A.scores,
 | 
				
			||||||
            n.states * n.classes * sizeof(A.scores[0]))
 | 
					            n.states * n.classes * sizeof(A.scores[0]))
 | 
				
			||||||
        A.unmaxed = <float*>realloc(A.unmaxed,
 | 
					        A.unmaxed = <float*>realloc(A.unmaxed,
 | 
				
			||||||
| 
						 | 
					@ -242,7 +237,7 @@ class ParserStepModel(Model):
 | 
				
			||||||
    def begin_update(self, states, drop=0.):
 | 
					    def begin_update(self, states, drop=0.):
 | 
				
			||||||
        token_ids = self.get_token_ids(states)
 | 
					        token_ids = self.get_token_ids(states)
 | 
				
			||||||
        vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
 | 
					        vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
 | 
				
			||||||
        mask = self.ops.get_dropout_mask(vector.shape, drop)
 | 
					        mask = self.vec2scores.ops.get_dropout_mask(vector.shape, drop)
 | 
				
			||||||
        if mask is not None:
 | 
					        if mask is not None:
 | 
				
			||||||
            vector *= mask
 | 
					            vector *= mask
 | 
				
			||||||
        scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
 | 
					        scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
 | 
				
			||||||
| 
						 | 
					@ -251,7 +246,7 @@ class ParserStepModel(Model):
 | 
				
			||||||
            d_vector = get_d_vector(d_scores, sgd=sgd)
 | 
					            d_vector = get_d_vector(d_scores, sgd=sgd)
 | 
				
			||||||
            if mask is not None:
 | 
					            if mask is not None:
 | 
				
			||||||
                d_vector *= mask
 | 
					                d_vector *= mask
 | 
				
			||||||
            if isinstance(self.ops, CupyOps) \
 | 
					            if isinstance(self.state2vec.ops, CupyOps) \
 | 
				
			||||||
            and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
 | 
					            and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
 | 
				
			||||||
                # Move token_ids and d_vector to GPU, asynchronously
 | 
					                # Move token_ids and d_vector to GPU, asynchronously
 | 
				
			||||||
                self.backprops.append((
 | 
					                self.backprops.append((
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user