mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add previous HashEmbedCNN tok2vec to make transition easier
This commit is contained in:
parent
1784c95827
commit
c35d6282fc
|
@ -20,8 +20,37 @@ def tok2vec_listener_v1(width, upstream="*"):
|
|||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
depth: int,
|
||||
embed_size: int,
|
||||
window_size: int,
|
||||
maxout_pieces: int,
|
||||
subword_features: bool,
|
||||
dropout: Optional[float],
|
||||
pretrained_vectors: Optional[bool]
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||
with subword features and a CNN with layer-normalized maxout."""
|
||||
return build_Tok2Vec_model(
|
||||
embed=MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
also_embed_subwords=subword_features,
|
||||
also_use_static_vectors=bool(pretrained_vectors),
|
||||
),
|
||||
encode=MaxoutWindowEncoder(
|
||||
width=width,
|
||||
depth=depth,
|
||||
window_size=window_size,
|
||||
maxout_pieces=maxout_pieces
|
||||
)
|
||||
)
|
||||
|
||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||
def Tok2Vec(
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
@ -62,7 +91,7 @@ def MultiHashEmbed(
|
|||
]
|
||||
else:
|
||||
embeddings = [make_hash_embed(NORM)]
|
||||
|
||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||
if also_use_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
|
@ -73,7 +102,7 @@ def MultiHashEmbed(
|
|||
),
|
||||
StaticVectors(width, dropout=0.0),
|
||||
),
|
||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list(),
|
||||
)
|
||||
else:
|
||||
|
@ -83,7 +112,7 @@ def MultiHashEmbed(
|
|||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list(),
|
||||
)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue
Block a user