diff --git a/spacy/errors.py b/spacy/errors.py index 1a576faee..2b803576c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -128,7 +128,8 @@ class Errors: "got {component} (name: '{name}'). If you're using a custom " "component factory, double-check that it correctly returns your " "initialized component.") - E004 = ("Can't set up pipeline component: a factory for '{name}' already exists.") + E004 = ("Can't set up pipeline component: a factory for '{name}' already " + "exists. Existing factory: {func}. New factory: {new_func}") E005 = ("Pipeline component '{name}' returned None. If you're using a " "custom component, maybe you forgot to return the processed Doc?") E006 = ("Invalid constraints for adding pipeline component. You can only " diff --git a/spacy/language.py b/spacy/language.py index 90d2cf81a..1b9a2bfc8 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -396,13 +396,21 @@ class Language: style="default config", name=name, cfg_type=type(default_config) ) raise ValueError(err) - internal_name = cls.get_factory_name(name) - if internal_name in registry.factories: - # We only check for the internal name here – it's okay if it's a - # subclass and the base class has a factory of the same name - raise ValueError(Errors.E004.format(name=name)) def add_factory(factory_func: Callable) -> Callable: + internal_name = cls.get_factory_name(name) + if internal_name in registry.factories: + # We only check for the internal name here – it's okay if it's a + # subclass and the base class has a factory of the same name. We + # also only raise if the function is different to prevent raising + # if module is reloaded. + existing_func = registry.factories.get(internal_name) + if not util.is_same_func(factory_func, existing_func): + err = Errors.E004.format( + name=name, func=existing_func, new_func=factory_func + ) + raise ValueError(err) + arg_names = util.get_arg_names(factory_func) if "nlp" not in arg_names or "name" not in arg_names: raise ValueError(Errors.E964.format(name=name)) @@ -472,6 +480,21 @@ class Language: def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]: return component_func + internal_name = cls.get_factory_name(name) + if internal_name in registry.factories: + # We only check for the internal name here – it's okay if it's a + # subclass and the base class has a factory of the same name. We + # also only raise if the function is different to prevent raising + # if module is reloaded. It's hacky, but we need to check the + # existing functure for a closure and whether that's identical + # to the component function (because factory_func created above + # will always be different, even for the same function) + existing_func = registry.factories.get(internal_name) + closure = existing_func.__closure__ + wrapped = [c.cell_contents for c in closure][0] if closure else None + if util.is_same_func(wrapped, component_func): + factory_func = existing_func # noqa: F811 + cls.factory( component_name, assigns=assigns, diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py index aa682fefe..f75c9ec8c 100644 --- a/spacy/tests/pipeline/test_pipe_factories.py +++ b/spacy/tests/pipeline/test_pipe_factories.py @@ -438,3 +438,26 @@ def test_pipe_factories_from_source_config(): config = nlp.config["components"]["custom"] assert config["factory"] == name assert config["arg"] == "world" + + +def test_pipe_factories_decorator_idempotent(): + """Check that decorator can be run multiple times if the function is the + same. This is especially relevant for live reloading because we don't + want spaCy to raise an error if a module registering components is reloaded. + """ + name = "test_pipe_factories_decorator_idempotent" + func = lambda nlp, name: lambda doc: doc + for i in range(5): + Language.factory(name, func=func) + nlp = Language() + nlp.add_pipe(name) + Language.factory(name, func=func) + # Make sure it also works for component decorator, which creates the + # factory function + name2 = f"{name}2" + func2 = lambda doc: doc + for i in range(5): + Language.component(name2, func=func2) + nlp = Language() + nlp.add_pipe(name) + Language.component(name2, func=func2) diff --git a/spacy/util.py b/spacy/util.py index 8ba164dc1..149bbb21a 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -673,6 +673,25 @@ def get_object_name(obj: Any) -> str: return repr(obj) +def is_same_func(func1: Callable, func2: Callable) -> bool: + """Approximately decide whether two functions are the same, even if their + identity is different (e.g. after they have been live reloaded). Mostly + used in the @Language.component and @Language.factory decorators to decide + whether to raise if a factory already exists. Allows decorator to run + multiple times with the same function. + + func1 (Callable): The first function. + func2 (Callable): The second function. + RETURNS (bool): Whether it's the same function (most likely). + """ + if not callable(func1) or not callable(func2): + return False + same_name = func1.__qualname__ == func2.__qualname__ + same_file = inspect.getfile(func1) == inspect.getfile(func2) + same_code = inspect.getsourcelines(func1) == inspect.getsourcelines(func2) + return all([same_name, same_file, same_code]) + + def get_cuda_stream( require: bool = False, non_blocking: bool = True ) -> Optional[CudaStream]: