mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add previous HashEmbedCNN tok2vec to make transition easier
This commit is contained in:
parent
1784c95827
commit
c35d6282fc
|
@ -20,8 +20,37 @@ def tok2vec_listener_v1(width, upstream="*"):
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||||
|
def build_hash_embed_cnn_tok2vec(
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
depth: int,
|
||||||
|
embed_size: int,
|
||||||
|
window_size: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
subword_features: bool,
|
||||||
|
dropout: Optional[float],
|
||||||
|
pretrained_vectors: Optional[bool]
|
||||||
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||||
|
with subword features and a CNN with layer-normalized maxout."""
|
||||||
|
return build_Tok2Vec_model(
|
||||||
|
embed=MultiHashEmbed(
|
||||||
|
width=width,
|
||||||
|
rows=embed_size,
|
||||||
|
also_embed_subwords=subword_features,
|
||||||
|
also_use_static_vectors=bool(pretrained_vectors),
|
||||||
|
),
|
||||||
|
encode=MaxoutWindowEncoder(
|
||||||
|
width=width,
|
||||||
|
depth=depth,
|
||||||
|
window_size=window_size,
|
||||||
|
maxout_pieces=maxout_pieces
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||||
def Tok2Vec(
|
def build_Tok2Vec_model(
|
||||||
embed: Model[List[Doc], List[Floats2d]],
|
embed: Model[List[Doc], List[Floats2d]],
|
||||||
encode: Model[List[Floats2d], List[Floats2d]],
|
encode: Model[List[Floats2d], List[Floats2d]],
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
@ -62,7 +91,7 @@ def MultiHashEmbed(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
embeddings = [make_hash_embed(NORM)]
|
embeddings = [make_hash_embed(NORM)]
|
||||||
|
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||||
if also_use_static_vectors:
|
if also_use_static_vectors:
|
||||||
model = chain(
|
model = chain(
|
||||||
concatenate(
|
concatenate(
|
||||||
|
@ -73,7 +102,7 @@ def MultiHashEmbed(
|
||||||
),
|
),
|
||||||
StaticVectors(width, dropout=0.0),
|
StaticVectors(width, dropout=0.0),
|
||||||
),
|
),
|
||||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||||
ragged2list(),
|
ragged2list(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -83,7 +112,7 @@ def MultiHashEmbed(
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings)),
|
with_array(concatenate(*embeddings)),
|
||||||
),
|
),
|
||||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||||
ragged2list(),
|
ragged2list(),
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user