diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index 30d4b67d3..f664e6a2c 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -157,6 +157,10 @@ cdef void cpu_log_loss(float* d_scores, cdef double max_, gmax, Z, gZ best = arg_max_if_gold(scores, costs, is_valid, O) guess = arg_max_if_valid(scores, is_valid, O) + if best == -1 or guess == -1: + # These shouldn't happen, but if they do, we want to make sure we don't + # cause an OOB access. + return Z = 1e-10 gZ = 1e-10 max_ = scores[guess] diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index d082cee5c..204f723d8 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -323,6 +323,12 @@ cdef cppclass StateC: if this._s_i >= 1: this._s_i -= 1 + void force_final() nogil: + # This should only be used in desperate situations, as it may leave + # the analysis in an unexpected state. + this._s_i = 0 + this._b_i = this.length + void unshift() nogil: this._b_i -= 1 this._buffer[this._b_i] = this.S(0) diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index b43a879d4..804167b0e 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -257,30 +257,42 @@ cdef class Missing: cdef class Begin: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - # Ensure we don't clobber preset entities. If no entity preset, - # ent_iob is 0 cdef int preset_ent_iob = st.B_(0).ent_iob - if preset_ent_iob == 1: + cdef int preset_ent_label = st.B_(0).ent_type + # If we're the last token of the input, we can't B -- must U or O. + if st.B(1) == -1: return False - elif preset_ent_iob == 2: + elif st.entity_is_open(): return False - elif preset_ent_iob == 3 and st.B_(0).ent_type != label: + elif label == 0: return False - # If the next word is B or O, we can't B now + elif preset_ent_iob == 1 or preset_ent_iob == 2: + # Ensure we don't clobber preset entities. If no entity preset, + # ent_iob is 0 + return False + elif preset_ent_iob == 3: + # Okay, we're in a preset entity. + if label != preset_ent_label: + # If label isn't right, reject + return False + elif st.B_(1).ent_iob != 1: + # If next token isn't marked I, we need to make U, not B. + return False + else: + # Otherwise, force acceptance, even if we're across a sentence + # boundary or the token is whitespace. + return True elif st.B_(1).ent_iob == 2 or st.B_(1).ent_iob == 3: + # If the next word is B or O, we can't B now return False - # If the current word is B, and the next word isn't I, the current word - # is really U - elif preset_ent_iob == 3 and st.B_(1).ent_iob != 1: - return False - # Don't allow entities to extend across sentence boundaries elif st.B_(1).sent_start == 1: + # Don't allow entities to extend across sentence boundaries return False # Don't allow entities to start on whitespace elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE): return False else: - return label != 0 and not st.entity_is_open() + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -314,18 +326,27 @@ cdef class In: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob - if preset_ent_iob == 2: + if label == 0: + return False + elif st.E_(0).ent_type != label: + return False + elif not st.entity_is_open(): + return False + elif st.B(1) == -1: + # If we're at the end, we can't I. + return False + elif preset_ent_iob == 2: return False elif preset_ent_iob == 3: return False - # TODO: Is this quite right? I think it's supposed to be ensuring the - # gazetteer matches are maintained - elif st.B(1) != -1 and st.B_(1).ent_iob != preset_ent_iob: + elif st.B_(1).ent_iob == 2 or st.B_(1).ent_iob == 3: + # If we know the next word is B or O, we can't be I (must be L) return False - # Don't allow entities to extend across sentence boundaries elif st.B(1) != -1 and st.B_(1).sent_start == 1: + # Don't allow entities to extend across sentence boundaries return False - return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label + else: + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -370,9 +391,17 @@ cdef class In: cdef class Last: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - if st.B_(1).ent_iob == 1: + if label == 0: return False - return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label + elif not st.entity_is_open(): + return False + elif st.E_(0).ent_type != label: + return False + elif st.B_(1).ent_iob == 1: + # If a preset entity has I next, we can't L here. + return False + else: + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -416,17 +445,29 @@ cdef class Unit: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob - if preset_ent_iob == 2: + cdef attr_t preset_ent_label = st.B_(0).ent_type + if label == 0: return False - elif preset_ent_iob == 1: + elif st.entity_is_open(): return False - elif preset_ent_iob == 3 and st.B_(0).ent_type != label: + elif preset_ent_iob == 2: + # Don't clobber preset O return False elif st.B_(1).ent_iob == 1: + # If next token is In, we can't be Unit -- must be Begin return False + elif preset_ent_iob == 3: + # Okay, there's a preset entity here + if label != preset_ent_label: + # Require labels to match + return False + else: + # Otherwise return True, ignoring the whitespace constraint. + return True elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE): return False - return label != 0 and not st.entity_is_open() + else: + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -461,11 +502,14 @@ cdef class Out: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob - if preset_ent_iob == 3: + if st.entity_is_open(): + return False + elif preset_ent_iob == 3: return False elif preset_ent_iob == 1: return False - return not st.entity_is_open() + else: + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index ee9d0ee7e..0009eba72 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -363,9 +363,14 @@ cdef class Parser: for i in range(batch_size): self.moves.set_valid(is_valid, states[i]) guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class) - action = self.moves.c[guess] - action.do(states[i], action.label) - states[i].push_hist(guess) + if guess == -1: + # This shouldn't happen, but it's hard to raise an error here, + # and we don't want to infinite loop. So, force to end state. + states[i].force_final() + else: + action = self.moves.c[guess] + action.do(states[i], action.label) + states[i].push_hist(guess) free(is_valid) def transition_beams(self, beams, float[:, ::1] scores): diff --git a/spacy/tests/regression/test_issue3345.py b/spacy/tests/regression/test_issue3345.py index 7b1f41fbf..8a7823d96 100644 --- a/spacy/tests/regression/test_issue3345.py +++ b/spacy/tests/regression/test_issue3345.py @@ -1,5 +1,8 @@ -"""Test interaction between preset entities and sentence boundaries in NER.""" -import spacy +# coding: utf8 +from __future__ import unicode_literals + +import pytest +from spacy.lang.en import English from spacy.tokens import Doc from spacy.pipeline import EntityRuler, EntityRecognizer @@ -7,7 +10,7 @@ from spacy.pipeline import EntityRuler, EntityRecognizer @pytest.mark.xfail def test_issue3345(): """Test case where preset entity crosses sentence boundary.""" - nlp = spacy.blank("en") + nlp = English() doc = Doc(nlp.vocab, words=["I", "live", "in", "New", "York"]) doc[4].is_sent_start = True