* Gazetteer stuff working, now need to wire up to API

This commit is contained in:
Matthew Honnibal 2015-08-06 00:35:40 +02:00
parent 47db3067a0
commit 9c1724ecae
8 changed files with 146 additions and 38 deletions

View File

@ -97,4 +97,5 @@ cdef class Matcher:
matches.append(get_entity(state, token, token_i))
else:
partials.push_back(state + 1)
doc.ents = list(sorted(list(doc.ents) + matches))
return matches

View File

@ -160,7 +160,17 @@ cdef class Missing:
cdef class Begin:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
return label != 0 and not st.entity_is_open()
# Ensure we don't clobber preset entities. If no entity preset,
# ent_iob is 0
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 1:
return False
elif preset_ent_iob == 2:
return False
elif preset_ent_iob == 3 and st.B_(0).ent_type != label:
return False
else:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(StateClass st, int label) nogil:
@ -190,6 +200,14 @@ cdef class Begin:
cdef class In:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
return False
elif preset_ent_iob == 3:
return False
# TODO: Is this quite right?
elif st.B_(1).ent_iob != preset_ent_iob:
return False
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
@ -230,6 +248,14 @@ cdef class In:
cdef class Last:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
return False
elif preset_ent_iob == 3:
return False
elif st.B_(1).ent_iob == 1:
return False
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
@ -269,6 +295,13 @@ cdef class Last:
cdef class Unit:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 2:
return False
elif preset_ent_iob == 1:
return False
elif st.B_(1).ent_iob == 1:
return False
return label != 0 and not st.entity_is_open()
@staticmethod
@ -300,6 +333,11 @@ cdef class Unit:
cdef class Out:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if preset_ent_iob == 3:
return False
elif preset_ent_iob == 1:
return False
return not st.entity_is_open()
@staticmethod

View File

@ -125,7 +125,7 @@ cdef class StateClass:
cdef void add_arc(self, int head, int child, int label) nogil
cdef void del_arc(self, int head, int child) nogil
cdef void open_ent(self, int label) nogil
cdef void close_ent(self) nogil

View File

@ -4,6 +4,11 @@ from preshed.counter cimport PreshCounter
from ..vocab cimport Vocab
from ..structs cimport TokenC, LexemeC
from ..typedefs cimport attr_t
from ..attrs cimport attr_id_t
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
ctypedef const LexemeC* const_Lexeme_ptr

View File

@ -119,40 +119,67 @@ cdef class Doc:
def string(self):
return u''.join([t.string for t in self])
@property
def ents(self):
"""Yields named-entity Span objects.
property ents:
def __get__(self):
"""Yields named-entity Span objects.
Iterate over the span to get individual Token objects, or access the label:
Iterate over the span to get individual Token objects, or access the label:
>>> from spacy.en import English
>>> nlp = English()
>>> tokens = nlp(u'Mr. Best flew to New York on Saturday morning.')
>>> ents = list(tokens.ents)
>>> ents[0].label, ents[0].label_, ''.join(t.orth_ for t in ents[0])
(112504, u'PERSON', u'Best ')
"""
cdef int i
cdef const TokenC* token
cdef int start = -1
cdef int label = 0
for i in range(self.length):
token = &self.data[i]
if token.ent_iob == 1:
assert start != -1
pass
elif token.ent_iob == 2:
if start != -1:
yield Span(self, start, i, label=label)
start = -1
label = 0
elif token.ent_iob == 3:
if start != -1:
yield Span(self, start, i, label=label)
start = i
label = token.ent_type
if start != -1:
yield Span(self, start, self.length, label=label)
>>> from spacy.en import English
>>> nlp = English()
>>> tokens = nlp(u'Mr. Best flew to New York on Saturday morning.')
>>> ents = list(tokens.ents)
>>> ents[0].label, ents[0].label_, ''.join(t.orth_ for t in ents[0])
(112504, u'PERSON', u'Best ')
"""
cdef int i
cdef const TokenC* token
cdef int start = -1
cdef int label = 0
output = []
for i in range(self.length):
token = &self.data[i]
if token.ent_iob == 1:
assert start != -1
elif token.ent_iob == 2 or token.ent_iob == 0:
if start != -1:
output.append(Span(self, start, i, label=label))
start = -1
label = 0
elif token.ent_iob == 3:
if start != -1:
output.append(Span(self, start, i, label=label))
start = i
label = token.ent_type
if start != -1:
output.append(Span(self, start, self.length, label=label))
return tuple(output)
def __set__(self, ents):
# TODO:
# 1. Allow negative matches
# 2. Ensure pre-set NERs are not over-written during statistical prediction
# 3. Test basic data-driven ORTH gazetteer
# 4. Test more nuanced date and currency regex
cdef int i
for i in range(self.length):
self.data[i].ent_type = 0
self.data[i].ent_iob = 0
cdef attr_t ent_type
cdef int start, end
for ent_type, start, end in ents:
if ent_type is None:
# Mark as O
for i in range(start, end):
self.data[i].ent_type = 0
self.data[i].ent_iob = 2
else:
# Mark (inside) as I
for i in range(start, end):
self.data[i].ent_type = ent_type
self.data[i].ent_iob = 1
# Set start as B
self.data[start].ent_iob = 3
@property
def noun_chunks(self):

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import pytest
@pytest.mark.models
def test_merge_tokens(EN):
tokens = EN(u'Los Angeles start.')
@ -32,3 +31,21 @@ def test_merge_heads(EN):
def test_issue_54(EN):
text = u'Talks given by women had a slightly higher number of questions asked (3.2$\pm$0.2) than talks given by men (2.6$\pm$0.1).'
tokens = EN(text, merge_mwes=True)
@pytest.mark.models
def test_np_merges(EN):
text = u'displaCy is a parse tool built with Javascript'
tokens = EN(text)
assert tokens[4].head.i == 1
tokens.merge(tokens[2].idx, tokens[4].idx + len(tokens[4]), u'NP', u'tool', u'O')
assert tokens[2].head.i == 1
tokens = EN('displaCy is a lightweight and modern dependency parse tree visualization tool built with CSS3 and JavaScript.')
ents = [(e[0].idx, e[-1].idx + len(e[-1]), e.label_, e.lemma_)
for e in tokens.ents]
for start, end, label, lemma in ents:
merged = tokens.merge(start, end, label, lemma, label)
assert merged != None, (start, end, label, lemma)
for tok in tokens:
print tok.orth_, tok.dep_, tok.head.orth_

View File

@ -47,5 +47,14 @@ def test_match_multi(matcher, EN):
assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4),
(EN.vocab.strings['product'], 5, 6)]
def test_dummy():
pass
def test_match_preserved(matcher, EN):
doc = EN.tokenizer('I like Java')
EN.tagger(doc)
EN.entity(doc)
assert len(doc.ents) == 0
doc = EN.tokenizer('I like Java')
matcher(doc)
assert len(doc.ents) == 1
EN.tagger(doc)
EN.entity(doc)
assert len(doc.ents) == 1

View File

@ -4,7 +4,6 @@ from spacy.tokens import Doc
import pytest
@pytest.mark.models
def test_getitem(EN):
tokens = EN(u'Give it back! He pleaded.')
@ -32,3 +31,15 @@ def test_serialize_whitespace(EN):
assert tokens.string == new_tokens.string
assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens]
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
def test_set_ents(EN):
tokens = EN.tokenizer(u'I use goggle chrone to surf the web')
assert len(tokens.ents) == 0
tokens.ents = [(EN.vocab.strings['PRODUCT'], 2, 4)]
assert len(list(tokens.ents)) == 1
assert [t.ent_iob for t in tokens] == [0, 0, 3, 1, 0, 0, 0, 0]
ent = tokens.ents[0]
assert ent.label_ == 'PRODUCT'
assert ent.start == 2
assert ent.end == 4