Intermediate state

This commit is contained in:
richardpaulhudson 2022-09-16 20:00:20 +02:00
parent d575b9f8d4
commit 6f42d79c1e
4 changed files with 210 additions and 2 deletions

View File

@ -941,6 +941,8 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.") "`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.") "{value}.")
E1044 = ("Special characters definition for '{label}' may not contain upper-case chars where case_sensitive==False.")
E1045 = ("Invalid affix group config '{label}'.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -0,0 +1,61 @@
from typing import List, Optional, Callable, Tuple
from thinc.types import Ints2d
from thinc.api import Model, registry, get_current_ops
from ..tokens import Doc
@registry.layers("spacy.AffixExtractor.v1")
def AffixExtractor(
*,
suffs_not_prefs: bool,
case_sensitive: bool,
len_start: Optional[int],
len_end: Optional[int],
special_chars: Optional[str],
sc_len_start: Optional[int],
sc_len_end: Optional[int],
) -> Model[List[Doc], List[Ints2d]]:
return Model(
"extract_affixes",
forward,
attrs={
"suffs_not_prefs": suffs_not_prefs,
"case_sensitive": case_sensitive,
"len_start": len_start if len_start is not None else 0,
"len_end": len_end if len_end is not None else 0,
"special_chars": special_chars if special_chars is not None else "",
"sc_len_start": sc_len_start if sc_len_start is not None else 0,
"sc_len_end": sc_len_end if sc_len_end is not None else 0,
},
)
def forward(
model: Model[List[Doc], List[Ints2d]], docs, is_train: bool
) -> Tuple[List[Ints2d], Callable]:
suffs_not_prefs: bool = model.attrs["suffs_not_prefs"]
case_sensitive: bool = model.attrs["case_sensitive"]
len_start: int = model.attrs["len_start"]
len_end: int = model.attrs["len_end"]
special_chars: str = model.attrs["special_chars"]
sc_len_start: int = model.attrs["sc_len_start"]
sc_len_end: int = model.attrs["sc_len_end"]
features: List[Ints2d] = []
for doc in docs:
features.append(
model.ops.asarray2i(
doc.get_affix_hashes(
suffs_not_prefs,
case_sensitive,
len_start,
len_end,
special_chars,
sc_len_start,
sc_len_end,
)
)
)
backprop: Callable[[List[Ints2d]], List] = lambda d_features: []
return features, backprop

View File

@ -1,4 +1,5 @@
from typing import Optional, List, Union, cast from typing import Optional, List, Union, cast
from spacy.ml.affixextractor import AffixExtractor
from thinc.types import Floats2d, Ints2d, Ragged, Ints1d from thinc.types import Floats2d, Ints2d, Ragged, Ints1d
from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
@ -185,6 +186,150 @@ def MultiHashEmbed(
) )
return model return model
def process_affix_config_group(
label: str,
start_len: Optional[int],
end_len: Optional[int],
rows: Optional[List[int]],
scs: Optional[str],
is_sc: bool,
) -> List[int]:
if start_len is not None or end_len is not None or rows is not None:
if start_len is None or end_len is None or rows is None:
raise ValueError(Errors.E1045.format(label=label))
if start_len < 0 or end_len < start_len + 1:
raise ValueError(Errors.E1045.format(label=label))
if is_sc and scs is None:
raise ValueError(Errors.E1045.format(label=label))
if scs is not None and scs != scs.lower():
raise ValueError(Errors.E1044.format(label=label))
if len(rows) != end_len - start_len:
raise ValueError(Errors.E1045.format(label=label))
elif scs is not None:
raise ValueError(Errors.E1045.format(label=label))
return rows if rows is not None else []
@registry.architectures("spacy.AffixMultiHashEmbed.v1")
def AffixMultiHashEmbed(
width: int,
attrs: List[Union[str, int]],
rows: List[int],
include_static_vectors: bool,
*,
affix_case_sensitive: bool,
suffix_start_len: Optional[int] = None,
suffix_end_len: Optional[int] = None,
suffix_rows: Optional[List[int]] = None,
suffix_scs: Optional[str] = None,
suffix_sc_start_len: Optional[int] = None,
suffix_sc_end_len: Optional[int] = None,
suffix_sc_rows: Optional[List[int]] = None,
prefix_start_len: Optional[int] = None,
prefix_end_len: Optional[int] = None,
prefix_rows: Optional[List[int]] = None,
prefix_scs: Optional[str] = None,
prefix_sc_start_len: Optional[int] = None,
prefix_sc_end_len: Optional[int] = None,
prefix_sc_rows: Optional[List[int]] = None,
) -> Model[List[Doc], List[Floats2d]]:
"""
TODO
"""
if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
rows.extend(
process_affix_config_group(
"prefix", prefix_start_len, prefix_end_len, prefix_rows, None, False
)
)
rows.extend(
process_affix_config_group(
"prefix_sc",
prefix_sc_start_len,
prefix_sc_end_len,
prefix_sc_rows,
prefix_scs,
True,
)
)
rows.extend(
process_affix_config_group(
"suffix", suffix_start_len, suffix_end_len, suffix_rows, None, False
)
)
rows.extend(
process_affix_config_group(
"suffix_sc",
suffix_sc_start_len,
suffix_sc_end_len,
suffix_sc_rows,
suffix_scs,
True,
)
)
embeddings = [
HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0)
for i, row in enumerate(rows)
]
concat_size = width * (len(embeddings) + include_static_vectors)
max_out: Model[Ragged, Ragged] = with_array(
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)
)
extractors = [FeatureExtractor(attrs)]
if prefix_start_len is not None or prefix_sc_start_len is not None:
extractors.append(
AffixExtractor(
suffs_not_prefs=False,
case_sensitive=affix_case_sensitive,
len_start=prefix_start_len,
len_end=prefix_end_len,
special_chars=prefix_scs,
sc_len_start=prefix_sc_start_len,
sc_len_end=prefix_sc_end_len,
)
)
if suffix_start_len is not None or suffix_sc_start_len is not None:
extractors.append(
AffixExtractor(
suffs_not_prefs=True,
case_sensitive=affix_case_sensitive,
len_start=suffix_start_len,
len_end=suffix_end_len,
special_chars=suffix_scs,
sc_len_start=suffix_sc_start_len,
sc_len_end=suffix_sc_end_len,
)
)
if include_static_vectors:
feature_extractor: Model[List[Doc], Ragged] = chain(
concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
)
model = chain(
concatenate(
feature_extractor,
StaticVectors(width, dropout=0.0),
),
max_out,
ragged2list(),
)
else:
model = chain(
concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
max_out,
ragged2list(),
)
return model
@registry.architectures("spacy.CharacterEmbed.v2") @registry.architectures("spacy.CharacterEmbed.v2")
def CharacterEmbed( def CharacterEmbed(

View File

@ -1735,7 +1735,7 @@ cdef class Doc:
j += 1 j += 1
return output return output
def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end, def get_affix_hashes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end,
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end): str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
""" """
TODO TODO
@ -1746,7 +1746,7 @@ cdef class Doc:
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64") cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64")
for token_index in range(num_tokens): for token_index in range(num_tokens):
token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_ token_string = self[token_index].orth_ if case_sensitive else self[token_index].lower_
if suffs_not_prefs: if suffs_not_prefs:
token_string = token_string[::-1] token_string = token_string[::-1]