Simplify MultiHashEmbed signature

This commit is contained in:
Matthew Honnibal 2020-10-05 19:57:45 +02:00
parent 71e73ed0a6
commit 6dcc4a0ba6
3 changed files with 31 additions and 41 deletions

View File

@ -55,13 +55,15 @@ def build_hash_embed_cnn_tok2vec(
pretrained_vectors (bool): Whether to also use static vectors. pretrained_vectors (bool): Whether to also use static vectors.
""" """
if subword_features: if subword_features:
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5} attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
else: else:
attrs = {"NORM": 1.0} attrs = ["NORM"]
row_sizes = [embed_size]
return build_Tok2Vec_model( return build_Tok2Vec_model(
embed=MultiHashEmbed( embed=MultiHashEmbed(
width=width, width=width,
rows=embed_size, rows=row_sizes,
attrs=attrs, attrs=attrs,
include_static_vectors=bool(pretrained_vectors), include_static_vectors=bool(pretrained_vectors),
), ),
@ -103,7 +105,7 @@ def MultiHashEmbed_v1(
here as a temporary compatibility.""" here as a temporary compatibility."""
return MultiHashEmbed( return MultiHashEmbed(
width=width, width=width,
rows=rows, rows=[rows, rows//2, rows//2, rows//2] if also_embed_subwords else [rows],
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM], attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
include_static_vectors=also_use_static_vectors, include_static_vectors=also_use_static_vectors,
) )
@ -112,8 +114,8 @@ def MultiHashEmbed_v1(
@registry.architectures.register("spacy.MultiHashEmbed.v2") @registry.architectures.register("spacy.MultiHashEmbed.v2")
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int,
rows: int, attrs: List[Union[str, int]],
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]], rows: List[int],
include_static_vectors: bool, include_static_vectors: bool,
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedding layer that separately embeds a number of lexical """Construct an embedding layer that separately embeds a number of lexical
@ -136,50 +138,38 @@ def MultiHashEmbed(
The `rows` parameter controls the number of rows used by the `HashEmbed` The `rows` parameter controls the number of rows used by the `HashEmbed`
tables. The HashEmbed layer needs surprisingly few rows, due to its use of tables. The HashEmbed layer needs surprisingly few rows, due to its use of
the hashing trick. Generally between 2000 and 10000 rows is sufficient, the hashing trick. Generally between 2000 and 10000 rows is sufficient,
even for very large vocabularies. You can vary the number of rows per even for very large vocabularies. A number of rows must be specified for each
attribute by specifying the attrs as a dict, mapping the keys to float table, so the `rows` list must be of the same length as the `attrs` parameter.
values which are interpreted as factors of `rows`. For instance,
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
assumed for all attributes.
width (int): The output width. Also used as the width of the embedding tables. width (int): The output width. Also used as the width of the embedding tables.
Recommended values are between 64 and 300. Recommended values are between 64 and 300.
rows (int): The base number of rows for the embedding tables. Can be low, due attrs (list of attr IDs): The token attributes to embed. A separate
to the hashing trick. The rows can be varied per attribute by providing embedding table will be constructed for each attribute.
a dictionary as the value of `attrs`. rows (List[int]): The number of rows in the embedding tables. Must have the
attrs (dict or list of attr IDs): The token attributes to embed. A separate same length as attrs.
embedding table will be constructed for each attribute. Attributes
can be specified as a list or as a dictionary, which lets you control
the number of rows used for each table.
include_static_vectors (bool): Whether to also use static word vectors. include_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab. Requires a vectors table to be loaded in the Doc objects' vocab.
""" """
if isinstance(attrs, dict):
# Exclude tables that would have 0 rows.
attrs = {key: value for key, value in attrs.items() if value > 0.0}
indices = {attr: i for i, attr in enumerate(attrs)}
seed = 7 seed = 7
def make_hash_embed(feature): def make_hash_embed(index):
nonlocal seed nonlocal seed
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
seed += 1 seed += 1
return HashEmbed( return HashEmbed(
width, width,
int(rows * row_factor), rows[index],
column=indices[feature], column=index,
seed=seed, seed=seed,
dropout=0.0, dropout=0.0,
) )
embeddings = [make_hash_embed(attr) for attr in attrs] embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors) concat_size = width * (len(embeddings) + include_static_vectors)
if include_static_vectors: if include_static_vectors:
model = chain( model = chain(
concatenate( concatenate(
chain( chain(
FeatureExtractor(list(attrs)), FeatureExtractor(attrs),
list2ragged(), list2ragged(),
with_array(concatenate(*embeddings)), with_array(concatenate(*embeddings)),
), ),

View File

@ -24,7 +24,7 @@ def test_empty_doc():
tok2vec = build_Tok2Vec_model( tok2vec = build_Tok2Vec_model(
MultiHashEmbed( MultiHashEmbed(
width=width, width=width,
rows=embed_size, rows=[embed_size, embed_size, embed_size, embed_size],
include_static_vectors=False, include_static_vectors=False,
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"], attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
), ),
@ -44,7 +44,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
tok2vec = build_Tok2Vec_model( tok2vec = build_Tok2Vec_model(
MultiHashEmbed( MultiHashEmbed(
width=width, width=width,
rows=embed_size, rows=[embed_size] * 4,
include_static_vectors=False, include_static_vectors=False,
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"], attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
), ),
@ -61,8 +61,8 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"width,embed_arch,embed_config,encode_arch,encode_config", "width,embed_arch,embed_config,encode_arch,encode_config",
[ [
(8, MultiHashEmbed, {"rows": 100, "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}), (8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
(8, MultiHashEmbed, {"rows": 100, "attrs": {"ORTH": 1.0, "PREFIX": 0.2}, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}), (8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}), (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
], ],
@ -116,11 +116,11 @@ cfg_string = """
@architectures = "spacy.Tok2Vec.v1" @architectures = "spacy.Tok2Vec.v1"
[components.tok2vec.model.embed] [components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1" @architectures = "spacy.MultiHashEmbed.v2"
width = ${components.tok2vec.model.encode.width} width = ${components.tok2vec.model.encode.width}
rows = 2000 rows = [2000, 1000, 1000, 1000]
also_embed_subwords = true attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
also_use_static_vectors = false include_static_vectors = false
[components.tok2vec.model.encode] [components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1" @architectures = "spacy.MaxoutWindowEncoder.v1"

View File

@ -63,7 +63,7 @@ def get_tok2vec_kwargs():
return { return {
"embed": MultiHashEmbed( "embed": MultiHashEmbed(
width=32, width=32,
rows=500, rows=[500, 500, 500],
attrs=["NORM", "PREFIX", "SHAPE"], attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False include_static_vectors=False
), ),
@ -80,7 +80,7 @@ def test_tok2vec():
def test_multi_hash_embed(): def test_multi_hash_embed():
embed = MultiHashEmbed( embed = MultiHashEmbed(
width=32, width=32,
rows=500, rows=[500, 500, 500],
attrs=["NORM", "PREFIX", "SHAPE"], attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False include_static_vectors=False
) )
@ -95,8 +95,8 @@ def test_multi_hash_embed():
# Now try with different row factors # Now try with different row factors
embed = MultiHashEmbed( embed = MultiHashEmbed(
width=32, width=32,
rows=500, rows=[1000, 50, 250],
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5}, attrs=["NORM", "PREFIX", "SHAPE"],
include_static_vectors=False include_static_vectors=False
) )
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"] hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]