From 09cec3e41b4408956f7ef5656c099c46695025a8 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 7 Nov 2019 11:45:22 +0100 Subject: [PATCH] Replace function registries with catalogue (#4584) * Replace functions registries with catalogue * Update __init__.py * Fix test * Revert unrelated flag [ci skip] --- requirements.txt | 2 +- setup.cfg | 2 +- spacy/__init__.py | 2 +- spacy/compat.py | 5 - spacy/displacy/render.py | 4 +- spacy/language.py | 6 +- spacy/ml/common.py | 6 +- spacy/ml/tok2vec.py | 18 ++-- spacy/tests/test_architectures.py | 19 ++++ spacy/tests/test_register_architecture.py | 19 ---- spacy/util.py | 113 ++++------------------ 11 files changed, 56 insertions(+), 140 deletions(-) create mode 100644 spacy/tests/test_architectures.py delete mode 100644 spacy/tests/test_register_architecture.py diff --git a/requirements.txt b/requirements.txt index 89118b970..12f19bb88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,12 +6,12 @@ blis>=0.4.0,<0.5.0 murmurhash>=0.28.0,<1.1.0 wasabi>=0.4.0,<1.1.0 srsly>=0.1.0,<1.1.0 +catalogue>=0.0.7,<1.1.0 # Third party dependencies numpy>=1.15.0 requests>=2.13.0,<3.0.0 plac>=0.9.6,<1.2.0 pathlib==1.0.1; python_version < "3.4" -importlib_metadata>=0.20; python_version < "3.8" # Optional dependencies jsonschema>=2.6.0,<3.1.0 # Development dependencies diff --git a/setup.cfg b/setup.cfg index 60a24dc58..940066a9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,13 +48,13 @@ install_requires = blis>=0.4.0,<0.5.0 wasabi>=0.4.0,<1.1.0 srsly>=0.1.0,<1.1.0 + catalogue>=0.0.7,<1.1.0 # Third-party dependencies setuptools numpy>=1.15.0 plac>=0.9.6,<1.2.0 requests>=2.13.0,<3.0.0 pathlib==1.0.1; python_version < "3.4" - importlib_metadata>=0.20; python_version < "3.8" [options.extras_require] lookups = diff --git a/spacy/__init__.py b/spacy/__init__.py index 57701179f..4a0d16a49 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -15,7 +15,7 @@ from .glossary import explain from .about import __version__ from .errors import Errors, Warnings, deprecation_warning from . import util -from .util import register_architecture, get_architecture +from .util import registry from .language import component diff --git a/spacy/compat.py b/spacy/compat.py index 5bff28815..0ea31c6b3 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -36,11 +36,6 @@ try: except ImportError: cupy = None -try: # Python 3.8 - import importlib.metadata as importlib_metadata -except ImportError: - import importlib_metadata # noqa: F401 - try: from thinc.neural.optimizers import Optimizer # noqa: F401 except ImportError: diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index 17b67940a..d6e33437b 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -5,7 +5,7 @@ import uuid from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE -from ..util import minify_html, escape_html, get_entry_points, ENTRY_POINTS +from ..util import minify_html, escape_html, registry from ..errors import Errors @@ -242,7 +242,7 @@ class EntityRenderer(object): "CARDINAL": "#e4e7d2", "PERCENT": "#e4e7d2", } - user_colors = get_entry_points(ENTRY_POINTS.displacy_colors) + user_colors = registry.displacy_colors.get_all() for user_color in user_colors.values(): colors.update(user_color) colors.update(options.get("colors", {})) diff --git a/spacy/language.py b/spacy/language.py index 97d6515c5..72044a0c5 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -51,8 +51,8 @@ class BaseDefaults(object): filenames = {name: root / filename for name, filename in cls.resources} if LANG in cls.lex_attr_getters: lang = cls.lex_attr_getters[LANG](None) - user_lookups = util.get_entry_point(util.ENTRY_POINTS.lookups, lang, {}) - filenames.update(user_lookups) + if lang in util.registry.lookups: + filenames.update(util.registry.lookups.get(lang)) lookups = Lookups() for name, filename in filenames.items(): data = util.load_language_data(filename) @@ -155,7 +155,7 @@ class Language(object): 100,000 characters in one text. RETURNS (Language): The newly constructed object. """ - user_factories = util.get_entry_points(util.ENTRY_POINTS.factories) + user_factories = util.registry.factories.get_all() self.factories.update(user_factories) self._meta = dict(meta) self._path = None diff --git a/spacy/ml/common.py b/spacy/ml/common.py index 963d4dc35..f90b53a15 100644 --- a/spacy/ml/common.py +++ b/spacy/ml/common.py @@ -3,10 +3,10 @@ from __future__ import unicode_literals from thinc.api import chain from thinc.v2v import Maxout from thinc.misc import LayerNorm -from ..util import register_architecture, make_layer +from ..util import registry, make_layer -@register_architecture("thinc.FeedForward.v1") +@registry.architectures.register("thinc.FeedForward.v1") def FeedForward(config): layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]] model = chain(*layers) @@ -14,7 +14,7 @@ def FeedForward(config): return model -@register_architecture("spacy.LayerNormalizedMaxout.v1") +@registry.architectures.register("spacy.LayerNormalizedMaxout.v1") def LayerNormalizedMaxout(config): width = config["width"] pieces = config["pieces"] diff --git a/spacy/ml/tok2vec.py b/spacy/ml/tok2vec.py index 0b30551b5..8f86475ef 100644 --- a/spacy/ml/tok2vec.py +++ b/spacy/ml/tok2vec.py @@ -6,11 +6,11 @@ from thinc.v2v import Maxout, Model from thinc.i2v import HashEmbed, StaticVectors from thinc.t2t import ExtractWindow from thinc.misc import Residual, LayerNorm, FeatureExtracter -from ..util import make_layer, register_architecture +from ..util import make_layer, registry from ._wire import concatenate_lists -@register_architecture("spacy.Tok2Vec.v1") +@registry.architectures.register("spacy.Tok2Vec.v1") def Tok2Vec(config): doc2feats = make_layer(config["@doc2feats"]) embed = make_layer(config["@embed"]) @@ -24,13 +24,13 @@ def Tok2Vec(config): return tok2vec -@register_architecture("spacy.Doc2Feats.v1") +@registry.architectures.register("spacy.Doc2Feats.v1") def Doc2Feats(config): columns = config["columns"] return FeatureExtracter(columns) -@register_architecture("spacy.MultiHashEmbed.v1") +@registry.architectures.register("spacy.MultiHashEmbed.v1") def MultiHashEmbed(config): # For backwards compatibility with models before the architecture registry, # we have to be careful to get exactly the same model structure. One subtle @@ -78,7 +78,7 @@ def MultiHashEmbed(config): return layer -@register_architecture("spacy.CharacterEmbed.v1") +@registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(config): from .. import _ml @@ -94,7 +94,7 @@ def CharacterEmbed(config): return model -@register_architecture("spacy.MaxoutWindowEncoder.v1") +@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") def MaxoutWindowEncoder(config): nO = config["width"] nW = config["window_size"] @@ -110,7 +110,7 @@ def MaxoutWindowEncoder(config): return model -@register_architecture("spacy.MishWindowEncoder.v1") +@registry.architectures.register("spacy.MishWindowEncoder.v1") def MishWindowEncoder(config): from thinc.v2v import Mish @@ -124,12 +124,12 @@ def MishWindowEncoder(config): return model -@register_architecture("spacy.PretrainedVectors.v1") +@registry.architectures.register("spacy.PretrainedVectors.v1") def PretrainedVectors(config): return StaticVectors(config["vectors_name"], config["width"], config["column"]) -@register_architecture("spacy.TorchBiLSTMEncoder.v1") +@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") def TorchBiLSTMEncoder(config): import torch.nn from thinc.extra.wrappers import PyTorchWrapperRNN diff --git a/spacy/tests/test_architectures.py b/spacy/tests/test_architectures.py new file mode 100644 index 000000000..77f1af020 --- /dev/null +++ b/spacy/tests/test_architectures.py @@ -0,0 +1,19 @@ +# coding: utf8 +from __future__ import unicode_literals + +import pytest +from spacy import registry +from thinc.v2v import Affine +from catalogue import RegistryError + + +@registry.architectures.register("my_test_function") +def create_model(nr_in, nr_out): + return Affine(nr_in, nr_out) + + +def test_get_architecture(): + arch = registry.architectures.get("my_test_function") + assert arch is create_model + with pytest.raises(RegistryError): + registry.architectures.get("not_an_existing_key") diff --git a/spacy/tests/test_register_architecture.py b/spacy/tests/test_register_architecture.py deleted file mode 100644 index 0c1b5b16f..000000000 --- a/spacy/tests/test_register_architecture.py +++ /dev/null @@ -1,19 +0,0 @@ -# coding: utf8 -from __future__ import unicode_literals - -import pytest -from spacy import register_architecture -from spacy import get_architecture -from thinc.v2v import Affine - - -@register_architecture("my_test_function") -def create_model(nr_in, nr_out): - return Affine(nr_in, nr_out) - - -def test_get_architecture(): - arch = get_architecture("my_test_function") - assert arch is create_model - with pytest.raises(KeyError): - get_architecture("not_an_existing_key") diff --git a/spacy/util.py b/spacy/util.py index 74e4cc1c6..2d5a56806 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -13,6 +13,7 @@ import functools import itertools import numpy.random import srsly +import catalogue import sys try: @@ -27,29 +28,20 @@ except ImportError: from .symbols import ORTH from .compat import cupy, CudaStream, path2str, basestring_, unicode_ -from .compat import import_file, importlib_metadata +from .compat import import_file from .errors import Errors, Warnings, deprecation_warning -LANGUAGES = {} -ARCHITECTURES = {} _data_path = Path(__file__).parent / "data" _PRINT_ENV = False -# NB: Ony ever call this once! If called more than ince within the -# function, test_issue1506 hangs and it's not 100% clear why. -AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points() - - -class ENTRY_POINTS(object): - """Available entry points to register extensions.""" - - factories = "spacy_factories" - languages = "spacy_languages" - displacy_colors = "spacy_displacy_colors" - lookups = "spacy_lookups" - architectures = "spacy_architectures" +class registry(object): + languages = catalogue.create("spacy", "languages", entry_points=True) + architectures = catalogue.create("spacy", "architectures", entry_points=True) + lookups = catalogue.create("spacy", "lookups", entry_points=True) + factories = catalogue.create("spacy", "factories", entry_points=True) + displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True) def set_env_log(value): @@ -65,8 +57,7 @@ def lang_class_is_loaded(lang): lang (unicode): Two-letter language code, e.g. 'en'. RETURNS (bool): Whether a Language class has been loaded. """ - global LANGUAGES - return lang in LANGUAGES + return lang in registry.languages def get_lang_class(lang): @@ -75,19 +66,16 @@ def get_lang_class(lang): lang (unicode): Two-letter language code, e.g. 'en'. RETURNS (Language): Language class. """ - global LANGUAGES - # Check if an entry point is exposed for the language code - entry_point = get_entry_point(ENTRY_POINTS.languages, lang) - if entry_point is not None: - LANGUAGES[lang] = entry_point - return entry_point - if lang not in LANGUAGES: + # Check if language is registered / entry point is available + if lang in registry.languages: + return registry.languages.get(lang) + else: try: module = importlib.import_module(".lang.%s" % lang, "spacy") except ImportError as err: raise ImportError(Errors.E048.format(lang=lang, err=err)) - LANGUAGES[lang] = getattr(module, module.__all__[0]) - return LANGUAGES[lang] + set_lang_class(lang, getattr(module, module.__all__[0])) + return registry.languages.get(lang) def set_lang_class(name, cls): @@ -96,8 +84,7 @@ def set_lang_class(name, cls): name (unicode): Name of Language class. cls (Language): Language class. """ - global LANGUAGES - LANGUAGES[name] = cls + registry.languages.register(name, func=cls) def get_data_path(require_exists=True): @@ -121,49 +108,11 @@ def set_data_path(path): _data_path = ensure_path(path) -def register_architecture(name, arch=None): - """Decorator to register an architecture. An architecture is a function - that returns a Thinc Model object. - - name (unicode): The name of the architecture to register. - arch (Model): Optional architecture if function is called directly and - not used as a decorator. - RETURNS (callable): Function to register architecture. - """ - global ARCHITECTURES - if arch is not None: - ARCHITECTURES[name] = arch - return arch - - def do_registration(arch): - ARCHITECTURES[name] = arch - return arch - - return do_registration - - def make_layer(arch_config): - arch_func = get_architecture(arch_config["arch"]) + arch_func = registry.architectures.get(arch_config["arch"]) return arch_func(arch_config["config"]) -def get_architecture(name): - """Get a model architecture function by name. Raises a KeyError if the - architecture is not found. - - name (unicode): The mame of the architecture. - RETURNS (Model): The architecture. - """ - # Check if an entry point is exposed for the architecture code - entry_point = get_entry_point(ENTRY_POINTS.architectures, name) - if entry_point is not None: - ARCHITECTURES[name] = entry_point - if name not in ARCHITECTURES: - names = ", ".join(sorted(ARCHITECTURES.keys())) - raise KeyError(Errors.E174.format(name=name, names=names)) - return ARCHITECTURES[name] - - def ensure_path(path): """Ensure string is converted to a Path. @@ -327,34 +276,6 @@ def get_package_path(name): return Path(pkg.__file__).parent -def get_entry_points(key): - """Get registered entry points from other packages for a given key, e.g. - 'spacy_factories' and return them as a dictionary, keyed by name. - - key (unicode): Entry point name. - RETURNS (dict): Entry points, keyed by name. - """ - result = {} - for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []): - result[entry_point.name] = entry_point.load() - return result - - -def get_entry_point(key, value, default=None): - """Check if registered entry point is available for a given name and - load it. Otherwise, return None. - - key (unicode): Entry point name. - value (unicode): Name of entry point to load. - default: Optional default value to return. - RETURNS: The loaded entry point or None. - """ - for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []): - if entry_point.name == value: - return entry_point.load() - return default - - def is_in_jupyter(): """Check if user is running spaCy from a Jupyter notebook by detecting the IPython kernel. Mainly used for the displaCy visualizer.