diff --git a/spacy/ml/featureextractor.py b/spacy/ml/featureextractor.py index 06f1ff51a..a87b4156c 100644 --- a/spacy/ml/featureextractor.py +++ b/spacy/ml/featureextractor.py @@ -7,7 +7,9 @@ from ..tokens import Doc @registry.layers("spacy.FeatureExtractor.v1") -def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]: +def FeatureExtractor( + columns: Union[List[str], List[int], List[Union[int, str]]] +) -> Model[List[Doc], List[Ints2d]]: return Model("extract_features", forward, attrs={"columns": columns}) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 4dd23679e..b2b803b6e 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -122,7 +122,7 @@ def build_Tok2Vec_model( def MultiHashEmbed( width: int, - attrs: List[Union[str, int]], + attrs: Union[List[str], List[int], List[Union[str, int]]], rows: List[int], include_static_vectors: bool, ) -> Model[List[Doc], List[Floats2d]]: @@ -188,7 +188,7 @@ def MultiHashEmbed( ) else: model = chain( - FeatureExtractor(list(attrs)), + FeatureExtractor(attrs), cast(Model[List[Ints2d], Ragged], list2ragged()), with_array(concatenate(*embeddings)), max_out,