mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Tidy up and auto-format
This commit is contained in:
parent
181039bd17
commit
9614e53b02
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Union, Dict
|
from typing import Optional, List, Union
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||||
|
@ -11,7 +11,7 @@ from ...ml import _character_embed
|
||||||
from ..staticvectors import StaticVectors
|
from ..staticvectors import StaticVectors
|
||||||
from ..featureextractor import FeatureExtractor
|
from ..featureextractor import FeatureExtractor
|
||||||
from ...pipeline.tok2vec import Tok2VecListener
|
from ...pipeline.tok2vec import Tok2VecListener
|
||||||
from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
from ...attrs import intify_attr
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||||
|
@ -29,7 +29,7 @@ def build_hash_embed_cnn_tok2vec(
|
||||||
window_size: int,
|
window_size: int,
|
||||||
maxout_pieces: int,
|
maxout_pieces: int,
|
||||||
subword_features: bool,
|
subword_features: bool,
|
||||||
pretrained_vectors: Optional[bool]
|
pretrained_vectors: Optional[bool],
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||||
with subword features and a CNN with layer-normalized maxout.
|
with subword features and a CNN with layer-normalized maxout.
|
||||||
|
@ -56,7 +56,7 @@ def build_hash_embed_cnn_tok2vec(
|
||||||
"""
|
"""
|
||||||
if subword_features:
|
if subword_features:
|
||||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
|
row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
|
||||||
else:
|
else:
|
||||||
attrs = ["NORM"]
|
attrs = ["NORM"]
|
||||||
row_sizes = [embed_size]
|
row_sizes = [embed_size]
|
||||||
|
@ -120,7 +120,7 @@ def MultiHashEmbed(
|
||||||
layer is used to map the vectors to the specified width before concatenating
|
layer is used to map the vectors to the specified width before concatenating
|
||||||
it with the other embedding outputs. A single Maxout layer is then used to
|
it with the other embedding outputs. A single Maxout layer is then used to
|
||||||
reduce the concatenated vectors to the final width.
|
reduce the concatenated vectors to the final width.
|
||||||
|
|
||||||
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||||
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||||
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||||
|
@ -143,13 +143,7 @@ def MultiHashEmbed(
|
||||||
def make_hash_embed(index):
|
def make_hash_embed(index):
|
||||||
nonlocal seed
|
nonlocal seed
|
||||||
seed += 1
|
seed += 1
|
||||||
return HashEmbed(
|
return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
|
||||||
width,
|
|
||||||
rows[index],
|
|
||||||
column=index,
|
|
||||||
seed=seed,
|
|
||||||
dropout=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
|
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
|
||||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user