Don't use the same vocab for source models (#8388)

* Don't use the same vocab for source models

The source models should not be loaded with the vocab from the current
pipeline because this loads the vectors from the source model into the
current vocab.

The strings are all copied in `Language.create_pipe_from_source`, so if
the vectors are configured correctly in the current pipeline, the
sourced component will work as expected. If there is a vector mismatch,
a warning is shown. (It's not possible to inspect whether the vectors
are actually used by the component, so a warning is the best option.)

* Update comment on source model loading
This commit is contained in:
Adriane Boyd 2021-06-21 09:33:33 +02:00 committed by GitHub
parent 02d2fdb123
commit 7abfa25035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 3 deletions

View File

@ -1696,9 +1696,12 @@ class Language:
else:
model = pipe_cfg["source"]
if model not in source_nlps:
# We only need the components here and we need to init
# model with the same vocab as the current nlp object
source_nlps[model] = util.load_model(model, vocab=nlp.vocab)
# We only need the components here and we intentionally
# do not load the model with the same vocab because
# this would cause the vectors to be copied into the
# current nlp object (all the strings will be added in
# create_pipe_from_source)
source_nlps[model] = util.load_model(model)
source_name = pipe_cfg.get("component", pipe_name)
listeners_replaced = False
if "replace_listeners" in pipe_cfg:

View File

@ -475,3 +475,26 @@ def test_language_init_invalid_vocab(value):
with pytest.raises(ValueError) as e:
Language(value)
assert err_fragment in str(e.value)
def test_language_source_and_vectors(nlp2):
nlp = Language(Vocab())
textcat = nlp.add_pipe("textcat")
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
nlp.initialize()
long_string = "thisisalongstring"
assert long_string not in nlp.vocab.strings
assert long_string not in nlp2.vocab.strings
nlp.vocab.strings.add(long_string)
assert nlp.vocab.vectors.to_bytes() != nlp2.vocab.vectors.to_bytes()
vectors_bytes = nlp.vocab.vectors.to_bytes()
# TODO: convert to pytest.warns for v3.1
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
nlp2.add_pipe("textcat", name="textcat2", source=nlp)
mock_warning.assert_called()
# strings should be added
assert long_string in nlp2.vocab.strings
# vectors should remain unmodified
assert nlp.vocab.vectors.to_bytes() == vectors_bytes