deuglify kb deserializer

This commit is contained in:
svlandeg 2019-07-03 15:00:42 +02:00
parent 8840d4b1b3
commit 668b17ea4a
2 changed files with 12 additions and 9 deletions

View File

@ -118,7 +118,7 @@ class Language(object):
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg), "tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg), "parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg), "ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
"entity_linker": lambda nlp, **cfg: EntityLinker(**cfg), "entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg), "similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg), "textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg), "sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
@ -811,13 +811,6 @@ class Language(object):
exclude = list(exclude) + ["vocab"] exclude = list(exclude) + ["vocab"]
util.from_disk(path, deserializers, exclude) util.from_disk(path, deserializers, exclude)
# download the KB for the entity linking component - requires the vocab
for pipe_name, pipe in self.pipeline:
if pipe_name == "entity_linker":
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"])
kb.load_bulk(path / pipe_name / "kb")
pipe.set_kb(kb)
self._path = path self._path = path
return self return self

View File

@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..cli.pretrain import get_cossim_loss from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens from .functions import merge_subtokens
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -1079,7 +1080,8 @@ class EntityLinker(Pipe):
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg) model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
return model return model
def __init__(self, **cfg): def __init__(self, vocab, **cfg):
self.vocab = vocab
self.model = True self.model = True
self.kb = None self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
@ -1277,6 +1279,7 @@ class EntityLinker(Pipe):
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):
serialize = OrderedDict() serialize = OrderedDict()
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p) serialize["kb"] = lambda p: self.kb.dump(p)
if self.model not in (None, True, False): if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
@ -1289,8 +1292,15 @@ class EntityLinker(Pipe):
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open("rb").read()) self.model.from_bytes(p.open("rb").read())
def load_kb(p):
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
kb.load_bulk(p)
self.set_kb(kb)
deserialize = OrderedDict() deserialize = OrderedDict()
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["kb"] = load_kb
deserialize["model"] = load_model deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)