Fix NER when preset entities cross sentence boundaries (#3379)

💫 Fix NER when preset entities cross sentence boundaries
This commit is contained in:
Matthew Honnibal 2019-03-10 14:53:03 +01:00 committed by Ines Montani
parent 3fe5811fa7
commit a5b1f6dcec
5 changed files with 94 additions and 32 deletions

View File

@ -157,6 +157,10 @@ cdef void cpu_log_loss(float* d_scores,
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O)
if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access.
return
Z = 1e-10
gZ = 1e-10
max_ = scores[guess]

View File

@ -323,6 +323,12 @@ cdef cppclass StateC:
if this._s_i >= 1:
this._s_i -= 1
void force_final() nogil:
# This should only be used in desperate situations, as it may leave
# the analysis in an unexpected state.
this._s_i = 0
this._b_i = this.length
void unshift() nogil:
this._b_i -= 1
this._buffer[this._b_i] = this.S(0)

View File

@ -257,30 +257,42 @@ cdef class Missing:
cdef class Begin:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# Ensure we don't clobber preset entities. If no entity preset,
# ent_iob is 0
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 1:
cdef int preset_ent_label = st.B_(0).ent_type
# If we're the last token of the input, we can't B -- must U or O.
if st.B(1) == -1:
return False
elif preset_ent_iob == 2:
elif st.entity_is_open():
return False
elif preset_ent_iob == 3 and st.B_(0).ent_type != label:
elif label == 0:
return False
# If the next word is B or O, we can't B now
elif preset_ent_iob == 1 or preset_ent_iob == 2:
# Ensure we don't clobber preset entities. If no entity preset,
# ent_iob is 0
return False
elif preset_ent_iob == 3:
# Okay, we're in a preset entity.
if label != preset_ent_label:
# If label isn't right, reject
return False
elif st.B_(1).ent_iob != 1:
# If next token isn't marked I, we need to make U, not B.
return False
else:
# Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace.
return True
elif st.B_(1).ent_iob == 2 or st.B_(1).ent_iob == 3:
# If the next word is B or O, we can't B now
return False
# If the current word is B, and the next word isn't I, the current word
# is really U
elif preset_ent_iob == 3 and st.B_(1).ent_iob != 1:
return False
# Don't allow entities to extend across sentence boundaries
elif st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
# Don't allow entities to start on whitespace
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
return False
else:
return label != 0 and not st.entity_is_open()
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -314,18 +326,27 @@ cdef class In:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
if label == 0:
return False
elif st.E_(0).ent_type != label:
return False
elif not st.entity_is_open():
return False
elif st.B(1) == -1:
# If we're at the end, we can't I.
return False
elif preset_ent_iob == 2:
return False
elif preset_ent_iob == 3:
return False
# TODO: Is this quite right? I think it's supposed to be ensuring the
# gazetteer matches are maintained
elif st.B(1) != -1 and st.B_(1).ent_iob != preset_ent_iob:
elif st.B_(1).ent_iob == 2 or st.B_(1).ent_iob == 3:
# If we know the next word is B or O, we can't be I (must be L)
return False
# Don't allow entities to extend across sentence boundaries
elif st.B(1) != -1 and st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -370,9 +391,17 @@ cdef class In:
cdef class Last:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.B_(1).ent_iob == 1:
if label == 0:
return False
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
elif not st.entity_is_open():
return False
elif st.E_(0).ent_type != label:
return False
elif st.B_(1).ent_iob == 1:
# If a preset entity has I next, we can't L here.
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -416,17 +445,29 @@ cdef class Unit:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
cdef attr_t preset_ent_label = st.B_(0).ent_type
if label == 0:
return False
elif preset_ent_iob == 1:
elif st.entity_is_open():
return False
elif preset_ent_iob == 3 and st.B_(0).ent_type != label:
elif preset_ent_iob == 2:
# Don't clobber preset O
return False
elif st.B_(1).ent_iob == 1:
# If next token is In, we can't be Unit -- must be Begin
return False
elif preset_ent_iob == 3:
# Okay, there's a preset entity here
if label != preset_ent_label:
# Require labels to match
return False
else:
# Otherwise return True, ignoring the whitespace constraint.
return True
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
return False
return label != 0 and not st.entity_is_open()
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -461,11 +502,14 @@ cdef class Out:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 3:
if st.entity_is_open():
return False
elif preset_ent_iob == 3:
return False
elif preset_ent_iob == 1:
return False
return not st.entity_is_open()
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:

View File

@ -363,9 +363,14 @@ cdef class Parser:
for i in range(batch_size):
self.moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
action = self.moves.c[guess]
action.do(states[i], action.label)
states[i].push_hist(guess)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = self.moves.c[guess]
action.do(states[i], action.label)
states[i].push_hist(guess)
free(is_valid)
def transition_beams(self, beams, float[:, ::1] scores):

View File

@ -1,5 +1,8 @@
"""Test interaction between preset entities and sentence boundaries in NER."""
import spacy
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.lang.en import English
from spacy.tokens import Doc
from spacy.pipeline import EntityRuler, EntityRecognizer
@ -7,7 +10,7 @@ from spacy.pipeline import EntityRuler, EntityRecognizer
@pytest.mark.xfail
def test_issue3345():
"""Test case where preset entity crosses sentence boundary."""
nlp = spacy.blank("en")
nlp = English()
doc = Doc(nlp.vocab, words=["I", "live", "in", "New", "York"])
doc[4].is_sent_start = True