mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
consistently use registry as callable
This commit is contained in:
parent
212f0e779e
commit
d900c55061
|
@ -8,7 +8,7 @@ from ...kb import KnowledgeBase, Candidate, get_candidates
|
||||||
from ...vocab import Vocab
|
from ...vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.EntityLinker.v1")
|
@registry.architectures("spacy.EntityLinker.v1")
|
||||||
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
||||||
with Model.define_operators({">>": chain, "**": clone}):
|
with Model.define_operators({">>": chain, "**": clone}):
|
||||||
token_width = tok2vec.get_dim("nO")
|
token_width = tok2vec.get_dim("nO")
|
||||||
|
@ -25,7 +25,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.misc.register("spacy.KBFromFile.v1")
|
@registry.misc("spacy.KBFromFile.v1")
|
||||||
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
||||||
def kb_from_file(vocab):
|
def kb_from_file(vocab):
|
||||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||||
|
@ -35,7 +35,7 @@ def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
||||||
return kb_from_file
|
return kb_from_file
|
||||||
|
|
||||||
|
|
||||||
@registry.misc.register("spacy.EmptyKB.v1")
|
@registry.misc("spacy.EmptyKB.v1")
|
||||||
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
||||||
def empty_kb_factory(vocab):
|
def empty_kb_factory(vocab):
|
||||||
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||||
|
@ -43,6 +43,6 @@ def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
||||||
return empty_kb_factory
|
return empty_kb_factory
|
||||||
|
|
||||||
|
|
||||||
@registry.misc.register("spacy.CandidateGenerator.v1")
|
@registry.misc("spacy.CandidateGenerator.v1")
|
||||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||||
return get_candidates
|
return get_candidates
|
||||||
|
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||||
from ...tokens import Doc # noqa: F401
|
from ...tokens import Doc # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.PretrainVectors.v1")
|
@registry.architectures("spacy.PretrainVectors.v1")
|
||||||
def create_pretrain_vectors(
|
def create_pretrain_vectors(
|
||||||
maxout_pieces: int, hidden_size: int, loss: str
|
maxout_pieces: int, hidden_size: int, loss: str
|
||||||
) -> Callable[["Vocab", Model], Model]:
|
) -> Callable[["Vocab", Model], Model]:
|
||||||
|
@ -40,7 +40,7 @@ def create_pretrain_vectors(
|
||||||
return create_vectors_objective
|
return create_vectors_objective
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.PretrainCharacters.v1")
|
@registry.architectures("spacy.PretrainCharacters.v1")
|
||||||
def create_pretrain_characters(
|
def create_pretrain_characters(
|
||||||
maxout_pieces: int, hidden_size: int, n_characters: int
|
maxout_pieces: int, hidden_size: int, n_characters: int
|
||||||
) -> Callable[["Vocab", Model], Model]:
|
) -> Callable[["Vocab", Model], Model]:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..tb_framework import TransitionModel
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
@registry.architectures("spacy.TransitionBasedParser.v1")
|
||||||
def transition_parser_v1(
|
def transition_parser_v1(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
state_type: Literal["parser", "ner"],
|
state_type: Literal["parser", "ner"],
|
||||||
|
@ -31,7 +31,7 @@ def transition_parser_v1(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v2")
|
@registry.architectures("spacy.TransitionBasedParser.v2")
|
||||||
def transition_parser_v2(
|
def transition_parser_v2(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
state_type: Literal["parser", "ner"],
|
state_type: Literal["parser", "ner"],
|
||||||
|
|
|
@ -6,7 +6,7 @@ from ...util import registry
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tagger.v1")
|
@registry.architectures("spacy.Tagger.v1")
|
||||||
def build_tagger_model(
|
def build_tagger_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ...tokens import Doc
|
||||||
from .tok2vec import get_tok2vec_width
|
from .tok2vec import get_tok2vec_width
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
@registry.architectures("spacy.TextCatCNN.v1")
|
||||||
def build_simple_cnn_text_classifier(
|
def build_simple_cnn_text_classifier(
|
||||||
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
@ -41,7 +41,7 @@ def build_simple_cnn_text_classifier(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatBOW.v1")
|
@registry.architectures("spacy.TextCatBOW.v1")
|
||||||
def build_bow_text_classifier(
|
def build_bow_text_classifier(
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
ngram_size: int,
|
ngram_size: int,
|
||||||
|
@ -60,7 +60,7 @@ def build_bow_text_classifier(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatEnsemble.v2")
|
@registry.architectures("spacy.TextCatEnsemble.v2")
|
||||||
def build_text_classifier_v2(
|
def build_text_classifier_v2(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
linear_model: Model[List[Doc], Floats2d],
|
linear_model: Model[List[Doc], Floats2d],
|
||||||
|
@ -112,7 +112,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
@registry.architectures("spacy.TextCatLowData.v1")
|
||||||
def build_text_classifier_lowdata(
|
def build_text_classifier_lowdata(
|
||||||
width: int, dropout: Optional[float], nO: Optional[int] = None
|
width: int, dropout: Optional[float], nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
|
|
@ -14,7 +14,7 @@ from ...pipeline.tok2vec import Tok2VecListener
|
||||||
from ...attrs import intify_attr
|
from ...attrs import intify_attr
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
@registry.architectures("spacy.Tok2VecListener.v1")
|
||||||
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
@ -31,7 +31,7 @@ def get_tok2vec_width(model: Model):
|
||||||
return nO
|
return nO
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
@registry.architectures("spacy.HashEmbedCNN.v1")
|
||||||
def build_hash_embed_cnn_tok2vec(
|
def build_hash_embed_cnn_tok2vec(
|
||||||
*,
|
*,
|
||||||
width: int,
|
width: int,
|
||||||
|
@ -87,7 +87,7 @@ def build_hash_embed_cnn_tok2vec(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2Vec.v2")
|
@registry.architectures("spacy.Tok2Vec.v2")
|
||||||
def build_Tok2Vec_model(
|
def build_Tok2Vec_model(
|
||||||
embed: Model[List[Doc], List[Floats2d]],
|
embed: Model[List[Doc], List[Floats2d]],
|
||||||
encode: Model[List[Floats2d], List[Floats2d]],
|
encode: Model[List[Floats2d], List[Floats2d]],
|
||||||
|
@ -108,7 +108,7 @@ def build_Tok2Vec_model(
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
@registry.architectures("spacy.MultiHashEmbed.v1")
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
width: int,
|
width: int,
|
||||||
attrs: List[Union[str, int]],
|
attrs: List[Union[str, int]],
|
||||||
|
@ -182,7 +182,7 @@ def MultiHashEmbed(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(
|
def CharacterEmbed(
|
||||||
width: int,
|
width: int,
|
||||||
rows: int,
|
rows: int,
|
||||||
|
@ -255,7 +255,7 @@ def CharacterEmbed(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2")
|
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
|
||||||
def MaxoutWindowEncoder(
|
def MaxoutWindowEncoder(
|
||||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
|
@ -287,7 +287,7 @@ def MaxoutWindowEncoder(
|
||||||
return with_array(model, pad=receptive_field)
|
return with_array(model, pad=receptive_field)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MishWindowEncoder.v2")
|
@registry.architectures("spacy.MishWindowEncoder.v2")
|
||||||
def MishWindowEncoder(
|
def MishWindowEncoder(
|
||||||
width: int, window_size: int, depth: int
|
width: int, window_size: int, depth: int
|
||||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
|
@ -310,7 +310,7 @@ def MishWindowEncoder(
|
||||||
return with_array(model)
|
return with_array(model)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
|
||||||
def BiLSTMEncoder(
|
def BiLSTMEncoder(
|
||||||
width: int, depth: int, dropout: float
|
width: int, depth: int, dropout: float
|
||||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
|
|
|
@ -230,7 +230,7 @@ def test_el_pipe_configuration(nlp):
|
||||||
def get_lowercased_candidates(kb, span):
|
def get_lowercased_candidates(kb, span):
|
||||||
return kb.get_alias_candidates(span.text.lower())
|
return kb.get_alias_candidates(span.text.lower())
|
||||||
|
|
||||||
@registry.misc.register("spacy.LowercaseCandidateGenerator.v1")
|
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||||
return get_lowercased_candidates
|
return get_lowercased_candidates
|
||||||
|
|
||||||
|
|
|
@ -160,7 +160,7 @@ subword_features = false
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("my_test_parser")
|
@registry.architectures("my_test_parser")
|
||||||
def my_parser():
|
def my_parser():
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
MultiHashEmbed(
|
MultiHashEmbed(
|
||||||
|
|
|
@ -108,7 +108,7 @@ def test_serialize_subclassed_kb():
|
||||||
super().__init__(vocab, entity_vector_length)
|
super().__init__(vocab, entity_vector_length)
|
||||||
self.custom_field = custom_field
|
self.custom_field = custom_field
|
||||||
|
|
||||||
@registry.misc.register("spacy.CustomKB.v1")
|
@registry.misc("spacy.CustomKB.v1")
|
||||||
def custom_kb(
|
def custom_kb(
|
||||||
entity_vector_length: int, custom_field: int
|
entity_vector_length: int, custom_field: int
|
||||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||||
|
|
|
@ -4,12 +4,12 @@ from thinc.api import Linear
|
||||||
from catalogue import RegistryError
|
from catalogue import RegistryError
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("my_test_function")
|
|
||||||
def create_model(nr_in, nr_out):
|
|
||||||
return Linear(nr_in, nr_out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_architecture():
|
def test_get_architecture():
|
||||||
|
|
||||||
|
@registry.architectures("my_test_function")
|
||||||
|
def create_model(nr_in, nr_out):
|
||||||
|
return Linear(nr_in, nr_out)
|
||||||
|
|
||||||
arch = registry.architectures.get("my_test_function")
|
arch = registry.architectures.get("my_test_function")
|
||||||
assert arch is create_model
|
assert arch is create_model
|
||||||
with pytest.raises(RegistryError):
|
with pytest.raises(RegistryError):
|
||||||
|
|
|
@ -27,7 +27,7 @@ def test_readers():
|
||||||
factory = "textcat"
|
factory = "textcat"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@registry.readers.register("myreader.v1")
|
@registry.readers("myreader.v1")
|
||||||
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
||||||
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ next: /usage/projects
|
||||||
> ```python
|
> ```python
|
||||||
> from thinc.api import Model, chain
|
> from thinc.api import Model, chain
|
||||||
>
|
>
|
||||||
> @spacy.registry.architectures.register("model.v1")
|
> @spacy.registry.architectures("model.v1")
|
||||||
> def build_model(width: int, classes: int) -> Model:
|
> def build_model(width: int, classes: int) -> Model:
|
||||||
> tok2vec = build_tok2vec(width)
|
> tok2vec = build_tok2vec(width)
|
||||||
> output_layer = build_output_layer(width, classes)
|
> output_layer = build_output_layer(width, classes)
|
||||||
|
@ -563,7 +563,7 @@ matrix** (~~Floats2d~~) of predictions:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### The model architecture
|
### The model architecture
|
||||||
@spacy.registry.architectures.register("rel_model.v1")
|
@spacy.registry.architectures("rel_model.v1")
|
||||||
def create_relation_model(...) -> Model[List[Doc], Floats2d]:
|
def create_relation_model(...) -> Model[List[Doc], Floats2d]:
|
||||||
model = ... # 👈 model will go here
|
model = ... # 👈 model will go here
|
||||||
return model
|
return model
|
||||||
|
@ -589,7 +589,7 @@ transforms the instance tensor into a final tensor holding the predictions:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### The model architecture {highlight="6"}
|
### The model architecture {highlight="6"}
|
||||||
@spacy.registry.architectures.register("rel_model.v1")
|
@spacy.registry.architectures("rel_model.v1")
|
||||||
def create_relation_model(
|
def create_relation_model(
|
||||||
create_instance_tensor: Model[List[Doc], Floats2d],
|
create_instance_tensor: Model[List[Doc], Floats2d],
|
||||||
classification_layer: Model[Floats2d, Floats2d],
|
classification_layer: Model[Floats2d, Floats2d],
|
||||||
|
@ -613,7 +613,7 @@ The `classification_layer` could be something like a
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### The classification layer
|
### The classification layer
|
||||||
@spacy.registry.architectures.register("rel_classification_layer.v1")
|
@spacy.registry.architectures("rel_classification_layer.v1")
|
||||||
def create_classification_layer(
|
def create_classification_layer(
|
||||||
nO: int = None, nI: int = None
|
nO: int = None, nI: int = None
|
||||||
) -> Model[Floats2d, Floats2d]:
|
) -> Model[Floats2d, Floats2d]:
|
||||||
|
@ -650,7 +650,7 @@ that has the full implementation.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### The layer that creates the instance tensor
|
### The layer that creates the instance tensor
|
||||||
@spacy.registry.architectures.register("rel_instance_tensor.v1")
|
@spacy.registry.architectures("rel_instance_tensor.v1")
|
||||||
def create_tensors(
|
def create_tensors(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
pooling: Model[Ragged, Floats2d],
|
pooling: Model[Ragged, Floats2d],
|
||||||
|
@ -731,7 +731,7 @@ are within a **maximum distance** (in number of tokens) of each other:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### Candidate generation
|
### Candidate generation
|
||||||
@spacy.registry.misc.register("rel_instance_generator.v1")
|
@spacy.registry.misc("rel_instance_generator.v1")
|
||||||
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
|
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
|
||||||
def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]:
|
def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]:
|
||||||
candidates = []
|
candidates = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user