ensure the lang of vocab and nlp stay consistent (#4057)

* ensure the language of vocab and nlp stay consistent across serialization

* equality with =
This commit is contained in:
Sofie Van Landeghem 2019-08-01 17:13:01 +02:00 committed by Matthew Honnibal
parent a83c0add2e
commit f7d950de6d
3 changed files with 71 additions and 11 deletions

View File

@ -415,6 +415,8 @@ class Errors(object):
"is assigned to a KB identifier.") "is assigned to a KB identifier.")
E149 = ("Error deserializing model. Check that the config used to create the " E149 = ("Error deserializing model. Check that the config used to create the "
"component matches the model being loaded.") "component matches the model being loaded.")
E150 = ("The language of the `nlp` object and the `vocab` should be the same, "
"but found '{nlp}' and '{vocab}' respectively.")
@add_codes @add_codes
class TempErrors(object): class TempErrors(object):

View File

@ -14,7 +14,8 @@ import srsly
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker from .pipeline import DependencyParser, Tagger
from .pipeline import Tensorizer, EntityRecognizer, EntityLinker
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
from .pipeline import EntityRuler from .pipeline import EntityRuler
@ -158,6 +159,9 @@ class Language(object):
vocab = factory(self, **meta.get("vocab", {})) vocab = factory(self, **meta.get("vocab", {}))
if vocab.vectors.name is None: if vocab.vectors.name is None:
vocab.vectors.name = meta.get("vectors", {}).get("name") vocab.vectors.name = meta.get("vectors", {}).get("name")
else:
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
self.vocab = vocab self.vocab = vocab
if make_doc is True: if make_doc is True:
factory = self.Defaults.create_tokenizer factory = self.Defaults.create_tokenizer
@ -173,7 +177,10 @@ class Language(object):
@property @property
def meta(self): def meta(self):
if self.vocab.lang:
self._meta.setdefault("lang", self.vocab.lang) self._meta.setdefault("lang", self.vocab.lang)
else:
self._meta.setdefault("lang", self.lang)
self._meta.setdefault("name", "model") self._meta.setdefault("name", "model")
self._meta.setdefault("version", "0.0.0") self._meta.setdefault("version", "0.0.0")
self._meta.setdefault("spacy_version", ">={}".format(about.__version__)) self._meta.setdefault("spacy_version", ">={}".format(about.__version__))
@ -618,7 +625,9 @@ class Language(object):
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
docs, golds = zip(*docs_golds) docs, golds = zip(*docs_golds)
docs = [self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs] docs = [
self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs
]
golds = list(golds) golds = list(golds)
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {}) kwargs = component_cfg.get(name, {})
@ -769,8 +778,12 @@ class Language(object):
exclude = disable exclude = disable
path = util.ensure_path(path) path = util.ensure_path(path)
serializers = OrderedDict() serializers = OrderedDict()
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(p, exclude=["vocab"]) serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(
serializers["meta.json"] = lambda p: p.open("w").write(srsly.json_dumps(self.meta)) p, exclude=["vocab"]
)
serializers["meta.json"] = lambda p: p.open("w").write(
srsly.json_dumps(self.meta)
)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if not hasattr(proc, "name"): if not hasattr(proc, "name"):
continue continue
@ -799,14 +812,20 @@ class Language(object):
path = util.ensure_path(path) path = util.ensure_path(path)
deserializers = OrderedDict() deserializers = OrderedDict()
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p)) deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
deserializers["vocab"] = lambda p: self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self) deserializers["vocab"] = lambda p: self.vocab.from_disk(
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(p, exclude=["vocab"]) p
) and _fix_pretrained_vectors_name(self)
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"]
)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_disk"): if not hasattr(proc, "from_disk"):
continue continue
deserializers[name] = lambda p, proc=proc: proc.from_disk(p, exclude=["vocab"]) deserializers[name] = lambda p, proc=proc: proc.from_disk(
p, exclude=["vocab"]
)
if not (path / "vocab").exists() and "vocab" not in exclude: if not (path / "vocab").exists() and "vocab" not in exclude:
# Convert to list here in case exclude is (default) tuple # Convert to list here in case exclude is (default) tuple
exclude = list(exclude) + ["vocab"] exclude = list(exclude) + ["vocab"]
@ -852,14 +871,20 @@ class Language(object):
exclude = disable exclude = disable
deserializers = OrderedDict() deserializers = OrderedDict()
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b)) deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self) deserializers["vocab"] = lambda b: self.vocab.from_bytes(
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(b, exclude=["vocab"]) b
) and _fix_pretrained_vectors_name(self)
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_bytes"): if not hasattr(proc, "from_bytes"):
continue continue
deserializers[name] = lambda b, proc=proc: proc.from_bytes(b, exclude=["vocab"]) deserializers[name] = lambda b, proc=proc: proc.from_bytes(
b, exclude=["vocab"]
)
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs) exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
util.from_bytes(bytes_data, deserializers, exclude) util.from_bytes(bytes_data, deserializers, exclude)
return self return self

View File

@ -0,0 +1,33 @@
# coding: utf8
from __future__ import unicode_literals
from spacy.vocab import Vocab
import spacy
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.util import ensure_path
def test_issue4054(en_vocab):
"""Test that a new blank model can be made with a vocab from file,
and that serialization does not drop the language at any point."""
nlp1 = English()
vocab1 = nlp1.vocab
with make_tempdir() as d:
vocab_dir = ensure_path(d / "vocab")
if not vocab_dir.exists():
vocab_dir.mkdir()
vocab1.to_disk(vocab_dir)
vocab2 = Vocab().from_disk(vocab_dir)
print("lang", vocab2.lang)
nlp2 = spacy.blank("en", vocab=vocab2)
nlp_dir = ensure_path(d / "nlp")
if not nlp_dir.exists():
nlp_dir.mkdir()
nlp2.to_disk(nlp_dir)
nlp3 = spacy.load(nlp_dir)
assert nlp3.lang == "en"