diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 9fb0e8b7f..3ee76b2d3 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -186,6 +186,7 @@ def MultiHashEmbed( ) return model + def process_affix_config_group( label: str, start_len: Optional[int], @@ -209,6 +210,7 @@ def process_affix_config_group( raise ValueError(Errors.E1045.format(label=label)) return rows if rows is not None else [] + @registry.architectures("spacy.AffixMultiHashEmbed.v1") def AffixMultiHashEmbed( width: int, @@ -237,7 +239,6 @@ def AffixMultiHashEmbed( TODO """ - if len(rows) != len(attrs): raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}") @@ -272,7 +273,7 @@ def AffixMultiHashEmbed( ) ) - embeddings = [ + embeddings = [ # type:ignore HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0) for i, row in enumerate(rows) ] @@ -306,8 +307,8 @@ def AffixMultiHashEmbed( ) ) - if include_static_vectors: - feature_extractor: Model[List[Doc], Ragged] = chain( + if include_static_vectors: + feature_extractor: Model[List[Doc], Ragged] = chain( # type: ignore concatenate(*extractors), cast(Model[List[Ints2d], Ragged], list2ragged()), with_array(concatenate(*embeddings)), @@ -321,7 +322,7 @@ def AffixMultiHashEmbed( ragged2list(), ) else: - model = chain( + model = chain( # type: ignore concatenate(*extractors), cast(Model[List[Ints2d], Ragged], list2ragged()), with_array(concatenate(*embeddings)),