Intermediate state

This commit is contained in:
richardpaulhudson 2022-09-16 20:00:20 +02:00
parent d575b9f8d4
commit 6f42d79c1e
4 changed files with 210 additions and 2 deletions

View File

@ -941,6 +941,8 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1044 = ("Special characters definition for '{label}' may not contain upper-case chars where case_sensitive==False.")
E1045 = ("Invalid affix group config '{label}'.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -0,0 +1,61 @@
from typing import List, Optional, Callable, Tuple
from thinc.types import Ints2d
from thinc.api import Model, registry, get_current_ops
from ..tokens import Doc
@registry.layers("spacy.AffixExtractor.v1")
def AffixExtractor(
*,
suffs_not_prefs: bool,
case_sensitive: bool,
len_start: Optional[int],
len_end: Optional[int],
special_chars: Optional[str],
sc_len_start: Optional[int],
sc_len_end: Optional[int],
) -> Model[List[Doc], List[Ints2d]]:
return Model(
"extract_affixes",
forward,
attrs={
"suffs_not_prefs": suffs_not_prefs,
"case_sensitive": case_sensitive,
"len_start": len_start if len_start is not None else 0,
"len_end": len_end if len_end is not None else 0,
"special_chars": special_chars if special_chars is not None else "",
"sc_len_start": sc_len_start if sc_len_start is not None else 0,
"sc_len_end": sc_len_end if sc_len_end is not None else 0,
},
)
def forward(
model: Model[List[Doc], List[Ints2d]], docs, is_train: bool
) -> Tuple[List[Ints2d], Callable]:
suffs_not_prefs: bool = model.attrs["suffs_not_prefs"]
case_sensitive: bool = model.attrs["case_sensitive"]
len_start: int = model.attrs["len_start"]
len_end: int = model.attrs["len_end"]
special_chars: str = model.attrs["special_chars"]
sc_len_start: int = model.attrs["sc_len_start"]
sc_len_end: int = model.attrs["sc_len_end"]
features: List[Ints2d] = []
for doc in docs:
features.append(
model.ops.asarray2i(
doc.get_affix_hashes(
suffs_not_prefs,
case_sensitive,
len_start,
len_end,
special_chars,
sc_len_start,
sc_len_end,
)
)
)
backprop: Callable[[List[Ints2d]], List] = lambda d_features: []
return features, backprop

View File

@ -1,4 +1,5 @@
from typing import Optional, List, Union, cast
from spacy.ml.affixextractor import AffixExtractor
from thinc.types import Floats2d, Ints2d, Ragged, Ints1d
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
@ -185,6 +186,150 @@ def MultiHashEmbed(
)
return model
def process_affix_config_group(
label: str,
start_len: Optional[int],
end_len: Optional[int],
rows: Optional[List[int]],
scs: Optional[str],
is_sc: bool,
) -> List[int]:
if start_len is not None or end_len is not None or rows is not None:
if start_len is None or end_len is None or rows is None:
raise ValueError(Errors.E1045.format(label=label))
if start_len < 0 or end_len < start_len + 1:
raise ValueError(Errors.E1045.format(label=label))
if is_sc and scs is None:
raise ValueError(Errors.E1045.format(label=label))
if scs is not None and scs != scs.lower():
raise ValueError(Errors.E1044.format(label=label))
if len(rows) != end_len - start_len:
raise ValueError(Errors.E1045.format(label=label))
elif scs is not None:
raise ValueError(Errors.E1045.format(label=label))
return rows if rows is not None else []
@registry.architectures("spacy.AffixMultiHashEmbed.v1")
def AffixMultiHashEmbed(
width: int,
attrs: List[Union[str, int]],
rows: List[int],
include_static_vectors: bool,
*,
affix_case_sensitive: bool,
suffix_start_len: Optional[int] = None,
suffix_end_len: Optional[int] = None,
suffix_rows: Optional[List[int]] = None,
suffix_scs: Optional[str] = None,
suffix_sc_start_len: Optional[int] = None,
suffix_sc_end_len: Optional[int] = None,
suffix_sc_rows: Optional[List[int]] = None,
prefix_start_len: Optional[int] = None,
prefix_end_len: Optional[int] = None,
prefix_rows: Optional[List[int]] = None,
prefix_scs: Optional[str] = None,
prefix_sc_start_len: Optional[int] = None,
prefix_sc_end_len: Optional[int] = None,
prefix_sc_rows: Optional[List[int]] = None,
) -> Model[List[Doc], List[Floats2d]]:
"""
TODO
"""
if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
rows.extend(
process_affix_config_group(
"prefix", prefix_start_len, prefix_end_len, prefix_rows, None, False
)
)
rows.extend(
process_affix_config_group(
"prefix_sc",
prefix_sc_start_len,
prefix_sc_end_len,
prefix_sc_rows,
prefix_scs,
True,
)
)
rows.extend(
process_affix_config_group(
"suffix", suffix_start_len, suffix_end_len, suffix_rows, None, False
)
)
rows.extend(
process_affix_config_group(
"suffix_sc",
suffix_sc_start_len,
suffix_sc_end_len,
suffix_sc_rows,
suffix_scs,
True,
)
)
embeddings = [
HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0)
for i, row in enumerate(rows)
]
concat_size = width * (len(embeddings) + include_static_vectors)
max_out: Model[Ragged, Ragged] = with_array(
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)
)
extractors = [FeatureExtractor(attrs)]
if prefix_start_len is not None or prefix_sc_start_len is not None:
extractors.append(
AffixExtractor(
suffs_not_prefs=False,
case_sensitive=affix_case_sensitive,
len_start=prefix_start_len,
len_end=prefix_end_len,
special_chars=prefix_scs,
sc_len_start=prefix_sc_start_len,
sc_len_end=prefix_sc_end_len,
)
)
if suffix_start_len is not None or suffix_sc_start_len is not None:
extractors.append(
AffixExtractor(
suffs_not_prefs=True,
case_sensitive=affix_case_sensitive,
len_start=suffix_start_len,
len_end=suffix_end_len,
special_chars=suffix_scs,
sc_len_start=suffix_sc_start_len,
sc_len_end=suffix_sc_end_len,
)
)
if include_static_vectors:
feature_extractor: Model[List[Doc], Ragged] = chain(
concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
)
model = chain(
concatenate(
feature_extractor,
StaticVectors(width, dropout=0.0),
),
max_out,
ragged2list(),
)
else:
model = chain(
concatenate(*extractors),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
max_out,
ragged2list(),
)
return model
@registry.architectures("spacy.CharacterEmbed.v2")
def CharacterEmbed(

View File

@ -1735,7 +1735,7 @@ cdef class Doc:
j += 1
return output
def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end,
def get_affix_hashes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end,
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
"""
TODO
@ -1746,7 +1746,7 @@ cdef class Doc:
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64")
for token_index in range(num_tokens):
token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_
token_string = self[token_index].orth_ if case_sensitive else self[token_index].lower_
if suffs_not_prefs:
token_string = token_string[::-1]