mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Python 3.13 support (#13823)
In order to support Python 3.13, we had to migrate to Cython 3.0. This caused some tricky interaction with our Pydantic usage, because Cython 3 uses the from __future__ import annotations semantics, which causes type annotations to be saved as strings. The end result is that we can't have Language.factory decorated functions in Cython modules anymore, as the Language.factory decorator expects to inspect the signature of the functions and build a Pydantic model. If the function is implemented in Cython, an error is raised because the type is not resolved. To address this I've moved the factory functions into a new module, spacy.pipeline.factories. I've added __getattr__ importlib hooks to the previous locations, in case anyone was importing these functions directly. The change should have no backwards compatibility implications. Along the way I've also refactored the registration of functions for the config. Previously these ran as import-time side-effects, using the registry decorator. I've created instead a new module spacy.registrations. When the registry is accessed it calls a function ensure_populated(), which cases the registrations to occur. I've made a similar change to the Language.factory registrations in the new spacy.pipeline.factories module. I want to remove these import-time side-effects so that we can speed up the loading time of the library, which can be especially painful on the CLI. I also find that I'm often working to track down the implementations of functions referenced by strings in the config. Having the registrations all happen in one place will make this easier. With these changes I've fortunately avoided the need to migrate to Pydantic v2 properly --- we're still using the v1 compatibility shim. We might not be able to hold out forever though: Pydantic (reasonably) aren't actively supporting the v1 shims. I put a lot of work into v2 migration when investigating the 3.13 support, and it's definitely challenging. In any case, it's a relief that we don't have to do the v2 migration at the same time as the Cython 3.0/Python 3.13 support.
This commit is contained in:
parent
911539e9a4
commit
5bebbf7550
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
|
@ -45,11 +45,12 @@ jobs:
|
|||
run: |
|
||||
python -m pip install flake8==5.0.4
|
||||
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
||||
- name: cython-lint
|
||||
run: |
|
||||
python -m pip install cython-lint -c requirements.txt
|
||||
# E501: line too log, W291: trailing whitespace, E266: too many leading '#' for block comment
|
||||
cython-lint spacy --ignore E501,W291,E266
|
||||
# Unfortunately cython-lint isn't working after the shift to Cython 3.
|
||||
#- name: cython-lint
|
||||
# run: |
|
||||
# python -m pip install cython-lint -c requirements.txt
|
||||
# # E501: line too log, W291: trailing whitespace, E266: too many leading '#' for block comment
|
||||
# cython-lint spacy --ignore E501,W291,E266
|
||||
|
||||
tests:
|
||||
name: Test
|
||||
|
@ -58,7 +59,7 @@ jobs:
|
|||
fail-fast: true
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python_version: ["3.9", "3.12"]
|
||||
python_version: ["3.9", "3.12", "3.13"]
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
|
|
@ -4,5 +4,6 @@ include README.md
|
|||
include pyproject.toml
|
||||
include spacy/py.typed
|
||||
recursive-include spacy/cli *.yml
|
||||
recursive-include spacy/tests *.json
|
||||
recursive-include licenses *
|
||||
recursive-exclude spacy *.cpp
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[build-system]
|
||||
requires = [
|
||||
"setuptools",
|
||||
"cython>=0.25,<3.0",
|
||||
"cython>=3.0,<4.0",
|
||||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
|
|
|
@ -23,7 +23,7 @@ setuptools
|
|||
packaging>=20.0
|
||||
# Development dependencies
|
||||
pre-commit>=2.13.0
|
||||
cython>=0.25,<3.0
|
||||
cython>=3.0,<4.0
|
||||
pytest>=5.2.0,!=7.1.0
|
||||
pytest-timeout>=1.3.0,<2.0.0
|
||||
mock>=2.0.0,<3.0.0
|
||||
|
|
|
@ -30,11 +30,11 @@ project_urls =
|
|||
[options]
|
||||
zip_safe = false
|
||||
include_package_data = true
|
||||
python_requires = >=3.9,<3.13
|
||||
python_requires = >=3.9,<3.14
|
||||
# NOTE: This section is superseded by pyproject.toml and will be removed in
|
||||
# spaCy v4
|
||||
setup_requires =
|
||||
cython>=0.25,<3.0
|
||||
cython>=3.0,<4.0
|
||||
numpy>=2.0.0,<3.0.0; python_version < "3.9"
|
||||
numpy>=2.0.0,<3.0.0; python_version >= "3.9"
|
||||
# We also need our Cython packages here to compile against
|
||||
|
|
|
@ -17,6 +17,7 @@ from .cli.info import info # noqa: F401
|
|||
from .errors import Errors
|
||||
from .glossary import explain # noqa: F401
|
||||
from .language import Language
|
||||
from .registrations import REGISTRY_POPULATED, populate_registry
|
||||
from .util import logger, registry # noqa: F401
|
||||
from .vocab import Vocab
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ split_mode = null
|
|||
"""
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.ja.JapaneseTokenizer")
|
||||
def create_tokenizer(split_mode: Optional[str] = None):
|
||||
def japanese_tokenizer_factory(nlp):
|
||||
return JapaneseTokenizer(nlp.vocab, split_mode=split_mode)
|
||||
|
|
|
@ -20,7 +20,6 @@ DEFAULT_CONFIG = """
|
|||
"""
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.ko.KoreanTokenizer")
|
||||
def create_tokenizer():
|
||||
def korean_tokenizer_factory(nlp):
|
||||
return KoreanTokenizer(nlp.vocab)
|
||||
|
|
|
@ -13,7 +13,6 @@ DEFAULT_CONFIG = """
|
|||
"""
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.th.ThaiTokenizer")
|
||||
def create_thai_tokenizer():
|
||||
def thai_tokenizer_factory(nlp):
|
||||
return ThaiTokenizer(nlp.vocab)
|
||||
|
|
|
@ -22,7 +22,6 @@ use_pyvi = true
|
|||
"""
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.vi.VietnameseTokenizer")
|
||||
def create_vietnamese_tokenizer(use_pyvi: bool = True):
|
||||
def vietnamese_tokenizer_factory(nlp):
|
||||
return VietnameseTokenizer(nlp.vocab, use_pyvi=use_pyvi)
|
||||
|
|
|
@ -46,7 +46,6 @@ class Segmenter(str, Enum):
|
|||
return list(cls.__members__.keys())
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.zh.ChineseTokenizer")
|
||||
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
|
||||
def chinese_tokenizer_factory(nlp):
|
||||
return ChineseTokenizer(nlp.vocab, segmenter=segmenter)
|
||||
|
|
|
@ -104,7 +104,6 @@ class BaseDefaults:
|
|||
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
|
||||
|
||||
|
||||
@registry.tokenizers("spacy.Tokenizer.v1")
|
||||
def create_tokenizer() -> Callable[["Language"], Tokenizer]:
|
||||
"""Registered function to create a tokenizer. Returns a factory that takes
|
||||
the nlp object and returns a Tokenizer instance using the language detaults.
|
||||
|
@ -130,7 +129,6 @@ def create_tokenizer() -> Callable[["Language"], Tokenizer]:
|
|||
return tokenizer_factory
|
||||
|
||||
|
||||
@registry.misc("spacy.LookupsDataLoader.v1")
|
||||
def load_lookups_data(lang, tables):
|
||||
util.logger.debug("Loading lookups from spacy-lookups-data: %s", tables)
|
||||
lookups = load_lookups(lang=lang, tables=tables)
|
||||
|
@ -185,6 +183,9 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#init
|
||||
"""
|
||||
from .pipeline.factories import register_factories
|
||||
|
||||
register_factories()
|
||||
# We're only calling this to import all factories provided via entry
|
||||
# points. The factory decorator applied to these functions takes care
|
||||
# of the rest.
|
||||
|
|
|
@ -35,7 +35,7 @@ cdef class Lexeme:
|
|||
return self
|
||||
|
||||
@staticmethod
|
||||
cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) nogil:
|
||||
cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) noexcept nogil:
|
||||
if name < (sizeof(flags_t) * 8):
|
||||
Lexeme.c_set_flag(lex, name, value)
|
||||
elif name == ID:
|
||||
|
@ -54,7 +54,7 @@ cdef class Lexeme:
|
|||
lex.lang = value
|
||||
|
||||
@staticmethod
|
||||
cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil:
|
||||
cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) noexcept nogil:
|
||||
if feat_name < (sizeof(flags_t) * 8):
|
||||
if Lexeme.c_check_flag(lex, feat_name):
|
||||
return 1
|
||||
|
@ -82,7 +82,7 @@ cdef class Lexeme:
|
|||
return 0
|
||||
|
||||
@staticmethod
|
||||
cdef inline bint c_check_flag(const LexemeC* lexeme, attr_id_t flag_id) nogil:
|
||||
cdef inline bint c_check_flag(const LexemeC* lexeme, attr_id_t flag_id) noexcept nogil:
|
||||
cdef flags_t one = 1
|
||||
if lexeme.flags & (one << flag_id):
|
||||
return True
|
||||
|
@ -90,7 +90,7 @@ cdef class Lexeme:
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef inline bint c_set_flag(LexemeC* lex, attr_id_t flag_id, bint value) nogil:
|
||||
cdef inline bint c_set_flag(LexemeC* lex, attr_id_t flag_id, bint value) noexcept nogil:
|
||||
cdef flags_t one = 1
|
||||
if value:
|
||||
lex.flags |= one << flag_id
|
||||
|
|
|
@ -70,7 +70,7 @@ cdef class Lexeme:
|
|||
if isinstance(other, Lexeme):
|
||||
a = self.orth
|
||||
b = other.orth
|
||||
elif isinstance(other, long):
|
||||
elif isinstance(other, int):
|
||||
a = self.orth
|
||||
b = other
|
||||
elif isinstance(other, str):
|
||||
|
@ -104,7 +104,7 @@ cdef class Lexeme:
|
|||
# skip PROB, e.g. from lexemes.jsonl
|
||||
if isinstance(value, float):
|
||||
continue
|
||||
elif isinstance(value, (int, long)):
|
||||
elif isinstance(value, int):
|
||||
Lexeme.set_struct_attr(self.c, attr, value)
|
||||
else:
|
||||
Lexeme.set_struct_attr(self.c, attr, self.vocab.strings.add(value))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: binding=True, infer_types=True
|
||||
# cython: binding=True, infer_types=True, language_level=3
|
||||
from cpython.object cimport PyObject
|
||||
from libc.stdint cimport int64_t
|
||||
|
||||
|
@ -27,6 +27,5 @@ cpdef bint levenshtein_compare(input_text: str, pattern_text: str, fuzzy: int =
|
|||
return levenshtein(input_text, pattern_text, max_edits) <= max_edits
|
||||
|
||||
|
||||
@registry.misc("spacy.levenshtein_compare.v1")
|
||||
def make_levenshtein_compare():
|
||||
return levenshtein_compare
|
||||
|
|
|
@ -625,7 +625,7 @@ cdef action_t get_action(
|
|||
const TokenC * token,
|
||||
const attr_t * extra_attrs,
|
||||
const int8_t * predicate_matches
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
"""We need to consider:
|
||||
a) Does the token match the specification? [Yes, No]
|
||||
b) What's the quantifier? [1, 0+, ?]
|
||||
|
@ -740,7 +740,7 @@ cdef int8_t get_is_match(
|
|||
const TokenC* token,
|
||||
const attr_t* extra_attrs,
|
||||
const int8_t* predicate_matches
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
for i in range(state.pattern.nr_py):
|
||||
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||
return 0
|
||||
|
@ -755,14 +755,14 @@ cdef int8_t get_is_match(
|
|||
return True
|
||||
|
||||
|
||||
cdef inline int8_t get_is_final(PatternStateC state) nogil:
|
||||
cdef inline int8_t get_is_final(PatternStateC state) noexcept nogil:
|
||||
if state.pattern[1].quantifier == FINAL_ID:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
cdef inline int8_t get_quantifier(PatternStateC state) nogil:
|
||||
cdef inline int8_t get_quantifier(PatternStateC state) noexcept nogil:
|
||||
return state.pattern.quantifier
|
||||
|
||||
|
||||
|
@ -805,7 +805,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
|||
return pattern
|
||||
|
||||
|
||||
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||
cdef attr_t get_ent_id(const TokenPatternC* pattern) noexcept nogil:
|
||||
while pattern.quantifier != FINAL_ID:
|
||||
pattern += 1
|
||||
id_attr = pattern[0].attrs[0]
|
||||
|
|
|
@ -47,7 +47,7 @@ cdef class PhraseMatcher:
|
|||
self._terminal_hash = 826361138722620965
|
||||
map_init(self.mem, self.c_map, 8)
|
||||
|
||||
if isinstance(attr, (int, long)):
|
||||
if isinstance(attr, int):
|
||||
self.attr = attr
|
||||
else:
|
||||
if attr is None:
|
||||
|
|
|
@ -7,7 +7,6 @@ from ..tokens import Doc
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.CharEmbed.v1")
|
||||
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
|
||||
# nM: Number of dimensions per character. nC: Number of characters.
|
||||
return Model(
|
||||
|
|
|
@ -3,7 +3,6 @@ from thinc.api import Model, normal_init
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.PrecomputableAffine.v1")
|
||||
def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
|
||||
model = Model(
|
||||
"precomputable_affine",
|
||||
|
|
|
@ -50,7 +50,6 @@ def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
|
|||
return nlp
|
||||
|
||||
|
||||
@registry.callbacks("spacy.models_with_nvtx_range.v1")
|
||||
def create_models_with_nvtx_range(
|
||||
forward_color: int = -1, backprop_color: int = -1
|
||||
) -> Callable[["Language"], "Language"]:
|
||||
|
@ -110,7 +109,6 @@ def pipes_with_nvtx_range(
|
|||
return nlp
|
||||
|
||||
|
||||
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
|
||||
def create_models_and_pipes_with_nvtx_range(
|
||||
forward_color: int = -1,
|
||||
backprop_color: int = -1,
|
||||
|
|
|
@ -4,7 +4,6 @@ from ..attrs import LOWER
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.extract_ngrams.v1")
|
||||
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
|
||||
model: Model = Model("extract_ngrams", forward)
|
||||
model.attrs["ngram_size"] = ngram_size
|
||||
|
|
|
@ -6,7 +6,6 @@ from thinc.types import Ints1d, Ragged
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.extract_spans.v1")
|
||||
def extract_spans() -> Model[Tuple[Ragged, Ragged], Ragged]:
|
||||
"""Extract spans from a sequence of source arrays, as specified by an array
|
||||
of (start, end) indices. The output is a ragged array of the
|
||||
|
|
|
@ -6,8 +6,9 @@ from thinc.types import Ints2d
|
|||
from ..tokens import Doc
|
||||
|
||||
|
||||
@registry.layers("spacy.FeatureExtractor.v1")
|
||||
def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]:
|
||||
def FeatureExtractor(
|
||||
columns: Union[List[str], List[int], List[Union[int, str]]]
|
||||
) -> Model[List[Doc], List[Ints2d]]:
|
||||
return Model("extract_features", forward, attrs={"columns": columns})
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ from ...vocab import Vocab
|
|||
from ..extract_spans import extract_spans
|
||||
|
||||
|
||||
@registry.architectures("spacy.EntityLinker.v2")
|
||||
def build_nel_encoder(
|
||||
tok2vec: Model, nO: Optional[int] = None
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
|
@ -92,7 +91,6 @@ def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callab
|
|||
return out, lambda x: []
|
||||
|
||||
|
||||
@registry.misc("spacy.KBFromFile.v1")
|
||||
def load_kb(
|
||||
kb_path: Path,
|
||||
) -> Callable[[Vocab], KnowledgeBase]:
|
||||
|
@ -104,7 +102,6 @@ def load_kb(
|
|||
return kb_from_file
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v2")
|
||||
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
|
@ -112,7 +109,6 @@ def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
|
|||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v1")
|
||||
def empty_kb(
|
||||
entity_vector_length: int,
|
||||
) -> Callable[[Vocab], KnowledgeBase]:
|
||||
|
@ -122,12 +118,10 @@ def empty_kb(
|
|||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.misc("spacy.CandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
||||
return get_candidates
|
||||
|
||||
|
||||
@registry.misc("spacy.CandidateBatchGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
]:
|
||||
|
|
|
@ -30,7 +30,6 @@ if TYPE_CHECKING:
|
|||
from ...vocab import Vocab # noqa: F401
|
||||
|
||||
|
||||
@registry.architectures("spacy.PretrainVectors.v1")
|
||||
def create_pretrain_vectors(
|
||||
maxout_pieces: int, hidden_size: int, loss: str
|
||||
) -> Callable[["Vocab", Model], Model]:
|
||||
|
@ -57,7 +56,6 @@ def create_pretrain_vectors(
|
|||
return create_vectors_objective
|
||||
|
||||
|
||||
@registry.architectures("spacy.PretrainCharacters.v1")
|
||||
def create_pretrain_characters(
|
||||
maxout_pieces: int, hidden_size: int, n_characters: int
|
||||
) -> Callable[["Vocab", Model], Model]:
|
||||
|
|
|
@ -11,7 +11,6 @@ from .._precomputable_affine import PrecomputableAffine
|
|||
from ..tb_framework import TransitionModel
|
||||
|
||||
|
||||
@registry.architectures("spacy.TransitionBasedParser.v2")
|
||||
def build_tb_parser_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
state_type: Literal["parser", "ner"],
|
||||
|
|
|
@ -10,7 +10,6 @@ InT = List[Doc]
|
|||
OutT = Floats2d
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanFinder.v1")
|
||||
def build_finder_model(
|
||||
tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT]
|
||||
) -> Model[InT, OutT]:
|
||||
|
|
|
@ -22,7 +22,6 @@ from ...util import registry
|
|||
from ..extract_spans import extract_spans
|
||||
|
||||
|
||||
@registry.layers("spacy.LinearLogistic.v1")
|
||||
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
||||
"""An output layer for multi-label classification. It uses a linear layer
|
||||
followed by a logistic activation.
|
||||
|
@ -30,7 +29,6 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
|||
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
|
||||
|
||||
|
||||
@registry.layers("spacy.mean_max_reducer.v1")
|
||||
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
||||
"""Reduce sequences by concatenating their mean and max pooled vectors,
|
||||
and then combine the concatenated vectors with a hidden layer.
|
||||
|
@ -46,7 +44,6 @@ def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanCategorizer.v1")
|
||||
def build_spancat_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
reducer: Model[Ragged, Floats2d],
|
||||
|
|
|
@ -7,7 +7,6 @@ from ...tokens import Doc
|
|||
from ...util import registry
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tagger.v2")
|
||||
def build_tagger_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None, normalize=False
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
|
|
@ -44,7 +44,6 @@ from .tok2vec import get_tok2vec_width
|
|||
NEG_VALUE = -5000
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatCNN.v2")
|
||||
def build_simple_cnn_text_classifier(
|
||||
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
|
@ -72,7 +71,6 @@ def resize_and_set_ref(model, new_nO, resizable_layer):
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatBOW.v2")
|
||||
def build_bow_text_classifier(
|
||||
exclusive_classes: bool,
|
||||
ngram_size: int,
|
||||
|
@ -88,7 +86,6 @@ def build_bow_text_classifier(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatBOW.v3")
|
||||
def build_bow_text_classifier_v3(
|
||||
exclusive_classes: bool,
|
||||
ngram_size: int,
|
||||
|
@ -142,7 +139,6 @@ def _build_bow_text_classifier(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatEnsemble.v2")
|
||||
def build_text_classifier_v2(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
linear_model: Model[List[Doc], Floats2d],
|
||||
|
@ -200,7 +196,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(
|
||||
width: int, dropout: Optional[float], nO: Optional[int] = None
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
|
@ -221,7 +216,6 @@ def build_text_classifier_lowdata(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatParametricAttention.v1")
|
||||
def build_textcat_parametric_attention_v1(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
exclusive_classes: bool,
|
||||
|
@ -294,7 +288,6 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatReduce.v1")
|
||||
def build_reduce_text_classifier(
|
||||
tok2vec: Model,
|
||||
exclusive_classes: bool,
|
||||
|
|
|
@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor
|
|||
from ..staticvectors import StaticVectors
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2VecListener.v1")
|
||||
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||
return tok2vec
|
||||
|
@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model):
|
|||
return nO
|
||||
|
||||
|
||||
@registry.architectures("spacy.HashEmbedCNN.v2")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
|
@ -102,7 +100,6 @@ def build_hash_embed_cnn_tok2vec(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2Vec.v2")
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
|
@ -123,10 +120,9 @@ def build_Tok2Vec_model(
|
|||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures("spacy.MultiHashEmbed.v2")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
attrs: Union[List[str], List[int], List[Union[str, int]]],
|
||||
rows: List[int],
|
||||
include_static_vectors: bool,
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
@ -192,7 +188,7 @@ def MultiHashEmbed(
|
|||
)
|
||||
else:
|
||||
model = chain(
|
||||
FeatureExtractor(list(attrs)),
|
||||
FeatureExtractor(attrs),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
max_out,
|
||||
|
@ -201,7 +197,6 @@ def MultiHashEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.CharacterEmbed.v2")
|
||||
def CharacterEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
|
@ -278,7 +273,6 @@ def CharacterEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
|
||||
def MaxoutWindowEncoder(
|
||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -310,7 +304,6 @@ def MaxoutWindowEncoder(
|
|||
return with_array(model, pad=receptive_field)
|
||||
|
||||
|
||||
@registry.architectures("spacy.MishWindowEncoder.v2")
|
||||
def MishWindowEncoder(
|
||||
width: int, window_size: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -333,7 +326,6 @@ def MishWindowEncoder(
|
|||
return with_array(model)
|
||||
|
||||
|
||||
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
|
||||
def BiLSTMEncoder(
|
||||
width: int, depth: int, dropout: float
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
|
|
@ -52,14 +52,14 @@ cdef SizesC get_c_sizes(model, int batch_size) except *:
|
|||
return output
|
||||
|
||||
|
||||
cdef ActivationsC alloc_activations(SizesC n) nogil:
|
||||
cdef ActivationsC alloc_activations(SizesC n) noexcept nogil:
|
||||
cdef ActivationsC A
|
||||
memset(&A, 0, sizeof(A))
|
||||
resize_activations(&A, n)
|
||||
return A
|
||||
|
||||
|
||||
cdef void free_activations(const ActivationsC* A) nogil:
|
||||
cdef void free_activations(const ActivationsC* A) noexcept nogil:
|
||||
free(A.token_ids)
|
||||
free(A.scores)
|
||||
free(A.unmaxed)
|
||||
|
@ -67,7 +67,7 @@ cdef void free_activations(const ActivationsC* A) nogil:
|
|||
free(A.is_valid)
|
||||
|
||||
|
||||
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||
cdef void resize_activations(ActivationsC* A, SizesC n) noexcept nogil:
|
||||
if n.states <= A._max_size:
|
||||
A._curr_size = n.states
|
||||
return
|
||||
|
@ -100,7 +100,7 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|||
|
||||
cdef void predict_states(
|
||||
CBlas cblas, ActivationsC* A, StateC** states, const WeightsC* W, SizesC n
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
resize_activations(A, n)
|
||||
for i in range(n.states):
|
||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||
|
@ -159,7 +159,7 @@ cdef void sum_state_features(
|
|||
int B,
|
||||
int F,
|
||||
int O
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
cdef int idx, b, f
|
||||
cdef const float* feature
|
||||
padding = cached
|
||||
|
@ -183,7 +183,7 @@ cdef void cpu_log_loss(
|
|||
const int* is_valid,
|
||||
const float* scores,
|
||||
int O
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
"""Do multi-label log loss"""
|
||||
cdef double max_, gmax, Z, gZ
|
||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
||||
|
@ -209,7 +209,7 @@ cdef void cpu_log_loss(
|
|||
|
||||
cdef int arg_max_if_gold(
|
||||
const weight_t* scores, const weight_t* costs, const int* is_valid, int n
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
# Find minimum cost
|
||||
cdef float cost = 1
|
||||
for i in range(n):
|
||||
|
@ -224,7 +224,7 @@ cdef int arg_max_if_gold(
|
|||
return best
|
||||
|
||||
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) noexcept nogil:
|
||||
cdef int best = -1
|
||||
for i in range(n):
|
||||
if is_valid[i] >= 1:
|
||||
|
|
|
@ -13,7 +13,6 @@ from ..vectors import Mode, Vectors
|
|||
from ..vocab import Vocab
|
||||
|
||||
|
||||
@registry.layers("spacy.StaticVectors.v2")
|
||||
def StaticVectors(
|
||||
nO: Optional[int] = None,
|
||||
nM: Optional[int] = None,
|
||||
|
|
|
@ -4,7 +4,6 @@ from ..util import registry
|
|||
from .parser_model import ParserStepModel
|
||||
|
||||
|
||||
@registry.layers("spacy.TransitionModel.v1")
|
||||
def TransitionModel(
|
||||
tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()
|
||||
):
|
||||
|
|
|
@ -25,3 +25,8 @@ IDS = {
|
|||
|
||||
|
||||
NAMES = {value: key for key, value in IDS.items()}
|
||||
|
||||
# As of Cython 3.1, the global Python namespace no longer has the enum
|
||||
# contents by default.
|
||||
globals().update(IDS)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from ...typedefs cimport attr_t
|
|||
from ...vocab cimport EMPTY_LEXEME
|
||||
|
||||
|
||||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||
cdef inline bint is_space_token(const TokenC* token) noexcept nogil:
|
||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||
|
||||
cdef struct ArcC:
|
||||
|
@ -41,7 +41,7 @@ cdef cppclass StateC:
|
|||
int offset
|
||||
int _b_i
|
||||
|
||||
__init__(const TokenC* sent, int length) nogil:
|
||||
inline __init__(const TokenC* sent, int length) noexcept nogil:
|
||||
this._sent = sent
|
||||
this._heads = <int*>calloc(length, sizeof(int))
|
||||
if not (this._sent and this._heads):
|
||||
|
@ -57,10 +57,10 @@ cdef cppclass StateC:
|
|||
memset(&this._empty_token, 0, sizeof(TokenC))
|
||||
this._empty_token.lex = &EMPTY_LEXEME
|
||||
|
||||
__dealloc__():
|
||||
inline __dealloc__():
|
||||
free(this._heads)
|
||||
|
||||
void set_context_tokens(int* ids, int n) nogil:
|
||||
inline void set_context_tokens(int* ids, int n) noexcept nogil:
|
||||
cdef int i, j
|
||||
if n == 1:
|
||||
if this.B(0) >= 0:
|
||||
|
@ -131,14 +131,14 @@ cdef cppclass StateC:
|
|||
else:
|
||||
ids[i] = -1
|
||||
|
||||
int S(int i) nogil const:
|
||||
inline int S(int i) noexcept nogil const:
|
||||
if i >= this._stack.size():
|
||||
return -1
|
||||
elif i < 0:
|
||||
return -1
|
||||
return this._stack.at(this._stack.size() - (i+1))
|
||||
|
||||
int B(int i) nogil const:
|
||||
inline int B(int i) noexcept nogil const:
|
||||
if i < 0:
|
||||
return -1
|
||||
elif i < this._rebuffer.size():
|
||||
|
@ -150,19 +150,19 @@ cdef cppclass StateC:
|
|||
else:
|
||||
return b_i
|
||||
|
||||
const TokenC* B_(int i) nogil const:
|
||||
inline const TokenC* B_(int i) noexcept nogil const:
|
||||
return this.safe_get(this.B(i))
|
||||
|
||||
const TokenC* E_(int i) nogil const:
|
||||
inline const TokenC* E_(int i) noexcept nogil const:
|
||||
return this.safe_get(this.E(i))
|
||||
|
||||
const TokenC* safe_get(int i) nogil const:
|
||||
inline const TokenC* safe_get(int i) noexcept nogil const:
|
||||
if i < 0 or i >= this.length:
|
||||
return &this._empty_token
|
||||
else:
|
||||
return &this._sent[i]
|
||||
|
||||
void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) nogil const:
|
||||
inline void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) noexcept nogil const:
|
||||
cdef const vector[ArcC]* arcs
|
||||
head_arcs_it = heads_arcs.const_begin()
|
||||
while head_arcs_it != heads_arcs.const_end():
|
||||
|
@ -175,23 +175,23 @@ cdef cppclass StateC:
|
|||
incr(arcs_it)
|
||||
incr(head_arcs_it)
|
||||
|
||||
void get_arcs(vector[ArcC]* out) nogil const:
|
||||
inline void get_arcs(vector[ArcC]* out) noexcept nogil const:
|
||||
this.map_get_arcs(this._left_arcs, out)
|
||||
this.map_get_arcs(this._right_arcs, out)
|
||||
|
||||
int H(int child) nogil const:
|
||||
inline int H(int child) noexcept nogil const:
|
||||
if child >= this.length or child < 0:
|
||||
return -1
|
||||
else:
|
||||
return this._heads[child]
|
||||
|
||||
int E(int i) nogil const:
|
||||
inline int E(int i) noexcept nogil const:
|
||||
if this._ents.size() == 0:
|
||||
return -1
|
||||
else:
|
||||
return this._ents.back().start
|
||||
|
||||
int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) nogil const:
|
||||
inline int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) noexcept nogil const:
|
||||
if idx < 1:
|
||||
return -1
|
||||
|
||||
|
@ -215,22 +215,22 @@ cdef cppclass StateC:
|
|||
|
||||
return -1
|
||||
|
||||
int L(int head, int idx) nogil const:
|
||||
inline int L(int head, int idx) noexcept nogil const:
|
||||
return this.nth_child(this._left_arcs, head, idx)
|
||||
|
||||
int R(int head, int idx) nogil const:
|
||||
inline int R(int head, int idx) noexcept nogil const:
|
||||
return this.nth_child(this._right_arcs, head, idx)
|
||||
|
||||
bint empty() nogil const:
|
||||
inline bint empty() noexcept nogil const:
|
||||
return this._stack.size() == 0
|
||||
|
||||
bint eol() nogil const:
|
||||
inline bint eol() noexcept nogil const:
|
||||
return this.buffer_length() == 0
|
||||
|
||||
bint is_final() nogil const:
|
||||
inline bint is_final() noexcept nogil const:
|
||||
return this.stack_depth() <= 0 and this.eol()
|
||||
|
||||
int cannot_sent_start(int word) nogil const:
|
||||
inline int cannot_sent_start(int word) noexcept nogil const:
|
||||
if word < 0 or word >= this.length:
|
||||
return 0
|
||||
elif this._sent[word].sent_start == -1:
|
||||
|
@ -238,7 +238,7 @@ cdef cppclass StateC:
|
|||
else:
|
||||
return 0
|
||||
|
||||
int is_sent_start(int word) nogil const:
|
||||
inline int is_sent_start(int word) noexcept nogil const:
|
||||
if word < 0 or word >= this.length:
|
||||
return 0
|
||||
elif this._sent[word].sent_start == 1:
|
||||
|
@ -248,20 +248,20 @@ cdef cppclass StateC:
|
|||
else:
|
||||
return 0
|
||||
|
||||
void set_sent_start(int word, int value) nogil:
|
||||
inline void set_sent_start(int word, int value) noexcept nogil:
|
||||
if value >= 1:
|
||||
this._sent_starts.insert(word)
|
||||
|
||||
bint has_head(int child) nogil const:
|
||||
inline bint has_head(int child) noexcept nogil const:
|
||||
return this._heads[child] >= 0
|
||||
|
||||
int l_edge(int word) nogil const:
|
||||
inline int l_edge(int word) noexcept nogil const:
|
||||
return word
|
||||
|
||||
int r_edge(int word) nogil const:
|
||||
inline int r_edge(int word) noexcept nogil const:
|
||||
return word
|
||||
|
||||
int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) nogil const:
|
||||
inline int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) noexcept nogil const:
|
||||
cdef int n = 0
|
||||
head_arcs_it = heads_arcs.const_find(head)
|
||||
if head_arcs_it == heads_arcs.const_end():
|
||||
|
@ -277,28 +277,28 @@ cdef cppclass StateC:
|
|||
|
||||
return n
|
||||
|
||||
int n_L(int head) nogil const:
|
||||
inline int n_L(int head) noexcept nogil const:
|
||||
return n_arcs(this._left_arcs, head)
|
||||
|
||||
int n_R(int head) nogil const:
|
||||
inline int n_R(int head) noexcept nogil const:
|
||||
return n_arcs(this._right_arcs, head)
|
||||
|
||||
bint stack_is_connected() nogil const:
|
||||
inline bint stack_is_connected() noexcept nogil const:
|
||||
return False
|
||||
|
||||
bint entity_is_open() nogil const:
|
||||
inline bint entity_is_open() noexcept nogil const:
|
||||
if this._ents.size() == 0:
|
||||
return False
|
||||
else:
|
||||
return this._ents.back().end == -1
|
||||
|
||||
int stack_depth() nogil const:
|
||||
inline int stack_depth() noexcept nogil const:
|
||||
return this._stack.size()
|
||||
|
||||
int buffer_length() nogil const:
|
||||
inline int buffer_length() noexcept nogil const:
|
||||
return (this.length - this._b_i) + this._rebuffer.size()
|
||||
|
||||
void push() nogil:
|
||||
inline void push() noexcept nogil:
|
||||
b0 = this.B(0)
|
||||
if this._rebuffer.size():
|
||||
b0 = this._rebuffer.back()
|
||||
|
@ -308,32 +308,32 @@ cdef cppclass StateC:
|
|||
this._b_i += 1
|
||||
this._stack.push_back(b0)
|
||||
|
||||
void pop() nogil:
|
||||
inline void pop() noexcept nogil:
|
||||
this._stack.pop_back()
|
||||
|
||||
void force_final() nogil:
|
||||
inline void force_final() noexcept nogil:
|
||||
# This should only be used in desperate situations, as it may leave
|
||||
# the analysis in an unexpected state.
|
||||
this._stack.clear()
|
||||
this._b_i = this.length
|
||||
|
||||
void unshift() nogil:
|
||||
inline void unshift() noexcept nogil:
|
||||
s0 = this._stack.back()
|
||||
this._unshiftable[s0] = 1
|
||||
this._rebuffer.push_back(s0)
|
||||
this._stack.pop_back()
|
||||
|
||||
int is_unshiftable(int item) nogil const:
|
||||
inline int is_unshiftable(int item) noexcept nogil const:
|
||||
if item >= this._unshiftable.size():
|
||||
return 0
|
||||
else:
|
||||
return this._unshiftable.at(item)
|
||||
|
||||
void set_reshiftable(int item) nogil:
|
||||
inline void set_reshiftable(int item) noexcept nogil:
|
||||
if item < this._unshiftable.size():
|
||||
this._unshiftable[item] = 0
|
||||
|
||||
void add_arc(int head, int child, attr_t label) nogil:
|
||||
inline void add_arc(int head, int child, attr_t label) noexcept nogil:
|
||||
if this.has_head(child):
|
||||
this.del_arc(this.H(child), child)
|
||||
cdef ArcC arc
|
||||
|
@ -346,7 +346,7 @@ cdef cppclass StateC:
|
|||
this._right_arcs[arc.head].push_back(arc)
|
||||
this._heads[child] = head
|
||||
|
||||
void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
|
||||
inline void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) noexcept nogil:
|
||||
arcs_it = heads_arcs.find(h_i)
|
||||
if arcs_it == heads_arcs.end():
|
||||
return
|
||||
|
@ -367,13 +367,13 @@ cdef cppclass StateC:
|
|||
arc.label = 0
|
||||
break
|
||||
|
||||
void del_arc(int h_i, int c_i) nogil:
|
||||
inline void del_arc(int h_i, int c_i) noexcept nogil:
|
||||
if h_i > c_i:
|
||||
this.map_del_arc(&this._left_arcs, h_i, c_i)
|
||||
else:
|
||||
this.map_del_arc(&this._right_arcs, h_i, c_i)
|
||||
|
||||
SpanC get_ent() nogil const:
|
||||
inline SpanC get_ent() noexcept nogil const:
|
||||
cdef SpanC ent
|
||||
if this._ents.size() == 0:
|
||||
ent.start = 0
|
||||
|
@ -383,17 +383,17 @@ cdef cppclass StateC:
|
|||
else:
|
||||
return this._ents.back()
|
||||
|
||||
void open_ent(attr_t label) nogil:
|
||||
inline void open_ent(attr_t label) noexcept nogil:
|
||||
cdef SpanC ent
|
||||
ent.start = this.B(0)
|
||||
ent.label = label
|
||||
ent.end = -1
|
||||
this._ents.push_back(ent)
|
||||
|
||||
void close_ent() nogil:
|
||||
inline void close_ent() noexcept nogil:
|
||||
this._ents.back().end = this.B(0)+1
|
||||
|
||||
void clone(const StateC* src) nogil:
|
||||
inline void clone(const StateC* src) noexcept nogil:
|
||||
this.length = src.length
|
||||
this._sent = src._sent
|
||||
this._stack = src._stack
|
||||
|
|
|
@ -155,7 +155,7 @@ cdef GoldParseStateC create_gold_state(
|
|||
return gs
|
||||
|
||||
|
||||
cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
|
||||
cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) noexcept nogil:
|
||||
for i in range(gs.length):
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
|
@ -239,12 +239,12 @@ def _get_aligned_sent_starts(example):
|
|||
return [None] * len(example.x)
|
||||
|
||||
|
||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||
cdef int check_state_gold(char state_bits, char flag) noexcept nogil:
|
||||
cdef char one = 1
|
||||
return 1 if (state_bits & (one << flag)) else 0
|
||||
|
||||
|
||||
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
|
||||
cdef int set_state_flag(char state_bits, char flag, int value) noexcept nogil:
|
||||
cdef char one = 1
|
||||
if value:
|
||||
return state_bits | (one << flag)
|
||||
|
@ -252,27 +252,27 @@ cdef int set_state_flag(char state_bits, char flag, int value) nogil:
|
|||
return state_bits & ~(one << flag)
|
||||
|
||||
|
||||
cdef int is_head_in_stack(const GoldParseStateC* gold, int i) nogil:
|
||||
cdef int is_head_in_stack(const GoldParseStateC* gold, int i) noexcept nogil:
|
||||
return check_state_gold(gold.state_bits[i], HEAD_IN_STACK)
|
||||
|
||||
|
||||
cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
|
||||
cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) noexcept nogil:
|
||||
return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER)
|
||||
|
||||
|
||||
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||
cdef int is_head_unknown(const GoldParseStateC* gold, int i) noexcept nogil:
|
||||
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
||||
|
||||
cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil:
|
||||
cdef int is_sent_start(const GoldParseStateC* gold, int i) noexcept nogil:
|
||||
return check_state_gold(gold.state_bits[i], IS_SENT_START)
|
||||
|
||||
cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||
cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) noexcept nogil:
|
||||
return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN)
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
||||
cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) noexcept nogil:
|
||||
cdef weight_t cost = 0
|
||||
b0 = state.B(0)
|
||||
if b0 < 0:
|
||||
|
@ -285,7 +285,7 @@ cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
|||
return cost
|
||||
|
||||
|
||||
cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
||||
cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) noexcept nogil:
|
||||
cdef weight_t cost = 0
|
||||
s0 = state.S(0)
|
||||
if s0 < 0:
|
||||
|
@ -296,7 +296,7 @@ cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
|||
return cost
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
|
||||
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) noexcept nogil:
|
||||
if is_head_unknown(gold, child):
|
||||
return True
|
||||
elif gold.heads[child] == head:
|
||||
|
@ -305,7 +305,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
|
|||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
|
||||
cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) noexcept nogil:
|
||||
if is_head_unknown(gold, child):
|
||||
return True
|
||||
elif label == 0:
|
||||
|
@ -316,7 +316,7 @@ cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) no
|
|||
return False
|
||||
|
||||
|
||||
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
|
||||
cdef bint _is_gold_root(const GoldParseStateC* gold, int word) noexcept nogil:
|
||||
return gold.heads[word] == word or is_head_unknown(gold, word)
|
||||
|
||||
|
||||
|
@ -336,7 +336,7 @@ cdef class Shift:
|
|||
* Advance buffer
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return 1
|
||||
elif st.buffer_length() < 2:
|
||||
|
@ -349,11 +349,11 @@ cdef class Shift:
|
|||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.push()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return gold.push_cost
|
||||
|
||||
|
@ -375,7 +375,7 @@ cdef class Reduce:
|
|||
cost by those arcs.
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return False
|
||||
elif st.buffer_length() == 0:
|
||||
|
@ -386,14 +386,14 @@ cdef class Reduce:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
if st.has_head(st.S(0)) or st.stack_depth() == 1:
|
||||
st.pop()
|
||||
else:
|
||||
st.unshift()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
if state.is_sent_start(state.B(0)):
|
||||
return 0
|
||||
|
@ -421,7 +421,7 @@ cdef class LeftArc:
|
|||
pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return 0
|
||||
elif st.buffer_length() == 0:
|
||||
|
@ -434,7 +434,7 @@ cdef class LeftArc:
|
|||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
# If we change the stack, it's okay to remove the shifted mark, as
|
||||
# we can't get in an infinite loop this way.
|
||||
|
@ -442,7 +442,7 @@ cdef class LeftArc:
|
|||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef weight_t cost = gold.pop_cost
|
||||
s0 = state.S(0)
|
||||
|
@ -474,7 +474,7 @@ cdef class RightArc:
|
|||
push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return 0
|
||||
elif st.buffer_length() == 0:
|
||||
|
@ -488,12 +488,12 @@ cdef class RightArc:
|
|||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cost = gold.push_cost
|
||||
s0 = state.S(0)
|
||||
|
@ -525,7 +525,7 @@ cdef class Break:
|
|||
* Arcs between S and B[1]
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if st.buffer_length() < 2:
|
||||
return False
|
||||
elif st.B(1) != st.B(0) + 1:
|
||||
|
@ -538,11 +538,11 @@ cdef class Break:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.set_sent_start(st.B(1), 1)
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef int b0 = state.B(0)
|
||||
cdef int cost = 0
|
||||
|
@ -785,7 +785,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
else:
|
||||
return False
|
||||
|
||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||
cdef int set_valid(self, int* output, const StateC* st) noexcept nogil:
|
||||
cdef int[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(st, 0)
|
||||
is_valid[REDUCE] = Reduce.is_valid(st, 0)
|
||||
|
|
|
@ -110,7 +110,7 @@ cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
|
|||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
||||
|
||||
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
|
||||
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) noexcept nogil:
|
||||
if not state.entity_is_open():
|
||||
return False
|
||||
|
||||
|
@ -238,7 +238,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
|
||||
def add_action(self, int action, label_name, freq=None):
|
||||
cdef attr_t label_id
|
||||
if not isinstance(label_name, (int, long)):
|
||||
if not isinstance(label_name, int):
|
||||
label_id = self.strings.add(label_name)
|
||||
else:
|
||||
label_id = label_name
|
||||
|
@ -347,21 +347,21 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
|
||||
cdef class Missing:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* s, attr_t label) nogil:
|
||||
cdef int transition(StateC* s, attr_t label) noexcept nogil:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
return 9000
|
||||
|
||||
|
||||
cdef class Begin:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
||||
if st.entity_is_open():
|
||||
|
@ -400,13 +400,13 @@ cdef class Begin:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.open_ent(label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
b0 = s.B(0)
|
||||
cdef int cost = 0
|
||||
|
@ -439,7 +439,7 @@ cdef class Begin:
|
|||
|
||||
cdef class In:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
if not st.entity_is_open():
|
||||
return False
|
||||
if st.buffer_length() < 2:
|
||||
|
@ -475,12 +475,12 @@ cdef class In:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
|
@ -510,7 +510,7 @@ cdef class In:
|
|||
|
||||
cdef class Last:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
||||
if label == 0:
|
||||
|
@ -535,13 +535,13 @@ cdef class Last:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.close_ent()
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
b0 = s.B(0)
|
||||
ent_start = s.E(0)
|
||||
|
@ -581,7 +581,7 @@ cdef class Last:
|
|||
|
||||
cdef class Unit:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
||||
if label == 0:
|
||||
|
@ -609,14 +609,14 @@ cdef class Unit:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.open_ent(label)
|
||||
st.close_ent()
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
|
@ -646,7 +646,7 @@ cdef class Unit:
|
|||
|
||||
cdef class Out:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
if st.entity_is_open():
|
||||
return False
|
||||
|
@ -658,12 +658,12 @@ cdef class Out:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
cdef int transition(StateC* st, attr_t label) noexcept nogil:
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef weight_t cost = 0
|
||||
|
|
|
@ -94,7 +94,7 @@ cdef bool _has_head_as_ancestor(int tokenid, int head, const vector[int]& heads)
|
|||
return False
|
||||
|
||||
|
||||
cdef string heads_to_string(const vector[int]& heads) nogil:
|
||||
cdef string heads_to_string(const vector[int]& heads) noexcept nogil:
|
||||
cdef vector[int].const_iterator citer
|
||||
cdef string cycle_str
|
||||
|
||||
|
|
|
@ -15,22 +15,22 @@ cdef struct Transition:
|
|||
|
||||
weight_t score
|
||||
|
||||
bint (*is_valid)(const StateC* state, attr_t label) nogil
|
||||
weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil
|
||||
int (*do)(StateC* state, attr_t label) nogil
|
||||
bint (*is_valid)(const StateC* state, attr_t label) noexcept nogil
|
||||
weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) noexcept nogil
|
||||
int (*do)(StateC* state, attr_t label) noexcept nogil
|
||||
|
||||
|
||||
ctypedef weight_t (*get_cost_func_t)(
|
||||
const StateC* state, const void* gold, attr_tlabel
|
||||
) nogil
|
||||
) noexcept nogil
|
||||
ctypedef weight_t (*move_cost_func_t)(
|
||||
const StateC* state, const void* gold
|
||||
) nogil
|
||||
) noexcept nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(
|
||||
const StateC* state, const void* gold, attr_t label
|
||||
) nogil
|
||||
) noexcept nogil
|
||||
|
||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
|
||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) noexcept nogil
|
||||
|
||||
ctypedef void* (*init_state_t)(Pool mem, int length, void* tokens) except NULL
|
||||
|
||||
|
@ -53,7 +53,7 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef Transition init_transition(self, int clas, int move, attr_t label) except *
|
||||
|
||||
cdef int set_valid(self, int* output, const StateC* st) nogil
|
||||
cdef int set_valid(self, int* output, const StateC* st) noexcept nogil
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
const StateC* state, gold) except -1
|
||||
|
|
|
@ -149,7 +149,7 @@ cdef class TransitionSystem:
|
|||
action = self.lookup_transition(move_name)
|
||||
return action.is_valid(stcls.c, action.label)
|
||||
|
||||
cdef int set_valid(self, int* is_valid, const StateC* st) nogil:
|
||||
cdef int set_valid(self, int* is_valid, const StateC* st) noexcept nogil:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
|
||||
|
@ -191,8 +191,7 @@ cdef class TransitionSystem:
|
|||
|
||||
def add_action(self, int action, label_name):
|
||||
cdef attr_t label_id
|
||||
if not isinstance(label_name, int) and \
|
||||
not isinstance(label_name, long):
|
||||
if not isinstance(label_name, int):
|
||||
label_id = self.strings.add(label_name)
|
||||
else:
|
||||
label_id = label_name
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -22,19 +24,6 @@ TagMapType = Dict[str, Dict[Union[int, str], Union[int, str]]]
|
|||
MorphRulesType = Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"attribute_ruler",
|
||||
default_config={
|
||||
"validate": False,
|
||||
"scorer": {"@scorers": "spacy.attribute_ruler_scorer.v1"},
|
||||
},
|
||||
)
|
||||
def make_attribute_ruler(
|
||||
nlp: Language, name: str, validate: bool, scorer: Optional[Callable]
|
||||
):
|
||||
return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer)
|
||||
|
||||
|
||||
def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
def morph_key_getter(token, attr):
|
||||
return getattr(token, attr).key
|
||||
|
@ -54,7 +43,6 @@ def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, An
|
|||
return results
|
||||
|
||||
|
||||
@registry.scorers("spacy.attribute_ruler_scorer.v1")
|
||||
def make_attribute_ruler_scorer():
|
||||
return attribute_ruler_score
|
||||
|
||||
|
@ -355,3 +343,11 @@ def _split_morph_attrs(attrs: dict) -> Tuple[dict, dict]:
|
|||
else:
|
||||
morph_attrs[k] = v
|
||||
return other_attrs, morph_attrs
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_attribute_ruler":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_attribute_ruler
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
@ -39,188 +41,6 @@ subword_features = true
|
|||
DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
"scorer": {"@scorers": "spacy.parser_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"dep_uas": 0.5,
|
||||
"dep_las": 0.5,
|
||||
"dep_las_per_type": None,
|
||||
"sents_p": None,
|
||||
"sents_r": None,
|
||||
"sents_f": 0.0,
|
||||
},
|
||||
)
|
||||
def make_parser(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
learn_tokens: bool,
|
||||
min_action_freq: int,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Create a transition-based DependencyParser component. The dependency parser
|
||||
jointly learns sentence segmentation and labelled dependency parsing, and can
|
||||
optionally learn to merge tokens that had been over-segmented by the tokenizer.
|
||||
|
||||
The parser uses a variant of the non-monotonic arc-eager transition-system
|
||||
described by Honnibal and Johnson (2014), with the addition of a "break"
|
||||
transition to perform the sentence segmentation. Nivre's pseudo-projective
|
||||
dependency transformation is used to allow the parser to predict
|
||||
non-projective parses.
|
||||
|
||||
The parser is trained using an imitation learning objective. The parser follows
|
||||
the actions predicted by the current weights, and at each state, determines
|
||||
which actions are compatible with the optimal parse that could be reached
|
||||
from the current state. The weights such that the scores assigned to the
|
||||
set of optimal actions is increased, while scores assigned to other
|
||||
actions are decreased. Note that more than one action may be optimal for
|
||||
a given state.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (Optional[TransitionSystem]): This defines how the parse-state is created,
|
||||
updated and evaluated. If 'moves' is None, a new instance is
|
||||
created with `self.TransitionSystem()`. Defaults to `None`.
|
||||
update_with_oracle_cut_size (int): During training, cut long sequences into
|
||||
shorter segments by creating intermediate states based on the gold-standard
|
||||
history. The model is not very sensitive to this parameter, so you usually
|
||||
won't need to change it. 100 is a good default.
|
||||
learn_tokens (bool): Whether to learn to merge subtokens that are split
|
||||
relative to the gold standard. Experimental.
|
||||
min_action_freq (int): The minimum frequency of labelled actions to retain.
|
||||
Rarer labelled actions have their label backed-off to "dep". While this
|
||||
primarily affects the label accuracy, it can also affect the attachment
|
||||
structure, as the labels are used to represent the pseudo-projectivity
|
||||
transformation.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return DependencyParser(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
multitasks=[],
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
# At some point in the future we can try to implement support for
|
||||
# partial annotations, perhaps only in the beam objective.
|
||||
incorrect_spans_key=None,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"beam_parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"beam_width": 8,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
"scorer": {"@scorers": "spacy.parser_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"dep_uas": 0.5,
|
||||
"dep_las": 0.5,
|
||||
"dep_las_per_type": None,
|
||||
"sents_p": None,
|
||||
"sents_r": None,
|
||||
"sents_f": 0.0,
|
||||
},
|
||||
)
|
||||
def make_beam_parser(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
learn_tokens: bool,
|
||||
min_action_freq: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Create a transition-based DependencyParser component that uses beam-search.
|
||||
The dependency parser jointly learns sentence segmentation and labelled
|
||||
dependency parsing, and can optionally learn to merge tokens that had been
|
||||
over-segmented by the tokenizer.
|
||||
|
||||
The parser uses a variant of the non-monotonic arc-eager transition-system
|
||||
described by Honnibal and Johnson (2014), with the addition of a "break"
|
||||
transition to perform the sentence segmentation. Nivre's pseudo-projective
|
||||
dependency transformation is used to allow the parser to predict
|
||||
non-projective parses.
|
||||
|
||||
The parser is trained using a global objective. That is, it learns to assign
|
||||
probabilities to whole parses.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (Optional[TransitionSystem]): This defines how the parse-state is created,
|
||||
updated and evaluated. If 'moves' is None, a new instance is
|
||||
created with `self.TransitionSystem()`. Defaults to `None`.
|
||||
update_with_oracle_cut_size (int): During training, cut long sequences into
|
||||
shorter segments by creating intermediate states based on the gold-standard
|
||||
history. The model is not very sensitive to this parameter, so you usually
|
||||
won't need to change it. 100 is a good default.
|
||||
beam_width (int): The number of candidate analyses to maintain.
|
||||
beam_density (float): The minimum ratio between the scores of the first and
|
||||
last candidates in the beam. This allows the parser to avoid exploring
|
||||
candidates that are too far behind. This is mostly intended to improve
|
||||
efficiency, but it can also improve accuracy as deeper search is not
|
||||
always better.
|
||||
beam_update_prob (float): The chance of making a beam update, instead of a
|
||||
greedy update. Greedy updates are an approximation for the beam updates,
|
||||
and are faster to compute.
|
||||
learn_tokens (bool): Whether to learn to merge subtokens that are split
|
||||
relative to the gold standard. Experimental.
|
||||
min_action_freq (int): The minimum frequency of labelled actions to retain.
|
||||
Rarer labelled actions have their label backed-off to "dep". While this
|
||||
primarily affects the label accuracy, it can also affect the attachment
|
||||
structure, as the labels are used to represent the pseudo-projectivity
|
||||
transformation.
|
||||
"""
|
||||
return DependencyParser(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
multitasks=[],
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
# At some point in the future we can try to implement support for
|
||||
# partial annotations, perhaps only in the beam objective.
|
||||
incorrect_spans_key=None,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def parser_score(examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
|
@ -246,7 +66,6 @@ def parser_score(examples, **kwargs):
|
|||
return results
|
||||
|
||||
|
||||
@registry.scorers("spacy.parser_scorer.v1")
|
||||
def make_parser_scorer():
|
||||
return parser_score
|
||||
|
||||
|
@ -346,3 +165,14 @@ cdef class DependencyParser(Parser):
|
|||
# because we instead have a label frequency cut-off and back off rare
|
||||
# labels to 'dep'.
|
||||
pass
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_parser":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_parser
|
||||
elif name == "make_beam_parser":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_beam_parser
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from collections import Counter
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
||||
|
@ -39,43 +41,6 @@ subword_features = true
|
|||
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"trainable_lemmatizer",
|
||||
assigns=["token.lemma"],
|
||||
requires=[],
|
||||
default_config={
|
||||
"model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL,
|
||||
"backoff": "orth",
|
||||
"min_tree_freq": 3,
|
||||
"overwrite": False,
|
||||
"top_k": 1,
|
||||
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"lemma_acc": 1.0},
|
||||
)
|
||||
def make_edit_tree_lemmatizer(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
backoff: Optional[str],
|
||||
min_tree_freq: int,
|
||||
overwrite: bool,
|
||||
top_k: int,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Construct an EditTreeLemmatizer component."""
|
||||
return EditTreeLemmatizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
backoff=backoff,
|
||||
min_tree_freq=min_tree_freq,
|
||||
overwrite=overwrite,
|
||||
top_k=top_k,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
class EditTreeLemmatizer(TrainablePipe):
|
||||
"""
|
||||
Lemmatizer that lemmatizes each word using a predicted edit tree.
|
||||
|
@ -421,3 +386,11 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
self.tree2label[tree_id] = len(self.cfg["labels"])
|
||||
self.cfg["labels"].append(tree_id)
|
||||
return self.tree2label[tree_id]
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_edit_tree_lemmatizer":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_edit_tree_lemmatizer
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import importlib
|
||||
import random
|
||||
import sys
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
@ -40,117 +42,10 @@ subword_features = true
|
|||
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"entity_linker",
|
||||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||
assigns=["token.ent_kb_id"],
|
||||
default_config={
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"n_sents": 0,
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
||||
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
|
||||
"overwrite": True,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
"use_gold_ents": True,
|
||||
"candidates_batch_size": 1,
|
||||
"threshold": None,
|
||||
},
|
||||
default_score_weights={
|
||||
"nel_micro_f": 1.0,
|
||||
"nel_micro_r": None,
|
||||
"nel_micro_p": None,
|
||||
},
|
||||
)
|
||||
def make_entity_linker(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
n_sents: int,
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
use_gold_ents: bool,
|
||||
candidates_batch_size: int,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
||||
model (Model[List[Doc], Floats2d]): A model that learns document vector
|
||||
representations. Given a batch of Doc objects, it should return a single
|
||||
array, with one row per item in the batch.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
n_sents (int): The number of neighbouring sentences to take into account.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs during training or not. If false, another
|
||||
component must provide entity annotations.
|
||||
candidates_batch_size (int): Size of batches for entity candidate generation.
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
||||
prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||
"""
|
||||
|
||||
if not model.attrs.get("include_span_maker", False):
|
||||
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
|
||||
return EntityLinker_v1(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels_discard=labels_discard,
|
||||
n_sents=n_sents,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
)
|
||||
return EntityLinker(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels_discard=labels_discard,
|
||||
n_sents=n_sents,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
get_candidates_batch=get_candidates_batch,
|
||||
generate_empty_kb=generate_empty_kb,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
use_gold_ents=use_gold_ents,
|
||||
candidates_batch_size=candidates_batch_size,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
|
||||
def entity_linker_score(examples, **kwargs):
|
||||
return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.entity_linker_scorer.v1")
|
||||
def make_entity_linker_scorer():
|
||||
return entity_linker_score
|
||||
|
||||
|
@ -676,3 +571,11 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_entity_linker":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_entity_linker
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
@ -19,51 +21,10 @@ DEFAULT_ENT_ID_SEP = "||"
|
|||
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"entity_ruler",
|
||||
assigns=["doc.ents", "token.ent_type", "token.ent_iob"],
|
||||
default_config={
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite_ents": False,
|
||||
"ent_id_sep": DEFAULT_ENT_ID_SEP,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)
|
||||
def make_entity_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
ent_id_sep: str,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return EntityRuler(
|
||||
nlp,
|
||||
name,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite_ents=overwrite_ents,
|
||||
ent_id_sep=ent_id_sep,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def entity_ruler_score(examples, **kwargs):
|
||||
return get_ner_prf(examples)
|
||||
|
||||
|
||||
@registry.scorers("spacy.entity_ruler_scorer.v1")
|
||||
def make_entity_ruler_scorer():
|
||||
return entity_ruler_score
|
||||
|
||||
|
@ -539,3 +500,11 @@ class EntityRuler(Pipe):
|
|||
srsly.write_jsonl(path, self.patterns)
|
||||
else:
|
||||
to_disk(path, serializers, {})
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_entity_ruler":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_entity_ruler
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
929
spacy/pipeline/factories.py
Normal file
929
spacy/pipeline/factories.py
Normal file
|
@ -0,0 +1,929 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from thinc.api import Model
|
||||
from thinc.types import Floats2d, Ragged
|
||||
|
||||
from ..kb import Candidate, KnowledgeBase
|
||||
from ..language import Language
|
||||
from ..pipeline._parser_internals.transition_system import TransitionSystem
|
||||
from ..pipeline.attributeruler import AttributeRuler
|
||||
from ..pipeline.dep_parser import DEFAULT_PARSER_MODEL, DependencyParser
|
||||
from ..pipeline.edit_tree_lemmatizer import (
|
||||
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL,
|
||||
EditTreeLemmatizer,
|
||||
)
|
||||
|
||||
# Import factory default configurations
|
||||
from ..pipeline.entity_linker import DEFAULT_NEL_MODEL, EntityLinker, EntityLinker_v1
|
||||
from ..pipeline.entityruler import DEFAULT_ENT_ID_SEP, EntityRuler
|
||||
from ..pipeline.functions import DocCleaner, TokenSplitter
|
||||
from ..pipeline.lemmatizer import Lemmatizer
|
||||
from ..pipeline.morphologizer import DEFAULT_MORPH_MODEL, Morphologizer
|
||||
from ..pipeline.multitask import DEFAULT_MT_MODEL, MultitaskObjective
|
||||
from ..pipeline.ner import DEFAULT_NER_MODEL, EntityRecognizer
|
||||
from ..pipeline.sentencizer import Sentencizer
|
||||
from ..pipeline.senter import DEFAULT_SENTER_MODEL, SentenceRecognizer
|
||||
from ..pipeline.span_finder import DEFAULT_SPAN_FINDER_MODEL, SpanFinder
|
||||
from ..pipeline.span_ruler import DEFAULT_SPANS_KEY as SPAN_RULER_DEFAULT_SPANS_KEY
|
||||
from ..pipeline.span_ruler import (
|
||||
SpanRuler,
|
||||
prioritize_existing_ents_filter,
|
||||
prioritize_new_ents_filter,
|
||||
)
|
||||
from ..pipeline.spancat import (
|
||||
DEFAULT_SPANCAT_MODEL,
|
||||
DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||
DEFAULT_SPANS_KEY,
|
||||
SpanCategorizer,
|
||||
Suggester,
|
||||
)
|
||||
from ..pipeline.tagger import DEFAULT_TAGGER_MODEL, Tagger
|
||||
from ..pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL, TextCategorizer
|
||||
from ..pipeline.textcat_multilabel import (
|
||||
DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||
MultiLabel_TextCategorizer,
|
||||
)
|
||||
from ..pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL, Tok2Vec
|
||||
from ..tokens.doc import Doc
|
||||
from ..tokens.span import Span
|
||||
from ..vocab import Vocab
|
||||
|
||||
# Global flag to track if factories have been registered
|
||||
FACTORIES_REGISTERED = False
|
||||
|
||||
|
||||
def register_factories() -> None:
|
||||
"""Register all factories with the registry.
|
||||
|
||||
This function registers all pipeline component factories, centralizing
|
||||
the registrations that were previously done with @Language.factory decorators.
|
||||
"""
|
||||
global FACTORIES_REGISTERED
|
||||
|
||||
if FACTORIES_REGISTERED:
|
||||
return
|
||||
|
||||
# Register factories using the same pattern as Language.factory decorator
|
||||
# We use Language.factory()() pattern which exactly mimics the decorator
|
||||
|
||||
# attributeruler
|
||||
Language.factory(
|
||||
"attribute_ruler",
|
||||
default_config={
|
||||
"validate": False,
|
||||
"scorer": {"@scorers": "spacy.attribute_ruler_scorer.v1"},
|
||||
},
|
||||
)(make_attribute_ruler)
|
||||
|
||||
# entity_linker
|
||||
Language.factory(
|
||||
"entity_linker",
|
||||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||
assigns=["token.ent_kb_id"],
|
||||
default_config={
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"n_sents": 0,
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
||||
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
|
||||
"overwrite": True,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
"use_gold_ents": True,
|
||||
"candidates_batch_size": 1,
|
||||
"threshold": None,
|
||||
},
|
||||
default_score_weights={
|
||||
"nel_micro_f": 1.0,
|
||||
"nel_micro_r": None,
|
||||
"nel_micro_p": None,
|
||||
},
|
||||
)(make_entity_linker)
|
||||
|
||||
# entity_ruler
|
||||
Language.factory(
|
||||
"entity_ruler",
|
||||
assigns=["doc.ents", "token.ent_type", "token.ent_iob"],
|
||||
default_config={
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite_ents": False,
|
||||
"ent_id_sep": DEFAULT_ENT_ID_SEP,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)(make_entity_ruler)
|
||||
|
||||
# lemmatizer
|
||||
Language.factory(
|
||||
"lemmatizer",
|
||||
assigns=["token.lemma"],
|
||||
default_config={
|
||||
"model": None,
|
||||
"mode": "lookup",
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"lemma_acc": 1.0},
|
||||
)(make_lemmatizer)
|
||||
|
||||
# textcat
|
||||
Language.factory(
|
||||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.0,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
"cats_score_desc": None,
|
||||
"cats_micro_p": None,
|
||||
"cats_micro_r": None,
|
||||
"cats_micro_f": None,
|
||||
"cats_macro_p": None,
|
||||
"cats_macro_r": None,
|
||||
"cats_macro_f": None,
|
||||
"cats_macro_auc": None,
|
||||
"cats_f_per_type": None,
|
||||
},
|
||||
)(make_textcat)
|
||||
|
||||
# token_splitter
|
||||
Language.factory(
|
||||
"token_splitter",
|
||||
default_config={"min_length": 25, "split_length": 10},
|
||||
retokenizes=True,
|
||||
)(make_token_splitter)
|
||||
|
||||
# doc_cleaner
|
||||
Language.factory(
|
||||
"doc_cleaner",
|
||||
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
|
||||
)(make_doc_cleaner)
|
||||
|
||||
# tok2vec
|
||||
Language.factory(
|
||||
"tok2vec",
|
||||
assigns=["doc.tensor"],
|
||||
default_config={"model": DEFAULT_TOK2VEC_MODEL},
|
||||
)(make_tok2vec)
|
||||
|
||||
# senter
|
||||
Language.factory(
|
||||
"senter",
|
||||
assigns=["token.is_sent_start"],
|
||||
default_config={
|
||||
"model": DEFAULT_SENTER_MODEL,
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)(make_senter)
|
||||
|
||||
# morphologizer
|
||||
Language.factory(
|
||||
"morphologizer",
|
||||
assigns=["token.morph", "token.pos"],
|
||||
default_config={
|
||||
"model": DEFAULT_MORPH_MODEL,
|
||||
"overwrite": True,
|
||||
"extend": False,
|
||||
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"},
|
||||
"label_smoothing": 0.0,
|
||||
},
|
||||
default_score_weights={
|
||||
"pos_acc": 0.5,
|
||||
"morph_acc": 0.5,
|
||||
"morph_per_feat": None,
|
||||
},
|
||||
)(make_morphologizer)
|
||||
|
||||
# spancat
|
||||
Language.factory(
|
||||
"spancat",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_positive": None,
|
||||
"model": DEFAULT_SPANCAT_MODEL,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)(make_spancat)
|
||||
|
||||
# spancat_singlelabel
|
||||
Language.factory(
|
||||
"spancat_singlelabel",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||
"negative_weight": 1.0,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||
"allow_overlap": True,
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)(make_spancat_singlelabel)
|
||||
|
||||
# future_entity_ruler
|
||||
Language.factory(
|
||||
"future_entity_ruler",
|
||||
assigns=["doc.ents"],
|
||||
default_config={
|
||||
"phrase_matcher_attr": None,
|
||||
"validate": False,
|
||||
"overwrite_ents": False,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
"ent_id_sep": "__unused__",
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)(make_future_entity_ruler)
|
||||
|
||||
# span_ruler
|
||||
Language.factory(
|
||||
"span_ruler",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": SPAN_RULER_DEFAULT_SPANS_KEY,
|
||||
"spans_filter": None,
|
||||
"annotate_ents": False,
|
||||
"ents_filter": {"@misc": "spacy.first_longest_spans_filter.v1"},
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite": True,
|
||||
"scorer": {
|
||||
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
|
||||
"spans_key": SPAN_RULER_DEFAULT_SPANS_KEY,
|
||||
},
|
||||
},
|
||||
default_score_weights={
|
||||
f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_per_type": None,
|
||||
},
|
||||
)(make_span_ruler)
|
||||
|
||||
# trainable_lemmatizer
|
||||
Language.factory(
|
||||
"trainable_lemmatizer",
|
||||
assigns=["token.lemma"],
|
||||
requires=[],
|
||||
default_config={
|
||||
"model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL,
|
||||
"backoff": "orth",
|
||||
"min_tree_freq": 3,
|
||||
"overwrite": False,
|
||||
"top_k": 1,
|
||||
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"lemma_acc": 1.0},
|
||||
)(make_edit_tree_lemmatizer)
|
||||
|
||||
# textcat_multilabel
|
||||
Language.factory(
|
||||
"textcat_multilabel",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
"cats_score_desc": None,
|
||||
"cats_micro_p": None,
|
||||
"cats_micro_r": None,
|
||||
"cats_micro_f": None,
|
||||
"cats_macro_p": None,
|
||||
"cats_macro_r": None,
|
||||
"cats_macro_f": None,
|
||||
"cats_macro_auc": None,
|
||||
"cats_f_per_type": None,
|
||||
},
|
||||
)(make_multilabel_textcat)
|
||||
|
||||
# span_finder
|
||||
Language.factory(
|
||||
"span_finder",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_SPAN_FINDER_MODEL,
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_length": 25,
|
||||
"min_length": None,
|
||||
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
},
|
||||
)(make_span_finder)
|
||||
|
||||
# ner
|
||||
Language.factory(
|
||||
"ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
"incorrect_spans_key": None,
|
||||
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)(make_ner)
|
||||
|
||||
# beam_ner
|
||||
Language.factory(
|
||||
"beam_ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"beam_width": 32,
|
||||
"incorrect_spans_key": None,
|
||||
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)(make_beam_ner)
|
||||
|
||||
# parser
|
||||
Language.factory(
|
||||
"parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
"scorer": {"@scorers": "spacy.parser_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"dep_uas": 0.5,
|
||||
"dep_las": 0.5,
|
||||
"dep_las_per_type": None,
|
||||
"sents_p": None,
|
||||
"sents_r": None,
|
||||
"sents_f": 0.0,
|
||||
},
|
||||
)(make_parser)
|
||||
|
||||
# beam_parser
|
||||
Language.factory(
|
||||
"beam_parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"beam_width": 8,
|
||||
"beam_density": 0.0001,
|
||||
"beam_update_prob": 0.5,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
"scorer": {"@scorers": "spacy.parser_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"dep_uas": 0.5,
|
||||
"dep_las": 0.5,
|
||||
"dep_las_per_type": None,
|
||||
"sents_p": None,
|
||||
"sents_r": None,
|
||||
"sents_f": 0.0,
|
||||
},
|
||||
)(make_beam_parser)
|
||||
|
||||
# tagger
|
||||
Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={
|
||||
"model": DEFAULT_TAGGER_MODEL,
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
|
||||
"neg_prefix": "!",
|
||||
"label_smoothing": 0.0,
|
||||
},
|
||||
default_score_weights={
|
||||
"tag_acc": 1.0,
|
||||
"pos_acc": 0.0,
|
||||
"tag_micro_p": None,
|
||||
"tag_micro_r": None,
|
||||
"tag_micro_f": None,
|
||||
},
|
||||
)(make_tagger)
|
||||
|
||||
# nn_labeller
|
||||
Language.factory(
|
||||
"nn_labeller",
|
||||
default_config={
|
||||
"labels": None,
|
||||
"target": "dep_tag_offset",
|
||||
"model": DEFAULT_MT_MODEL,
|
||||
},
|
||||
)(make_nn_labeller)
|
||||
|
||||
# sentencizer
|
||||
Language.factory(
|
||||
"sentencizer",
|
||||
assigns=["token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"punct_chars": None,
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)(make_sentencizer)
|
||||
|
||||
# Set the flag to indicate that all factories have been registered
|
||||
FACTORIES_REGISTERED = True
|
||||
|
||||
|
||||
# We can't have function implementations for these factories in Cython, because
|
||||
# we need to build a Pydantic model for them dynamically, reading their argument
|
||||
# structure from the signature. In Cython 3, this doesn't work because the
|
||||
# from __future__ import annotations semantics are used, which means the types
|
||||
# are stored as strings.
|
||||
def make_sentencizer(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
punct_chars: Optional[List[str]],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Sentencizer(
|
||||
name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
def make_attribute_ruler(
|
||||
nlp: Language, name: str, validate: bool, scorer: Optional[Callable]
|
||||
):
|
||||
return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer)
|
||||
|
||||
|
||||
def make_entity_linker(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
n_sents: int,
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
use_gold_ents: bool,
|
||||
candidates_batch_size: int,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
|
||||
if not model.attrs.get("include_span_maker", False):
|
||||
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
|
||||
return EntityLinker_v1(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels_discard=labels_discard,
|
||||
n_sents=n_sents,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
)
|
||||
return EntityLinker(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels_discard=labels_discard,
|
||||
n_sents=n_sents,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
get_candidates_batch=get_candidates_batch,
|
||||
generate_empty_kb=generate_empty_kb,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
use_gold_ents=use_gold_ents,
|
||||
candidates_batch_size=candidates_batch_size,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
|
||||
def make_lemmatizer(
|
||||
nlp: Language,
|
||||
model: Optional[Model],
|
||||
name: str,
|
||||
mode: str,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Lemmatizer(
|
||||
nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
def make_textcat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
) -> TextCategorizer:
|
||||
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
|
||||
|
||||
|
||||
def make_token_splitter(
|
||||
nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0
|
||||
):
|
||||
return TokenSplitter(min_length=min_length, split_length=split_length)
|
||||
|
||||
|
||||
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
|
||||
return DocCleaner(attrs, silent=silent)
|
||||
|
||||
|
||||
def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec:
|
||||
return Tok2Vec(nlp.vocab, model, name)
|
||||
|
||||
|
||||
def make_spancat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Suggester,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
scorer: Optional[Callable],
|
||||
threshold: float,
|
||||
max_positive: Optional[int],
|
||||
) -> SpanCategorizer:
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
model=model,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=None,
|
||||
allow_overlap=True,
|
||||
max_positive=max_positive,
|
||||
threshold=threshold,
|
||||
scorer=scorer,
|
||||
add_negative_label=False,
|
||||
)
|
||||
|
||||
|
||||
def make_spancat_singlelabel(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Suggester,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
negative_weight: float,
|
||||
allow_overlap: bool,
|
||||
scorer: Optional[Callable],
|
||||
) -> SpanCategorizer:
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
model=model,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=negative_weight,
|
||||
allow_overlap=allow_overlap,
|
||||
max_positive=1,
|
||||
add_negative_label=True,
|
||||
threshold=None,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_future_entity_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
scorer: Optional[Callable],
|
||||
ent_id_sep: str,
|
||||
):
|
||||
if overwrite_ents:
|
||||
ents_filter = prioritize_new_ents_filter
|
||||
else:
|
||||
ents_filter = prioritize_existing_ents_filter
|
||||
return SpanRuler(
|
||||
nlp,
|
||||
name,
|
||||
spans_key=None,
|
||||
spans_filter=None,
|
||||
annotate_ents=True,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=False,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_entity_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
ent_id_sep: str,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return EntityRuler(
|
||||
nlp,
|
||||
name,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite_ents=overwrite_ents,
|
||||
ent_id_sep=ent_id_sep,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_span_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
spans_key: Optional[str],
|
||||
spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]],
|
||||
annotate_ents: bool,
|
||||
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return SpanRuler(
|
||||
nlp,
|
||||
name,
|
||||
spans_key=spans_key,
|
||||
spans_filter=spans_filter,
|
||||
annotate_ents=annotate_ents,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_edit_tree_lemmatizer(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
backoff: Optional[str],
|
||||
min_tree_freq: int,
|
||||
overwrite: bool,
|
||||
top_k: int,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return EditTreeLemmatizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
backoff=backoff,
|
||||
min_tree_freq=min_tree_freq,
|
||||
overwrite=overwrite,
|
||||
top_k=top_k,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_multilabel_textcat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
) -> MultiLabel_TextCategorizer:
|
||||
return MultiLabel_TextCategorizer(
|
||||
nlp.vocab, model, name, threshold=threshold, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
def make_span_finder(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[Iterable[Doc], Floats2d],
|
||||
spans_key: str,
|
||||
threshold: float,
|
||||
max_length: Optional[int],
|
||||
min_length: Optional[int],
|
||||
scorer: Optional[Callable],
|
||||
) -> SpanFinder:
|
||||
return SpanFinder(
|
||||
nlp,
|
||||
model=model,
|
||||
threshold=threshold,
|
||||
name=name,
|
||||
scorer=scorer,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
spans_key=spans_key,
|
||||
)
|
||||
|
||||
|
||||
def make_ner(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
incorrect_spans_key: Optional[str],
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return EntityRecognizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name=name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
incorrect_spans_key=incorrect_spans_key,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_beam_ner(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
incorrect_spans_key: Optional[str],
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return EntityRecognizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name=name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
incorrect_spans_key=incorrect_spans_key,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_parser(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
learn_tokens: bool,
|
||||
min_action_freq: int,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return DependencyParser(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name=name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_beam_parser(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
learn_tokens: bool,
|
||||
min_action_freq: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return DependencyParser(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name=name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_tagger(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
label_smoothing: float,
|
||||
):
|
||||
return Tagger(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name=name,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
neg_prefix=neg_prefix,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
|
||||
|
||||
def make_nn_labeller(
|
||||
nlp: Language, name: str, model: Model, labels: Optional[dict], target: str
|
||||
):
|
||||
return MultitaskObjective(nlp.vocab, model, name, target=target)
|
||||
|
||||
|
||||
def make_morphologizer(
|
||||
nlp: Language,
|
||||
model: Model,
|
||||
name: str,
|
||||
overwrite: bool,
|
||||
extend: bool,
|
||||
label_smoothing: float,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Morphologizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
overwrite=overwrite,
|
||||
extend=extend,
|
||||
label_smoothing=label_smoothing,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def make_senter(
|
||||
nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]
|
||||
):
|
||||
return SentenceRecognizer(
|
||||
nlp.vocab, model, name, overwrite=overwrite, scorer=scorer
|
||||
)
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
|
@ -73,17 +75,6 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
|
|||
return doc
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"token_splitter",
|
||||
default_config={"min_length": 25, "split_length": 10},
|
||||
retokenizes=True,
|
||||
)
|
||||
def make_token_splitter(
|
||||
nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0
|
||||
):
|
||||
return TokenSplitter(min_length=min_length, split_length=split_length)
|
||||
|
||||
|
||||
class TokenSplitter:
|
||||
def __init__(self, min_length: int = 0, split_length: int = 0):
|
||||
self.min_length = min_length
|
||||
|
@ -141,14 +132,6 @@ class TokenSplitter:
|
|||
util.from_disk(path, serializers, [])
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"doc_cleaner",
|
||||
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
|
||||
)
|
||||
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
|
||||
return DocCleaner(attrs, silent=silent)
|
||||
|
||||
|
||||
class DocCleaner:
|
||||
def __init__(self, attrs: Dict[str, Any], *, silent: bool = True):
|
||||
self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent}
|
||||
|
@ -201,3 +184,14 @@ class DocCleaner:
|
|||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
}
|
||||
util.from_disk(path, serializers, [])
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_doc_cleaner":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_doc_cleaner
|
||||
elif name == "make_token_splitter":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_token_splitter
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
@ -16,35 +18,10 @@ from ..vocab import Vocab
|
|||
from .pipe import Pipe
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"lemmatizer",
|
||||
assigns=["token.lemma"],
|
||||
default_config={
|
||||
"model": None,
|
||||
"mode": "lookup",
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"lemma_acc": 1.0},
|
||||
)
|
||||
def make_lemmatizer(
|
||||
nlp: Language,
|
||||
model: Optional[Model],
|
||||
name: str,
|
||||
mode: str,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Lemmatizer(
|
||||
nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
return Scorer.score_token_attr(examples, "lemma", **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.lemmatizer_scorer.v1")
|
||||
def make_lemmatizer_scorer():
|
||||
return lemmatizer_score
|
||||
|
||||
|
@ -334,3 +311,11 @@ class Lemmatizer(Pipe):
|
|||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
self._validate_tables()
|
||||
return self
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_lemmatizer":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_lemmatizer
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
|
@ -47,25 +49,6 @@ maxout_pieces = 3
|
|||
DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"morphologizer",
|
||||
assigns=["token.morph", "token.pos"],
|
||||
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False,
|
||||
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}, "label_smoothing": 0.0},
|
||||
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
|
||||
)
|
||||
def make_morphologizer(
|
||||
nlp: Language,
|
||||
model: Model,
|
||||
name: str,
|
||||
overwrite: bool,
|
||||
extend: bool,
|
||||
label_smoothing: float,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, label_smoothing=label_smoothing, scorer=scorer)
|
||||
|
||||
|
||||
def morphologizer_score(examples, **kwargs):
|
||||
def morph_key_getter(token, attr):
|
||||
return getattr(token, attr).key
|
||||
|
@ -81,7 +64,6 @@ def morphologizer_score(examples, **kwargs):
|
|||
return results
|
||||
|
||||
|
||||
@registry.scorers("spacy.morphologizer_scorer.v1")
|
||||
def make_morphologizer_scorer():
|
||||
return morphologizer_score
|
||||
|
||||
|
@ -309,3 +291,11 @@ class Morphologizer(Tagger):
|
|||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
return float(loss), d_scores
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_morphologizer":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_morphologizer
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
|
@ -30,14 +32,6 @@ subword_features = true
|
|||
DEFAULT_MT_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"nn_labeller",
|
||||
default_config={"labels": None, "target": "dep_tag_offset", "model": DEFAULT_MT_MODEL}
|
||||
)
|
||||
def make_nn_labeller(nlp: Language, name: str, model: Model, labels: Optional[dict], target: str):
|
||||
return MultitaskObjective(nlp.vocab, model, name)
|
||||
|
||||
|
||||
class MultitaskObjective(Tagger):
|
||||
"""Experimental: Assist training of a parser or tagger, by training a
|
||||
side-objective.
|
||||
|
@ -213,3 +207,11 @@ class ClozeMultitask(TrainablePipe):
|
|||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_nn_labeller":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_nn_labeller
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
@ -36,154 +38,10 @@ subword_features = true
|
|||
DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
"incorrect_spans_key": None,
|
||||
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
|
||||
|
||||
)
|
||||
def make_ner(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
incorrect_spans_key: Optional[str],
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Create a transition-based EntityRecognizer component. The entity recognizer
|
||||
identifies non-overlapping labelled spans of tokens.
|
||||
|
||||
The transition-based algorithm used encodes certain assumptions that are
|
||||
effective for "traditional" named entity recognition tasks, but may not be
|
||||
a good fit for every span identification problem. Specifically, the loss
|
||||
function optimizes for whole entity accuracy, so if your inter-annotator
|
||||
agreement on boundary tokens is low, the component will likely perform poorly
|
||||
on your problem. The transition-based algorithm also assumes that the most
|
||||
decisive information about your entities will be close to their initial tokens.
|
||||
If your entities are long and characterised by tokens in their middle, the
|
||||
component will likely do poorly on your task.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (Optional[TransitionSystem]): This defines how the parse-state is created,
|
||||
updated and evaluated. If 'moves' is None, a new instance is
|
||||
created with `self.TransitionSystem()`. Defaults to `None`.
|
||||
update_with_oracle_cut_size (int): During training, cut long sequences into
|
||||
shorter segments by creating intermediate states based on the gold-standard
|
||||
history. The model is not very sensitive to this parameter, so you usually
|
||||
won't need to change it. 100 is a good default.
|
||||
incorrect_spans_key (Optional[str]): Identifies spans that are known
|
||||
to be incorrect entity annotations. The incorrect entity annotations
|
||||
can be stored in the span group, under this key.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return EntityRecognizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
incorrect_spans_key=incorrect_spans_key,
|
||||
multitasks=[],
|
||||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"beam_ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"beam_width": 32,
|
||||
"incorrect_spans_key": None,
|
||||
"scorer": None,
|
||||
},
|
||||
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
|
||||
)
|
||||
def make_beam_ner(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[TransitionSystem],
|
||||
update_with_oracle_cut_size: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
incorrect_spans_key: Optional[str],
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Create a transition-based EntityRecognizer component that uses beam-search.
|
||||
The entity recognizer identifies non-overlapping labelled spans of tokens.
|
||||
|
||||
The transition-based algorithm used encodes certain assumptions that are
|
||||
effective for "traditional" named entity recognition tasks, but may not be
|
||||
a good fit for every span identification problem. Specifically, the loss
|
||||
function optimizes for whole entity accuracy, so if your inter-annotator
|
||||
agreement on boundary tokens is low, the component will likely perform poorly
|
||||
on your problem. The transition-based algorithm also assumes that the most
|
||||
decisive information about your entities will be close to their initial tokens.
|
||||
If your entities are long and characterised by tokens in their middle, the
|
||||
component will likely do poorly on your task.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (Optional[TransitionSystem]): This defines how the parse-state is created,
|
||||
updated and evaluated. If 'moves' is None, a new instance is
|
||||
created with `self.TransitionSystem()`. Defaults to `None`.
|
||||
update_with_oracle_cut_size (int): During training, cut long sequences into
|
||||
shorter segments by creating intermediate states based on the gold-standard
|
||||
history. The model is not very sensitive to this parameter, so you usually
|
||||
won't need to change it. 100 is a good default.
|
||||
beam_width (int): The number of candidate analyses to maintain.
|
||||
beam_density (float): The minimum ratio between the scores of the first and
|
||||
last candidates in the beam. This allows the parser to avoid exploring
|
||||
candidates that are too far behind. This is mostly intended to improve
|
||||
efficiency, but it can also improve accuracy as deeper search is not
|
||||
always better.
|
||||
beam_update_prob (float): The chance of making a beam update, instead of a
|
||||
greedy update. Greedy updates are an approximation for the beam updates,
|
||||
and are faster to compute.
|
||||
incorrect_spans_key (Optional[str]): Optional key into span groups of
|
||||
entities known to be non-entities.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return EntityRecognizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
multitasks=[],
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
incorrect_spans_key=incorrect_spans_key,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def ner_score(examples, **kwargs):
|
||||
return get_ner_prf(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.ner_scorer.v1")
|
||||
def make_ner_scorer():
|
||||
return ner_score
|
||||
|
||||
|
@ -261,3 +119,14 @@ cdef class EntityRecognizer(Parser):
|
|||
score_dict[(start, end, label)] += score
|
||||
entity_scores.append(score_dict)
|
||||
return entity_scores
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_ner":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_ner
|
||||
elif name == "make_beam_ner":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_beam_ner
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -21,13 +21,6 @@ cdef class Pipe:
|
|||
DOCS: https://spacy.io/api/pipe
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Raise a warning if an inheriting class implements 'begin_training'
|
||||
(from v2) instead of the new 'initialize' method (from v3)"""
|
||||
if hasattr(cls, "begin_training"):
|
||||
warnings.warn(Warnings.W088.format(name=cls.__name__))
|
||||
|
||||
def __call__(self, Doc doc) -> Doc:
|
||||
"""Apply the pipe to one document. The document is modified in place,
|
||||
and returned. This usually happens under the hood when the nlp object
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import srsly
|
||||
|
@ -14,22 +16,6 @@ from .senter import senter_score
|
|||
BACKWARD_OVERWRITE = False
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"sentencizer",
|
||||
assigns=["token.is_sent_start", "doc.sents"],
|
||||
default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_sentencizer(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
punct_chars: Optional[List[str]],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Sentencizer(name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer)
|
||||
|
||||
|
||||
class Sentencizer(Pipe):
|
||||
"""Segment the Doc into sentences using a rule-based strategy.
|
||||
|
||||
|
@ -181,3 +167,11 @@ class Sentencizer(Pipe):
|
|||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
self.overwrite = cfg.get("overwrite", self.overwrite)
|
||||
return self
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_sentencizer":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_sentencizer
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
@ -34,16 +36,6 @@ subword_features = true
|
|||
DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"senter",
|
||||
assigns=["token.is_sent_start"],
|
||||
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||
|
||||
|
||||
def senter_score(examples, **kwargs):
|
||||
def has_sents(doc):
|
||||
return doc.has_annotation("SENT_START")
|
||||
|
@ -53,7 +45,6 @@ def senter_score(examples, **kwargs):
|
|||
return results
|
||||
|
||||
|
||||
@registry.scorers("spacy.senter_scorer.v1")
|
||||
def make_senter_scorer():
|
||||
return senter_score
|
||||
|
||||
|
@ -185,3 +176,11 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
def add_label(self, label, values=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_senter":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_senter
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from thinc.api import Config, Model, Optimizer, set_dropout_rate
|
||||
|
@ -41,63 +43,6 @@ depth = 4
|
|||
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"span_finder",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_SPAN_FINDER_MODEL,
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_length": 25,
|
||||
"min_length": None,
|
||||
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
},
|
||||
)
|
||||
def make_span_finder(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[Iterable[Doc], Floats2d],
|
||||
spans_key: str,
|
||||
threshold: float,
|
||||
max_length: Optional[int],
|
||||
min_length: Optional[int],
|
||||
scorer: Optional[Callable],
|
||||
) -> "SpanFinder":
|
||||
"""Create a SpanFinder component. The component predicts whether a token is
|
||||
the start or the end of a potential span.
|
||||
|
||||
model (Model[List[Doc], Floats2d]): A model instance that
|
||||
is given a list of documents and predicts a probability for each token.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
threshold (float): Minimum probability to consider a prediction positive.
|
||||
max_length (Optional[int]): Maximum length of the produced spans, defaults
|
||||
to None meaning unlimited length.
|
||||
min_length (Optional[int]): Minimum length of the produced spans, defaults
|
||||
to None meaning shortest span length is 1.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
"""
|
||||
return SpanFinder(
|
||||
nlp,
|
||||
model=model,
|
||||
threshold=threshold,
|
||||
name=name,
|
||||
scorer=scorer,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
spans_key=spans_key,
|
||||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.span_finder_scorer.v1")
|
||||
def make_span_finder_scorer():
|
||||
return span_finder_score
|
||||
|
||||
|
@ -333,3 +278,11 @@ class SpanFinder(TrainablePipe):
|
|||
self.model.initialize(X=docs, Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_span_finder":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_span_finder
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
@ -32,105 +34,6 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
|||
DEFAULT_SPANS_KEY = "ruler"
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"future_entity_ruler",
|
||||
assigns=["doc.ents"],
|
||||
default_config={
|
||||
"phrase_matcher_attr": None,
|
||||
"validate": False,
|
||||
"overwrite_ents": False,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
"ent_id_sep": "__unused__",
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
"ents_p": 0.0,
|
||||
"ents_r": 0.0,
|
||||
"ents_per_type": None,
|
||||
},
|
||||
)
|
||||
def make_entity_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
scorer: Optional[Callable],
|
||||
ent_id_sep: str,
|
||||
):
|
||||
if overwrite_ents:
|
||||
ents_filter = prioritize_new_ents_filter
|
||||
else:
|
||||
ents_filter = prioritize_existing_ents_filter
|
||||
return SpanRuler(
|
||||
nlp,
|
||||
name,
|
||||
spans_key=None,
|
||||
spans_filter=None,
|
||||
annotate_ents=True,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=False,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"span_ruler",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"spans_filter": None,
|
||||
"annotate_ents": False,
|
||||
"ents_filter": {"@misc": "spacy.first_longest_spans_filter.v1"},
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite": True,
|
||||
"scorer": {
|
||||
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
},
|
||||
},
|
||||
default_score_weights={
|
||||
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_per_type": None,
|
||||
},
|
||||
)
|
||||
def make_span_ruler(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
spans_key: Optional[str],
|
||||
spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]],
|
||||
annotate_ents: bool,
|
||||
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return SpanRuler(
|
||||
nlp,
|
||||
name,
|
||||
spans_key=spans_key,
|
||||
spans_filter=spans_filter,
|
||||
annotate_ents=annotate_ents,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def prioritize_new_ents_filter(
|
||||
entities: Iterable[Span], spans: Iterable[Span]
|
||||
) -> List[Span]:
|
||||
|
@ -157,7 +60,6 @@ def prioritize_new_ents_filter(
|
|||
return entities + new_entities
|
||||
|
||||
|
||||
@registry.misc("spacy.prioritize_new_ents_filter.v1")
|
||||
def make_prioritize_new_ents_filter():
|
||||
return prioritize_new_ents_filter
|
||||
|
||||
|
@ -188,7 +90,6 @@ def prioritize_existing_ents_filter(
|
|||
return entities + new_entities
|
||||
|
||||
|
||||
@registry.misc("spacy.prioritize_existing_ents_filter.v1")
|
||||
def make_preserve_existing_ents_filter():
|
||||
return prioritize_existing_ents_filter
|
||||
|
||||
|
@ -208,7 +109,6 @@ def overlapping_labeled_spans_score(
|
|||
return Scorer.score_spans(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.overlapping_labeled_spans_scorer.v1")
|
||||
def make_overlapping_labeled_spans_scorer(spans_key: str = DEFAULT_SPANS_KEY):
|
||||
return partial(overlapping_labeled_spans_score, spans_key=spans_key)
|
||||
|
||||
|
@ -595,3 +495,14 @@ class SpanRuler(Pipe):
|
|||
"patterns": lambda p: srsly.write_jsonl(p, self.patterns),
|
||||
}
|
||||
util.to_disk(path, serializers, {})
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_span_ruler":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_span_ruler
|
||||
elif name == "make_entity_ruler":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_future_entity_ruler
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
@ -134,7 +136,6 @@ def preset_spans_suggester(
|
|||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
|
@ -143,7 +144,6 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
|||
return partial(ngram_suggester, sizes=sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||
def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
||||
"""Suggest all spans of the given lengths between a given min and max value - both inclusive.
|
||||
Spans are returned as a ragged array of integers. The array has two columns,
|
||||
|
@ -152,7 +152,6 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
|||
return build_ngram_suggester(sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.preset_spans_suggester.v1")
|
||||
def build_preset_spans_suggester(spans_key: str) -> Suggester:
|
||||
"""Suggest all spans that are already stored in doc.spans[spans_key].
|
||||
This is useful when an upstream component is used to set the spans
|
||||
|
@ -160,136 +159,6 @@ def build_preset_spans_suggester(spans_key: str) -> Suggester:
|
|||
return partial(preset_spans_suggester, spans_key=spans_key)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_positive": None,
|
||||
"model": DEFAULT_SPANCAT_MODEL,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)
|
||||
def make_spancat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Suggester,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
scorer: Optional[Callable],
|
||||
threshold: float,
|
||||
max_positive: Optional[int],
|
||||
) -> "SpanCategorizer":
|
||||
"""Create a SpanCategorizer component and configure it for multi-label
|
||||
classification to be able to assign multiple labels for each span.
|
||||
The span categorizer consists of two
|
||||
parts: a suggester function that proposes candidate spans, and a labeller
|
||||
model that predicts one or more labels for each span.
|
||||
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
|
||||
is given a list of documents and (start, end) indices representing
|
||||
candidate span offsets. The model predicts a probability for each category
|
||||
for each span.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
threshold (float): Minimum probability to consider a prediction positive.
|
||||
Spans with a positive prediction will be saved on the Doc. Defaults to
|
||||
0.5.
|
||||
max_positive (Optional[int]): Maximum number of labels to consider positive
|
||||
per span. Defaults to None, indicating no limit.
|
||||
"""
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
model=model,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=None,
|
||||
allow_overlap=True,
|
||||
max_positive=max_positive,
|
||||
threshold=threshold,
|
||||
scorer=scorer,
|
||||
add_negative_label=False,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat_singlelabel",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||
"negative_weight": 1.0,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||
"allow_overlap": True,
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)
|
||||
def make_spancat_singlelabel(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Suggester,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
negative_weight: float,
|
||||
allow_overlap: bool,
|
||||
scorer: Optional[Callable],
|
||||
) -> "SpanCategorizer":
|
||||
"""Create a SpanCategorizer component and configure it for multi-class
|
||||
classification. With this configuration each span can get at most one
|
||||
label. The span categorizer consists of two
|
||||
parts: a suggester function that proposes candidate spans, and a labeller
|
||||
model that predicts one or more labels for each span.
|
||||
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
|
||||
is given a list of documents and (start, end) indices representing
|
||||
candidate span offsets. The model predicts a probability for each category
|
||||
for each span.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
negative_weight (float): Multiplier for the loss terms.
|
||||
Can be used to downweight the negative samples if there are too many.
|
||||
allow_overlap (bool): If True the data is assumed to contain overlapping spans.
|
||||
Otherwise it produces non-overlapping spans greedily prioritizing
|
||||
higher assigned label scores.
|
||||
"""
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
model=model,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=negative_weight,
|
||||
allow_overlap=allow_overlap,
|
||||
max_positive=1,
|
||||
add_negative_label=True,
|
||||
threshold=None,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = dict(kwargs)
|
||||
attr_prefix = "spans_"
|
||||
|
@ -303,7 +172,6 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
return Scorer.score_spans(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.spancat_scorer.v1")
|
||||
def make_spancat_scorer():
|
||||
return spancat_score
|
||||
|
||||
|
@ -785,3 +653,14 @@ class SpanCategorizer(TrainablePipe):
|
|||
|
||||
spans.attrs["scores"] = numpy.array(attrs_scores)
|
||||
return spans
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_spancat":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_spancat
|
||||
elif name == "make_spancat_singlelabel":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_spancat_singlelabel
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
@ -35,36 +37,10 @@ subword_features = true
|
|||
DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!", "label_smoothing": 0.0},
|
||||
default_score_weights={"tag_acc": 1.0},
|
||||
)
|
||||
def make_tagger(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
label_smoothing: float,
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||
the tag probabilities. The output vectors should match the number of tags
|
||||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||
with the rows summing to 1).
|
||||
"""
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, label_smoothing=label_smoothing)
|
||||
|
||||
|
||||
def tagger_score(examples, **kwargs):
|
||||
return Scorer.score_token_attr(examples, "tag", **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.tagger_scorer.v1")
|
||||
def make_tagger_scorer():
|
||||
return tagger_score
|
||||
|
||||
|
@ -317,3 +293,11 @@ class Tagger(TrainablePipe):
|
|||
self.cfg["labels"].append(label)
|
||||
self.vocab.strings.add(label)
|
||||
return 1
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_tagger":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_tagger
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
|
@ -74,46 +76,6 @@ subword_features = true
|
|||
"""
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.0,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
"cats_score_desc": None,
|
||||
"cats_micro_p": None,
|
||||
"cats_micro_r": None,
|
||||
"cats_micro_f": None,
|
||||
"cats_macro_p": None,
|
||||
"cats_macro_r": None,
|
||||
"cats_macro_f": None,
|
||||
"cats_macro_auc": None,
|
||||
"cats_f_per_type": None,
|
||||
},
|
||||
)
|
||||
def make_textcat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels are considered
|
||||
to be mutually exclusive (i.e. one true label per doc).
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||
scores for each category.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
|
||||
|
||||
|
||||
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
return Scorer.score_cats(
|
||||
examples,
|
||||
|
@ -123,7 +85,6 @@ def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_scorer.v2")
|
||||
def make_textcat_scorer():
|
||||
return textcat_score
|
||||
|
||||
|
@ -412,3 +373,11 @@ class TextCategorizer(TrainablePipe):
|
|||
for val in vals:
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_textcat":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_textcat
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
|
@ -72,49 +74,6 @@ subword_features = true
|
|||
"""
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"textcat_multilabel",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
"cats_score_desc": None,
|
||||
"cats_micro_p": None,
|
||||
"cats_micro_r": None,
|
||||
"cats_micro_f": None,
|
||||
"cats_macro_p": None,
|
||||
"cats_macro_r": None,
|
||||
"cats_macro_f": None,
|
||||
"cats_macro_auc": None,
|
||||
"cats_f_per_type": None,
|
||||
},
|
||||
)
|
||||
def make_multilabel_textcat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
) -> "MultiLabel_TextCategorizer":
|
||||
"""Create a MultiLabel_TextCategorizer component. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels are considered
|
||||
to be non-mutually exclusive, which means that there can be zero or more labels
|
||||
per doc).
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||
scores for each category.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return MultiLabel_TextCategorizer(
|
||||
nlp.vocab, model, name, threshold=threshold, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
return Scorer.score_cats(
|
||||
examples,
|
||||
|
@ -124,7 +83,6 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_multilabel_scorer.v2")
|
||||
def make_textcat_multilabel_scorer():
|
||||
return textcat_multilabel_score
|
||||
|
||||
|
@ -212,3 +170,11 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
for val in ex.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_multilabel_textcat":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_multilabel_textcat
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import importlib
|
||||
import sys
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
|
@ -24,13 +26,6 @@ subword_features = true
|
|||
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"tok2vec", assigns=["doc.tensor"], default_config={"model": DEFAULT_TOK2VEC_MODEL}
|
||||
)
|
||||
def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec":
|
||||
return Tok2Vec(nlp.vocab, model, name)
|
||||
|
||||
|
||||
class Tok2Vec(TrainablePipe):
|
||||
"""Apply a "token-to-vector" model and set its outputs in the doc.tensor
|
||||
attribute. This is mostly useful to share a single subnetwork between multiple
|
||||
|
@ -320,3 +315,11 @@ def forward(model: Tok2VecListener, inputs, is_train: bool):
|
|||
|
||||
def _empty_backprop(dX): # for pickling
|
||||
return []
|
||||
|
||||
|
||||
# Setup backwards compatibility hook for factories
|
||||
def __getattr__(name):
|
||||
if name == "make_tok2vec":
|
||||
module = importlib.import_module("spacy.pipeline.factories")
|
||||
return module.make_tok2vec
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
|
|
@ -19,7 +19,7 @@ cdef class Parser(TrainablePipe):
|
|||
StateC** states,
|
||||
WeightsC weights,
|
||||
SizesC sizes
|
||||
) nogil
|
||||
) noexcept nogil
|
||||
|
||||
cdef void c_transition_batch(
|
||||
self,
|
||||
|
@ -27,4 +27,4 @@ cdef class Parser(TrainablePipe):
|
|||
const float* scores,
|
||||
int nr_class,
|
||||
int batch_size
|
||||
) nogil
|
||||
) noexcept nogil
|
||||
|
|
|
@ -316,7 +316,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
cdef void _parseC(
|
||||
self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
cdef int i
|
||||
cdef vector[StateC*] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
|
@ -359,7 +359,7 @@ cdef class Parser(TrainablePipe):
|
|||
const float* scores,
|
||||
int nr_class,
|
||||
int batch_size
|
||||
) nogil:
|
||||
) noexcept nogil:
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
with gil:
|
||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||
|
|
245
spacy/registrations.py
Normal file
245
spacy/registrations.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
"""Centralized registry population for spaCy config
|
||||
|
||||
This module centralizes registry decorations to prevent circular import issues
|
||||
with Cython annotation changes from __future__ import annotations. Functions
|
||||
remain in their original locations, but decoration is moved here.
|
||||
|
||||
Component definitions and registrations are in spacy/pipeline/factories.py
|
||||
"""
|
||||
# Global flag to track if registry has been populated
|
||||
REGISTRY_POPULATED = False
|
||||
|
||||
|
||||
def populate_registry() -> None:
|
||||
"""Populate the registry with all necessary components.
|
||||
|
||||
This function should be called before accessing the registry, to ensure
|
||||
it's populated. The function uses a global flag to prevent repopulation.
|
||||
"""
|
||||
global REGISTRY_POPULATED
|
||||
if REGISTRY_POPULATED:
|
||||
return
|
||||
|
||||
# Import all necessary modules
|
||||
from .lang.ja import create_tokenizer as create_japanese_tokenizer
|
||||
from .lang.ko import create_tokenizer as create_korean_tokenizer
|
||||
from .lang.th import create_thai_tokenizer
|
||||
from .lang.vi import create_vietnamese_tokenizer
|
||||
from .lang.zh import create_chinese_tokenizer
|
||||
from .language import load_lookups_data
|
||||
from .matcher.levenshtein import make_levenshtein_compare
|
||||
from .ml.models.entity_linker import (
|
||||
create_candidates,
|
||||
create_candidates_batch,
|
||||
empty_kb,
|
||||
empty_kb_for_config,
|
||||
load_kb,
|
||||
)
|
||||
from .pipeline.attributeruler import make_attribute_ruler_scorer
|
||||
from .pipeline.dep_parser import make_parser_scorer
|
||||
|
||||
# Import the functions we refactored by removing direct registry decorators
|
||||
from .pipeline.entity_linker import make_entity_linker_scorer
|
||||
from .pipeline.entityruler import (
|
||||
make_entity_ruler_scorer as make_entityruler_scorer,
|
||||
)
|
||||
from .pipeline.lemmatizer import make_lemmatizer_scorer
|
||||
from .pipeline.morphologizer import make_morphologizer_scorer
|
||||
from .pipeline.ner import make_ner_scorer
|
||||
from .pipeline.senter import make_senter_scorer
|
||||
from .pipeline.span_finder import make_span_finder_scorer
|
||||
from .pipeline.span_ruler import (
|
||||
make_overlapping_labeled_spans_scorer,
|
||||
make_preserve_existing_ents_filter,
|
||||
make_prioritize_new_ents_filter,
|
||||
)
|
||||
from .pipeline.spancat import (
|
||||
build_ngram_range_suggester,
|
||||
build_ngram_suggester,
|
||||
build_preset_spans_suggester,
|
||||
make_spancat_scorer,
|
||||
)
|
||||
|
||||
# Import all pipeline components that were using registry decorators
|
||||
from .pipeline.tagger import make_tagger_scorer
|
||||
from .pipeline.textcat import make_textcat_scorer
|
||||
from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer
|
||||
from .util import make_first_longest_spans_filter, registry
|
||||
|
||||
# Register miscellaneous components
|
||||
registry.misc("spacy.first_longest_spans_filter.v1")(
|
||||
make_first_longest_spans_filter
|
||||
)
|
||||
registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester)
|
||||
registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester)
|
||||
registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester)
|
||||
registry.misc("spacy.prioritize_new_ents_filter.v1")(
|
||||
make_prioritize_new_ents_filter
|
||||
)
|
||||
registry.misc("spacy.prioritize_existing_ents_filter.v1")(
|
||||
make_preserve_existing_ents_filter
|
||||
)
|
||||
registry.misc("spacy.levenshtein_compare.v1")(make_levenshtein_compare)
|
||||
# KB-related registrations
|
||||
registry.misc("spacy.KBFromFile.v1")(load_kb)
|
||||
registry.misc("spacy.EmptyKB.v2")(empty_kb_for_config)
|
||||
registry.misc("spacy.EmptyKB.v1")(empty_kb)
|
||||
registry.misc("spacy.CandidateGenerator.v1")(create_candidates)
|
||||
registry.misc("spacy.CandidateBatchGenerator.v1")(create_candidates_batch)
|
||||
registry.misc("spacy.LookupsDataLoader.v1")(load_lookups_data)
|
||||
|
||||
# Need to get references to the existing functions in registry by importing the function that is there
|
||||
# For the registry that was previously decorated
|
||||
|
||||
# Import ML components that use registry
|
||||
from .language import create_tokenizer
|
||||
from .ml._precomputable_affine import PrecomputableAffine
|
||||
from .ml.callbacks import (
|
||||
create_models_and_pipes_with_nvtx_range,
|
||||
create_models_with_nvtx_range,
|
||||
)
|
||||
from .ml.extract_ngrams import extract_ngrams
|
||||
from .ml.extract_spans import extract_spans
|
||||
|
||||
# Import decorator-removed ML components
|
||||
from .ml.featureextractor import FeatureExtractor
|
||||
from .ml.models.entity_linker import build_nel_encoder
|
||||
from .ml.models.multi_task import (
|
||||
create_pretrain_characters,
|
||||
create_pretrain_vectors,
|
||||
)
|
||||
from .ml.models.parser import build_tb_parser_model
|
||||
from .ml.models.span_finder import build_finder_model
|
||||
from .ml.models.spancat import (
|
||||
build_linear_logistic,
|
||||
build_mean_max_reducer,
|
||||
build_spancat_model,
|
||||
)
|
||||
from .ml.models.tagger import build_tagger_model
|
||||
from .ml.models.textcat import (
|
||||
build_bow_text_classifier,
|
||||
build_bow_text_classifier_v3,
|
||||
build_reduce_text_classifier,
|
||||
build_simple_cnn_text_classifier,
|
||||
build_text_classifier_lowdata,
|
||||
build_text_classifier_v2,
|
||||
build_textcat_parametric_attention_v1,
|
||||
)
|
||||
from .ml.models.tok2vec import (
|
||||
BiLSTMEncoder,
|
||||
CharacterEmbed,
|
||||
MaxoutWindowEncoder,
|
||||
MishWindowEncoder,
|
||||
MultiHashEmbed,
|
||||
build_hash_embed_cnn_tok2vec,
|
||||
build_Tok2Vec_model,
|
||||
tok2vec_listener_v1,
|
||||
)
|
||||
from .ml.staticvectors import StaticVectors
|
||||
from .ml.tb_framework import TransitionModel
|
||||
from .training.augment import (
|
||||
create_combined_augmenter,
|
||||
create_lower_casing_augmenter,
|
||||
create_orth_variants_augmenter,
|
||||
)
|
||||
from .training.batchers import (
|
||||
configure_minibatch,
|
||||
configure_minibatch_by_padded_size,
|
||||
configure_minibatch_by_words,
|
||||
)
|
||||
from .training.callbacks import create_copy_from_base_model
|
||||
from .training.loggers import console_logger, console_logger_v3
|
||||
|
||||
# Register scorers
|
||||
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
|
||||
registry.scorers("spacy.ner_scorer.v1")(make_ner_scorer)
|
||||
# span_ruler_scorer removed as it's not in span_ruler.py
|
||||
registry.scorers("spacy.entity_ruler_scorer.v1")(make_entityruler_scorer)
|
||||
registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer)
|
||||
registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer)
|
||||
registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer)
|
||||
registry.scorers("spacy.textcat_multilabel_scorer.v1")(
|
||||
make_textcat_multilabel_scorer
|
||||
)
|
||||
registry.scorers("spacy.textcat_multilabel_scorer.v2")(
|
||||
make_textcat_multilabel_scorer
|
||||
)
|
||||
registry.scorers("spacy.lemmatizer_scorer.v1")(make_lemmatizer_scorer)
|
||||
registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer)
|
||||
registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer)
|
||||
registry.scorers("spacy.entity_linker_scorer.v1")(make_entity_linker_scorer)
|
||||
registry.scorers("spacy.overlapping_labeled_spans_scorer.v1")(
|
||||
make_overlapping_labeled_spans_scorer
|
||||
)
|
||||
registry.scorers("spacy.attribute_ruler_scorer.v1")(make_attribute_ruler_scorer)
|
||||
registry.scorers("spacy.parser_scorer.v1")(make_parser_scorer)
|
||||
registry.scorers("spacy.morphologizer_scorer.v1")(make_morphologizer_scorer)
|
||||
|
||||
# Register tokenizers
|
||||
registry.tokenizers("spacy.Tokenizer.v1")(create_tokenizer)
|
||||
registry.tokenizers("spacy.ja.JapaneseTokenizer")(create_japanese_tokenizer)
|
||||
registry.tokenizers("spacy.zh.ChineseTokenizer")(create_chinese_tokenizer)
|
||||
registry.tokenizers("spacy.ko.KoreanTokenizer")(create_korean_tokenizer)
|
||||
registry.tokenizers("spacy.vi.VietnameseTokenizer")(create_vietnamese_tokenizer)
|
||||
registry.tokenizers("spacy.th.ThaiTokenizer")(create_thai_tokenizer)
|
||||
|
||||
# Register tok2vec architectures we've modified
|
||||
registry.architectures("spacy.Tok2VecListener.v1")(tok2vec_listener_v1)
|
||||
registry.architectures("spacy.HashEmbedCNN.v2")(build_hash_embed_cnn_tok2vec)
|
||||
registry.architectures("spacy.Tok2Vec.v2")(build_Tok2Vec_model)
|
||||
registry.architectures("spacy.MultiHashEmbed.v2")(MultiHashEmbed)
|
||||
registry.architectures("spacy.CharacterEmbed.v2")(CharacterEmbed)
|
||||
registry.architectures("spacy.MaxoutWindowEncoder.v2")(MaxoutWindowEncoder)
|
||||
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
|
||||
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
|
||||
registry.architectures("spacy.EntityLinker.v2")(build_nel_encoder)
|
||||
registry.architectures("spacy.TextCatCNN.v2")(build_simple_cnn_text_classifier)
|
||||
registry.architectures("spacy.TextCatBOW.v2")(build_bow_text_classifier)
|
||||
registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3)
|
||||
registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2)
|
||||
registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata)
|
||||
registry.architectures("spacy.TextCatParametricAttention.v1")(
|
||||
build_textcat_parametric_attention_v1
|
||||
)
|
||||
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
|
||||
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
|
||||
registry.architectures("spacy.SpanFinder.v1")(build_finder_model)
|
||||
registry.architectures("spacy.TransitionBasedParser.v2")(build_tb_parser_model)
|
||||
registry.architectures("spacy.PretrainVectors.v1")(create_pretrain_vectors)
|
||||
registry.architectures("spacy.PretrainCharacters.v1")(create_pretrain_characters)
|
||||
registry.architectures("spacy.Tagger.v2")(build_tagger_model)
|
||||
|
||||
# Register layers
|
||||
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
|
||||
registry.layers("spacy.extract_spans.v1")(extract_spans)
|
||||
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
|
||||
registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic)
|
||||
registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer)
|
||||
registry.layers("spacy.StaticVectors.v2")(StaticVectors)
|
||||
registry.layers("spacy.PrecomputableAffine.v1")(PrecomputableAffine)
|
||||
registry.layers("spacy.CharEmbed.v1")(CharacterEmbed)
|
||||
registry.layers("spacy.TransitionModel.v1")(TransitionModel)
|
||||
|
||||
# Register callbacks
|
||||
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
|
||||
registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range)
|
||||
registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")(
|
||||
create_models_and_pipes_with_nvtx_range
|
||||
)
|
||||
|
||||
# Register loggers
|
||||
registry.loggers("spacy.ConsoleLogger.v2")(console_logger)
|
||||
registry.loggers("spacy.ConsoleLogger.v3")(console_logger_v3)
|
||||
|
||||
# Register batchers
|
||||
registry.batchers("spacy.batch_by_padded.v1")(configure_minibatch_by_padded_size)
|
||||
registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words)
|
||||
registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch)
|
||||
|
||||
# Register augmenters
|
||||
registry.augmenters("spacy.combined_augmenter.v1")(create_combined_augmenter)
|
||||
registry.augmenters("spacy.lower_case.v1")(create_lower_casing_augmenter)
|
||||
registry.augmenters("spacy.orth_variants.v1")(create_orth_variants_augmenter)
|
||||
|
||||
# Set the flag to indicate that the registry has been populated
|
||||
REGISTRY_POPULATED = True
|
|
@ -479,3 +479,4 @@ NAMES = [it[0] for it in sorted(IDS.items(), key=sort_nums)]
|
|||
# (which is generating an enormous amount of C++ in Cython 0.24+)
|
||||
# We keep the enum cdef, and just make sure the names are available to Python
|
||||
locals().update(IDS)
|
||||
|
||||
|
|
132
spacy/tests/factory_registrations.json
Normal file
132
spacy/tests/factory_registrations.json
Normal file
|
@ -0,0 +1,132 @@
|
|||
{
|
||||
"attribute_ruler": {
|
||||
"name": "attribute_ruler",
|
||||
"module": "spacy.pipeline.attributeruler",
|
||||
"function": "make_attribute_ruler"
|
||||
},
|
||||
"beam_ner": {
|
||||
"name": "beam_ner",
|
||||
"module": "spacy.pipeline.ner",
|
||||
"function": "make_beam_ner"
|
||||
},
|
||||
"beam_parser": {
|
||||
"name": "beam_parser",
|
||||
"module": "spacy.pipeline.dep_parser",
|
||||
"function": "make_beam_parser"
|
||||
},
|
||||
"doc_cleaner": {
|
||||
"name": "doc_cleaner",
|
||||
"module": "spacy.pipeline.functions",
|
||||
"function": "make_doc_cleaner"
|
||||
},
|
||||
"entity_linker": {
|
||||
"name": "entity_linker",
|
||||
"module": "spacy.pipeline.entity_linker",
|
||||
"function": "make_entity_linker"
|
||||
},
|
||||
"entity_ruler": {
|
||||
"name": "entity_ruler",
|
||||
"module": "spacy.pipeline.entityruler",
|
||||
"function": "make_entity_ruler"
|
||||
},
|
||||
"future_entity_ruler": {
|
||||
"name": "future_entity_ruler",
|
||||
"module": "spacy.pipeline.span_ruler",
|
||||
"function": "make_entity_ruler"
|
||||
},
|
||||
"lemmatizer": {
|
||||
"name": "lemmatizer",
|
||||
"module": "spacy.pipeline.lemmatizer",
|
||||
"function": "make_lemmatizer"
|
||||
},
|
||||
"merge_entities": {
|
||||
"name": "merge_entities",
|
||||
"module": "spacy.language",
|
||||
"function": "Language.component.<locals>.add_component.<locals>.factory_func"
|
||||
},
|
||||
"merge_noun_chunks": {
|
||||
"name": "merge_noun_chunks",
|
||||
"module": "spacy.language",
|
||||
"function": "Language.component.<locals>.add_component.<locals>.factory_func"
|
||||
},
|
||||
"merge_subtokens": {
|
||||
"name": "merge_subtokens",
|
||||
"module": "spacy.language",
|
||||
"function": "Language.component.<locals>.add_component.<locals>.factory_func"
|
||||
},
|
||||
"morphologizer": {
|
||||
"name": "morphologizer",
|
||||
"module": "spacy.pipeline.morphologizer",
|
||||
"function": "make_morphologizer"
|
||||
},
|
||||
"ner": {
|
||||
"name": "ner",
|
||||
"module": "spacy.pipeline.ner",
|
||||
"function": "make_ner"
|
||||
},
|
||||
"parser": {
|
||||
"name": "parser",
|
||||
"module": "spacy.pipeline.dep_parser",
|
||||
"function": "make_parser"
|
||||
},
|
||||
"sentencizer": {
|
||||
"name": "sentencizer",
|
||||
"module": "spacy.pipeline.sentencizer",
|
||||
"function": "make_sentencizer"
|
||||
},
|
||||
"senter": {
|
||||
"name": "senter",
|
||||
"module": "spacy.pipeline.senter",
|
||||
"function": "make_senter"
|
||||
},
|
||||
"span_finder": {
|
||||
"name": "span_finder",
|
||||
"module": "spacy.pipeline.span_finder",
|
||||
"function": "make_span_finder"
|
||||
},
|
||||
"span_ruler": {
|
||||
"name": "span_ruler",
|
||||
"module": "spacy.pipeline.span_ruler",
|
||||
"function": "make_span_ruler"
|
||||
},
|
||||
"spancat": {
|
||||
"name": "spancat",
|
||||
"module": "spacy.pipeline.spancat",
|
||||
"function": "make_spancat"
|
||||
},
|
||||
"spancat_singlelabel": {
|
||||
"name": "spancat_singlelabel",
|
||||
"module": "spacy.pipeline.spancat",
|
||||
"function": "make_spancat_singlelabel"
|
||||
},
|
||||
"tagger": {
|
||||
"name": "tagger",
|
||||
"module": "spacy.pipeline.tagger",
|
||||
"function": "make_tagger"
|
||||
},
|
||||
"textcat": {
|
||||
"name": "textcat",
|
||||
"module": "spacy.pipeline.textcat",
|
||||
"function": "make_textcat"
|
||||
},
|
||||
"textcat_multilabel": {
|
||||
"name": "textcat_multilabel",
|
||||
"module": "spacy.pipeline.textcat_multilabel",
|
||||
"function": "make_multilabel_textcat"
|
||||
},
|
||||
"tok2vec": {
|
||||
"name": "tok2vec",
|
||||
"module": "spacy.pipeline.tok2vec",
|
||||
"function": "make_tok2vec"
|
||||
},
|
||||
"token_splitter": {
|
||||
"name": "token_splitter",
|
||||
"module": "spacy.pipeline.functions",
|
||||
"function": "make_token_splitter"
|
||||
},
|
||||
"trainable_lemmatizer": {
|
||||
"name": "trainable_lemmatizer",
|
||||
"module": "spacy.pipeline.edit_tree_lemmatizer",
|
||||
"function": "make_edit_tree_lemmatizer"
|
||||
}
|
||||
}
|
|
@ -529,17 +529,6 @@ def test_pipe_label_data_no_labels(pipe):
|
|||
assert "labels" not in get_arg_names(initialize)
|
||||
|
||||
|
||||
def test_warning_pipe_begin_training():
|
||||
with pytest.warns(UserWarning, match="begin_training"):
|
||||
|
||||
class IncompatPipe(TrainablePipe):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
def begin_training(*args, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
def test_pipe_methods_initialize():
|
||||
"""Test that the [initialize] config reflects the components correctly."""
|
||||
nlp = Language()
|
||||
|
|
284
spacy/tests/registry_contents.json
Normal file
284
spacy/tests/registry_contents.json
Normal file
|
@ -0,0 +1,284 @@
|
|||
{
|
||||
"architectures": [
|
||||
"spacy-legacy.CharacterEmbed.v1",
|
||||
"spacy-legacy.EntityLinker.v1",
|
||||
"spacy-legacy.HashEmbedCNN.v1",
|
||||
"spacy-legacy.MaxoutWindowEncoder.v1",
|
||||
"spacy-legacy.MishWindowEncoder.v1",
|
||||
"spacy-legacy.MultiHashEmbed.v1",
|
||||
"spacy-legacy.Tagger.v1",
|
||||
"spacy-legacy.TextCatBOW.v1",
|
||||
"spacy-legacy.TextCatCNN.v1",
|
||||
"spacy-legacy.TextCatEnsemble.v1",
|
||||
"spacy-legacy.Tok2Vec.v1",
|
||||
"spacy-legacy.TransitionBasedParser.v1",
|
||||
"spacy.CharacterEmbed.v2",
|
||||
"spacy.EntityLinker.v2",
|
||||
"spacy.HashEmbedCNN.v2",
|
||||
"spacy.MaxoutWindowEncoder.v2",
|
||||
"spacy.MishWindowEncoder.v2",
|
||||
"spacy.MultiHashEmbed.v2",
|
||||
"spacy.PretrainCharacters.v1",
|
||||
"spacy.PretrainVectors.v1",
|
||||
"spacy.SpanCategorizer.v1",
|
||||
"spacy.SpanFinder.v1",
|
||||
"spacy.Tagger.v2",
|
||||
"spacy.TextCatBOW.v2",
|
||||
"spacy.TextCatBOW.v3",
|
||||
"spacy.TextCatCNN.v2",
|
||||
"spacy.TextCatEnsemble.v2",
|
||||
"spacy.TextCatLowData.v1",
|
||||
"spacy.TextCatParametricAttention.v1",
|
||||
"spacy.TextCatReduce.v1",
|
||||
"spacy.Tok2Vec.v2",
|
||||
"spacy.Tok2VecListener.v1",
|
||||
"spacy.TorchBiLSTMEncoder.v1",
|
||||
"spacy.TransitionBasedParser.v2"
|
||||
],
|
||||
"augmenters": [
|
||||
"spacy.combined_augmenter.v1",
|
||||
"spacy.lower_case.v1",
|
||||
"spacy.orth_variants.v1"
|
||||
],
|
||||
"batchers": [
|
||||
"spacy.batch_by_padded.v1",
|
||||
"spacy.batch_by_sequence.v1",
|
||||
"spacy.batch_by_words.v1"
|
||||
],
|
||||
"callbacks": [
|
||||
"spacy.copy_from_base_model.v1",
|
||||
"spacy.models_and_pipes_with_nvtx_range.v1",
|
||||
"spacy.models_with_nvtx_range.v1"
|
||||
],
|
||||
"cli": [],
|
||||
"datasets": [],
|
||||
"displacy_colors": [],
|
||||
"factories": [
|
||||
"attribute_ruler",
|
||||
"beam_ner",
|
||||
"beam_parser",
|
||||
"doc_cleaner",
|
||||
"entity_linker",
|
||||
"entity_ruler",
|
||||
"future_entity_ruler",
|
||||
"lemmatizer",
|
||||
"merge_entities",
|
||||
"merge_noun_chunks",
|
||||
"merge_subtokens",
|
||||
"morphologizer",
|
||||
"ner",
|
||||
"parser",
|
||||
"sentencizer",
|
||||
"senter",
|
||||
"span_finder",
|
||||
"span_ruler",
|
||||
"spancat",
|
||||
"spancat_singlelabel",
|
||||
"tagger",
|
||||
"textcat",
|
||||
"textcat_multilabel",
|
||||
"tok2vec",
|
||||
"token_splitter",
|
||||
"trainable_lemmatizer"
|
||||
],
|
||||
"initializers": [
|
||||
"glorot_normal_init.v1",
|
||||
"glorot_uniform_init.v1",
|
||||
"he_normal_init.v1",
|
||||
"he_uniform_init.v1",
|
||||
"lecun_normal_init.v1",
|
||||
"lecun_uniform_init.v1",
|
||||
"normal_init.v1",
|
||||
"uniform_init.v1",
|
||||
"zero_init.v1"
|
||||
],
|
||||
"languages": [],
|
||||
"layers": [
|
||||
"CauchySimilarity.v1",
|
||||
"ClippedLinear.v1",
|
||||
"Dish.v1",
|
||||
"Dropout.v1",
|
||||
"Embed.v1",
|
||||
"Gelu.v1",
|
||||
"HardSigmoid.v1",
|
||||
"HardSwish.v1",
|
||||
"HardSwishMobilenet.v1",
|
||||
"HardTanh.v1",
|
||||
"HashEmbed.v1",
|
||||
"LSTM.v1",
|
||||
"LayerNorm.v1",
|
||||
"Linear.v1",
|
||||
"Logistic.v1",
|
||||
"MXNetWrapper.v1",
|
||||
"Maxout.v1",
|
||||
"Mish.v1",
|
||||
"MultiSoftmax.v1",
|
||||
"ParametricAttention.v1",
|
||||
"ParametricAttention.v2",
|
||||
"PyTorchLSTM.v1",
|
||||
"PyTorchRNNWrapper.v1",
|
||||
"PyTorchWrapper.v1",
|
||||
"PyTorchWrapper.v2",
|
||||
"PyTorchWrapper.v3",
|
||||
"Relu.v1",
|
||||
"ReluK.v1",
|
||||
"Sigmoid.v1",
|
||||
"Softmax.v1",
|
||||
"Softmax.v2",
|
||||
"SparseLinear.v1",
|
||||
"SparseLinear.v2",
|
||||
"Swish.v1",
|
||||
"add.v1",
|
||||
"bidirectional.v1",
|
||||
"chain.v1",
|
||||
"clone.v1",
|
||||
"concatenate.v1",
|
||||
"expand_window.v1",
|
||||
"list2array.v1",
|
||||
"list2padded.v1",
|
||||
"list2ragged.v1",
|
||||
"noop.v1",
|
||||
"padded2list.v1",
|
||||
"premap_ids.v1",
|
||||
"ragged2list.v1",
|
||||
"reduce_first.v1",
|
||||
"reduce_last.v1",
|
||||
"reduce_max.v1",
|
||||
"reduce_mean.v1",
|
||||
"reduce_sum.v1",
|
||||
"remap_ids.v1",
|
||||
"remap_ids.v2",
|
||||
"residual.v1",
|
||||
"resizable.v1",
|
||||
"siamese.v1",
|
||||
"sigmoid_activation.v1",
|
||||
"softmax_activation.v1",
|
||||
"spacy-legacy.StaticVectors.v1",
|
||||
"spacy.CharEmbed.v1",
|
||||
"spacy.FeatureExtractor.v1",
|
||||
"spacy.LinearLogistic.v1",
|
||||
"spacy.PrecomputableAffine.v1",
|
||||
"spacy.StaticVectors.v2",
|
||||
"spacy.TransitionModel.v1",
|
||||
"spacy.extract_ngrams.v1",
|
||||
"spacy.extract_spans.v1",
|
||||
"spacy.mean_max_reducer.v1",
|
||||
"strings2arrays.v1",
|
||||
"tuplify.v1",
|
||||
"uniqued.v1",
|
||||
"with_array.v1",
|
||||
"with_array2d.v1",
|
||||
"with_cpu.v1",
|
||||
"with_flatten.v1",
|
||||
"with_flatten.v2",
|
||||
"with_getitem.v1",
|
||||
"with_list.v1",
|
||||
"with_padded.v1",
|
||||
"with_ragged.v1",
|
||||
"with_reshape.v1"
|
||||
],
|
||||
"lemmatizers": [],
|
||||
"loggers": [
|
||||
"spacy-legacy.ConsoleLogger.v1",
|
||||
"spacy-legacy.ConsoleLogger.v2",
|
||||
"spacy-legacy.WandbLogger.v1",
|
||||
"spacy.ChainLogger.v1",
|
||||
"spacy.ClearMLLogger.v1",
|
||||
"spacy.ClearMLLogger.v2",
|
||||
"spacy.ConsoleLogger.v2",
|
||||
"spacy.ConsoleLogger.v3",
|
||||
"spacy.CupyLogger.v1",
|
||||
"spacy.LookupLogger.v1",
|
||||
"spacy.MLflowLogger.v1",
|
||||
"spacy.MLflowLogger.v2",
|
||||
"spacy.PyTorchLogger.v1",
|
||||
"spacy.WandbLogger.v1",
|
||||
"spacy.WandbLogger.v2",
|
||||
"spacy.WandbLogger.v3",
|
||||
"spacy.WandbLogger.v4",
|
||||
"spacy.WandbLogger.v5"
|
||||
],
|
||||
"lookups": [],
|
||||
"losses": [
|
||||
"CategoricalCrossentropy.v1",
|
||||
"CategoricalCrossentropy.v2",
|
||||
"CategoricalCrossentropy.v3",
|
||||
"CosineDistance.v1",
|
||||
"L2Distance.v1",
|
||||
"SequenceCategoricalCrossentropy.v1",
|
||||
"SequenceCategoricalCrossentropy.v2",
|
||||
"SequenceCategoricalCrossentropy.v3"
|
||||
],
|
||||
"misc": [
|
||||
"spacy.CandidateBatchGenerator.v1",
|
||||
"spacy.CandidateGenerator.v1",
|
||||
"spacy.EmptyKB.v1",
|
||||
"spacy.EmptyKB.v2",
|
||||
"spacy.KBFromFile.v1",
|
||||
"spacy.LookupsDataLoader.v1",
|
||||
"spacy.first_longest_spans_filter.v1",
|
||||
"spacy.levenshtein_compare.v1",
|
||||
"spacy.ngram_range_suggester.v1",
|
||||
"spacy.ngram_suggester.v1",
|
||||
"spacy.preset_spans_suggester.v1",
|
||||
"spacy.prioritize_existing_ents_filter.v1",
|
||||
"spacy.prioritize_new_ents_filter.v1"
|
||||
],
|
||||
"models": [],
|
||||
"ops": [
|
||||
"CupyOps",
|
||||
"MPSOps",
|
||||
"NumpyOps"
|
||||
],
|
||||
"optimizers": [
|
||||
"Adam.v1",
|
||||
"RAdam.v1",
|
||||
"SGD.v1"
|
||||
],
|
||||
"readers": [
|
||||
"ml_datasets.cmu_movies.v1",
|
||||
"ml_datasets.dbpedia.v1",
|
||||
"ml_datasets.imdb_sentiment.v1",
|
||||
"spacy.Corpus.v1",
|
||||
"spacy.JsonlCorpus.v1",
|
||||
"spacy.PlainTextCorpus.v1",
|
||||
"spacy.read_labels.v1",
|
||||
"srsly.read_json.v1",
|
||||
"srsly.read_jsonl.v1",
|
||||
"srsly.read_msgpack.v1",
|
||||
"srsly.read_yaml.v1"
|
||||
],
|
||||
"schedules": [
|
||||
"compounding.v1",
|
||||
"constant.v1",
|
||||
"constant_then.v1",
|
||||
"cyclic_triangular.v1",
|
||||
"decaying.v1",
|
||||
"slanted_triangular.v1",
|
||||
"warmup_linear.v1"
|
||||
],
|
||||
"scorers": [
|
||||
"spacy-legacy.textcat_multilabel_scorer.v1",
|
||||
"spacy-legacy.textcat_scorer.v1",
|
||||
"spacy.attribute_ruler_scorer.v1",
|
||||
"spacy.entity_linker_scorer.v1",
|
||||
"spacy.entity_ruler_scorer.v1",
|
||||
"spacy.lemmatizer_scorer.v1",
|
||||
"spacy.morphologizer_scorer.v1",
|
||||
"spacy.ner_scorer.v1",
|
||||
"spacy.overlapping_labeled_spans_scorer.v1",
|
||||
"spacy.parser_scorer.v1",
|
||||
"spacy.senter_scorer.v1",
|
||||
"spacy.span_finder_scorer.v1",
|
||||
"spacy.spancat_scorer.v1",
|
||||
"spacy.tagger_scorer.v1",
|
||||
"spacy.textcat_multilabel_scorer.v2",
|
||||
"spacy.textcat_scorer.v2"
|
||||
],
|
||||
"tokenizers": [
|
||||
"spacy.Tokenizer.v1"
|
||||
],
|
||||
"vectors": [
|
||||
"spacy.Vectors.v1"
|
||||
]
|
||||
}
|
|
@ -87,7 +87,7 @@ def entity_linker():
|
|||
|
||||
|
||||
objects_to_test = (
|
||||
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
|
||||
[nlp, vectors, custom_pipe, tagger, entity_linker],
|
||||
["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
|
||||
)
|
||||
|
||||
|
@ -101,8 +101,9 @@ def write_obj_and_catch_warnings(obj):
|
|||
return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1])
|
||||
def test_to_disk_resource_warning(obj):
|
||||
@pytest.mark.parametrize("obj_factory", objects_to_test[0], ids=objects_to_test[1])
|
||||
def test_to_disk_resource_warning(obj_factory):
|
||||
obj = obj_factory()
|
||||
warnings_list = write_obj_and_catch_warnings(obj)
|
||||
assert len(warnings_list) == 0
|
||||
|
||||
|
@ -139,9 +140,11 @@ def test_save_and_load_knowledge_base():
|
|||
|
||||
class TestToDiskResourceWarningUnittest(TestCase):
|
||||
def test_resource_warning(self):
|
||||
scenarios = zip(*objects_to_test)
|
||||
items = [x() for x in objects_to_test[0]]
|
||||
names = objects_to_test[1]
|
||||
scenarios = zip(items, names)
|
||||
|
||||
for scenario in scenarios:
|
||||
with self.subTest(msg=scenario[1]):
|
||||
warnings_list = write_obj_and_catch_warnings(scenario[0])
|
||||
for item, name in scenarios:
|
||||
with self.subTest(msg=name):
|
||||
warnings_list = write_obj_and_catch_warnings(item)
|
||||
self.assertEqual(len(warnings_list), 0)
|
||||
|
|
85
spacy/tests/test_factory_imports.py
Normal file
85
spacy/tests/test_factory_imports.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# coding: utf-8
|
||||
"""Test factory import compatibility from original and new locations."""
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"factory_name,original_module,compat_module",
|
||||
[
|
||||
("make_tagger", "spacy.pipeline.factories", "spacy.pipeline.tagger"),
|
||||
("make_sentencizer", "spacy.pipeline.factories", "spacy.pipeline.sentencizer"),
|
||||
("make_ner", "spacy.pipeline.factories", "spacy.pipeline.ner"),
|
||||
("make_parser", "spacy.pipeline.factories", "spacy.pipeline.dep_parser"),
|
||||
("make_tok2vec", "spacy.pipeline.factories", "spacy.pipeline.tok2vec"),
|
||||
("make_spancat", "spacy.pipeline.factories", "spacy.pipeline.spancat"),
|
||||
(
|
||||
"make_spancat_singlelabel",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.spancat",
|
||||
),
|
||||
("make_lemmatizer", "spacy.pipeline.factories", "spacy.pipeline.lemmatizer"),
|
||||
("make_entity_ruler", "spacy.pipeline.factories", "spacy.pipeline.entityruler"),
|
||||
("make_span_ruler", "spacy.pipeline.factories", "spacy.pipeline.span_ruler"),
|
||||
(
|
||||
"make_edit_tree_lemmatizer",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.edit_tree_lemmatizer",
|
||||
),
|
||||
(
|
||||
"make_attribute_ruler",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.attributeruler",
|
||||
),
|
||||
(
|
||||
"make_entity_linker",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.entity_linker",
|
||||
),
|
||||
("make_textcat", "spacy.pipeline.factories", "spacy.pipeline.textcat"),
|
||||
("make_token_splitter", "spacy.pipeline.factories", "spacy.pipeline.functions"),
|
||||
("make_doc_cleaner", "spacy.pipeline.factories", "spacy.pipeline.functions"),
|
||||
(
|
||||
"make_morphologizer",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.morphologizer",
|
||||
),
|
||||
("make_senter", "spacy.pipeline.factories", "spacy.pipeline.senter"),
|
||||
("make_span_finder", "spacy.pipeline.factories", "spacy.pipeline.span_finder"),
|
||||
(
|
||||
"make_multilabel_textcat",
|
||||
"spacy.pipeline.factories",
|
||||
"spacy.pipeline.textcat_multilabel",
|
||||
),
|
||||
("make_beam_ner", "spacy.pipeline.factories", "spacy.pipeline.ner"),
|
||||
("make_beam_parser", "spacy.pipeline.factories", "spacy.pipeline.dep_parser"),
|
||||
("make_nn_labeller", "spacy.pipeline.factories", "spacy.pipeline.multitask"),
|
||||
# This one's special because the function was named make_span_ruler, so
|
||||
# the name in the registrations.py doesn't match the name we make the import hook
|
||||
# point to. We could make a test just for this but shrug
|
||||
# ("make_future_entity_ruler", "spacy.pipeline.factories", "spacy.pipeline.span_ruler"),
|
||||
],
|
||||
)
|
||||
def test_factory_import_compatibility(factory_name, original_module, compat_module):
|
||||
"""Test that factory functions can be imported from both original and compatibility locations."""
|
||||
# Import from the original module (registrations.py)
|
||||
original_module_obj = importlib.import_module(original_module)
|
||||
original_factory = getattr(original_module_obj, factory_name)
|
||||
assert (
|
||||
original_factory is not None
|
||||
), f"Could not import {factory_name} from {original_module}"
|
||||
|
||||
# Import from the compatibility module (component file)
|
||||
compat_module_obj = importlib.import_module(compat_module)
|
||||
compat_factory = getattr(compat_module_obj, factory_name)
|
||||
assert (
|
||||
compat_factory is not None
|
||||
), f"Could not import {factory_name} from {compat_module}"
|
||||
|
||||
# Test that they're the same function (identity)
|
||||
assert original_factory is compat_factory, (
|
||||
f"Factory {factory_name} imported from {original_module} is not the same object "
|
||||
f"as the one imported from {compat_module}"
|
||||
)
|
97
spacy/tests/test_factory_registrations.py
Normal file
97
spacy/tests/test_factory_registrations.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from spacy.language import Language
|
||||
from spacy.util import registry
|
||||
|
||||
# Path to the reference factory registrations, relative to this file
|
||||
REFERENCE_FILE = Path(__file__).parent / "factory_registrations.json"
|
||||
|
||||
# Monkey patch the util.is_same_func to handle Cython functions
|
||||
import inspect
|
||||
|
||||
from spacy import util
|
||||
|
||||
original_is_same_func = util.is_same_func
|
||||
|
||||
|
||||
def patched_is_same_func(func1, func2):
|
||||
# Handle Cython functions
|
||||
try:
|
||||
return original_is_same_func(func1, func2)
|
||||
except TypeError:
|
||||
# For Cython functions, just compare the string representation
|
||||
return str(func1) == str(func2)
|
||||
|
||||
|
||||
util.is_same_func = patched_is_same_func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reference_factory_registrations():
|
||||
"""Load reference factory registrations from JSON file"""
|
||||
if not REFERENCE_FILE.exists():
|
||||
pytest.fail(
|
||||
f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first."
|
||||
)
|
||||
|
||||
with REFERENCE_FILE.open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def test_factory_registrations_preserved(reference_factory_registrations):
|
||||
"""Test that all factory registrations from the reference file are still present."""
|
||||
# Ensure the registry is populated
|
||||
registry.ensure_populated()
|
||||
|
||||
# Get all factory registrations
|
||||
all_factories = registry.factories.get_all()
|
||||
|
||||
# Initialize our data structure to store current factory registrations
|
||||
current_registrations = {}
|
||||
|
||||
# Process factory registrations
|
||||
for name, func in all_factories.items():
|
||||
# Store information about each factory
|
||||
try:
|
||||
module_name = func.__module__
|
||||
except (AttributeError, TypeError):
|
||||
# For Cython functions, just use a placeholder
|
||||
module_name = str(func).split()[1].split(".")[0]
|
||||
|
||||
try:
|
||||
func_name = func.__qualname__
|
||||
except (AttributeError, TypeError):
|
||||
# For Cython functions, use the function's name
|
||||
func_name = (
|
||||
func.__name__
|
||||
if hasattr(func, "__name__")
|
||||
else str(func).split()[1].split(".")[-1]
|
||||
)
|
||||
|
||||
current_registrations[name] = {
|
||||
"name": name,
|
||||
"module": module_name,
|
||||
"function": func_name,
|
||||
}
|
||||
|
||||
# Check for missing registrations
|
||||
missing_registrations = set(reference_factory_registrations.keys()) - set(
|
||||
current_registrations.keys()
|
||||
)
|
||||
assert (
|
||||
not missing_registrations
|
||||
), f"Missing factory registrations: {', '.join(sorted(missing_registrations))}"
|
||||
|
||||
# Check for new registrations (not an error, but informative)
|
||||
new_registrations = set(current_registrations.keys()) - set(
|
||||
reference_factory_registrations.keys()
|
||||
)
|
||||
if new_registrations:
|
||||
# This is not an error, just informative
|
||||
print(
|
||||
f"New factory registrations found: {', '.join(sorted(new_registrations))}"
|
||||
)
|
55
spacy/tests/test_registry_population.py
Normal file
55
spacy/tests/test_registry_population.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from spacy.util import registry
|
||||
|
||||
# Path to the reference registry contents, relative to this file
|
||||
REFERENCE_FILE = Path(__file__).parent / "registry_contents.json"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reference_registry():
|
||||
"""Load reference registry contents from JSON file"""
|
||||
if not REFERENCE_FILE.exists():
|
||||
pytest.fail(f"Reference file {REFERENCE_FILE} not found.")
|
||||
|
||||
with REFERENCE_FILE.open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def test_registry_types(reference_registry):
|
||||
"""Test that all registry types match the reference"""
|
||||
# Get current registry types
|
||||
current_registry_types = set(registry.get_registry_names())
|
||||
expected_registry_types = set(reference_registry.keys())
|
||||
|
||||
# Check for missing registry types
|
||||
missing_types = expected_registry_types - current_registry_types
|
||||
assert not missing_types, f"Missing registry types: {', '.join(missing_types)}"
|
||||
|
||||
|
||||
def test_registry_entries(reference_registry):
|
||||
"""Test that all registry entries are present"""
|
||||
# Check each registry's entries
|
||||
for registry_name, expected_entries in reference_registry.items():
|
||||
# Skip if this registry type doesn't exist
|
||||
if not hasattr(registry, registry_name):
|
||||
pytest.fail(f"Registry '{registry_name}' does not exist.")
|
||||
|
||||
# Get current entries
|
||||
reg = getattr(registry, registry_name)
|
||||
current_entries = sorted(list(reg.get_all().keys()))
|
||||
|
||||
# Compare entries
|
||||
expected_set = set(expected_entries)
|
||||
current_set = set(current_entries)
|
||||
|
||||
# Check for missing entries - these would indicate our new registry population
|
||||
# mechanism is missing something
|
||||
missing_entries = expected_set - current_set
|
||||
assert (
|
||||
not missing_entries
|
||||
), f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}"
|
|
@ -101,7 +101,11 @@ def test_cat_readers(reader, additional_config):
|
|||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
print("T", T)
|
||||
print("dot names", dot_names)
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
data = list(train_corpus(nlp))
|
||||
print(len(data))
|
||||
optimizer = T["optimizer"]
|
||||
# simulate a training loop
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
|
|
|
@ -867,11 +867,11 @@ cdef extern from "<algorithm>" namespace "std" nogil:
|
|||
bint (*)(SpanC, SpanC))
|
||||
|
||||
|
||||
cdef bint len_start_cmp(SpanC a, SpanC b) nogil:
|
||||
cdef bint len_start_cmp(SpanC a, SpanC b) noexcept nogil:
|
||||
if a.end - a.start == b.end - b.start:
|
||||
return b.start < a.start
|
||||
return a.end - a.start < b.end - b.start
|
||||
|
||||
|
||||
cdef bint start_cmp(SpanC a, SpanC b) nogil:
|
||||
cdef bint start_cmp(SpanC a, SpanC b) noexcept nogil:
|
||||
return a.start < b.start
|
||||
|
|
|
@ -7,8 +7,8 @@ from ..typedefs cimport attr_t
|
|||
from ..vocab cimport Vocab
|
||||
|
||||
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) noexcept nogil
|
||||
|
||||
|
||||
ctypedef const LexemeC* const_Lexeme_ptr
|
||||
|
|
|
@ -71,7 +71,7 @@ cdef int bounds_check(int i, int length, int padding) except -1:
|
|||
raise IndexError(Errors.E026.format(i=i, length=length))
|
||||
|
||||
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil:
|
||||
if feat_name == LEMMA:
|
||||
return token.lemma
|
||||
elif feat_name == NORM:
|
||||
|
@ -106,7 +106,7 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
|||
return Lexeme.get_struct_attr(token.lex, feat_name)
|
||||
|
||||
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil:
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) noexcept nogil:
|
||||
if feat_name == SENT_START:
|
||||
if token.sent_start == 1:
|
||||
return True
|
||||
|
|
|
@ -33,7 +33,7 @@ cdef class Token:
|
|||
cpdef bint check_flag(self, attr_id_t flag_id) except -1
|
||||
|
||||
@staticmethod
|
||||
cdef inline attr_t get_struct_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
||||
cdef inline attr_t get_struct_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil:
|
||||
if feat_name < (sizeof(flags_t) * 8):
|
||||
return Lexeme.c_check_flag(token.lex, feat_name)
|
||||
elif feat_name == LEMMA:
|
||||
|
@ -70,7 +70,7 @@ cdef class Token:
|
|||
|
||||
@staticmethod
|
||||
cdef inline attr_t set_struct_attr(TokenC* token, attr_id_t feat_name,
|
||||
attr_t value) nogil:
|
||||
attr_t value) noexcept nogil:
|
||||
if feat_name == LEMMA:
|
||||
token.lemma = value
|
||||
elif feat_name == NORM:
|
||||
|
@ -99,9 +99,9 @@ cdef class Token:
|
|||
token.sent_start = value
|
||||
|
||||
@staticmethod
|
||||
cdef inline int missing_dep(const TokenC* token) nogil:
|
||||
cdef inline int missing_dep(const TokenC* token) noexcept nogil:
|
||||
return token.dep == MISSING_DEP
|
||||
|
||||
@staticmethod
|
||||
cdef inline int missing_head(const TokenC* token) nogil:
|
||||
cdef inline int missing_head(const TokenC* token) noexcept nogil:
|
||||
return Token.missing_dep(token)
|
||||
|
|
|
@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
|||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
@registry.augmenters("spacy.combined_augmenter.v1")
|
||||
def create_combined_augmenter(
|
||||
lower_level: float,
|
||||
orth_level: float,
|
||||
|
@ -84,7 +83,6 @@ def combined_augmenter(
|
|||
yield example
|
||||
|
||||
|
||||
@registry.augmenters("spacy.orth_variants.v1")
|
||||
def create_orth_variants_augmenter(
|
||||
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
|
@ -102,7 +100,6 @@ def create_orth_variants_augmenter(
|
|||
)
|
||||
|
||||
|
||||
@registry.augmenters("spacy.lower_case.v1")
|
||||
def create_lower_casing_augmenter(
|
||||
level: float,
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
|
|
|
@ -19,7 +19,6 @@ ItemT = TypeVar("ItemT")
|
|||
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
|
||||
|
||||
@registry.batchers("spacy.batch_by_padded.v1")
|
||||
def configure_minibatch_by_padded_size(
|
||||
*,
|
||||
size: Sizing,
|
||||
|
@ -54,7 +53,6 @@ def configure_minibatch_by_padded_size(
|
|||
)
|
||||
|
||||
|
||||
@registry.batchers("spacy.batch_by_words.v1")
|
||||
def configure_minibatch_by_words(
|
||||
*,
|
||||
size: Sizing,
|
||||
|
@ -82,7 +80,6 @@ def configure_minibatch_by_words(
|
|||
)
|
||||
|
||||
|
||||
@registry.batchers("spacy.batch_by_sequence.v1")
|
||||
def configure_minibatch(
|
||||
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
|
|
|
@ -7,7 +7,6 @@ if TYPE_CHECKING:
|
|||
from ..language import Language
|
||||
|
||||
|
||||
@registry.callbacks("spacy.copy_from_base_model.v1")
|
||||
def create_copy_from_base_model(
|
||||
tokenizer: Optional[str] = None,
|
||||
vocab: Optional[str] = None,
|
||||
|
|
|
@ -29,7 +29,6 @@ def setup_table(
|
|||
|
||||
# We cannot rename this method as it's directly imported
|
||||
# and used by external packages such as spacy-loggers.
|
||||
@registry.loggers("spacy.ConsoleLogger.v2")
|
||||
def console_logger(
|
||||
progress_bar: bool = False,
|
||||
console_output: bool = True,
|
||||
|
@ -47,7 +46,6 @@ def console_logger(
|
|||
)
|
||||
|
||||
|
||||
@registry.loggers("spacy.ConsoleLogger.v3")
|
||||
def console_logger_v3(
|
||||
progress_bar: Optional[str] = None,
|
||||
console_output: bool = True,
|
||||
|
|
|
@ -132,9 +132,18 @@ class registry(thinc.registry):
|
|||
models = catalogue.create("spacy", "models", entry_points=True)
|
||||
cli = catalogue.create("spacy", "cli", entry_points=True)
|
||||
|
||||
@classmethod
|
||||
def ensure_populated(cls) -> None:
|
||||
"""Ensure the registry is populated with all necessary components."""
|
||||
from .registrations import REGISTRY_POPULATED, populate_registry
|
||||
|
||||
if not REGISTRY_POPULATED:
|
||||
populate_registry()
|
||||
|
||||
@classmethod
|
||||
def get_registry_names(cls) -> List[str]:
|
||||
"""List all available registries."""
|
||||
cls.ensure_populated()
|
||||
names = []
|
||||
for name, value in inspect.getmembers(cls):
|
||||
if not name.startswith("_") and isinstance(value, Registry):
|
||||
|
@ -144,6 +153,7 @@ class registry(thinc.registry):
|
|||
@classmethod
|
||||
def get(cls, registry_name: str, func_name: str) -> Callable:
|
||||
"""Get a registered function from the registry."""
|
||||
cls.ensure_populated()
|
||||
# We're overwriting this classmethod so we're able to provide more
|
||||
# specific error messages and implement a fallback to spacy-legacy.
|
||||
if not hasattr(cls, registry_name):
|
||||
|
@ -179,6 +189,7 @@ class registry(thinc.registry):
|
|||
func_name (str): Name of the registered function.
|
||||
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
|
||||
"""
|
||||
cls.ensure_populated()
|
||||
# We're overwriting this classmethod so we're able to provide more
|
||||
# specific error messages and implement a fallback to spacy-legacy.
|
||||
if not hasattr(cls, registry_name):
|
||||
|
@ -205,6 +216,7 @@ class registry(thinc.registry):
|
|||
@classmethod
|
||||
def has(cls, registry_name: str, func_name: str) -> bool:
|
||||
"""Check whether a function is available in a registry."""
|
||||
cls.ensure_populated()
|
||||
if not hasattr(cls, registry_name):
|
||||
return False
|
||||
reg = getattr(cls, registry_name)
|
||||
|
@ -1323,7 +1335,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]:
|
|||
return filter_spans(itertools.chain(*spans))
|
||||
|
||||
|
||||
@registry.misc("spacy.first_longest_spans_filter.v1")
|
||||
def make_first_longest_spans_filter():
|
||||
return filter_chain_spans
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ cdef class Vectors(BaseVectors):
|
|||
self.hash_seed = hash_seed
|
||||
self.bow = bow
|
||||
self.eow = eow
|
||||
if isinstance(attr, (int, long)):
|
||||
if isinstance(attr, int):
|
||||
self.attr = attr
|
||||
else:
|
||||
attr = attr.upper()
|
||||
|
|
Loading…
Reference in New Issue
Block a user