This commit is contained in:
Matthew Honnibal 2020-07-28 22:02:34 +02:00
parent 2aff3c4b5a
commit 0c17ea4c85
3 changed files with 22 additions and 33 deletions

View File

@ -23,7 +23,7 @@ def tok2vec_listener_v1(width, upstream="*"):
@registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(
embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]]
encode: Model[List[Floats2d], List[Floats2d]],
) -> Model[List[Doc], List[Floats2d]]:
receptive_field = encode.attrs.get("receptive_field", 0)
@ -36,14 +36,12 @@ def Tok2Vec(
@registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(
width: int,
rows: int,
also_embed_subwords: bool,
also_use_static_vectors: bool
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
):
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
seed = 7
def make_hash_embed(feature):
nonlocal seed
seed += 1
@ -52,15 +50,15 @@ def MultiHashEmbed(
rows if feature == NORM else rows // 2,
column=cols.index(feature),
seed=seed,
dropout=0.0
dropout=0.0,
)
if also_embed_subwords:
embeddings = [
make_hash_embed(NORM),
make_hash_embed(PREFIX),
make_hash_embed(SUFFIX),
make_hash_embed(SHAPE)
make_hash_embed(SHAPE),
]
else:
embeddings = [make_hash_embed(NORM)]
@ -71,25 +69,25 @@ def MultiHashEmbed(
chain(
FeatureExtractor(cols),
list2ragged(),
with_array(concatenate(*embeddings))
with_array(concatenate(*embeddings)),
),
StaticVectors(width, dropout=0.0)
StaticVectors(width, dropout=0.0),
),
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list()
ragged2list(),
)
else:
model = chain(
chain(
FeatureExtractor(cols),
list2ragged(),
with_array(concatenate(*embeddings))
with_array(concatenate(*embeddings)),
),
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list()
ragged2list(),
)
return model
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
@ -137,6 +135,4 @@ def MishWindowEncoder(width, window_size, depth):
def BiLSTMEncoder(width, depth, dropout):
if depth == 0:
return noop()
return with_padded(
PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout)
)
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))

View File

@ -15,7 +15,7 @@ def StaticVectors(
*,
dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init,
key_attr: str="ORTH"
key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is
@ -45,21 +45,17 @@ def forward(
)
output = Ragged(
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
model.ops.asarray([len(doc) for doc in docs], dtype="i")
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
)
if mask is not None:
output.data *= mask
def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None:
d_output.data *= mask
model.inc_grad(
"W",
model.ops.gemm(
d_output.data,
model.ops.as_contig(V[rows]),
trans1=True
)
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
)
return []
@ -78,7 +74,7 @@ def init(
nM = X[0].vocab.vectors.data.shape[1]
if Y is not None:
nO = Y.data.shape[1]
if nM is None:
raise ValueError(
"Cannot initialize StaticVectors layer: nM dimension unset. "

View File

@ -190,10 +190,7 @@ def get_module_path(module: ModuleType) -> Path:
def load_vectors_into_model(
nlp: "Language",
name: Union[str, Path],
*,
add_strings=True
nlp: "Language", name: Union[str, Path], *, add_strings=True
) -> None:
"""Load word vectors from an installed model or path into a model instance."""
vectors_nlp = load_model(name)
@ -1205,12 +1202,12 @@ class DummyTokenizer:
def link_vectors_to_models(
vocab: "Vocab",
models: List[Model]=[],
models: List[Model] = [],
*,
vectors_name_attr="vectors_name",
vectors_attr="vectors",
key2row_attr="key2row",
default_vectors_name="spacy_pretrained_vectors"
default_vectors_name="spacy_pretrained_vectors",
) -> None:
"""Supply vectors data to models."""
vectors = vocab.vectors