From 5dbd02d560bd869addf551e69fd0e25b58531e7e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 21 May 2025 22:31:25 +0200 Subject: [PATCH] Fix type --- spacy/ml/featureextractor.py | 4 +++- spacy/ml/models/tok2vec.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/spacy/ml/featureextractor.py b/spacy/ml/featureextractor.py index 06f1ff51a..a87b4156c 100644 --- a/spacy/ml/featureextractor.py +++ b/spacy/ml/featureextractor.py @@ -7,7 +7,9 @@ from ..tokens import Doc @registry.layers("spacy.FeatureExtractor.v1") -def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]: +def FeatureExtractor( + columns: Union[List[str], List[int], List[Union[int, str]]] +) -> Model[List[Doc], List[Ints2d]]: return Model("extract_features", forward, attrs={"columns": columns}) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 4dd23679e..b2b803b6e 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -122,7 +122,7 @@ def build_Tok2Vec_model( def MultiHashEmbed( width: int, - attrs: List[Union[str, int]], + attrs: Union[List[str], List[int], List[Union[str, int]]], rows: List[int], include_static_vectors: bool, ) -> Model[List[Doc], List[Floats2d]]: @@ -188,7 +188,7 @@ def MultiHashEmbed( ) else: model = chain( - FeatureExtractor(list(attrs)), + FeatureExtractor(attrs), cast(Model[List[Ints2d], Ragged], list2ragged()), with_array(concatenate(*embeddings)), max_out,