Tidy up and auto-format

This commit is contained in:
Ines Montani 2020-10-05 21:55:18 +02:00
parent 181039bd17
commit 9614e53b02

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Union, Dict from typing import Optional, List, Union
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
@ -11,7 +11,7 @@ from ...ml import _character_embed
from ..staticvectors import StaticVectors from ..staticvectors import StaticVectors
from ..featureextractor import FeatureExtractor from ..featureextractor import FeatureExtractor
from ...pipeline.tok2vec import Tok2VecListener from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr from ...attrs import intify_attr
@registry.architectures.register("spacy.Tok2VecListener.v1") @registry.architectures.register("spacy.Tok2VecListener.v1")
@ -29,7 +29,7 @@ def build_hash_embed_cnn_tok2vec(
window_size: int, window_size: int,
maxout_pieces: int, maxout_pieces: int,
subword_features: bool, subword_features: bool,
pretrained_vectors: Optional[bool] pretrained_vectors: Optional[bool],
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding """Build spaCy's 'standard' tok2vec layer, which uses hash embedding
with subword features and a CNN with layer-normalized maxout. with subword features and a CNN with layer-normalized maxout.
@ -56,7 +56,7 @@ def build_hash_embed_cnn_tok2vec(
""" """
if subword_features: if subword_features:
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2] row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
else: else:
attrs = ["NORM"] attrs = ["NORM"]
row_sizes = [embed_size] row_sizes = [embed_size]
@ -143,13 +143,7 @@ def MultiHashEmbed(
def make_hash_embed(index): def make_hash_embed(index):
nonlocal seed nonlocal seed
seed += 1 seed += 1
return HashEmbed( return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
width,
rows[index],
column=index,
seed=seed,
dropout=0.0,
)
embeddings = [make_hash_embed(i) for i in range(len(attrs))] embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors) concat_size = width * (len(embeddings) + include_static_vectors)