mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
💫 Improve error message when model.from_bytes() dies (#4014)
* Improve error message when model.from_bytes() dies When Thinc's model.from_bytes() is called with a mismatched model, often we get a particularly ungraceful error, e.g. "AttributeError: FunctionLayer has no attribute G" This is because we're trying to load the parameters for something like a LayerNorm layer, and the model architecture has some other layer there instead. This is obviously terrible, especially since the error *type* is wrong. I've changed it to raise a ValueError. The error message is still probably a bit terse, but it's hard to be sure exactly what's gone wrong. * Update spacy/pipeline/pipes.pyx * Update spacy/pipeline/pipes.pyx * Update spacy/pipeline/pipes.pyx * Update spacy/syntax/nn_parser.pyx * Update spacy/syntax/nn_parser.pyx * Update spacy/pipeline/pipes.pyx Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> * Update spacy/pipeline/pipes.pyx Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
87fcf3141c
commit
73e095923f
|
@ -413,7 +413,8 @@ class Errors(object):
|
|||
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
|
||||
"is assigned to a KB identifier.")
|
||||
|
||||
E149 = ("Error deserializing model. Check that the config used to create the "
|
||||
"component matches the model being loaded.")
|
||||
|
||||
@add_codes
|
||||
class TempErrors(object):
|
||||
|
|
|
@ -167,7 +167,10 @@ class Pipe(object):
|
|||
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
self.model.from_bytes(b)
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
deserialize = OrderedDict()
|
||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||
|
@ -196,7 +199,10 @@ class Pipe(object):
|
|||
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
try:
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
deserialize = OrderedDict()
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||
|
@ -562,7 +568,10 @@ class Tagger(Pipe):
|
|||
"token_vector_width",
|
||||
self.cfg.get("token_vector_width", 96))
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
self.model.from_bytes(b)
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
def load_tag_map(b):
|
||||
tag_map = srsly.msgpack_loads(b)
|
||||
|
@ -600,7 +609,10 @@ class Tagger(Pipe):
|
|||
if self.model is True:
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
with p.open("rb") as file_:
|
||||
self.model.from_bytes(file_.read())
|
||||
try:
|
||||
self.model.from_bytes(file_.read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
def load_tag_map(p):
|
||||
tag_map = srsly.read_msgpack(p)
|
||||
|
@ -1315,9 +1327,12 @@ class EntityLinker(Pipe):
|
|||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
def load_model(p):
|
||||
if self.model is True:
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
try:
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
def load_kb(p):
|
||||
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
||||
|
|
|
@ -631,7 +631,10 @@ cdef class Parser:
|
|||
cfg = {}
|
||||
with (path / 'model').open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
self.model.from_bytes(bytes_data)
|
||||
try:
|
||||
self.model.from_bytes(bytes_data)
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
self.cfg.update(cfg)
|
||||
return self
|
||||
|
||||
|
@ -663,6 +666,9 @@ cdef class Parser:
|
|||
else:
|
||||
cfg = {}
|
||||
if 'model' in msg:
|
||||
self.model.from_bytes(msg['model'])
|
||||
try:
|
||||
self.model.from_bytes(msg['model'])
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
self.cfg.update(cfg)
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue
Block a user