Merge pull request #8787 from adrianeboyd/chore/backport-v3.0.7

Backport bug fixes to v3.0.x
This commit is contained in:
Adriane Boyd 2021-07-21 16:53:50 +02:00 committed by GitHub
commit 034ac0acf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 209 additions and 105 deletions

View File

@ -8,3 +8,4 @@ recursive-exclude spacy/lang *.json
recursive-include spacy/lang *.json.gz recursive-include spacy/lang *.json.gz
recursive-include spacy/cli *.json *.yml recursive-include spacy/cli *.json *.yml
recursive-include licenses * recursive-include licenses *
recursive-exclude spacy *.cpp

View File

@ -1,6 +1,6 @@
# fmt: off # fmt: off
__title__ = "spacy" __title__ = "spacy"
__version__ = "3.0.6" __version__ = "3.0.7"
__download_url__ = "https://github.com/explosion/spacy-models/releases/download" __download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json" __compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__projects__ = "https://github.com/explosion/projects" __projects__ = "https://github.com/explosion/projects"

View File

@ -115,7 +115,8 @@ def convert(
ner_map = srsly.read_json(ner_map) if ner_map is not None else None ner_map = srsly.read_json(ner_map) if ner_map is not None else None
doc_files = [] doc_files = []
for input_loc in walk_directory(Path(input_path), converter): for input_loc in walk_directory(Path(input_path), converter):
input_data = input_loc.open("r", encoding="utf-8").read() with input_loc.open("r", encoding="utf-8") as infile:
input_data = infile.read()
# Use converter function to convert data # Use converter function to convert data
func = CONVERTERS[converter] func = CONVERTERS[converter]
docs = func( docs = func(

View File

@ -18,7 +18,7 @@ def package_cli(
output_dir: Path = Arg(..., help="Output parent directory", exists=True, file_okay=False), output_dir: Path = Arg(..., help="Output parent directory", exists=True, file_okay=False),
code_paths: str = Opt("", "--code", "-c", help="Comma-separated paths to Python file with additional code (registered functions) to be included in the package"), code_paths: str = Opt("", "--code", "-c", help="Comma-separated paths to Python file with additional code (registered functions) to be included in the package"),
meta_path: Optional[Path] = Opt(None, "--meta-path", "--meta", "-m", help="Path to meta.json", exists=True, dir_okay=False), meta_path: Optional[Path] = Opt(None, "--meta-path", "--meta", "-m", help="Path to meta.json", exists=True, dir_okay=False),
create_meta: bool = Opt(False, "--create-meta", "-c", "-C", help="Create meta.json, even if one exists"), create_meta: bool = Opt(False, "--create-meta", "-C", help="Create meta.json, even if one exists"),
name: Optional[str] = Opt(None, "--name", "-n", help="Package name to override meta"), name: Optional[str] = Opt(None, "--name", "-n", help="Package name to override meta"),
version: Optional[str] = Opt(None, "--version", "-v", help="Package version to override meta"), version: Optional[str] = Opt(None, "--version", "-v", help="Package version to override meta"),
build: str = Opt("sdist", "--build", "-b", help="Comma-separated formats to build: sdist and/or wheel, or none."), build: str = Opt("sdist", "--build", "-b", help="Comma-separated formats to build: sdist and/or wheel, or none."),

View File

@ -418,7 +418,7 @@ compound = 1.001
[initialize] [initialize]
{% if use_transformer or optimize == "efficiency" or not word_vectors -%} {% if use_transformer or optimize == "efficiency" or not word_vectors -%}
vectors = null vectors = ${paths.vectors}
{% else -%} {% else -%}
vectors = "{{ word_vectors }}" vectors = "{{ word_vectors }}"
{% endif -%} {% endif -%}

View File

@ -518,6 +518,11 @@ class Errors:
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.") E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
# New errors added in v3.x # New errors added in v3.x
E867 = ("The 'textcat' component requires at least two labels because it "
"uses mutually exclusive classes where exactly one label is True "
"for each doc. For binary classification tasks, you can use two "
"labels with 'textcat' (LABEL / NOT_LABEL) or alternatively, you "
"can use the 'textcat_multilabel' component with one label.")
E870 = ("Could not serialize the DocBin because it is too large. Consider " E870 = ("Could not serialize the DocBin because it is too large. Consider "
"splitting up your documents into several doc bins and serializing " "splitting up your documents into several doc bins and serializing "
"each separately. spacy.Corpus.v1 will search recursively for all " "each separately. spacy.Corpus.v1 will search recursively for all "

View File

@ -1,16 +1,11 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS, TOKEN_MATCH
from .stop_words import STOP_WORDS from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
from .lex_attrs import LEX_ATTRS from .lex_attrs import LEX_ATTRS
from ...language import Language from ...language import Language
class AzerbaijaniDefaults(Language.Defaults): class AzerbaijaniDefaults(Language.Defaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS stop_words = STOP_WORDS
token_match = TOKEN_MATCH
syntax_iterators = SYNTAX_ITERATORS
class Azerbaijani(Language): class Azerbaijani(Language):

View File

@ -57,6 +57,6 @@ class GreekLemmatizer(Lemmatizer):
forms.extend(oov_forms) forms.extend(oov_forms)
if not forms: if not forms:
forms.append(string) forms.append(string)
forms = list(set(forms)) forms = list(dict.fromkeys(forms))
self.cache[cache_key] = forms self.cache[cache_key] = forms
return forms return forms

View File

@ -12,7 +12,6 @@ PUNCT_RULES = {"«": '"', "»": '"'}
class RussianLemmatizer(Lemmatizer): class RussianLemmatizer(Lemmatizer):
_morph = None
def __init__( def __init__(
self, self,
@ -31,8 +30,8 @@ class RussianLemmatizer(Lemmatizer):
"The Russian lemmatizer mode 'pymorphy2' requires the " "The Russian lemmatizer mode 'pymorphy2' requires the "
"pymorphy2 library. Install it with: pip install pymorphy2" "pymorphy2 library. Install it with: pip install pymorphy2"
) from None ) from None
if RussianLemmatizer._morph is None: if getattr(self, "_morph", None) is None:
RussianLemmatizer._morph = MorphAnalyzer() self._morph = MorphAnalyzer()
super().__init__(vocab, model, name, mode=mode, overwrite=overwrite) super().__init__(vocab, model, name, mode=mode, overwrite=overwrite)
def pymorphy2_lemmatize(self, token: Token) -> List[str]: def pymorphy2_lemmatize(self, token: Token) -> List[str]:

View File

@ -7,8 +7,6 @@ from ...vocab import Vocab
class UkrainianLemmatizer(RussianLemmatizer): class UkrainianLemmatizer(RussianLemmatizer):
_morph = None
def __init__( def __init__(
self, self,
vocab: Vocab, vocab: Vocab,
@ -27,6 +25,6 @@ class UkrainianLemmatizer(RussianLemmatizer):
"pymorphy2 library and dictionaries. Install them with: " "pymorphy2 library and dictionaries. Install them with: "
"pip install pymorphy2 pymorphy2-dicts-uk" "pip install pymorphy2 pymorphy2-dicts-uk"
) from None ) from None
if UkrainianLemmatizer._morph is None: if getattr(self, "_morph", None) is None:
UkrainianLemmatizer._morph = MorphAnalyzer(lang="uk") self._morph = MorphAnalyzer(lang="uk")
super().__init__(vocab, model, name, mode=mode, overwrite=overwrite) super().__init__(vocab, model, name, mode=mode, overwrite=overwrite)

View File

@ -50,6 +50,8 @@ cdef class PhraseMatcher:
if isinstance(attr, (int, long)): if isinstance(attr, (int, long)):
self.attr = attr self.attr = attr
else: else:
if attr is None:
attr = "ORTH"
attr = attr.upper() attr = attr.upper()
if attr == "TEXT": if attr == "TEXT":
attr = "ORTH" attr = "ORTH"

View File

@ -3,7 +3,7 @@ from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Mode
from thinc.api import MultiSoftmax, list2array from thinc.api import MultiSoftmax, list2array
from thinc.api import to_categorical, CosineDistance, L2Distance from thinc.api import to_categorical, CosineDistance, L2Distance
from ...util import registry from ...util import registry, OOV_RANK
from ...errors import Errors from ...errors import Errors
from ...attrs import ID from ...attrs import ID
@ -70,6 +70,7 @@ def get_vectors_loss(ops, docs, prediction, distance):
# and look them up all at once. This prevents data copying. # and look them up all at once. This prevents data copying.
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
target = docs[0].vocab.vectors.data[ids] target = docs[0].vocab.vectors.data[ids]
target[ids == OOV_RANK] = 0
d_target, loss = distance(prediction, target) d_target, loss = distance(prediction, target)
return loss, d_target return loss, d_target

View File

@ -481,7 +481,8 @@ class EntityLinker(TrainablePipe):
def load_model(p): def load_model(p):
try: try:
self.model.from_bytes(p.open("rb").read()) with p.open("rb") as infile:
self.model.from_bytes(infile.read())
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None

View File

@ -3,6 +3,7 @@ from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable,
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import srsly import srsly
import warnings
from .pipe import Pipe from .pipe import Pipe
from ..training import Example from ..training import Example
@ -102,17 +103,12 @@ class EntityRuler(Pipe):
self.overwrite = overwrite_ents self.overwrite = overwrite_ents
self.token_patterns = defaultdict(list) self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self._validate = validate
self.matcher = Matcher(nlp.vocab, validate=validate) self.matcher = Matcher(nlp.vocab, validate=validate)
if phrase_matcher_attr is not None:
if phrase_matcher_attr.upper() == "TEXT":
phrase_matcher_attr = "ORTH"
self.phrase_matcher_attr = phrase_matcher_attr self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
) )
else:
self.phrase_matcher_attr = None
self.phrase_matcher = PhraseMatcher(nlp.vocab, validate=validate)
self.ent_id_sep = ent_id_sep self.ent_id_sep = ent_id_sep
self._ent_ids = defaultdict(dict) self._ent_ids = defaultdict(dict)
if patterns is not None: if patterns is not None:
@ -146,6 +142,8 @@ class EntityRuler(Pipe):
def match(self, doc: Doc): def match(self, doc: Doc):
self._require_patterns() self._require_patterns()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W036")
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set( matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end] [(m_id, start, end) for m_id, start, end in matches if start != end]
@ -281,7 +279,7 @@ class EntityRuler(Pipe):
current_index = i current_index = i
break break
subsequent_pipes = [ subsequent_pipes = [
pipe for pipe in self.nlp.pipe_names[current_index + 1 :] pipe for pipe in self.nlp.pipe_names[current_index :]
] ]
except ValueError: except ValueError:
subsequent_pipes = [] subsequent_pipes = []
@ -317,20 +315,22 @@ class EntityRuler(Pipe):
pattern = entry["pattern"] pattern = entry["pattern"]
if isinstance(pattern, Doc): if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern) self.phrase_patterns[label].append(pattern)
self.phrase_matcher.add(label, [pattern])
elif isinstance(pattern, list): elif isinstance(pattern, list):
self.token_patterns[label].append(pattern) self.token_patterns[label].append(pattern)
self.matcher.add(label, [pattern])
else: else:
raise ValueError(Errors.E097.format(pattern=pattern)) raise ValueError(Errors.E097.format(pattern=pattern))
for label, patterns in self.token_patterns.items():
self.matcher.add(label, patterns)
for label, patterns in self.phrase_patterns.items():
self.phrase_matcher.add(label, patterns)
def clear(self) -> None: def clear(self) -> None:
"""Reset all patterns.""" """Reset all patterns."""
self.token_patterns = defaultdict(list) self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict) self._ent_ids = defaultdict(dict)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
)
def _require_patterns(self) -> None: def _require_patterns(self) -> None:
"""Raise a warning if this component has no patterns defined.""" """Raise a warning if this component has no patterns defined."""
@ -381,7 +381,6 @@ class EntityRuler(Pipe):
self.add_patterns(cfg.get("patterns", cfg)) self.add_patterns(cfg.get("patterns", cfg))
self.overwrite = cfg.get("overwrite", False) self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr self.nlp.vocab, attr=self.phrase_matcher_attr
) )
@ -435,7 +434,6 @@ class EntityRuler(Pipe):
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr self.nlp.vocab, attr=self.phrase_matcher_attr
) )

View File

@ -332,6 +332,8 @@ class TextCategorizer(TrainablePipe):
else: else:
for label in labels: for label in labels:
self.add_label(label) self.add_label(label)
if len(self.labels) < 2:
raise ValueError(Errors.E867)
if positive_label is not None: if positive_label is not None:
if positive_label not in self.labels: if positive_label not in self.labels:
err = Errors.E920.format(pos_label=positive_label, labels=self.labels) err = Errors.E920.format(pos_label=positive_label, labels=self.labels)

View File

@ -324,7 +324,8 @@ cdef class TrainablePipe(Pipe):
def load_model(p): def load_model(p):
try: try:
self.model.from_bytes(p.open("rb").read()) with open(p, "rb") as mfile:
self.model.from_bytes(mfile.read())
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None

View File

@ -351,13 +351,21 @@ def test_doc_from_array_morph(en_vocab):
@pytest.mark.usefixtures("clean_underscore") @pytest.mark.usefixtures("clean_underscore")
def test_doc_api_from_docs(en_tokenizer, de_tokenizer): def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
en_texts = ["Merging the docs is fun.", "", "They don't think alike."] en_texts = [
"Merging the docs is fun.",
"",
"They don't think alike. ",
"Another doc.",
]
en_texts_without_empty = [t for t in en_texts if len(t)] en_texts_without_empty = [t for t in en_texts if len(t)]
de_text = "Wie war die Frage?" de_text = "Wie war die Frage?"
en_docs = [en_tokenizer(text) for text in en_texts] en_docs = [en_tokenizer(text) for text in en_texts]
en_docs[0].spans["group"] = [en_docs[0][1:4]] en_docs[0].spans["group"] = [en_docs[0][1:4]]
en_docs[2].spans["group"] = [en_docs[2][1:4]] en_docs[2].spans["group"] = [en_docs[2][1:4]]
span_group_texts = sorted([en_docs[0][1:4].text, en_docs[2][1:4].text]) en_docs[3].spans["group"] = [en_docs[3][0:1]]
span_group_texts = sorted(
[en_docs[0][1:4].text, en_docs[2][1:4].text, en_docs[3][0:1].text]
)
de_doc = de_tokenizer(de_text) de_doc = de_tokenizer(de_text)
Token.set_extension("is_ambiguous", default=False) Token.set_extension("is_ambiguous", default=False)
en_docs[0][2]._.is_ambiguous = True # docs en_docs[0][2]._.is_ambiguous = True # docs
@ -371,8 +379,8 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
m_doc = Doc.from_docs(en_docs) m_doc = Doc.from_docs(en_docs)
assert len(en_texts_without_empty) == len(list(m_doc.sents)) assert len(en_texts_without_empty) == len(list(m_doc.sents))
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1]) assert len(m_doc.text) > len(en_texts[0]) + len(en_texts[1])
assert str(m_doc) == " ".join(en_texts_without_empty) assert m_doc.text == " ".join([t.strip() for t in en_texts_without_empty])
p_token = m_doc[len(en_docs[0]) - 1] p_token = m_doc[len(en_docs[0]) - 1]
assert p_token.text == "." and bool(p_token.whitespace_) assert p_token.text == "." and bool(p_token.whitespace_)
en_docs_tokens = [t for doc in en_docs for t in doc] en_docs_tokens = [t for doc in en_docs for t in doc]
@ -384,11 +392,12 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert not any([t._.is_ambiguous for t in m_doc[3:8]]) assert not any([t._.is_ambiguous for t in m_doc[3:8]])
assert "group" in m_doc.spans assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
assert bool(m_doc[11].whitespace_)
m_doc = Doc.from_docs(en_docs, ensure_whitespace=False) m_doc = Doc.from_docs(en_docs, ensure_whitespace=False)
assert len(en_texts_without_empty) == len(list(m_doc.sents)) assert len(en_texts_without_empty) == len(list(m_doc.sents))
assert len(str(m_doc)) == sum(len(t) for t in en_texts) assert len(m_doc.text) == sum(len(t) for t in en_texts)
assert str(m_doc) == "".join(en_texts) assert m_doc.text == "".join(en_texts_without_empty)
p_token = m_doc[len(en_docs[0]) - 1] p_token = m_doc[len(en_docs[0]) - 1]
assert p_token.text == "." and not bool(p_token.whitespace_) assert p_token.text == "." and not bool(p_token.whitespace_)
en_docs_tokens = [t for doc in en_docs for t in doc] en_docs_tokens = [t for doc in en_docs for t in doc]
@ -397,11 +406,12 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
assert m_doc[9].idx == think_idx assert m_doc[9].idx == think_idx
assert "group" in m_doc.spans assert "group" in m_doc.spans
assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]]) assert span_group_texts == sorted([s.text for s in m_doc.spans["group"]])
assert bool(m_doc[11].whitespace_)
m_doc = Doc.from_docs(en_docs, attrs=["lemma", "length", "pos"]) m_doc = Doc.from_docs(en_docs, attrs=["lemma", "length", "pos"])
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1]) assert len(m_doc.text) > len(en_texts[0]) + len(en_texts[1])
# space delimiter considered, although spacy attribute was missing # space delimiter considered, although spacy attribute was missing
assert str(m_doc) == " ".join(en_texts_without_empty) assert m_doc.text == " ".join([t.strip() for t in en_texts_without_empty])
p_token = m_doc[len(en_docs[0]) - 1] p_token = m_doc[len(en_docs[0]) - 1]
assert p_token.text == "." and bool(p_token.whitespace_) assert p_token.text == "." and bool(p_token.whitespace_)
en_docs_tokens = [t for doc in en_docs for t in doc] en_docs_tokens = [t for doc in en_docs for t in doc]
@ -414,6 +424,16 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
# can merge empty docs # can merge empty docs
doc = Doc.from_docs([en_tokenizer("")] * 10) doc = Doc.from_docs([en_tokenizer("")] * 10)
# empty but set spans keys are preserved
en_docs = [en_tokenizer(text) for text in en_texts]
m_doc = Doc.from_docs(en_docs)
assert "group" not in m_doc.spans
for doc in en_docs:
doc.spans["group"] = []
m_doc = Doc.from_docs(en_docs)
assert "group" in m_doc.spans
assert len(m_doc.spans["group"]) == 0
def test_doc_api_from_docs_ents(en_tokenizer): def test_doc_api_from_docs_ents(en_tokenizer):
texts = ["Merging the docs is fun.", "They don't think alike."] texts = ["Merging the docs is fun.", "They don't think alike."]

View File

@ -4,12 +4,13 @@ from spacy.util import get_lang_class
# fmt: off # fmt: off
# Only include languages with no external dependencies # Only include languages with no external dependencies
# excluded: ja, ru, th, uk, vi, zh # excluded: ja, ko, th, vi, zh
LANGUAGES = ["af", "ar", "bg", "bn", "ca", "cs", "da", "de", "el", "en", "es", LANGUAGES = ["af", "am", "ar", "az", "bg", "bn", "ca", "cs", "da", "de", "el",
"et", "fa", "fi", "fr", "ga", "he", "hi", "hr", "hu", "id", "is", "en", "es", "et", "eu", "fa", "fi", "fr", "ga", "gu", "he", "hi",
"it", "kn", "lt", "lv", "nb", "nl", "pl", "pt", "ro", "si", "sk", "hr", "hu", "hy", "id", "is", "it", "kn", "ky", "lb", "lt", "lv",
"sl", "sq", "sr", "sv", "ta", "te", "tl", "tn", "tr", "tt", "ur", "mk", "ml", "mr", "nb", "ne", "nl", "pl", "pt", "ro", "ru", "sa",
"yo"] "si", "sk", "sl", "sq", "sr", "sv", "ta", "te", "ti", "tl", "tn",
"tr", "tt", "uk", "ur", "xx", "yo"]
# fmt: on # fmt: on

View File

@ -481,6 +481,7 @@ def test_matcher_schema_token_attributes(en_vocab, pattern, text):
assert len(matches) == 1 assert len(matches) == 1
@pytest.mark.filterwarnings("ignore:\\[W036")
def test_matcher_valid_callback(en_vocab): def test_matcher_valid_callback(en_vocab):
"""Test that on_match can only be None or callable.""" """Test that on_match can only be None or callable."""
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)

View File

@ -180,6 +180,7 @@ def test_matcher_sets_return_correct_tokens(en_vocab):
assert texts == ["zero", "one", "two"] assert texts == ["zero", "one", "two"]
@pytest.mark.filterwarnings("ignore:\\[W036")
def test_matcher_remove(): def test_matcher_remove():
nlp = English() nlp = English()
matcher = Matcher(nlp.vocab) matcher = Matcher(nlp.vocab)

View File

@ -252,12 +252,12 @@ def test_ruler_before_ner():
# 1 : Entity Ruler - should set "this" to B and everything else to empty # 1 : Entity Ruler - should set "this" to B and everything else to empty
patterns = [{"label": "THING", "pattern": "This"}] patterns = [{"label": "THING", "pattern": "This"}]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
# 2: untrained NER - should set everything else to O # 2: untrained NER - should set everything else to O
untrained_ner = nlp.add_pipe("ner") untrained_ner = nlp.add_pipe("ner")
untrained_ner.add_label("MY_LABEL") untrained_ner.add_label("MY_LABEL")
nlp.initialize() nlp.initialize()
ruler.add_patterns(patterns)
doc = nlp("This is Antti Korhonen speaking in Finland") doc = nlp("This is Antti Korhonen speaking in Finland")
expected_iobs = ["B", "O", "O", "O", "O", "O", "O"] expected_iobs = ["B", "O", "O", "O", "O", "O", "O"]
expected_types = ["THING", "", "", "", "", "", ""] expected_types = ["THING", "", "", "", "", "", ""]

View File

@ -324,6 +324,7 @@ def test_append_alias(nlp):
assert len(mykb.get_alias_candidates("douglas")) == 3 assert len(mykb.get_alias_candidates("douglas")) == 3
@pytest.mark.filterwarnings("ignore:\\[W036")
def test_append_invalid_alias(nlp): def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1""" """Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
@ -342,6 +343,7 @@ def test_append_invalid_alias(nlp):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
@pytest.mark.filterwarnings("ignore:\\[W036")
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
vector_length = 1 vector_length = 1

View File

@ -89,6 +89,20 @@ def test_entity_ruler_init_clear(nlp, patterns):
assert len(ruler.labels) == 0 assert len(ruler.labels) == 0
def test_entity_ruler_clear(nlp, patterns):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
assert len(ruler.labels) == 4
doc = nlp("hello world")
assert len(doc.ents) == 1
ruler.clear()
assert len(ruler.labels) == 0
with pytest.warns(UserWarning):
doc = nlp("hello world")
assert len(doc.ents) == 0
def test_entity_ruler_existing(nlp, patterns): def test_entity_ruler_existing(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)

View File

@ -334,24 +334,31 @@ def test_language_factories_invalid():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weights,expected", "weights,override,expected",
[ [
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}), ([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {}, {"a": 0.33, "b": 0.33, "c": 0.33}),
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}), ([{"a": 1.0}, {"b": 50}, {"c": 100}], {}, {"a": 0.01, "b": 0.33, "c": 0.66}),
( (
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}], [{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
{},
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17}, {"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
), ),
( (
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}], [{"a": 100, "b": 300}, {"c": 50, "d": 50}],
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25}, {},
{"a": 0.2, "b": 0.6, "c": 0.1, "d": 0.1},
), ),
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}), ([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {}, {"a": 0.33, "b": 0.67}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}), ([{"a": 0.5, "b": 0.0}], {}, {"a": 1.0, "b": 0.0}),
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.0}, {"a": 0.0, "b": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {}, {"a": 0.0, "b": 0.0, "c": 0.0}),
([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"c": 0.2}, {"a": 0.0, "b": 0.0, "c": 1.0}),
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5}),
], ],
) )
def test_language_factories_combine_score_weights(weights, expected): def test_language_factories_combine_score_weights(weights, override, expected):
result = combine_score_weights(weights) result = combine_score_weights(weights, override)
assert sum(result.values()) in (0.99, 1.0, 0.0) assert sum(result.values()) in (0.99, 1.0, 0.0)
assert result == expected assert result == expected
@ -377,17 +384,17 @@ def test_language_factories_scores():
# Test with custom defaults # Test with custom defaults
config = nlp.config.copy() config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = 0.0 config["training"]["score_weights"]["a1"] = 0.0
config["training"]["score_weights"]["b3"] = 1.0 config["training"]["score_weights"]["b3"] = 1.3
nlp = English.from_config(config) nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34} expected = {"a1": 0.0, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.65}
assert score_weights == expected assert score_weights == expected
# Test with null values # Test with null values
config = nlp.config.copy() config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = None config["training"]["score_weights"]["a1"] = None
nlp = English.from_config(config) nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35} expected = {"a1": None, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.66}
assert score_weights == expected assert score_weights == expected

View File

@ -108,6 +108,12 @@ def test_label_types(name):
textcat.add_label("answer") textcat.add_label("answer")
with pytest.raises(ValueError): with pytest.raises(ValueError):
textcat.add_label(9) textcat.add_label(9)
# textcat requires at least two labels
if name == "textcat":
with pytest.raises(ValueError):
nlp.initialize()
else:
nlp.initialize()
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"]) @pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])

View File

@ -0,0 +1,34 @@
import pytest
from spacy import registry
from spacy.language import Language
from spacy.pipeline import EntityRuler
@pytest.fixture
def nlp():
return Language()
@pytest.fixture
@registry.misc("entity_ruler_patterns")
def patterns():
return [
{"label": "HELLO", "pattern": "hello world"},
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
]
def test_entity_ruler_fix8216(nlp, patterns):
"""Test that patterns don't get added excessively."""
ruler = nlp.add_pipe("entity_ruler", config={"validate": True})
ruler.add_patterns(patterns)
pattern_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
assert pattern_count > 0
ruler.add_patterns([])
after_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
assert after_count == pattern_count

View File

@ -84,7 +84,8 @@ Phasellus tincidunt, augue quis porta finibus, massa sapien consectetur augue, n
@pytest.mark.parametrize("file_name", ["sun.txt"]) @pytest.mark.parametrize("file_name", ["sun.txt"])
def test_tokenizer_handle_text_from_file(tokenizer, file_name): def test_tokenizer_handle_text_from_file(tokenizer, file_name):
loc = ensure_path(__file__).parent / file_name loc = ensure_path(__file__).parent / file_name
text = loc.open("r", encoding="utf8").read() with loc.open("r", encoding="utf8") as infile:
text = infile.read()
assert len(text) != 0 assert len(text) != 0
tokens = tokenizer(text) tokens = tokenizer(text)
assert len(tokens) > 100 assert len(tokens) > 100

View File

@ -182,6 +182,27 @@ def test_Example_from_dict_with_entities(annots):
assert example.reference[5].ent_type_ == "LOC" assert example.reference[5].ent_type_ == "LOC"
def test_Example_from_dict_with_empty_entities():
annots = {
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
"entities": [],
}
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots)
# entities as empty list sets everything to O
assert example.reference.has_annotation("ENT_IOB")
assert len(list(example.reference.ents)) == 0
assert all(token.ent_iob_ == "O" for token in example.reference)
# various unset/missing entities leaves entities unset
annots["entities"] = None
example = Example.from_dict(predicted, annots)
assert not example.reference.has_annotation("ENT_IOB")
annots.pop("entities", None)
example = Example.from_dict(predicted, annots)
assert not example.reference.has_annotation("ENT_IOB")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"annots", "annots",
[ [

View File

@ -1141,6 +1141,10 @@ cdef class Doc:
else: else:
warnings.warn(Warnings.W102.format(key=key, value=value)) warnings.warn(Warnings.W102.format(key=key, value=value))
for key in doc.spans: for key in doc.spans:
# if a spans key is in any doc, include it in the merged doc
# even if it is empty
if key not in concat_spans:
concat_spans[key] = []
for span in doc.spans[key]: for span in doc.spans[key]:
concat_spans[key].append(( concat_spans[key].append((
span.start_char + char_offset, span.start_char + char_offset,
@ -1150,7 +1154,7 @@ cdef class Doc:
span.text, # included as a check span.text, # included as a check
)) ))
char_offset += len(doc.text) char_offset += len(doc.text)
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space: if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_):
char_offset += 1 char_offset += 1
arrays = [doc.to_array(attrs) for doc in docs] arrays = [doc.to_array(attrs) for doc in docs]

View File

@ -416,7 +416,7 @@ def _fix_legacy_dict_data(example_dict):
token_dict = example_dict.get("token_annotation", {}) token_dict = example_dict.get("token_annotation", {})
doc_dict = example_dict.get("doc_annotation", {}) doc_dict = example_dict.get("doc_annotation", {})
for key, value in example_dict.items(): for key, value in example_dict.items():
if value: if value is not None:
if key in ("token_annotation", "doc_annotation"): if key in ("token_annotation", "doc_annotation"):
pass pass
elif key == "ids": elif key == "ids":

View File

@ -1370,32 +1370,14 @@ def combine_score_weights(
should be preserved. should be preserved.
RETURNS (Dict[str, float]): The combined and normalized weights. RETURNS (Dict[str, float]): The combined and normalized weights.
""" """
# We divide each weight by the total weight sum.
# We first need to extract all None/null values for score weights that # We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted # shouldn't be shown in the table *or* be weighted
result = {} result = {key: overrides.get(key, value) for w_dict in weights for (key, value) in w_dict.items()}
all_weights = [] weight_sum = sum([v if v else 0.0 for v in result.values()])
for w_dict in weights: for key, value in result.items():
filtered_weights = {} if value and weight_sum > 0:
for key, value in w_dict.items(): result[key] = round(value / weight_sum, 2)
value = overrides.get(key, value)
if value is None:
result[key] = None
else:
filtered_weights[key] = value
all_weights.append(filtered_weights)
for w_dict in all_weights:
# We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
if total == 0:
weight = 0.0
else:
weight = round(value / total / len(all_weights), 2)
prev_weight = result.get(key, 0.0)
prev_weight = 0.0 if prev_weight is None else prev_weight
result[key] = prev_weight + weight
return result return result

View File

@ -10,11 +10,12 @@ api_trainable: true
--- ---
The text categorizer predicts **categories over a whole document**. and comes in The text categorizer predicts **categories over a whole document**. and comes in
two flavours: `textcat` and `textcat_multilabel`. When you need to predict two flavors: `textcat` and `textcat_multilabel`. When you need to predict
exactly one true label per document, use the `textcat` which has mutually exactly one true label per document, use the `textcat` which has mutually
exclusive labels. If you want to perform multi-label classification and predict exclusive labels. If you want to perform multi-label classification and predict
zero, one or more labels per document, use the `textcat_multilabel` component zero, one or more true labels per document, use the `textcat_multilabel`
instead. component instead. For a binary classification task, you can use `textcat` with
**two** labels or `textcat_multilabel` with **one** label.
Both components are documented on this page. Both components are documented on this page.
@ -189,7 +190,7 @@ This method was previously called `begin_training`.
| _keyword-only_ | | | _keyword-only_ | |
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ | | `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ | | `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ |
| `positive_label` | The positive label for a binary task with exclusive classes, `None` otherwise and by default. This parameter is not available when using the `textcat_multilabel` component. ~~Optional[str]~~ | | `positive_label` | The positive label for a binary task with exclusive classes, `None` otherwise and by default. This parameter is only used during scoring. It is not available when using the `textcat_multilabel` component. ~~Optional[str]~~ |
## TextCategorizer.predict {#predict tag="method"} ## TextCategorizer.predict {#predict tag="method"}

View File

@ -262,7 +262,12 @@
}, },
{ {
"code": "mk", "code": "mk",
"name": "Macedonian" "name": "Macedonian",
"models": [
"mk_core_news_sm",
"mk_core_news_md",
"mk_core_news_lg"
]
}, },
{ {
"code": "ml", "code": "ml",