mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Simplify MultiHashEmbed signature
This commit is contained in:
parent
71e73ed0a6
commit
6dcc4a0ba6
|
@ -55,13 +55,15 @@ def build_hash_embed_cnn_tok2vec(
|
|||
pretrained_vectors (bool): Whether to also use static vectors.
|
||||
"""
|
||||
if subword_features:
|
||||
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5}
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
row_sizes = [embed_size, embed_size//2, embed_size//2, embed_size//2]
|
||||
else:
|
||||
attrs = {"NORM": 1.0}
|
||||
attrs = ["NORM"]
|
||||
row_sizes = [embed_size]
|
||||
return build_Tok2Vec_model(
|
||||
embed=MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
rows=row_sizes,
|
||||
attrs=attrs,
|
||||
include_static_vectors=bool(pretrained_vectors),
|
||||
),
|
||||
|
@ -103,7 +105,7 @@ def MultiHashEmbed_v1(
|
|||
here as a temporary compatibility."""
|
||||
return MultiHashEmbed(
|
||||
width=width,
|
||||
rows=rows,
|
||||
rows=[rows, rows//2, rows//2, rows//2] if also_embed_subwords else [rows],
|
||||
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
|
||||
include_static_vectors=also_use_static_vectors,
|
||||
)
|
||||
|
@ -112,8 +114,8 @@ def MultiHashEmbed_v1(
|
|||
@registry.architectures.register("spacy.MultiHashEmbed.v2")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]],
|
||||
attrs: List[Union[str, int]],
|
||||
rows: List[int],
|
||||
include_static_vectors: bool,
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Construct an embedding layer that separately embeds a number of lexical
|
||||
|
@ -136,50 +138,38 @@ def MultiHashEmbed(
|
|||
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||
even for very large vocabularies. You can vary the number of rows per
|
||||
attribute by specifying the attrs as a dict, mapping the keys to float
|
||||
values which are interpreted as factors of `rows`. For instance,
|
||||
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
|
||||
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
|
||||
assumed for all attributes.
|
||||
even for very large vocabularies. A number of rows must be specified for each
|
||||
table, so the `rows` list must be of the same length as the `attrs` parameter.
|
||||
|
||||
width (int): The output width. Also used as the width of the embedding tables.
|
||||
Recommended values are between 64 and 300.
|
||||
rows (int): The base number of rows for the embedding tables. Can be low, due
|
||||
to the hashing trick. The rows can be varied per attribute by providing
|
||||
a dictionary as the value of `attrs`.
|
||||
attrs (dict or list of attr IDs): The token attributes to embed. A separate
|
||||
embedding table will be constructed for each attribute. Attributes
|
||||
can be specified as a list or as a dictionary, which lets you control
|
||||
the number of rows used for each table.
|
||||
attrs (list of attr IDs): The token attributes to embed. A separate
|
||||
embedding table will be constructed for each attribute.
|
||||
rows (List[int]): The number of rows in the embedding tables. Must have the
|
||||
same length as attrs.
|
||||
include_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
if isinstance(attrs, dict):
|
||||
# Exclude tables that would have 0 rows.
|
||||
attrs = {key: value for key, value in attrs.items() if value > 0.0}
|
||||
indices = {attr: i for i, attr in enumerate(attrs)}
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(feature):
|
||||
def make_hash_embed(index):
|
||||
nonlocal seed
|
||||
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
|
||||
seed += 1
|
||||
return HashEmbed(
|
||||
width,
|
||||
int(rows * row_factor),
|
||||
column=indices[feature],
|
||||
rows[index],
|
||||
column=index,
|
||||
seed=seed,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
embeddings = [make_hash_embed(attr) for attr in attrs]
|
||||
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
|
||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||
if include_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(
|
||||
FeatureExtractor(list(attrs)),
|
||||
FeatureExtractor(attrs),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_empty_doc():
|
|||
tok2vec = build_Tok2Vec_model(
|
||||
MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
rows=[embed_size, embed_size, embed_size, embed_size],
|
||||
include_static_vectors=False,
|
||||
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
||||
),
|
||||
|
@ -44,7 +44,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
tok2vec = build_Tok2Vec_model(
|
||||
MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
rows=[embed_size] * 4,
|
||||
include_static_vectors=False,
|
||||
attrs=["NORM", "PREFIX", "SUFFIX", "SHAPE"],
|
||||
),
|
||||
|
@ -61,8 +61,8 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
@pytest.mark.parametrize(
|
||||
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||
[
|
||||
(8, MultiHashEmbed, {"rows": 100, "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||
(8, MultiHashEmbed, {"rows": 100, "attrs": {"ORTH": 1.0, "PREFIX": 0.2}, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||
],
|
||||
|
@ -116,11 +116,11 @@ cfg_string = """
|
|||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
@architectures = "spacy.MultiHashEmbed.v2"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
rows = 2000
|
||||
also_embed_subwords = true
|
||||
also_use_static_vectors = false
|
||||
rows = [2000, 1000, 1000, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
|
|
|
@ -63,7 +63,7 @@ def get_tok2vec_kwargs():
|
|||
return {
|
||||
"embed": MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
rows=[500, 500, 500],
|
||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||
include_static_vectors=False
|
||||
),
|
||||
|
@ -80,7 +80,7 @@ def test_tok2vec():
|
|||
def test_multi_hash_embed():
|
||||
embed = MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
rows=[500, 500, 500],
|
||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||
include_static_vectors=False
|
||||
)
|
||||
|
@ -95,8 +95,8 @@ def test_multi_hash_embed():
|
|||
# Now try with different row factors
|
||||
embed = MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5},
|
||||
rows=[1000, 50, 250],
|
||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||
include_static_vectors=False
|
||||
)
|
||||
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user