mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 13:43:24 +03:00
WIP on new StaticVectors
This commit is contained in:
parent
e257e66ab9
commit
cb9654e98c
|
@ -1,7 +1,9 @@
|
||||||
|
from typing import Optional, List
|
||||||
from thinc.api import chain, clone, concatenate, with_array, uniqued
|
from thinc.api import chain, clone, concatenate, with_array, uniqued
|
||||||
from thinc.api import Model, noop, with_padded, Maxout, expand_window
|
from thinc.api import Model, noop, with_padded, Maxout, expand_window
|
||||||
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
|
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
|
||||||
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
|
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
from ... import util
|
from ... import util
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
@ -42,15 +44,15 @@ def Doc2Feats(columns):
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||||
def hash_embed_cnn(
|
def hash_embed_cnn(
|
||||||
pretrained_vectors,
|
pretrained_vectors: str,
|
||||||
width,
|
width: int,
|
||||||
depth,
|
depth: int,
|
||||||
embed_size,
|
embed_size: int,
|
||||||
maxout_pieces,
|
maxout_pieces: int,
|
||||||
window_size,
|
window_size: int,
|
||||||
subword_features,
|
subword_features: bool,
|
||||||
dropout,
|
dropout: float,
|
||||||
):
|
) -> Model[List[Doc], List[Floats2d]:
|
||||||
# Does not use character embeddings: set to False by default
|
# Does not use character embeddings: set to False by default
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
width=width,
|
width=width,
|
||||||
|
@ -182,7 +184,7 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
glove = StaticVectors(
|
glove = StaticVectors(
|
||||||
vectors=pretrained_vectors.data,
|
vectors_name=pretrained_vectors,
|
||||||
nO=width,
|
nO=width,
|
||||||
column=columns.index(ID),
|
column=columns.index(ID),
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
@ -261,18 +263,18 @@ def TorchBiLSTMEncoder(width, depth):
|
||||||
|
|
||||||
|
|
||||||
def build_Tok2Vec_model(
|
def build_Tok2Vec_model(
|
||||||
width,
|
width: int,
|
||||||
embed_size,
|
embed_size: int,
|
||||||
pretrained_vectors,
|
pretrained_vectors: Optional[str],
|
||||||
window_size,
|
window_size: int,
|
||||||
maxout_pieces,
|
maxout_pieces: int,
|
||||||
subword_features,
|
subword_features: bool,
|
||||||
char_embed,
|
char_embed: bool,
|
||||||
nM,
|
nM: int,
|
||||||
nC,
|
nC: int,
|
||||||
conv_depth,
|
conv_depth: int,
|
||||||
bilstm_depth,
|
bilstm_depth: int,
|
||||||
dropout,
|
dropout: float,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if char_embed:
|
if char_embed:
|
||||||
subword_features = False
|
subword_features = False
|
||||||
|
|
|
@ -24,6 +24,8 @@ import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import shlex
|
import shlex
|
||||||
import inspect
|
import inspect
|
||||||
|
from thinc.types import Unserializable
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cupy.random
|
import cupy.random
|
||||||
|
@ -1184,20 +1186,27 @@ class DummyTokenizer:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def link_vectors_to_models(vocab: "Vocab") -> None:
|
def link_vectors_to_models(
|
||||||
|
vocab: "Vocab",
|
||||||
|
models: List[Model]=[],
|
||||||
|
*,
|
||||||
|
vectors_name_attr="vectors_name",
|
||||||
|
vectors_attr="vectors",
|
||||||
|
key2row_attr="key2row",
|
||||||
|
default_vectors_name="spacy_pretrained_vectors"
|
||||||
|
) -> None:
|
||||||
|
"""Supply vectors data to models."""
|
||||||
vectors = vocab.vectors
|
vectors = vocab.vectors
|
||||||
if vectors.name is None:
|
if vectors.name is None:
|
||||||
vectors.name = VECTORS_KEY
|
vectors.name = default_vectors_name
|
||||||
if vectors.data.size != 0:
|
if vectors.data.size != 0:
|
||||||
warnings.warn(Warnings.W020.format(shape=vectors.data.shape))
|
warnings.warn(Warnings.W020.format(shape=vectors.data.shape))
|
||||||
for word in vocab:
|
|
||||||
if word.orth in vectors.key2row:
|
|
||||||
word.rank = vectors.key2row[word.orth]
|
|
||||||
else:
|
|
||||||
word.rank = 0
|
|
||||||
|
|
||||||
|
for model in models:
|
||||||
VECTORS_KEY = "spacy_pretrained_vectors"
|
for node in model.walk():
|
||||||
|
if node.attrs.get(vectors_name_attr) == vectors.name:
|
||||||
|
node.attrs[vectors_attr] = Unserializable(vectors.data)
|
||||||
|
node.attrs[key2row_attr] = Unserializable(vectors.key2row)
|
||||||
|
|
||||||
|
|
||||||
def create_default_optimizer() -> Optimizer:
|
def create_default_optimizer() -> Optimizer:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user