Tidy up and auto-format

This commit is contained in:
Ines Montani 2020-10-05 21:55:18 +02:00
parent 181039bd17
commit 9614e53b02

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Union, Dict
from typing import Optional, List, Union
from thinc.types import Floats2d
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
@ -11,7 +11,7 @@ from ...ml import _character_embed
from ..staticvectors import StaticVectors
from ..featureextractor import FeatureExtractor
from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
from ...attrs import intify_attr
@registry.architectures.register("spacy.Tok2VecListener.v1")
@ -29,7 +29,7 @@ def build_hash_embed_cnn_tok2vec(
window_size: int,
maxout_pieces: int,
subword_features: bool,
pretrained_vectors: Optional[bool]
pretrained_vectors: Optional[bool],
) -> Model[List[Doc], List[Floats2d]]:
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
with subword features and a CNN with layer-normalized maxout.
@ -56,7 +56,7 @@ def build_hash_embed_cnn_tok2vec(
"""
if subword_features:
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
else:
attrs = ["NORM"]
row_sizes = [embed_size]
@ -143,13 +143,7 @@ def MultiHashEmbed(
def make_hash_embed(index):
nonlocal seed
seed += 1
return HashEmbed(
width,
rows[index],
column=index,
seed=seed,
dropout=0.0,
)
return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)