Replace function registries with catalogue (#4584)

* Replace functions registries with catalogue

* Update __init__.py

* Fix test

* Revert unrelated flag [ci skip]
This commit is contained in:
Ines Montani 2019-11-07 11:45:22 +01:00 committed by GitHub
parent 0f8678c0b1
commit 09cec3e41b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 56 additions and 140 deletions

View File

@ -6,12 +6,12 @@ blis>=0.4.0,<0.5.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.4.0,<1.1.0 wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0 srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third party dependencies # Third party dependencies
numpy>=1.15.0 numpy>=1.15.0
requests>=2.13.0,<3.0.0 requests>=2.13.0,<3.0.0
plac>=0.9.6,<1.2.0 plac>=0.9.6,<1.2.0
pathlib==1.0.1; python_version < "3.4" pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"
# Optional dependencies # Optional dependencies
jsonschema>=2.6.0,<3.1.0 jsonschema>=2.6.0,<3.1.0
# Development dependencies # Development dependencies

View File

@ -48,13 +48,13 @@ install_requires =
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.4.0,<1.1.0 wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0 srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third-party dependencies # Third-party dependencies
setuptools setuptools
numpy>=1.15.0 numpy>=1.15.0
plac>=0.9.6,<1.2.0 plac>=0.9.6,<1.2.0
requests>=2.13.0,<3.0.0 requests>=2.13.0,<3.0.0
pathlib==1.0.1; python_version < "3.4" pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"
[options.extras_require] [options.extras_require]
lookups = lookups =

View File

@ -15,7 +15,7 @@ from .glossary import explain
from .about import __version__ from .about import __version__
from .errors import Errors, Warnings, deprecation_warning from .errors import Errors, Warnings, deprecation_warning
from . import util from . import util
from .util import register_architecture, get_architecture from .util import registry
from .language import component from .language import component

View File

@ -36,11 +36,6 @@ try:
except ImportError: except ImportError:
cupy = None cupy = None
try: # Python 3.8
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata # noqa: F401
try: try:
from thinc.neural.optimizers import Optimizer # noqa: F401 from thinc.neural.optimizers import Optimizer # noqa: F401
except ImportError: except ImportError:

View File

@ -5,7 +5,7 @@ import uuid
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from ..util import minify_html, escape_html, get_entry_points, ENTRY_POINTS from ..util import minify_html, escape_html, registry
from ..errors import Errors from ..errors import Errors
@ -242,7 +242,7 @@ class EntityRenderer(object):
"CARDINAL": "#e4e7d2", "CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2", "PERCENT": "#e4e7d2",
} }
user_colors = get_entry_points(ENTRY_POINTS.displacy_colors) user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values(): for user_color in user_colors.values():
colors.update(user_color) colors.update(user_color)
colors.update(options.get("colors", {})) colors.update(options.get("colors", {}))

View File

@ -51,8 +51,8 @@ class BaseDefaults(object):
filenames = {name: root / filename for name, filename in cls.resources} filenames = {name: root / filename for name, filename in cls.resources}
if LANG in cls.lex_attr_getters: if LANG in cls.lex_attr_getters:
lang = cls.lex_attr_getters[LANG](None) lang = cls.lex_attr_getters[LANG](None)
user_lookups = util.get_entry_point(util.ENTRY_POINTS.lookups, lang, {}) if lang in util.registry.lookups:
filenames.update(user_lookups) filenames.update(util.registry.lookups.get(lang))
lookups = Lookups() lookups = Lookups()
for name, filename in filenames.items(): for name, filename in filenames.items():
data = util.load_language_data(filename) data = util.load_language_data(filename)
@ -155,7 +155,7 @@ class Language(object):
100,000 characters in one text. 100,000 characters in one text.
RETURNS (Language): The newly constructed object. RETURNS (Language): The newly constructed object.
""" """
user_factories = util.get_entry_points(util.ENTRY_POINTS.factories) user_factories = util.registry.factories.get_all()
self.factories.update(user_factories) self.factories.update(user_factories)
self._meta = dict(meta) self._meta = dict(meta)
self._path = None self._path = None

View File

@ -3,10 +3,10 @@ from __future__ import unicode_literals
from thinc.api import chain from thinc.api import chain
from thinc.v2v import Maxout from thinc.v2v import Maxout
from thinc.misc import LayerNorm from thinc.misc import LayerNorm
from ..util import register_architecture, make_layer from ..util import registry, make_layer
@register_architecture("thinc.FeedForward.v1") @registry.architectures.register("thinc.FeedForward.v1")
def FeedForward(config): def FeedForward(config):
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]] layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
model = chain(*layers) model = chain(*layers)
@ -14,7 +14,7 @@ def FeedForward(config):
return model return model
@register_architecture("spacy.LayerNormalizedMaxout.v1") @registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(config): def LayerNormalizedMaxout(config):
width = config["width"] width = config["width"]
pieces = config["pieces"] pieces = config["pieces"]

View File

@ -6,11 +6,11 @@ from thinc.v2v import Maxout, Model
from thinc.i2v import HashEmbed, StaticVectors from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow from thinc.t2t import ExtractWindow
from thinc.misc import Residual, LayerNorm, FeatureExtracter from thinc.misc import Residual, LayerNorm, FeatureExtracter
from ..util import make_layer, register_architecture from ..util import make_layer, registry
from ._wire import concatenate_lists from ._wire import concatenate_lists
@register_architecture("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config): def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"]) doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"]) embed = make_layer(config["@embed"])
@ -24,13 +24,13 @@ def Tok2Vec(config):
return tok2vec return tok2vec
@register_architecture("spacy.Doc2Feats.v1") @registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config): def Doc2Feats(config):
columns = config["columns"] columns = config["columns"]
return FeatureExtracter(columns) return FeatureExtracter(columns)
@register_architecture("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config): def MultiHashEmbed(config):
# For backwards compatibility with models before the architecture registry, # For backwards compatibility with models before the architecture registry,
# we have to be careful to get exactly the same model structure. One subtle # we have to be careful to get exactly the same model structure. One subtle
@ -78,7 +78,7 @@ def MultiHashEmbed(config):
return layer return layer
@register_architecture("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config): def CharacterEmbed(config):
from .. import _ml from .. import _ml
@ -94,7 +94,7 @@ def CharacterEmbed(config):
return model return model
@register_architecture("spacy.MaxoutWindowEncoder.v1") @registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config): def MaxoutWindowEncoder(config):
nO = config["width"] nO = config["width"]
nW = config["window_size"] nW = config["window_size"]
@ -110,7 +110,7 @@ def MaxoutWindowEncoder(config):
return model return model
@register_architecture("spacy.MishWindowEncoder.v1") @registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config): def MishWindowEncoder(config):
from thinc.v2v import Mish from thinc.v2v import Mish
@ -124,12 +124,12 @@ def MishWindowEncoder(config):
return model return model
@register_architecture("spacy.PretrainedVectors.v1") @registry.architectures.register("spacy.PretrainedVectors.v1")
def PretrainedVectors(config): def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"]) return StaticVectors(config["vectors_name"], config["width"], config["column"])
@register_architecture("spacy.TorchBiLSTMEncoder.v1") @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config): def TorchBiLSTMEncoder(config):
import torch.nn import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN from thinc.extra.wrappers import PyTorchWrapperRNN

View File

@ -0,0 +1,19 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy import registry
from thinc.v2v import Affine
from catalogue import RegistryError
@registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)
def test_get_architecture():
arch = registry.architectures.get("my_test_function")
assert arch is create_model
with pytest.raises(RegistryError):
registry.architectures.get("not_an_existing_key")

View File

@ -1,19 +0,0 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy import register_architecture
from spacy import get_architecture
from thinc.v2v import Affine
@register_architecture("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)
def test_get_architecture():
arch = get_architecture("my_test_function")
assert arch is create_model
with pytest.raises(KeyError):
get_architecture("not_an_existing_key")

View File

@ -13,6 +13,7 @@ import functools
import itertools import itertools
import numpy.random import numpy.random
import srsly import srsly
import catalogue
import sys import sys
try: try:
@ -27,29 +28,20 @@ except ImportError:
from .symbols import ORTH from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, unicode_ from .compat import cupy, CudaStream, path2str, basestring_, unicode_
from .compat import import_file, importlib_metadata from .compat import import_file
from .errors import Errors, Warnings, deprecation_warning from .errors import Errors, Warnings, deprecation_warning
LANGUAGES = {}
ARCHITECTURES = {}
_data_path = Path(__file__).parent / "data" _data_path = Path(__file__).parent / "data"
_PRINT_ENV = False _PRINT_ENV = False
# NB: Ony ever call this once! If called more than ince within the class registry(object):
# function, test_issue1506 hangs and it's not 100% clear why. languages = catalogue.create("spacy", "languages", entry_points=True)
AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points() architectures = catalogue.create("spacy", "architectures", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True)
factories = catalogue.create("spacy", "factories", entry_points=True)
class ENTRY_POINTS(object): displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
"""Available entry points to register extensions."""
factories = "spacy_factories"
languages = "spacy_languages"
displacy_colors = "spacy_displacy_colors"
lookups = "spacy_lookups"
architectures = "spacy_architectures"
def set_env_log(value): def set_env_log(value):
@ -65,8 +57,7 @@ def lang_class_is_loaded(lang):
lang (unicode): Two-letter language code, e.g. 'en'. lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (bool): Whether a Language class has been loaded. RETURNS (bool): Whether a Language class has been loaded.
""" """
global LANGUAGES return lang in registry.languages
return lang in LANGUAGES
def get_lang_class(lang): def get_lang_class(lang):
@ -75,19 +66,16 @@ def get_lang_class(lang):
lang (unicode): Two-letter language code, e.g. 'en'. lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (Language): Language class. RETURNS (Language): Language class.
""" """
global LANGUAGES # Check if language is registered / entry point is available
# Check if an entry point is exposed for the language code if lang in registry.languages:
entry_point = get_entry_point(ENTRY_POINTS.languages, lang) return registry.languages.get(lang)
if entry_point is not None: else:
LANGUAGES[lang] = entry_point
return entry_point
if lang not in LANGUAGES:
try: try:
module = importlib.import_module(".lang.%s" % lang, "spacy") module = importlib.import_module(".lang.%s" % lang, "spacy")
except ImportError as err: except ImportError as err:
raise ImportError(Errors.E048.format(lang=lang, err=err)) raise ImportError(Errors.E048.format(lang=lang, err=err))
LANGUAGES[lang] = getattr(module, module.__all__[0]) set_lang_class(lang, getattr(module, module.__all__[0]))
return LANGUAGES[lang] return registry.languages.get(lang)
def set_lang_class(name, cls): def set_lang_class(name, cls):
@ -96,8 +84,7 @@ def set_lang_class(name, cls):
name (unicode): Name of Language class. name (unicode): Name of Language class.
cls (Language): Language class. cls (Language): Language class.
""" """
global LANGUAGES registry.languages.register(name, func=cls)
LANGUAGES[name] = cls
def get_data_path(require_exists=True): def get_data_path(require_exists=True):
@ -121,49 +108,11 @@ def set_data_path(path):
_data_path = ensure_path(path) _data_path = ensure_path(path)
def register_architecture(name, arch=None):
"""Decorator to register an architecture. An architecture is a function
that returns a Thinc Model object.
name (unicode): The name of the architecture to register.
arch (Model): Optional architecture if function is called directly and
not used as a decorator.
RETURNS (callable): Function to register architecture.
"""
global ARCHITECTURES
if arch is not None:
ARCHITECTURES[name] = arch
return arch
def do_registration(arch):
ARCHITECTURES[name] = arch
return arch
return do_registration
def make_layer(arch_config): def make_layer(arch_config):
arch_func = get_architecture(arch_config["arch"]) arch_func = registry.architectures.get(arch_config["arch"])
return arch_func(arch_config["config"]) return arch_func(arch_config["config"])
def get_architecture(name):
"""Get a model architecture function by name. Raises a KeyError if the
architecture is not found.
name (unicode): The mame of the architecture.
RETURNS (Model): The architecture.
"""
# Check if an entry point is exposed for the architecture code
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
if entry_point is not None:
ARCHITECTURES[name] = entry_point
if name not in ARCHITECTURES:
names = ", ".join(sorted(ARCHITECTURES.keys()))
raise KeyError(Errors.E174.format(name=name, names=names))
return ARCHITECTURES[name]
def ensure_path(path): def ensure_path(path):
"""Ensure string is converted to a Path. """Ensure string is converted to a Path.
@ -327,34 +276,6 @@ def get_package_path(name):
return Path(pkg.__file__).parent return Path(pkg.__file__).parent
def get_entry_points(key):
"""Get registered entry points from other packages for a given key, e.g.
'spacy_factories' and return them as a dictionary, keyed by name.
key (unicode): Entry point name.
RETURNS (dict): Entry points, keyed by name.
"""
result = {}
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
result[entry_point.name] = entry_point.load()
return result
def get_entry_point(key, value, default=None):
"""Check if registered entry point is available for a given name and
load it. Otherwise, return None.
key (unicode): Entry point name.
value (unicode): Name of entry point to load.
default: Optional default value to return.
RETURNS: The loaded entry point or None.
"""
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
if entry_point.name == value:
return entry_point.load()
return default
def is_in_jupyter(): def is_in_jupyter():
"""Check if user is running spaCy from a Jupyter notebook by detecting the """Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer. IPython kernel. Mainly used for the displaCy visualizer.