mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Don't use the same vocab for source models (#8388)
* Don't use the same vocab for source models The source models should not be loaded with the vocab from the current pipeline because this loads the vectors from the source model into the current vocab. The strings are all copied in `Language.create_pipe_from_source`, so if the vectors are configured correctly in the current pipeline, the sourced component will work as expected. If there is a vector mismatch, a warning is shown. (It's not possible to inspect whether the vectors are actually used by the component, so a warning is the best option.) * Update comment on source model loading
This commit is contained in:
parent
02d2fdb123
commit
7abfa25035
|
@ -1696,9 +1696,12 @@ class Language:
|
||||||
else:
|
else:
|
||||||
model = pipe_cfg["source"]
|
model = pipe_cfg["source"]
|
||||||
if model not in source_nlps:
|
if model not in source_nlps:
|
||||||
# We only need the components here and we need to init
|
# We only need the components here and we intentionally
|
||||||
# model with the same vocab as the current nlp object
|
# do not load the model with the same vocab because
|
||||||
source_nlps[model] = util.load_model(model, vocab=nlp.vocab)
|
# this would cause the vectors to be copied into the
|
||||||
|
# current nlp object (all the strings will be added in
|
||||||
|
# create_pipe_from_source)
|
||||||
|
source_nlps[model] = util.load_model(model)
|
||||||
source_name = pipe_cfg.get("component", pipe_name)
|
source_name = pipe_cfg.get("component", pipe_name)
|
||||||
listeners_replaced = False
|
listeners_replaced = False
|
||||||
if "replace_listeners" in pipe_cfg:
|
if "replace_listeners" in pipe_cfg:
|
||||||
|
|
|
@ -475,3 +475,26 @@ def test_language_init_invalid_vocab(value):
|
||||||
with pytest.raises(ValueError) as e:
|
with pytest.raises(ValueError) as e:
|
||||||
Language(value)
|
Language(value)
|
||||||
assert err_fragment in str(e.value)
|
assert err_fragment in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_language_source_and_vectors(nlp2):
|
||||||
|
nlp = Language(Vocab())
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
for label in ("POSITIVE", "NEGATIVE"):
|
||||||
|
textcat.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
long_string = "thisisalongstring"
|
||||||
|
assert long_string not in nlp.vocab.strings
|
||||||
|
assert long_string not in nlp2.vocab.strings
|
||||||
|
nlp.vocab.strings.add(long_string)
|
||||||
|
assert nlp.vocab.vectors.to_bytes() != nlp2.vocab.vectors.to_bytes()
|
||||||
|
vectors_bytes = nlp.vocab.vectors.to_bytes()
|
||||||
|
# TODO: convert to pytest.warns for v3.1
|
||||||
|
logger = logging.getLogger("spacy")
|
||||||
|
with mock.patch.object(logger, "warning") as mock_warning:
|
||||||
|
nlp2.add_pipe("textcat", name="textcat2", source=nlp)
|
||||||
|
mock_warning.assert_called()
|
||||||
|
# strings should be added
|
||||||
|
assert long_string in nlp2.vocab.strings
|
||||||
|
# vectors should remain unmodified
|
||||||
|
assert nlp.vocab.vectors.to_bytes() == vectors_bytes
|
||||||
|
|
Loading…
Reference in New Issue
Block a user