This commit is contained in:
Matthew Honnibal 2020-07-28 22:02:34 +02:00
parent 2aff3c4b5a
commit 0c17ea4c85
3 changed files with 22 additions and 33 deletions

View File

@ -23,7 +23,7 @@ def tok2vec_listener_v1(width, upstream="*"):
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec( def Tok2Vec(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]] encode: Model[List[Floats2d], List[Floats2d]],
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
receptive_field = encode.attrs.get("receptive_field", 0) receptive_field = encode.attrs.get("receptive_field", 0)
@ -36,14 +36,12 @@ def Tok2Vec(
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
rows: int,
also_embed_subwords: bool,
also_use_static_vectors: bool
): ):
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
seed = 7 seed = 7
def make_hash_embed(feature): def make_hash_embed(feature):
nonlocal seed nonlocal seed
seed += 1 seed += 1
@ -52,7 +50,7 @@ def MultiHashEmbed(
rows if feature == NORM else rows // 2, rows if feature == NORM else rows // 2,
column=cols.index(feature), column=cols.index(feature),
seed=seed, seed=seed,
dropout=0.0 dropout=0.0,
) )
if also_embed_subwords: if also_embed_subwords:
@ -60,7 +58,7 @@ def MultiHashEmbed(
make_hash_embed(NORM), make_hash_embed(NORM),
make_hash_embed(PREFIX), make_hash_embed(PREFIX),
make_hash_embed(SUFFIX), make_hash_embed(SUFFIX),
make_hash_embed(SHAPE) make_hash_embed(SHAPE),
] ]
else: else:
embeddings = [make_hash_embed(NORM)] embeddings = [make_hash_embed(NORM)]
@ -71,22 +69,22 @@ def MultiHashEmbed(
chain( chain(
FeatureExtractor(cols), FeatureExtractor(cols),
list2ragged(), list2ragged(),
with_array(concatenate(*embeddings)) with_array(concatenate(*embeddings)),
), ),
StaticVectors(width, dropout=0.0) StaticVectors(width, dropout=0.0),
), ),
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list() ragged2list(),
) )
else: else:
model = chain( model = chain(
chain( chain(
FeatureExtractor(cols), FeatureExtractor(cols),
list2ragged(), list2ragged(),
with_array(concatenate(*embeddings)) with_array(concatenate(*embeddings)),
), ),
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
ragged2list() ragged2list(),
) )
return model return model
@ -137,6 +135,4 @@ def MishWindowEncoder(width, window_size, depth):
def BiLSTMEncoder(width, depth, dropout): def BiLSTMEncoder(width, depth, dropout):
if depth == 0: if depth == 0:
return noop() return noop()
return with_padded( return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout)
)

View File

@ -15,7 +15,7 @@ def StaticVectors(
*, *,
dropout: Optional[float] = None, dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init, init_W: Callable = glorot_uniform_init,
key_attr: str="ORTH" key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]: ) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned """Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is linear projection to control the dimensionality. If a dropout rate is
@ -45,7 +45,7 @@ def forward(
) )
output = Ragged( output = Ragged(
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
model.ops.asarray([len(doc) for doc in docs], dtype="i") model.ops.asarray([len(doc) for doc in docs], dtype="i"),
) )
if mask is not None: if mask is not None:
output.data *= mask output.data *= mask
@ -55,11 +55,7 @@ def forward(
d_output.data *= mask d_output.data *= mask
model.inc_grad( model.inc_grad(
"W", "W",
model.ops.gemm( model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
d_output.data,
model.ops.as_contig(V[rows]),
trans1=True
)
) )
return [] return []

View File

@ -190,10 +190,7 @@ def get_module_path(module: ModuleType) -> Path:
def load_vectors_into_model( def load_vectors_into_model(
nlp: "Language", nlp: "Language", name: Union[str, Path], *, add_strings=True
name: Union[str, Path],
*,
add_strings=True
) -> None: ) -> None:
"""Load word vectors from an installed model or path into a model instance.""" """Load word vectors from an installed model or path into a model instance."""
vectors_nlp = load_model(name) vectors_nlp = load_model(name)
@ -1205,12 +1202,12 @@ class DummyTokenizer:
def link_vectors_to_models( def link_vectors_to_models(
vocab: "Vocab", vocab: "Vocab",
models: List[Model]=[], models: List[Model] = [],
*, *,
vectors_name_attr="vectors_name", vectors_name_attr="vectors_name",
vectors_attr="vectors", vectors_attr="vectors",
key2row_attr="key2row", key2row_attr="key2row",
default_vectors_name="spacy_pretrained_vectors" default_vectors_name="spacy_pretrained_vectors",
) -> None: ) -> None:
"""Supply vectors data to models.""" """Supply vectors data to models."""
vectors = vocab.vectors vectors = vocab.vectors