Simplify MultiHashEmbed signature

This commit is contained in:
Matthew Honnibal 2020-10-05 19:57:45 +02:00
parent 71e73ed0a6
commit 6dcc4a0ba6
3 changed files with 31 additions and 41 deletions

View File

@ -55,13 +55,15 @@ def build_hash_embed_cnn_tok2vec(
pretrained_vectors (bool): Whether to also use static vectors.
"""
if subword_features:
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5}
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
else:
attrs = {"NORM": 1.0}
attrs = ["NORM"]
row_sizes = [embed_size]
return build_Tok2Vec_model(
embed=MultiHashEmbed(
width=width,
rows=embed_size,
rows=row_sizes,
attrs=attrs,
include_static_vectors=bool(pretrained_vectors),
),
@ -103,7 +105,7 @@ def MultiHashEmbed_v1(
here as a temporary compatibility."""
return MultiHashEmbed(
width=width,
rows=rows,
rows=[rows, rows//2, rows//2, rows//2] if also_embed_subwords else [rows],
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
include_static_vectors=also_use_static_vectors,
)
@ -112,8 +114,8 @@ def MultiHashEmbed_v1(
@registry.architectures.register("spacy.MultiHashEmbed.v2")
def MultiHashEmbed(
width: int,
rows: int,
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]],
attrs: List[Union[str, int]],
rows: List[int],
include_static_vectors: bool,
) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedding layer that separately embeds a number of lexical
@ -136,50 +138,38 @@ def MultiHashEmbed(
The `rows` parameter controls the number of rows used by the `HashEmbed`
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
even for very large vocabularies. You can vary the number of rows per
attribute by specifying the attrs as a dict, mapping the keys to float
values which are interpreted as factors of `rows`. For instance,
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
assumed for all attributes.
even for very large vocabularies. A number of rows must be specified for each
table, so the `rows` list must be of the same length as the `attrs` parameter.
width (int): The output width. Also used as the width of the embedding tables.
Recommended values are between 64 and 300.
rows (int): The base number of rows for the embedding tables. Can be low, due
to the hashing trick. The rows can be varied per attribute by providing
a dictionary as the value of `attrs`.
attrs (dict or list of attr IDs): The token attributes to embed. A separate
embedding table will be constructed for each attribute. Attributes
can be specified as a list or as a dictionary, which lets you control
the number of rows used for each table.
attrs (list of attr IDs): The token attributes to embed. A separate
embedding table will be constructed for each attribute.
rows (List[int]): The number of rows in the embedding tables. Must have the
same length as attrs.
include_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab.
"""
if isinstance(attrs, dict):
# Exclude tables that would have 0 rows.
attrs = {key: value for key, value in attrs.items() if value > 0.0}
indices = {attr: i for i, attr in enumerate(attrs)}
seed = 7
def make_hash_embed(feature):
def make_hash_embed(index):
nonlocal seed
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
seed += 1
return HashEmbed(
width,
int(rows * row_factor),
column=indices[feature],
rows[index],
column=index,
seed=seed,
dropout=0.0,
)
embeddings = [make_hash_embed(attr) for attr in attrs]
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
if include_static_vectors:
model = chain(
concatenate(
chain(
FeatureExtractor(list(attrs)),
FeatureExtractor(attrs),
list2ragged(),
with_array(concatenate(*embeddings)),
),

View File

@ -24,7 +24,7 @@ def test_empty_doc():
tok2vec = build_Tok2Vec_model(
MultiHashEmbed(
width=width,
rows=embed_size,
rows=[embed_size, embed_size, embed_size, embed_size],
include_static_vectors=False,
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
),
@ -44,7 +44,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
tok2vec = build_Tok2Vec_model(
MultiHashEmbed(
width=width,
rows=embed_size,
rows=[embed_size] * 4,
include_static_vectors=False,
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
),
@ -61,8 +61,8 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
@pytest.mark.parametrize(
"width,embed_arch,embed_config,encode_arch,encode_config",
[
(8, MultiHashEmbed, {"rows": 100, "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
(8, MultiHashEmbed, {"rows": 100, "attrs": {"ORTH": 1.0, "PREFIX": 0.2}, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
],
@ -116,11 +116,11 @@ cfg_string = """
@architectures = "spacy.Tok2Vec.v1"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
@architectures = "spacy.MultiHashEmbed.v2"
width = ${components.tok2vec.model.encode.width}
rows = 2000
also_embed_subwords = true
also_use_static_vectors = false
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"

View File

@ -63,7 +63,7 @@ def get_tok2vec_kwargs():
return {
"embed": MultiHashEmbed(
width=32,
rows=500,
rows=[500, 500, 500],
attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False
),
@ -80,7 +80,7 @@ def test_tok2vec():
def test_multi_hash_embed():
embed = MultiHashEmbed(
width=32,
rows=500,
rows=[500, 500, 500],
attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False
)
@ -95,8 +95,8 @@ def test_multi_hash_embed():
# Now try with different row factors
embed = MultiHashEmbed(
width=32,
rows=500,
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5},
rows=[1000, 50, 250],
attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False
)
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]