mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Replace function registries with catalogue (#4584)
* Replace functions registries with catalogue * Update __init__.py * Fix test * Revert unrelated flag [ci skip]
This commit is contained in:
parent
0f8678c0b1
commit
09cec3e41b
|
@ -6,12 +6,12 @@ blis>=0.4.0,<0.5.0
|
|||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.4.0,<1.1.0
|
||||
srsly>=0.1.0,<1.1.0
|
||||
catalogue>=0.0.7,<1.1.0
|
||||
# Third party dependencies
|
||||
numpy>=1.15.0
|
||||
requests>=2.13.0,<3.0.0
|
||||
plac>=0.9.6,<1.2.0
|
||||
pathlib==1.0.1; python_version < "3.4"
|
||||
importlib_metadata>=0.20; python_version < "3.8"
|
||||
# Optional dependencies
|
||||
jsonschema>=2.6.0,<3.1.0
|
||||
# Development dependencies
|
||||
|
|
|
@ -48,13 +48,13 @@ install_requires =
|
|||
blis>=0.4.0,<0.5.0
|
||||
wasabi>=0.4.0,<1.1.0
|
||||
srsly>=0.1.0,<1.1.0
|
||||
catalogue>=0.0.7,<1.1.0
|
||||
# Third-party dependencies
|
||||
setuptools
|
||||
numpy>=1.15.0
|
||||
plac>=0.9.6,<1.2.0
|
||||
requests>=2.13.0,<3.0.0
|
||||
pathlib==1.0.1; python_version < "3.4"
|
||||
importlib_metadata>=0.20; python_version < "3.8"
|
||||
|
||||
[options.extras_require]
|
||||
lookups =
|
||||
|
|
|
@ -15,7 +15,7 @@ from .glossary import explain
|
|||
from .about import __version__
|
||||
from .errors import Errors, Warnings, deprecation_warning
|
||||
from . import util
|
||||
from .util import register_architecture, get_architecture
|
||||
from .util import registry
|
||||
from .language import component
|
||||
|
||||
|
||||
|
|
|
@ -36,11 +36,6 @@ try:
|
|||
except ImportError:
|
||||
cupy = None
|
||||
|
||||
try: # Python 3.8
|
||||
import importlib.metadata as importlib_metadata
|
||||
except ImportError:
|
||||
import importlib_metadata # noqa: F401
|
||||
|
||||
try:
|
||||
from thinc.neural.optimizers import Optimizer # noqa: F401
|
||||
except ImportError:
|
||||
|
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
|
||||
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS
|
||||
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
|
||||
from ..util import minify_html, escape_html, get_entry_points, ENTRY_POINTS
|
||||
from ..util import minify_html, escape_html, registry
|
||||
from ..errors import Errors
|
||||
|
||||
|
||||
|
@ -242,7 +242,7 @@ class EntityRenderer(object):
|
|||
"CARDINAL": "#e4e7d2",
|
||||
"PERCENT": "#e4e7d2",
|
||||
}
|
||||
user_colors = get_entry_points(ENTRY_POINTS.displacy_colors)
|
||||
user_colors = registry.displacy_colors.get_all()
|
||||
for user_color in user_colors.values():
|
||||
colors.update(user_color)
|
||||
colors.update(options.get("colors", {}))
|
||||
|
|
|
@ -51,8 +51,8 @@ class BaseDefaults(object):
|
|||
filenames = {name: root / filename for name, filename in cls.resources}
|
||||
if LANG in cls.lex_attr_getters:
|
||||
lang = cls.lex_attr_getters[LANG](None)
|
||||
user_lookups = util.get_entry_point(util.ENTRY_POINTS.lookups, lang, {})
|
||||
filenames.update(user_lookups)
|
||||
if lang in util.registry.lookups:
|
||||
filenames.update(util.registry.lookups.get(lang))
|
||||
lookups = Lookups()
|
||||
for name, filename in filenames.items():
|
||||
data = util.load_language_data(filename)
|
||||
|
@ -155,7 +155,7 @@ class Language(object):
|
|||
100,000 characters in one text.
|
||||
RETURNS (Language): The newly constructed object.
|
||||
"""
|
||||
user_factories = util.get_entry_points(util.ENTRY_POINTS.factories)
|
||||
user_factories = util.registry.factories.get_all()
|
||||
self.factories.update(user_factories)
|
||||
self._meta = dict(meta)
|
||||
self._path = None
|
||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import unicode_literals
|
|||
from thinc.api import chain
|
||||
from thinc.v2v import Maxout
|
||||
from thinc.misc import LayerNorm
|
||||
from ..util import register_architecture, make_layer
|
||||
from ..util import registry, make_layer
|
||||
|
||||
|
||||
@register_architecture("thinc.FeedForward.v1")
|
||||
@registry.architectures.register("thinc.FeedForward.v1")
|
||||
def FeedForward(config):
|
||||
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
|
||||
model = chain(*layers)
|
||||
|
@ -14,7 +14,7 @@ def FeedForward(config):
|
|||
return model
|
||||
|
||||
|
||||
@register_architecture("spacy.LayerNormalizedMaxout.v1")
|
||||
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
|
||||
def LayerNormalizedMaxout(config):
|
||||
width = config["width"]
|
||||
pieces = config["pieces"]
|
||||
|
|
|
@ -6,11 +6,11 @@ from thinc.v2v import Maxout, Model
|
|||
from thinc.i2v import HashEmbed, StaticVectors
|
||||
from thinc.t2t import ExtractWindow
|
||||
from thinc.misc import Residual, LayerNorm, FeatureExtracter
|
||||
from ..util import make_layer, register_architecture
|
||||
from ..util import make_layer, registry
|
||||
from ._wire import concatenate_lists
|
||||
|
||||
|
||||
@register_architecture("spacy.Tok2Vec.v1")
|
||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||
def Tok2Vec(config):
|
||||
doc2feats = make_layer(config["@doc2feats"])
|
||||
embed = make_layer(config["@embed"])
|
||||
|
@ -24,13 +24,13 @@ def Tok2Vec(config):
|
|||
return tok2vec
|
||||
|
||||
|
||||
@register_architecture("spacy.Doc2Feats.v1")
|
||||
@registry.architectures.register("spacy.Doc2Feats.v1")
|
||||
def Doc2Feats(config):
|
||||
columns = config["columns"]
|
||||
return FeatureExtracter(columns)
|
||||
|
||||
|
||||
@register_architecture("spacy.MultiHashEmbed.v1")
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(config):
|
||||
# For backwards compatibility with models before the architecture registry,
|
||||
# we have to be careful to get exactly the same model structure. One subtle
|
||||
|
@ -78,7 +78,7 @@ def MultiHashEmbed(config):
|
|||
return layer
|
||||
|
||||
|
||||
@register_architecture("spacy.CharacterEmbed.v1")
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(config):
|
||||
from .. import _ml
|
||||
|
||||
|
@ -94,7 +94,7 @@ def CharacterEmbed(config):
|
|||
return model
|
||||
|
||||
|
||||
@register_architecture("spacy.MaxoutWindowEncoder.v1")
|
||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
||||
def MaxoutWindowEncoder(config):
|
||||
nO = config["width"]
|
||||
nW = config["window_size"]
|
||||
|
@ -110,7 +110,7 @@ def MaxoutWindowEncoder(config):
|
|||
return model
|
||||
|
||||
|
||||
@register_architecture("spacy.MishWindowEncoder.v1")
|
||||
@registry.architectures.register("spacy.MishWindowEncoder.v1")
|
||||
def MishWindowEncoder(config):
|
||||
from thinc.v2v import Mish
|
||||
|
||||
|
@ -124,12 +124,12 @@ def MishWindowEncoder(config):
|
|||
return model
|
||||
|
||||
|
||||
@register_architecture("spacy.PretrainedVectors.v1")
|
||||
@registry.architectures.register("spacy.PretrainedVectors.v1")
|
||||
def PretrainedVectors(config):
|
||||
return StaticVectors(config["vectors_name"], config["width"], config["column"])
|
||||
|
||||
|
||||
@register_architecture("spacy.TorchBiLSTMEncoder.v1")
|
||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||
def TorchBiLSTMEncoder(config):
|
||||
import torch.nn
|
||||
from thinc.extra.wrappers import PyTorchWrapperRNN
|
||||
|
|
19
spacy/tests/test_architectures.py
Normal file
19
spacy/tests/test_architectures.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy import registry
|
||||
from thinc.v2v import Affine
|
||||
from catalogue import RegistryError
|
||||
|
||||
|
||||
@registry.architectures.register("my_test_function")
|
||||
def create_model(nr_in, nr_out):
|
||||
return Affine(nr_in, nr_out)
|
||||
|
||||
|
||||
def test_get_architecture():
|
||||
arch = registry.architectures.get("my_test_function")
|
||||
assert arch is create_model
|
||||
with pytest.raises(RegistryError):
|
||||
registry.architectures.get("not_an_existing_key")
|
|
@ -1,19 +0,0 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy import register_architecture
|
||||
from spacy import get_architecture
|
||||
from thinc.v2v import Affine
|
||||
|
||||
|
||||
@register_architecture("my_test_function")
|
||||
def create_model(nr_in, nr_out):
|
||||
return Affine(nr_in, nr_out)
|
||||
|
||||
|
||||
def test_get_architecture():
|
||||
arch = get_architecture("my_test_function")
|
||||
assert arch is create_model
|
||||
with pytest.raises(KeyError):
|
||||
get_architecture("not_an_existing_key")
|
113
spacy/util.py
113
spacy/util.py
|
@ -13,6 +13,7 @@ import functools
|
|||
import itertools
|
||||
import numpy.random
|
||||
import srsly
|
||||
import catalogue
|
||||
import sys
|
||||
|
||||
try:
|
||||
|
@ -27,29 +28,20 @@ except ImportError:
|
|||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
|
||||
from .compat import import_file, importlib_metadata
|
||||
from .compat import import_file
|
||||
from .errors import Errors, Warnings, deprecation_warning
|
||||
|
||||
|
||||
LANGUAGES = {}
|
||||
ARCHITECTURES = {}
|
||||
_data_path = Path(__file__).parent / "data"
|
||||
_PRINT_ENV = False
|
||||
|
||||
|
||||
# NB: Ony ever call this once! If called more than ince within the
|
||||
# function, test_issue1506 hangs and it's not 100% clear why.
|
||||
AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points()
|
||||
|
||||
|
||||
class ENTRY_POINTS(object):
|
||||
"""Available entry points to register extensions."""
|
||||
|
||||
factories = "spacy_factories"
|
||||
languages = "spacy_languages"
|
||||
displacy_colors = "spacy_displacy_colors"
|
||||
lookups = "spacy_lookups"
|
||||
architectures = "spacy_architectures"
|
||||
class registry(object):
|
||||
languages = catalogue.create("spacy", "languages", entry_points=True)
|
||||
architectures = catalogue.create("spacy", "architectures", entry_points=True)
|
||||
lookups = catalogue.create("spacy", "lookups", entry_points=True)
|
||||
factories = catalogue.create("spacy", "factories", entry_points=True)
|
||||
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
|
||||
|
||||
|
||||
def set_env_log(value):
|
||||
|
@ -65,8 +57,7 @@ def lang_class_is_loaded(lang):
|
|||
lang (unicode): Two-letter language code, e.g. 'en'.
|
||||
RETURNS (bool): Whether a Language class has been loaded.
|
||||
"""
|
||||
global LANGUAGES
|
||||
return lang in LANGUAGES
|
||||
return lang in registry.languages
|
||||
|
||||
|
||||
def get_lang_class(lang):
|
||||
|
@ -75,19 +66,16 @@ def get_lang_class(lang):
|
|||
lang (unicode): Two-letter language code, e.g. 'en'.
|
||||
RETURNS (Language): Language class.
|
||||
"""
|
||||
global LANGUAGES
|
||||
# Check if an entry point is exposed for the language code
|
||||
entry_point = get_entry_point(ENTRY_POINTS.languages, lang)
|
||||
if entry_point is not None:
|
||||
LANGUAGES[lang] = entry_point
|
||||
return entry_point
|
||||
if lang not in LANGUAGES:
|
||||
# Check if language is registered / entry point is available
|
||||
if lang in registry.languages:
|
||||
return registry.languages.get(lang)
|
||||
else:
|
||||
try:
|
||||
module = importlib.import_module(".lang.%s" % lang, "spacy")
|
||||
except ImportError as err:
|
||||
raise ImportError(Errors.E048.format(lang=lang, err=err))
|
||||
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
||||
return LANGUAGES[lang]
|
||||
set_lang_class(lang, getattr(module, module.__all__[0]))
|
||||
return registry.languages.get(lang)
|
||||
|
||||
|
||||
def set_lang_class(name, cls):
|
||||
|
@ -96,8 +84,7 @@ def set_lang_class(name, cls):
|
|||
name (unicode): Name of Language class.
|
||||
cls (Language): Language class.
|
||||
"""
|
||||
global LANGUAGES
|
||||
LANGUAGES[name] = cls
|
||||
registry.languages.register(name, func=cls)
|
||||
|
||||
|
||||
def get_data_path(require_exists=True):
|
||||
|
@ -121,49 +108,11 @@ def set_data_path(path):
|
|||
_data_path = ensure_path(path)
|
||||
|
||||
|
||||
def register_architecture(name, arch=None):
|
||||
"""Decorator to register an architecture. An architecture is a function
|
||||
that returns a Thinc Model object.
|
||||
|
||||
name (unicode): The name of the architecture to register.
|
||||
arch (Model): Optional architecture if function is called directly and
|
||||
not used as a decorator.
|
||||
RETURNS (callable): Function to register architecture.
|
||||
"""
|
||||
global ARCHITECTURES
|
||||
if arch is not None:
|
||||
ARCHITECTURES[name] = arch
|
||||
return arch
|
||||
|
||||
def do_registration(arch):
|
||||
ARCHITECTURES[name] = arch
|
||||
return arch
|
||||
|
||||
return do_registration
|
||||
|
||||
|
||||
def make_layer(arch_config):
|
||||
arch_func = get_architecture(arch_config["arch"])
|
||||
arch_func = registry.architectures.get(arch_config["arch"])
|
||||
return arch_func(arch_config["config"])
|
||||
|
||||
|
||||
def get_architecture(name):
|
||||
"""Get a model architecture function by name. Raises a KeyError if the
|
||||
architecture is not found.
|
||||
|
||||
name (unicode): The mame of the architecture.
|
||||
RETURNS (Model): The architecture.
|
||||
"""
|
||||
# Check if an entry point is exposed for the architecture code
|
||||
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
|
||||
if entry_point is not None:
|
||||
ARCHITECTURES[name] = entry_point
|
||||
if name not in ARCHITECTURES:
|
||||
names = ", ".join(sorted(ARCHITECTURES.keys()))
|
||||
raise KeyError(Errors.E174.format(name=name, names=names))
|
||||
return ARCHITECTURES[name]
|
||||
|
||||
|
||||
def ensure_path(path):
|
||||
"""Ensure string is converted to a Path.
|
||||
|
||||
|
@ -327,34 +276,6 @@ def get_package_path(name):
|
|||
return Path(pkg.__file__).parent
|
||||
|
||||
|
||||
def get_entry_points(key):
|
||||
"""Get registered entry points from other packages for a given key, e.g.
|
||||
'spacy_factories' and return them as a dictionary, keyed by name.
|
||||
|
||||
key (unicode): Entry point name.
|
||||
RETURNS (dict): Entry points, keyed by name.
|
||||
"""
|
||||
result = {}
|
||||
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
|
||||
result[entry_point.name] = entry_point.load()
|
||||
return result
|
||||
|
||||
|
||||
def get_entry_point(key, value, default=None):
|
||||
"""Check if registered entry point is available for a given name and
|
||||
load it. Otherwise, return None.
|
||||
|
||||
key (unicode): Entry point name.
|
||||
value (unicode): Name of entry point to load.
|
||||
default: Optional default value to return.
|
||||
RETURNS: The loaded entry point or None.
|
||||
"""
|
||||
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
|
||||
if entry_point.name == value:
|
||||
return entry_point.load()
|
||||
return default
|
||||
|
||||
|
||||
def is_in_jupyter():
|
||||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||||
IPython kernel. Mainly used for the displaCy visualizer.
|
||||
|
|
Loading…
Reference in New Issue
Block a user