mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix loading models with pretrained vectors
This commit is contained in:
		
							parent
							
								
									96b612873b
								
							
						
					
					
						commit
						81f4005f3d
					
				| 
						 | 
				
			
			@ -636,11 +636,11 @@ class Language(object):
 | 
			
		|||
        """
 | 
			
		||||
        path = util.ensure_path(path)
 | 
			
		||||
        deserializers = OrderedDict((
 | 
			
		||||
            ('vocab', lambda p: self.vocab.from_disk(p)),
 | 
			
		||||
            ('meta.json', lambda p: self.meta.update(util.read_json(p))),
 | 
			
		||||
            ('vocab', lambda p: (
 | 
			
		||||
                self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self))),
 | 
			
		||||
            ('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
 | 
			
		||||
            ('meta.json', lambda p: self.meta.update(util.read_json(p)))
 | 
			
		||||
        ))
 | 
			
		||||
        _fix_pretrained_vectors_name(self)
 | 
			
		||||
        for name, proc in self.pipeline:
 | 
			
		||||
            if name in disable:
 | 
			
		||||
                continue
 | 
			
		||||
| 
						 | 
				
			
			@ -682,11 +682,11 @@ class Language(object):
 | 
			
		|||
        RETURNS (Language): The `Language` object.
 | 
			
		||||
        """
 | 
			
		||||
        deserializers = OrderedDict((
 | 
			
		||||
            ('vocab', lambda b: self.vocab.from_bytes(b)),
 | 
			
		||||
            ('meta', lambda b: self.meta.update(ujson.loads(b))),
 | 
			
		||||
            ('vocab', lambda b: (
 | 
			
		||||
                self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self))),
 | 
			
		||||
            ('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
 | 
			
		||||
            ('meta', lambda b: self.meta.update(ujson.loads(b)))
 | 
			
		||||
        ))
 | 
			
		||||
        _fix_pretrained_vectors_name(self)
 | 
			
		||||
        for i, (name, proc) in enumerate(self.pipeline):
 | 
			
		||||
            if name in disable:
 | 
			
		||||
                continue
 | 
			
		||||
| 
						 | 
				
			
			@ -708,12 +708,12 @@ def _fix_pretrained_vectors_name(nlp):
 | 
			
		|||
        nlp.vocab.vectors.name = vectors_name
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(Errors.E092)
 | 
			
		||||
    link_vectors_to_models(nlp.vocab)
 | 
			
		||||
    for name, proc in nlp.pipeline:
 | 
			
		||||
        if not hasattr(proc, 'cfg'):
 | 
			
		||||
            continue
 | 
			
		||||
        if proc.cfg.get('pretrained_dims'):
 | 
			
		||||
            assert nlp.vocab.vectors.name
 | 
			
		||||
            proc.cfg['pretrained_vectors'] = nlp.vocab.vectors.name
 | 
			
		||||
        proc.cfg.setdefault('deprecation_fixes', {})
 | 
			
		||||
        proc.cfg['deprecation_fixes']['vectors_name'] = nlp.vocab.vectors.name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DisabledPipes(list):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user