mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
ensure the lang of vocab and nlp stay consistent (#4057)
* ensure the language of vocab and nlp stay consistent across serialization * equality with =
This commit is contained in:
parent
a83c0add2e
commit
f7d950de6d
|
@ -415,6 +415,8 @@ class Errors(object):
|
|||
"is assigned to a KB identifier.")
|
||||
E149 = ("Error deserializing model. Check that the config used to create the "
|
||||
"component matches the model being loaded.")
|
||||
E150 = ("The language of the `nlp` object and the `vocab` should be the same, "
|
||||
"but found '{nlp}' and '{vocab}' respectively.")
|
||||
|
||||
@add_codes
|
||||
class TempErrors(object):
|
||||
|
|
|
@ -14,7 +14,8 @@ import srsly
|
|||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker
|
||||
from .pipeline import DependencyParser, Tagger
|
||||
from .pipeline import Tensorizer, EntityRecognizer, EntityLinker
|
||||
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
|
||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||
from .pipeline import EntityRuler
|
||||
|
@ -158,6 +159,9 @@ class Language(object):
|
|||
vocab = factory(self, **meta.get("vocab", {}))
|
||||
if vocab.vectors.name is None:
|
||||
vocab.vectors.name = meta.get("vectors", {}).get("name")
|
||||
else:
|
||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||
self.vocab = vocab
|
||||
if make_doc is True:
|
||||
factory = self.Defaults.create_tokenizer
|
||||
|
@ -173,7 +177,10 @@ class Language(object):
|
|||
|
||||
@property
|
||||
def meta(self):
|
||||
self._meta.setdefault("lang", self.vocab.lang)
|
||||
if self.vocab.lang:
|
||||
self._meta.setdefault("lang", self.vocab.lang)
|
||||
else:
|
||||
self._meta.setdefault("lang", self.lang)
|
||||
self._meta.setdefault("name", "model")
|
||||
self._meta.setdefault("version", "0.0.0")
|
||||
self._meta.setdefault("spacy_version", ">={}".format(about.__version__))
|
||||
|
@ -618,7 +625,9 @@ class Language(object):
|
|||
if component_cfg is None:
|
||||
component_cfg = {}
|
||||
docs, golds = zip(*docs_golds)
|
||||
docs = [self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs]
|
||||
docs = [
|
||||
self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs
|
||||
]
|
||||
golds = list(golds)
|
||||
for name, pipe in self.pipeline:
|
||||
kwargs = component_cfg.get(name, {})
|
||||
|
@ -769,8 +778,12 @@ class Language(object):
|
|||
exclude = disable
|
||||
path = util.ensure_path(path)
|
||||
serializers = OrderedDict()
|
||||
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(p, exclude=["vocab"])
|
||||
serializers["meta.json"] = lambda p: p.open("w").write(srsly.json_dumps(self.meta))
|
||||
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
serializers["meta.json"] = lambda p: p.open("w").write(
|
||||
srsly.json_dumps(self.meta)
|
||||
)
|
||||
for name, proc in self.pipeline:
|
||||
if not hasattr(proc, "name"):
|
||||
continue
|
||||
|
@ -799,14 +812,20 @@ class Language(object):
|
|||
path = util.ensure_path(path)
|
||||
deserializers = OrderedDict()
|
||||
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
|
||||
deserializers["vocab"] = lambda p: self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(p, exclude=["vocab"])
|
||||
deserializers["vocab"] = lambda p: self.vocab.from_disk(
|
||||
p
|
||||
) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
for name, proc in self.pipeline:
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "from_disk"):
|
||||
continue
|
||||
deserializers[name] = lambda p, proc=proc: proc.from_disk(p, exclude=["vocab"])
|
||||
deserializers[name] = lambda p, proc=proc: proc.from_disk(
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
if not (path / "vocab").exists() and "vocab" not in exclude:
|
||||
# Convert to list here in case exclude is (default) tuple
|
||||
exclude = list(exclude) + ["vocab"]
|
||||
|
@ -852,14 +871,20 @@ class Language(object):
|
|||
exclude = disable
|
||||
deserializers = OrderedDict()
|
||||
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
|
||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(b, exclude=["vocab"])
|
||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
|
||||
b
|
||||
) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||
b, exclude=["vocab"]
|
||||
)
|
||||
for name, proc in self.pipeline:
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "from_bytes"):
|
||||
continue
|
||||
deserializers[name] = lambda b, proc=proc: proc.from_bytes(b, exclude=["vocab"])
|
||||
deserializers[name] = lambda b, proc=proc: proc.from_bytes(
|
||||
b, exclude=["vocab"]
|
||||
)
|
||||
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
|
||||
util.from_bytes(bytes_data, deserializers, exclude)
|
||||
return self
|
||||
|
|
33
spacy/tests/regression/test_issue4054.py
Normal file
33
spacy/tests/regression/test_issue4054.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
import spacy
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.util import ensure_path
|
||||
|
||||
|
||||
def test_issue4054(en_vocab):
|
||||
"""Test that a new blank model can be made with a vocab from file,
|
||||
and that serialization does not drop the language at any point."""
|
||||
nlp1 = English()
|
||||
vocab1 = nlp1.vocab
|
||||
|
||||
with make_tempdir() as d:
|
||||
vocab_dir = ensure_path(d / "vocab")
|
||||
if not vocab_dir.exists():
|
||||
vocab_dir.mkdir()
|
||||
vocab1.to_disk(vocab_dir)
|
||||
|
||||
vocab2 = Vocab().from_disk(vocab_dir)
|
||||
print("lang", vocab2.lang)
|
||||
nlp2 = spacy.blank("en", vocab=vocab2)
|
||||
|
||||
nlp_dir = ensure_path(d / "nlp")
|
||||
if not nlp_dir.exists():
|
||||
nlp_dir.mkdir()
|
||||
nlp2.to_disk(nlp_dir)
|
||||
nlp3 = spacy.load(nlp_dir)
|
||||
assert nlp3.lang == "en"
|
Loading…
Reference in New Issue
Block a user