mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
ensure the lang of vocab and nlp stay consistent (#4057)
* ensure the language of vocab and nlp stay consistent across serialization * equality with =
This commit is contained in:
parent
a83c0add2e
commit
f7d950de6d
|
@ -415,6 +415,8 @@ class Errors(object):
|
||||||
"is assigned to a KB identifier.")
|
"is assigned to a KB identifier.")
|
||||||
E149 = ("Error deserializing model. Check that the config used to create the "
|
E149 = ("Error deserializing model. Check that the config used to create the "
|
||||||
"component matches the model being loaded.")
|
"component matches the model being loaded.")
|
||||||
|
E150 = ("The language of the `nlp` object and the `vocab` should be the same, "
|
||||||
|
"but found '{nlp}' and '{vocab}' respectively.")
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
class TempErrors(object):
|
class TempErrors(object):
|
||||||
|
|
|
@ -14,7 +14,8 @@ import srsly
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker
|
from .pipeline import DependencyParser, Tagger
|
||||||
|
from .pipeline import Tensorizer, EntityRecognizer, EntityLinker
|
||||||
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
|
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
|
||||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||||
from .pipeline import EntityRuler
|
from .pipeline import EntityRuler
|
||||||
|
@ -158,6 +159,9 @@ class Language(object):
|
||||||
vocab = factory(self, **meta.get("vocab", {}))
|
vocab = factory(self, **meta.get("vocab", {}))
|
||||||
if vocab.vectors.name is None:
|
if vocab.vectors.name is None:
|
||||||
vocab.vectors.name = meta.get("vectors", {}).get("name")
|
vocab.vectors.name = meta.get("vectors", {}).get("name")
|
||||||
|
else:
|
||||||
|
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||||
|
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
if make_doc is True:
|
if make_doc is True:
|
||||||
factory = self.Defaults.create_tokenizer
|
factory = self.Defaults.create_tokenizer
|
||||||
|
@ -173,7 +177,10 @@ class Language(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def meta(self):
|
def meta(self):
|
||||||
|
if self.vocab.lang:
|
||||||
self._meta.setdefault("lang", self.vocab.lang)
|
self._meta.setdefault("lang", self.vocab.lang)
|
||||||
|
else:
|
||||||
|
self._meta.setdefault("lang", self.lang)
|
||||||
self._meta.setdefault("name", "model")
|
self._meta.setdefault("name", "model")
|
||||||
self._meta.setdefault("version", "0.0.0")
|
self._meta.setdefault("version", "0.0.0")
|
||||||
self._meta.setdefault("spacy_version", ">={}".format(about.__version__))
|
self._meta.setdefault("spacy_version", ">={}".format(about.__version__))
|
||||||
|
@ -618,7 +625,9 @@ class Language(object):
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
component_cfg = {}
|
component_cfg = {}
|
||||||
docs, golds = zip(*docs_golds)
|
docs, golds = zip(*docs_golds)
|
||||||
docs = [self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs]
|
docs = [
|
||||||
|
self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs
|
||||||
|
]
|
||||||
golds = list(golds)
|
golds = list(golds)
|
||||||
for name, pipe in self.pipeline:
|
for name, pipe in self.pipeline:
|
||||||
kwargs = component_cfg.get(name, {})
|
kwargs = component_cfg.get(name, {})
|
||||||
|
@ -769,8 +778,12 @@ class Language(object):
|
||||||
exclude = disable
|
exclude = disable
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
serializers = OrderedDict()
|
serializers = OrderedDict()
|
||||||
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(p, exclude=["vocab"])
|
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(
|
||||||
serializers["meta.json"] = lambda p: p.open("w").write(srsly.json_dumps(self.meta))
|
p, exclude=["vocab"]
|
||||||
|
)
|
||||||
|
serializers["meta.json"] = lambda p: p.open("w").write(
|
||||||
|
srsly.json_dumps(self.meta)
|
||||||
|
)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if not hasattr(proc, "name"):
|
if not hasattr(proc, "name"):
|
||||||
continue
|
continue
|
||||||
|
@ -799,14 +812,20 @@ class Language(object):
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
deserializers = OrderedDict()
|
deserializers = OrderedDict()
|
||||||
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
|
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
|
||||||
deserializers["vocab"] = lambda p: self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self)
|
deserializers["vocab"] = lambda p: self.vocab.from_disk(
|
||||||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(p, exclude=["vocab"])
|
p
|
||||||
|
) and _fix_pretrained_vectors_name(self)
|
||||||
|
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
||||||
|
p, exclude=["vocab"]
|
||||||
|
)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in exclude:
|
if name in exclude:
|
||||||
continue
|
continue
|
||||||
if not hasattr(proc, "from_disk"):
|
if not hasattr(proc, "from_disk"):
|
||||||
continue
|
continue
|
||||||
deserializers[name] = lambda p, proc=proc: proc.from_disk(p, exclude=["vocab"])
|
deserializers[name] = lambda p, proc=proc: proc.from_disk(
|
||||||
|
p, exclude=["vocab"]
|
||||||
|
)
|
||||||
if not (path / "vocab").exists() and "vocab" not in exclude:
|
if not (path / "vocab").exists() and "vocab" not in exclude:
|
||||||
# Convert to list here in case exclude is (default) tuple
|
# Convert to list here in case exclude is (default) tuple
|
||||||
exclude = list(exclude) + ["vocab"]
|
exclude = list(exclude) + ["vocab"]
|
||||||
|
@ -852,14 +871,20 @@ class Language(object):
|
||||||
exclude = disable
|
exclude = disable
|
||||||
deserializers = OrderedDict()
|
deserializers = OrderedDict()
|
||||||
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
|
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
|
||||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self)
|
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
|
||||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(b, exclude=["vocab"])
|
b
|
||||||
|
) and _fix_pretrained_vectors_name(self)
|
||||||
|
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||||
|
b, exclude=["vocab"]
|
||||||
|
)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in exclude:
|
if name in exclude:
|
||||||
continue
|
continue
|
||||||
if not hasattr(proc, "from_bytes"):
|
if not hasattr(proc, "from_bytes"):
|
||||||
continue
|
continue
|
||||||
deserializers[name] = lambda b, proc=proc: proc.from_bytes(b, exclude=["vocab"])
|
deserializers[name] = lambda b, proc=proc: proc.from_bytes(
|
||||||
|
b, exclude=["vocab"]
|
||||||
|
)
|
||||||
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
|
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
|
||||||
util.from_bytes(bytes_data, deserializers, exclude)
|
util.from_bytes(bytes_data, deserializers, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
33
spacy/tests/regression/test_issue4054.py
Normal file
33
spacy/tests/regression/test_issue4054.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tests.util import make_tempdir
|
||||||
|
from spacy.util import ensure_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue4054(en_vocab):
|
||||||
|
"""Test that a new blank model can be made with a vocab from file,
|
||||||
|
and that serialization does not drop the language at any point."""
|
||||||
|
nlp1 = English()
|
||||||
|
vocab1 = nlp1.vocab
|
||||||
|
|
||||||
|
with make_tempdir() as d:
|
||||||
|
vocab_dir = ensure_path(d / "vocab")
|
||||||
|
if not vocab_dir.exists():
|
||||||
|
vocab_dir.mkdir()
|
||||||
|
vocab1.to_disk(vocab_dir)
|
||||||
|
|
||||||
|
vocab2 = Vocab().from_disk(vocab_dir)
|
||||||
|
print("lang", vocab2.lang)
|
||||||
|
nlp2 = spacy.blank("en", vocab=vocab2)
|
||||||
|
|
||||||
|
nlp_dir = ensure_path(d / "nlp")
|
||||||
|
if not nlp_dir.exists():
|
||||||
|
nlp_dir.mkdir()
|
||||||
|
nlp2.to_disk(nlp_dir)
|
||||||
|
nlp3 = spacy.load(nlp_dir)
|
||||||
|
assert nlp3.lang == "en"
|
Loading…
Reference in New Issue
Block a user