mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	💫 Improve error message when model.from_bytes() dies (#4014)
* Improve error message when model.from_bytes() dies When Thinc's model.from_bytes() is called with a mismatched model, often we get a particularly ungraceful error, e.g. "AttributeError: FunctionLayer has no attribute G" This is because we're trying to load the parameters for something like a LayerNorm layer, and the model architecture has some other layer there instead. This is obviously terrible, especially since the error *type* is wrong. I've changed it to raise a ValueError. The error message is still probably a bit terse, but it's hard to be sure exactly what's gone wrong. * Update spacy/pipeline/pipes.pyx * Update spacy/pipeline/pipes.pyx * Update spacy/pipeline/pipes.pyx * Update spacy/syntax/nn_parser.pyx * Update spacy/syntax/nn_parser.pyx * Update spacy/pipeline/pipes.pyx Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> * Update spacy/pipeline/pipes.pyx Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
		
							parent
							
								
									87fcf3141c
								
							
						
					
					
						commit
						73e095923f
					
				|  | @ -413,7 +413,8 @@ class Errors(object): | ||||||
|             "This is likely a bug in spaCy, so feel free to open an issue.") |             "This is likely a bug in spaCy, so feel free to open an issue.") | ||||||
|     E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` " |     E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` " | ||||||
|             "is assigned to a KB identifier.") |             "is assigned to a KB identifier.") | ||||||
| 
 |     E149 = ("Error deserializing model. Check that the config used to create the " | ||||||
|  |             "component matches the model being loaded.") | ||||||
| 
 | 
 | ||||||
| @add_codes | @add_codes | ||||||
| class TempErrors(object): | class TempErrors(object): | ||||||
|  |  | ||||||
|  | @ -167,7 +167,10 @@ class Pipe(object): | ||||||
|                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name |                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name | ||||||
|             if self.model is True: |             if self.model is True: | ||||||
|                 self.model = self.Model(**self.cfg) |                 self.model = self.Model(**self.cfg) | ||||||
|             self.model.from_bytes(b) |             try: | ||||||
|  |                 self.model.from_bytes(b) | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise ValueError(Errors.E149) | ||||||
| 
 | 
 | ||||||
|         deserialize = OrderedDict() |         deserialize = OrderedDict() | ||||||
|         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) |         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) | ||||||
|  | @ -196,7 +199,10 @@ class Pipe(object): | ||||||
|                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name |                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name | ||||||
|             if self.model is True: |             if self.model is True: | ||||||
|                 self.model = self.Model(**self.cfg) |                 self.model = self.Model(**self.cfg) | ||||||
|             self.model.from_bytes(p.open("rb").read()) |             try: | ||||||
|  |                 self.model.from_bytes(p.open("rb").read()) | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise ValueError(Errors.E149) | ||||||
| 
 | 
 | ||||||
|         deserialize = OrderedDict() |         deserialize = OrderedDict() | ||||||
|         deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) |         deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) | ||||||
|  | @ -562,7 +568,10 @@ class Tagger(Pipe): | ||||||
|                     "token_vector_width", |                     "token_vector_width", | ||||||
|                     self.cfg.get("token_vector_width", 96)) |                     self.cfg.get("token_vector_width", 96)) | ||||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) |                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||||
|             self.model.from_bytes(b) |             try: | ||||||
|  |                 self.model.from_bytes(b) | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise ValueError(Errors.E149) | ||||||
| 
 | 
 | ||||||
|         def load_tag_map(b): |         def load_tag_map(b): | ||||||
|             tag_map = srsly.msgpack_loads(b) |             tag_map = srsly.msgpack_loads(b) | ||||||
|  | @ -600,7 +609,10 @@ class Tagger(Pipe): | ||||||
|             if self.model is True: |             if self.model is True: | ||||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) |                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||||
|             with p.open("rb") as file_: |             with p.open("rb") as file_: | ||||||
|                 self.model.from_bytes(file_.read()) |                 try: | ||||||
|  |                     self.model.from_bytes(file_.read()) | ||||||
|  |                 except AttributeError: | ||||||
|  |                     raise ValueError(Errors.E149) | ||||||
| 
 | 
 | ||||||
|         def load_tag_map(p): |         def load_tag_map(p): | ||||||
|             tag_map = srsly.read_msgpack(p) |             tag_map = srsly.read_msgpack(p) | ||||||
|  | @ -1315,9 +1327,12 @@ class EntityLinker(Pipe): | ||||||
| 
 | 
 | ||||||
|     def from_disk(self, path, exclude=tuple(), **kwargs): |     def from_disk(self, path, exclude=tuple(), **kwargs): | ||||||
|         def load_model(p): |         def load_model(p): | ||||||
|              if self.model is True: |             if self.model is True: | ||||||
|                 self.model = self.Model(**self.cfg) |                 self.model = self.Model(**self.cfg) | ||||||
|              self.model.from_bytes(p.open("rb").read()) |             try:  | ||||||
|  |                 self.model.from_bytes(p.open("rb").read()) | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise ValueError(Errors.E149) | ||||||
| 
 | 
 | ||||||
|         def load_kb(p): |         def load_kb(p): | ||||||
|             kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) |             kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) | ||||||
|  |  | ||||||
|  | @ -631,7 +631,10 @@ cdef class Parser: | ||||||
|                 cfg = {} |                 cfg = {} | ||||||
|             with (path / 'model').open('rb') as file_: |             with (path / 'model').open('rb') as file_: | ||||||
|                 bytes_data = file_.read() |                 bytes_data = file_.read() | ||||||
|             self.model.from_bytes(bytes_data) |             try: | ||||||
|  |                 self.model.from_bytes(bytes_data) | ||||||
|  |             except AttributeError: | ||||||
|  |                 raise ValueError(Errors.E149) | ||||||
|             self.cfg.update(cfg) |             self.cfg.update(cfg) | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|  | @ -663,6 +666,9 @@ cdef class Parser: | ||||||
|             else: |             else: | ||||||
|                 cfg = {} |                 cfg = {} | ||||||
|             if 'model' in msg: |             if 'model' in msg: | ||||||
|                 self.model.from_bytes(msg['model']) |                 try: | ||||||
|  |                     self.model.from_bytes(msg['model']) | ||||||
|  |                 except AttributeError: | ||||||
|  |                     raise ValueError(Errors.E149) | ||||||
|             self.cfg.update(cfg) |             self.cfg.update(cfg) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user