diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3da2b63d8..6ee1b8af4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -45,11 +45,12 @@ jobs: run: | python -m pip install flake8==5.0.4 python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics - - name: cython-lint - run: | - python -m pip install cython-lint -c requirements.txt - # E501: line too log, W291: trailing whitespace, E266: too many leading '#' for block comment - cython-lint spacy --ignore E501,W291,E266 + # Unfortunately cython-lint isn't working after the shift to Cython 3. + #- name: cython-lint + # run: | + # python -m pip install cython-lint -c requirements.txt + # # E501: line too log, W291: trailing whitespace, E266: too many leading '#' for block comment + # cython-lint spacy --ignore E501,W291,E266 tests: name: Test @@ -58,7 +59,7 @@ jobs: fail-fast: true matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python_version: ["3.9", "3.12"] + python_version: ["3.9", "3.12", "3.13"] runs-on: ${{ matrix.os }} diff --git a/MANIFEST.in b/MANIFEST.in index 8ded6f808..1caf75846 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,5 +4,6 @@ include README.md include pyproject.toml include spacy/py.typed recursive-include spacy/cli *.yml +recursive-include spacy/tests *.json recursive-include licenses * recursive-exclude spacy *.cpp diff --git a/pyproject.toml b/pyproject.toml index edebbff52..06289ccab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ "setuptools", - "cython>=0.25,<3.0", + "cython>=3.0,<4.0", "cymem>=2.0.2,<2.1.0", "preshed>=3.0.2,<3.1.0", "murmurhash>=0.28.0,<1.1.0", diff --git a/requirements.txt b/requirements.txt index bfdcf0d96..4f4383300 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ setuptools packaging>=20.0 # Development dependencies pre-commit>=2.13.0 -cython>=0.25,<3.0 +cython>=3.0,<4.0 pytest>=5.2.0,!=7.1.0 pytest-timeout>=1.3.0,<2.0.0 mock>=2.0.0,<3.0.0 diff --git a/setup.cfg b/setup.cfg index d43a782d5..bc7b6e9d7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,11 +30,11 @@ project_urls = [options] zip_safe = false include_package_data = true -python_requires = >=3.9,<3.13 +python_requires = >=3.9,<3.14 # NOTE: This section is superseded by pyproject.toml and will be removed in # spaCy v4 setup_requires = - cython>=0.25,<3.0 + cython>=3.0,<4.0 numpy>=2.0.0,<3.0.0; python_version < "3.9" numpy>=2.0.0,<3.0.0; python_version >= "3.9" # We also need our Cython packages here to compile against diff --git a/spacy/__init__.py b/spacy/__init__.py index 1a18ad0d5..8bb8b4949 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -17,6 +17,7 @@ from .cli.info import info # noqa: F401 from .errors import Errors from .glossary import explain # noqa: F401 from .language import Language +from .registrations import REGISTRY_POPULATED, populate_registry from .util import logger, registry # noqa: F401 from .vocab import Vocab diff --git a/spacy/lang/ja/__init__.py b/spacy/lang/ja/__init__.py index 0d5f97ac8..e21e85cd9 100644 --- a/spacy/lang/ja/__init__.py +++ b/spacy/lang/ja/__init__.py @@ -32,7 +32,6 @@ split_mode = null """ -@registry.tokenizers("spacy.ja.JapaneseTokenizer") def create_tokenizer(split_mode: Optional[str] = None): def japanese_tokenizer_factory(nlp): return JapaneseTokenizer(nlp.vocab, split_mode=split_mode) diff --git a/spacy/lang/ko/__init__.py b/spacy/lang/ko/__init__.py index e2c860f7d..3231e191a 100644 --- a/spacy/lang/ko/__init__.py +++ b/spacy/lang/ko/__init__.py @@ -20,7 +20,6 @@ DEFAULT_CONFIG = """ """ -@registry.tokenizers("spacy.ko.KoreanTokenizer") def create_tokenizer(): def korean_tokenizer_factory(nlp): return KoreanTokenizer(nlp.vocab) diff --git a/spacy/lang/th/__init__.py b/spacy/lang/th/__init__.py index bd29d32a4..551f50eee 100644 --- a/spacy/lang/th/__init__.py +++ b/spacy/lang/th/__init__.py @@ -13,7 +13,6 @@ DEFAULT_CONFIG = """ """ -@registry.tokenizers("spacy.th.ThaiTokenizer") def create_thai_tokenizer(): def thai_tokenizer_factory(nlp): return ThaiTokenizer(nlp.vocab) diff --git a/spacy/lang/vi/__init__.py b/spacy/lang/vi/__init__.py index a621b8bfe..ae1fa469d 100644 --- a/spacy/lang/vi/__init__.py +++ b/spacy/lang/vi/__init__.py @@ -22,7 +22,6 @@ use_pyvi = true """ -@registry.tokenizers("spacy.vi.VietnameseTokenizer") def create_vietnamese_tokenizer(use_pyvi: bool = True): def vietnamese_tokenizer_factory(nlp): return VietnameseTokenizer(nlp.vocab, use_pyvi=use_pyvi) diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py index f7bb09277..6ad044c60 100644 --- a/spacy/lang/zh/__init__.py +++ b/spacy/lang/zh/__init__.py @@ -46,7 +46,6 @@ class Segmenter(str, Enum): return list(cls.__members__.keys()) -@registry.tokenizers("spacy.zh.ChineseTokenizer") def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char): def chinese_tokenizer_factory(nlp): return ChineseTokenizer(nlp.vocab, segmenter=segmenter) diff --git a/spacy/language.py b/spacy/language.py index 93840c922..9cdd724f5 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -104,7 +104,6 @@ class BaseDefaults: writing_system = {"direction": "ltr", "has_case": True, "has_letters": True} -@registry.tokenizers("spacy.Tokenizer.v1") def create_tokenizer() -> Callable[["Language"], Tokenizer]: """Registered function to create a tokenizer. Returns a factory that takes the nlp object and returns a Tokenizer instance using the language detaults. @@ -130,7 +129,6 @@ def create_tokenizer() -> Callable[["Language"], Tokenizer]: return tokenizer_factory -@registry.misc("spacy.LookupsDataLoader.v1") def load_lookups_data(lang, tables): util.logger.debug("Loading lookups from spacy-lookups-data: %s", tables) lookups = load_lookups(lang=lang, tables=tables) @@ -185,6 +183,9 @@ class Language: DOCS: https://spacy.io/api/language#init """ + from .pipeline.factories import register_factories + + register_factories() # We're only calling this to import all factories provided via entry # points. The factory decorator applied to these functions takes care # of the rest. diff --git a/spacy/lexeme.pxd b/spacy/lexeme.pxd index ff2e4f92e..a16a14f76 100644 --- a/spacy/lexeme.pxd +++ b/spacy/lexeme.pxd @@ -35,7 +35,7 @@ cdef class Lexeme: return self @staticmethod - cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) nogil: + cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) noexcept nogil: if name < (sizeof(flags_t) * 8): Lexeme.c_set_flag(lex, name, value) elif name == ID: @@ -54,7 +54,7 @@ cdef class Lexeme: lex.lang = value @staticmethod - cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil: + cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) noexcept nogil: if feat_name < (sizeof(flags_t) * 8): if Lexeme.c_check_flag(lex, feat_name): return 1 @@ -82,7 +82,7 @@ cdef class Lexeme: return 0 @staticmethod - cdef inline bint c_check_flag(const LexemeC* lexeme, attr_id_t flag_id) nogil: + cdef inline bint c_check_flag(const LexemeC* lexeme, attr_id_t flag_id) noexcept nogil: cdef flags_t one = 1 if lexeme.flags & (one << flag_id): return True @@ -90,7 +90,7 @@ cdef class Lexeme: return False @staticmethod - cdef inline bint c_set_flag(LexemeC* lex, attr_id_t flag_id, bint value) nogil: + cdef inline bint c_set_flag(LexemeC* lex, attr_id_t flag_id, bint value) noexcept nogil: cdef flags_t one = 1 if value: lex.flags |= one << flag_id diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index 7a0c19bf3..8886dde01 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -70,7 +70,7 @@ cdef class Lexeme: if isinstance(other, Lexeme): a = self.orth b = other.orth - elif isinstance(other, long): + elif isinstance(other, int): a = self.orth b = other elif isinstance(other, str): @@ -104,7 +104,7 @@ cdef class Lexeme: # skip PROB, e.g. from lexemes.jsonl if isinstance(value, float): continue - elif isinstance(value, (int, long)): + elif isinstance(value, int): Lexeme.set_struct_attr(self.c, attr, value) else: Lexeme.set_struct_attr(self.c, attr, self.vocab.strings.add(value)) diff --git a/spacy/matcher/levenshtein.pyx b/spacy/matcher/levenshtein.pyx index e394f2cf4..1bafdbbcb 100644 --- a/spacy/matcher/levenshtein.pyx +++ b/spacy/matcher/levenshtein.pyx @@ -1,4 +1,4 @@ -# cython: binding=True, infer_types=True +# cython: binding=True, infer_types=True, language_level=3 from cpython.object cimport PyObject from libc.stdint cimport int64_t @@ -27,6 +27,5 @@ cpdef bint levenshtein_compare(input_text: str, pattern_text: str, fuzzy: int = return levenshtein(input_text, pattern_text, max_edits) <= max_edits -@registry.misc("spacy.levenshtein_compare.v1") def make_levenshtein_compare(): return levenshtein_compare diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 9a9ed4212..64c26c82a 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -625,7 +625,7 @@ cdef action_t get_action( const TokenC * token, const attr_t * extra_attrs, const int8_t * predicate_matches -) nogil: +) noexcept nogil: """We need to consider: a) Does the token match the specification? [Yes, No] b) What's the quantifier? [1, 0+, ?] @@ -740,7 +740,7 @@ cdef int8_t get_is_match( const TokenC* token, const attr_t* extra_attrs, const int8_t* predicate_matches -) nogil: +) noexcept nogil: for i in range(state.pattern.nr_py): if predicate_matches[state.pattern.py_predicates[i]] == -1: return 0 @@ -755,14 +755,14 @@ cdef int8_t get_is_match( return True -cdef inline int8_t get_is_final(PatternStateC state) nogil: +cdef inline int8_t get_is_final(PatternStateC state) noexcept nogil: if state.pattern[1].quantifier == FINAL_ID: return 1 else: return 0 -cdef inline int8_t get_quantifier(PatternStateC state) nogil: +cdef inline int8_t get_quantifier(PatternStateC state) noexcept nogil: return state.pattern.quantifier @@ -805,7 +805,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) return pattern -cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: +cdef attr_t get_ent_id(const TokenPatternC* pattern) noexcept nogil: while pattern.quantifier != FINAL_ID: pattern += 1 id_attr = pattern[0].attrs[0] diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 4efcdb05c..ccc830e35 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -47,7 +47,7 @@ cdef class PhraseMatcher: self._terminal_hash = 826361138722620965 map_init(self.mem, self.c_map, 8) - if isinstance(attr, (int, long)): + if isinstance(attr, int): self.attr = attr else: if attr is None: diff --git a/spacy/ml/_character_embed.py b/spacy/ml/_character_embed.py index 89c836144..fde73f35b 100644 --- a/spacy/ml/_character_embed.py +++ b/spacy/ml/_character_embed.py @@ -7,7 +7,6 @@ from ..tokens import Doc from ..util import registry -@registry.layers("spacy.CharEmbed.v1") def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]: # nM: Number of dimensions per character. nC: Number of characters. return Model( diff --git a/spacy/ml/_precomputable_affine.py b/spacy/ml/_precomputable_affine.py index 1c20c622b..cdcac0c38 100644 --- a/spacy/ml/_precomputable_affine.py +++ b/spacy/ml/_precomputable_affine.py @@ -3,7 +3,6 @@ from thinc.api import Model, normal_init from ..util import registry -@registry.layers("spacy.PrecomputableAffine.v1") def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1): model = Model( "precomputable_affine", diff --git a/spacy/ml/callbacks.py b/spacy/ml/callbacks.py index e2378a7ba..fefb170ba 100644 --- a/spacy/ml/callbacks.py +++ b/spacy/ml/callbacks.py @@ -50,7 +50,6 @@ def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int): return nlp -@registry.callbacks("spacy.models_with_nvtx_range.v1") def create_models_with_nvtx_range( forward_color: int = -1, backprop_color: int = -1 ) -> Callable[["Language"], "Language"]: @@ -110,7 +109,6 @@ def pipes_with_nvtx_range( return nlp -@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1") def create_models_and_pipes_with_nvtx_range( forward_color: int = -1, backprop_color: int = -1, diff --git a/spacy/ml/extract_ngrams.py b/spacy/ml/extract_ngrams.py index ce7c585cc..d57197312 100644 --- a/spacy/ml/extract_ngrams.py +++ b/spacy/ml/extract_ngrams.py @@ -4,7 +4,6 @@ from ..attrs import LOWER from ..util import registry -@registry.layers("spacy.extract_ngrams.v1") def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model: model: Model = Model("extract_ngrams", forward) model.attrs["ngram_size"] = ngram_size diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py index ac0f5fa1b..d3456b705 100644 --- a/spacy/ml/extract_spans.py +++ b/spacy/ml/extract_spans.py @@ -6,7 +6,6 @@ from thinc.types import Ints1d, Ragged from ..util import registry -@registry.layers("spacy.extract_spans.v1") def extract_spans() -> Model[Tuple[Ragged, Ragged], Ragged]: """Extract spans from a sequence of source arrays, as specified by an array of (start, end) indices. The output is a ragged array of the diff --git a/spacy/ml/featureextractor.py b/spacy/ml/featureextractor.py index 06f1ff51a..2f869ad65 100644 --- a/spacy/ml/featureextractor.py +++ b/spacy/ml/featureextractor.py @@ -6,8 +6,9 @@ from thinc.types import Ints2d from ..tokens import Doc -@registry.layers("spacy.FeatureExtractor.v1") -def FeatureExtractor(columns: List[Union[int, str]]) -> Model[List[Doc], List[Ints2d]]: +def FeatureExtractor( + columns: Union[List[str], List[int], List[Union[int, str]]] +) -> Model[List[Doc], List[Ints2d]]: return Model("extract_features", forward, attrs={"columns": columns}) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index b7100c00a..752d1c443 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -28,7 +28,6 @@ from ...vocab import Vocab from ..extract_spans import extract_spans -@registry.architectures("spacy.EntityLinker.v2") def build_nel_encoder( tok2vec: Model, nO: Optional[int] = None ) -> Model[List[Doc], Floats2d]: @@ -92,7 +91,6 @@ def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callab return out, lambda x: [] -@registry.misc("spacy.KBFromFile.v1") def load_kb( kb_path: Path, ) -> Callable[[Vocab], KnowledgeBase]: @@ -104,7 +102,6 @@ def load_kb( return kb_from_file -@registry.misc("spacy.EmptyKB.v2") def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: def empty_kb_factory(vocab: Vocab, entity_vector_length: int): return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) @@ -112,7 +109,6 @@ def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: return empty_kb_factory -@registry.misc("spacy.EmptyKB.v1") def empty_kb( entity_vector_length: int, ) -> Callable[[Vocab], KnowledgeBase]: @@ -122,12 +118,10 @@ def empty_kb( return empty_kb_factory -@registry.misc("spacy.CandidateGenerator.v1") def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: return get_candidates -@registry.misc("spacy.CandidateBatchGenerator.v1") def create_candidates_batch() -> Callable[ [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] ]: diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py index b7faf1cd7..7c68fe481 100644 --- a/spacy/ml/models/multi_task.py +++ b/spacy/ml/models/multi_task.py @@ -30,7 +30,6 @@ if TYPE_CHECKING: from ...vocab import Vocab # noqa: F401 -@registry.architectures("spacy.PretrainVectors.v1") def create_pretrain_vectors( maxout_pieces: int, hidden_size: int, loss: str ) -> Callable[["Vocab", Model], Model]: @@ -57,7 +56,6 @@ def create_pretrain_vectors( return create_vectors_objective -@registry.architectures("spacy.PretrainCharacters.v1") def create_pretrain_characters( maxout_pieces: int, hidden_size: int, n_characters: int ) -> Callable[["Vocab", Model], Model]: diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py index f6c0e565d..9ff0ac8ba 100644 --- a/spacy/ml/models/parser.py +++ b/spacy/ml/models/parser.py @@ -11,7 +11,6 @@ from .._precomputable_affine import PrecomputableAffine from ..tb_framework import TransitionModel -@registry.architectures("spacy.TransitionBasedParser.v2") def build_tb_parser_model( tok2vec: Model[List[Doc], List[Floats2d]], state_type: Literal["parser", "ner"], diff --git a/spacy/ml/models/span_finder.py b/spacy/ml/models/span_finder.py index d327fc761..8081ed92b 100644 --- a/spacy/ml/models/span_finder.py +++ b/spacy/ml/models/span_finder.py @@ -10,7 +10,6 @@ InT = List[Doc] OutT = Floats2d -@registry.architectures("spacy.SpanFinder.v1") def build_finder_model( tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT] ) -> Model[InT, OutT]: diff --git a/spacy/ml/models/spancat.py b/spacy/ml/models/spancat.py index 140ec553a..91dfb41ed 100644 --- a/spacy/ml/models/spancat.py +++ b/spacy/ml/models/spancat.py @@ -22,7 +22,6 @@ from ...util import registry from ..extract_spans import extract_spans -@registry.layers("spacy.LinearLogistic.v1") def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]: """An output layer for multi-label classification. It uses a linear layer followed by a logistic activation. @@ -30,7 +29,6 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]: return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic()) -@registry.layers("spacy.mean_max_reducer.v1") def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]: """Reduce sequences by concatenating their mean and max pooled vectors, and then combine the concatenated vectors with a hidden layer. @@ -46,7 +44,6 @@ def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]: ) -@registry.architectures("spacy.SpanCategorizer.v1") def build_spancat_model( tok2vec: Model[List[Doc], List[Floats2d]], reducer: Model[Ragged, Floats2d], diff --git a/spacy/ml/models/tagger.py b/spacy/ml/models/tagger.py index 8f1554fab..aec4276db 100644 --- a/spacy/ml/models/tagger.py +++ b/spacy/ml/models/tagger.py @@ -7,7 +7,6 @@ from ...tokens import Doc from ...util import registry -@registry.architectures("spacy.Tagger.v2") def build_tagger_model( tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None, normalize=False ) -> Model[List[Doc], List[Floats2d]]: diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 601c94a7f..49c0dd707 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -44,7 +44,6 @@ from .tok2vec import get_tok2vec_width NEG_VALUE = -5000 -@registry.architectures("spacy.TextCatCNN.v2") def build_simple_cnn_text_classifier( tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None ) -> Model[List[Doc], Floats2d]: @@ -72,7 +71,6 @@ def resize_and_set_ref(model, new_nO, resizable_layer): return model -@registry.architectures("spacy.TextCatBOW.v2") def build_bow_text_classifier( exclusive_classes: bool, ngram_size: int, @@ -88,7 +86,6 @@ def build_bow_text_classifier( ) -@registry.architectures("spacy.TextCatBOW.v3") def build_bow_text_classifier_v3( exclusive_classes: bool, ngram_size: int, @@ -142,7 +139,6 @@ def _build_bow_text_classifier( return model -@registry.architectures("spacy.TextCatEnsemble.v2") def build_text_classifier_v2( tok2vec: Model[List[Doc], List[Floats2d]], linear_model: Model[List[Doc], Floats2d], @@ -200,7 +196,6 @@ def init_ensemble_textcat(model, X, Y) -> Model: return model -@registry.architectures("spacy.TextCatLowData.v1") def build_text_classifier_lowdata( width: int, dropout: Optional[float], nO: Optional[int] = None ) -> Model[List[Doc], Floats2d]: @@ -221,7 +216,6 @@ def build_text_classifier_lowdata( return model -@registry.architectures("spacy.TextCatParametricAttention.v1") def build_textcat_parametric_attention_v1( tok2vec: Model[List[Doc], List[Floats2d]], exclusive_classes: bool, @@ -294,7 +288,6 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model: return model -@registry.architectures("spacy.TextCatReduce.v1") def build_reduce_text_classifier( tok2vec: Model, exclusive_classes: bool, diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 0edc89991..b2b803b6e 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor from ..staticvectors import StaticVectors -@registry.architectures("spacy.Tok2VecListener.v1") def tok2vec_listener_v1(width: int, upstream: str = "*"): tok2vec = Tok2VecListener(upstream_name=upstream, width=width) return tok2vec @@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model): return nO -@registry.architectures("spacy.HashEmbedCNN.v2") def build_hash_embed_cnn_tok2vec( *, width: int, @@ -102,7 +100,6 @@ def build_hash_embed_cnn_tok2vec( ) -@registry.architectures("spacy.Tok2Vec.v2") def build_Tok2Vec_model( embed: Model[List[Doc], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]], @@ -123,10 +120,9 @@ def build_Tok2Vec_model( return tok2vec -@registry.architectures("spacy.MultiHashEmbed.v2") def MultiHashEmbed( width: int, - attrs: List[Union[str, int]], + attrs: Union[List[str], List[int], List[Union[str, int]]], rows: List[int], include_static_vectors: bool, ) -> Model[List[Doc], List[Floats2d]]: @@ -192,7 +188,7 @@ def MultiHashEmbed( ) else: model = chain( - FeatureExtractor(list(attrs)), + FeatureExtractor(attrs), cast(Model[List[Ints2d], Ragged], list2ragged()), with_array(concatenate(*embeddings)), max_out, @@ -201,7 +197,6 @@ def MultiHashEmbed( return model -@registry.architectures("spacy.CharacterEmbed.v2") def CharacterEmbed( width: int, rows: int, @@ -278,7 +273,6 @@ def CharacterEmbed( return model -@registry.architectures("spacy.MaxoutWindowEncoder.v2") def MaxoutWindowEncoder( width: int, window_size: int, maxout_pieces: int, depth: int ) -> Model[List[Floats2d], List[Floats2d]]: @@ -310,7 +304,6 @@ def MaxoutWindowEncoder( return with_array(model, pad=receptive_field) -@registry.architectures("spacy.MishWindowEncoder.v2") def MishWindowEncoder( width: int, window_size: int, depth: int ) -> Model[List[Floats2d], List[Floats2d]]: @@ -333,7 +326,6 @@ def MishWindowEncoder( return with_array(model) -@registry.architectures("spacy.TorchBiLSTMEncoder.v1") def BiLSTMEncoder( width: int, depth: int, dropout: float ) -> Model[List[Floats2d], List[Floats2d]]: diff --git a/spacy/ml/parser_model.pyx b/spacy/ml/parser_model.pyx index f004c562e..96f2487ef 100644 --- a/spacy/ml/parser_model.pyx +++ b/spacy/ml/parser_model.pyx @@ -52,14 +52,14 @@ cdef SizesC get_c_sizes(model, int batch_size) except *: return output -cdef ActivationsC alloc_activations(SizesC n) nogil: +cdef ActivationsC alloc_activations(SizesC n) noexcept nogil: cdef ActivationsC A memset(&A, 0, sizeof(A)) resize_activations(&A, n) return A -cdef void free_activations(const ActivationsC* A) nogil: +cdef void free_activations(const ActivationsC* A) noexcept nogil: free(A.token_ids) free(A.scores) free(A.unmaxed) @@ -67,7 +67,7 @@ cdef void free_activations(const ActivationsC* A) nogil: free(A.is_valid) -cdef void resize_activations(ActivationsC* A, SizesC n) nogil: +cdef void resize_activations(ActivationsC* A, SizesC n) noexcept nogil: if n.states <= A._max_size: A._curr_size = n.states return @@ -100,7 +100,7 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil: cdef void predict_states( CBlas cblas, ActivationsC* A, StateC** states, const WeightsC* W, SizesC n -) nogil: +) noexcept nogil: resize_activations(A, n) for i in range(n.states): states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) @@ -159,7 +159,7 @@ cdef void sum_state_features( int B, int F, int O -) nogil: +) noexcept nogil: cdef int idx, b, f cdef const float* feature padding = cached @@ -183,7 +183,7 @@ cdef void cpu_log_loss( const int* is_valid, const float* scores, int O -) nogil: +) noexcept nogil: """Do multi-label log loss""" cdef double max_, gmax, Z, gZ best = arg_max_if_gold(scores, costs, is_valid, O) @@ -209,7 +209,7 @@ cdef void cpu_log_loss( cdef int arg_max_if_gold( const weight_t* scores, const weight_t* costs, const int* is_valid, int n -) nogil: +) noexcept nogil: # Find minimum cost cdef float cost = 1 for i in range(n): @@ -224,7 +224,7 @@ cdef int arg_max_if_gold( return best -cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil: +cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) noexcept nogil: cdef int best = -1 for i in range(n): if is_valid[i] >= 1: diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index 1a1b0a0ff..122ef3795 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -13,7 +13,6 @@ from ..vectors import Mode, Vectors from ..vocab import Vocab -@registry.layers("spacy.StaticVectors.v2") def StaticVectors( nO: Optional[int] = None, nM: Optional[int] = None, diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index e351ad4e5..16c894f6c 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -4,7 +4,6 @@ from ..util import registry from .parser_model import ParserStepModel -@registry.layers("spacy.TransitionModel.v1") def TransitionModel( tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set() ): diff --git a/spacy/parts_of_speech.pyx b/spacy/parts_of_speech.pyx index 98e3570ec..1e643c099 100644 --- a/spacy/parts_of_speech.pyx +++ b/spacy/parts_of_speech.pyx @@ -25,3 +25,8 @@ IDS = { NAMES = {value: key for key, value in IDS.items()} + +# As of Cython 3.1, the global Python namespace no longer has the enum +# contents by default. +globals().update(IDS) + diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index c063cf97c..ea1a7874b 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -17,7 +17,7 @@ from ...typedefs cimport attr_t from ...vocab cimport EMPTY_LEXEME -cdef inline bint is_space_token(const TokenC* token) nogil: +cdef inline bint is_space_token(const TokenC* token) noexcept nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) cdef struct ArcC: @@ -41,7 +41,7 @@ cdef cppclass StateC: int offset int _b_i - __init__(const TokenC* sent, int length) nogil: + inline __init__(const TokenC* sent, int length) noexcept nogil: this._sent = sent this._heads = calloc(length, sizeof(int)) if not (this._sent and this._heads): @@ -57,10 +57,10 @@ cdef cppclass StateC: memset(&this._empty_token, 0, sizeof(TokenC)) this._empty_token.lex = &EMPTY_LEXEME - __dealloc__(): + inline __dealloc__(): free(this._heads) - void set_context_tokens(int* ids, int n) nogil: + inline void set_context_tokens(int* ids, int n) noexcept nogil: cdef int i, j if n == 1: if this.B(0) >= 0: @@ -131,14 +131,14 @@ cdef cppclass StateC: else: ids[i] = -1 - int S(int i) nogil const: + inline int S(int i) noexcept nogil const: if i >= this._stack.size(): return -1 elif i < 0: return -1 return this._stack.at(this._stack.size() - (i+1)) - int B(int i) nogil const: + inline int B(int i) noexcept nogil const: if i < 0: return -1 elif i < this._rebuffer.size(): @@ -150,19 +150,19 @@ cdef cppclass StateC: else: return b_i - const TokenC* B_(int i) nogil const: + inline const TokenC* B_(int i) noexcept nogil const: return this.safe_get(this.B(i)) - const TokenC* E_(int i) nogil const: + inline const TokenC* E_(int i) noexcept nogil const: return this.safe_get(this.E(i)) - const TokenC* safe_get(int i) nogil const: + inline const TokenC* safe_get(int i) noexcept nogil const: if i < 0 or i >= this.length: return &this._empty_token else: return &this._sent[i] - void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) nogil const: + inline void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) noexcept nogil const: cdef const vector[ArcC]* arcs head_arcs_it = heads_arcs.const_begin() while head_arcs_it != heads_arcs.const_end(): @@ -175,23 +175,23 @@ cdef cppclass StateC: incr(arcs_it) incr(head_arcs_it) - void get_arcs(vector[ArcC]* out) nogil const: + inline void get_arcs(vector[ArcC]* out) noexcept nogil const: this.map_get_arcs(this._left_arcs, out) this.map_get_arcs(this._right_arcs, out) - int H(int child) nogil const: + inline int H(int child) noexcept nogil const: if child >= this.length or child < 0: return -1 else: return this._heads[child] - int E(int i) nogil const: + inline int E(int i) noexcept nogil const: if this._ents.size() == 0: return -1 else: return this._ents.back().start - int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) nogil const: + inline int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) noexcept nogil const: if idx < 1: return -1 @@ -215,22 +215,22 @@ cdef cppclass StateC: return -1 - int L(int head, int idx) nogil const: + inline int L(int head, int idx) noexcept nogil const: return this.nth_child(this._left_arcs, head, idx) - int R(int head, int idx) nogil const: + inline int R(int head, int idx) noexcept nogil const: return this.nth_child(this._right_arcs, head, idx) - bint empty() nogil const: + inline bint empty() noexcept nogil const: return this._stack.size() == 0 - bint eol() nogil const: + inline bint eol() noexcept nogil const: return this.buffer_length() == 0 - bint is_final() nogil const: + inline bint is_final() noexcept nogil const: return this.stack_depth() <= 0 and this.eol() - int cannot_sent_start(int word) nogil const: + inline int cannot_sent_start(int word) noexcept nogil const: if word < 0 or word >= this.length: return 0 elif this._sent[word].sent_start == -1: @@ -238,7 +238,7 @@ cdef cppclass StateC: else: return 0 - int is_sent_start(int word) nogil const: + inline int is_sent_start(int word) noexcept nogil const: if word < 0 or word >= this.length: return 0 elif this._sent[word].sent_start == 1: @@ -248,20 +248,20 @@ cdef cppclass StateC: else: return 0 - void set_sent_start(int word, int value) nogil: + inline void set_sent_start(int word, int value) noexcept nogil: if value >= 1: this._sent_starts.insert(word) - bint has_head(int child) nogil const: + inline bint has_head(int child) noexcept nogil const: return this._heads[child] >= 0 - int l_edge(int word) nogil const: + inline int l_edge(int word) noexcept nogil const: return word - int r_edge(int word) nogil const: + inline int r_edge(int word) noexcept nogil const: return word - int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) nogil const: + inline int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) noexcept nogil const: cdef int n = 0 head_arcs_it = heads_arcs.const_find(head) if head_arcs_it == heads_arcs.const_end(): @@ -277,28 +277,28 @@ cdef cppclass StateC: return n - int n_L(int head) nogil const: + inline int n_L(int head) noexcept nogil const: return n_arcs(this._left_arcs, head) - int n_R(int head) nogil const: + inline int n_R(int head) noexcept nogil const: return n_arcs(this._right_arcs, head) - bint stack_is_connected() nogil const: + inline bint stack_is_connected() noexcept nogil const: return False - bint entity_is_open() nogil const: + inline bint entity_is_open() noexcept nogil const: if this._ents.size() == 0: return False else: return this._ents.back().end == -1 - int stack_depth() nogil const: + inline int stack_depth() noexcept nogil const: return this._stack.size() - int buffer_length() nogil const: + inline int buffer_length() noexcept nogil const: return (this.length - this._b_i) + this._rebuffer.size() - void push() nogil: + inline void push() noexcept nogil: b0 = this.B(0) if this._rebuffer.size(): b0 = this._rebuffer.back() @@ -308,32 +308,32 @@ cdef cppclass StateC: this._b_i += 1 this._stack.push_back(b0) - void pop() nogil: + inline void pop() noexcept nogil: this._stack.pop_back() - void force_final() nogil: + inline void force_final() noexcept nogil: # This should only be used in desperate situations, as it may leave # the analysis in an unexpected state. this._stack.clear() this._b_i = this.length - void unshift() nogil: + inline void unshift() noexcept nogil: s0 = this._stack.back() this._unshiftable[s0] = 1 this._rebuffer.push_back(s0) this._stack.pop_back() - int is_unshiftable(int item) nogil const: + inline int is_unshiftable(int item) noexcept nogil const: if item >= this._unshiftable.size(): return 0 else: return this._unshiftable.at(item) - void set_reshiftable(int item) nogil: + inline void set_reshiftable(int item) noexcept nogil: if item < this._unshiftable.size(): this._unshiftable[item] = 0 - void add_arc(int head, int child, attr_t label) nogil: + inline void add_arc(int head, int child, attr_t label) noexcept nogil: if this.has_head(child): this.del_arc(this.H(child), child) cdef ArcC arc @@ -346,7 +346,7 @@ cdef cppclass StateC: this._right_arcs[arc.head].push_back(arc) this._heads[child] = head - void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil: + inline void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) noexcept nogil: arcs_it = heads_arcs.find(h_i) if arcs_it == heads_arcs.end(): return @@ -367,13 +367,13 @@ cdef cppclass StateC: arc.label = 0 break - void del_arc(int h_i, int c_i) nogil: + inline void del_arc(int h_i, int c_i) noexcept nogil: if h_i > c_i: this.map_del_arc(&this._left_arcs, h_i, c_i) else: this.map_del_arc(&this._right_arcs, h_i, c_i) - SpanC get_ent() nogil const: + inline SpanC get_ent() noexcept nogil const: cdef SpanC ent if this._ents.size() == 0: ent.start = 0 @@ -383,17 +383,17 @@ cdef cppclass StateC: else: return this._ents.back() - void open_ent(attr_t label) nogil: + inline void open_ent(attr_t label) noexcept nogil: cdef SpanC ent ent.start = this.B(0) ent.label = label ent.end = -1 this._ents.push_back(ent) - void close_ent() nogil: + inline void close_ent() noexcept nogil: this._ents.back().end = this.B(0)+1 - void clone(const StateC* src) nogil: + inline void clone(const StateC* src) noexcept nogil: this.length = src.length this._sent = src._sent this._stack = src._stack diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index bedaaf9fe..cccd51fca 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -155,7 +155,7 @@ cdef GoldParseStateC create_gold_state( return gs -cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil: +cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) noexcept nogil: for i in range(gs.length): gs.state_bits[i] = set_state_flag( gs.state_bits[i], @@ -239,12 +239,12 @@ def _get_aligned_sent_starts(example): return [None] * len(example.x) -cdef int check_state_gold(char state_bits, char flag) nogil: +cdef int check_state_gold(char state_bits, char flag) noexcept nogil: cdef char one = 1 return 1 if (state_bits & (one << flag)) else 0 -cdef int set_state_flag(char state_bits, char flag, int value) nogil: +cdef int set_state_flag(char state_bits, char flag, int value) noexcept nogil: cdef char one = 1 if value: return state_bits | (one << flag) @@ -252,27 +252,27 @@ cdef int set_state_flag(char state_bits, char flag, int value) nogil: return state_bits & ~(one << flag) -cdef int is_head_in_stack(const GoldParseStateC* gold, int i) nogil: +cdef int is_head_in_stack(const GoldParseStateC* gold, int i) noexcept nogil: return check_state_gold(gold.state_bits[i], HEAD_IN_STACK) -cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil: +cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) noexcept nogil: return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER) -cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: +cdef int is_head_unknown(const GoldParseStateC* gold, int i) noexcept nogil: return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) -cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil: +cdef int is_sent_start(const GoldParseStateC* gold, int i) noexcept nogil: return check_state_gold(gold.state_bits[i], IS_SENT_START) -cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil: +cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) noexcept nogil: return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN) # Helper functions for the arc-eager oracle -cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil: +cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) noexcept nogil: cdef weight_t cost = 0 b0 = state.B(0) if b0 < 0: @@ -285,7 +285,7 @@ cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil: return cost -cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil: +cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) noexcept nogil: cdef weight_t cost = 0 s0 = state.S(0) if s0 < 0: @@ -296,7 +296,7 @@ cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil: return cost -cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: +cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) noexcept nogil: if is_head_unknown(gold, child): return True elif gold.heads[child] == head: @@ -305,7 +305,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: return False -cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil: +cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) noexcept nogil: if is_head_unknown(gold, child): return True elif label == 0: @@ -316,7 +316,7 @@ cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) no return False -cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil: +cdef bint _is_gold_root(const GoldParseStateC* gold, int word) noexcept nogil: return gold.heads[word] == word or is_head_unknown(gold, word) @@ -336,7 +336,7 @@ cdef class Shift: * Advance buffer """ @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if st.stack_depth() == 0: return 1 elif st.buffer_length() < 2: @@ -349,11 +349,11 @@ cdef class Shift: return 1 @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.push() @staticmethod - cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil: gold = _gold return gold.push_cost @@ -375,7 +375,7 @@ cdef class Reduce: cost by those arcs. """ @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if st.stack_depth() == 0: return False elif st.buffer_length() == 0: @@ -386,14 +386,14 @@ cdef class Reduce: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: if st.has_head(st.S(0)) or st.stack_depth() == 1: st.pop() else: st.unshift() @staticmethod - cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil: gold = _gold if state.is_sent_start(state.B(0)): return 0 @@ -421,7 +421,7 @@ cdef class LeftArc: pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0])) """ @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if st.stack_depth() == 0: return 0 elif st.buffer_length() == 0: @@ -434,7 +434,7 @@ cdef class LeftArc: return 1 @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.add_arc(st.B(0), st.S(0), label) # If we change the stack, it's okay to remove the shifted mark, as # we can't get in an infinite loop this way. @@ -442,7 +442,7 @@ cdef class LeftArc: st.pop() @staticmethod - cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil: gold = _gold cdef weight_t cost = gold.pop_cost s0 = state.S(0) @@ -474,7 +474,7 @@ cdef class RightArc: push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label) """ @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if st.stack_depth() == 0: return 0 elif st.buffer_length() == 0: @@ -488,12 +488,12 @@ cdef class RightArc: return 1 @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.add_arc(st.S(0), st.B(0), label) st.push() @staticmethod - cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil: gold = _gold cost = gold.push_cost s0 = state.S(0) @@ -525,7 +525,7 @@ cdef class Break: * Arcs between S and B[1] """ @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if st.buffer_length() < 2: return False elif st.B(1) != st.B(0) + 1: @@ -538,11 +538,11 @@ cdef class Break: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.set_sent_start(st.B(1), 1) @staticmethod - cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) noexcept nogil: gold = _gold cdef int b0 = state.B(0) cdef int cost = 0 @@ -785,7 +785,7 @@ cdef class ArcEager(TransitionSystem): else: return False - cdef int set_valid(self, int* output, const StateC* st) nogil: + cdef int set_valid(self, int* output, const StateC* st) noexcept nogil: cdef int[N_MOVES] is_valid is_valid[SHIFT] = Shift.is_valid(st, 0) is_valid[REDUCE] = Reduce.is_valid(st, 0) diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index e4312bd2f..84d8ed220 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -110,7 +110,7 @@ cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *: cdef do_func_t[N_MOVES] do_funcs -cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil: +cdef bint _entity_is_sunk(const StateC* state, Transition* golds) noexcept nogil: if not state.entity_is_open(): return False @@ -238,7 +238,7 @@ cdef class BiluoPushDown(TransitionSystem): def add_action(self, int action, label_name, freq=None): cdef attr_t label_id - if not isinstance(label_name, (int, long)): + if not isinstance(label_name, int): label_id = self.strings.add(label_name) else: label_id = label_name @@ -347,21 +347,21 @@ cdef class BiluoPushDown(TransitionSystem): cdef class Missing: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: return False @staticmethod - cdef int transition(StateC* s, attr_t label) nogil: + cdef int transition(StateC* s, attr_t label) noexcept nogil: pass @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: return 9000 cdef class Begin: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if st.entity_is_open(): @@ -400,13 +400,13 @@ cdef class Begin: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.open_ent(label) st.push() st.pop() @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: gold = _gold b0 = s.B(0) cdef int cost = 0 @@ -439,7 +439,7 @@ cdef class Begin: cdef class In: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: if not st.entity_is_open(): return False if st.buffer_length() < 2: @@ -475,12 +475,12 @@ cdef class In: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.push() st.pop() @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: gold = _gold cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT cdef int g_act = gold.ner[s.B(0)].move @@ -510,7 +510,7 @@ cdef class In: cdef class Last: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: @@ -535,13 +535,13 @@ cdef class Last: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.close_ent() st.push() st.pop() @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: gold = _gold b0 = s.B(0) ent_start = s.E(0) @@ -581,7 +581,7 @@ cdef class Last: cdef class Unit: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: @@ -609,14 +609,14 @@ cdef class Unit: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.open_ent(label) st.close_ent() st.push() st.pop() @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label @@ -646,7 +646,7 @@ cdef class Unit: cdef class Out: @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef bint is_valid(const StateC* st, attr_t label) noexcept nogil: cdef int preset_ent_iob = st.B_(0).ent_iob if st.entity_is_open(): return False @@ -658,12 +658,12 @@ cdef class Out: return True @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: + cdef int transition(StateC* st, attr_t label) noexcept nogil: st.push() st.pop() @staticmethod - cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) noexcept nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef weight_t cost = 0 diff --git a/spacy/pipeline/_parser_internals/nonproj.pyx b/spacy/pipeline/_parser_internals/nonproj.pyx index 9e3a21b81..016b8b487 100644 --- a/spacy/pipeline/_parser_internals/nonproj.pyx +++ b/spacy/pipeline/_parser_internals/nonproj.pyx @@ -94,7 +94,7 @@ cdef bool _has_head_as_ancestor(int tokenid, int head, const vector[int]& heads) return False -cdef string heads_to_string(const vector[int]& heads) nogil: +cdef string heads_to_string(const vector[int]& heads) noexcept nogil: cdef vector[int].const_iterator citer cdef string cycle_str diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd index 04cd10d88..74fd48961 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pxd +++ b/spacy/pipeline/_parser_internals/transition_system.pxd @@ -15,22 +15,22 @@ cdef struct Transition: weight_t score - bint (*is_valid)(const StateC* state, attr_t label) nogil - weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil - int (*do)(StateC* state, attr_t label) nogil + bint (*is_valid)(const StateC* state, attr_t label) noexcept nogil + weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) noexcept nogil + int (*do)(StateC* state, attr_t label) noexcept nogil ctypedef weight_t (*get_cost_func_t)( const StateC* state, const void* gold, attr_tlabel -) nogil +) noexcept nogil ctypedef weight_t (*move_cost_func_t)( const StateC* state, const void* gold -) nogil +) noexcept nogil ctypedef weight_t (*label_cost_func_t)( const StateC* state, const void* gold, attr_t label -) nogil +) noexcept nogil -ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil +ctypedef int (*do_func_t)(StateC* state, attr_t label) noexcept nogil ctypedef void* (*init_state_t)(Pool mem, int length, void* tokens) except NULL @@ -53,7 +53,7 @@ cdef class TransitionSystem: cdef Transition init_transition(self, int clas, int move, attr_t label) except * - cdef int set_valid(self, int* output, const StateC* st) nogil + cdef int set_valid(self, int* output, const StateC* st) noexcept nogil cdef int set_costs(self, int* is_valid, weight_t* costs, const StateC* state, gold) except -1 diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index e035053b3..c859135d9 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -149,7 +149,7 @@ cdef class TransitionSystem: action = self.lookup_transition(move_name) return action.is_valid(stcls.c, action.label) - cdef int set_valid(self, int* is_valid, const StateC* st) nogil: + cdef int set_valid(self, int* is_valid, const StateC* st) noexcept nogil: cdef int i for i in range(self.n_moves): is_valid[i] = self.c[i].is_valid(st, self.c[i].label) @@ -191,8 +191,7 @@ cdef class TransitionSystem: def add_action(self, int action, label_name): cdef attr_t label_id - if not isinstance(label_name, int) and \ - not isinstance(label_name, long): + if not isinstance(label_name, int): label_id = self.strings.add(label_name) else: label_id = label_name diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py index 8ac74d92b..cc1e2e37a 100644 --- a/spacy/pipeline/attributeruler.py +++ b/spacy/pipeline/attributeruler.py @@ -1,3 +1,5 @@ +import importlib +import sys from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -22,19 +24,6 @@ TagMapType = Dict[str, Dict[Union[int, str], Union[int, str]]] MorphRulesType = Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]] -@Language.factory( - "attribute_ruler", - default_config={ - "validate": False, - "scorer": {"@scorers": "spacy.attribute_ruler_scorer.v1"}, - }, -) -def make_attribute_ruler( - nlp: Language, name: str, validate: bool, scorer: Optional[Callable] -): - return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer) - - def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def morph_key_getter(token, attr): return getattr(token, attr).key @@ -54,7 +43,6 @@ def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, An return results -@registry.scorers("spacy.attribute_ruler_scorer.v1") def make_attribute_ruler_scorer(): return attribute_ruler_score @@ -355,3 +343,11 @@ def _split_morph_attrs(attrs: dict) -> Tuple[dict, dict]: else: morph_attrs[k] = v return other_attrs, morph_attrs + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_attribute_ruler": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_attribute_ruler + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx index 18a220bd6..881ec2dc4 100644 --- a/spacy/pipeline/dep_parser.pyx +++ b/spacy/pipeline/dep_parser.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from collections import defaultdict from typing import Callable, Optional @@ -39,188 +41,6 @@ subword_features = true DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "parser", - assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], - default_config={ - "moves": None, - "update_with_oracle_cut_size": 100, - "learn_tokens": False, - "min_action_freq": 30, - "model": DEFAULT_PARSER_MODEL, - "scorer": {"@scorers": "spacy.parser_scorer.v1"}, - }, - default_score_weights={ - "dep_uas": 0.5, - "dep_las": 0.5, - "dep_las_per_type": None, - "sents_p": None, - "sents_r": None, - "sents_f": 0.0, - }, -) -def make_parser( - nlp: Language, - name: str, - model: Model, - moves: Optional[TransitionSystem], - update_with_oracle_cut_size: int, - learn_tokens: bool, - min_action_freq: int, - scorer: Optional[Callable], -): - """Create a transition-based DependencyParser component. The dependency parser - jointly learns sentence segmentation and labelled dependency parsing, and can - optionally learn to merge tokens that had been over-segmented by the tokenizer. - - The parser uses a variant of the non-monotonic arc-eager transition-system - described by Honnibal and Johnson (2014), with the addition of a "break" - transition to perform the sentence segmentation. Nivre's pseudo-projective - dependency transformation is used to allow the parser to predict - non-projective parses. - - The parser is trained using an imitation learning objective. The parser follows - the actions predicted by the current weights, and at each state, determines - which actions are compatible with the optimal parse that could be reached - from the current state. The weights such that the scores assigned to the - set of optimal actions is increased, while scores assigned to other - actions are decreased. Note that more than one action may be optimal for - a given state. - - model (Model): The model for the transition-based parser. The model needs - to have a specific substructure of named components --- see the - spacy.ml.tb_framework.TransitionModel for details. - moves (Optional[TransitionSystem]): This defines how the parse-state is created, - updated and evaluated. If 'moves' is None, a new instance is - created with `self.TransitionSystem()`. Defaults to `None`. - update_with_oracle_cut_size (int): During training, cut long sequences into - shorter segments by creating intermediate states based on the gold-standard - history. The model is not very sensitive to this parameter, so you usually - won't need to change it. 100 is a good default. - learn_tokens (bool): Whether to learn to merge subtokens that are split - relative to the gold standard. Experimental. - min_action_freq (int): The minimum frequency of labelled actions to retain. - Rarer labelled actions have their label backed-off to "dep". While this - primarily affects the label accuracy, it can also affect the attachment - structure, as the labels are used to represent the pseudo-projectivity - transformation. - scorer (Optional[Callable]): The scoring method. - """ - return DependencyParser( - nlp.vocab, - model, - name, - moves=moves, - update_with_oracle_cut_size=update_with_oracle_cut_size, - multitasks=[], - learn_tokens=learn_tokens, - min_action_freq=min_action_freq, - beam_width=1, - beam_density=0.0, - beam_update_prob=0.0, - # At some point in the future we can try to implement support for - # partial annotations, perhaps only in the beam objective. - incorrect_spans_key=None, - scorer=scorer, - ) - - -@Language.factory( - "beam_parser", - assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], - default_config={ - "beam_width": 8, - "beam_density": 0.01, - "beam_update_prob": 0.5, - "moves": None, - "update_with_oracle_cut_size": 100, - "learn_tokens": False, - "min_action_freq": 30, - "model": DEFAULT_PARSER_MODEL, - "scorer": {"@scorers": "spacy.parser_scorer.v1"}, - }, - default_score_weights={ - "dep_uas": 0.5, - "dep_las": 0.5, - "dep_las_per_type": None, - "sents_p": None, - "sents_r": None, - "sents_f": 0.0, - }, -) -def make_beam_parser( - nlp: Language, - name: str, - model: Model, - moves: Optional[TransitionSystem], - update_with_oracle_cut_size: int, - learn_tokens: bool, - min_action_freq: int, - beam_width: int, - beam_density: float, - beam_update_prob: float, - scorer: Optional[Callable], -): - """Create a transition-based DependencyParser component that uses beam-search. - The dependency parser jointly learns sentence segmentation and labelled - dependency parsing, and can optionally learn to merge tokens that had been - over-segmented by the tokenizer. - - The parser uses a variant of the non-monotonic arc-eager transition-system - described by Honnibal and Johnson (2014), with the addition of a "break" - transition to perform the sentence segmentation. Nivre's pseudo-projective - dependency transformation is used to allow the parser to predict - non-projective parses. - - The parser is trained using a global objective. That is, it learns to assign - probabilities to whole parses. - - model (Model): The model for the transition-based parser. The model needs - to have a specific substructure of named components --- see the - spacy.ml.tb_framework.TransitionModel for details. - moves (Optional[TransitionSystem]): This defines how the parse-state is created, - updated and evaluated. If 'moves' is None, a new instance is - created with `self.TransitionSystem()`. Defaults to `None`. - update_with_oracle_cut_size (int): During training, cut long sequences into - shorter segments by creating intermediate states based on the gold-standard - history. The model is not very sensitive to this parameter, so you usually - won't need to change it. 100 is a good default. - beam_width (int): The number of candidate analyses to maintain. - beam_density (float): The minimum ratio between the scores of the first and - last candidates in the beam. This allows the parser to avoid exploring - candidates that are too far behind. This is mostly intended to improve - efficiency, but it can also improve accuracy as deeper search is not - always better. - beam_update_prob (float): The chance of making a beam update, instead of a - greedy update. Greedy updates are an approximation for the beam updates, - and are faster to compute. - learn_tokens (bool): Whether to learn to merge subtokens that are split - relative to the gold standard. Experimental. - min_action_freq (int): The minimum frequency of labelled actions to retain. - Rarer labelled actions have their label backed-off to "dep". While this - primarily affects the label accuracy, it can also affect the attachment - structure, as the labels are used to represent the pseudo-projectivity - transformation. - """ - return DependencyParser( - nlp.vocab, - model, - name, - moves=moves, - update_with_oracle_cut_size=update_with_oracle_cut_size, - beam_width=beam_width, - beam_density=beam_density, - beam_update_prob=beam_update_prob, - multitasks=[], - learn_tokens=learn_tokens, - min_action_freq=min_action_freq, - # At some point in the future we can try to implement support for - # partial annotations, perhaps only in the beam objective. - incorrect_spans_key=None, - scorer=scorer, - ) - - def parser_score(examples, **kwargs): """Score a batch of examples. @@ -246,7 +66,6 @@ def parser_score(examples, **kwargs): return results -@registry.scorers("spacy.parser_scorer.v1") def make_parser_scorer(): return parser_score @@ -346,3 +165,14 @@ cdef class DependencyParser(Parser): # because we instead have a label frequency cut-off and back off rare # labels to 'dep'. pass + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_parser": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_parser + elif name == "make_beam_parser": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_beam_parser + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 4a6174bc3..6029ed313 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -1,3 +1,5 @@ +import importlib +import sys from collections import Counter from itertools import islice from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast @@ -39,43 +41,6 @@ subword_features = true DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "trainable_lemmatizer", - assigns=["token.lemma"], - requires=[], - default_config={ - "model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL, - "backoff": "orth", - "min_tree_freq": 3, - "overwrite": False, - "top_k": 1, - "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, - }, - default_score_weights={"lemma_acc": 1.0}, -) -def make_edit_tree_lemmatizer( - nlp: Language, - name: str, - model: Model, - backoff: Optional[str], - min_tree_freq: int, - overwrite: bool, - top_k: int, - scorer: Optional[Callable], -): - """Construct an EditTreeLemmatizer component.""" - return EditTreeLemmatizer( - nlp.vocab, - model, - name, - backoff=backoff, - min_tree_freq=min_tree_freq, - overwrite=overwrite, - top_k=top_k, - scorer=scorer, - ) - - class EditTreeLemmatizer(TrainablePipe): """ Lemmatizer that lemmatizes each word using a predicted edit tree. @@ -421,3 +386,11 @@ class EditTreeLemmatizer(TrainablePipe): self.tree2label[tree_id] = len(self.cfg["labels"]) self.cfg["labels"].append(tree_id) return self.tree2label[tree_id] + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_edit_tree_lemmatizer": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_edit_tree_lemmatizer + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 40a9c8a79..6a1ed11df 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,4 +1,6 @@ +import importlib import random +import sys from itertools import islice from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Union @@ -40,117 +42,10 @@ subword_features = true DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "entity_linker", - requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], - assigns=["token.ent_kb_id"], - default_config={ - "model": DEFAULT_NEL_MODEL, - "labels_discard": [], - "n_sents": 0, - "incl_prior": True, - "incl_context": True, - "entity_vector_length": 64, - "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, - "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, - "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, - "overwrite": True, - "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, - "use_gold_ents": True, - "candidates_batch_size": 1, - "threshold": None, - }, - default_score_weights={ - "nel_micro_f": 1.0, - "nel_micro_r": None, - "nel_micro_p": None, - }, -) -def make_entity_linker( - nlp: Language, - name: str, - model: Model, - *, - labels_discard: Iterable[str], - n_sents: int, - incl_prior: bool, - incl_context: bool, - entity_vector_length: int, - get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] - ], - generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], - overwrite: bool, - scorer: Optional[Callable], - use_gold_ents: bool, - candidates_batch_size: int, - threshold: Optional[float] = None, -): - """Construct an EntityLinker component. - - model (Model[List[Doc], Floats2d]): A model that learns document vector - representations. Given a batch of Doc objects, it should return a single - array, with one row per item in the batch. - labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. - n_sents (int): The number of neighbouring sentences to take into account. - incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. - incl_context (bool): Whether or not to include the local context in the model. - entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. - generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. - scorer (Optional[Callable]): The scoring method. - use_gold_ents (bool): Whether to copy entities from gold docs during training or not. If false, another - component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. - threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, - prediction is discarded. If None, predictions are not filtered by any threshold. - """ - - if not model.attrs.get("include_span_maker", False): - # The only difference in arguments here is that use_gold_ents and threshold aren't available. - return EntityLinker_v1( - nlp.vocab, - model, - name, - labels_discard=labels_discard, - n_sents=n_sents, - incl_prior=incl_prior, - incl_context=incl_context, - entity_vector_length=entity_vector_length, - get_candidates=get_candidates, - overwrite=overwrite, - scorer=scorer, - ) - return EntityLinker( - nlp.vocab, - model, - name, - labels_discard=labels_discard, - n_sents=n_sents, - incl_prior=incl_prior, - incl_context=incl_context, - entity_vector_length=entity_vector_length, - get_candidates=get_candidates, - get_candidates_batch=get_candidates_batch, - generate_empty_kb=generate_empty_kb, - overwrite=overwrite, - scorer=scorer, - use_gold_ents=use_gold_ents, - candidates_batch_size=candidates_batch_size, - threshold=threshold, - ) - - def entity_linker_score(examples, **kwargs): return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs) -@registry.scorers("spacy.entity_linker_scorer.v1") def make_entity_linker_scorer(): return entity_linker_score @@ -676,3 +571,11 @@ class EntityLinker(TrainablePipe): def add_label(self, label): raise NotImplementedError + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_entity_linker": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_entity_linker + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 3683cfc02..2b8c98307 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -1,3 +1,5 @@ +import importlib +import sys import warnings from collections import defaultdict from pathlib import Path @@ -19,51 +21,10 @@ DEFAULT_ENT_ID_SEP = "||" PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] -@Language.factory( - "entity_ruler", - assigns=["doc.ents", "token.ent_type", "token.ent_iob"], - default_config={ - "phrase_matcher_attr": None, - "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, - "validate": False, - "overwrite_ents": False, - "ent_id_sep": DEFAULT_ENT_ID_SEP, - "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, - }, - default_score_weights={ - "ents_f": 1.0, - "ents_p": 0.0, - "ents_r": 0.0, - "ents_per_type": None, - }, -) -def make_entity_ruler( - nlp: Language, - name: str, - phrase_matcher_attr: Optional[Union[int, str]], - matcher_fuzzy_compare: Callable, - validate: bool, - overwrite_ents: bool, - ent_id_sep: str, - scorer: Optional[Callable], -): - return EntityRuler( - nlp, - name, - phrase_matcher_attr=phrase_matcher_attr, - matcher_fuzzy_compare=matcher_fuzzy_compare, - validate=validate, - overwrite_ents=overwrite_ents, - ent_id_sep=ent_id_sep, - scorer=scorer, - ) - - def entity_ruler_score(examples, **kwargs): return get_ner_prf(examples) -@registry.scorers("spacy.entity_ruler_scorer.v1") def make_entity_ruler_scorer(): return entity_ruler_score @@ -539,3 +500,11 @@ class EntityRuler(Pipe): srsly.write_jsonl(path, self.patterns) else: to_disk(path, serializers, {}) + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_entity_ruler": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_entity_ruler + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/factories.py b/spacy/pipeline/factories.py new file mode 100644 index 000000000..f796f2dc8 --- /dev/null +++ b/spacy/pipeline/factories.py @@ -0,0 +1,929 @@ +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +from thinc.api import Model +from thinc.types import Floats2d, Ragged + +from ..kb import Candidate, KnowledgeBase +from ..language import Language +from ..pipeline._parser_internals.transition_system import TransitionSystem +from ..pipeline.attributeruler import AttributeRuler +from ..pipeline.dep_parser import DEFAULT_PARSER_MODEL, DependencyParser +from ..pipeline.edit_tree_lemmatizer import ( + DEFAULT_EDIT_TREE_LEMMATIZER_MODEL, + EditTreeLemmatizer, +) + +# Import factory default configurations +from ..pipeline.entity_linker import DEFAULT_NEL_MODEL, EntityLinker, EntityLinker_v1 +from ..pipeline.entityruler import DEFAULT_ENT_ID_SEP, EntityRuler +from ..pipeline.functions import DocCleaner, TokenSplitter +from ..pipeline.lemmatizer import Lemmatizer +from ..pipeline.morphologizer import DEFAULT_MORPH_MODEL, Morphologizer +from ..pipeline.multitask import DEFAULT_MT_MODEL, MultitaskObjective +from ..pipeline.ner import DEFAULT_NER_MODEL, EntityRecognizer +from ..pipeline.sentencizer import Sentencizer +from ..pipeline.senter import DEFAULT_SENTER_MODEL, SentenceRecognizer +from ..pipeline.span_finder import DEFAULT_SPAN_FINDER_MODEL, SpanFinder +from ..pipeline.span_ruler import DEFAULT_SPANS_KEY as SPAN_RULER_DEFAULT_SPANS_KEY +from ..pipeline.span_ruler import ( + SpanRuler, + prioritize_existing_ents_filter, + prioritize_new_ents_filter, +) +from ..pipeline.spancat import ( + DEFAULT_SPANCAT_MODEL, + DEFAULT_SPANCAT_SINGLELABEL_MODEL, + DEFAULT_SPANS_KEY, + SpanCategorizer, + Suggester, +) +from ..pipeline.tagger import DEFAULT_TAGGER_MODEL, Tagger +from ..pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL, TextCategorizer +from ..pipeline.textcat_multilabel import ( + DEFAULT_MULTI_TEXTCAT_MODEL, + MultiLabel_TextCategorizer, +) +from ..pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL, Tok2Vec +from ..tokens.doc import Doc +from ..tokens.span import Span +from ..vocab import Vocab + +# Global flag to track if factories have been registered +FACTORIES_REGISTERED = False + + +def register_factories() -> None: + """Register all factories with the registry. + + This function registers all pipeline component factories, centralizing + the registrations that were previously done with @Language.factory decorators. + """ + global FACTORIES_REGISTERED + + if FACTORIES_REGISTERED: + return + + # Register factories using the same pattern as Language.factory decorator + # We use Language.factory()() pattern which exactly mimics the decorator + + # attributeruler + Language.factory( + "attribute_ruler", + default_config={ + "validate": False, + "scorer": {"@scorers": "spacy.attribute_ruler_scorer.v1"}, + }, + )(make_attribute_ruler) + + # entity_linker + Language.factory( + "entity_linker", + requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], + assigns=["token.ent_kb_id"], + default_config={ + "model": DEFAULT_NEL_MODEL, + "labels_discard": [], + "n_sents": 0, + "incl_prior": True, + "incl_context": True, + "entity_vector_length": 64, + "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, + "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, + "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, + "overwrite": True, + "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, + "use_gold_ents": True, + "candidates_batch_size": 1, + "threshold": None, + }, + default_score_weights={ + "nel_micro_f": 1.0, + "nel_micro_r": None, + "nel_micro_p": None, + }, + )(make_entity_linker) + + # entity_ruler + Language.factory( + "entity_ruler", + assigns=["doc.ents", "token.ent_type", "token.ent_iob"], + default_config={ + "phrase_matcher_attr": None, + "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, + "validate": False, + "overwrite_ents": False, + "ent_id_sep": DEFAULT_ENT_ID_SEP, + "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, + }, + default_score_weights={ + "ents_f": 1.0, + "ents_p": 0.0, + "ents_r": 0.0, + "ents_per_type": None, + }, + )(make_entity_ruler) + + # lemmatizer + Language.factory( + "lemmatizer", + assigns=["token.lemma"], + default_config={ + "model": None, + "mode": "lookup", + "overwrite": False, + "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, + }, + default_score_weights={"lemma_acc": 1.0}, + )(make_lemmatizer) + + # textcat + Language.factory( + "textcat", + assigns=["doc.cats"], + default_config={ + "threshold": 0.0, + "model": DEFAULT_SINGLE_TEXTCAT_MODEL, + "scorer": {"@scorers": "spacy.textcat_scorer.v2"}, + }, + default_score_weights={ + "cats_score": 1.0, + "cats_score_desc": None, + "cats_micro_p": None, + "cats_micro_r": None, + "cats_micro_f": None, + "cats_macro_p": None, + "cats_macro_r": None, + "cats_macro_f": None, + "cats_macro_auc": None, + "cats_f_per_type": None, + }, + )(make_textcat) + + # token_splitter + Language.factory( + "token_splitter", + default_config={"min_length": 25, "split_length": 10}, + retokenizes=True, + )(make_token_splitter) + + # doc_cleaner + Language.factory( + "doc_cleaner", + default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True}, + )(make_doc_cleaner) + + # tok2vec + Language.factory( + "tok2vec", + assigns=["doc.tensor"], + default_config={"model": DEFAULT_TOK2VEC_MODEL}, + )(make_tok2vec) + + # senter + Language.factory( + "senter", + assigns=["token.is_sent_start"], + default_config={ + "model": DEFAULT_SENTER_MODEL, + "overwrite": False, + "scorer": {"@scorers": "spacy.senter_scorer.v1"}, + }, + default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, + )(make_senter) + + # morphologizer + Language.factory( + "morphologizer", + assigns=["token.morph", "token.pos"], + default_config={ + "model": DEFAULT_MORPH_MODEL, + "overwrite": True, + "extend": False, + "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}, + "label_smoothing": 0.0, + }, + default_score_weights={ + "pos_acc": 0.5, + "morph_acc": 0.5, + "morph_per_feat": None, + }, + )(make_morphologizer) + + # spancat + Language.factory( + "spancat", + assigns=["doc.spans"], + default_config={ + "threshold": 0.5, + "spans_key": DEFAULT_SPANS_KEY, + "max_positive": None, + "model": DEFAULT_SPANCAT_MODEL, + "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, + "scorer": {"@scorers": "spacy.spancat_scorer.v1"}, + }, + default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, + )(make_spancat) + + # spancat_singlelabel + Language.factory( + "spancat_singlelabel", + assigns=["doc.spans"], + default_config={ + "spans_key": DEFAULT_SPANS_KEY, + "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL, + "negative_weight": 1.0, + "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, + "scorer": {"@scorers": "spacy.spancat_scorer.v1"}, + "allow_overlap": True, + }, + default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, + )(make_spancat_singlelabel) + + # future_entity_ruler + Language.factory( + "future_entity_ruler", + assigns=["doc.ents"], + default_config={ + "phrase_matcher_attr": None, + "validate": False, + "overwrite_ents": False, + "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, + "ent_id_sep": "__unused__", + "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, + }, + default_score_weights={ + "ents_f": 1.0, + "ents_p": 0.0, + "ents_r": 0.0, + "ents_per_type": None, + }, + )(make_future_entity_ruler) + + # span_ruler + Language.factory( + "span_ruler", + assigns=["doc.spans"], + default_config={ + "spans_key": SPAN_RULER_DEFAULT_SPANS_KEY, + "spans_filter": None, + "annotate_ents": False, + "ents_filter": {"@misc": "spacy.first_longest_spans_filter.v1"}, + "phrase_matcher_attr": None, + "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, + "validate": False, + "overwrite": True, + "scorer": { + "@scorers": "spacy.overlapping_labeled_spans_scorer.v1", + "spans_key": SPAN_RULER_DEFAULT_SPANS_KEY, + }, + }, + default_score_weights={ + f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_f": 1.0, + f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_p": 0.0, + f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_r": 0.0, + f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_per_type": None, + }, + )(make_span_ruler) + + # trainable_lemmatizer + Language.factory( + "trainable_lemmatizer", + assigns=["token.lemma"], + requires=[], + default_config={ + "model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL, + "backoff": "orth", + "min_tree_freq": 3, + "overwrite": False, + "top_k": 1, + "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, + }, + default_score_weights={"lemma_acc": 1.0}, + )(make_edit_tree_lemmatizer) + + # textcat_multilabel + Language.factory( + "textcat_multilabel", + assigns=["doc.cats"], + default_config={ + "threshold": 0.5, + "model": DEFAULT_MULTI_TEXTCAT_MODEL, + "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"}, + }, + default_score_weights={ + "cats_score": 1.0, + "cats_score_desc": None, + "cats_micro_p": None, + "cats_micro_r": None, + "cats_micro_f": None, + "cats_macro_p": None, + "cats_macro_r": None, + "cats_macro_f": None, + "cats_macro_auc": None, + "cats_f_per_type": None, + }, + )(make_multilabel_textcat) + + # span_finder + Language.factory( + "span_finder", + assigns=["doc.spans"], + default_config={ + "threshold": 0.5, + "model": DEFAULT_SPAN_FINDER_MODEL, + "spans_key": DEFAULT_SPANS_KEY, + "max_length": 25, + "min_length": None, + "scorer": {"@scorers": "spacy.span_finder_scorer.v1"}, + }, + default_score_weights={ + f"spans_{DEFAULT_SPANS_KEY}_f": 1.0, + f"spans_{DEFAULT_SPANS_KEY}_p": 0.0, + f"spans_{DEFAULT_SPANS_KEY}_r": 0.0, + }, + )(make_span_finder) + + # ner + Language.factory( + "ner", + assigns=["doc.ents", "token.ent_iob", "token.ent_type"], + default_config={ + "moves": None, + "update_with_oracle_cut_size": 100, + "model": DEFAULT_NER_MODEL, + "incorrect_spans_key": None, + "scorer": {"@scorers": "spacy.ner_scorer.v1"}, + }, + default_score_weights={ + "ents_f": 1.0, + "ents_p": 0.0, + "ents_r": 0.0, + "ents_per_type": None, + }, + )(make_ner) + + # beam_ner + Language.factory( + "beam_ner", + assigns=["doc.ents", "token.ent_iob", "token.ent_type"], + default_config={ + "moves": None, + "update_with_oracle_cut_size": 100, + "model": DEFAULT_NER_MODEL, + "beam_density": 0.01, + "beam_update_prob": 0.5, + "beam_width": 32, + "incorrect_spans_key": None, + "scorer": {"@scorers": "spacy.ner_scorer.v1"}, + }, + default_score_weights={ + "ents_f": 1.0, + "ents_p": 0.0, + "ents_r": 0.0, + "ents_per_type": None, + }, + )(make_beam_ner) + + # parser + Language.factory( + "parser", + assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], + default_config={ + "moves": None, + "update_with_oracle_cut_size": 100, + "learn_tokens": False, + "min_action_freq": 30, + "model": DEFAULT_PARSER_MODEL, + "scorer": {"@scorers": "spacy.parser_scorer.v1"}, + }, + default_score_weights={ + "dep_uas": 0.5, + "dep_las": 0.5, + "dep_las_per_type": None, + "sents_p": None, + "sents_r": None, + "sents_f": 0.0, + }, + )(make_parser) + + # beam_parser + Language.factory( + "beam_parser", + assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], + default_config={ + "moves": None, + "update_with_oracle_cut_size": 100, + "learn_tokens": False, + "min_action_freq": 30, + "beam_width": 8, + "beam_density": 0.0001, + "beam_update_prob": 0.5, + "model": DEFAULT_PARSER_MODEL, + "scorer": {"@scorers": "spacy.parser_scorer.v1"}, + }, + default_score_weights={ + "dep_uas": 0.5, + "dep_las": 0.5, + "dep_las_per_type": None, + "sents_p": None, + "sents_r": None, + "sents_f": 0.0, + }, + )(make_beam_parser) + + # tagger + Language.factory( + "tagger", + assigns=["token.tag"], + default_config={ + "model": DEFAULT_TAGGER_MODEL, + "overwrite": False, + "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, + "neg_prefix": "!", + "label_smoothing": 0.0, + }, + default_score_weights={ + "tag_acc": 1.0, + "pos_acc": 0.0, + "tag_micro_p": None, + "tag_micro_r": None, + "tag_micro_f": None, + }, + )(make_tagger) + + # nn_labeller + Language.factory( + "nn_labeller", + default_config={ + "labels": None, + "target": "dep_tag_offset", + "model": DEFAULT_MT_MODEL, + }, + )(make_nn_labeller) + + # sentencizer + Language.factory( + "sentencizer", + assigns=["token.is_sent_start", "doc.sents"], + default_config={ + "punct_chars": None, + "overwrite": False, + "scorer": {"@scorers": "spacy.senter_scorer.v1"}, + }, + default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, + )(make_sentencizer) + + # Set the flag to indicate that all factories have been registered + FACTORIES_REGISTERED = True + + +# We can't have function implementations for these factories in Cython, because +# we need to build a Pydantic model for them dynamically, reading their argument +# structure from the signature. In Cython 3, this doesn't work because the +# from __future__ import annotations semantics are used, which means the types +# are stored as strings. +def make_sentencizer( + nlp: Language, + name: str, + punct_chars: Optional[List[str]], + overwrite: bool, + scorer: Optional[Callable], +): + return Sentencizer( + name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer + ) + + +def make_attribute_ruler( + nlp: Language, name: str, validate: bool, scorer: Optional[Callable] +): + return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer) + + +def make_entity_linker( + nlp: Language, + name: str, + model: Model, + *, + labels_discard: Iterable[str], + n_sents: int, + incl_prior: bool, + incl_context: bool, + entity_vector_length: int, + get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], + get_candidates_batch: Callable[ + [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], + overwrite: bool, + scorer: Optional[Callable], + use_gold_ents: bool, + candidates_batch_size: int, + threshold: Optional[float] = None, +): + + if not model.attrs.get("include_span_maker", False): + # The only difference in arguments here is that use_gold_ents and threshold aren't available. + return EntityLinker_v1( + nlp.vocab, + model, + name, + labels_discard=labels_discard, + n_sents=n_sents, + incl_prior=incl_prior, + incl_context=incl_context, + entity_vector_length=entity_vector_length, + get_candidates=get_candidates, + overwrite=overwrite, + scorer=scorer, + ) + return EntityLinker( + nlp.vocab, + model, + name, + labels_discard=labels_discard, + n_sents=n_sents, + incl_prior=incl_prior, + incl_context=incl_context, + entity_vector_length=entity_vector_length, + get_candidates=get_candidates, + get_candidates_batch=get_candidates_batch, + generate_empty_kb=generate_empty_kb, + overwrite=overwrite, + scorer=scorer, + use_gold_ents=use_gold_ents, + candidates_batch_size=candidates_batch_size, + threshold=threshold, + ) + + +def make_lemmatizer( + nlp: Language, + model: Optional[Model], + name: str, + mode: str, + overwrite: bool, + scorer: Optional[Callable], +): + return Lemmatizer( + nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer + ) + + +def make_textcat( + nlp: Language, + name: str, + model: Model[List[Doc], List[Floats2d]], + threshold: float, + scorer: Optional[Callable], +) -> TextCategorizer: + return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) + + +def make_token_splitter( + nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0 +): + return TokenSplitter(min_length=min_length, split_length=split_length) + + +def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool): + return DocCleaner(attrs, silent=silent) + + +def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec: + return Tok2Vec(nlp.vocab, model, name) + + +def make_spancat( + nlp: Language, + name: str, + suggester: Suggester, + model: Model[Tuple[List[Doc], Ragged], Floats2d], + spans_key: str, + scorer: Optional[Callable], + threshold: float, + max_positive: Optional[int], +) -> SpanCategorizer: + return SpanCategorizer( + nlp.vocab, + model=model, + suggester=suggester, + name=name, + spans_key=spans_key, + negative_weight=None, + allow_overlap=True, + max_positive=max_positive, + threshold=threshold, + scorer=scorer, + add_negative_label=False, + ) + + +def make_spancat_singlelabel( + nlp: Language, + name: str, + suggester: Suggester, + model: Model[Tuple[List[Doc], Ragged], Floats2d], + spans_key: str, + negative_weight: float, + allow_overlap: bool, + scorer: Optional[Callable], +) -> SpanCategorizer: + return SpanCategorizer( + nlp.vocab, + model=model, + suggester=suggester, + name=name, + spans_key=spans_key, + negative_weight=negative_weight, + allow_overlap=allow_overlap, + max_positive=1, + add_negative_label=True, + threshold=None, + scorer=scorer, + ) + + +def make_future_entity_ruler( + nlp: Language, + name: str, + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite_ents: bool, + scorer: Optional[Callable], + ent_id_sep: str, +): + if overwrite_ents: + ents_filter = prioritize_new_ents_filter + else: + ents_filter = prioritize_existing_ents_filter + return SpanRuler( + nlp, + name, + spans_key=None, + spans_filter=None, + annotate_ents=True, + ents_filter=ents_filter, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite=False, + scorer=scorer, + ) + + +def make_entity_ruler( + nlp: Language, + name: str, + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite_ents: bool, + ent_id_sep: str, + scorer: Optional[Callable], +): + return EntityRuler( + nlp, + name, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite_ents=overwrite_ents, + ent_id_sep=ent_id_sep, + scorer=scorer, + ) + + +def make_span_ruler( + nlp: Language, + name: str, + spans_key: Optional[str], + spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]], + annotate_ents: bool, + ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]], + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite: bool, + scorer: Optional[Callable], +): + return SpanRuler( + nlp, + name, + spans_key=spans_key, + spans_filter=spans_filter, + annotate_ents=annotate_ents, + ents_filter=ents_filter, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite=overwrite, + scorer=scorer, + ) + + +def make_edit_tree_lemmatizer( + nlp: Language, + name: str, + model: Model, + backoff: Optional[str], + min_tree_freq: int, + overwrite: bool, + top_k: int, + scorer: Optional[Callable], +): + return EditTreeLemmatizer( + nlp.vocab, + model, + name, + backoff=backoff, + min_tree_freq=min_tree_freq, + overwrite=overwrite, + top_k=top_k, + scorer=scorer, + ) + + +def make_multilabel_textcat( + nlp: Language, + name: str, + model: Model[List[Doc], List[Floats2d]], + threshold: float, + scorer: Optional[Callable], +) -> MultiLabel_TextCategorizer: + return MultiLabel_TextCategorizer( + nlp.vocab, model, name, threshold=threshold, scorer=scorer + ) + + +def make_span_finder( + nlp: Language, + name: str, + model: Model[Iterable[Doc], Floats2d], + spans_key: str, + threshold: float, + max_length: Optional[int], + min_length: Optional[int], + scorer: Optional[Callable], +) -> SpanFinder: + return SpanFinder( + nlp, + model=model, + threshold=threshold, + name=name, + scorer=scorer, + max_length=max_length, + min_length=min_length, + spans_key=spans_key, + ) + + +def make_ner( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + incorrect_spans_key: Optional[str], + scorer: Optional[Callable], +): + return EntityRecognizer( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + incorrect_spans_key=incorrect_spans_key, + scorer=scorer, + ) + + +def make_beam_ner( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, + incorrect_spans_key: Optional[str], + scorer: Optional[Callable], +): + return EntityRecognizer( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, + incorrect_spans_key=incorrect_spans_key, + scorer=scorer, + ) + + +def make_parser( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + learn_tokens: bool, + min_action_freq: int, + scorer: Optional[Callable], +): + return DependencyParser( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + learn_tokens=learn_tokens, + min_action_freq=min_action_freq, + scorer=scorer, + ) + + +def make_beam_parser( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + learn_tokens: bool, + min_action_freq: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, + scorer: Optional[Callable], +): + return DependencyParser( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + learn_tokens=learn_tokens, + min_action_freq=min_action_freq, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, + scorer=scorer, + ) + + +def make_tagger( + nlp: Language, + name: str, + model: Model, + overwrite: bool, + scorer: Optional[Callable], + neg_prefix: str, + label_smoothing: float, +): + return Tagger( + nlp.vocab, + model, + name=name, + overwrite=overwrite, + scorer=scorer, + neg_prefix=neg_prefix, + label_smoothing=label_smoothing, + ) + + +def make_nn_labeller( + nlp: Language, name: str, model: Model, labels: Optional[dict], target: str +): + return MultitaskObjective(nlp.vocab, model, name, target=target) + + +def make_morphologizer( + nlp: Language, + model: Model, + name: str, + overwrite: bool, + extend: bool, + label_smoothing: float, + scorer: Optional[Callable], +): + return Morphologizer( + nlp.vocab, + model, + name, + overwrite=overwrite, + extend=extend, + label_smoothing=label_smoothing, + scorer=scorer, + ) + + +def make_senter( + nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable] +): + return SentenceRecognizer( + nlp.vocab, model, name, overwrite=overwrite, scorer=scorer + ) diff --git a/spacy/pipeline/functions.py b/spacy/pipeline/functions.py index 2bf0437d5..e4a3d6d1d 100644 --- a/spacy/pipeline/functions.py +++ b/spacy/pipeline/functions.py @@ -1,3 +1,5 @@ +import importlib +import sys import warnings from typing import Any, Dict @@ -73,17 +75,6 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc: return doc -@Language.factory( - "token_splitter", - default_config={"min_length": 25, "split_length": 10}, - retokenizes=True, -) -def make_token_splitter( - nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0 -): - return TokenSplitter(min_length=min_length, split_length=split_length) - - class TokenSplitter: def __init__(self, min_length: int = 0, split_length: int = 0): self.min_length = min_length @@ -141,14 +132,6 @@ class TokenSplitter: util.from_disk(path, serializers, []) -@Language.factory( - "doc_cleaner", - default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True}, -) -def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool): - return DocCleaner(attrs, silent=silent) - - class DocCleaner: def __init__(self, attrs: Dict[str, Any], *, silent: bool = True): self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent} @@ -201,3 +184,14 @@ class DocCleaner: "cfg": lambda p: self.cfg.update(srsly.read_json(p)), } util.from_disk(path, serializers, []) + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_doc_cleaner": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_doc_cleaner + elif name == "make_token_splitter": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_token_splitter + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/lemmatizer.py b/spacy/pipeline/lemmatizer.py index 09e501595..c08d59a3b 100644 --- a/spacy/pipeline/lemmatizer.py +++ b/spacy/pipeline/lemmatizer.py @@ -1,3 +1,5 @@ +import importlib +import sys import warnings from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -16,35 +18,10 @@ from ..vocab import Vocab from .pipe import Pipe -@Language.factory( - "lemmatizer", - assigns=["token.lemma"], - default_config={ - "model": None, - "mode": "lookup", - "overwrite": False, - "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, - }, - default_score_weights={"lemma_acc": 1.0}, -) -def make_lemmatizer( - nlp: Language, - model: Optional[Model], - name: str, - mode: str, - overwrite: bool, - scorer: Optional[Callable], -): - return Lemmatizer( - nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer - ) - - def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return Scorer.score_token_attr(examples, "lemma", **kwargs) -@registry.scorers("spacy.lemmatizer_scorer.v1") def make_lemmatizer_scorer(): return lemmatizer_score @@ -334,3 +311,11 @@ class Lemmatizer(Pipe): util.from_bytes(bytes_data, deserialize, exclude) self._validate_tables() return self + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_lemmatizer": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_lemmatizer + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index d415ae43c..333f64d29 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from itertools import islice from typing import Callable, Dict, Optional, Union @@ -47,25 +49,6 @@ maxout_pieces = 3 DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "morphologizer", - assigns=["token.morph", "token.pos"], - default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, - "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}, "label_smoothing": 0.0}, - default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None}, -) -def make_morphologizer( - nlp: Language, - model: Model, - name: str, - overwrite: bool, - extend: bool, - label_smoothing: float, - scorer: Optional[Callable], -): - return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, label_smoothing=label_smoothing, scorer=scorer) - - def morphologizer_score(examples, **kwargs): def morph_key_getter(token, attr): return getattr(token, attr).key @@ -81,7 +64,6 @@ def morphologizer_score(examples, **kwargs): return results -@registry.scorers("spacy.morphologizer_scorer.v1") def make_morphologizer_scorer(): return morphologizer_score @@ -309,3 +291,11 @@ class Morphologizer(Tagger): if self.model.ops.xp.isnan(loss): raise ValueError(Errors.E910.format(name=self.name)) return float(loss), d_scores + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_morphologizer": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_morphologizer + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/multitask.pyx b/spacy/pipeline/multitask.pyx index f33a90fde..1ba84b28e 100644 --- a/spacy/pipeline/multitask.pyx +++ b/spacy/pipeline/multitask.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from typing import Optional import numpy @@ -30,14 +32,6 @@ subword_features = true DEFAULT_MT_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "nn_labeller", - default_config={"labels": None, "target": "dep_tag_offset", "model": DEFAULT_MT_MODEL} -) -def make_nn_labeller(nlp: Language, name: str, model: Model, labels: Optional[dict], target: str): - return MultitaskObjective(nlp.vocab, model, name) - - class MultitaskObjective(Tagger): """Experimental: Assist training of a parser or tagger, by training a side-objective. @@ -213,3 +207,11 @@ class ClozeMultitask(TrainablePipe): def add_label(self, label): raise NotImplementedError + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_nn_labeller": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_nn_labeller + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index bb009dc7a..1257a648a 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from collections import defaultdict from typing import Callable, Optional @@ -36,154 +38,10 @@ subword_features = true DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "ner", - assigns=["doc.ents", "token.ent_iob", "token.ent_type"], - default_config={ - "moves": None, - "update_with_oracle_cut_size": 100, - "model": DEFAULT_NER_MODEL, - "incorrect_spans_key": None, - "scorer": {"@scorers": "spacy.ner_scorer.v1"}, - }, - default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, - -) -def make_ner( - nlp: Language, - name: str, - model: Model, - moves: Optional[TransitionSystem], - update_with_oracle_cut_size: int, - incorrect_spans_key: Optional[str], - scorer: Optional[Callable], -): - """Create a transition-based EntityRecognizer component. The entity recognizer - identifies non-overlapping labelled spans of tokens. - - The transition-based algorithm used encodes certain assumptions that are - effective for "traditional" named entity recognition tasks, but may not be - a good fit for every span identification problem. Specifically, the loss - function optimizes for whole entity accuracy, so if your inter-annotator - agreement on boundary tokens is low, the component will likely perform poorly - on your problem. The transition-based algorithm also assumes that the most - decisive information about your entities will be close to their initial tokens. - If your entities are long and characterised by tokens in their middle, the - component will likely do poorly on your task. - - model (Model): The model for the transition-based parser. The model needs - to have a specific substructure of named components --- see the - spacy.ml.tb_framework.TransitionModel for details. - moves (Optional[TransitionSystem]): This defines how the parse-state is created, - updated and evaluated. If 'moves' is None, a new instance is - created with `self.TransitionSystem()`. Defaults to `None`. - update_with_oracle_cut_size (int): During training, cut long sequences into - shorter segments by creating intermediate states based on the gold-standard - history. The model is not very sensitive to this parameter, so you usually - won't need to change it. 100 is a good default. - incorrect_spans_key (Optional[str]): Identifies spans that are known - to be incorrect entity annotations. The incorrect entity annotations - can be stored in the span group, under this key. - scorer (Optional[Callable]): The scoring method. - """ - return EntityRecognizer( - nlp.vocab, - model, - name, - moves=moves, - update_with_oracle_cut_size=update_with_oracle_cut_size, - incorrect_spans_key=incorrect_spans_key, - multitasks=[], - beam_width=1, - beam_density=0.0, - beam_update_prob=0.0, - scorer=scorer, - ) - - -@Language.factory( - "beam_ner", - assigns=["doc.ents", "token.ent_iob", "token.ent_type"], - default_config={ - "moves": None, - "update_with_oracle_cut_size": 100, - "model": DEFAULT_NER_MODEL, - "beam_density": 0.01, - "beam_update_prob": 0.5, - "beam_width": 32, - "incorrect_spans_key": None, - "scorer": None, - }, - default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, -) -def make_beam_ner( - nlp: Language, - name: str, - model: Model, - moves: Optional[TransitionSystem], - update_with_oracle_cut_size: int, - beam_width: int, - beam_density: float, - beam_update_prob: float, - incorrect_spans_key: Optional[str], - scorer: Optional[Callable], -): - """Create a transition-based EntityRecognizer component that uses beam-search. - The entity recognizer identifies non-overlapping labelled spans of tokens. - - The transition-based algorithm used encodes certain assumptions that are - effective for "traditional" named entity recognition tasks, but may not be - a good fit for every span identification problem. Specifically, the loss - function optimizes for whole entity accuracy, so if your inter-annotator - agreement on boundary tokens is low, the component will likely perform poorly - on your problem. The transition-based algorithm also assumes that the most - decisive information about your entities will be close to their initial tokens. - If your entities are long and characterised by tokens in their middle, the - component will likely do poorly on your task. - - model (Model): The model for the transition-based parser. The model needs - to have a specific substructure of named components --- see the - spacy.ml.tb_framework.TransitionModel for details. - moves (Optional[TransitionSystem]): This defines how the parse-state is created, - updated and evaluated. If 'moves' is None, a new instance is - created with `self.TransitionSystem()`. Defaults to `None`. - update_with_oracle_cut_size (int): During training, cut long sequences into - shorter segments by creating intermediate states based on the gold-standard - history. The model is not very sensitive to this parameter, so you usually - won't need to change it. 100 is a good default. - beam_width (int): The number of candidate analyses to maintain. - beam_density (float): The minimum ratio between the scores of the first and - last candidates in the beam. This allows the parser to avoid exploring - candidates that are too far behind. This is mostly intended to improve - efficiency, but it can also improve accuracy as deeper search is not - always better. - beam_update_prob (float): The chance of making a beam update, instead of a - greedy update. Greedy updates are an approximation for the beam updates, - and are faster to compute. - incorrect_spans_key (Optional[str]): Optional key into span groups of - entities known to be non-entities. - scorer (Optional[Callable]): The scoring method. - """ - return EntityRecognizer( - nlp.vocab, - model, - name, - moves=moves, - update_with_oracle_cut_size=update_with_oracle_cut_size, - multitasks=[], - beam_width=beam_width, - beam_density=beam_density, - beam_update_prob=beam_update_prob, - incorrect_spans_key=incorrect_spans_key, - scorer=scorer, - ) - - def ner_score(examples, **kwargs): return get_ner_prf(examples, **kwargs) -@registry.scorers("spacy.ner_scorer.v1") def make_ner_scorer(): return ner_score @@ -261,3 +119,14 @@ cdef class EntityRecognizer(Parser): score_dict[(start, end, label)] += score entity_scores.append(score_dict) return entity_scores + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_ner": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_ner + elif name == "make_beam_ner": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_beam_ner + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx index 72ea7e45a..ea5fc5253 100644 --- a/spacy/pipeline/pipe.pyx +++ b/spacy/pipeline/pipe.pyx @@ -21,13 +21,6 @@ cdef class Pipe: DOCS: https://spacy.io/api/pipe """ - @classmethod - def __init_subclass__(cls, **kwargs): - """Raise a warning if an inheriting class implements 'begin_training' - (from v2) instead of the new 'initialize' method (from v3)""" - if hasattr(cls, "begin_training"): - warnings.warn(Warnings.W088.format(name=cls.__name__)) - def __call__(self, Doc doc) -> Doc: """Apply the pipe to one document. The document is modified in place, and returned. This usually happens under the hood when the nlp object diff --git a/spacy/pipeline/sentencizer.pyx b/spacy/pipeline/sentencizer.pyx index 08ba9d989..d2b0a8d4a 100644 --- a/spacy/pipeline/sentencizer.pyx +++ b/spacy/pipeline/sentencizer.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from typing import Callable, List, Optional import srsly @@ -14,22 +16,6 @@ from .senter import senter_score BACKWARD_OVERWRITE = False -@Language.factory( - "sentencizer", - assigns=["token.is_sent_start", "doc.sents"], - default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}}, - default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, -) -def make_sentencizer( - nlp: Language, - name: str, - punct_chars: Optional[List[str]], - overwrite: bool, - scorer: Optional[Callable], -): - return Sentencizer(name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer) - - class Sentencizer(Pipe): """Segment the Doc into sentences using a rule-based strategy. @@ -181,3 +167,11 @@ class Sentencizer(Pipe): self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars)) self.overwrite = cfg.get("overwrite", self.overwrite) return self + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_sentencizer": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_sentencizer + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index df093baa9..a5d85f438 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from itertools import islice from typing import Callable, Optional @@ -34,16 +36,6 @@ subword_features = true DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "senter", - assigns=["token.is_sent_start"], - default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}}, - default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, -) -def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]): - return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer) - - def senter_score(examples, **kwargs): def has_sents(doc): return doc.has_annotation("SENT_START") @@ -53,7 +45,6 @@ def senter_score(examples, **kwargs): return results -@registry.scorers("spacy.senter_scorer.v1") def make_senter_scorer(): return senter_score @@ -185,3 +176,11 @@ class SentenceRecognizer(Tagger): def add_label(self, label, values=None): raise NotImplementedError + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_senter": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_senter + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index a12d52911..26c9efb6a 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -1,3 +1,5 @@ +import importlib +import sys from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from thinc.api import Config, Model, Optimizer, set_dropout_rate @@ -41,63 +43,6 @@ depth = 4 DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"] -@Language.factory( - "span_finder", - assigns=["doc.spans"], - default_config={ - "threshold": 0.5, - "model": DEFAULT_SPAN_FINDER_MODEL, - "spans_key": DEFAULT_SPANS_KEY, - "max_length": 25, - "min_length": None, - "scorer": {"@scorers": "spacy.span_finder_scorer.v1"}, - }, - default_score_weights={ - f"spans_{DEFAULT_SPANS_KEY}_f": 1.0, - f"spans_{DEFAULT_SPANS_KEY}_p": 0.0, - f"spans_{DEFAULT_SPANS_KEY}_r": 0.0, - }, -) -def make_span_finder( - nlp: Language, - name: str, - model: Model[Iterable[Doc], Floats2d], - spans_key: str, - threshold: float, - max_length: Optional[int], - min_length: Optional[int], - scorer: Optional[Callable], -) -> "SpanFinder": - """Create a SpanFinder component. The component predicts whether a token is - the start or the end of a potential span. - - model (Model[List[Doc], Floats2d]): A model instance that - is given a list of documents and predicts a probability for each token. - spans_key (str): Key of the doc.spans dict to save the spans under. During - initialization and training, the component will look for spans on the - reference document under the same key. - threshold (float): Minimum probability to consider a prediction positive. - max_length (Optional[int]): Maximum length of the produced spans, defaults - to None meaning unlimited length. - min_length (Optional[int]): Minimum length of the produced spans, defaults - to None meaning shortest span length is 1. - scorer (Optional[Callable]): The scoring method. Defaults to - Scorer.score_spans for the Doc.spans[spans_key] with overlapping - spans allowed. - """ - return SpanFinder( - nlp, - model=model, - threshold=threshold, - name=name, - scorer=scorer, - max_length=max_length, - min_length=min_length, - spans_key=spans_key, - ) - - -@registry.scorers("spacy.span_finder_scorer.v1") def make_span_finder_scorer(): return span_finder_score @@ -333,3 +278,11 @@ class SpanFinder(TrainablePipe): self.model.initialize(X=docs, Y=Y) else: self.model.initialize() + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_span_finder": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_span_finder + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/span_ruler.py b/spacy/pipeline/span_ruler.py index 2a5e2179a..98287ba1d 100644 --- a/spacy/pipeline/span_ruler.py +++ b/spacy/pipeline/span_ruler.py @@ -1,3 +1,5 @@ +import importlib +import sys import warnings from functools import partial from pathlib import Path @@ -32,105 +34,6 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] DEFAULT_SPANS_KEY = "ruler" -@Language.factory( - "future_entity_ruler", - assigns=["doc.ents"], - default_config={ - "phrase_matcher_attr": None, - "validate": False, - "overwrite_ents": False, - "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, - "ent_id_sep": "__unused__", - "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, - }, - default_score_weights={ - "ents_f": 1.0, - "ents_p": 0.0, - "ents_r": 0.0, - "ents_per_type": None, - }, -) -def make_entity_ruler( - nlp: Language, - name: str, - phrase_matcher_attr: Optional[Union[int, str]], - matcher_fuzzy_compare: Callable, - validate: bool, - overwrite_ents: bool, - scorer: Optional[Callable], - ent_id_sep: str, -): - if overwrite_ents: - ents_filter = prioritize_new_ents_filter - else: - ents_filter = prioritize_existing_ents_filter - return SpanRuler( - nlp, - name, - spans_key=None, - spans_filter=None, - annotate_ents=True, - ents_filter=ents_filter, - phrase_matcher_attr=phrase_matcher_attr, - matcher_fuzzy_compare=matcher_fuzzy_compare, - validate=validate, - overwrite=False, - scorer=scorer, - ) - - -@Language.factory( - "span_ruler", - assigns=["doc.spans"], - default_config={ - "spans_key": DEFAULT_SPANS_KEY, - "spans_filter": None, - "annotate_ents": False, - "ents_filter": {"@misc": "spacy.first_longest_spans_filter.v1"}, - "phrase_matcher_attr": None, - "matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"}, - "validate": False, - "overwrite": True, - "scorer": { - "@scorers": "spacy.overlapping_labeled_spans_scorer.v1", - "spans_key": DEFAULT_SPANS_KEY, - }, - }, - default_score_weights={ - f"spans_{DEFAULT_SPANS_KEY}_f": 1.0, - f"spans_{DEFAULT_SPANS_KEY}_p": 0.0, - f"spans_{DEFAULT_SPANS_KEY}_r": 0.0, - f"spans_{DEFAULT_SPANS_KEY}_per_type": None, - }, -) -def make_span_ruler( - nlp: Language, - name: str, - spans_key: Optional[str], - spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]], - annotate_ents: bool, - ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]], - phrase_matcher_attr: Optional[Union[int, str]], - matcher_fuzzy_compare: Callable, - validate: bool, - overwrite: bool, - scorer: Optional[Callable], -): - return SpanRuler( - nlp, - name, - spans_key=spans_key, - spans_filter=spans_filter, - annotate_ents=annotate_ents, - ents_filter=ents_filter, - phrase_matcher_attr=phrase_matcher_attr, - matcher_fuzzy_compare=matcher_fuzzy_compare, - validate=validate, - overwrite=overwrite, - scorer=scorer, - ) - - def prioritize_new_ents_filter( entities: Iterable[Span], spans: Iterable[Span] ) -> List[Span]: @@ -157,7 +60,6 @@ def prioritize_new_ents_filter( return entities + new_entities -@registry.misc("spacy.prioritize_new_ents_filter.v1") def make_prioritize_new_ents_filter(): return prioritize_new_ents_filter @@ -188,7 +90,6 @@ def prioritize_existing_ents_filter( return entities + new_entities -@registry.misc("spacy.prioritize_existing_ents_filter.v1") def make_preserve_existing_ents_filter(): return prioritize_existing_ents_filter @@ -208,7 +109,6 @@ def overlapping_labeled_spans_score( return Scorer.score_spans(examples, **kwargs) -@registry.scorers("spacy.overlapping_labeled_spans_scorer.v1") def make_overlapping_labeled_spans_scorer(spans_key: str = DEFAULT_SPANS_KEY): return partial(overlapping_labeled_spans_score, spans_key=spans_key) @@ -595,3 +495,14 @@ class SpanRuler(Pipe): "patterns": lambda p: srsly.write_jsonl(p, self.patterns), } util.to_disk(path, serializers, {}) + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_span_ruler": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_span_ruler + elif name == "make_entity_ruler": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_future_entity_ruler + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 08a5478a9..030572850 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,3 +1,5 @@ +import importlib +import sys from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast @@ -134,7 +136,6 @@ def preset_spans_suggester( return output -@registry.misc("spacy.ngram_suggester.v1") def build_ngram_suggester(sizes: List[int]) -> Suggester: """Suggest all spans of the given lengths. Spans are returned as a ragged array of integers. The array has two columns, indicating the start and end @@ -143,7 +144,6 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester: return partial(ngram_suggester, sizes=sizes) -@registry.misc("spacy.ngram_range_suggester.v1") def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester: """Suggest all spans of the given lengths between a given min and max value - both inclusive. Spans are returned as a ragged array of integers. The array has two columns, @@ -152,7 +152,6 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester: return build_ngram_suggester(sizes) -@registry.misc("spacy.preset_spans_suggester.v1") def build_preset_spans_suggester(spans_key: str) -> Suggester: """Suggest all spans that are already stored in doc.spans[spans_key]. This is useful when an upstream component is used to set the spans @@ -160,136 +159,6 @@ def build_preset_spans_suggester(spans_key: str) -> Suggester: return partial(preset_spans_suggester, spans_key=spans_key) -@Language.factory( - "spancat", - assigns=["doc.spans"], - default_config={ - "threshold": 0.5, - "spans_key": DEFAULT_SPANS_KEY, - "max_positive": None, - "model": DEFAULT_SPANCAT_MODEL, - "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, - "scorer": {"@scorers": "spacy.spancat_scorer.v1"}, - }, - default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, -) -def make_spancat( - nlp: Language, - name: str, - suggester: Suggester, - model: Model[Tuple[List[Doc], Ragged], Floats2d], - spans_key: str, - scorer: Optional[Callable], - threshold: float, - max_positive: Optional[int], -) -> "SpanCategorizer": - """Create a SpanCategorizer component and configure it for multi-label - classification to be able to assign multiple labels for each span. - The span categorizer consists of two - parts: a suggester function that proposes candidate spans, and a labeller - model that predicts one or more labels for each span. - - name (str): The component instance name, used to add entries to the - losses during training. - suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. - Spans are returned as a ragged array with two integer columns, for the - start and end positions. - model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that - is given a list of documents and (start, end) indices representing - candidate span offsets. The model predicts a probability for each category - for each span. - spans_key (str): Key of the doc.spans dict to save the spans under. During - initialization and training, the component will look for spans on the - reference document under the same key. - scorer (Optional[Callable]): The scoring method. Defaults to - Scorer.score_spans for the Doc.spans[spans_key] with overlapping - spans allowed. - threshold (float): Minimum probability to consider a prediction positive. - Spans with a positive prediction will be saved on the Doc. Defaults to - 0.5. - max_positive (Optional[int]): Maximum number of labels to consider positive - per span. Defaults to None, indicating no limit. - """ - return SpanCategorizer( - nlp.vocab, - model=model, - suggester=suggester, - name=name, - spans_key=spans_key, - negative_weight=None, - allow_overlap=True, - max_positive=max_positive, - threshold=threshold, - scorer=scorer, - add_negative_label=False, - ) - - -@Language.factory( - "spancat_singlelabel", - assigns=["doc.spans"], - default_config={ - "spans_key": DEFAULT_SPANS_KEY, - "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL, - "negative_weight": 1.0, - "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, - "scorer": {"@scorers": "spacy.spancat_scorer.v1"}, - "allow_overlap": True, - }, - default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, -) -def make_spancat_singlelabel( - nlp: Language, - name: str, - suggester: Suggester, - model: Model[Tuple[List[Doc], Ragged], Floats2d], - spans_key: str, - negative_weight: float, - allow_overlap: bool, - scorer: Optional[Callable], -) -> "SpanCategorizer": - """Create a SpanCategorizer component and configure it for multi-class - classification. With this configuration each span can get at most one - label. The span categorizer consists of two - parts: a suggester function that proposes candidate spans, and a labeller - model that predicts one or more labels for each span. - - name (str): The component instance name, used to add entries to the - losses during training. - suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. - Spans are returned as a ragged array with two integer columns, for the - start and end positions. - model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that - is given a list of documents and (start, end) indices representing - candidate span offsets. The model predicts a probability for each category - for each span. - spans_key (str): Key of the doc.spans dict to save the spans under. During - initialization and training, the component will look for spans on the - reference document under the same key. - scorer (Optional[Callable]): The scoring method. Defaults to - Scorer.score_spans for the Doc.spans[spans_key] with overlapping - spans allowed. - negative_weight (float): Multiplier for the loss terms. - Can be used to downweight the negative samples if there are too many. - allow_overlap (bool): If True the data is assumed to contain overlapping spans. - Otherwise it produces non-overlapping spans greedily prioritizing - higher assigned label scores. - """ - return SpanCategorizer( - nlp.vocab, - model=model, - suggester=suggester, - name=name, - spans_key=spans_key, - negative_weight=negative_weight, - allow_overlap=allow_overlap, - max_positive=1, - add_negative_label=True, - threshold=None, - scorer=scorer, - ) - - def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: kwargs = dict(kwargs) attr_prefix = "spans_" @@ -303,7 +172,6 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return Scorer.score_spans(examples, **kwargs) -@registry.scorers("spacy.spancat_scorer.v1") def make_spancat_scorer(): return spancat_score @@ -785,3 +653,14 @@ class SpanCategorizer(TrainablePipe): spans.attrs["scores"] = numpy.array(attrs_scores) return spans + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_spancat": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_spancat + elif name == "make_spancat_singlelabel": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_spancat_singlelabel + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 34e85d49c..f7a16e07b 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True, binding=True +import importlib +import sys from itertools import islice from typing import Callable, Optional @@ -35,36 +37,10 @@ subword_features = true DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "tagger", - assigns=["token.tag"], - default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!", "label_smoothing": 0.0}, - default_score_weights={"tag_acc": 1.0}, -) -def make_tagger( - nlp: Language, - name: str, - model: Model, - overwrite: bool, - scorer: Optional[Callable], - neg_prefix: str, - label_smoothing: float, -): - """Construct a part-of-speech tagger component. - - model (Model[List[Doc], List[Floats2d]]): A model instance that predicts - the tag probabilities. The output vectors should match the number of tags - in size, and be normalized as probabilities (all scores between 0 and 1, - with the rows summing to 1). - """ - return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, label_smoothing=label_smoothing) - - def tagger_score(examples, **kwargs): return Scorer.score_token_attr(examples, "tag", **kwargs) -@registry.scorers("spacy.tagger_scorer.v1") def make_tagger_scorer(): return tagger_score @@ -317,3 +293,11 @@ class Tagger(TrainablePipe): self.cfg["labels"].append(label) self.vocab.strings.add(label) return 1 + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_tagger": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_tagger + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index ae227017a..36b569edc 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -1,3 +1,5 @@ +import importlib +import sys from itertools import islice from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple @@ -74,46 +76,6 @@ subword_features = true """ -@Language.factory( - "textcat", - assigns=["doc.cats"], - default_config={ - "threshold": 0.0, - "model": DEFAULT_SINGLE_TEXTCAT_MODEL, - "scorer": {"@scorers": "spacy.textcat_scorer.v2"}, - }, - default_score_weights={ - "cats_score": 1.0, - "cats_score_desc": None, - "cats_micro_p": None, - "cats_micro_r": None, - "cats_micro_f": None, - "cats_macro_p": None, - "cats_macro_r": None, - "cats_macro_f": None, - "cats_macro_auc": None, - "cats_f_per_type": None, - }, -) -def make_textcat( - nlp: Language, - name: str, - model: Model[List[Doc], List[Floats2d]], - threshold: float, - scorer: Optional[Callable], -) -> "TextCategorizer": - """Create a TextCategorizer component. The text categorizer predicts categories - over a whole document. It can learn one or more labels, and the labels are considered - to be mutually exclusive (i.e. one true label per doc). - - model (Model[List[Doc], List[Floats2d]]): A model instance that predicts - scores for each category. - threshold (float): Cutoff to consider a prediction "positive". - scorer (Optional[Callable]): The scoring method. - """ - return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) - - def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return Scorer.score_cats( examples, @@ -123,7 +85,6 @@ def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: ) -@registry.scorers("spacy.textcat_scorer.v2") def make_textcat_scorer(): return textcat_score @@ -412,3 +373,11 @@ class TextCategorizer(TrainablePipe): for val in vals: if not (val == 1.0 or val == 0.0): raise ValueError(Errors.E851.format(val=val)) + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_textcat": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_textcat + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index 2f8d5e604..32845490d 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -1,3 +1,5 @@ +import importlib +import sys from itertools import islice from typing import Any, Callable, Dict, Iterable, List, Optional @@ -72,49 +74,6 @@ subword_features = true """ -@Language.factory( - "textcat_multilabel", - assigns=["doc.cats"], - default_config={ - "threshold": 0.5, - "model": DEFAULT_MULTI_TEXTCAT_MODEL, - "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"}, - }, - default_score_weights={ - "cats_score": 1.0, - "cats_score_desc": None, - "cats_micro_p": None, - "cats_micro_r": None, - "cats_micro_f": None, - "cats_macro_p": None, - "cats_macro_r": None, - "cats_macro_f": None, - "cats_macro_auc": None, - "cats_f_per_type": None, - }, -) -def make_multilabel_textcat( - nlp: Language, - name: str, - model: Model[List[Doc], List[Floats2d]], - threshold: float, - scorer: Optional[Callable], -) -> "MultiLabel_TextCategorizer": - """Create a MultiLabel_TextCategorizer component. The text categorizer predicts categories - over a whole document. It can learn one or more labels, and the labels are considered - to be non-mutually exclusive, which means that there can be zero or more labels - per doc). - - model (Model[List[Doc], List[Floats2d]]): A model instance that predicts - scores for each category. - threshold (float): Cutoff to consider a prediction "positive". - scorer (Optional[Callable]): The scoring method. - """ - return MultiLabel_TextCategorizer( - nlp.vocab, model, name, threshold=threshold, scorer=scorer - ) - - def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return Scorer.score_cats( examples, @@ -124,7 +83,6 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str, ) -@registry.scorers("spacy.textcat_multilabel_scorer.v2") def make_textcat_multilabel_scorer(): return textcat_multilabel_score @@ -212,3 +170,11 @@ class MultiLabel_TextCategorizer(TextCategorizer): for val in ex.reference.cats.values(): if not (val == 1.0 or val == 0.0): raise ValueError(Errors.E851.format(val=val)) + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_multilabel_textcat": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_multilabel_textcat + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 677f5eec1..ce0296bf5 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -1,3 +1,5 @@ +import importlib +import sys from itertools import islice from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence @@ -24,13 +26,6 @@ subword_features = true DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"] -@Language.factory( - "tok2vec", assigns=["doc.tensor"], default_config={"model": DEFAULT_TOK2VEC_MODEL} -) -def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec": - return Tok2Vec(nlp.vocab, model, name) - - class Tok2Vec(TrainablePipe): """Apply a "token-to-vector" model and set its outputs in the doc.tensor attribute. This is mostly useful to share a single subnetwork between multiple @@ -320,3 +315,11 @@ def forward(model: Tok2VecListener, inputs, is_train: bool): def _empty_backprop(dX): # for pickling return [] + + +# Setup backwards compatibility hook for factories +def __getattr__(name): + if name == "make_tok2vec": + module = importlib.import_module("spacy.pipeline.factories") + return module.make_tok2vec + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/spacy/pipeline/transition_parser.pxd b/spacy/pipeline/transition_parser.pxd index 7ddb91e01..62c2bfb56 100644 --- a/spacy/pipeline/transition_parser.pxd +++ b/spacy/pipeline/transition_parser.pxd @@ -19,7 +19,7 @@ cdef class Parser(TrainablePipe): StateC** states, WeightsC weights, SizesC sizes - ) nogil + ) noexcept nogil cdef void c_transition_batch( self, @@ -27,4 +27,4 @@ cdef class Parser(TrainablePipe): const float* scores, int nr_class, int batch_size - ) nogil + ) noexcept nogil diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 9a278fc13..24a5bc1d9 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -316,7 +316,7 @@ cdef class Parser(TrainablePipe): cdef void _parseC( self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes - ) nogil: + ) noexcept nogil: cdef int i cdef vector[StateC*] unfinished cdef ActivationsC activations = alloc_activations(sizes) @@ -359,7 +359,7 @@ cdef class Parser(TrainablePipe): const float* scores, int nr_class, int batch_size - ) nogil: + ) noexcept nogil: # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc with gil: assert self.moves.n_moves > 0, Errors.E924.format(name=self.name) diff --git a/spacy/registrations.py b/spacy/registrations.py new file mode 100644 index 000000000..f742da9d3 --- /dev/null +++ b/spacy/registrations.py @@ -0,0 +1,245 @@ +"""Centralized registry population for spaCy config + +This module centralizes registry decorations to prevent circular import issues +with Cython annotation changes from __future__ import annotations. Functions +remain in their original locations, but decoration is moved here. + +Component definitions and registrations are in spacy/pipeline/factories.py +""" +# Global flag to track if registry has been populated +REGISTRY_POPULATED = False + + +def populate_registry() -> None: + """Populate the registry with all necessary components. + + This function should be called before accessing the registry, to ensure + it's populated. The function uses a global flag to prevent repopulation. + """ + global REGISTRY_POPULATED + if REGISTRY_POPULATED: + return + + # Import all necessary modules + from .lang.ja import create_tokenizer as create_japanese_tokenizer + from .lang.ko import create_tokenizer as create_korean_tokenizer + from .lang.th import create_thai_tokenizer + from .lang.vi import create_vietnamese_tokenizer + from .lang.zh import create_chinese_tokenizer + from .language import load_lookups_data + from .matcher.levenshtein import make_levenshtein_compare + from .ml.models.entity_linker import ( + create_candidates, + create_candidates_batch, + empty_kb, + empty_kb_for_config, + load_kb, + ) + from .pipeline.attributeruler import make_attribute_ruler_scorer + from .pipeline.dep_parser import make_parser_scorer + + # Import the functions we refactored by removing direct registry decorators + from .pipeline.entity_linker import make_entity_linker_scorer + from .pipeline.entityruler import ( + make_entity_ruler_scorer as make_entityruler_scorer, + ) + from .pipeline.lemmatizer import make_lemmatizer_scorer + from .pipeline.morphologizer import make_morphologizer_scorer + from .pipeline.ner import make_ner_scorer + from .pipeline.senter import make_senter_scorer + from .pipeline.span_finder import make_span_finder_scorer + from .pipeline.span_ruler import ( + make_overlapping_labeled_spans_scorer, + make_preserve_existing_ents_filter, + make_prioritize_new_ents_filter, + ) + from .pipeline.spancat import ( + build_ngram_range_suggester, + build_ngram_suggester, + build_preset_spans_suggester, + make_spancat_scorer, + ) + + # Import all pipeline components that were using registry decorators + from .pipeline.tagger import make_tagger_scorer + from .pipeline.textcat import make_textcat_scorer + from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer + from .util import make_first_longest_spans_filter, registry + + # Register miscellaneous components + registry.misc("spacy.first_longest_spans_filter.v1")( + make_first_longest_spans_filter + ) + registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester) + registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester) + registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester) + registry.misc("spacy.prioritize_new_ents_filter.v1")( + make_prioritize_new_ents_filter + ) + registry.misc("spacy.prioritize_existing_ents_filter.v1")( + make_preserve_existing_ents_filter + ) + registry.misc("spacy.levenshtein_compare.v1")(make_levenshtein_compare) + # KB-related registrations + registry.misc("spacy.KBFromFile.v1")(load_kb) + registry.misc("spacy.EmptyKB.v2")(empty_kb_for_config) + registry.misc("spacy.EmptyKB.v1")(empty_kb) + registry.misc("spacy.CandidateGenerator.v1")(create_candidates) + registry.misc("spacy.CandidateBatchGenerator.v1")(create_candidates_batch) + registry.misc("spacy.LookupsDataLoader.v1")(load_lookups_data) + + # Need to get references to the existing functions in registry by importing the function that is there + # For the registry that was previously decorated + + # Import ML components that use registry + from .language import create_tokenizer + from .ml._precomputable_affine import PrecomputableAffine + from .ml.callbacks import ( + create_models_and_pipes_with_nvtx_range, + create_models_with_nvtx_range, + ) + from .ml.extract_ngrams import extract_ngrams + from .ml.extract_spans import extract_spans + + # Import decorator-removed ML components + from .ml.featureextractor import FeatureExtractor + from .ml.models.entity_linker import build_nel_encoder + from .ml.models.multi_task import ( + create_pretrain_characters, + create_pretrain_vectors, + ) + from .ml.models.parser import build_tb_parser_model + from .ml.models.span_finder import build_finder_model + from .ml.models.spancat import ( + build_linear_logistic, + build_mean_max_reducer, + build_spancat_model, + ) + from .ml.models.tagger import build_tagger_model + from .ml.models.textcat import ( + build_bow_text_classifier, + build_bow_text_classifier_v3, + build_reduce_text_classifier, + build_simple_cnn_text_classifier, + build_text_classifier_lowdata, + build_text_classifier_v2, + build_textcat_parametric_attention_v1, + ) + from .ml.models.tok2vec import ( + BiLSTMEncoder, + CharacterEmbed, + MaxoutWindowEncoder, + MishWindowEncoder, + MultiHashEmbed, + build_hash_embed_cnn_tok2vec, + build_Tok2Vec_model, + tok2vec_listener_v1, + ) + from .ml.staticvectors import StaticVectors + from .ml.tb_framework import TransitionModel + from .training.augment import ( + create_combined_augmenter, + create_lower_casing_augmenter, + create_orth_variants_augmenter, + ) + from .training.batchers import ( + configure_minibatch, + configure_minibatch_by_padded_size, + configure_minibatch_by_words, + ) + from .training.callbacks import create_copy_from_base_model + from .training.loggers import console_logger, console_logger_v3 + + # Register scorers + registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer) + registry.scorers("spacy.ner_scorer.v1")(make_ner_scorer) + # span_ruler_scorer removed as it's not in span_ruler.py + registry.scorers("spacy.entity_ruler_scorer.v1")(make_entityruler_scorer) + registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer) + registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer) + registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer) + registry.scorers("spacy.textcat_multilabel_scorer.v1")( + make_textcat_multilabel_scorer + ) + registry.scorers("spacy.textcat_multilabel_scorer.v2")( + make_textcat_multilabel_scorer + ) + registry.scorers("spacy.lemmatizer_scorer.v1")(make_lemmatizer_scorer) + registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer) + registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer) + registry.scorers("spacy.entity_linker_scorer.v1")(make_entity_linker_scorer) + registry.scorers("spacy.overlapping_labeled_spans_scorer.v1")( + make_overlapping_labeled_spans_scorer + ) + registry.scorers("spacy.attribute_ruler_scorer.v1")(make_attribute_ruler_scorer) + registry.scorers("spacy.parser_scorer.v1")(make_parser_scorer) + registry.scorers("spacy.morphologizer_scorer.v1")(make_morphologizer_scorer) + + # Register tokenizers + registry.tokenizers("spacy.Tokenizer.v1")(create_tokenizer) + registry.tokenizers("spacy.ja.JapaneseTokenizer")(create_japanese_tokenizer) + registry.tokenizers("spacy.zh.ChineseTokenizer")(create_chinese_tokenizer) + registry.tokenizers("spacy.ko.KoreanTokenizer")(create_korean_tokenizer) + registry.tokenizers("spacy.vi.VietnameseTokenizer")(create_vietnamese_tokenizer) + registry.tokenizers("spacy.th.ThaiTokenizer")(create_thai_tokenizer) + + # Register tok2vec architectures we've modified + registry.architectures("spacy.Tok2VecListener.v1")(tok2vec_listener_v1) + registry.architectures("spacy.HashEmbedCNN.v2")(build_hash_embed_cnn_tok2vec) + registry.architectures("spacy.Tok2Vec.v2")(build_Tok2Vec_model) + registry.architectures("spacy.MultiHashEmbed.v2")(MultiHashEmbed) + registry.architectures("spacy.CharacterEmbed.v2")(CharacterEmbed) + registry.architectures("spacy.MaxoutWindowEncoder.v2")(MaxoutWindowEncoder) + registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder) + registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder) + registry.architectures("spacy.EntityLinker.v2")(build_nel_encoder) + registry.architectures("spacy.TextCatCNN.v2")(build_simple_cnn_text_classifier) + registry.architectures("spacy.TextCatBOW.v2")(build_bow_text_classifier) + registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3) + registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2) + registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata) + registry.architectures("spacy.TextCatParametricAttention.v1")( + build_textcat_parametric_attention_v1 + ) + registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier) + registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model) + registry.architectures("spacy.SpanFinder.v1")(build_finder_model) + registry.architectures("spacy.TransitionBasedParser.v2")(build_tb_parser_model) + registry.architectures("spacy.PretrainVectors.v1")(create_pretrain_vectors) + registry.architectures("spacy.PretrainCharacters.v1")(create_pretrain_characters) + registry.architectures("spacy.Tagger.v2")(build_tagger_model) + + # Register layers + registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor) + registry.layers("spacy.extract_spans.v1")(extract_spans) + registry.layers("spacy.extract_ngrams.v1")(extract_ngrams) + registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic) + registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer) + registry.layers("spacy.StaticVectors.v2")(StaticVectors) + registry.layers("spacy.PrecomputableAffine.v1")(PrecomputableAffine) + registry.layers("spacy.CharEmbed.v1")(CharacterEmbed) + registry.layers("spacy.TransitionModel.v1")(TransitionModel) + + # Register callbacks + registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model) + registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range) + registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")( + create_models_and_pipes_with_nvtx_range + ) + + # Register loggers + registry.loggers("spacy.ConsoleLogger.v2")(console_logger) + registry.loggers("spacy.ConsoleLogger.v3")(console_logger_v3) + + # Register batchers + registry.batchers("spacy.batch_by_padded.v1")(configure_minibatch_by_padded_size) + registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words) + registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch) + + # Register augmenters + registry.augmenters("spacy.combined_augmenter.v1")(create_combined_augmenter) + registry.augmenters("spacy.lower_case.v1")(create_lower_casing_augmenter) + registry.augmenters("spacy.orth_variants.v1")(create_orth_variants_augmenter) + + # Set the flag to indicate that the registry has been populated + REGISTRY_POPULATED = True diff --git a/spacy/symbols.pyx b/spacy/symbols.pyx index f7713577b..29c179df8 100644 --- a/spacy/symbols.pyx +++ b/spacy/symbols.pyx @@ -479,3 +479,4 @@ NAMES = [it[0] for it in sorted(IDS.items(), key=sort_nums)] # (which is generating an enormous amount of C++ in Cython 0.24+) # We keep the enum cdef, and just make sure the names are available to Python locals().update(IDS) + diff --git a/spacy/tests/factory_registrations.json b/spacy/tests/factory_registrations.json new file mode 100644 index 000000000..475e48020 --- /dev/null +++ b/spacy/tests/factory_registrations.json @@ -0,0 +1,132 @@ +{ + "attribute_ruler": { + "name": "attribute_ruler", + "module": "spacy.pipeline.attributeruler", + "function": "make_attribute_ruler" + }, + "beam_ner": { + "name": "beam_ner", + "module": "spacy.pipeline.ner", + "function": "make_beam_ner" + }, + "beam_parser": { + "name": "beam_parser", + "module": "spacy.pipeline.dep_parser", + "function": "make_beam_parser" + }, + "doc_cleaner": { + "name": "doc_cleaner", + "module": "spacy.pipeline.functions", + "function": "make_doc_cleaner" + }, + "entity_linker": { + "name": "entity_linker", + "module": "spacy.pipeline.entity_linker", + "function": "make_entity_linker" + }, + "entity_ruler": { + "name": "entity_ruler", + "module": "spacy.pipeline.entityruler", + "function": "make_entity_ruler" + }, + "future_entity_ruler": { + "name": "future_entity_ruler", + "module": "spacy.pipeline.span_ruler", + "function": "make_entity_ruler" + }, + "lemmatizer": { + "name": "lemmatizer", + "module": "spacy.pipeline.lemmatizer", + "function": "make_lemmatizer" + }, + "merge_entities": { + "name": "merge_entities", + "module": "spacy.language", + "function": "Language.component..add_component..factory_func" + }, + "merge_noun_chunks": { + "name": "merge_noun_chunks", + "module": "spacy.language", + "function": "Language.component..add_component..factory_func" + }, + "merge_subtokens": { + "name": "merge_subtokens", + "module": "spacy.language", + "function": "Language.component..add_component..factory_func" + }, + "morphologizer": { + "name": "morphologizer", + "module": "spacy.pipeline.morphologizer", + "function": "make_morphologizer" + }, + "ner": { + "name": "ner", + "module": "spacy.pipeline.ner", + "function": "make_ner" + }, + "parser": { + "name": "parser", + "module": "spacy.pipeline.dep_parser", + "function": "make_parser" + }, + "sentencizer": { + "name": "sentencizer", + "module": "spacy.pipeline.sentencizer", + "function": "make_sentencizer" + }, + "senter": { + "name": "senter", + "module": "spacy.pipeline.senter", + "function": "make_senter" + }, + "span_finder": { + "name": "span_finder", + "module": "spacy.pipeline.span_finder", + "function": "make_span_finder" + }, + "span_ruler": { + "name": "span_ruler", + "module": "spacy.pipeline.span_ruler", + "function": "make_span_ruler" + }, + "spancat": { + "name": "spancat", + "module": "spacy.pipeline.spancat", + "function": "make_spancat" + }, + "spancat_singlelabel": { + "name": "spancat_singlelabel", + "module": "spacy.pipeline.spancat", + "function": "make_spancat_singlelabel" + }, + "tagger": { + "name": "tagger", + "module": "spacy.pipeline.tagger", + "function": "make_tagger" + }, + "textcat": { + "name": "textcat", + "module": "spacy.pipeline.textcat", + "function": "make_textcat" + }, + "textcat_multilabel": { + "name": "textcat_multilabel", + "module": "spacy.pipeline.textcat_multilabel", + "function": "make_multilabel_textcat" + }, + "tok2vec": { + "name": "tok2vec", + "module": "spacy.pipeline.tok2vec", + "function": "make_tok2vec" + }, + "token_splitter": { + "name": "token_splitter", + "module": "spacy.pipeline.functions", + "function": "make_token_splitter" + }, + "trainable_lemmatizer": { + "name": "trainable_lemmatizer", + "module": "spacy.pipeline.edit_tree_lemmatizer", + "function": "make_edit_tree_lemmatizer" + } +} \ No newline at end of file diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 4dd7bae16..9b9786f04 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -529,17 +529,6 @@ def test_pipe_label_data_no_labels(pipe): assert "labels" not in get_arg_names(initialize) -def test_warning_pipe_begin_training(): - with pytest.warns(UserWarning, match="begin_training"): - - class IncompatPipe(TrainablePipe): - def __init__(self): - ... - - def begin_training(*args, **kwargs): - ... - - def test_pipe_methods_initialize(): """Test that the [initialize] config reflects the components correctly.""" nlp = Language() diff --git a/spacy/tests/registry_contents.json b/spacy/tests/registry_contents.json new file mode 100644 index 000000000..1836d0328 --- /dev/null +++ b/spacy/tests/registry_contents.json @@ -0,0 +1,284 @@ +{ + "architectures": [ + "spacy-legacy.CharacterEmbed.v1", + "spacy-legacy.EntityLinker.v1", + "spacy-legacy.HashEmbedCNN.v1", + "spacy-legacy.MaxoutWindowEncoder.v1", + "spacy-legacy.MishWindowEncoder.v1", + "spacy-legacy.MultiHashEmbed.v1", + "spacy-legacy.Tagger.v1", + "spacy-legacy.TextCatBOW.v1", + "spacy-legacy.TextCatCNN.v1", + "spacy-legacy.TextCatEnsemble.v1", + "spacy-legacy.Tok2Vec.v1", + "spacy-legacy.TransitionBasedParser.v1", + "spacy.CharacterEmbed.v2", + "spacy.EntityLinker.v2", + "spacy.HashEmbedCNN.v2", + "spacy.MaxoutWindowEncoder.v2", + "spacy.MishWindowEncoder.v2", + "spacy.MultiHashEmbed.v2", + "spacy.PretrainCharacters.v1", + "spacy.PretrainVectors.v1", + "spacy.SpanCategorizer.v1", + "spacy.SpanFinder.v1", + "spacy.Tagger.v2", + "spacy.TextCatBOW.v2", + "spacy.TextCatBOW.v3", + "spacy.TextCatCNN.v2", + "spacy.TextCatEnsemble.v2", + "spacy.TextCatLowData.v1", + "spacy.TextCatParametricAttention.v1", + "spacy.TextCatReduce.v1", + "spacy.Tok2Vec.v2", + "spacy.Tok2VecListener.v1", + "spacy.TorchBiLSTMEncoder.v1", + "spacy.TransitionBasedParser.v2" + ], + "augmenters": [ + "spacy.combined_augmenter.v1", + "spacy.lower_case.v1", + "spacy.orth_variants.v1" + ], + "batchers": [ + "spacy.batch_by_padded.v1", + "spacy.batch_by_sequence.v1", + "spacy.batch_by_words.v1" + ], + "callbacks": [ + "spacy.copy_from_base_model.v1", + "spacy.models_and_pipes_with_nvtx_range.v1", + "spacy.models_with_nvtx_range.v1" + ], + "cli": [], + "datasets": [], + "displacy_colors": [], + "factories": [ + "attribute_ruler", + "beam_ner", + "beam_parser", + "doc_cleaner", + "entity_linker", + "entity_ruler", + "future_entity_ruler", + "lemmatizer", + "merge_entities", + "merge_noun_chunks", + "merge_subtokens", + "morphologizer", + "ner", + "parser", + "sentencizer", + "senter", + "span_finder", + "span_ruler", + "spancat", + "spancat_singlelabel", + "tagger", + "textcat", + "textcat_multilabel", + "tok2vec", + "token_splitter", + "trainable_lemmatizer" + ], + "initializers": [ + "glorot_normal_init.v1", + "glorot_uniform_init.v1", + "he_normal_init.v1", + "he_uniform_init.v1", + "lecun_normal_init.v1", + "lecun_uniform_init.v1", + "normal_init.v1", + "uniform_init.v1", + "zero_init.v1" + ], + "languages": [], + "layers": [ + "CauchySimilarity.v1", + "ClippedLinear.v1", + "Dish.v1", + "Dropout.v1", + "Embed.v1", + "Gelu.v1", + "HardSigmoid.v1", + "HardSwish.v1", + "HardSwishMobilenet.v1", + "HardTanh.v1", + "HashEmbed.v1", + "LSTM.v1", + "LayerNorm.v1", + "Linear.v1", + "Logistic.v1", + "MXNetWrapper.v1", + "Maxout.v1", + "Mish.v1", + "MultiSoftmax.v1", + "ParametricAttention.v1", + "ParametricAttention.v2", + "PyTorchLSTM.v1", + "PyTorchRNNWrapper.v1", + "PyTorchWrapper.v1", + "PyTorchWrapper.v2", + "PyTorchWrapper.v3", + "Relu.v1", + "ReluK.v1", + "Sigmoid.v1", + "Softmax.v1", + "Softmax.v2", + "SparseLinear.v1", + "SparseLinear.v2", + "Swish.v1", + "add.v1", + "bidirectional.v1", + "chain.v1", + "clone.v1", + "concatenate.v1", + "expand_window.v1", + "list2array.v1", + "list2padded.v1", + "list2ragged.v1", + "noop.v1", + "padded2list.v1", + "premap_ids.v1", + "ragged2list.v1", + "reduce_first.v1", + "reduce_last.v1", + "reduce_max.v1", + "reduce_mean.v1", + "reduce_sum.v1", + "remap_ids.v1", + "remap_ids.v2", + "residual.v1", + "resizable.v1", + "siamese.v1", + "sigmoid_activation.v1", + "softmax_activation.v1", + "spacy-legacy.StaticVectors.v1", + "spacy.CharEmbed.v1", + "spacy.FeatureExtractor.v1", + "spacy.LinearLogistic.v1", + "spacy.PrecomputableAffine.v1", + "spacy.StaticVectors.v2", + "spacy.TransitionModel.v1", + "spacy.extract_ngrams.v1", + "spacy.extract_spans.v1", + "spacy.mean_max_reducer.v1", + "strings2arrays.v1", + "tuplify.v1", + "uniqued.v1", + "with_array.v1", + "with_array2d.v1", + "with_cpu.v1", + "with_flatten.v1", + "with_flatten.v2", + "with_getitem.v1", + "with_list.v1", + "with_padded.v1", + "with_ragged.v1", + "with_reshape.v1" + ], + "lemmatizers": [], + "loggers": [ + "spacy-legacy.ConsoleLogger.v1", + "spacy-legacy.ConsoleLogger.v2", + "spacy-legacy.WandbLogger.v1", + "spacy.ChainLogger.v1", + "spacy.ClearMLLogger.v1", + "spacy.ClearMLLogger.v2", + "spacy.ConsoleLogger.v2", + "spacy.ConsoleLogger.v3", + "spacy.CupyLogger.v1", + "spacy.LookupLogger.v1", + "spacy.MLflowLogger.v1", + "spacy.MLflowLogger.v2", + "spacy.PyTorchLogger.v1", + "spacy.WandbLogger.v1", + "spacy.WandbLogger.v2", + "spacy.WandbLogger.v3", + "spacy.WandbLogger.v4", + "spacy.WandbLogger.v5" + ], + "lookups": [], + "losses": [ + "CategoricalCrossentropy.v1", + "CategoricalCrossentropy.v2", + "CategoricalCrossentropy.v3", + "CosineDistance.v1", + "L2Distance.v1", + "SequenceCategoricalCrossentropy.v1", + "SequenceCategoricalCrossentropy.v2", + "SequenceCategoricalCrossentropy.v3" + ], + "misc": [ + "spacy.CandidateBatchGenerator.v1", + "spacy.CandidateGenerator.v1", + "spacy.EmptyKB.v1", + "spacy.EmptyKB.v2", + "spacy.KBFromFile.v1", + "spacy.LookupsDataLoader.v1", + "spacy.first_longest_spans_filter.v1", + "spacy.levenshtein_compare.v1", + "spacy.ngram_range_suggester.v1", + "spacy.ngram_suggester.v1", + "spacy.preset_spans_suggester.v1", + "spacy.prioritize_existing_ents_filter.v1", + "spacy.prioritize_new_ents_filter.v1" + ], + "models": [], + "ops": [ + "CupyOps", + "MPSOps", + "NumpyOps" + ], + "optimizers": [ + "Adam.v1", + "RAdam.v1", + "SGD.v1" + ], + "readers": [ + "ml_datasets.cmu_movies.v1", + "ml_datasets.dbpedia.v1", + "ml_datasets.imdb_sentiment.v1", + "spacy.Corpus.v1", + "spacy.JsonlCorpus.v1", + "spacy.PlainTextCorpus.v1", + "spacy.read_labels.v1", + "srsly.read_json.v1", + "srsly.read_jsonl.v1", + "srsly.read_msgpack.v1", + "srsly.read_yaml.v1" + ], + "schedules": [ + "compounding.v1", + "constant.v1", + "constant_then.v1", + "cyclic_triangular.v1", + "decaying.v1", + "slanted_triangular.v1", + "warmup_linear.v1" + ], + "scorers": [ + "spacy-legacy.textcat_multilabel_scorer.v1", + "spacy-legacy.textcat_scorer.v1", + "spacy.attribute_ruler_scorer.v1", + "spacy.entity_linker_scorer.v1", + "spacy.entity_ruler_scorer.v1", + "spacy.lemmatizer_scorer.v1", + "spacy.morphologizer_scorer.v1", + "spacy.ner_scorer.v1", + "spacy.overlapping_labeled_spans_scorer.v1", + "spacy.parser_scorer.v1", + "spacy.senter_scorer.v1", + "spacy.span_finder_scorer.v1", + "spacy.spancat_scorer.v1", + "spacy.tagger_scorer.v1", + "spacy.textcat_multilabel_scorer.v2", + "spacy.textcat_scorer.v2" + ], + "tokenizers": [ + "spacy.Tokenizer.v1" + ], + "vectors": [ + "spacy.Vectors.v1" + ] +} diff --git a/spacy/tests/serialize/test_resource_warning.py b/spacy/tests/serialize/test_resource_warning.py index ab6e6e9ee..4cf0ac558 100644 --- a/spacy/tests/serialize/test_resource_warning.py +++ b/spacy/tests/serialize/test_resource_warning.py @@ -87,7 +87,7 @@ def entity_linker(): objects_to_test = ( - [nlp(), vectors(), custom_pipe(), tagger(), entity_linker()], + [nlp, vectors, custom_pipe, tagger, entity_linker], ["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"], ) @@ -101,8 +101,9 @@ def write_obj_and_catch_warnings(obj): return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list)) -@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1]) -def test_to_disk_resource_warning(obj): +@pytest.mark.parametrize("obj_factory", objects_to_test[0], ids=objects_to_test[1]) +def test_to_disk_resource_warning(obj_factory): + obj = obj_factory() warnings_list = write_obj_and_catch_warnings(obj) assert len(warnings_list) == 0 @@ -139,9 +140,11 @@ def test_save_and_load_knowledge_base(): class TestToDiskResourceWarningUnittest(TestCase): def test_resource_warning(self): - scenarios = zip(*objects_to_test) + items = [x() for x in objects_to_test[0]] + names = objects_to_test[1] + scenarios = zip(items, names) - for scenario in scenarios: - with self.subTest(msg=scenario[1]): - warnings_list = write_obj_and_catch_warnings(scenario[0]) + for item, name in scenarios: + with self.subTest(msg=name): + warnings_list = write_obj_and_catch_warnings(item) self.assertEqual(len(warnings_list), 0) diff --git a/spacy/tests/test_factory_imports.py b/spacy/tests/test_factory_imports.py new file mode 100644 index 000000000..a975af0bb --- /dev/null +++ b/spacy/tests/test_factory_imports.py @@ -0,0 +1,85 @@ +# coding: utf-8 +"""Test factory import compatibility from original and new locations.""" + +import importlib + +import pytest + + +@pytest.mark.parametrize( + "factory_name,original_module,compat_module", + [ + ("make_tagger", "spacy.pipeline.factories", "spacy.pipeline.tagger"), + ("make_sentencizer", "spacy.pipeline.factories", "spacy.pipeline.sentencizer"), + ("make_ner", "spacy.pipeline.factories", "spacy.pipeline.ner"), + ("make_parser", "spacy.pipeline.factories", "spacy.pipeline.dep_parser"), + ("make_tok2vec", "spacy.pipeline.factories", "spacy.pipeline.tok2vec"), + ("make_spancat", "spacy.pipeline.factories", "spacy.pipeline.spancat"), + ( + "make_spancat_singlelabel", + "spacy.pipeline.factories", + "spacy.pipeline.spancat", + ), + ("make_lemmatizer", "spacy.pipeline.factories", "spacy.pipeline.lemmatizer"), + ("make_entity_ruler", "spacy.pipeline.factories", "spacy.pipeline.entityruler"), + ("make_span_ruler", "spacy.pipeline.factories", "spacy.pipeline.span_ruler"), + ( + "make_edit_tree_lemmatizer", + "spacy.pipeline.factories", + "spacy.pipeline.edit_tree_lemmatizer", + ), + ( + "make_attribute_ruler", + "spacy.pipeline.factories", + "spacy.pipeline.attributeruler", + ), + ( + "make_entity_linker", + "spacy.pipeline.factories", + "spacy.pipeline.entity_linker", + ), + ("make_textcat", "spacy.pipeline.factories", "spacy.pipeline.textcat"), + ("make_token_splitter", "spacy.pipeline.factories", "spacy.pipeline.functions"), + ("make_doc_cleaner", "spacy.pipeline.factories", "spacy.pipeline.functions"), + ( + "make_morphologizer", + "spacy.pipeline.factories", + "spacy.pipeline.morphologizer", + ), + ("make_senter", "spacy.pipeline.factories", "spacy.pipeline.senter"), + ("make_span_finder", "spacy.pipeline.factories", "spacy.pipeline.span_finder"), + ( + "make_multilabel_textcat", + "spacy.pipeline.factories", + "spacy.pipeline.textcat_multilabel", + ), + ("make_beam_ner", "spacy.pipeline.factories", "spacy.pipeline.ner"), + ("make_beam_parser", "spacy.pipeline.factories", "spacy.pipeline.dep_parser"), + ("make_nn_labeller", "spacy.pipeline.factories", "spacy.pipeline.multitask"), + # This one's special because the function was named make_span_ruler, so + # the name in the registrations.py doesn't match the name we make the import hook + # point to. We could make a test just for this but shrug + # ("make_future_entity_ruler", "spacy.pipeline.factories", "spacy.pipeline.span_ruler"), + ], +) +def test_factory_import_compatibility(factory_name, original_module, compat_module): + """Test that factory functions can be imported from both original and compatibility locations.""" + # Import from the original module (registrations.py) + original_module_obj = importlib.import_module(original_module) + original_factory = getattr(original_module_obj, factory_name) + assert ( + original_factory is not None + ), f"Could not import {factory_name} from {original_module}" + + # Import from the compatibility module (component file) + compat_module_obj = importlib.import_module(compat_module) + compat_factory = getattr(compat_module_obj, factory_name) + assert ( + compat_factory is not None + ), f"Could not import {factory_name} from {compat_module}" + + # Test that they're the same function (identity) + assert original_factory is compat_factory, ( + f"Factory {factory_name} imported from {original_module} is not the same object " + f"as the one imported from {compat_module}" + ) diff --git a/spacy/tests/test_factory_registrations.py b/spacy/tests/test_factory_registrations.py new file mode 100644 index 000000000..8e93f54f0 --- /dev/null +++ b/spacy/tests/test_factory_registrations.py @@ -0,0 +1,97 @@ +import inspect +import json +from pathlib import Path + +import pytest + +from spacy.language import Language +from spacy.util import registry + +# Path to the reference factory registrations, relative to this file +REFERENCE_FILE = Path(__file__).parent / "factory_registrations.json" + +# Monkey patch the util.is_same_func to handle Cython functions +import inspect + +from spacy import util + +original_is_same_func = util.is_same_func + + +def patched_is_same_func(func1, func2): + # Handle Cython functions + try: + return original_is_same_func(func1, func2) + except TypeError: + # For Cython functions, just compare the string representation + return str(func1) == str(func2) + + +util.is_same_func = patched_is_same_func + + +@pytest.fixture +def reference_factory_registrations(): + """Load reference factory registrations from JSON file""" + if not REFERENCE_FILE.exists(): + pytest.fail( + f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first." + ) + + with REFERENCE_FILE.open("r") as f: + return json.load(f) + + +def test_factory_registrations_preserved(reference_factory_registrations): + """Test that all factory registrations from the reference file are still present.""" + # Ensure the registry is populated + registry.ensure_populated() + + # Get all factory registrations + all_factories = registry.factories.get_all() + + # Initialize our data structure to store current factory registrations + current_registrations = {} + + # Process factory registrations + for name, func in all_factories.items(): + # Store information about each factory + try: + module_name = func.__module__ + except (AttributeError, TypeError): + # For Cython functions, just use a placeholder + module_name = str(func).split()[1].split(".")[0] + + try: + func_name = func.__qualname__ + except (AttributeError, TypeError): + # For Cython functions, use the function's name + func_name = ( + func.__name__ + if hasattr(func, "__name__") + else str(func).split()[1].split(".")[-1] + ) + + current_registrations[name] = { + "name": name, + "module": module_name, + "function": func_name, + } + + # Check for missing registrations + missing_registrations = set(reference_factory_registrations.keys()) - set( + current_registrations.keys() + ) + assert ( + not missing_registrations + ), f"Missing factory registrations: {', '.join(sorted(missing_registrations))}" + + # Check for new registrations (not an error, but informative) + new_registrations = set(current_registrations.keys()) - set( + reference_factory_registrations.keys() + ) + if new_registrations: + # This is not an error, just informative + print( + f"New factory registrations found: {', '.join(sorted(new_registrations))}" + ) diff --git a/spacy/tests/test_registry_population.py b/spacy/tests/test_registry_population.py new file mode 100644 index 000000000..592e74dd2 --- /dev/null +++ b/spacy/tests/test_registry_population.py @@ -0,0 +1,55 @@ +import json +import os +from pathlib import Path + +import pytest + +from spacy.util import registry + +# Path to the reference registry contents, relative to this file +REFERENCE_FILE = Path(__file__).parent / "registry_contents.json" + + +@pytest.fixture +def reference_registry(): + """Load reference registry contents from JSON file""" + if not REFERENCE_FILE.exists(): + pytest.fail(f"Reference file {REFERENCE_FILE} not found.") + + with REFERENCE_FILE.open("r") as f: + return json.load(f) + + +def test_registry_types(reference_registry): + """Test that all registry types match the reference""" + # Get current registry types + current_registry_types = set(registry.get_registry_names()) + expected_registry_types = set(reference_registry.keys()) + + # Check for missing registry types + missing_types = expected_registry_types - current_registry_types + assert not missing_types, f"Missing registry types: {', '.join(missing_types)}" + + +def test_registry_entries(reference_registry): + """Test that all registry entries are present""" + # Check each registry's entries + for registry_name, expected_entries in reference_registry.items(): + # Skip if this registry type doesn't exist + if not hasattr(registry, registry_name): + pytest.fail(f"Registry '{registry_name}' does not exist.") + + # Get current entries + reg = getattr(registry, registry_name) + current_entries = sorted(list(reg.get_all().keys())) + + # Compare entries + expected_set = set(expected_entries) + current_set = set(current_entries) + + # Check for missing entries - these would indicate our new registry population + # mechanism is missing something + missing_entries = expected_set - current_set + assert ( + not missing_entries + ), f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}" diff --git a/spacy/tests/training/test_readers.py b/spacy/tests/training/test_readers.py index 22cf75272..87b5343ed 100644 --- a/spacy/tests/training/test_readers.py +++ b/spacy/tests/training/test_readers.py @@ -101,7 +101,11 @@ def test_cat_readers(reader, additional_config): nlp = load_model_from_config(config, auto_fill=True) T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) dot_names = [T["train_corpus"], T["dev_corpus"]] + print("T", T) + print("dot names", dot_names) train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names) + data = list(train_corpus(nlp)) + print(len(data)) optimizer = T["optimizer"] # simulate a training loop nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index 6ca170dd4..77718a75b 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -867,11 +867,11 @@ cdef extern from "" namespace "std" nogil: bint (*)(SpanC, SpanC)) -cdef bint len_start_cmp(SpanC a, SpanC b) nogil: +cdef bint len_start_cmp(SpanC a, SpanC b) noexcept nogil: if a.end - a.start == b.end - b.start: return b.start < a.start return a.end - a.start < b.end - b.start -cdef bint start_cmp(SpanC a, SpanC b) nogil: +cdef bint start_cmp(SpanC a, SpanC b) noexcept nogil: return a.start < b.start diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index d9719609c..454166056 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -7,8 +7,8 @@ from ..typedefs cimport attr_t from ..vocab cimport Vocab -cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil -cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil +cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil +cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) noexcept nogil ctypedef const LexemeC* const_Lexeme_ptr diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 4d6249569..0a90a67d1 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -71,7 +71,7 @@ cdef int bounds_check(int i, int length, int padding) except -1: raise IndexError(Errors.E026.format(i=i, length=length)) -cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil: +cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil: if feat_name == LEMMA: return token.lemma elif feat_name == NORM: @@ -106,7 +106,7 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil: return Lexeme.get_struct_attr(token.lex, feat_name) -cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil: +cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) noexcept nogil: if feat_name == SENT_START: if token.sent_start == 1: return True diff --git a/spacy/tokens/token.pxd b/spacy/tokens/token.pxd index f4e4611df..3252fcdeb 100644 --- a/spacy/tokens/token.pxd +++ b/spacy/tokens/token.pxd @@ -33,7 +33,7 @@ cdef class Token: cpdef bint check_flag(self, attr_id_t flag_id) except -1 @staticmethod - cdef inline attr_t get_struct_attr(const TokenC* token, attr_id_t feat_name) nogil: + cdef inline attr_t get_struct_attr(const TokenC* token, attr_id_t feat_name) noexcept nogil: if feat_name < (sizeof(flags_t) * 8): return Lexeme.c_check_flag(token.lex, feat_name) elif feat_name == LEMMA: @@ -70,7 +70,7 @@ cdef class Token: @staticmethod cdef inline attr_t set_struct_attr(TokenC* token, attr_id_t feat_name, - attr_t value) nogil: + attr_t value) noexcept nogil: if feat_name == LEMMA: token.lemma = value elif feat_name == NORM: @@ -99,9 +99,9 @@ cdef class Token: token.sent_start = value @staticmethod - cdef inline int missing_dep(const TokenC* token) nogil: + cdef inline int missing_dep(const TokenC* token) noexcept nogil: return token.dep == MISSING_DEP @staticmethod - cdef inline int missing_head(const TokenC* token) nogil: + cdef inline int missing_head(const TokenC* token) noexcept nogil: return Token.missing_dep(token) diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 1ebd3313c..da5ae3d08 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from ..language import Language # noqa: F401 -@registry.augmenters("spacy.combined_augmenter.v1") def create_combined_augmenter( lower_level: float, orth_level: float, @@ -84,7 +83,6 @@ def combined_augmenter( yield example -@registry.augmenters("spacy.orth_variants.v1") def create_orth_variants_augmenter( level: float, lower: float, orth_variants: Dict[str, List[Dict]] ) -> Callable[["Language", Example], Iterator[Example]]: @@ -102,7 +100,6 @@ def create_orth_variants_augmenter( ) -@registry.augmenters("spacy.lower_case.v1") def create_lower_casing_augmenter( level: float, ) -> Callable[["Language", Example], Iterator[Example]]: diff --git a/spacy/training/batchers.py b/spacy/training/batchers.py index 050c3351b..4a1dfa945 100644 --- a/spacy/training/batchers.py +++ b/spacy/training/batchers.py @@ -19,7 +19,6 @@ ItemT = TypeVar("ItemT") BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] -@registry.batchers("spacy.batch_by_padded.v1") def configure_minibatch_by_padded_size( *, size: Sizing, @@ -54,7 +53,6 @@ def configure_minibatch_by_padded_size( ) -@registry.batchers("spacy.batch_by_words.v1") def configure_minibatch_by_words( *, size: Sizing, @@ -82,7 +80,6 @@ def configure_minibatch_by_words( ) -@registry.batchers("spacy.batch_by_sequence.v1") def configure_minibatch( size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: diff --git a/spacy/training/callbacks.py b/spacy/training/callbacks.py index 21c3d56a1..714deea6d 100644 --- a/spacy/training/callbacks.py +++ b/spacy/training/callbacks.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: from ..language import Language -@registry.callbacks("spacy.copy_from_base_model.v1") def create_copy_from_base_model( tokenizer: Optional[str] = None, vocab: Optional[str] = None, diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index 1ec0b7b25..488ca4a71 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -29,7 +29,6 @@ def setup_table( # We cannot rename this method as it's directly imported # and used by external packages such as spacy-loggers. -@registry.loggers("spacy.ConsoleLogger.v2") def console_logger( progress_bar: bool = False, console_output: bool = True, @@ -47,7 +46,6 @@ def console_logger( ) -@registry.loggers("spacy.ConsoleLogger.v3") def console_logger_v3( progress_bar: Optional[str] = None, console_output: bool = True, diff --git a/spacy/util.py b/spacy/util.py index c127be03c..f1e68696b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -132,9 +132,18 @@ class registry(thinc.registry): models = catalogue.create("spacy", "models", entry_points=True) cli = catalogue.create("spacy", "cli", entry_points=True) + @classmethod + def ensure_populated(cls) -> None: + """Ensure the registry is populated with all necessary components.""" + from .registrations import REGISTRY_POPULATED, populate_registry + + if not REGISTRY_POPULATED: + populate_registry() + @classmethod def get_registry_names(cls) -> List[str]: """List all available registries.""" + cls.ensure_populated() names = [] for name, value in inspect.getmembers(cls): if not name.startswith("_") and isinstance(value, Registry): @@ -144,6 +153,7 @@ class registry(thinc.registry): @classmethod def get(cls, registry_name: str, func_name: str) -> Callable: """Get a registered function from the registry.""" + cls.ensure_populated() # We're overwriting this classmethod so we're able to provide more # specific error messages and implement a fallback to spacy-legacy. if not hasattr(cls, registry_name): @@ -179,6 +189,7 @@ class registry(thinc.registry): func_name (str): Name of the registered function. RETURNS (Dict[str, Optional[Union[str, int]]]): The function info. """ + cls.ensure_populated() # We're overwriting this classmethod so we're able to provide more # specific error messages and implement a fallback to spacy-legacy. if not hasattr(cls, registry_name): @@ -205,6 +216,7 @@ class registry(thinc.registry): @classmethod def has(cls, registry_name: str, func_name: str) -> bool: """Check whether a function is available in a registry.""" + cls.ensure_populated() if not hasattr(cls, registry_name): return False reg = getattr(cls, registry_name) @@ -1323,7 +1335,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]: return filter_spans(itertools.chain(*spans)) -@registry.misc("spacy.first_longest_spans_filter.v1") def make_first_longest_spans_filter(): return filter_chain_spans diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 6ff99bb59..d1fb9a747 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -177,7 +177,7 @@ cdef class Vectors(BaseVectors): self.hash_seed = hash_seed self.bow = bow self.eow = eow - if isinstance(attr, (int, long)): + if isinstance(attr, int): self.attr = attr else: attr = attr.upper()