mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 11:20:19 +03:00
Intermediate state
This commit is contained in:
parent
d575b9f8d4
commit
6f42d79c1e
|
@ -941,6 +941,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
||||
"{value}.")
|
||||
E1044 = ("Special characters definition for '{label}' may not contain upper-case chars where case_sensitive==False.")
|
||||
E1045 = ("Invalid affix group config '{label}'.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
61
spacy/ml/affixextractor.py
Normal file
61
spacy/ml/affixextractor.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from typing import List, Optional, Callable, Tuple
|
||||
from thinc.types import Ints2d
|
||||
from thinc.api import Model, registry, get_current_ops
|
||||
|
||||
from ..tokens import Doc
|
||||
|
||||
|
||||
@registry.layers("spacy.AffixExtractor.v1")
|
||||
def AffixExtractor(
|
||||
*,
|
||||
suffs_not_prefs: bool,
|
||||
case_sensitive: bool,
|
||||
len_start: Optional[int],
|
||||
len_end: Optional[int],
|
||||
special_chars: Optional[str],
|
||||
sc_len_start: Optional[int],
|
||||
sc_len_end: Optional[int],
|
||||
) -> Model[List[Doc], List[Ints2d]]:
|
||||
return Model(
|
||||
"extract_affixes",
|
||||
forward,
|
||||
attrs={
|
||||
"suffs_not_prefs": suffs_not_prefs,
|
||||
"case_sensitive": case_sensitive,
|
||||
"len_start": len_start if len_start is not None else 0,
|
||||
"len_end": len_end if len_end is not None else 0,
|
||||
"special_chars": special_chars if special_chars is not None else "",
|
||||
"sc_len_start": sc_len_start if sc_len_start is not None else 0,
|
||||
"sc_len_end": sc_len_end if sc_len_end is not None else 0,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
model: Model[List[Doc], List[Ints2d]], docs, is_train: bool
|
||||
) -> Tuple[List[Ints2d], Callable]:
|
||||
suffs_not_prefs: bool = model.attrs["suffs_not_prefs"]
|
||||
case_sensitive: bool = model.attrs["case_sensitive"]
|
||||
len_start: int = model.attrs["len_start"]
|
||||
len_end: int = model.attrs["len_end"]
|
||||
special_chars: str = model.attrs["special_chars"]
|
||||
sc_len_start: int = model.attrs["sc_len_start"]
|
||||
sc_len_end: int = model.attrs["sc_len_end"]
|
||||
features: List[Ints2d] = []
|
||||
for doc in docs:
|
||||
features.append(
|
||||
model.ops.asarray2i(
|
||||
doc.get_affix_hashes(
|
||||
suffs_not_prefs,
|
||||
case_sensitive,
|
||||
len_start,
|
||||
len_end,
|
||||
special_chars,
|
||||
sc_len_start,
|
||||
sc_len_end,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
backprop: Callable[[List[Ints2d]], List] = lambda d_features: []
|
||||
return features, backprop
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, List, Union, cast
|
||||
from spacy.ml.affixextractor import AffixExtractor
|
||||
from thinc.types import Floats2d, Ints2d, Ragged, Ints1d
|
||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||
|
@ -185,6 +186,150 @@ def MultiHashEmbed(
|
|||
)
|
||||
return model
|
||||
|
||||
def process_affix_config_group(
|
||||
label: str,
|
||||
start_len: Optional[int],
|
||||
end_len: Optional[int],
|
||||
rows: Optional[List[int]],
|
||||
scs: Optional[str],
|
||||
is_sc: bool,
|
||||
) -> List[int]:
|
||||
if start_len is not None or end_len is not None or rows is not None:
|
||||
if start_len is None or end_len is None or rows is None:
|
||||
raise ValueError(Errors.E1045.format(label=label))
|
||||
if start_len < 0 or end_len < start_len + 1:
|
||||
raise ValueError(Errors.E1045.format(label=label))
|
||||
if is_sc and scs is None:
|
||||
raise ValueError(Errors.E1045.format(label=label))
|
||||
if scs is not None and scs != scs.lower():
|
||||
raise ValueError(Errors.E1044.format(label=label))
|
||||
if len(rows) != end_len - start_len:
|
||||
raise ValueError(Errors.E1045.format(label=label))
|
||||
elif scs is not None:
|
||||
raise ValueError(Errors.E1045.format(label=label))
|
||||
return rows if rows is not None else []
|
||||
|
||||
@registry.architectures("spacy.AffixMultiHashEmbed.v1")
|
||||
def AffixMultiHashEmbed(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
rows: List[int],
|
||||
include_static_vectors: bool,
|
||||
*,
|
||||
affix_case_sensitive: bool,
|
||||
suffix_start_len: Optional[int] = None,
|
||||
suffix_end_len: Optional[int] = None,
|
||||
suffix_rows: Optional[List[int]] = None,
|
||||
suffix_scs: Optional[str] = None,
|
||||
suffix_sc_start_len: Optional[int] = None,
|
||||
suffix_sc_end_len: Optional[int] = None,
|
||||
suffix_sc_rows: Optional[List[int]] = None,
|
||||
prefix_start_len: Optional[int] = None,
|
||||
prefix_end_len: Optional[int] = None,
|
||||
prefix_rows: Optional[List[int]] = None,
|
||||
prefix_scs: Optional[str] = None,
|
||||
prefix_sc_start_len: Optional[int] = None,
|
||||
prefix_sc_end_len: Optional[int] = None,
|
||||
prefix_sc_rows: Optional[List[int]] = None,
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
|
||||
if len(rows) != len(attrs):
|
||||
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
|
||||
|
||||
rows.extend(
|
||||
process_affix_config_group(
|
||||
"prefix", prefix_start_len, prefix_end_len, prefix_rows, None, False
|
||||
)
|
||||
)
|
||||
rows.extend(
|
||||
process_affix_config_group(
|
||||
"prefix_sc",
|
||||
prefix_sc_start_len,
|
||||
prefix_sc_end_len,
|
||||
prefix_sc_rows,
|
||||
prefix_scs,
|
||||
True,
|
||||
)
|
||||
)
|
||||
rows.extend(
|
||||
process_affix_config_group(
|
||||
"suffix", suffix_start_len, suffix_end_len, suffix_rows, None, False
|
||||
)
|
||||
)
|
||||
rows.extend(
|
||||
process_affix_config_group(
|
||||
"suffix_sc",
|
||||
suffix_sc_start_len,
|
||||
suffix_sc_end_len,
|
||||
suffix_sc_rows,
|
||||
suffix_scs,
|
||||
True,
|
||||
)
|
||||
)
|
||||
|
||||
embeddings = [
|
||||
HashEmbed(width, row, column=i, seed=i + 7, dropout=0.0)
|
||||
for i, row in enumerate(rows)
|
||||
]
|
||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||
max_out: Model[Ragged, Ragged] = with_array(
|
||||
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)
|
||||
)
|
||||
extractors = [FeatureExtractor(attrs)]
|
||||
if prefix_start_len is not None or prefix_sc_start_len is not None:
|
||||
extractors.append(
|
||||
AffixExtractor(
|
||||
suffs_not_prefs=False,
|
||||
case_sensitive=affix_case_sensitive,
|
||||
len_start=prefix_start_len,
|
||||
len_end=prefix_end_len,
|
||||
special_chars=prefix_scs,
|
||||
sc_len_start=prefix_sc_start_len,
|
||||
sc_len_end=prefix_sc_end_len,
|
||||
)
|
||||
)
|
||||
if suffix_start_len is not None or suffix_sc_start_len is not None:
|
||||
extractors.append(
|
||||
AffixExtractor(
|
||||
suffs_not_prefs=True,
|
||||
case_sensitive=affix_case_sensitive,
|
||||
len_start=suffix_start_len,
|
||||
len_end=suffix_end_len,
|
||||
special_chars=suffix_scs,
|
||||
sc_len_start=suffix_sc_start_len,
|
||||
sc_len_end=suffix_sc_end_len,
|
||||
)
|
||||
)
|
||||
|
||||
if include_static_vectors:
|
||||
feature_extractor: Model[List[Doc], Ragged] = chain(
|
||||
concatenate(*extractors),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
)
|
||||
model = chain(
|
||||
concatenate(
|
||||
feature_extractor,
|
||||
StaticVectors(width, dropout=0.0),
|
||||
),
|
||||
max_out,
|
||||
ragged2list(),
|
||||
)
|
||||
else:
|
||||
model = chain(
|
||||
concatenate(*extractors),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
max_out,
|
||||
ragged2list(),
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.CharacterEmbed.v2")
|
||||
def CharacterEmbed(
|
||||
|
|
|
@ -1735,7 +1735,7 @@ cdef class Doc:
|
|||
j += 1
|
||||
return output
|
||||
|
||||
def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end,
|
||||
def get_affix_hashes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end,
|
||||
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
|
||||
"""
|
||||
TODO
|
||||
|
@ -1746,7 +1746,7 @@ cdef class Doc:
|
|||
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64")
|
||||
|
||||
for token_index in range(num_tokens):
|
||||
token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_
|
||||
token_string = self[token_index].orth_ if case_sensitive else self[token_index].lower_
|
||||
if suffs_not_prefs:
|
||||
token_string = token_string[::-1]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user