mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Preserve existing ENT_KB_ID annotation in NER (#7988)
* Preserve existing ENT_KB_ID annotation in NER Preserve `ent_kb_id` annotation on existing entity spans, which is not preserved by the transition system. * Simplify kb_id assignment * Simplify further
This commit is contained in:
parent
02a6a5fea0
commit
6788d90f61
|
@ -247,7 +247,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
for i in range(state.c._ents.size()):
|
||||
ent = state.c._ents.at(i)
|
||||
if ent.start != -1 and ent.end != -1:
|
||||
ents.append(Span(doc, ent.start, ent.end, label=ent.label))
|
||||
ents.append(Span(doc, ent.start, ent.end, label=ent.label, kb_id=doc.c[ent.start].ent_kb_id))
|
||||
doc.set_ents(ents, default="unmodified")
|
||||
# Set non-blocked tokens to O
|
||||
for i in range(doc.length):
|
||||
|
|
|
@ -8,7 +8,7 @@ from spacy.language import Language
|
|||
from spacy.lookups import Lookups
|
||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||
from spacy.training import Example
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
import logging
|
||||
|
||||
|
@ -358,6 +358,26 @@ def test_overfitting_IO(use_upper):
|
|||
assert_equal(batch_deps_1, batch_deps_2)
|
||||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
# test that kb_id is preserved
|
||||
test_text = "I like London and London."
|
||||
doc = nlp.make_doc(test_text)
|
||||
doc.ents = [Span(doc, 2, 3, label="LOC", kb_id=1234)]
|
||||
ents = doc.ents
|
||||
assert len(ents) == 1
|
||||
assert ents[0].text == "London"
|
||||
assert ents[0].label_ == "LOC"
|
||||
assert ents[0].kb_id == 1234
|
||||
doc = nlp.get_pipe("ner")(doc)
|
||||
ents = doc.ents
|
||||
assert len(ents) == 2
|
||||
assert ents[0].text == "London"
|
||||
assert ents[0].label_ == "LOC"
|
||||
assert ents[0].kb_id == 1234
|
||||
# ent added by ner has kb_id == 0
|
||||
assert ents[1].text == "London"
|
||||
assert ents[1].label_ == "LOC"
|
||||
assert ents[1].kb_id == 0
|
||||
|
||||
|
||||
def test_beam_ner_scores():
|
||||
# Test that we can get confidence values out of the beam_ner pipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user