This commit is contained in:
Matthew Honnibal 2025-05-21 22:31:25 +02:00
parent b8db76ccbe
commit 5dbd02d560
2 changed files with 5 additions and 3 deletions

View File

@ -7,7 +7,9 @@ from ..tokens import Doc
@registry.layers("spacy.FeatureExtractor.v1") @registry.layers("spacy.FeatureExtractor.v1")
def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]: def FeatureExtractor(
columns: Union[List[str], List[int], List[Union[int, str]]]
) -> Model[List[Doc], List[Ints2d]]:
return Model("extract_features", forward, attrs={"columns": columns}) return Model("extract_features", forward, attrs={"columns": columns})

View File

@ -122,7 +122,7 @@ def build_Tok2Vec_model(
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int,
attrs: List[Union[str, int]], attrs: Union[List[str], List[int], List[Union[str, int]]],
rows: List[int], rows: List[int],
include_static_vectors: bool, include_static_vectors: bool,
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
@ -188,7 +188,7 @@ def MultiHashEmbed(
) )
else: else:
model = chain( model = chain(
FeatureExtractor(list(attrs)), FeatureExtractor(attrs),
cast(Model[List[Ints2d], Ragged], list2ragged()), cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)), with_array(concatenate(*embeddings)),
max_out, max_out,