Replace function registries with catalogue (#4584)

* Replace functions registries with catalogue

* Update __init__.py

* Fix test

* Revert unrelated flag [ci skip]
This commit is contained in:
Ines Montani 2019-11-07 11:45:22 +01:00 committed by GitHub
parent 0f8678c0b1
commit 09cec3e41b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 56 additions and 140 deletions

View File

@ -6,12 +6,12 @@ blis>=0.4.0,<0.5.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third party dependencies
numpy>=1.15.0
requests>=2.13.0,<3.0.0
plac>=0.9.6,<1.2.0
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"
# Optional dependencies
jsonschema>=2.6.0,<3.1.0
# Development dependencies

View File

@ -48,13 +48,13 @@ install_requires =
blis>=0.4.0,<0.5.0
wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third-party dependencies
setuptools
numpy>=1.15.0
plac>=0.9.6,<1.2.0
requests>=2.13.0,<3.0.0
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"
[options.extras_require]
lookups =

View File

@ -15,7 +15,7 @@ from .glossary import explain
from .about import __version__
from .errors import Errors, Warnings, deprecation_warning
from . import util
from .util import register_architecture, get_architecture
from .util import registry
from .language import component

View File

@ -36,11 +36,6 @@ try:
except ImportError:
cupy = None
try: # Python 3.8
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata # noqa: F401
try:
from thinc.neural.optimizers import Optimizer # noqa: F401
except ImportError:

View File

@ -5,7 +5,7 @@ import uuid
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from ..util import minify_html, escape_html, get_entry_points, ENTRY_POINTS
from ..util import minify_html, escape_html, registry
from ..errors import Errors
@ -242,7 +242,7 @@ class EntityRenderer(object):
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
user_colors = get_entry_points(ENTRY_POINTS.displacy_colors)
user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values():
colors.update(user_color)
colors.update(options.get("colors", {}))

View File

@ -51,8 +51,8 @@ class BaseDefaults(object):
filenames = {name: root / filename for name, filename in cls.resources}
if LANG in cls.lex_attr_getters:
lang = cls.lex_attr_getters[LANG](None)
user_lookups = util.get_entry_point(util.ENTRY_POINTS.lookups, lang, {})
filenames.update(user_lookups)
if lang in util.registry.lookups:
filenames.update(util.registry.lookups.get(lang))
lookups = Lookups()
for name, filename in filenames.items():
data = util.load_language_data(filename)
@ -155,7 +155,7 @@ class Language(object):
100,000 characters in one text.
RETURNS (Language): The newly constructed object.
"""
user_factories = util.get_entry_points(util.ENTRY_POINTS.factories)
user_factories = util.registry.factories.get_all()
self.factories.update(user_factories)
self._meta = dict(meta)
self._path = None

View File

@ -3,10 +3,10 @@ from __future__ import unicode_literals
from thinc.api import chain
from thinc.v2v import Maxout
from thinc.misc import LayerNorm
from ..util import register_architecture, make_layer
from ..util import registry, make_layer
@register_architecture("thinc.FeedForward.v1")
@registry.architectures.register("thinc.FeedForward.v1")
def FeedForward(config):
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
model = chain(*layers)
@ -14,7 +14,7 @@ def FeedForward(config):
return model
@register_architecture("spacy.LayerNormalizedMaxout.v1")
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(config):
width = config["width"]
pieces = config["pieces"]

View File

@ -6,11 +6,11 @@ from thinc.v2v import Maxout, Model
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual, LayerNorm, FeatureExtracter
from ..util import make_layer, register_architecture
from ..util import make_layer, registry
from ._wire import concatenate_lists
@register_architecture("spacy.Tok2Vec.v1")
@registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
@ -24,13 +24,13 @@ def Tok2Vec(config):
return tok2vec
@register_architecture("spacy.Doc2Feats.v1")
@registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config):
columns = config["columns"]
return FeatureExtracter(columns)
@register_architecture("spacy.MultiHashEmbed.v1")
@registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config):
# For backwards compatibility with models before the architecture registry,
# we have to be careful to get exactly the same model structure. One subtle
@ -78,7 +78,7 @@ def MultiHashEmbed(config):
return layer
@register_architecture("spacy.CharacterEmbed.v1")
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
from .. import _ml
@ -94,7 +94,7 @@ def CharacterEmbed(config):
return model
@register_architecture("spacy.MaxoutWindowEncoder.v1")
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
@ -110,7 +110,7 @@ def MaxoutWindowEncoder(config):
return model
@register_architecture("spacy.MishWindowEncoder.v1")
@registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config):
from thinc.v2v import Mish
@ -124,12 +124,12 @@ def MishWindowEncoder(config):
return model
@register_architecture("spacy.PretrainedVectors.v1")
@registry.architectures.register("spacy.PretrainedVectors.v1")
def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"])
@register_architecture("spacy.TorchBiLSTMEncoder.v1")
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN

View File

@ -0,0 +1,19 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy import registry
from thinc.v2v import Affine
from catalogue import RegistryError
@registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)
def test_get_architecture():
arch = registry.architectures.get("my_test_function")
assert arch is create_model
with pytest.raises(RegistryError):
registry.architectures.get("not_an_existing_key")

View File

@ -1,19 +0,0 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy import register_architecture
from spacy import get_architecture
from thinc.v2v import Affine
@register_architecture("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)
def test_get_architecture():
arch = get_architecture("my_test_function")
assert arch is create_model
with pytest.raises(KeyError):
get_architecture("not_an_existing_key")

View File

@ -13,6 +13,7 @@ import functools
import itertools
import numpy.random
import srsly
import catalogue
import sys
try:
@ -27,29 +28,20 @@ except ImportError:
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
from .compat import import_file, importlib_metadata
from .compat import import_file
from .errors import Errors, Warnings, deprecation_warning
LANGUAGES = {}
ARCHITECTURES = {}
_data_path = Path(__file__).parent / "data"
_PRINT_ENV = False
# NB: Ony ever call this once! If called more than ince within the
# function, test_issue1506 hangs and it's not 100% clear why.
AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points()
class ENTRY_POINTS(object):
"""Available entry points to register extensions."""
factories = "spacy_factories"
languages = "spacy_languages"
displacy_colors = "spacy_displacy_colors"
lookups = "spacy_lookups"
architectures = "spacy_architectures"
class registry(object):
languages = catalogue.create("spacy", "languages", entry_points=True)
architectures = catalogue.create("spacy", "architectures", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True)
factories = catalogue.create("spacy", "factories", entry_points=True)
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
def set_env_log(value):
@ -65,8 +57,7 @@ def lang_class_is_loaded(lang):
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (bool): Whether a Language class has been loaded.
"""
global LANGUAGES
return lang in LANGUAGES
return lang in registry.languages
def get_lang_class(lang):
@ -75,19 +66,16 @@ def get_lang_class(lang):
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (Language): Language class.
"""
global LANGUAGES
# Check if an entry point is exposed for the language code
entry_point = get_entry_point(ENTRY_POINTS.languages, lang)
if entry_point is not None:
LANGUAGES[lang] = entry_point
return entry_point
if lang not in LANGUAGES:
# Check if language is registered / entry point is available
if lang in registry.languages:
return registry.languages.get(lang)
else:
try:
module = importlib.import_module(".lang.%s" % lang, "spacy")
except ImportError as err:
raise ImportError(Errors.E048.format(lang=lang, err=err))
LANGUAGES[lang] = getattr(module, module.__all__[0])
return LANGUAGES[lang]
set_lang_class(lang, getattr(module, module.__all__[0]))
return registry.languages.get(lang)
def set_lang_class(name, cls):
@ -96,8 +84,7 @@ def set_lang_class(name, cls):
name (unicode): Name of Language class.
cls (Language): Language class.
"""
global LANGUAGES
LANGUAGES[name] = cls
registry.languages.register(name, func=cls)
def get_data_path(require_exists=True):
@ -121,49 +108,11 @@ def set_data_path(path):
_data_path = ensure_path(path)
def register_architecture(name, arch=None):
"""Decorator to register an architecture. An architecture is a function
that returns a Thinc Model object.
name (unicode): The name of the architecture to register.
arch (Model): Optional architecture if function is called directly and
not used as a decorator.
RETURNS (callable): Function to register architecture.
"""
global ARCHITECTURES
if arch is not None:
ARCHITECTURES[name] = arch
return arch
def do_registration(arch):
ARCHITECTURES[name] = arch
return arch
return do_registration
def make_layer(arch_config):
arch_func = get_architecture(arch_config["arch"])
arch_func = registry.architectures.get(arch_config["arch"])
return arch_func(arch_config["config"])
def get_architecture(name):
"""Get a model architecture function by name. Raises a KeyError if the
architecture is not found.
name (unicode): The mame of the architecture.
RETURNS (Model): The architecture.
"""
# Check if an entry point is exposed for the architecture code
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
if entry_point is not None:
ARCHITECTURES[name] = entry_point
if name not in ARCHITECTURES:
names = ", ".join(sorted(ARCHITECTURES.keys()))
raise KeyError(Errors.E174.format(name=name, names=names))
return ARCHITECTURES[name]
def ensure_path(path):
"""Ensure string is converted to a Path.
@ -327,34 +276,6 @@ def get_package_path(name):
return Path(pkg.__file__).parent
def get_entry_points(key):
"""Get registered entry points from other packages for a given key, e.g.
'spacy_factories' and return them as a dictionary, keyed by name.
key (unicode): Entry point name.
RETURNS (dict): Entry points, keyed by name.
"""
result = {}
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
result[entry_point.name] = entry_point.load()
return result
def get_entry_point(key, value, default=None):
"""Check if registered entry point is available for a given name and
load it. Otherwise, return None.
key (unicode): Entry point name.
value (unicode): Name of entry point to load.
default: Optional default value to return.
RETURNS: The loaded entry point or None.
"""
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
if entry_point.name == value:
return entry_point.load()
return default
def is_in_jupyter():
"""Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer.