mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Allow configuration of MultiHashEmbed features
Update arguments to MultiHashEmbed layer so that the attributes can be controlled. A kind of tricky scheme is used to allow optional specification of the rows. I think it's an okay balance between flexibility and convenience.
This commit is contained in:
parent
6a9d14e35a
commit
8ec79ad3fa
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Dict
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||||
|
@ -11,7 +11,7 @@ from ...ml import _character_embed
|
||||||
from ..staticvectors import StaticVectors
|
from ..staticvectors import StaticVectors
|
||||||
from ..featureextractor import FeatureExtractor
|
from ..featureextractor import FeatureExtractor
|
||||||
from ...pipeline.tok2vec import Tok2VecListener
|
from ...pipeline.tok2vec import Tok2VecListener
|
||||||
from ...attrs import ORTH, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
from ...attrs import ORTH, NORM, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||||
|
@ -54,12 +54,16 @@ def build_hash_embed_cnn_tok2vec(
|
||||||
a language such as Chinese.
|
a language such as Chinese.
|
||||||
pretrained_vectors (bool): Whether to also use static vectors.
|
pretrained_vectors (bool): Whether to also use static vectors.
|
||||||
"""
|
"""
|
||||||
|
if subword_features:
|
||||||
|
attrs = {"NORM": 1.0, "PREFIX": 0.5, "SUFFIX": 0.5, "SHAPE": 0.5}
|
||||||
|
else:
|
||||||
|
attrs = {"NORM": 1.0}
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
embed=MultiHashEmbed(
|
embed=MultiHashEmbed(
|
||||||
width=width,
|
width=width,
|
||||||
rows=embed_size,
|
rows=embed_size,
|
||||||
also_embed_subwords=subword_features,
|
attrs=attrs,
|
||||||
also_use_static_vectors=bool(pretrained_vectors),
|
include_static_vectors=bool(pretrained_vectors),
|
||||||
),
|
),
|
||||||
encode=MaxoutWindowEncoder(
|
encode=MaxoutWindowEncoder(
|
||||||
width=width,
|
width=width,
|
||||||
|
@ -92,59 +96,89 @@ def build_Tok2Vec_model(
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed_v1(
|
||||||
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||||
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
"""Previous interface for MultiHashEmbed. This should be removed, it's only
|
||||||
|
here as a temporary compatibility."""
|
||||||
|
return MultiHashEmbed(
|
||||||
|
width=width,
|
||||||
|
rows=rows,
|
||||||
|
attrs=[NORM, PREFIX, SUFFIX, SHAPE] if also_embed_subwords else [NORM],
|
||||||
|
include_static_vectors=also_use_static_vectors
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.MultiHashEmbed.v2")
|
||||||
|
def MultiHashEmbed(
|
||||||
|
width: int,
|
||||||
|
rows: int,
|
||||||
|
attrs: Union[List[Union[str, int]], Dict[Union[str, int], float]],
|
||||||
|
include_static_vectors: bool
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
"""Construct an embedding layer that separately embeds a number of lexical
|
"""Construct an embedding layer that separately embeds a number of lexical
|
||||||
attributes using hash embedding, concatenates the results, and passes it
|
attributes using hash embedding, concatenates the results, and passes it
|
||||||
through a feed-forward subnetwork to build a mixed representations.
|
through a feed-forward subnetwork to build a mixed representations.
|
||||||
|
|
||||||
The features used are the LOWER, PREFIX, SUFFIX and SHAPE, which can have
|
The features used can be configured with the 'attrs' argument. The suggested
|
||||||
varying definitions depending on the Vocab of the Doc object passed in.
|
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
|
||||||
Vectors from pretrained static vectors can also be incorporated into the
|
account some subword information, without contruction a fully character-based
|
||||||
concatenated representation.
|
representation. If pretrained vectors are available, they can be included in
|
||||||
|
the representation as well, with the vectors table will be kept static
|
||||||
|
(i.e. it's not updated).
|
||||||
|
|
||||||
|
The `width` parameter specifices the output width of the layer and the widths
|
||||||
|
of all embedding tables. If static vectors are included, a learned linear
|
||||||
|
layer is used to map the vectors to the specified width before concatenating
|
||||||
|
it with the other embedding outputs. A single Maxout layer is then used to
|
||||||
|
reduce the concatenated vectors to the final width.
|
||||||
|
|
||||||
|
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||||
|
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||||
|
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||||
|
even for very large vocabularies. You can vary the number of rows per
|
||||||
|
attribute by specifying the attrs as a dict, mapping the keys to float
|
||||||
|
values which are interpreted as factors of `rows`. For instance,
|
||||||
|
attrs={"NORM": 1.0, PREFIX: 0.2} will use rows*1 for the NORM table and
|
||||||
|
rows*0.2 for the PREFIX table. If `attrs` is a list, factors of 1.0 are
|
||||||
|
assumed for all attributes.
|
||||||
|
|
||||||
width (int): The output width. Also used as the width of the embedding tables.
|
width (int): The output width. Also used as the width of the embedding tables.
|
||||||
Recommended values are between 64 and 300.
|
Recommended values are between 64 and 300.
|
||||||
rows (int): The number of rows for the embedding tables. Can be low, due
|
rows (int): The base number of rows for the embedding tables. Can be low, due
|
||||||
to the hashing trick. Embeddings for prefix, suffix and word shape
|
to the hashing trick. The rows can be varied per attribute by providing
|
||||||
use half as many rows. Recommended values are between 2000 and 10000.
|
a dictionary as the value of `attrs`.
|
||||||
also_embed_subwords (bool): Whether to use the PREFIX, SUFFIX and SHAPE
|
attrs (dict or list of attr IDs): The token attributes to embed. A separate
|
||||||
features in the embeddings. If not using these, you may need more
|
embedding table will be constructed for each attribute. Attributes
|
||||||
rows in your hash embeddings, as there will be increased chance of
|
can be specified as a list or as a dictionary, which lets you control
|
||||||
collisions.
|
the number of rows used for each table.
|
||||||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
include_static_vectors (bool): Whether to also use static word vectors.
|
||||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||||
"""
|
"""
|
||||||
cols = [LOWER, PREFIX, SUFFIX, SHAPE, ORTH]
|
if isinstance(attrs, dict):
|
||||||
|
# Exclude tables that would have 0 rows.
|
||||||
|
attrs = {key: value for key, value in attrs.items() if value > 0.0}
|
||||||
|
indices = {attr: i for i, attr in enumerate(attrs)}
|
||||||
seed = 7
|
seed = 7
|
||||||
|
|
||||||
def make_hash_embed(feature):
|
def make_hash_embed(feature):
|
||||||
nonlocal seed
|
nonlocal seed
|
||||||
|
row_factor = attrs[feature] if isinstance(attrs, dict) else 1.0
|
||||||
seed += 1
|
seed += 1
|
||||||
return HashEmbed(
|
return HashEmbed(
|
||||||
width,
|
width,
|
||||||
rows if feature == LOWER else rows // 2,
|
int(rows * row_factor),
|
||||||
column=cols.index(feature),
|
column=indices[feature],
|
||||||
seed=seed,
|
seed=seed,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if also_embed_subwords:
|
embeddings = [make_hash_embed(attr) for attr in attrs]
|
||||||
embeddings = [
|
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||||
make_hash_embed(LOWER),
|
if include_static_vectors:
|
||||||
make_hash_embed(PREFIX),
|
|
||||||
make_hash_embed(SUFFIX),
|
|
||||||
make_hash_embed(SHAPE),
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
embeddings = [make_hash_embed(LOWER)]
|
|
||||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
|
||||||
if also_use_static_vectors:
|
|
||||||
model = chain(
|
model = chain(
|
||||||
concatenate(
|
concatenate(
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor(cols),
|
FeatureExtractor(list(attrs)),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings)),
|
with_array(concatenate(*embeddings)),
|
||||||
),
|
),
|
||||||
|
@ -155,7 +189,7 @@ def MultiHashEmbed(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = chain(
|
model = chain(
|
||||||
FeatureExtractor(cols),
|
FeatureExtractor(list(attrs)),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings)),
|
with_array(concatenate(*embeddings)),
|
||||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||||
|
|
|
@ -6,6 +6,7 @@ from numpy.testing import assert_array_equal
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||||
|
from spacy.ml.models import MultiHashEmbed_v1
|
||||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||||
|
@ -61,7 +62,10 @@ def get_tok2vec_kwargs():
|
||||||
# This actually creates models, so seems best to put it in a function.
|
# This actually creates models, so seems best to put it in a function.
|
||||||
return {
|
return {
|
||||||
"embed": MultiHashEmbed(
|
"embed": MultiHashEmbed(
|
||||||
width=32, rows=500, also_embed_subwords=True, also_use_static_vectors=False
|
width=32,
|
||||||
|
rows=500,
|
||||||
|
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||||
|
include_static_vectors=False
|
||||||
),
|
),
|
||||||
"encode": MaxoutWindowEncoder(
|
"encode": MaxoutWindowEncoder(
|
||||||
width=32, depth=2, maxout_pieces=2, window_size=1
|
width=32, depth=2, maxout_pieces=2, window_size=1
|
||||||
|
@ -73,6 +77,32 @@ def test_tok2vec():
|
||||||
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_hash_embed():
|
||||||
|
embed = MultiHashEmbed(
|
||||||
|
width=32,
|
||||||
|
rows=500,
|
||||||
|
attrs=["NORM", "PREFIX", "SHAPE"],
|
||||||
|
include_static_vectors=False
|
||||||
|
)
|
||||||
|
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||||
|
assert len(hash_embeds) == 3
|
||||||
|
# Check they look at different columns.
|
||||||
|
assert list(sorted(he.attrs["column"] for he in hash_embeds)) == [0, 1, 2]
|
||||||
|
# Check they use different seeds
|
||||||
|
assert len(set(he.attrs["seed"] for he in hash_embeds)) == 3
|
||||||
|
# Check they all have the same number of rows
|
||||||
|
assert [he.get_dim("nV") for he in hash_embeds] == [500, 500, 500]
|
||||||
|
# Now try with different row factors
|
||||||
|
embed = MultiHashEmbed(
|
||||||
|
width=32,
|
||||||
|
rows=500,
|
||||||
|
attrs={"NORM": 2.0, "PREFIX": 0.1, "SHAPE": 0.5},
|
||||||
|
include_static_vectors=False
|
||||||
|
)
|
||||||
|
hash_embeds = [node for node in embed.walk() if node.name == "hashembed"]
|
||||||
|
assert [he.get_dim("nV") for he in hash_embeds] == [1000, 50, 250]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,model_func,kwargs",
|
"seed,model_func,kwargs",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user