mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 01:13:17 +03:00
Format
This commit is contained in:
parent
2aff3c4b5a
commit
0c17ea4c85
|
@ -23,7 +23,7 @@ def tok2vec_listener_v1(width, upstream="*"):
|
||||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||||
def Tok2Vec(
|
def Tok2Vec(
|
||||||
embed: Model[List[Doc], List[Floats2d]],
|
embed: Model[List[Doc], List[Floats2d]],
|
||||||
encode: Model[List[Floats2d], List[Floats2d]]
|
encode: Model[List[Floats2d], List[Floats2d]],
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
|
||||||
receptive_field = encode.attrs.get("receptive_field", 0)
|
receptive_field = encode.attrs.get("receptive_field", 0)
|
||||||
|
@ -36,14 +36,12 @@ def Tok2Vec(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
width: int,
|
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||||
rows: int,
|
|
||||||
also_embed_subwords: bool,
|
|
||||||
also_use_static_vectors: bool
|
|
||||||
):
|
):
|
||||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
|
|
||||||
seed = 7
|
seed = 7
|
||||||
|
|
||||||
def make_hash_embed(feature):
|
def make_hash_embed(feature):
|
||||||
nonlocal seed
|
nonlocal seed
|
||||||
seed += 1
|
seed += 1
|
||||||
|
@ -52,7 +50,7 @@ def MultiHashEmbed(
|
||||||
rows if feature == NORM else rows // 2,
|
rows if feature == NORM else rows // 2,
|
||||||
column=cols.index(feature),
|
column=cols.index(feature),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
dropout=0.0
|
dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if also_embed_subwords:
|
if also_embed_subwords:
|
||||||
|
@ -60,7 +58,7 @@ def MultiHashEmbed(
|
||||||
make_hash_embed(NORM),
|
make_hash_embed(NORM),
|
||||||
make_hash_embed(PREFIX),
|
make_hash_embed(PREFIX),
|
||||||
make_hash_embed(SUFFIX),
|
make_hash_embed(SUFFIX),
|
||||||
make_hash_embed(SHAPE)
|
make_hash_embed(SHAPE),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
embeddings = [make_hash_embed(NORM)]
|
embeddings = [make_hash_embed(NORM)]
|
||||||
|
@ -71,22 +69,22 @@ def MultiHashEmbed(
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor(cols),
|
FeatureExtractor(cols),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings))
|
with_array(concatenate(*embeddings)),
|
||||||
),
|
),
|
||||||
StaticVectors(width, dropout=0.0)
|
StaticVectors(width, dropout=0.0),
|
||||||
),
|
),
|
||||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||||
ragged2list()
|
ragged2list(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = chain(
|
model = chain(
|
||||||
chain(
|
chain(
|
||||||
FeatureExtractor(cols),
|
FeatureExtractor(cols),
|
||||||
list2ragged(),
|
list2ragged(),
|
||||||
with_array(concatenate(*embeddings))
|
with_array(concatenate(*embeddings)),
|
||||||
),
|
),
|
||||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||||
ragged2list()
|
ragged2list(),
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -137,6 +135,4 @@ def MishWindowEncoder(width, window_size, depth):
|
||||||
def BiLSTMEncoder(width, depth, dropout):
|
def BiLSTMEncoder(width, depth, dropout):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
return noop()
|
return noop()
|
||||||
return with_padded(
|
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
|
||||||
PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout)
|
|
||||||
)
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ def forward(
|
||||||
)
|
)
|
||||||
output = Ragged(
|
output = Ragged(
|
||||||
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||||
model.ops.asarray([len(doc) for doc in docs], dtype="i")
|
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
||||||
)
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
output.data *= mask
|
output.data *= mask
|
||||||
|
@ -55,11 +55,7 @@ def forward(
|
||||||
d_output.data *= mask
|
d_output.data *= mask
|
||||||
model.inc_grad(
|
model.inc_grad(
|
||||||
"W",
|
"W",
|
||||||
model.ops.gemm(
|
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
|
||||||
d_output.data,
|
|
||||||
model.ops.as_contig(V[rows]),
|
|
||||||
trans1=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
@ -190,10 +190,7 @@ def get_module_path(module: ModuleType) -> Path:
|
||||||
|
|
||||||
|
|
||||||
def load_vectors_into_model(
|
def load_vectors_into_model(
|
||||||
nlp: "Language",
|
nlp: "Language", name: Union[str, Path], *, add_strings=True
|
||||||
name: Union[str, Path],
|
|
||||||
*,
|
|
||||||
add_strings=True
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load word vectors from an installed model or path into a model instance."""
|
"""Load word vectors from an installed model or path into a model instance."""
|
||||||
vectors_nlp = load_model(name)
|
vectors_nlp = load_model(name)
|
||||||
|
@ -1210,7 +1207,7 @@ def link_vectors_to_models(
|
||||||
vectors_name_attr="vectors_name",
|
vectors_name_attr="vectors_name",
|
||||||
vectors_attr="vectors",
|
vectors_attr="vectors",
|
||||||
key2row_attr="key2row",
|
key2row_attr="key2row",
|
||||||
default_vectors_name="spacy_pretrained_vectors"
|
default_vectors_name="spacy_pretrained_vectors",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Supply vectors data to models."""
|
"""Supply vectors data to models."""
|
||||||
vectors = vocab.vectors
|
vectors = vocab.vectors
|
||||||
|
|
Loading…
Reference in New Issue
Block a user