From ed18a6efbd0aed54be103921ceedd15157722cb7 Mon Sep 17 00:00:00 2001 From: BreakBB <33514570+BreakBB@users.noreply.github.com> Date: Tue, 14 May 2019 16:59:31 +0200 Subject: [PATCH] Add check for callable to 'Language.replace_pipe' to fix #3737 (#3741) --- spacy/errors.py | 2 ++ spacy/language.py | 5 +++++ spacy/tests/pipeline/test_pipe_methods.py | 6 ++++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 5f964114e..b28393156 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -383,6 +383,8 @@ class Errors(object): E133 = ("The sum of prior probabilities for alias '{alias}' should not exceed 1, " "but found {sum}.") E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.") + E135 = ("If you meant to replace a built-in component, use `create_pipe`: " + "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`") @add_codes diff --git a/spacy/language.py b/spacy/language.py index 6bd21b0bc..924c0b423 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -333,6 +333,11 @@ class Language(object): """ if name not in self.pipe_names: raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names)) + if not hasattr(component, "__call__"): + msg = Errors.E003.format(component=repr(component), name=name) + if isinstance(component, basestring_) and component in self.factories: + msg += Errors.E135.format(name=name) + raise ValueError(msg) self.pipeline[self.pipe_names.index(name)] = (name, component) def rename_pipe(self, old_name, new_name): diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index d36201718..a0870784c 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -52,11 +52,13 @@ def test_get_pipe(nlp, name): assert nlp.get_pipe(name) == new_pipe -@pytest.mark.parametrize("name,replacement", [("my_component", lambda doc: doc)]) -def test_replace_pipe(nlp, name, replacement): +@pytest.mark.parametrize("name,replacement,not_callable", [("my_component", lambda doc: doc, {})]) +def test_replace_pipe(nlp, name, replacement, not_callable): with pytest.raises(ValueError): nlp.replace_pipe(name, new_pipe) nlp.add_pipe(new_pipe, name=name) + with pytest.raises(ValueError): + nlp.replace_pipe(name, not_callable) nlp.replace_pipe(name, replacement) assert nlp.get_pipe(name) != new_pipe assert nlp.get_pipe(name) == replacement