mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Fix type
This commit is contained in:
parent
b8db76ccbe
commit
5dbd02d560
|
@ -7,7 +7,9 @@ from ..tokens import Doc
|
|||
|
||||
|
||||
@registry.layers("spacy.FeatureExtractor.v1")
|
||||
def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]:
|
||||
def FeatureExtractor(
|
||||
columns: Union[List[str], List[int], List[Union[int, str]]]
|
||||
) -> Model[List[Doc], List[Ints2d]]:
|
||||
return Model("extract_features", forward, attrs={"columns": columns})
|
||||
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ def build_Tok2Vec_model(
|
|||
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
attrs: Union[List[str], List[int], List[Union[str, int]]],
|
||||
rows: List[int],
|
||||
include_static_vectors: bool,
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
@ -188,7 +188,7 @@ def MultiHashEmbed(
|
|||
)
|
||||
else:
|
||||
model = chain(
|
||||
FeatureExtractor(list(attrs)),
|
||||
FeatureExtractor(attrs),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
max_out,
|
||||
|
|
Loading…
Reference in New Issue
Block a user