mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Format
This commit is contained in:
parent
234c52a91e
commit
473504d837
|
@ -1,11 +1,22 @@
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from thinc.api import zero_init, with_array, Softmax, chain, Model
|
from thinc.api import zero_init, with_array, Softmax, chain, Model
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tagger.v1")
|
@registry.architectures.register("spacy.Tagger.v1")
|
||||||
def build_tagger_model(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
def build_tagger_model(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||||
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
"""Build a tagger model, using a provided token-to-vector component. The tagger
|
||||||
|
model simply adds a linear layer with softmax activation to predict scores
|
||||||
|
given the token vectors.
|
||||||
|
|
||||||
|
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
|
||||||
|
nO (int or None): The number of tags to output. Inferred from the data if None.
|
||||||
|
"""
|
||||||
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
|
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
|
||||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||||
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
||||||
|
|
|
@ -205,7 +205,9 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
||||||
def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int) -> Model[List[Floats2d], List[Floats2d]]:
|
def MaxoutWindowEncoder(
|
||||||
|
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||||
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
"""Encode context using convolutions with maxout activation, layer
|
"""Encode context using convolutions with maxout activation, layer
|
||||||
normalization and residual connections.
|
normalization and residual connections.
|
||||||
|
|
||||||
|
@ -235,7 +237,9 @@ def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth:
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MishWindowEncoder.v1")
|
@registry.architectures.register("spacy.MishWindowEncoder.v1")
|
||||||
def MishWindowEncoder(width: int, window_size: int, depth: int) -> Model[List[Floats2d], List[Floats2d]]:
|
def MishWindowEncoder(
|
||||||
|
width: int, window_size: int, depth: int
|
||||||
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
"""Encode context using convolutions with mish activation, layer
|
"""Encode context using convolutions with mish activation, layer
|
||||||
normalization and residual connections.
|
normalization and residual connections.
|
||||||
|
|
||||||
|
@ -256,7 +260,9 @@ def MishWindowEncoder(width: int, window_size: int, depth: int) -> Model[List[Fl
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||||
def BiLSTMEncoder(width: int, depth: int, dropout: float) -> Model[List[Floats2d], List[Floats2d]]:
|
def BiLSTMEncoder(
|
||||||
|
width: int, depth: int, dropout: float
|
||||||
|
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||||
"""Encode context using bidirectonal LSTM layers. Requires PyTorch.
|
"""Encode context using bidirectonal LSTM layers. Requires PyTorch.
|
||||||
|
|
||||||
width (int): The input and output width. These are required to be the same,
|
width (int): The input and output width. These are required to be the same,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user