From 6f42d79c1ec70a90b3a80c40e6ae7f29ab6580cb Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Fri, 16 Sep 2022 20:00:20 +0200 Subject: [PATCH] Intermediate state --- spacy/errors.py | 2 + spacy/ml/affixextractor.py | 61 ++++++++++++++++ spacy/ml/models/tok2vec.py | 145 +++++++++++++++++++++++++++++++++++++ spacy/tokens/doc.pyx | 4 +- 4 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 spacy/ml/affixextractor.py diff --git a/spacy/errors.py b/spacy/errors.py index 5ee1476c2..f171b0037 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -941,6 +941,8 @@ class Errors(metaclass=ErrorsWithCodes): "`{arg2}`={arg2_values} but these arguments are conflicting.") E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " "{value}.") + E1044 = ("Special characters definition for '{label}' may not contain upper-case chars where case_sensitive==False.") + E1045 = ("Invalid affix group config '{label}'.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/ml/affixextractor.py b/spacy/ml/affixextractor.py new file mode 100644 index 000000000..6258fc66d --- /dev/null +++ b/spacy/ml/affixextractor.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Callable, Tuple +from thinc.types import Ints2d +from thinc.api import Model, registry, get_current_ops + +from ..tokens import Doc + + +@registry.layers("spacy.AffixExtractor.v1") +def AffixExtractor( + *, + suffs_not_prefs: bool, + case_sensitive: bool, + len_start: Optional[int], + len_end: Optional[int], + special_chars: Optional[str], + sc_len_start: Optional[int], + sc_len_end: Optional[int], +) -> Model[List[Doc], List[Ints2d]]: + return Model( + "extract_affixes", + forward, + attrs={ + "suffs_not_prefs": suffs_not_prefs, + "case_sensitive": case_sensitive, + "len_start": len_start if len_start is not None else 0, + "len_end": len_end if len_end is not None else 0, + "special_chars": special_chars if special_chars is not None else "", + "sc_len_start": sc_len_start if sc_len_start is not None else 0, + "sc_len_end": sc_len_end if sc_len_end is not None else 0, + }, + ) + + +def forward( + model: Model[List[Doc], List[Ints2d]], docs, is_train: bool +) -> Tuple[List[Ints2d], Callable]: + suffs_not_prefs: bool = model.attrs["suffs_not_prefs"] + case_sensitive: bool = model.attrs["case_sensitive"] + len_start: int = model.attrs["len_start"] + len_end: int = model.attrs["len_end"] + special_chars: str = model.attrs["special_chars"] + sc_len_start: int = model.attrs["sc_len_start"] + sc_len_end: int = model.attrs["sc_len_end"] + features: List[Ints2d] = [] + for doc in docs: + features.append( + model.ops.asarray2i( + doc.get_affix_hashes( + suffs_not_prefs, + case_sensitive, + len_start, + len_end, + special_chars, + sc_len_start, + sc_len_end, + ) + ) + ) + + backprop: Callable[[List[Ints2d]], List] = lambda d_features: [] + return features, backprop diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 30c7360ff..9fb0e8b7f 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,4 +1,5 @@ from typing import Optional, List, Union, cast +from spacy.ml.affixextractor import AffixExtractor from thinc.types import Floats2d, Ints2d, Ragged, Ints1d from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed @@ -185,6 +186,150 @@ def MultiHashEmbed( ) return model +def process_affix_config_group( + label: str, + start_len: Optional[int], + end_len: Optional[int], + rows: Optional[List[int]], + scs: Optional[str], + is_sc: bool, +) -> List[int]: + if start_len is not None or end_len is not None or rows is not None: + if start_len is None or end_len is None or rows is None: + raise ValueError(Errors.E1045.format(label=label)) + if start_len < 0 or end_len < start_len + 1: + raise ValueError(Errors.E1045.format(label=label)) + if is_sc and scs is None: + raise ValueError(Errors.E1045.format(label=label)) + if scs is not None and scs != scs.lower(): + raise ValueError(Errors.E1044.format(label=label)) + if len(rows) != end_len - start_len: + raise ValueError(Errors.E1045.format(label=label)) + elif scs is not None: + raise ValueError(Errors.E1045.format(label=label)) + return rows if rows is not None else [] + +@registry.architectures("spacy.AffixMultiHashEmbed.v1") +def AffixMultiHashEmbed( + width: int, + attrs: List[Union[str, int]], + rows: List[int], + include_static_vectors: bool, + *, + affix_case_sensitive: bool, + suffix_start_len: Optional[int] = None, + suffix_end_len: Optional[int] = None, + suffix_rows: Optional[List[int]] = None, + suffix_scs: Optional[str] = None, + suffix_sc_start_len: Optional[int] = None, + suffix_sc_end_len: Optional[int] = None, + suffix_sc_rows: Optional[List[int]] = None, + prefix_start_len: Optional[int] = None, + prefix_end_len: Optional[int] = None, + prefix_rows: Optional[List[int]] = None, + prefix_scs: Optional[str] = None, + prefix_sc_start_len: Optional[int] = None, + prefix_sc_end_len: Optional[int] = None, + prefix_sc_rows: Optional[List[int]] = None, +) -> Model[List[Doc], List[Floats2d]]: + + """ + TODO + """ + + + if len(rows) != len(attrs): + raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}") + + rows.extend( + process_affix_config_group( + "prefix", prefix_start_len, prefix_end_len, prefix_rows, None, False + ) + ) + rows.extend( + process_affix_config_group( + "prefix_sc", + prefix_sc_start_len, + prefix_sc_end_len, + prefix_sc_rows, + prefix_scs, + True, + ) + ) + rows.extend( + process_affix_config_group( + "suffix", suffix_start_len, suffix_end_len, suffix_rows, None, False + ) + ) + rows.extend( + process_affix_config_group( + "suffix_sc", + suffix_sc_start_len, + suffix_sc_end_len, + suffix_sc_rows, + suffix_scs, + True, + ) + ) + + embeddings = [ + HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0) + for i, row in enumerate(rows) + ] + concat_size = width * (len(embeddings) + include_static_vectors) + max_out: Model[Ragged, Ragged] = with_array( + Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) + ) + extractors = [FeatureExtractor(attrs)] + if prefix_start_len is not None or prefix_sc_start_len is not None: + extractors.append( + AffixExtractor( + suffs_not_prefs=False, + case_sensitive=affix_case_sensitive, + len_start=prefix_start_len, + len_end=prefix_end_len, + special_chars=prefix_scs, + sc_len_start=prefix_sc_start_len, + sc_len_end=prefix_sc_end_len, + ) + ) + if suffix_start_len is not None or suffix_sc_start_len is not None: + extractors.append( + AffixExtractor( + suffs_not_prefs=True, + case_sensitive=affix_case_sensitive, + len_start=suffix_start_len, + len_end=suffix_end_len, + special_chars=suffix_scs, + sc_len_start=suffix_sc_start_len, + sc_len_end=suffix_sc_end_len, + ) + ) + + if include_static_vectors: + feature_extractor: Model[List[Doc], Ragged] = chain( + concatenate(*extractors), + cast(Model[List[Ints2d], Ragged], list2ragged()), + with_array(concatenate(*embeddings)), + ) + model = chain( + concatenate( + feature_extractor, + StaticVectors(width, dropout=0.0), + ), + max_out, + ragged2list(), + ) + else: + model = chain( + concatenate(*extractors), + cast(Model[List[Ints2d], Ragged], list2ragged()), + with_array(concatenate(*embeddings)), + max_out, + ragged2list(), + ) + return model + @registry.architectures("spacy.CharacterEmbed.v2") def CharacterEmbed( diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 9bf0b733d..4c29cf1e9 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1735,7 +1735,7 @@ cdef class Doc: j += 1 return output - def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end, + def get_affix_hashes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end, str special_chars, unsigned int sc_len_start, unsigned int sc_len_end): """ TODO @@ -1746,7 +1746,7 @@ cdef class Doc: cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64") for token_index in range(num_tokens): - token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_ + token_string = self[token_index].orth_ if case_sensitive else self[token_index].lower_ if suffs_not_prefs: token_string = token_string[::-1]