From 862da5e7932bda6ceee3091c915b782a8215cc83 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 22 May 2018 18:29:45 +0200 Subject: [PATCH] Support pipeline factories via entry points (#2348) --- spacy/errors.py | 2 ++ spacy/language.py | 7 ++++++- spacy/util.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/spacy/errors.py b/spacy/errors.py index b812a6f76..ad0518fca 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -46,6 +46,8 @@ class Warnings(object): "use context-sensitive tensors. You can always add your own word " "vectors, or use one of the larger models instead if available.") W008 = ("Evaluating {obj}.similarity based on empty vectors.") + W009 = ("Custom factory '{name}' provided by entry points of another " + "package overwrites built-in factory.") @add_codes diff --git a/spacy/language.py b/spacy/language.py index be331e72d..e1e01d0ca 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -28,7 +28,7 @@ from .lang.punctuation import TOKENIZER_INFIXES from .lang.tokenizer_exceptions import TOKEN_MATCH from .lang.tag_map import TAG_MAP from .lang.lex_attrs import LEX_ATTRS, is_stop -from .errors import Errors +from .errors import Errors, Warnings, user_warning from . import util from . import about @@ -139,6 +139,11 @@ class Language(object): 100,000 characters in one text. RETURNS (Language): The newly constructed object. """ + user_factories = util.get_entry_points('spacy_factories') + for factory in user_factories.keys(): + if factory in self.factories: + user_warning(Warnings.W009.format(name=factory)) + self.factories.update(user_factories) self._meta = dict(meta) self._path = None if vocab is True: diff --git a/spacy/util.py b/spacy/util.py index b80142c38..80adf7257 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -220,6 +220,19 @@ def get_package_path(name): return Path(pkg.__file__).parent +def get_entry_points(key): + """Get registered entry points from other packages for a given key, e.g. + 'spacy_factories' and return them as a dictionary, keyed by name. + + key (unicode): Entry point name. + RETURNS (dict): Entry points, keyed by name. + """ + result = {} + for entry_point in pkg_resources.iter_entry_points(key): + result[entry_point.name] = entry_point.load() + return result + + def is_in_jupyter(): """Check if user is running spaCy from a Jupyter notebook by detecting the IPython kernel. Mainly used for the displaCy visualizer.