mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Avoid relying on NORM in default v3 models (#6176)
* Allow CharacterEmbed to specify feature * Default to LOWER in character embed * Update tok2vec * Use LOWER, not NORM
This commit is contained in:
parent
50162b8726
commit
300e5a9928
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List
|
||||
from typing import Optional, List, Union
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||
|
@ -10,7 +10,7 @@ from ...ml import _character_embed
|
|||
from ..staticvectors import StaticVectors
|
||||
from ..featureextractor import FeatureExtractor
|
||||
from ...pipeline.tok2vec import Tok2VecListener
|
||||
from ...attrs import ORTH, NORM, PREFIX, SUFFIX, SHAPE
|
||||
from ...attrs import ORTH, LOWER, PREFIX, SUFFIX, SHAPE, intify_attr
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||
|
@ -98,7 +98,7 @@ def MultiHashEmbed(
|
|||
attributes using hash embedding, concatenates the results, and passes it
|
||||
through a feed-forward subnetwork to build a mixed representations.
|
||||
|
||||
The features used are the NORM, PREFIX, SUFFIX and SHAPE, which can have
|
||||
The features used are the LOWER, PREFIX, SUFFIX and SHAPE, which can have
|
||||
varying definitions depending on the Vocab of the Doc object passed in.
|
||||
Vectors from pretrained static vectors can also be incorporated into the
|
||||
concatenated representation.
|
||||
|
@ -115,7 +115,7 @@ def MultiHashEmbed(
|
|||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
cols = [LOWER, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(feature):
|
||||
|
@ -123,7 +123,7 @@ def MultiHashEmbed(
|
|||
seed += 1
|
||||
return HashEmbed(
|
||||
width,
|
||||
rows if feature == NORM else rows // 2,
|
||||
rows if feature == LOWER else rows // 2,
|
||||
column=cols.index(feature),
|
||||
seed=seed,
|
||||
dropout=0.0,
|
||||
|
@ -131,13 +131,13 @@ def MultiHashEmbed(
|
|||
|
||||
if also_embed_subwords:
|
||||
embeddings = [
|
||||
make_hash_embed(NORM),
|
||||
make_hash_embed(LOWER),
|
||||
make_hash_embed(PREFIX),
|
||||
make_hash_embed(SUFFIX),
|
||||
make_hash_embed(SHAPE),
|
||||
]
|
||||
else:
|
||||
embeddings = [make_hash_embed(NORM)]
|
||||
embeddings = [make_hash_embed(LOWER)]
|
||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||
if also_use_static_vectors:
|
||||
model = chain(
|
||||
|
@ -165,7 +165,8 @@ def MultiHashEmbed(
|
|||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(
|
||||
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool
|
||||
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
|
||||
feature: Union[int, str]="LOWER"
|
||||
):
|
||||
"""Construct an embedded representation based on character embeddings, using
|
||||
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
||||
|
@ -179,12 +180,13 @@ def CharacterEmbed(
|
|||
of being in an arbitrary position depending on the word length.
|
||||
|
||||
The characters are embedded in a embedding table with a given number of rows,
|
||||
and the vectors concatenated. A hash-embedded vector of the NORM of the word is
|
||||
and the vectors concatenated. A hash-embedded vector of the LOWER of the word is
|
||||
also concatenated on, and the result is then passed through a feed-forward
|
||||
network to construct a single vector to represent the information.
|
||||
|
||||
width (int): The width of the output vector and the NORM hash embedding.
|
||||
rows (int): The number of rows in the NORM hash embedding table.
|
||||
feature (int or str): An attribute to embed, to concatenate with the characters.
|
||||
width (int): The width of the output vector and the feature embedding.
|
||||
rows (int): The number of rows in the LOWER hash embedding table.
|
||||
nM (int): The dimensionality of the character embeddings. Recommended values
|
||||
are between 16 and 64.
|
||||
nC (int): The number of UTF-8 bytes to embed per word. Recommended values
|
||||
|
@ -193,12 +195,15 @@ def CharacterEmbed(
|
|||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
feature = intify_attr(feature)
|
||||
if feature is None:
|
||||
raise ValueError("Invalid feature: Must be a token attribute.")
|
||||
if also_use_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
chain(
|
||||
FeatureExtractor([NORM]),
|
||||
FeatureExtractor([feature]),
|
||||
list2ragged(),
|
||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||
),
|
||||
|
@ -214,7 +219,7 @@ def CharacterEmbed(
|
|||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
chain(
|
||||
FeatureExtractor([NORM]),
|
||||
FeatureExtractor([feature]),
|
||||
list2ragged(),
|
||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue
Block a user