mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Simplify MultiHashEmbed signature
This commit is contained in:
parent
71e73ed0a6
commit
6dcc4a0ba6
|
@ -55,13 +55,15 @@ def build_hash_embed_cnn_tok2vec(
|
||||||
pretrained_vectors (bool): Whether to also use static vectors.
|
pretrained_vectors (bool): Whether to also use static vectors.
|
||||||
"""
|
"""
|
||||||
if subword_features:
|
if subword_features:
|
||||||
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5}
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
|
||||||
else:
|
else:
|
||||||
attrs = {"NORM": 1.0}
|
attrs = ["NORM"]
|
||||||
|
row_sizes = [embed_size]
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
embed=MultiHashEmbed(
|
embed=MultiHashEmbed(
|
||||||
width=width,
|
width=width,
|
||||||
rows=embed_size,
|
rows=row_sizes,
|
||||||
attrs=attrs,
|
attrs=attrs,
|
||||||
include_static_vectors=bool(pretrained_vectors),
|
include_static_vectors=bool(pretrained_vectors),
|
||||||
),
|
),
|
||||||
|
@ -103,7 +105,7 @@ def MultiHashEmbed_v1(
|
||||||
here as a temporary compatibility."""
|
here as a temporary compatibility."""
|
||||||
return MultiHashEmbed(
|
return MultiHashEmbed(
|
||||||
width=width,
|
width=width,
|
||||||
rows=rows,
|
rows=[rows, rows//2, rows//2, rows//2] if also_embed_subwords else [rows],
|
||||||
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
|
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
|
||||||
include_static_vectors=also_use_static_vectors,
|
include_static_vectors=also_use_static_vectors,
|
||||||
)
|
)
|
||||||
|
@ -112,8 +114,8 @@ def MultiHashEmbed_v1(
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v2")
|
@registry.architectures.register("spacy.MultiHashEmbed.v2")
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
width: int,
|
width: int,
|
||||||
rows: int,
|
attrs: List[Union[str, int]],
|
||||||
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]],
|
rows: List[int],
|
||||||
include_static_vectors: bool,
|
include_static_vectors: bool,
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
"""Construct an embedding layer that separately embeds a number of lexical
|
"""Construct an embedding layer that separately embeds a number of lexical
|
||||||
|
@ -136,50 +138,38 @@ def MultiHashEmbed(
|
||||||
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||||
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||||
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||||
even for very large vocabularies. You can vary the number of rows per
|
even for very large vocabularies. A number of rows must be specified for each
|
||||||
attribute by specifying the attrs as a dict, mapping the keys to float
|
table, so the `rows` list must be of the same length as the `attrs` parameter.
|
||||||
values which are interpreted as factors of `rows`. For instance,
|
|
||||||
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
|
|
||||||
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
|
|
||||||
assumed for all attributes.
|
|
||||||
|
|
||||||
width (int): The output width. Also used as the width of the embedding tables.
|
width (int): The output width. Also used as the width of the embedding tables.
|
||||||
Recommended values are between 64 and 300.
|
Recommended values are between 64 and 300.
|
||||||
rows (int): The base number of rows for the embedding tables. Can be low, due
|
attrs (list of attr IDs): The token attributes to embed. A separate
|
||||||
to the hashing trick. The rows can be varied per attribute by providing
|
embedding table will be constructed for each attribute.
|
||||||
a dictionary as the value of `attrs`.
|
rows (List[int]): The number of rows in the embedding tables. Must have the
|
||||||
attrs (dict or list of attr IDs): The token attributes to embed. A separate
|
same length as attrs.
|
||||||
embedding table will be constructed for each attribute. Attributes
|
|
||||||
can be specified as a list or as a dictionary, which lets you control
|
|
||||||
the number of rows used for each table.
|
|
||||||
include_static_vectors (bool): Whether to also use static word vectors.
|
include_static_vectors (bool): Whether to also use static word vectors.
|
||||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||||
"""
|
"""
|
||||||
if isinstance(attrs, dict):
|
|
||||||
# Exclude tables that would have 0 rows.
|
|
||||||
attrs = {key: value for key, value in attrs.items() if value > 0.0}
|
|
||||||
indices = {attr: i for i, attr in enumerate(attrs)}
|
|
||||||
seed = 7
|
seed = 7
|
||||||
|
|
||||||
def make_hash_embed(feature):
|
def make_hash_embed(index):
|
||||||
nonlocal seed
|
nonlocal seed
|
||||||
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
|
|
||||||
seed += 1
|
seed += 1
|
||||||
return HashEmbed(
|
return HashEmbed(
|
||||||
width,
|
width,
|
||||||
int(rows * row_factor),
|
rows[index],
|
||||||
column=indices[feature],
|
column=index,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = [make_hash_embed(attr) for attr in attrs]
|
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
|
||||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||||
if include_static_vectors:
|
if include_static_vectors:
|
||||||
model = chain(
|
model = chain(
|
||||||
concatenate(
|
concatenate(
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor(list(attrs)),
|
FeatureExtractor(attrs),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings)),
|
with_array(concatenate(*embeddings)),
|
||||||
),
|
),
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_empty_doc():
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
MultiHashEmbed(
|
MultiHashEmbed(
|
||||||
width=width,
|
width=width,
|
||||||
rows=embed_size,
|
rows=[embed_size, embed_size, embed_size, embed_size],
|
||||||
include_static_vectors=False,
|
include_static_vectors=False,
|
||||||
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
||||||
),
|
),
|
||||||
|
@ -44,7 +44,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
MultiHashEmbed(
|
MultiHashEmbed(
|
||||||
width=width,
|
width=width,
|
||||||
rows=embed_size,
|
rows=[embed_size] * 4,
|
||||||
include_static_vectors=False,
|
include_static_vectors=False,
|
||||||
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
||||||
),
|
),
|
||||||
|
@ -61,8 +61,8 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"width,embed_arch,embed_config,encode_arch,encode_config",
|
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||||
[
|
[
|
||||||
(8, MultiHashEmbed, {"rows": 100, "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||||
(8, MultiHashEmbed, {"rows": 100, "attrs": {"ORTH": 1.0, "PREFIX": 0.2}, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||||
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||||
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||||
],
|
],
|
||||||
|
@ -116,11 +116,11 @@ cfg_string = """
|
||||||
@architectures = "spacy.Tok2Vec.v1"
|
@architectures = "spacy.Tok2Vec.v1"
|
||||||
|
|
||||||
[components.tok2vec.model.embed]
|
[components.tok2vec.model.embed]
|
||||||
@architectures = "spacy.MultiHashEmbed.v1"
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
width = ${components.tok2vec.model.encode.width}
|
width = ${components.tok2vec.model.encode.width}
|
||||||
rows = 2000
|
rows = [2000, 1000, 1000, 1000]
|
||||||
also_embed_subwords = true
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
also_use_static_vectors = false
|
include_static_vectors = false
|
||||||
|
|
||||||
[components.tok2vec.model.encode]
|
[components.tok2vec.model.encode]
|
||||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||||
|
|
|
@ -63,7 +63,7 @@ def get_tok2vec_kwargs():
|
||||||
return {
|
return {
|
||||||
"embed": MultiHashEmbed(
|
"embed": MultiHashEmbed(
|
||||||
width=32,
|
width=32,
|
||||||
rows=500,
|
rows=[500, 500, 500],
|
||||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||||
include_static_vectors=False
|
include_static_vectors=False
|
||||||
),
|
),
|
||||||
|
@ -80,7 +80,7 @@ def test_tok2vec():
|
||||||
def test_multi_hash_embed():
|
def test_multi_hash_embed():
|
||||||
embed = MultiHashEmbed(
|
embed = MultiHashEmbed(
|
||||||
width=32,
|
width=32,
|
||||||
rows=500,
|
rows=[500, 500, 500],
|
||||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||||
include_static_vectors=False
|
include_static_vectors=False
|
||||||
)
|
)
|
||||||
|
@ -95,8 +95,8 @@ def test_multi_hash_embed():
|
||||||
# Now try with different row factors
|
# Now try with different row factors
|
||||||
embed = MultiHashEmbed(
|
embed = MultiHashEmbed(
|
||||||
width=32,
|
width=32,
|
||||||
rows=500,
|
rows=[1000, 50, 250],
|
||||||
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5},
|
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||||
include_static_vectors=False
|
include_static_vectors=False
|
||||||
)
|
)
|
||||||
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user