mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
consistently use registry as callable
This commit is contained in:
parent
212f0e779e
commit
d900c55061
|
@ -8,7 +8,7 @@ from ...kb import KnowledgeBase, Candidate, get_candidates
|
|||
from ...vocab import Vocab
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.EntityLinker.v1")
|
||||
@registry.architectures("spacy.EntityLinker.v1")
|
||||
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
token_width = tok2vec.get_dim("nO")
|
||||
|
@ -25,7 +25,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
@registry.misc.register("spacy.KBFromFile.v1")
|
||||
@registry.misc("spacy.KBFromFile.v1")
|
||||
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
||||
def kb_from_file(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
|
@ -35,7 +35,7 @@ def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
|||
return kb_from_file
|
||||
|
||||
|
||||
@registry.misc.register("spacy.EmptyKB.v1")
|
||||
@registry.misc("spacy.EmptyKB.v1")
|
||||
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab):
|
||||
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
|
@ -43,6 +43,6 @@ def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
|||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.misc.register("spacy.CandidateGenerator.v1")
|
||||
@registry.misc("spacy.CandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||
return get_candidates
|
||||
|
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
from ...tokens import Doc # noqa: F401
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.PretrainVectors.v1")
|
||||
@registry.architectures("spacy.PretrainVectors.v1")
|
||||
def create_pretrain_vectors(
|
||||
maxout_pieces: int, hidden_size: int, loss: str
|
||||
) -> Callable[["Vocab", Model], Model]:
|
||||
|
@ -40,7 +40,7 @@ def create_pretrain_vectors(
|
|||
return create_vectors_objective
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.PretrainCharacters.v1")
|
||||
@registry.architectures("spacy.PretrainCharacters.v1")
|
||||
def create_pretrain_characters(
|
||||
maxout_pieces: int, hidden_size: int, n_characters: int
|
||||
) -> Callable[["Vocab", Model], Model]:
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..tb_framework import TransitionModel
|
|||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||
@registry.architectures("spacy.TransitionBasedParser.v1")
|
||||
def transition_parser_v1(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
state_type: Literal["parser", "ner"],
|
||||
|
@ -31,7 +31,7 @@ def transition_parser_v1(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TransitionBasedParser.v2")
|
||||
@registry.architectures("spacy.TransitionBasedParser.v2")
|
||||
def transition_parser_v2(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
state_type: Literal["parser", "ner"],
|
||||
|
|
|
@ -6,7 +6,7 @@ from ...util import registry
|
|||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tagger.v1")
|
||||
@registry.architectures("spacy.Tagger.v1")
|
||||
def build_tagger_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
|
|
@ -15,7 +15,7 @@ from ...tokens import Doc
|
|||
from .tok2vec import get_tok2vec_width
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||
@registry.architectures("spacy.TextCatCNN.v1")
|
||||
def build_simple_cnn_text_classifier(
|
||||
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
|
@ -41,7 +41,7 @@ def build_simple_cnn_text_classifier(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatBOW.v1")
|
||||
@registry.architectures("spacy.TextCatBOW.v1")
|
||||
def build_bow_text_classifier(
|
||||
exclusive_classes: bool,
|
||||
ngram_size: int,
|
||||
|
@ -60,7 +60,7 @@ def build_bow_text_classifier(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatEnsemble.v2")
|
||||
@registry.architectures("spacy.TextCatEnsemble.v2")
|
||||
def build_text_classifier_v2(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
linear_model: Model[List[Doc], Floats2d],
|
||||
|
@ -112,7 +112,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||
@registry.architectures("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(
|
||||
width: int, dropout: Optional[float], nO: Optional[int] = None
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
|
|
|
@ -14,7 +14,7 @@ from ...pipeline.tok2vec import Tok2VecListener
|
|||
from ...attrs import intify_attr
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||
@registry.architectures("spacy.Tok2VecListener.v1")
|
||||
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||
return tok2vec
|
||||
|
@ -31,7 +31,7 @@ def get_tok2vec_width(model: Model):
|
|||
return nO
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
@registry.architectures("spacy.HashEmbedCNN.v1")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
|
@ -87,7 +87,7 @@ def build_hash_embed_cnn_tok2vec(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tok2Vec.v2")
|
||||
@registry.architectures("spacy.Tok2Vec.v2")
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
|
@ -108,7 +108,7 @@ def build_Tok2Vec_model(
|
|||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
@registry.architectures("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
|
@ -182,7 +182,7 @@ def MultiHashEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
@registry.architectures("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
|
@ -255,7 +255,7 @@ def CharacterEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2")
|
||||
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
|
||||
def MaxoutWindowEncoder(
|
||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -287,7 +287,7 @@ def MaxoutWindowEncoder(
|
|||
return with_array(model, pad=receptive_field)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.MishWindowEncoder.v2")
|
||||
@registry.architectures("spacy.MishWindowEncoder.v2")
|
||||
def MishWindowEncoder(
|
||||
width: int, window_size: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -310,7 +310,7 @@ def MishWindowEncoder(
|
|||
return with_array(model)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
|
||||
def BiLSTMEncoder(
|
||||
width: int, depth: int, dropout: float
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
|
|
@ -230,7 +230,7 @@ def test_el_pipe_configuration(nlp):
|
|||
def get_lowercased_candidates(kb, span):
|
||||
return kb.get_alias_candidates(span.text.lower())
|
||||
|
||||
@registry.misc.register("spacy.LowercaseCandidateGenerator.v1")
|
||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||
return get_lowercased_candidates
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ subword_features = false
|
|||
"""
|
||||
|
||||
|
||||
@registry.architectures.register("my_test_parser")
|
||||
@registry.architectures("my_test_parser")
|
||||
def my_parser():
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
MultiHashEmbed(
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_serialize_subclassed_kb():
|
|||
super().__init__(vocab, entity_vector_length)
|
||||
self.custom_field = custom_field
|
||||
|
||||
@registry.misc.register("spacy.CustomKB.v1")
|
||||
@registry.misc("spacy.CustomKB.v1")
|
||||
def custom_kb(
|
||||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||
|
|
|
@ -4,12 +4,12 @@ from thinc.api import Linear
|
|||
from catalogue import RegistryError
|
||||
|
||||
|
||||
@registry.architectures.register("my_test_function")
|
||||
def create_model(nr_in, nr_out):
|
||||
return Linear(nr_in, nr_out)
|
||||
|
||||
|
||||
def test_get_architecture():
|
||||
|
||||
@registry.architectures("my_test_function")
|
||||
def create_model(nr_in, nr_out):
|
||||
return Linear(nr_in, nr_out)
|
||||
|
||||
arch = registry.architectures.get("my_test_function")
|
||||
assert arch is create_model
|
||||
with pytest.raises(RegistryError):
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_readers():
|
|||
factory = "textcat"
|
||||
"""
|
||||
|
||||
@registry.readers.register("myreader.v1")
|
||||
@registry.readers("myreader.v1")
|
||||
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
||||
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ next: /usage/projects
|
|||
> ```python
|
||||
> from thinc.api import Model, chain
|
||||
>
|
||||
> @spacy.registry.architectures.register("model.v1")
|
||||
> @spacy.registry.architectures("model.v1")
|
||||
> def build_model(width: int, classes: int) -> Model:
|
||||
> tok2vec = build_tok2vec(width)
|
||||
> output_layer = build_output_layer(width, classes)
|
||||
|
@ -563,7 +563,7 @@ matrix** (~~Floats2d~~) of predictions:
|
|||
|
||||
```python
|
||||
### The model architecture
|
||||
@spacy.registry.architectures.register("rel_model.v1")
|
||||
@spacy.registry.architectures("rel_model.v1")
|
||||
def create_relation_model(...) -> Model[List[Doc], Floats2d]:
|
||||
model = ... # 👈 model will go here
|
||||
return model
|
||||
|
@ -589,7 +589,7 @@ transforms the instance tensor into a final tensor holding the predictions:
|
|||
|
||||
```python
|
||||
### The model architecture {highlight="6"}
|
||||
@spacy.registry.architectures.register("rel_model.v1")
|
||||
@spacy.registry.architectures("rel_model.v1")
|
||||
def create_relation_model(
|
||||
create_instance_tensor: Model[List[Doc], Floats2d],
|
||||
classification_layer: Model[Floats2d, Floats2d],
|
||||
|
@ -613,7 +613,7 @@ The `classification_layer` could be something like a
|
|||
|
||||
```python
|
||||
### The classification layer
|
||||
@spacy.registry.architectures.register("rel_classification_layer.v1")
|
||||
@spacy.registry.architectures("rel_classification_layer.v1")
|
||||
def create_classification_layer(
|
||||
nO: int = None, nI: int = None
|
||||
) -> Model[Floats2d, Floats2d]:
|
||||
|
@ -650,7 +650,7 @@ that has the full implementation.
|
|||
|
||||
```python
|
||||
### The layer that creates the instance tensor
|
||||
@spacy.registry.architectures.register("rel_instance_tensor.v1")
|
||||
@spacy.registry.architectures("rel_instance_tensor.v1")
|
||||
def create_tensors(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
pooling: Model[Ragged, Floats2d],
|
||||
|
@ -731,7 +731,7 @@ are within a **maximum distance** (in number of tokens) of each other:
|
|||
|
||||
```python
|
||||
### Candidate generation
|
||||
@spacy.registry.misc.register("rel_instance_generator.v1")
|
||||
@spacy.registry.misc("rel_instance_generator.v1")
|
||||
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
|
||||
def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]:
|
||||
candidates = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user