consistently use registry as callable

This commit is contained in:
svlandeg 2021-03-02 17:56:28 +01:00
parent 212f0e779e
commit d900c55061
12 changed files with 36 additions and 36 deletions

View File

@ -8,7 +8,7 @@ from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab
@registry.architectures.register("spacy.EntityLinker.v1")
@registry.architectures("spacy.EntityLinker.v1")
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.get_dim("nO")
@ -25,7 +25,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
return model
@registry.misc.register("spacy.KBFromFile.v1")
@registry.misc("spacy.KBFromFile.v1")
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
@ -35,7 +35,7 @@ def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
return kb_from_file
@registry.misc.register("spacy.EmptyKB.v1")
@registry.misc("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab):
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
@ -43,6 +43,6 @@ def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
return empty_kb_factory
@registry.misc.register("spacy.CandidateGenerator.v1")
@registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from ...tokens import Doc # noqa: F401
@registry.architectures.register("spacy.PretrainVectors.v1")
@registry.architectures("spacy.PretrainVectors.v1")
def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]:
@ -40,7 +40,7 @@ def create_pretrain_vectors(
return create_vectors_objective
@registry.architectures.register("spacy.PretrainCharacters.v1")
@registry.architectures("spacy.PretrainCharacters.v1")
def create_pretrain_characters(
maxout_pieces: int, hidden_size: int, n_characters: int
) -> Callable[["Vocab", Model], Model]:

View File

@ -10,7 +10,7 @@ from ..tb_framework import TransitionModel
from ...tokens import Doc
@registry.architectures.register("spacy.TransitionBasedParser.v1")
@registry.architectures("spacy.TransitionBasedParser.v1")
def transition_parser_v1(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
@ -31,7 +31,7 @@ def transition_parser_v1(
)
@registry.architectures.register("spacy.TransitionBasedParser.v2")
@registry.architectures("spacy.TransitionBasedParser.v2")
def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],

View File

@ -6,7 +6,7 @@ from ...util import registry
from ...tokens import Doc
@registry.architectures.register("spacy.Tagger.v1")
@registry.architectures("spacy.Tagger.v1")
def build_tagger_model(
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
) -> Model[List[Doc], List[Floats2d]]:

View File

@ -15,7 +15,7 @@ from ...tokens import Doc
from .tok2vec import get_tok2vec_width
@registry.architectures.register("spacy.TextCatCNN.v1")
@registry.architectures("spacy.TextCatCNN.v1")
def build_simple_cnn_text_classifier(
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
@ -41,7 +41,7 @@ def build_simple_cnn_text_classifier(
return model
@registry.architectures.register("spacy.TextCatBOW.v1")
@registry.architectures("spacy.TextCatBOW.v1")
def build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
@ -60,7 +60,7 @@ def build_bow_text_classifier(
return model
@registry.architectures.register("spacy.TextCatEnsemble.v2")
@registry.architectures("spacy.TextCatEnsemble.v2")
def build_text_classifier_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
linear_model: Model[List[Doc], Floats2d],
@ -112,7 +112,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
return model
@registry.architectures.register("spacy.TextCatLowData.v1")
@registry.architectures("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(
width: int, dropout: Optional[float], nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:

View File

@ -14,7 +14,7 @@ from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import intify_attr
@registry.architectures.register("spacy.Tok2VecListener.v1")
@registry.architectures("spacy.Tok2VecListener.v1")
def tok2vec_listener_v1(width: int, upstream: str = "*"):
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
return tok2vec
@ -31,7 +31,7 @@ def get_tok2vec_width(model: Model):
return nO
@registry.architectures.register("spacy.HashEmbedCNN.v1")
@registry.architectures("spacy.HashEmbedCNN.v1")
def build_hash_embed_cnn_tok2vec(
*,
width: int,
@ -87,7 +87,7 @@ def build_hash_embed_cnn_tok2vec(
)
@registry.architectures.register("spacy.Tok2Vec.v2")
@registry.architectures("spacy.Tok2Vec.v2")
def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]],
@ -108,7 +108,7 @@ def build_Tok2Vec_model(
return tok2vec
@registry.architectures.register("spacy.MultiHashEmbed.v1")
@registry.architectures("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(
width: int,
attrs: List[Union[str, int]],
@ -182,7 +182,7 @@ def MultiHashEmbed(
return model
@registry.architectures.register("spacy.CharacterEmbed.v1")
@registry.architectures("spacy.CharacterEmbed.v1")
def CharacterEmbed(
width: int,
rows: int,
@ -255,7 +255,7 @@ def CharacterEmbed(
return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2")
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
def MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]:
@ -287,7 +287,7 @@ def MaxoutWindowEncoder(
return with_array(model, pad=receptive_field)
@registry.architectures.register("spacy.MishWindowEncoder.v2")
@registry.architectures("spacy.MishWindowEncoder.v2")
def MishWindowEncoder(
width: int, window_size: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]:
@ -310,7 +310,7 @@ def MishWindowEncoder(
return with_array(model)
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
def BiLSTMEncoder(
width: int, depth: int, dropout: float
) -> Model[List[Floats2d], List[Floats2d]]:

View File

@ -230,7 +230,7 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
@registry.misc.register("spacy.LowercaseCandidateGenerator.v1")
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates

View File

@ -160,7 +160,7 @@ subword_features = false
"""
@registry.architectures.register("my_test_parser")
@registry.architectures("my_test_parser")
def my_parser():
tok2vec = build_Tok2Vec_model(
MultiHashEmbed(

View File

@ -108,7 +108,7 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
@registry.misc.register("spacy.CustomKB.v1")
@registry.misc("spacy.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:

View File

@ -4,12 +4,12 @@ from thinc.api import Linear
from catalogue import RegistryError
@registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out):
return Linear(nr_in, nr_out)
def test_get_architecture():
@registry.architectures("my_test_function")
def create_model(nr_in, nr_out):
return Linear(nr_in, nr_out)
arch = registry.architectures.get("my_test_function")
assert arch is create_model
with pytest.raises(RegistryError):

View File

@ -27,7 +27,7 @@ def test_readers():
factory = "textcat"
"""
@registry.readers.register("myreader.v1")
@registry.readers("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}

View File

@ -15,7 +15,7 @@ next: /usage/projects
> ```python
> from thinc.api import Model, chain
>
> @spacy.registry.architectures.register("model.v1")
> @spacy.registry.architectures("model.v1")
> def build_model(width: int, classes: int) -> Model:
> tok2vec = build_tok2vec(width)
> output_layer = build_output_layer(width, classes)
@ -563,7 +563,7 @@ matrix** (~~Floats2d~~) of predictions:
```python
### The model architecture
@spacy.registry.architectures.register("rel_model.v1")
@spacy.registry.architectures("rel_model.v1")
def create_relation_model(...) -> Model[List[Doc], Floats2d]:
model = ... # 👈 model will go here
return model
@ -589,7 +589,7 @@ transforms the instance tensor into a final tensor holding the predictions:
```python
### The model architecture {highlight="6"}
@spacy.registry.architectures.register("rel_model.v1")
@spacy.registry.architectures("rel_model.v1")
def create_relation_model(
create_instance_tensor: Model[List[Doc], Floats2d],
classification_layer: Model[Floats2d, Floats2d],
@ -613,7 +613,7 @@ The `classification_layer` could be something like a
```python
### The classification layer
@spacy.registry.architectures.register("rel_classification_layer.v1")
@spacy.registry.architectures("rel_classification_layer.v1")
def create_classification_layer(
nO: int = None, nI: int = None
) -> Model[Floats2d, Floats2d]:
@ -650,7 +650,7 @@ that has the full implementation.
```python
### The layer that creates the instance tensor
@spacy.registry.architectures.register("rel_instance_tensor.v1")
@spacy.registry.architectures("rel_instance_tensor.v1")
def create_tensors(
tok2vec: Model[List[Doc], List[Floats2d]],
pooling: Model[Ragged, Floats2d],
@ -731,7 +731,7 @@ are within a **maximum distance** (in number of tokens) of each other:
```python
### Candidate generation
@spacy.registry.misc.register("rel_instance_generator.v1")
@spacy.registry.misc("rel_instance_generator.v1")
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]:
candidates = []