mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Format
This commit is contained in:
parent
2aff3c4b5a
commit
0c17ea4c85
|
@ -23,7 +23,7 @@ def tok2vec_listener_v1(width, upstream="*"):
|
|||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||
def Tok2Vec(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]]
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
||||
receptive_field = encode.attrs.get("receptive_field", 0)
|
||||
|
@ -36,14 +36,12 @@ def Tok2Vec(
|
|||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
also_embed_subwords: bool,
|
||||
also_use_static_vectors: bool
|
||||
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||
):
|
||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(feature):
|
||||
nonlocal seed
|
||||
seed += 1
|
||||
|
@ -52,7 +50,7 @@ def MultiHashEmbed(
|
|||
rows if feature == NORM else rows // 2,
|
||||
column=cols.index(feature),
|
||||
seed=seed,
|
||||
dropout=0.0
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
if also_embed_subwords:
|
||||
|
@ -60,7 +58,7 @@ def MultiHashEmbed(
|
|||
make_hash_embed(NORM),
|
||||
make_hash_embed(PREFIX),
|
||||
make_hash_embed(SUFFIX),
|
||||
make_hash_embed(SHAPE)
|
||||
make_hash_embed(SHAPE),
|
||||
]
|
||||
else:
|
||||
embeddings = [make_hash_embed(NORM)]
|
||||
|
@ -71,22 +69,22 @@ def MultiHashEmbed(
|
|||
chain(
|
||||
FeatureExtractor(cols),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings))
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
StaticVectors(width, dropout=0.0)
|
||||
StaticVectors(width, dropout=0.0),
|
||||
),
|
||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list()
|
||||
ragged2list(),
|
||||
)
|
||||
else:
|
||||
model = chain(
|
||||
chain(
|
||||
FeatureExtractor(cols),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings))
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list()
|
||||
ragged2list(),
|
||||
)
|
||||
return model
|
||||
|
||||
|
@ -137,6 +135,4 @@ def MishWindowEncoder(width, window_size, depth):
|
|||
def BiLSTMEncoder(width, depth, dropout):
|
||||
if depth == 0:
|
||||
return noop()
|
||||
return with_padded(
|
||||
PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout)
|
||||
)
|
||||
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
|
||||
|
|
|
@ -45,7 +45,7 @@ def forward(
|
|||
)
|
||||
output = Ragged(
|
||||
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||
model.ops.asarray([len(doc) for doc in docs], dtype="i")
|
||||
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
||||
)
|
||||
if mask is not None:
|
||||
output.data *= mask
|
||||
|
@ -55,11 +55,7 @@ def forward(
|
|||
d_output.data *= mask
|
||||
model.inc_grad(
|
||||
"W",
|
||||
model.ops.gemm(
|
||||
d_output.data,
|
||||
model.ops.as_contig(V[rows]),
|
||||
trans1=True
|
||||
)
|
||||
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
|
||||
)
|
||||
return []
|
||||
|
||||
|
|
|
@ -190,10 +190,7 @@ def get_module_path(module: ModuleType) -> Path:
|
|||
|
||||
|
||||
def load_vectors_into_model(
|
||||
nlp: "Language",
|
||||
name: Union[str, Path],
|
||||
*,
|
||||
add_strings=True
|
||||
nlp: "Language", name: Union[str, Path], *, add_strings=True
|
||||
) -> None:
|
||||
"""Load word vectors from an installed model or path into a model instance."""
|
||||
vectors_nlp = load_model(name)
|
||||
|
@ -1210,7 +1207,7 @@ def link_vectors_to_models(
|
|||
vectors_name_attr="vectors_name",
|
||||
vectors_attr="vectors",
|
||||
key2row_attr="key2row",
|
||||
default_vectors_name="spacy_pretrained_vectors"
|
||||
default_vectors_name="spacy_pretrained_vectors",
|
||||
) -> None:
|
||||
"""Supply vectors data to models."""
|
||||
vectors = vocab.vectors
|
||||
|
|
Loading…
Reference in New Issue
Block a user