mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Allow configuration of MultiHashEmbed features
Update arguments to MultiHashEmbed layer so that the attributes can be controlled. A kind of tricky scheme is used to allow optional specification of the rows. I think it's an okay balance between flexibility and convenience.
This commit is contained in:
parent
6a9d14e35a
commit
8ec79ad3fa
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Union
|
||||
from typing import Optional, List, Union, Dict
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||
|
@ -11,7 +11,7 @@ from ...ml import _character_embed
|
|||
from ..staticvectors import StaticVectors
|
||||
from ..featureextractor import FeatureExtractor
|
||||
from ...pipeline.tok2vec import Tok2VecListener
|
||||
from ...attrs import ORTH, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
||||
from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||
|
@ -54,12 +54,16 @@ def build_hash_embed_cnn_tok2vec(
|
|||
a language such as Chinese.
|
||||
pretrained_vectors (bool): Whether to also use static vectors.
|
||||
"""
|
||||
if subword_features:
|
||||
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5}
|
||||
else:
|
||||
attrs = {"NORM": 1.0}
|
||||
return build_Tok2Vec_model(
|
||||
embed=MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
also_embed_subwords=subword_features,
|
||||
also_use_static_vectors=bool(pretrained_vectors),
|
||||
attrs=attrs,
|
||||
include_static_vectors=bool(pretrained_vectors),
|
||||
),
|
||||
encode=MaxoutWindowEncoder(
|
||||
width=width,
|
||||
|
@ -92,59 +96,89 @@ def build_Tok2Vec_model(
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(
|
||||
def MultiHashEmbed_v1(
|
||||
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Previous interface for MultiHashEmbed. This should be removed, it's only
|
||||
here as a temporary compatibility."""
|
||||
return MultiHashEmbed(
|
||||
width=width,
|
||||
rows=rows,
|
||||
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
|
||||
include_static_vectors=also_use_static_vectors
|
||||
)
|
||||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v2")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]],
|
||||
include_static_vectors: bool
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Construct an embedding layer that separately embeds a number of lexical
|
||||
attributes using hash embedding, concatenates the results, and passes it
|
||||
through a feed-forward subnetwork to build a mixed representations.
|
||||
|
||||
The features used are the LOWER, PREFIX, SUFFIX and SHAPE, which can have
|
||||
varying definitions depending on the Vocab of the Doc object passed in.
|
||||
Vectors from pretrained static vectors can also be incorporated into the
|
||||
concatenated representation.
|
||||
The features used can be configured with the 'attrs' argument. The suggested
|
||||
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
|
||||
account some subword information, without contruction a fully character-based
|
||||
representation. If pretrained vectors are available, they can be included in
|
||||
the representation as well, with the vectors table will be kept static
|
||||
(i.e. it's not updated).
|
||||
|
||||
The `width` parameter specifices the output width of the layer and the widths
|
||||
of all embedding tables. If static vectors are included, a learned linear
|
||||
layer is used to map the vectors to the specified width before concatenating
|
||||
it with the other embedding outputs. A single Maxout layer is then used to
|
||||
reduce the concatenated vectors to the final width.
|
||||
|
||||
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||
even for very large vocabularies. You can vary the number of rows per
|
||||
attribute by specifying the attrs as a dict, mapping the keys to float
|
||||
values which are interpreted as factors of `rows`. For instance,
|
||||
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
|
||||
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
|
||||
assumed for all attributes.
|
||||
|
||||
width (int): The output width. Also used as the width of the embedding tables.
|
||||
Recommended values are between 64 and 300.
|
||||
rows (int): The number of rows for the embedding tables. Can be low, due
|
||||
to the hashing trick. Embeddings for prefix, suffix and word shape
|
||||
use half as many rows. Recommended values are between 2000 and 10000.
|
||||
also_embed_subwords (bool): Whether to use the PREFIX, SUFFIX and SHAPE
|
||||
features in the embeddings. If not using these, you may need more
|
||||
rows in your hash embeddings, as there will be increased chance of
|
||||
collisions.
|
||||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||
rows (int): The base number of rows for the embedding tables. Can be low, due
|
||||
to the hashing trick. The rows can be varied per attribute by providing
|
||||
a dictionary as the value of `attrs`.
|
||||
attrs (dict or list of attr IDs): The token attributes to embed. A separate
|
||||
embedding table will be constructed for each attribute. Attributes
|
||||
can be specified as a list or as a dictionary, which lets you control
|
||||
the number of rows used for each table.
|
||||
include_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
cols = [LOWER, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
if isinstance(attrs, dict):
|
||||
# Exclude tables that would have 0 rows.
|
||||
attrs = {key: value for key, value in attrs.items() if value > 0.0}
|
||||
indices = {attr: i for i, attr in enumerate(attrs)}
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(feature):
|
||||
nonlocal seed
|
||||
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
|
||||
seed += 1
|
||||
return HashEmbed(
|
||||
width,
|
||||
rows if feature == LOWER else rows // 2,
|
||||
column=cols.index(feature),
|
||||
int(rows * row_factor),
|
||||
column=indices[feature],
|
||||
seed=seed,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
if also_embed_subwords:
|
||||
embeddings = [
|
||||
make_hash_embed(LOWER),
|
||||
make_hash_embed(PREFIX),
|
||||
make_hash_embed(SUFFIX),
|
||||
make_hash_embed(SHAPE),
|
||||
]
|
||||
else:
|
||||
embeddings = [make_hash_embed(LOWER)]
|
||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||
if also_use_static_vectors:
|
||||
embeddings = [make_hash_embed(attr) for attr in attrs]
|
||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||
if include_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(
|
||||
FeatureExtractor(cols),
|
||||
FeatureExtractor(list(attrs)),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
|
@ -155,7 +189,7 @@ def MultiHashEmbed(
|
|||
)
|
||||
else:
|
||||
model = chain(
|
||||
FeatureExtractor(cols),
|
||||
FeatureExtractor(list(attrs)),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||
|
|
|
@ -6,6 +6,7 @@ from numpy.testing import assert_array_equal
|
|||
import numpy
|
||||
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import MultiHashEmbed_v1
|
||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||
|
@ -61,7 +62,10 @@ def get_tok2vec_kwargs():
|
|||
# This actually creates models, so seems best to put it in a function.
|
||||
return {
|
||||
"embed": MultiHashEmbed(
|
||||
width=32, rows=500, also_embed_subwords=True, also_use_static_vectors=False
|
||||
width=32,
|
||||
rows=500,
|
||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||
include_static_vectors=False
|
||||
),
|
||||
"encode": MaxoutWindowEncoder(
|
||||
width=32, depth=2, maxout_pieces=2, window_size=1
|
||||
|
@ -73,6 +77,32 @@ def test_tok2vec():
|
|||
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||
|
||||
|
||||
def test_multi_hash_embed():
|
||||
embed = MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||
include_static_vectors=False
|
||||
)
|
||||
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||
assert len(hash_embeds) == 3
|
||||
# Check they look at different columns.
|
||||
assert list(sorted(he.attrs["column"] for he in hash_embeds)) == [0, 1, 2]
|
||||
# Check they use different seeds
|
||||
assert len(set(he.attrs["seed"] for he in hash_embeds)) == 3
|
||||
# Check they all have the same number of rows
|
||||
assert [he.get_dim("nV") for he in hash_embeds] == [500, 500, 500]
|
||||
# Now try with different row factors
|
||||
embed = MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5},
|
||||
include_static_vectors=False
|
||||
)
|
||||
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||
assert [he.get_dim("nV") for he in hash_embeds] == [1000, 50, 250]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed,model_func,kwargs",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue
Block a user