mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
WIP on new StaticVectors
This commit is contained in:
parent
e257e66ab9
commit
cb9654e98c
|
@ -1,7 +1,9 @@
|
|||
from typing import Optional, List
|
||||
from thinc.api import chain, clone, concatenate, with_array, uniqued
|
||||
from thinc.api import Model, noop, with_padded, Maxout, expand_window
|
||||
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
|
||||
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ... import util
|
||||
from ...util import registry
|
||||
|
@ -42,15 +44,15 @@ def Doc2Feats(columns):
|
|||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
def hash_embed_cnn(
|
||||
pretrained_vectors,
|
||||
width,
|
||||
depth,
|
||||
embed_size,
|
||||
maxout_pieces,
|
||||
window_size,
|
||||
subword_features,
|
||||
dropout,
|
||||
):
|
||||
pretrained_vectors: str,
|
||||
width: int,
|
||||
depth: int,
|
||||
embed_size: int,
|
||||
maxout_pieces: int,
|
||||
window_size: int,
|
||||
subword_features: bool,
|
||||
dropout: float,
|
||||
) -> Model[List[Doc], List[Floats2d]:
|
||||
# Does not use character embeddings: set to False by default
|
||||
return build_Tok2Vec_model(
|
||||
width=width,
|
||||
|
@ -182,7 +184,7 @@ def MultiHashEmbed(
|
|||
|
||||
if pretrained_vectors:
|
||||
glove = StaticVectors(
|
||||
vectors=pretrained_vectors.data,
|
||||
vectors_name=pretrained_vectors,
|
||||
nO=width,
|
||||
column=columns.index(ID),
|
||||
dropout=dropout,
|
||||
|
@ -261,18 +263,18 @@ def TorchBiLSTMEncoder(width, depth):
|
|||
|
||||
|
||||
def build_Tok2Vec_model(
|
||||
width,
|
||||
embed_size,
|
||||
pretrained_vectors,
|
||||
window_size,
|
||||
maxout_pieces,
|
||||
subword_features,
|
||||
char_embed,
|
||||
nM,
|
||||
nC,
|
||||
conv_depth,
|
||||
bilstm_depth,
|
||||
dropout,
|
||||
width: int,
|
||||
embed_size: int,
|
||||
pretrained_vectors: Optional[str],
|
||||
window_size: int,
|
||||
maxout_pieces: int,
|
||||
subword_features: bool,
|
||||
char_embed: bool,
|
||||
nM: int,
|
||||
nC: int,
|
||||
conv_depth: int,
|
||||
bilstm_depth: int,
|
||||
dropout: float,
|
||||
) -> Model:
|
||||
if char_embed:
|
||||
subword_features = False
|
||||
|
|
|
@ -24,6 +24,8 @@ import tempfile
|
|||
import shutil
|
||||
import shlex
|
||||
import inspect
|
||||
from thinc.types import Unserializable
|
||||
|
||||
|
||||
try:
|
||||
import cupy.random
|
||||
|
@ -1184,20 +1186,27 @@ class DummyTokenizer:
|
|||
return self
|
||||
|
||||
|
||||
def link_vectors_to_models(vocab: "Vocab") -> None:
|
||||
def link_vectors_to_models(
|
||||
vocab: "Vocab",
|
||||
models: List[Model]=[],
|
||||
*,
|
||||
vectors_name_attr="vectors_name",
|
||||
vectors_attr="vectors",
|
||||
key2row_attr="key2row",
|
||||
default_vectors_name="spacy_pretrained_vectors"
|
||||
) -> None:
|
||||
"""Supply vectors data to models."""
|
||||
vectors = vocab.vectors
|
||||
if vectors.name is None:
|
||||
vectors.name = VECTORS_KEY
|
||||
vectors.name = default_vectors_name
|
||||
if vectors.data.size != 0:
|
||||
warnings.warn(Warnings.W020.format(shape=vectors.data.shape))
|
||||
for word in vocab:
|
||||
if word.orth in vectors.key2row:
|
||||
word.rank = vectors.key2row[word.orth]
|
||||
else:
|
||||
word.rank = 0
|
||||
|
||||
|
||||
VECTORS_KEY = "spacy_pretrained_vectors"
|
||||
for model in models:
|
||||
for node in model.walk():
|
||||
if node.attrs.get(vectors_name_attr) == vectors.name:
|
||||
node.attrs[vectors_attr] = Unserializable(vectors.data)
|
||||
node.attrs[key2row_attr] = Unserializable(vectors.key2row)
|
||||
|
||||
|
||||
def create_default_optimizer() -> Optimizer:
|
||||
|
|
Loading…
Reference in New Issue
Block a user