Merge branch 'add/exclusive-spancat' of github.com:ljvmiranda921/spaCy into exclusive-spancat

This commit is contained in:
kadarakos 2023-03-08 11:36:45 +00:00
commit 95206efe95
21 changed files with 199 additions and 69 deletions

View File

@ -59,6 +59,11 @@ steps:
displayName: 'Test download CLI' displayName: 'Test download CLI'
condition: eq(variables['python_version'], '3.8') condition: eq(variables['python_version'], '3.8')
- script: |
python -W error -m spacy info ca_core_news_sm | grep -q download_url
displayName: 'Test download_url in info CLI'
condition: eq(variables['python_version'], '3.8')
- script: | - script: |
python -W error -c "import ca_core_news_sm; nlp = ca_core_news_sm.load(); doc=nlp('test')" python -W error -c "import ca_core_news_sm; nlp = ca_core_news_sm.load(); doc=nlp('test')"
displayName: 'Test no warnings on load (#11713)' displayName: 'Test no warnings on load (#11713)'

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.1.6,<8.2.0", "thinc>=8.1.8,<8.2.0",
"numpy>=1.15.0", "numpy>=1.15.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@ -3,7 +3,7 @@ spacy-legacy>=3.0.11,<3.1.0
spacy-loggers>=1.0.0,<2.0.0 spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.1.6,<8.2.0 thinc>=8.1.8,<8.2.0
ml_datasets>=0.2.0,<0.3.0 ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.2.0 wasabi>=0.9.1,<1.2.0

View File

@ -39,7 +39,7 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc>=8.1.6,<8.2.0 thinc>=8.1.8,<8.2.0
install_requires = install_requires =
# Our libraries # Our libraries
spacy-legacy>=3.0.11,<3.1.0 spacy-legacy>=3.0.11,<3.1.0
@ -47,7 +47,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.1.6,<8.2.0 thinc>=8.1.8,<8.2.0
wasabi>=0.9.1,<1.2.0 wasabi>=0.9.1,<1.2.0
srsly>=2.4.3,<3.0.0 srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0 catalogue>=2.0.6,<2.1.0

View File

@ -1,6 +1,5 @@
from typing import Optional, Dict, Any, Union, List from typing import Optional, Dict, Any, Union, List
import platform import platform
import pkg_resources
import json import json
from pathlib import Path from pathlib import Path
from wasabi import Printer, MarkdownRenderer from wasabi import Printer, MarkdownRenderer
@ -10,6 +9,7 @@ from ._util import app, Arg, Opt, string_to_list
from .download import get_model_filename, get_latest_version from .download import get_model_filename, get_latest_version
from .. import util from .. import util
from .. import about from .. import about
from ..compat import importlib_metadata
@app.command("info") @app.command("info")
@ -137,14 +137,13 @@ def info_installed_model_url(model: str) -> Optional[str]:
dist-info available. dist-info available.
""" """
try: try:
dist = pkg_resources.get_distribution(model) dist = importlib_metadata.distribution(model)
data = json.loads(dist.get_metadata("direct_url.json")) text = dist.read_text("direct_url.json")
if isinstance(text, str):
data = json.loads(text)
return data["url"] return data["url"]
except pkg_resources.DistributionNotFound:
# no such package
return None
except Exception: except Exception:
# something else, like no file or invalid JSON pass
return None return None

View File

@ -2,7 +2,6 @@ from typing import Optional, List, Dict, Sequence, Any, Iterable, Tuple
import os.path import os.path
from pathlib import Path from pathlib import Path
import pkg_resources
from wasabi import msg from wasabi import msg
from wasabi.util import locale_escape from wasabi.util import locale_escape
import sys import sys
@ -331,6 +330,7 @@ def _check_requirements(requirements: List[str]) -> Tuple[bool, bool]:
RETURNS (Tuple[bool, bool]): Whether (1) any packages couldn't be imported, (2) any packages with version conflicts RETURNS (Tuple[bool, bool]): Whether (1) any packages couldn't be imported, (2) any packages with version conflicts
exist. exist.
""" """
import pkg_resources
failed_pkgs_msgs: List[str] = [] failed_pkgs_msgs: List[str] = []
conflicting_pkgs_msgs: List[str] = [] conflicting_pkgs_msgs: List[str] = []

View File

@ -24,8 +24,11 @@ gpu_allocator = null
lang = "{{ lang }}" lang = "{{ lang }}"
{%- set has_textcat = ("textcat" in components or "textcat_multilabel" in components) -%} {%- set has_textcat = ("textcat" in components or "textcat_multilabel" in components) -%}
{%- set with_accuracy = optimize == "accuracy" -%} {%- set with_accuracy = optimize == "accuracy" -%}
{%- set has_accurate_textcat = has_textcat and with_accuracy -%} {# The BOW textcat doesn't need a source of features, so it can omit the
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or has_accurate_textcat) -%} tok2vec/transformer. #}
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%} {%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
{%- else -%} {%- else -%}
{%- set full_pipeline = components -%} {%- set full_pipeline = components -%}
@ -221,10 +224,16 @@ no_output_layer = false
{% else -%} {% else -%}
[components.textcat.model] [components.textcat.model]
@architectures = "spacy.TextCatBOW.v2" @architectures = "spacy.TextCatCNN.v2"
exclusive_classes = true exclusive_classes = true
ngram_size = 1 nO = null
no_output_layer = false
[components.textcat.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
[components.textcat.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
{%- endif %} {%- endif %}
{%- endif %} {%- endif %}
@ -252,10 +261,16 @@ no_output_layer = false
{% else -%} {% else -%}
[components.textcat_multilabel.model] [components.textcat_multilabel.model]
@architectures = "spacy.TextCatBOW.v2" @architectures = "spacy.TextCatCNN.v2"
exclusive_classes = false exclusive_classes = false
ngram_size = 1 nO = null
no_output_layer = false
[components.textcat_multilabel.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
[components.textcat_multilabel.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
{%- endif %} {%- endif %}
{%- endif %} {%- endif %}

View File

@ -549,6 +549,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'") "during training, make sure to include it in 'annotating components'")
# New errors added in v3.x # New errors added in v3.x
E850 = ("The PretrainVectors objective currently only supports default "
"vectors, not {mode} vectors.")
E851 = ("The 'textcat' component labels should only have values of 0 or 1, " E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
"but found value of '{val}'.") "but found value of '{val}'.")
E852 = ("The tar file pulled from the remote attempted an unsafe path " E852 = ("The tar file pulled from the remote attempted an unsafe path "

View File

@ -25,7 +25,8 @@ class Lexeme:
def orth_(self) -> str: ... def orth_(self) -> str: ...
@property @property
def text(self) -> str: ... def text(self) -> str: ...
lower: str orth: int
lower: int
norm: int norm: int
shape: int shape: int
prefix: int prefix: int

View File

@ -199,7 +199,7 @@ cdef class Lexeme:
return self.orth_ return self.orth_
property lower: property lower:
"""RETURNS (str): Lowercase form of the lexeme.""" """RETURNS (uint64): Lowercase form of the lexeme."""
def __get__(self): def __get__(self):
return self.c.lower return self.c.lower

View File

@ -89,6 +89,14 @@ def load_kb(
return kb_from_file return kb_from_file
@registry.misc("spacy.EmptyKB.v2")
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.misc("spacy.EmptyKB.v1") @registry.misc("spacy.EmptyKB.v1")
def empty_kb( def empty_kb(
entity_vector_length: int, entity_vector_length: int,

View File

@ -8,6 +8,7 @@ from thinc.loss import Loss
from ...util import registry, OOV_RANK from ...util import registry, OOV_RANK
from ...errors import Errors from ...errors import Errors
from ...attrs import ID from ...attrs import ID
from ...vectors import Mode as VectorsMode
import numpy import numpy
from functools import partial from functools import partial
@ -23,6 +24,8 @@ def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
if vocab.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
if vocab.vectors.shape[1] == 0: if vocab.vectors.shape[1] == 0:
raise ValueError(Errors.E875) raise ValueError(Errors.E875)
model = build_cloze_multi_task_model( model = build_cloze_multi_task_model(

View File

@ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"entity_vector_length": 64, "entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"overwrite": True, "overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
@ -80,6 +81,7 @@ def make_entity_linker(
get_candidates_batch: Callable[ get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
@ -101,6 +103,7 @@ def make_entity_linker(
get_candidates_batch ( get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
@ -135,6 +138,7 @@ def make_entity_linker(
entity_vector_length=entity_vector_length, entity_vector_length=entity_vector_length,
get_candidates=get_candidates, get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch, get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite, overwrite=overwrite,
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
@ -175,6 +179,7 @@ class EntityLinker(TrainablePipe):
get_candidates_batch: Callable[ get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
@ -198,6 +203,7 @@ class EntityLinker(TrainablePipe):
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
Iterable[Candidate]] Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
@ -220,6 +226,7 @@ class EntityLinker(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self.labels_discard = list(labels_discard) self.labels_discard = list(labels_discard)
# how many neighbour sentences to take into account
self.n_sents = n_sents self.n_sents = n_sents
self.incl_prior = incl_prior self.incl_prior = incl_prior
self.incl_context = incl_context self.incl_context = incl_context
@ -227,9 +234,7 @@ class EntityLinker(TrainablePipe):
self.get_candidates_batch = get_candidates_batch self.get_candidates_batch = get_candidates_batch
self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account self.kb = generate_empty_kb(self.vocab, entity_vector_length)
# create an empty KB by default
self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size self.candidates_batch_size = candidates_batch_size

View File

@ -9,6 +9,8 @@ from spacy.lang.en import English
from spacy.lang.it import Italian from spacy.lang.it import Italian
from spacy.language import Language from spacy.language import Language
from spacy.lookups import Lookups from spacy.lookups import Lookups
from spacy.pipeline import EntityRecognizer
from spacy.pipeline.ner import DEFAULT_NER_MODEL
from spacy.pipeline._parser_internals.ner import BiluoPushDown from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.training import Example, iob_to_biluo, split_bilu_label from spacy.training import Example, iob_to_biluo, split_bilu_label
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
@ -16,8 +18,6 @@ from spacy.vocab import Vocab
import logging import logging
from ..util import make_tempdir from ..util import make_tempdir
from ...pipeline import EntityRecognizer
from ...pipeline.ner import DEFAULT_NER_MODEL
TRAIN_DATA = [ TRAIN_DATA = [
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}), ("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),

View File

@ -8,11 +8,11 @@ from spacy.lang.en import English
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.pipeline import DependencyParser
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from ...pipeline import DependencyParser
from ...pipeline.dep_parser import DEFAULT_PARSER_MODEL
from ..util import apply_transition_sequence, make_tempdir from ..util import apply_transition_sequence, make_tempdir
from ...pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
TRAIN_DATA = [ TRAIN_DATA = [
( (

View File

@ -1,7 +1,10 @@
from typing import Callable from pathlib import Path
from typing import Callable, Iterable, Any, Dict
from spacy import util import srsly
from spacy.util import ensure_path, registry, load_model_from_config
from spacy import util, Errors
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
from spacy.kb.kb_in_memory import InMemoryLookupKB from spacy.kb.kb_in_memory import InMemoryLookupKB
from spacy.vocab import Vocab from spacy.vocab import Vocab
from thinc.api import Config from thinc.api import Config
@ -92,6 +95,9 @@ def test_serialize_subclassed_kb():
[components.entity_linker] [components.entity_linker]
factory = "entity_linker" factory = "entity_linker"
[components.entity_linker.generate_empty_kb]
@misc = "kb_test.CustomEmptyKB.v1"
[initialize] [initialize]
[initialize.components] [initialize.components]
@ -99,7 +105,7 @@ def test_serialize_subclassed_kb():
[initialize.components.entity_linker] [initialize.components.entity_linker]
[initialize.components.entity_linker.kb_loader] [initialize.components.entity_linker.kb_loader]
@misc = "spacy.CustomKB.v1" @misc = "kb_test.CustomKB.v1"
entity_vector_length = 342 entity_vector_length = 342
custom_field = 666 custom_field = 666
""" """
@ -109,10 +115,57 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length) super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field self.custom_field = custom_field
@registry.misc("spacy.CustomKB.v1") def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
path = ensure_path(path)
if not path.exists():
path.mkdir(parents=True)
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
def serialize_custom_fields(file_path: Path) -> None:
srsly.write_json(file_path, {"custom_field": self.custom_field})
serialize = {
"contents": lambda p: self.write_contents(p),
"strings.json": lambda p: self.vocab.strings.to_disk(p),
"custom_fields": lambda p: serialize_custom_fields(p),
}
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
"""We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well."""
path = ensure_path(path)
if not path.exists():
raise ValueError(Errors.E929.format(loc=path))
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
def deserialize_custom_fields(file_path: Path) -> None:
self.custom_field = srsly.read_json(file_path)["custom_field"]
deserialize: Dict[str, Callable[[Any], Any]] = {
"contents": lambda p: self.read_contents(p),
"strings.json": lambda p: self.vocab.strings.from_disk(p),
"custom_fields": lambda p: deserialize_custom_fields(p),
}
util.from_disk(path, deserialize, exclude)
@registry.misc("kb_test.CustomEmptyKB.v1")
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
return SubInMemoryLookupKB(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=0,
)
return empty_kb_factory
@registry.misc("kb_test.CustomKB.v1")
def custom_kb( def custom_kb(
entity_vector_length: int, custom_field: int entity_vector_length: int, custom_field: int
) -> Callable[[Vocab], InMemoryLookupKB]: ) -> Callable[[Vocab], SubInMemoryLookupKB]:
def custom_kb_factory(vocab): def custom_kb_factory(vocab):
kb = SubInMemoryLookupKB( kb = SubInMemoryLookupKB(
vocab=vocab, vocab=vocab,
@ -139,6 +192,6 @@ def test_serialize_subclassed_kb():
nlp2 = util.load_model_from_path(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker") entity_linker2 = nlp2.get_pipe("entity_linker")
# After IO, the KB is the standard one # After IO, the KB is the standard one
assert type(entity_linker2.kb) == InMemoryLookupKB assert type(entity_linker2.kb) == SubInMemoryLookupKB
assert entity_linker2.kb.entity_vector_length == 342 assert entity_linker2.kb.entity_vector_length == 342
assert not hasattr(entity_linker2.kb, "custom_field") assert entity_linker2.kb.custom_field == 666

View File

@ -2,7 +2,6 @@ import os
import math import math
from collections import Counter from collections import Counter
from typing import Tuple, List, Dict, Any from typing import Tuple, List, Dict, Any
import pkg_resources
import time import time
from pathlib import Path from pathlib import Path
@ -29,6 +28,7 @@ from spacy.cli.debug_data import _print_span_characteristics
from spacy.cli.debug_data import _get_spans_length_freq_dist from spacy.cli.debug_data import _get_spans_length_freq_dist
from spacy.cli.download import get_compatibility, get_version from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.init_pipeline import _init_labels
from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name from spacy.cli.package import _is_permitted_package_name
from spacy.cli.project.remote_storage import RemoteStorage from spacy.cli.project.remote_storage import RemoteStorage
@ -47,7 +47,6 @@ from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs from spacy.training.converters import iob_to_docs
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
from ..cli.init_pipeline import _init_labels
from .util import make_tempdir from .util import make_tempdir
@ -1126,6 +1125,7 @@ def test_cli_find_threshold(capsys):
) )
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"reqs,output", "reqs,output",
[ [
@ -1158,6 +1158,8 @@ def test_cli_find_threshold(capsys):
], ],
) )
def test_project_check_requirements(reqs, output): def test_project_check_requirements(reqs, output):
import pkg_resources
# excessive guard against unlikely package name # excessive guard against unlikely package name
try: try:
pkg_resources.require("spacyunknowndoesnotexist12345") pkg_resources.require("spacyunknowndoesnotexist12345")

View File

@ -2,17 +2,19 @@ from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
import srsly import srsly
from spacy.vocab import Vocab from thinc.api import Config, get_current_ops
from thinc.api import Config
from spacy import util
from spacy.lang.en import English
from spacy.training.initialize import init_nlp
from spacy.training.loop import train
from spacy.training.pretrain import pretrain
from spacy.tokens import Doc, DocBin
from spacy.language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
from spacy.ml.models.multi_task import create_pretrain_vectors
from spacy.vectors import Vectors
from spacy.vocab import Vocab
from ..util import make_tempdir from ..util import make_tempdir
from ... import util
from ...lang.en import English
from ...training.initialize import init_nlp
from ...training.loop import train
from ...training.pretrain import pretrain
from ...tokens import Doc, DocBin
from ...language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
pretrain_string_listener = """ pretrain_string_listener = """
[nlp] [nlp]
@ -346,3 +348,30 @@ def write_vectors_model(tmp_dir):
nlp = English(vocab) nlp = English(vocab)
nlp.to_disk(nlp_path) nlp.to_disk(nlp_path)
return str(nlp_path) return str(nlp_path)
def test_pretrain_default_vectors():
nlp = English()
nlp.add_pipe("tok2vec")
nlp.initialize()
# default vectors are supported
nlp.vocab.vectors = Vectors(shape=(10, 10))
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
# error for no vectors
with pytest.raises(ValueError, match="E875"):
nlp.vocab.vectors = Vectors()
create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model
)
# error for floret vectors
with pytest.raises(ValueError, match="E850"):
ops = get_current_ops()
nlp.vocab.vectors = Vectors(
data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1
)
create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model
)

View File

@ -899,15 +899,21 @@ The `EntityLinker` model architecture is a Thinc `Model` with a
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ | | `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ | | **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
### spacy.EmptyKB.v1 {id="EmptyKB"} ### spacy.EmptyKB.v1 {id="EmptyKB.v1"}
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
instance. This is the default when a new entity linker component is created. instance.
| Name | Description | | Name | Description |
| ---------------------- | ----------------------------------------------------------------------------------- | | ---------------------- | ----------------------------------------------------------------------------------- |
| `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ | | `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ |
### spacy.EmptyKB.v2 {id="EmptyKB"}
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
instance. This is the default when a new entity linker component is created. It
returns a `Callable[[Vocab, int], InMemoryLookupKB]`.
### spacy.KBFromFile.v1 {id="KBFromFile"} ### spacy.KBFromFile.v1 {id="KBFromFile"}
A function that reads an existing `KnowledgeBase` from file. A function that reads an existing `KnowledgeBase` from file.

View File

@ -54,7 +54,7 @@ architectures and their arguments and hyperparameters.
> ``` > ```
| Setting | Description | | Setting | Description |
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ | | `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
@ -63,6 +63,8 @@ architectures and their arguments and hyperparameters.
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | | `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
| `generate_empty_kb` <Tag variant="new">3.6</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ | | `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |

View File

@ -68,11 +68,11 @@ architectures and their arguments and hyperparameters.
> "spans_key": "labeled_spans", > "spans_key": "labeled_spans",
> "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL, > "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
> "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, > "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
> # Additional spancat_exclusive parameters > # Additional spancat_singlelabel parameters
> "negative_weight": 0.8, > "negative_weight": 0.8,
> "allow_overlap": True, > "allow_overlap": True,
> } > }
> nlp.add_pipe("spancat_exclusive", config=config) > nlp.add_pipe("spancat_singlelabel", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |