mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Fix type
This commit is contained in:
parent
b8db76ccbe
commit
5dbd02d560
|
@ -7,7 +7,9 @@ from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
@registry.layers("spacy.FeatureExtractor.v1")
|
@registry.layers("spacy.FeatureExtractor.v1")
|
||||||
def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]:
|
def FeatureExtractor(
|
||||||
|
columns: Union[List[str], List[int], List[Union[int, str]]]
|
||||||
|
) -> Model[List[Doc], List[Ints2d]]:
|
||||||
return Model("extract_features", forward, attrs={"columns": columns})
|
return Model("extract_features", forward, attrs={"columns": columns})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ def build_Tok2Vec_model(
|
||||||
|
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
width: int,
|
width: int,
|
||||||
attrs: List[Union[str, int]],
|
attrs: Union[List[str], List[int], List[Union[str, int]]],
|
||||||
rows: List[int],
|
rows: List[int],
|
||||||
include_static_vectors: bool,
|
include_static_vectors: bool,
|
||||||
) -> Model[List[Doc], List[Floats2d]]:
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
@ -188,7 +188,7 @@ def MultiHashEmbed(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = chain(
|
model = chain(
|
||||||
FeatureExtractor(list(attrs)),
|
FeatureExtractor(attrs),
|
||||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||||
with_array(concatenate(*embeddings)),
|
with_array(concatenate(*embeddings)),
|
||||||
max_out,
|
max_out,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user