Add check for callable to 'Language.replace_pipe' to fix #3737 (#3741)

This commit is contained in:
BreakBB 2019-05-14 16:59:31 +02:00 committed by Ines Montani
parent 8baff1c7c0
commit ed18a6efbd
3 changed files with 11 additions and 2 deletions

View File

@ -383,6 +383,8 @@ class Errors(object):
E133 = ("The sum of prior probabilities for alias '{alias}' should not exceed 1, "
"but found {sum}.")
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
@add_codes

View File

@ -333,6 +333,11 @@ class Language(object):
"""
if name not in self.pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
if not hasattr(component, "__call__"):
msg = Errors.E003.format(component=repr(component), name=name)
if isinstance(component, basestring_) and component in self.factories:
msg += Errors.E135.format(name=name)
raise ValueError(msg)
self.pipeline[self.pipe_names.index(name)] = (name, component)
def rename_pipe(self, old_name, new_name):

View File

@ -52,11 +52,13 @@ def test_get_pipe(nlp, name):
assert nlp.get_pipe(name) == new_pipe
@pytest.mark.parametrize("name,replacement", [("my_component", lambda doc: doc)])
def test_replace_pipe(nlp, name, replacement):
@pytest.mark.parametrize("name,replacement,not_callable", [("my_component", lambda doc: doc, {})])
def test_replace_pipe(nlp, name, replacement, not_callable):
with pytest.raises(ValueError):
nlp.replace_pipe(name, new_pipe)
nlp.add_pipe(new_pipe, name=name)
with pytest.raises(ValueError):
nlp.replace_pipe(name, not_callable)
nlp.replace_pipe(name, replacement)
assert nlp.get_pipe(name) != new_pipe
assert nlp.get_pipe(name) == replacement