mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
import pytest
|
|
|
|
from spacy import registry
|
|
from spacy.pipeline import EntityRecognizer
|
|
from spacy.pipeline.ner import DEFAULT_NER_MODEL
|
|
from spacy.tokens import Doc, Span
|
|
from spacy.training import Example
|
|
|
|
|
|
def _ner_example(ner):
|
|
doc = Doc(
|
|
ner.vocab,
|
|
words=["Joe", "loves", "visiting", "London", "during", "the", "weekend"],
|
|
)
|
|
gold = {"entities": [(0, 3, "PERSON"), (19, 25, "LOC")]}
|
|
return Example.from_dict(doc, gold)
|
|
|
|
|
|
def test_doc_add_entities_set_ents_iob(en_vocab):
|
|
text = ["This", "is", "a", "lion"]
|
|
doc = Doc(en_vocab, words=text)
|
|
cfg = {"model": DEFAULT_NER_MODEL}
|
|
model = registry.resolve(cfg, validate=True)["model"]
|
|
ner = EntityRecognizer(en_vocab, model)
|
|
ner.initialize(lambda: [_ner_example(ner)])
|
|
ner(doc)
|
|
|
|
doc.ents = [("ANIMAL", 3, 4)]
|
|
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
|
|
|
doc.ents = [("WORD", 0, 2)]
|
|
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
|
|
|
|
|
def test_ents_reset(en_vocab):
|
|
"""Ensure that resetting doc.ents does not change anything"""
|
|
text = ["This", "is", "a", "lion"]
|
|
doc = Doc(en_vocab, words=text)
|
|
cfg = {"model": DEFAULT_NER_MODEL}
|
|
model = registry.resolve(cfg, validate=True)["model"]
|
|
ner = EntityRecognizer(en_vocab, model)
|
|
ner.initialize(lambda: [_ner_example(ner)])
|
|
ner(doc)
|
|
orig_iobs = [t.ent_iob_ for t in doc]
|
|
doc.ents = list(doc.ents)
|
|
assert [t.ent_iob_ for t in doc] == orig_iobs
|
|
|
|
|
|
def test_ents_clear(en_vocab):
|
|
"""Ensure that removing entities clears token attributes"""
|
|
text = ["Louisiana", "Office", "of", "Conservation"]
|
|
doc = Doc(en_vocab, words=text)
|
|
entity = Span(doc, 0, 4, label=391, span_id="TEST")
|
|
doc.ents = [entity]
|
|
doc.ents = []
|
|
for token in doc:
|
|
assert token.ent_iob == 2
|
|
assert token.ent_type == 0
|
|
assert token.ent_id == 0
|
|
assert token.ent_kb_id == 0
|
|
doc.ents = [entity]
|
|
doc.set_ents([], default="missing")
|
|
for token in doc:
|
|
assert token.ent_iob == 0
|
|
assert token.ent_type == 0
|
|
assert token.ent_id == 0
|
|
assert token.ent_kb_id == 0
|
|
doc.set_ents([], default="blocked")
|
|
for token in doc:
|
|
assert token.ent_iob == 3
|
|
assert token.ent_type == 0
|
|
assert token.ent_id == 0
|
|
assert token.ent_kb_id == 0
|
|
|
|
|
|
def test_add_overlapping_entities(en_vocab):
|
|
text = ["Louisiana", "Office", "of", "Conservation"]
|
|
doc = Doc(en_vocab, words=text)
|
|
entity = Span(doc, 0, 4, label=391)
|
|
doc.ents = [entity]
|
|
|
|
new_entity = Span(doc, 0, 1, label=392)
|
|
with pytest.raises(ValueError):
|
|
doc.ents = list(doc.ents) + [new_entity]
|