Improve checks for sourced components (#7490)

* Improve checks for sourced components

* Remove language class checks

* Convert python warning to logger warning

* Remove unused warning

* Fix formatting
This commit is contained in:
Adriane Boyd 2021-04-19 10:36:32 +02:00 committed by GitHub
parent 05bdbe28bb
commit 1ad646cbcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 4 deletions

View File

@ -159,6 +159,8 @@ class Warnings:
"http://spacy.io/usage/v3#jupyter-notebook-gpu")
W112 = ("The model specified to use for initial vectors ({name}) has no "
"vectors. This is almost certainly a mistake.")
W113 = ("Sourced component '{name}' may not work as expected: source "
"vectors are not identical to current pipeline vectors.")
@add_codes
@ -651,8 +653,8 @@ class Errors:
"returned the initialized nlp object instead?")
E944 = ("Can't copy pipeline component '{name}' from source '{model}': "
"not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected "
"loaded nlp object, but got: {source}")
E947 = ("`Matcher.add` received invalid `greedy` argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("`Matcher.add` received invalid 'patterns' argument: expected "

View File

@ -682,9 +682,14 @@ class Language:
name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name.
"""
# TODO: handle errors and mismatches (vectors etc.)
if not isinstance(source, self.__class__):
# Check source type
if not isinstance(source, Language):
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
# Check vectors, with faster checks first
if self.vocab.vectors.shape != source.vocab.vectors.shape or \
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes():
util.logger.warning(Warnings.W113.format(name=source_name))
if not source_name in source.component_names:
raise KeyError(
Errors.E944.format(

View File

@ -1,4 +1,6 @@
import pytest
import mock
import logging
from spacy.language import Language
from spacy.lang.en import English
from spacy.lang.de import German
@ -402,6 +404,38 @@ def test_pipe_factories_from_source():
nlp.add_pipe("custom", source=source_nlp)
def test_pipe_factories_from_source_language_subclass():
class CustomEnglishDefaults(English.Defaults):
stop_words = set(["custom", "stop"])
@registry.languages("custom_en")
class CustomEnglish(English):
lang = "custom_en"
Defaults = CustomEnglishDefaults
source_nlp = English()
source_nlp.add_pipe("tagger")
# custom subclass
nlp = CustomEnglish()
nlp.add_pipe("tagger", source=source_nlp)
assert "tagger" in nlp.pipe_names
# non-subclass
nlp = German()
nlp.add_pipe("tagger", source=source_nlp)
assert "tagger" in nlp.pipe_names
# mismatched vectors
nlp = English()
nlp.vocab.vectors.resize((1, 4))
nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4])
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
nlp.add_pipe("tagger", source=source_nlp)
mock_warning.assert_called()
def test_pipe_factories_from_source_custom():
"""Test adding components from a source model with custom components."""
name = "test_pipe_factories_from_source_custom"