diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index f6f3f95d6..b1b77e162 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -97,4 +97,5 @@ cdef class Matcher: matches.append(get_entity(state, token, token_i)) else: partials.push_back(state + 1) + doc.ents = list(sorted(list(doc.ents) + matches)) return matches diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index fbd580b29..8fa4a03d5 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -160,7 +160,17 @@ cdef class Missing: cdef class Begin: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - return label != 0 and not st.entity_is_open() + # Ensure we don't clobber preset entities. If no entity preset, + # ent_iob is 0 + cdef int preset_ent_iob = st.B_(0).ent_iob + if preset_ent_iob == 1: + return False + elif preset_ent_iob == 2: + return False + elif preset_ent_iob == 3 and st.B_(0).ent_type != label: + return False + else: + return label != 0 and not st.entity_is_open() @staticmethod cdef int transition(StateClass st, int label) nogil: @@ -190,6 +200,14 @@ cdef class Begin: cdef class In: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: + cdef int preset_ent_iob = st.B_(0).ent_iob + if preset_ent_iob == 2: + return False + elif preset_ent_iob == 3: + return False + # TODO: Is this quite right? + elif st.B_(1).ent_iob != preset_ent_iob: + return False return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod @@ -230,6 +248,14 @@ cdef class In: cdef class Last: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: + cdef int preset_ent_iob = st.B_(0).ent_iob + if preset_ent_iob == 2: + return False + elif preset_ent_iob == 3: + return False + elif st.B_(1).ent_iob == 1: + return False + return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod @@ -269,6 +295,13 @@ cdef class Last: cdef class Unit: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: + cdef int preset_ent_iob = st.B_(0).ent_iob + if preset_ent_iob == 2: + return False + elif preset_ent_iob == 1: + return False + elif st.B_(1).ent_iob == 1: + return False return label != 0 and not st.entity_is_open() @staticmethod @@ -300,6 +333,11 @@ cdef class Unit: cdef class Out: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: + cdef int preset_ent_iob = st.B_(0).ent_iob + if preset_ent_iob == 3: + return False + elif preset_ent_iob == 1: + return False return not st.entity_is_open() @staticmethod diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd index 905d8cdde..888b01c32 100644 --- a/spacy/syntax/stateclass.pxd +++ b/spacy/syntax/stateclass.pxd @@ -125,7 +125,7 @@ cdef class StateClass: cdef void add_arc(self, int head, int child, int label) nogil cdef void del_arc(self, int head, int child) nogil - + cdef void open_ent(self, int label) nogil cdef void close_ent(self) nogil diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 7de5e0bea..121018770 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -4,6 +4,11 @@ from preshed.counter cimport PreshCounter from ..vocab cimport Vocab from ..structs cimport TokenC, LexemeC +from ..typedefs cimport attr_t +from ..attrs cimport attr_id_t + + +cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil ctypedef const LexemeC* const_Lexeme_ptr diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index a3ae45733..6d0cd9a8b 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -119,40 +119,67 @@ cdef class Doc: def string(self): return u''.join([t.string for t in self]) - @property - def ents(self): - """Yields named-entity Span objects. + property ents: + def __get__(self): + """Yields named-entity Span objects. - Iterate over the span to get individual Token objects, or access the label: + Iterate over the span to get individual Token objects, or access the label: - >>> from spacy.en import English - >>> nlp = English() - >>> tokens = nlp(u'Mr. Best flew to New York on Saturday morning.') - >>> ents = list(tokens.ents) - >>> ents[0].label, ents[0].label_, ''.join(t.orth_ for t in ents[0]) - (112504, u'PERSON', u'Best ') - """ - cdef int i - cdef const TokenC* token - cdef int start = -1 - cdef int label = 0 - for i in range(self.length): - token = &self.data[i] - if token.ent_iob == 1: - assert start != -1 - pass - elif token.ent_iob == 2: - if start != -1: - yield Span(self, start, i, label=label) - start = -1 - label = 0 - elif token.ent_iob == 3: - if start != -1: - yield Span(self, start, i, label=label) - start = i - label = token.ent_type - if start != -1: - yield Span(self, start, self.length, label=label) + >>> from spacy.en import English + >>> nlp = English() + >>> tokens = nlp(u'Mr. Best flew to New York on Saturday morning.') + >>> ents = list(tokens.ents) + >>> ents[0].label, ents[0].label_, ''.join(t.orth_ for t in ents[0]) + (112504, u'PERSON', u'Best ') + """ + cdef int i + cdef const TokenC* token + cdef int start = -1 + cdef int label = 0 + output = [] + for i in range(self.length): + token = &self.data[i] + if token.ent_iob == 1: + assert start != -1 + elif token.ent_iob == 2 or token.ent_iob == 0: + if start != -1: + output.append(Span(self, start, i, label=label)) + start = -1 + label = 0 + elif token.ent_iob == 3: + if start != -1: + output.append(Span(self, start, i, label=label)) + start = i + label = token.ent_type + if start != -1: + output.append(Span(self, start, self.length, label=label)) + return tuple(output) + + def __set__(self, ents): + # TODO: + # 1. Allow negative matches + # 2. Ensure pre-set NERs are not over-written during statistical prediction + # 3. Test basic data-driven ORTH gazetteer + # 4. Test more nuanced date and currency regex + cdef int i + for i in range(self.length): + self.data[i].ent_type = 0 + self.data[i].ent_iob = 0 + cdef attr_t ent_type + cdef int start, end + for ent_type, start, end in ents: + if ent_type is None: + # Mark as O + for i in range(start, end): + self.data[i].ent_type = 0 + self.data[i].ent_iob = 2 + else: + # Mark (inside) as I + for i in range(start, end): + self.data[i].ent_type = ent_type + self.data[i].ent_iob = 1 + # Set start as B + self.data[start].ent_iob = 3 @property def noun_chunks(self): diff --git a/tests/spans/test_merge.py b/tests/spans/test_merge.py index 3bba13064..e225db043 100644 --- a/tests/spans/test_merge.py +++ b/tests/spans/test_merge.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import pytest - @pytest.mark.models def test_merge_tokens(EN): tokens = EN(u'Los Angeles start.') @@ -32,3 +31,21 @@ def test_merge_heads(EN): def test_issue_54(EN): text = u'Talks given by women had a slightly higher number of questions asked (3.2$\pm$0.2) than talks given by men (2.6$\pm$0.1).' tokens = EN(text, merge_mwes=True) + +@pytest.mark.models +def test_np_merges(EN): + text = u'displaCy is a parse tool built with Javascript' + tokens = EN(text) + assert tokens[4].head.i == 1 + tokens.merge(tokens[2].idx, tokens[4].idx + len(tokens[4]), u'NP', u'tool', u'O') + assert tokens[2].head.i == 1 + tokens = EN('displaCy is a lightweight and modern dependency parse tree visualization tool built with CSS3 and JavaScript.') + + ents = [(e[0].idx, e[-1].idx + len(e[-1]), e.label_, e.lemma_) + for e in tokens.ents] + for start, end, label, lemma in ents: + merged = tokens.merge(start, end, label, lemma, label) + assert merged != None, (start, end, label, lemma) + for tok in tokens: + print tok.orth_, tok.dep_, tok.head.orth_ + diff --git a/tests/test_matcher.py b/tests/test_matcher.py index fb3665623..06950253c 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -47,5 +47,14 @@ def test_match_multi(matcher, EN): assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4), (EN.vocab.strings['product'], 5, 6)] -def test_dummy(): - pass +def test_match_preserved(matcher, EN): + doc = EN.tokenizer('I like Java') + EN.tagger(doc) + EN.entity(doc) + assert len(doc.ents) == 0 + doc = EN.tokenizer('I like Java') + matcher(doc) + assert len(doc.ents) == 1 + EN.tagger(doc) + EN.entity(doc) + assert len(doc.ents) == 1 diff --git a/tests/tokens/test_tokens_api.py b/tests/tokens/test_tokens_api.py index b935bbce7..e1238373f 100644 --- a/tests/tokens/test_tokens_api.py +++ b/tests/tokens/test_tokens_api.py @@ -4,7 +4,6 @@ from spacy.tokens import Doc import pytest - @pytest.mark.models def test_getitem(EN): tokens = EN(u'Give it back! He pleaded.') @@ -32,3 +31,15 @@ def test_serialize_whitespace(EN): assert tokens.string == new_tokens.string assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens] assert [t.orth for t in tokens] == [t.orth for t in new_tokens] + + +def test_set_ents(EN): + tokens = EN.tokenizer(u'I use goggle chrone to surf the web') + assert len(tokens.ents) == 0 + tokens.ents = [(EN.vocab.strings['PRODUCT'], 2, 4)] + assert len(list(tokens.ents)) == 1 + assert [t.ent_iob for t in tokens] == [0, 0, 3, 1, 0, 0, 0, 0] + ent = tokens.ents[0] + assert ent.label_ == 'PRODUCT' + assert ent.start == 2 + assert ent.end == 4