Merge pull request #7257 from svlandeg/fix/registry_consistency

This commit is contained in:
Ines Montani 2021-03-03 23:14:19 +11:00 committed by GitHub
commit ada4cdbd71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 37 additions and 37 deletions

View File

@ -8,7 +8,7 @@ from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab from ...vocab import Vocab
@registry.architectures.register("spacy.EntityLinker.v1") @registry.architectures("spacy.EntityLinker.v1")
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model: def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
with Model.define_operators({">>": chain, "**": clone}): with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.get_dim("nO") token_width = tok2vec.get_dim("nO")
@ -25,7 +25,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
return model return model
@registry.misc.register("spacy.KBFromFile.v1") @registry.misc("spacy.KBFromFile.v1")
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]: def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab): def kb_from_file(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
@ -35,7 +35,7 @@ def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
return kb_from_file return kb_from_file
@registry.misc.register("spacy.EmptyKB.v1") @registry.misc("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]: def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab): def empty_kb_factory(vocab):
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length) return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
@ -43,6 +43,6 @@ def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
return empty_kb_factory return empty_kb_factory
@registry.misc.register("spacy.CandidateGenerator.v1") @registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]: def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates return get_candidates

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from ...tokens import Doc # noqa: F401 from ...tokens import Doc # noqa: F401
@registry.architectures.register("spacy.PretrainVectors.v1") @registry.architectures("spacy.PretrainVectors.v1")
def create_pretrain_vectors( def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:
@ -40,7 +40,7 @@ def create_pretrain_vectors(
return create_vectors_objective return create_vectors_objective
@registry.architectures.register("spacy.PretrainCharacters.v1") @registry.architectures("spacy.PretrainCharacters.v1")
def create_pretrain_characters( def create_pretrain_characters(
maxout_pieces: int, hidden_size: int, n_characters: int maxout_pieces: int, hidden_size: int, n_characters: int
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:

View File

@ -10,7 +10,7 @@ from ..tb_framework import TransitionModel
from ...tokens import Doc from ...tokens import Doc
@registry.architectures.register("spacy.TransitionBasedParser.v1") @registry.architectures("spacy.TransitionBasedParser.v1")
def transition_parser_v1( def transition_parser_v1(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"], state_type: Literal["parser", "ner"],
@ -31,7 +31,7 @@ def transition_parser_v1(
) )
@registry.architectures.register("spacy.TransitionBasedParser.v2") @registry.architectures("spacy.TransitionBasedParser.v2")
def transition_parser_v2( def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"], state_type: Literal["parser", "ner"],

View File

@ -6,7 +6,7 @@ from ...util import registry
from ...tokens import Doc from ...tokens import Doc
@registry.architectures.register("spacy.Tagger.v1") @registry.architectures("spacy.Tagger.v1")
def build_tagger_model( def build_tagger_model(
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:

View File

@ -15,7 +15,7 @@ from ...tokens import Doc
from .tok2vec import get_tok2vec_width from .tok2vec import get_tok2vec_width
@registry.architectures.register("spacy.TextCatCNN.v1") @registry.architectures("spacy.TextCatCNN.v1")
def build_simple_cnn_text_classifier( def build_simple_cnn_text_classifier(
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]: ) -> Model[List[Doc], Floats2d]:
@ -41,7 +41,7 @@ def build_simple_cnn_text_classifier(
return model return model
@registry.architectures.register("spacy.TextCatBOW.v1") @registry.architectures("spacy.TextCatBOW.v1")
def build_bow_text_classifier( def build_bow_text_classifier(
exclusive_classes: bool, exclusive_classes: bool,
ngram_size: int, ngram_size: int,
@ -60,7 +60,7 @@ def build_bow_text_classifier(
return model return model
@registry.architectures.register("spacy.TextCatEnsemble.v2") @registry.architectures("spacy.TextCatEnsemble.v2")
def build_text_classifier_v2( def build_text_classifier_v2(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
linear_model: Model[List[Doc], Floats2d], linear_model: Model[List[Doc], Floats2d],
@ -112,7 +112,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
return model return model
@registry.architectures.register("spacy.TextCatLowData.v1") @registry.architectures("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata( def build_text_classifier_lowdata(
width: int, dropout: Optional[float], nO: Optional[int] = None width: int, dropout: Optional[float], nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]: ) -> Model[List[Doc], Floats2d]:

View File

@ -14,7 +14,7 @@ from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import intify_attr from ...attrs import intify_attr
@registry.architectures.register("spacy.Tok2VecListener.v1") @registry.architectures("spacy.Tok2VecListener.v1")
def tok2vec_listener_v1(width: int, upstream: str = "*"): def tok2vec_listener_v1(width: int, upstream: str = "*"):
tok2vec = Tok2VecListener(upstream_name=upstream, width=width) tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
return tok2vec return tok2vec
@ -31,7 +31,7 @@ def get_tok2vec_width(model: Model):
return nO return nO
@registry.architectures.register("spacy.HashEmbedCNN.v1") @registry.architectures("spacy.HashEmbedCNN.v1")
def build_hash_embed_cnn_tok2vec( def build_hash_embed_cnn_tok2vec(
*, *,
width: int, width: int,
@ -87,7 +87,7 @@ def build_hash_embed_cnn_tok2vec(
) )
@registry.architectures.register("spacy.Tok2Vec.v2") @registry.architectures("spacy.Tok2Vec.v2")
def build_Tok2Vec_model( def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]],
@ -108,7 +108,7 @@ def build_Tok2Vec_model(
return tok2vec return tok2vec
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures("spacy.MultiHashEmbed.v1")
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int,
attrs: List[Union[str, int]], attrs: List[Union[str, int]],
@ -182,7 +182,7 @@ def MultiHashEmbed(
return model return model
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures("spacy.CharacterEmbed.v1")
def CharacterEmbed( def CharacterEmbed(
width: int, width: int,
rows: int, rows: int,
@ -255,7 +255,7 @@ def CharacterEmbed(
return model return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2") @registry.architectures("spacy.MaxoutWindowEncoder.v2")
def MaxoutWindowEncoder( def MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
@ -287,7 +287,7 @@ def MaxoutWindowEncoder(
return with_array(model, pad=receptive_field) return with_array(model, pad=receptive_field)
@registry.architectures.register("spacy.MishWindowEncoder.v2") @registry.architectures("spacy.MishWindowEncoder.v2")
def MishWindowEncoder( def MishWindowEncoder(
width: int, window_size: int, depth: int width: int, window_size: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
@ -310,7 +310,7 @@ def MishWindowEncoder(
return with_array(model) return with_array(model)
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") @registry.architectures("spacy.TorchBiLSTMEncoder.v1")
def BiLSTMEncoder( def BiLSTMEncoder(
width: int, depth: int, dropout: float width: int, depth: int, dropout: float
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:

View File

@ -230,7 +230,7 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb, span): def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower()) return kb.get_alias_candidates(span.text.lower())
@registry.misc.register("spacy.LowercaseCandidateGenerator.v1") @registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]: def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates return get_lowercased_candidates

View File

@ -160,7 +160,7 @@ subword_features = false
""" """
@registry.architectures.register("my_test_parser") @registry.architectures("my_test_parser")
def my_parser(): def my_parser():
tok2vec = build_Tok2Vec_model( tok2vec = build_Tok2Vec_model(
MultiHashEmbed( MultiHashEmbed(

View File

@ -108,7 +108,7 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length) super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field self.custom_field = custom_field
@registry.misc.register("spacy.CustomKB.v1") @registry.misc("spacy.CustomKB.v1")
def custom_kb( def custom_kb(
entity_vector_length: int, custom_field: int entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]: ) -> Callable[["Vocab"], KnowledgeBase]:

View File

@ -4,12 +4,12 @@ from thinc.api import Linear
from catalogue import RegistryError from catalogue import RegistryError
@registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out):
return Linear(nr_in, nr_out)
def test_get_architecture(): def test_get_architecture():
@registry.architectures("my_test_function")
def create_model(nr_in, nr_out):
return Linear(nr_in, nr_out)
arch = registry.architectures.get("my_test_function") arch = registry.architectures.get("my_test_function")
assert arch is create_model assert arch is create_model
with pytest.raises(RegistryError): with pytest.raises(RegistryError):

View File

@ -27,7 +27,7 @@ def test_readers():
factory = "textcat" factory = "textcat"
""" """
@registry.readers.register("myreader.v1") @registry.readers("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]: def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}} annots = {"cats": {"POS": 1.0, "NEG": 0.0}}

View File

@ -19,7 +19,7 @@ spaCy's built-in architectures that are used for different NLP tasks. All
trainable [built-in components](/api#architecture-pipeline) expect a `model` trainable [built-in components](/api#architecture-pipeline) expect a `model`
argument defined in the config and document their the default architecture. argument defined in the config and document their the default architecture.
Custom architectures can be registered using the Custom architectures can be registered using the
[`@spacy.registry.architectures`](/api/top-level#regsitry) decorator and used as [`@spacy.registry.architectures`](/api/top-level#registry) decorator and used as
part of the [training config](/usage/training#custom-functions). Also see the part of the [training config](/usage/training#custom-functions). Also see the
usage documentation on usage documentation on
[layers and model architectures](/usage/layers-architectures). [layers and model architectures](/usage/layers-architectures).

View File

@ -15,7 +15,7 @@ next: /usage/projects
> ```python > ```python
> from thinc.api import Model, chain > from thinc.api import Model, chain
> >
> @spacy.registry.architectures.register("model.v1") > @spacy.registry.architectures("model.v1")
> def build_model(width: int, classes: int) -> Model: > def build_model(width: int, classes: int) -> Model:
> tok2vec = build_tok2vec(width) > tok2vec = build_tok2vec(width)
> output_layer = build_output_layer(width, classes) > output_layer = build_output_layer(width, classes)
@ -563,7 +563,7 @@ matrix** (~~Floats2d~~) of predictions:
```python ```python
### The model architecture ### The model architecture
@spacy.registry.architectures.register("rel_model.v1") @spacy.registry.architectures("rel_model.v1")
def create_relation_model(...) -> Model[List[Doc], Floats2d]: def create_relation_model(...) -> Model[List[Doc], Floats2d]:
model = ... # 👈 model will go here model = ... # 👈 model will go here
return model return model
@ -589,7 +589,7 @@ transforms the instance tensor into a final tensor holding the predictions:
```python ```python
### The model architecture {highlight="6"} ### The model architecture {highlight="6"}
@spacy.registry.architectures.register("rel_model.v1") @spacy.registry.architectures("rel_model.v1")
def create_relation_model( def create_relation_model(
create_instance_tensor: Model[List[Doc], Floats2d], create_instance_tensor: Model[List[Doc], Floats2d],
classification_layer: Model[Floats2d, Floats2d], classification_layer: Model[Floats2d, Floats2d],
@ -613,7 +613,7 @@ The `classification_layer` could be something like a
```python ```python
### The classification layer ### The classification layer
@spacy.registry.architectures.register("rel_classification_layer.v1") @spacy.registry.architectures("rel_classification_layer.v1")
def create_classification_layer( def create_classification_layer(
nO: int = None, nI: int = None nO: int = None, nI: int = None
) -> Model[Floats2d, Floats2d]: ) -> Model[Floats2d, Floats2d]:
@ -650,7 +650,7 @@ that has the full implementation.
```python ```python
### The layer that creates the instance tensor ### The layer that creates the instance tensor
@spacy.registry.architectures.register("rel_instance_tensor.v1") @spacy.registry.architectures("rel_instance_tensor.v1")
def create_tensors( def create_tensors(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
pooling: Model[Ragged, Floats2d], pooling: Model[Ragged, Floats2d],
@ -731,7 +731,7 @@ are within a **maximum distance** (in number of tokens) of each other:
```python ```python
### Candidate generation ### Candidate generation
@spacy.registry.misc.register("rel_instance_generator.v1") @spacy.registry.misc("rel_instance_generator.v1")
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]: def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]: def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]:
candidates = [] candidates = []