mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 18:36:36 +03:00
289 lines
9.7 KiB
Python
289 lines
9.7 KiB
Python
import pytest
|
|
from mock import Mock
|
|
from spacy.pipeline import EntityRuler
|
|
from spacy.matcher import DependencyMatcher
|
|
from spacy.tokens import Doc, Span, DocBin
|
|
from spacy.gold import Example
|
|
from spacy.gold.converters.conllu2docs import conllu2docs
|
|
from spacy.lang.en import English
|
|
from spacy.kb import KnowledgeBase
|
|
from spacy.vocab import Vocab
|
|
from spacy.language import Language
|
|
from spacy.util import ensure_path, load_model_from_path
|
|
import numpy
|
|
import pickle
|
|
|
|
from ..util import get_doc, make_tempdir
|
|
|
|
|
|
def test_issue4528(en_vocab):
|
|
"""Test that user_data is correctly serialized in DocBin."""
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
|
doc.user_data["foo"] = "bar"
|
|
# This is how extension attribute values are stored in the user data
|
|
doc.user_data[("._.", "foo", None, None)] = "bar"
|
|
doc_bin = DocBin(store_user_data=True)
|
|
doc_bin.add(doc)
|
|
doc_bin_bytes = doc_bin.to_bytes()
|
|
new_doc_bin = DocBin(store_user_data=True).from_bytes(doc_bin_bytes)
|
|
new_doc = list(new_doc_bin.get_docs(en_vocab))[0]
|
|
assert new_doc.user_data["foo"] == "bar"
|
|
assert new_doc.user_data[("._.", "foo", None, None)] == "bar"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"text,words", [("A'B C", ["A", "'", "B", "C"]), ("A-B", ["A-B"])]
|
|
)
|
|
def test_gold_misaligned(en_tokenizer, text, words):
|
|
doc = en_tokenizer(text)
|
|
Example.from_dict(doc, {"words": words})
|
|
|
|
|
|
def test_issue4590(en_vocab):
|
|
"""Test that matches param in on_match method are the same as matches run with no on_match method"""
|
|
pattern = [
|
|
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
|
{
|
|
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
|
|
"PATTERN": {"ORTH": "fox"},
|
|
},
|
|
{
|
|
"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"},
|
|
"PATTERN": {"ORTH": "fox"},
|
|
},
|
|
]
|
|
|
|
on_match = Mock()
|
|
matcher = DependencyMatcher(en_vocab)
|
|
matcher.add("pattern", on_match, pattern)
|
|
text = "The quick brown fox jumped over the lazy fox"
|
|
heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
|
|
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "det", "amod", "pobj"]
|
|
doc = get_doc(en_vocab, text.split(), heads=heads, deps=deps)
|
|
matches = matcher(doc)
|
|
on_match_args = on_match.call_args
|
|
assert on_match_args[0][3] == matches
|
|
|
|
|
|
def test_issue4651_with_phrase_matcher_attr():
|
|
"""Test that the EntityRuler PhraseMatcher is deserialize correctly using
|
|
the method from_disk when the EntityRuler argument phrase_matcher_attr is
|
|
specified.
|
|
"""
|
|
text = "Spacy is a python library for nlp"
|
|
nlp = English()
|
|
ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER")
|
|
patterns = [{"label": "PYTHON_LIB", "pattern": "spacy", "id": "spaCy"}]
|
|
ruler.add_patterns(patterns)
|
|
nlp.add_pipe(ruler)
|
|
doc = nlp(text)
|
|
res = [(ent.text, ent.label_, ent.ent_id_) for ent in doc.ents]
|
|
nlp_reloaded = English()
|
|
with make_tempdir() as d:
|
|
file_path = d / "entityruler"
|
|
ruler.to_disk(file_path)
|
|
ruler_reloaded = EntityRuler(nlp_reloaded).from_disk(file_path)
|
|
nlp_reloaded.add_pipe(ruler_reloaded)
|
|
doc_reloaded = nlp_reloaded(text)
|
|
res_reloaded = [(ent.text, ent.label_, ent.ent_id_) for ent in doc_reloaded.ents]
|
|
assert res == res_reloaded
|
|
|
|
|
|
def test_issue4651_without_phrase_matcher_attr():
|
|
"""Test that the EntityRuler PhraseMatcher is deserialize correctly using
|
|
the method from_disk when the EntityRuler argument phrase_matcher_attr is
|
|
not specified.
|
|
"""
|
|
text = "Spacy is a python library for nlp"
|
|
nlp = English()
|
|
ruler = EntityRuler(nlp)
|
|
patterns = [{"label": "PYTHON_LIB", "pattern": "spacy", "id": "spaCy"}]
|
|
ruler.add_patterns(patterns)
|
|
nlp.add_pipe(ruler)
|
|
doc = nlp(text)
|
|
res = [(ent.text, ent.label_, ent.ent_id_) for ent in doc.ents]
|
|
nlp_reloaded = English()
|
|
with make_tempdir() as d:
|
|
file_path = d / "entityruler"
|
|
ruler.to_disk(file_path)
|
|
ruler_reloaded = EntityRuler(nlp_reloaded).from_disk(file_path)
|
|
nlp_reloaded.add_pipe(ruler_reloaded)
|
|
doc_reloaded = nlp_reloaded(text)
|
|
res_reloaded = [(ent.text, ent.label_, ent.ent_id_) for ent in doc_reloaded.ents]
|
|
assert res == res_reloaded
|
|
|
|
|
|
def test_issue4665():
|
|
"""
|
|
conllu2json should not raise an exception if the HEAD column contains an
|
|
underscore
|
|
"""
|
|
input_data = """
|
|
1 [ _ PUNCT -LRB- _ _ punct _ _
|
|
2 This _ DET DT _ _ det _ _
|
|
3 killing _ NOUN NN _ _ nsubj _ _
|
|
4 of _ ADP IN _ _ case _ _
|
|
5 a _ DET DT _ _ det _ _
|
|
6 respected _ ADJ JJ _ _ amod _ _
|
|
7 cleric _ NOUN NN _ _ nmod _ _
|
|
8 will _ AUX MD _ _ aux _ _
|
|
9 be _ AUX VB _ _ aux _ _
|
|
10 causing _ VERB VBG _ _ root _ _
|
|
11 us _ PRON PRP _ _ iobj _ _
|
|
12 trouble _ NOUN NN _ _ dobj _ _
|
|
13 for _ ADP IN _ _ case _ _
|
|
14 years _ NOUN NNS _ _ nmod _ _
|
|
15 to _ PART TO _ _ mark _ _
|
|
16 come _ VERB VB _ _ acl _ _
|
|
17 . _ PUNCT . _ _ punct _ _
|
|
18 ] _ PUNCT -RRB- _ _ punct _ _
|
|
"""
|
|
conllu2docs(input_data)
|
|
|
|
|
|
def test_issue4674():
|
|
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
|
nlp = English()
|
|
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
|
vector1 = [0.9, 1.1, 1.01]
|
|
vector2 = [1.8, 2.25, 2.01]
|
|
with pytest.warns(UserWarning):
|
|
kb.set_entities(
|
|
entity_list=["Q1", "Q1"],
|
|
freq_list=[32, 111],
|
|
vector_list=[vector1, vector2],
|
|
)
|
|
assert kb.get_size_entities() == 1
|
|
# dumping to file & loading back in
|
|
with make_tempdir() as d:
|
|
dir_path = ensure_path(d)
|
|
if not dir_path.exists():
|
|
dir_path.mkdir()
|
|
file_path = dir_path / "kb"
|
|
kb.dump(str(file_path))
|
|
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
|
|
kb2.load_bulk(str(file_path))
|
|
assert kb2.get_size_entities() == 1
|
|
|
|
|
|
def test_issue4707():
|
|
"""Tests that disabled component names are also excluded from nlp.from_disk
|
|
by default when loading a model.
|
|
"""
|
|
nlp = English()
|
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
|
nlp.add_pipe(nlp.create_pipe("entity_ruler"))
|
|
assert nlp.pipe_names == ["sentencizer", "entity_ruler"]
|
|
exclude = ["tokenizer", "sentencizer"]
|
|
with make_tempdir() as tmpdir:
|
|
nlp.to_disk(tmpdir, exclude=exclude)
|
|
new_nlp = load_model_from_path(tmpdir, disable=exclude)
|
|
assert "sentencizer" not in new_nlp.pipe_names
|
|
assert "entity_ruler" in new_nlp.pipe_names
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
def test_issue4725_1():
|
|
""" Ensure the pickling of the NER goes well"""
|
|
vocab = Vocab(vectors_name="test_vocab_add_vector")
|
|
nlp = English(vocab=vocab)
|
|
ner = nlp.create_pipe("ner", config={"min_action_freq": 342})
|
|
with make_tempdir() as tmp_path:
|
|
with (tmp_path / "ner.pkl").open("wb") as file_:
|
|
pickle.dump(ner, file_)
|
|
assert ner.cfg["min_action_freq"] == 342
|
|
|
|
with (tmp_path / "ner.pkl").open("rb") as file_:
|
|
ner2 = pickle.load(file_)
|
|
assert ner2.cfg["min_action_freq"] == 342
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
|
def test_issue4725_2():
|
|
# ensures that this runs correctly and doesn't hang or crash because of the global vectors
|
|
# if it does crash, it's usually because of calling 'spawn' for multiprocessing (e.g. on Windows)
|
|
vocab = Vocab(vectors_name="test_vocab_add_vector")
|
|
data = numpy.ndarray((5, 3), dtype="f")
|
|
data[0] = 1.0
|
|
data[1] = 2.0
|
|
vocab.set_vector("cat", data[0])
|
|
vocab.set_vector("dog", data[1])
|
|
nlp = English(vocab=vocab)
|
|
ner = nlp.create_pipe("ner")
|
|
nlp.add_pipe(ner)
|
|
nlp.begin_training()
|
|
docs = ["Kurt is in London."] * 10
|
|
for _ in nlp.pipe(docs, batch_size=2, n_process=2):
|
|
pass
|
|
|
|
|
|
def test_issue4849():
|
|
nlp = English()
|
|
ruler = EntityRuler(
|
|
nlp,
|
|
patterns=[
|
|
{"label": "PERSON", "pattern": "joe biden", "id": "joe-biden"},
|
|
{"label": "PERSON", "pattern": "bernie sanders", "id": "bernie-sanders"},
|
|
],
|
|
phrase_matcher_attr="LOWER",
|
|
)
|
|
nlp.add_pipe(ruler)
|
|
text = """
|
|
The left is starting to take aim at Democratic front-runner Joe Biden.
|
|
Sen. Bernie Sanders joined in her criticism: "There is no 'middle ground' when it comes to climate policy."
|
|
"""
|
|
# USING 1 PROCESS
|
|
count_ents = 0
|
|
for doc in nlp.pipe([text], n_process=1):
|
|
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
|
assert count_ents == 2
|
|
# USING 2 PROCESSES
|
|
count_ents = 0
|
|
for doc in nlp.pipe([text], n_process=2):
|
|
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
|
assert count_ents == 2
|
|
|
|
|
|
class CustomPipe:
|
|
name = "my_pipe"
|
|
|
|
def __init__(self):
|
|
Span.set_extension("my_ext", getter=self._get_my_ext)
|
|
Doc.set_extension("my_ext", default=None)
|
|
|
|
def __call__(self, doc):
|
|
gathered_ext = []
|
|
for sent in doc.sents:
|
|
sent_ext = self._get_my_ext(sent)
|
|
sent._.set("my_ext", sent_ext)
|
|
gathered_ext.append(sent_ext)
|
|
|
|
doc._.set("my_ext", "\n".join(gathered_ext))
|
|
|
|
return doc
|
|
|
|
@staticmethod
|
|
def _get_my_ext(span):
|
|
return str(span.end)
|
|
|
|
|
|
def test_issue4903():
|
|
"""Ensure that this runs correctly and doesn't hang or crash on Windows /
|
|
macOS."""
|
|
nlp = English()
|
|
custom_component = CustomPipe()
|
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
|
nlp.add_pipe(custom_component, after="sentencizer")
|
|
|
|
text = ["I like bananas.", "Do you like them?", "No, I prefer wasabi."]
|
|
docs = list(nlp.pipe(text, n_process=2))
|
|
assert docs[0].text == "I like bananas."
|
|
assert docs[1].text == "Do you like them?"
|
|
assert docs[2].text == "No, I prefer wasabi."
|
|
|
|
|
|
def test_issue4924():
|
|
nlp = Language()
|
|
example = Example.from_dict(nlp.make_doc(""), {})
|
|
nlp.evaluate([example])
|