mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Allow CharacterEmbed to specify feature
This commit is contained in:
parent
0a8a124a6e
commit
684a77870b
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List
|
||||
from typing import Optional, List, Union
|
||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list
|
||||
from thinc.api import FeatureExtractor, HashEmbed
|
||||
|
@ -165,7 +165,8 @@ def MultiHashEmbed(
|
|||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(
|
||||
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool
|
||||
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
|
||||
feature: Union[int, str]="NORM"
|
||||
):
|
||||
"""Construct an embedded representation based on character embeddings, using
|
||||
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
||||
|
@ -183,7 +184,8 @@ def CharacterEmbed(
|
|||
also concatenated on, and the result is then passed through a feed-forward
|
||||
network to construct a single vector to represent the information.
|
||||
|
||||
width (int): The width of the output vector and the NORM hash embedding.
|
||||
feature (int or str): An attribute to embed, to concatenate with the characters.
|
||||
width (int): The width of the output vector and the feature embedding.
|
||||
rows (int): The number of rows in the NORM hash embedding table.
|
||||
nM (int): The dimensionality of the character embeddings. Recommended values
|
||||
are between 16 and 64.
|
||||
|
@ -193,12 +195,15 @@ def CharacterEmbed(
|
|||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
feature = intify_attr(feature)
|
||||
if feature is None:
|
||||
raise ValueError("Invalid feature: Must be a token attribute.")
|
||||
if also_use_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
chain(
|
||||
FeatureExtractor([NORM]),
|
||||
FeatureExtractor([feature]),
|
||||
list2ragged(),
|
||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||
),
|
||||
|
@ -214,7 +219,7 @@ def CharacterEmbed(
|
|||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
chain(
|
||||
FeatureExtractor([NORM]),
|
||||
FeatureExtractor([feature]),
|
||||
list2ragged(),
|
||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue
Block a user