mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
deuglify kb deserializer
This commit is contained in:
parent
8840d4b1b3
commit
668b17ea4a
|
@ -118,7 +118,7 @@ class Language(object):
|
||||||
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
|
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
|
||||||
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
|
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
|
||||||
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
||||||
"entity_linker": lambda nlp, **cfg: EntityLinker(**cfg),
|
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
|
||||||
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||||
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
||||||
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
||||||
|
@ -811,13 +811,6 @@ class Language(object):
|
||||||
exclude = list(exclude) + ["vocab"]
|
exclude = list(exclude) + ["vocab"]
|
||||||
util.from_disk(path, deserializers, exclude)
|
util.from_disk(path, deserializers, exclude)
|
||||||
|
|
||||||
# download the KB for the entity linking component - requires the vocab
|
|
||||||
for pipe_name, pipe in self.pipeline:
|
|
||||||
if pipe_name == "entity_linker":
|
|
||||||
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"])
|
|
||||||
kb.load_bulk(path / pipe_name / "kb")
|
|
||||||
pipe.set_kb(kb)
|
|
||||||
|
|
||||||
self._path = path
|
self._path = path
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
|
||||||
from thinc.neural.util import to_categorical
|
from thinc.neural.util import to_categorical
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
from ..cli.pretrain import get_cossim_loss
|
from ..cli.pretrain import get_cossim_loss
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
@ -1079,7 +1080,8 @@ class EntityLinker(Pipe):
|
||||||
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
|
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def __init__(self, **cfg):
|
def __init__(self, vocab, **cfg):
|
||||||
|
self.vocab = vocab
|
||||||
self.model = True
|
self.model = True
|
||||||
self.kb = None
|
self.kb = None
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
|
@ -1277,6 +1279,7 @@ class EntityLinker(Pipe):
|
||||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
serialize = OrderedDict()
|
serialize = OrderedDict()
|
||||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||||
if self.model not in (None, True, False):
|
if self.model not in (None, True, False):
|
||||||
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
||||||
|
@ -1289,8 +1292,15 @@ class EntityLinker(Pipe):
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
self.model.from_bytes(p.open("rb").read())
|
self.model.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
|
def load_kb(p):
|
||||||
|
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
||||||
|
kb.load_bulk(p)
|
||||||
|
self.set_kb(kb)
|
||||||
|
|
||||||
deserialize = OrderedDict()
|
deserialize = OrderedDict()
|
||||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||||
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||||
|
deserialize["kb"] = load_kb
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user