Merge pull request #6024 from explosion/chore/registry-renaming

This commit is contained in:
Ines Montani 2020-09-04 10:54:10 +02:00 committed by GitHub
commit 2189046869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 97 additions and 82 deletions

View File

@ -36,7 +36,7 @@ max_length = 0
limit = 0 limit = 0
[training.batcher] [training.batcher]
@batchers = "batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
discard_oversize = false discard_oversize = false
tolerance = 0.2 tolerance = 0.2

View File

@ -35,7 +35,7 @@ max_length = 0
limit = 0 limit = 0
[training.batcher] [training.batcher]
@batchers = "batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
discard_oversize = false discard_oversize = false
tolerance = 0.2 tolerance = 0.2

View File

@ -29,7 +29,7 @@ name = "{{ transformer["name"] }}"
tokenizer_config = {"use_fast": true} tokenizer_config = {"use_fast": true}
[components.transformer.model.get_spans] [components.transformer.model.get_spans]
@span_getters = "strided_spans.v1" @span_getters = "spacy-transformers.strided_spans.v1"
window = 128 window = 128
stride = 96 stride = 96
@ -204,13 +204,13 @@ max_length = 0
{% if use_transformer %} {% if use_transformer %}
[training.batcher] [training.batcher]
@batchers = "batch_by_padded.v1" @batchers = "spacy.batch_by_padded.v1"
discard_oversize = true discard_oversize = true
size = 2000 size = 2000
buffer = 256 buffer = 256
{%- else %} {%- else %}
[training.batcher] [training.batcher]
@batchers = "batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
discard_oversize = false discard_oversize = false
tolerance = 0.2 tolerance = 0.2

View File

@ -69,7 +69,7 @@ max_length = 2000
limit = 0 limit = 0
[training.batcher] [training.batcher]
@batchers = "batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
discard_oversize = false discard_oversize = false
tolerance = 0.2 tolerance = 0.2

View File

@ -249,6 +249,12 @@ class EntityRenderer:
colors = dict(DEFAULT_LABEL_COLORS) colors = dict(DEFAULT_LABEL_COLORS)
user_colors = registry.displacy_colors.get_all() user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values(): for user_color in user_colors.values():
if callable(user_color):
# Since this comes from the function registry, we want to make
# sure we support functions that *return* a dict of colors
user_color = user_color()
if not isinstance(user_color, dict):
raise ValueError(Errors.E925.format(obj=type(user_color)))
colors.update(user_color) colors.update(user_color)
colors.update(options.get("colors", {})) colors.update(options.get("colors", {}))
self.default_color = DEFAULT_ENTITY_COLOR self.default_color = DEFAULT_ENTITY_COLOR

View File

@ -476,6 +476,8 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E925 = ("Invalid color values for displaCy visualizer: expected dictionary "
"mapping label names to colors but got: {obj}")
E926 = ("It looks like you're trying to modify nlp.{attr} directly. This " E926 = ("It looks like you're trying to modify nlp.{attr} directly. This "
"doesn't work because it's an immutable computed property. If you " "doesn't work because it's an immutable computed property. If you "
"need to modify the pipeline, use the built-in methods like " "need to modify the pipeline, use the built-in methods like "

View File

@ -11,7 +11,7 @@ ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@registry.batchers("batch_by_padded.v1") @registry.batchers("spacy.batch_by_padded.v1")
def configure_minibatch_by_padded_size( def configure_minibatch_by_padded_size(
*, *,
size: Sizing, size: Sizing,
@ -46,7 +46,7 @@ def configure_minibatch_by_padded_size(
) )
@registry.batchers("batch_by_words.v1") @registry.batchers("spacy.batch_by_words.v1")
def configure_minibatch_by_words( def configure_minibatch_by_words(
*, *,
size: Sizing, size: Sizing,
@ -70,7 +70,7 @@ def configure_minibatch_by_words(
) )
@registry.batchers("batch_by_sequence.v1") @registry.batchers("spacy.batch_by_sequence.v1")
def configure_minibatch( def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT: ) -> BatcherT:

View File

@ -24,7 +24,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
return model return model
@registry.assets.register("spacy.KBFromFile.v1") @registry.misc.register("spacy.KBFromFile.v1")
def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]: def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab): def kb_from_file(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
@ -34,7 +34,7 @@ def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
return kb_from_file return kb_from_file
@registry.assets.register("spacy.EmptyKB.v1") @registry.misc.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]: def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab): def empty_kb_factory(vocab):
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length) return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
@ -42,6 +42,6 @@ def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
return empty_kb_factory return empty_kb_factory
@registry.assets.register("spacy.CandidateGenerator.v1") @registry.misc.register("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]: def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates return get_candidates

View File

@ -39,12 +39,12 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 64}, "kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL, "model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"get_candidates": {"@assets": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
}, },
) )
def make_entity_linker( def make_entity_linker(

View File

@ -14,7 +14,7 @@ LANGUAGES = ["el", "en", "fr", "nl"]
@pytest.mark.parametrize("lang", LANGUAGES) @pytest.mark.parametrize("lang", LANGUAGES)
def test_lemmatizer_initialize(lang, capfd): def test_lemmatizer_initialize(lang, capfd):
@registry.assets("lemmatizer_init_lookups") @registry.misc("lemmatizer_init_lookups")
def lemmatizer_init_lookups(): def lemmatizer_init_lookups():
lookups = Lookups() lookups = Lookups()
lookups.add_table("lemma_lookup", {"cope": "cope"}) lookups.add_table("lemma_lookup", {"cope": "cope"})
@ -25,9 +25,7 @@ def test_lemmatizer_initialize(lang, capfd):
"""Test that languages can be initialized.""" """Test that languages can be initialized."""
nlp = get_lang_class(lang)() nlp = get_lang_class(lang)()
nlp.add_pipe( nlp.add_pipe("lemmatizer", config={"lookups": {"@misc": "lemmatizer_init_lookups"}})
"lemmatizer", config={"lookups": {"@assets": "lemmatizer_init_lookups"}}
)
# Check for stray print statements (see #3342) # Check for stray print statements (see #3342)
doc = nlp("test") # noqa: F841 doc = nlp("test") # noqa: F841
captured = capfd.readouterr() captured = capfd.readouterr()

View File

@ -31,7 +31,7 @@ def pattern_dicts():
] ]
@registry.assets("attribute_ruler_patterns") @registry.misc("attribute_ruler_patterns")
def attribute_ruler_patterns(): def attribute_ruler_patterns():
return [ return [
{ {
@ -86,7 +86,7 @@ def test_attributeruler_init_patterns(nlp, pattern_dicts):
# initialize with patterns from asset # initialize with patterns from asset
nlp.add_pipe( nlp.add_pipe(
"attribute_ruler", "attribute_ruler",
config={"pattern_dicts": {"@assets": "attribute_ruler_patterns"}}, config={"pattern_dicts": {"@misc": "attribute_ruler_patterns"}},
) )
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert doc[2].lemma_ == "the" assert doc[2].lemma_ == "the"

View File

@ -137,7 +137,7 @@ def test_kb_undefined(nlp):
def test_kb_empty(nlp): def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB""" """Test that the EL can't train with an empty KB"""
config = {"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}} config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -183,7 +183,7 @@ def test_el_pipe_configuration(nlp):
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern]) ruler.add_patterns([pattern])
@registry.assets.register("myAdamKB.v1") @registry.misc.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]: def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
@ -199,7 +199,7 @@ def test_el_pipe_configuration(nlp):
# run an EL pipe without a trained context encoder, to check the candidate generation step only # run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe( nlp.add_pipe(
"entity_linker", "entity_linker",
config={"kb_loader": {"@assets": "myAdamKB.v1"}, "incl_context": False}, config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False},
) )
# With the default get_candidates function, matching is case-sensitive # With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same." text = "Douglas and douglas are not the same."
@ -211,7 +211,7 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb, span): def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower()) return kb.get_alias_candidates(span.text.lower())
@registry.assets.register("spacy.LowercaseCandidateGenerator.v1") @registry.misc.register("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]: def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates return get_lowercased_candidates
@ -220,9 +220,9 @@ def test_el_pipe_configuration(nlp):
"entity_linker", "entity_linker",
"entity_linker", "entity_linker",
config={ config={
"kb_loader": {"@assets": "myAdamKB.v1"}, "kb_loader": {"@misc": "myAdamKB.v1"},
"incl_context": False, "incl_context": False,
"get_candidates": {"@assets": "spacy.LowercaseCandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
}, },
) )
doc = nlp(text) doc = nlp(text)
@ -282,7 +282,7 @@ def test_append_invalid_alias(nlp):
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
@registry.assets.register("myLocationsKB.v1") @registry.misc.register("myLocationsKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
mykb = KnowledgeBase(vocab, entity_vector_length=1) mykb = KnowledgeBase(vocab, entity_vector_length=1)
@ -304,7 +304,7 @@ def test_preserving_links_asdoc(nlp):
] ]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
el_config = {"kb_loader": {"@assets": "myLocationsKB.v1"}, "incl_prior": False} el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True) el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
el_pipe.begin_training(lambda: []) el_pipe.begin_training(lambda: [])
el_pipe.incl_context = False el_pipe.incl_context = False
@ -387,7 +387,7 @@ def test_overfitting_IO():
doc = nlp(text) doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation)) train_examples.append(Example.from_dict(doc, annotation))
@registry.assets.register("myOverfittingKB.v1") @registry.misc.register("myOverfittingKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's # create artificial KB - assign same prior weight to the two russ cochran's
@ -408,7 +408,7 @@ def test_overfitting_IO():
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
nlp.add_pipe( nlp.add_pipe(
"entity_linker", "entity_linker",
config={"kb_loader": {"@assets": "myOverfittingKB.v1"}}, config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
last=True, last=True,
) )

View File

@ -13,7 +13,7 @@ def nlp():
@pytest.fixture @pytest.fixture
def lemmatizer(nlp): def lemmatizer(nlp):
@registry.assets("cope_lookups") @registry.misc("cope_lookups")
def cope_lookups(): def cope_lookups():
lookups = Lookups() lookups = Lookups()
lookups.add_table("lemma_lookup", {"cope": "cope"}) lookups.add_table("lemma_lookup", {"cope": "cope"})
@ -23,13 +23,13 @@ def lemmatizer(nlp):
return lookups return lookups
lemmatizer = nlp.add_pipe( lemmatizer = nlp.add_pipe(
"lemmatizer", config={"mode": "rule", "lookups": {"@assets": "cope_lookups"}} "lemmatizer", config={"mode": "rule", "lookups": {"@misc": "cope_lookups"}}
) )
return lemmatizer return lemmatizer
def test_lemmatizer_init(nlp): def test_lemmatizer_init(nlp):
@registry.assets("cope_lookups") @registry.misc("cope_lookups")
def cope_lookups(): def cope_lookups():
lookups = Lookups() lookups = Lookups()
lookups.add_table("lemma_lookup", {"cope": "cope"}) lookups.add_table("lemma_lookup", {"cope": "cope"})
@ -39,7 +39,7 @@ def test_lemmatizer_init(nlp):
return lookups return lookups
lemmatizer = nlp.add_pipe( lemmatizer = nlp.add_pipe(
"lemmatizer", config={"mode": "lookup", "lookups": {"@assets": "cope_lookups"}} "lemmatizer", config={"mode": "lookup", "lookups": {"@misc": "cope_lookups"}}
) )
assert isinstance(lemmatizer.lookups, Lookups) assert isinstance(lemmatizer.lookups, Lookups)
assert lemmatizer.mode == "lookup" assert lemmatizer.mode == "lookup"
@ -51,14 +51,14 @@ def test_lemmatizer_init(nlp):
nlp.remove_pipe("lemmatizer") nlp.remove_pipe("lemmatizer")
@registry.assets("empty_lookups") @registry.misc("empty_lookups")
def empty_lookups(): def empty_lookups():
return Lookups() return Lookups()
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.add_pipe( nlp.add_pipe(
"lemmatizer", "lemmatizer",
config={"mode": "lookup", "lookups": {"@assets": "empty_lookups"}}, config={"mode": "lookup", "lookups": {"@misc": "empty_lookups"}},
) )
@ -79,7 +79,7 @@ def test_lemmatizer_config(nlp, lemmatizer):
def test_lemmatizer_serialize(nlp, lemmatizer): def test_lemmatizer_serialize(nlp, lemmatizer):
@registry.assets("cope_lookups") @registry.misc("cope_lookups")
def cope_lookups(): def cope_lookups():
lookups = Lookups() lookups = Lookups()
lookups.add_table("lemma_lookup", {"cope": "cope"}) lookups.add_table("lemma_lookup", {"cope": "cope"})
@ -90,7 +90,7 @@ def test_lemmatizer_serialize(nlp, lemmatizer):
nlp2 = English() nlp2 = English()
lemmatizer2 = nlp2.add_pipe( lemmatizer2 = nlp2.add_pipe(
"lemmatizer", config={"mode": "rule", "lookups": {"@assets": "cope_lookups"}} "lemmatizer", config={"mode": "rule", "lookups": {"@misc": "cope_lookups"}}
) )
lemmatizer2.from_bytes(lemmatizer.to_bytes()) lemmatizer2.from_bytes(lemmatizer.to_bytes())
assert lemmatizer.to_bytes() == lemmatizer2.to_bytes() assert lemmatizer.to_bytes() == lemmatizer2.to_bytes()

View File

@ -71,7 +71,7 @@ def tagger():
def entity_linker(): def entity_linker():
nlp = Language() nlp = Language()
@registry.assets.register("TestIssue5230KB.v1") @registry.misc.register("TestIssue5230KB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
@ -80,7 +80,7 @@ def entity_linker():
return create_kb return create_kb
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}} config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,

View File

@ -28,7 +28,7 @@ path = ${paths.train}
path = ${paths.dev} path = ${paths.dev}
[training.batcher] [training.batcher]
@batchers = "batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
size = 666 size = 666
[nlp] [nlp]

View File

@ -85,7 +85,7 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length) super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field self.custom_field = custom_field
@registry.assets.register("spacy.CustomKB.v1") @registry.misc.register("spacy.CustomKB.v1")
def custom_kb( def custom_kb(
entity_vector_length: int, custom_field: int entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]: ) -> Callable[["Vocab"], KnowledgeBase]:
@ -101,7 +101,7 @@ def test_serialize_subclassed_kb():
nlp = English() nlp = English()
config = { config = {
"kb_loader": { "kb_loader": {
"@assets": "spacy.CustomKB.v1", "@misc": "spacy.CustomKB.v1",
"entity_vector_length": 342, "entity_vector_length": 342,
"custom_field": 666, "custom_field": 666,
} }

View File

@ -76,7 +76,7 @@ class registry(thinc.registry):
lemmatizers = catalogue.create("spacy", "lemmatizers", entry_points=True) lemmatizers = catalogue.create("spacy", "lemmatizers", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True) lookups = catalogue.create("spacy", "lookups", entry_points=True)
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True) displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
assets = catalogue.create("spacy", "assets", entry_points=True) misc = catalogue.create("spacy", "misc", entry_points=True)
# Callback functions used to manipulate nlp object etc. # Callback functions used to manipulate nlp object etc.
callbacks = catalogue.create("spacy", "callbacks") callbacks = catalogue.create("spacy", "callbacks")
batchers = catalogue.create("spacy", "batchers", entry_points=True) batchers = catalogue.create("spacy", "batchers", entry_points=True)

View File

@ -320,7 +320,7 @@ for details and system requirements.
> tokenizer_config = {"use_fast": true} > tokenizer_config = {"use_fast": true}
> >
> [model.get_spans] > [model.get_spans]
> @span_getters = "strided_spans.v1" > @span_getters = "spacy-transformers.strided_spans.v1"
> window = 128 > window = 128
> stride = 96 > stride = 96
> ``` > ```
@ -673,11 +673,11 @@ into the "real world". This requires 3 main components:
> subword_features = true > subword_features = true
> >
> [kb_loader] > [kb_loader]
> @assets = "spacy.EmptyKB.v1" > @misc = "spacy.EmptyKB.v1"
> entity_vector_length = 64 > entity_vector_length = 64
> >
> [get_candidates] > [get_candidates]
> @assets = "spacy.CandidateGenerator.v1" > @misc = "spacy.CandidateGenerator.v1"
> ``` > ```
The `EntityLinker` model architecture is a Thinc `Model` with a The `EntityLinker` model architecture is a Thinc `Model` with a

View File

@ -271,7 +271,7 @@ training -> dropout field required
training -> optimizer field required training -> optimizer field required
training -> optimize extra fields not permitted training -> optimize extra fields not permitted
{'vectors': 'en_vectors_web_lg', 'seed': 0, 'accumulate_gradient': 1, 'init_tok2vec': None, 'raw_text': None, 'patience': 1600, 'max_epochs': 0, 'max_steps': 20000, 'eval_frequency': 200, 'frozen_components': [], 'optimize': None, 'batcher': {'@batchers': 'batch_by_words.v1', 'discard_oversize': False, 'tolerance': 0.2, 'get_length': None, 'size': {'@schedules': 'compounding.v1', 'start': 100, 'stop': 1000, 'compound': 1.001, 't': 0.0}}, 'dev_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}, 'score_weights': {'tag_acc': 0.5, 'dep_uas': 0.25, 'dep_las': 0.25, 'sents_f': 0.0}, 'train_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}} {'vectors': 'en_vectors_web_lg', 'seed': 0, 'accumulate_gradient': 1, 'init_tok2vec': None, 'raw_text': None, 'patience': 1600, 'max_epochs': 0, 'max_steps': 20000, 'eval_frequency': 200, 'frozen_components': [], 'optimize': None, 'batcher': {'@batchers': 'spacy.batch_by_words.v1', 'discard_oversize': False, 'tolerance': 0.2, 'get_length': None, 'size': {'@schedules': 'compounding.v1', 'start': 100, 'stop': 1000, 'compound': 1.001, 't': 0.0}}, 'dev_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}, 'score_weights': {'tag_acc': 0.5, 'dep_uas': 0.25, 'dep_las': 0.25, 'sents_f': 0.0}, 'train_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}}
If your config contains missing values, you can run the 'init fill-config' If your config contains missing values, you can run the 'init fill-config'
command to fill in all the defaults, if possible: command to fill in all the defaults, if possible:
@ -361,7 +361,7 @@ Module spacy.gold.loggers
File /path/to/spacy/gold/loggers.py (line 8) File /path/to/spacy/gold/loggers.py (line 8)
[training.batcher] [training.batcher]
Registry @batchers Registry @batchers
Name batch_by_words.v1 Name spacy.batch_by_words.v1
Module spacy.gold.batchers Module spacy.gold.batchers
File /path/to/spacy/gold/batchers.py (line 49) File /path/to/spacy/gold/batchers.py (line 49)
[training.batcher.size] [training.batcher.size]

View File

@ -34,8 +34,8 @@ architectures and their arguments and hyperparameters.
> "incl_prior": True, > "incl_prior": True,
> "incl_context": True, > "incl_context": True,
> "model": DEFAULT_NEL_MODEL, > "model": DEFAULT_NEL_MODEL,
> "kb_loader": {'@assets': 'spacy.EmptyKB.v1', 'entity_vector_length': 64}, > "kb_loader": {'@misc': 'spacy.EmptyKB.v1', 'entity_vector_length': 64},
> "get_candidates": {'@assets': 'spacy.CandidateGenerator.v1'}, > "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'},
> } > }
> nlp.add_pipe("entity_linker", config=config) > nlp.add_pipe("entity_linker", config=config)
> ``` > ```
@ -66,7 +66,7 @@ https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/entity_linker.py
> entity_linker = nlp.add_pipe("entity_linker", config=config) > entity_linker = nlp.add_pipe("entity_linker", config=config)
> >
> # Construction via add_pipe with custom KB and candidate generation > # Construction via add_pipe with custom KB and candidate generation
> config = {"kb": {"@assets": "my_kb.v1"}} > config = {"kb": {"@misc": "my_kb.v1"}}
> entity_linker = nlp.add_pipe("entity_linker", config=config) > entity_linker = nlp.add_pipe("entity_linker", config=config)
> >
> # Construction from class > # Construction from class

View File

@ -307,7 +307,6 @@ factories.
| Registry name | Description | | Registry name | Description |
| ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `architectures` | Registry for functions that create [model architectures](/api/architectures). Can be used to register custom model architectures and reference them in the `config.cfg`. | | `architectures` | Registry for functions that create [model architectures](/api/architectures). Can be used to register custom model architectures and reference them in the `config.cfg`. |
| `assets` | Registry for data assets, knowledge bases etc. |
| `batchers` | Registry for training and evaluation [data batchers](#batchers). | | `batchers` | Registry for training and evaluation [data batchers](#batchers). |
| `callbacks` | Registry for custom callbacks to [modify the `nlp` object](/usage/training#custom-code-nlp-callbacks) before training. | | `callbacks` | Registry for custom callbacks to [modify the `nlp` object](/usage/training#custom-code-nlp-callbacks) before training. |
| `displacy_colors` | Registry for custom color scheme for the [`displacy` NER visualizer](/usage/visualizers). Automatically reads from [entry points](/usage/saving-loading#entry-points). | | `displacy_colors` | Registry for custom color scheme for the [`displacy` NER visualizer](/usage/visualizers). Automatically reads from [entry points](/usage/saving-loading#entry-points). |
@ -318,6 +317,7 @@ factories.
| `loggers` | Registry for functions that log [training results](/usage/training). | | `loggers` | Registry for functions that log [training results](/usage/training). |
| `lookups` | Registry for large lookup tables available via `vocab.lookups`. | | `lookups` | Registry for large lookup tables available via `vocab.lookups`. |
| `losses` | Registry for functions that create [losses](https://thinc.ai/docs/api-loss). | | `losses` | Registry for functions that create [losses](https://thinc.ai/docs/api-loss). |
| `misc` | Registry for miscellaneous functions that return data assets, knowledge bases or anything else you may need. |
| `optimizers` | Registry for functions that create [optimizers](https://thinc.ai/docs/api-optimizers). | | `optimizers` | Registry for functions that create [optimizers](https://thinc.ai/docs/api-optimizers). |
| `readers` | Registry for training and evaluation data readers like [`Corpus`](/api/corpus). | | `readers` | Registry for training and evaluation data readers like [`Corpus`](/api/corpus). |
| `schedules` | Registry for functions that create [schedules](https://thinc.ai/docs/api-schedules). | | `schedules` | Registry for functions that create [schedules](https://thinc.ai/docs/api-schedules). |
@ -364,7 +364,7 @@ results to a [Weights & Biases](https://www.wandb.com/) dashboard. Instead of
using one of the built-in loggers listed here, you can also using one of the built-in loggers listed here, you can also
[implement your own](/usage/training#custom-logging). [implement your own](/usage/training#custom-logging).
#### spacy.ConsoleLogger.v1 {#ConsoleLogger tag="registered function"} #### spacy.ConsoleLogger {#ConsoleLogger tag="registered function"}
> #### Example config > #### Example config
> >
@ -410,7 +410,7 @@ start decreasing across epochs.
</Accordion> </Accordion>
#### spacy.WandbLogger.v1 {#WandbLogger tag="registered function"} #### spacy.WandbLogger {#WandbLogger tag="registered function"}
> #### Installation > #### Installation
> >
@ -466,7 +466,7 @@ Instead of using one of the built-in batchers listed here, you can also
[implement your own](/usage/training#custom-code-readers-batchers), which may or [implement your own](/usage/training#custom-code-readers-batchers), which may or
may not use a custom schedule. may not use a custom schedule.
#### batch_by_words.v1 {#batch_by_words tag="registered function"} #### batch_by_words {#batch_by_words tag="registered function"}
Create minibatches of roughly a given number of words. If any examples are Create minibatches of roughly a given number of words. If any examples are
longer than the specified batch length, they will appear in a batch by longer than the specified batch length, they will appear in a batch by
@ -478,7 +478,7 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
> >
> ```ini > ```ini
> [training.batcher] > [training.batcher]
> @batchers = "batch_by_words.v1" > @batchers = "spacy.batch_by_words.v1"
> size = 100 > size = 100
> tolerance = 0.2 > tolerance = 0.2
> discard_oversize = false > discard_oversize = false
@ -493,13 +493,13 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ | | `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
#### batch_by_sequence.v1 {#batch_by_sequence tag="registered function"} #### batch_by_sequence {#batch_by_sequence tag="registered function"}
> #### Example config > #### Example config
> >
> ```ini > ```ini
> [training.batcher] > [training.batcher]
> @batchers = "batch_by_sequence.v1" > @batchers = "spacy.batch_by_sequence.v1"
> size = 32 > size = 32
> get_length = null > get_length = null
> ``` > ```
@ -511,13 +511,13 @@ Create a batcher that creates batches of the specified size.
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Sequence[int]]~~ | | `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Sequence[int]]~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
#### batch_by_padded.v1 {#batch_by_padded tag="registered function"} #### batch_by_padded {#batch_by_padded tag="registered function"}
> #### Example config > #### Example config
> >
> ```ini > ```ini
> [training.batcher] > [training.batcher]
> @batchers = "batch_by_padded.v1" > @batchers = "spacy.batch_by_padded.v1"
> size = 100 > size = 100
> buffer = 256 > buffer = 256
> discard_oversize = false > discard_oversize = false

View File

@ -453,7 +453,7 @@ using the `@spacy.registry.span_getters` decorator.
> #### Example > #### Example
> >
> ```python > ```python
> @spacy.registry.span_getters("sent_spans.v1") > @spacy.registry.span_getters("custom_sent_spans")
> def configure_get_sent_spans() -> Callable: > def configure_get_sent_spans() -> Callable:
> def get_sent_spans(docs: Iterable[Doc]) -> List[List[Span]]: > def get_sent_spans(docs: Iterable[Doc]) -> List[List[Span]]:
> return [list(doc.sents) for doc in docs] > return [list(doc.sents) for doc in docs]
@ -472,7 +472,7 @@ using the `@spacy.registry.span_getters` decorator.
> >
> ```ini > ```ini
> [transformer.model.get_spans] > [transformer.model.get_spans]
> @span_getters = "doc_spans.v1" > @span_getters = "spacy-transformers.doc_spans.v1"
> ``` > ```
Create a span getter that uses the whole document as its spans. This is the best Create a span getter that uses the whole document as its spans. This is the best
@ -485,7 +485,7 @@ texts.
> >
> ```ini > ```ini
> [transformer.model.get_spans] > [transformer.model.get_spans]
> @span_getters = "sent_spans.v1" > @span_getters = "spacy-transformers.sent_spans.v1"
> ``` > ```
Create a span getter that uses sentence boundary markers to extract the spans. Create a span getter that uses sentence boundary markers to extract the spans.
@ -500,7 +500,7 @@ more meaningful windows to attend over.
> >
> ```ini > ```ini
> [transformer.model.get_spans] > [transformer.model.get_spans]
> @span_getters = "strided_spans.v1" > @span_getters = "spacy-transformers.strided_spans.v1"
> window = 128 > window = 128
> stride = 96 > stride = 96
> ``` > ```

View File

@ -331,7 +331,7 @@ name = "bert-base-cased"
tokenizer_config = {"use_fast": true} tokenizer_config = {"use_fast": true}
[components.transformer.model.get_spans] [components.transformer.model.get_spans]
@span_getters = "doc_spans.v1" @span_getters = "spacy-transformers.doc_spans.v1"
[components.transformer.annotation_setter] [components.transformer.annotation_setter]
@annotation_setters = "spacy-transformers.null_annotation_setter.v1" @annotation_setters = "spacy-transformers.null_annotation_setter.v1"
@ -369,8 +369,9 @@ all defaults.
To change any of the settings, you can edit the `config.cfg` and re-run the To change any of the settings, you can edit the `config.cfg` and re-run the
training. To change any of the functions, like the span getter, you can replace training. To change any of the functions, like the span getter, you can replace
the name of the referenced function e.g. `@span_getters = "sent_spans.v1"` to the name of the referenced function e.g.
process sentences. You can also register your own functions using the `@span_getters = "spacy-transformers.sent_spans.v1"` to process sentences. You
can also register your own functions using the
[`span_getters` registry](/api/top-level#registry). For instance, the following [`span_getters` registry](/api/top-level#registry). For instance, the following
custom function returns [`Span`](/api/span) objects following sentence custom function returns [`Span`](/api/span) objects following sentence
boundaries, unless a sentence succeeds a certain amount of tokens, in which case boundaries, unless a sentence succeeds a certain amount of tokens, in which case

View File

@ -842,12 +842,20 @@ load and train custom pipelines with custom components. A simple solution is to
**register a function** that returns your resources. The **register a function** that returns your resources. The
[registry](/api/top-level#registry) lets you **map string names to functions** [registry](/api/top-level#registry) lets you **map string names to functions**
that create objects, so given a name and optional arguments, spaCy will know how that create objects, so given a name and optional arguments, spaCy will know how
to recreate the object. To register a function that returns a custom asset, you to recreate the object. To register a function that returns your custom
can use the `@spacy.registry.assets` decorator with a single argument, the name: dictionary, you can use the `@spacy.registry.misc` decorator with a single
argument, the name:
> #### What's the misc registry?
>
> The [`registry`](/api/top-level#registry) provides different categories for
> different types of functions for example, model architectures, tokenizers or
> batchers. `misc` is intended for miscellaneous functions that don't fit
> anywhere else.
```python ```python
### Registered function for assets {highlight="1"} ### Registered function for assets {highlight="1"}
@spacy.registry.assets("acronyms.slang_dict.v1") @spacy.registry.misc("acronyms.slang_dict.v1")
def create_acronyms_slang_dict(): def create_acronyms_slang_dict():
dictionary = {"lol": "laughing out loud", "brb": "be right back"} dictionary = {"lol": "laughing out loud", "brb": "be right back"}
dictionary.update({value: key for key, value in dictionary.items()}) dictionary.update({value: key for key, value in dictionary.items()})
@ -856,9 +864,9 @@ def create_acronyms_slang_dict():
In your `default_config` (and later in your In your `default_config` (and later in your
[training config](/usage/training#config)), you can now refer to the function [training config](/usage/training#config)), you can now refer to the function
registered under the name `"acronyms.slang_dict.v1"` using the `@assets` key. registered under the name `"acronyms.slang_dict.v1"` using the `@misc` key. This
This tells spaCy how to create the value, and when your component is created, tells spaCy how to create the value, and when your component is created, the
the result of the registered function is passed in as the key `"dictionary"`. result of the registered function is passed in as the key `"dictionary"`.
> #### config.cfg > #### config.cfg
> >
@ -867,22 +875,22 @@ the result of the registered function is passed in as the key `"dictionary"`.
> factory = "acronyms" > factory = "acronyms"
> >
> [components.acronyms.dictionary] > [components.acronyms.dictionary]
> @assets = "acronyms.slang_dict.v1" > @misc = "acronyms.slang_dict.v1"
> ``` > ```
```diff ```diff
- default_config = {"dictionary:" DICTIONARY} - default_config = {"dictionary:" DICTIONARY}
+ default_config = {"dictionary": {"@assets": "acronyms.slang_dict.v1"}} + default_config = {"dictionary": {"@misc": "acronyms.slang_dict.v1"}}
``` ```
Using a registered function also means that you can easily include your custom Using a registered function also means that you can easily include your custom
components in pipelines that you [train](/usage/training). To make sure spaCy components in pipelines that you [train](/usage/training). To make sure spaCy
knows where to find your custom `@assets` function, you can pass in a Python knows where to find your custom `@misc` function, you can pass in a Python file
file via the argument `--code`. If someone else is using your component, all via the argument `--code`. If someone else is using your component, all they
they have to do to customize the data is to register their own function and swap have to do to customize the data is to register their own function and swap out
out the name. Registered functions can also take **arguments** by the way that the name. Registered functions can also take **arguments** by the way that can
can be defined in the config as well you can read more about this in the docs be defined in the config as well you can read more about this in the docs on
on [training with custom code](/usage/training#custom-code). [training with custom code](/usage/training#custom-code).
### Python type hints and pydantic validation {#type-hints new="3"} ### Python type hints and pydantic validation {#type-hints new="3"}