Add temporary #type:ignore s

This commit is contained in:
richardpaulhudson 2022-10-05 19:15:18 +02:00
parent ed76c89968
commit 2a6c1cf63c

View File

@ -186,6 +186,7 @@ def MultiHashEmbed(
) )
return model return model
def process_affix_config_group( def process_affix_config_group(
label: str, label: str,
start_len: Optional[int], start_len: Optional[int],
@ -209,6 +210,7 @@ def process_affix_config_group(
raise ValueError(Errors.E1045.format(label=label)) raise ValueError(Errors.E1045.format(label=label))
return rows if rows is not None else [] return rows if rows is not None else []
@registry.architectures("spacy.AffixMultiHashEmbed.v1") @registry.architectures("spacy.AffixMultiHashEmbed.v1")
def AffixMultiHashEmbed( def AffixMultiHashEmbed(
width: int, width: int,
@ -237,7 +239,6 @@ def AffixMultiHashEmbed(
TODO TODO
""" """
if len(rows) != len(attrs): if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}") raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
@ -272,7 +273,7 @@ def AffixMultiHashEmbed(
) )
) )
embeddings = [ embeddings = [ # type:ignore
HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0) HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0)
for i, row in enumerate(rows) for i, row in enumerate(rows)
] ]
@ -306,8 +307,8 @@ def AffixMultiHashEmbed(
) )
) )
if include_static_vectors: if include_static_vectors:
feature_extractor: Model[List[Doc], Ragged] = chain( feature_extractor: Model[List[Doc], Ragged] = chain( # type: ignore
concatenate(*extractors), concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()), cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)), with_array(concatenate(*embeddings)),
@ -321,7 +322,7 @@ def AffixMultiHashEmbed(
ragged2list(), ragged2list(),
) )
else: else:
model = chain( model = chain( # type: ignore
concatenate(*extractors), concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()), cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)), with_array(concatenate(*embeddings)),