mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Initial, limited support for quantified patterns in Matcher, and tracking of ent_id attribute in Token and Span. The quantifiers need a lot more testing, and there are some known problems. The main known problem is that the zero-plus and one-plus quantifiers won't work if a token can match both the quantified pattern expression AND the tail of the match.
This commit is contained in:
parent
2735b6247b
commit
58e83fe34b
|
@ -12,9 +12,11 @@ from .lexeme cimport Lexeme
|
|||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport PreshMap
|
||||
from libcpp.vector cimport vector
|
||||
from libcpp.pair cimport pair
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
|
||||
from .attrs cimport ID, LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
|
||||
from . import attrs
|
||||
from .tokens.doc cimport get_token_attr
|
||||
from .tokens.doc cimport Doc
|
||||
|
@ -59,58 +61,96 @@ except ImportError:
|
|||
import json
|
||||
|
||||
|
||||
cdef struct AttrValue:
|
||||
cpdef enum quantifier_t:
|
||||
_META
|
||||
ONE
|
||||
ZERO
|
||||
ZERO_ONE
|
||||
ZERO_PLUS
|
||||
|
||||
|
||||
cdef enum action_t:
|
||||
REJECT
|
||||
ADVANCE
|
||||
REPEAT
|
||||
ACCEPT
|
||||
ADVANCE_ZERO
|
||||
PANIC
|
||||
|
||||
|
||||
cdef struct AttrValueC:
|
||||
attr_id_t attr
|
||||
attr_t value
|
||||
|
||||
|
||||
cdef struct Pattern:
|
||||
AttrValue* spec
|
||||
int length
|
||||
cdef struct TokenPatternC:
|
||||
AttrValueC* attrs
|
||||
int32_t nr_attr
|
||||
quantifier_t quantifier
|
||||
|
||||
|
||||
cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL:
|
||||
pattern = <Pattern*>mem.alloc(len(token_specs) + 1, sizeof(Pattern))
|
||||
ctypedef TokenPatternC* TokenPatternC_ptr
|
||||
ctypedef pair[int, TokenPatternC_ptr] StateC
|
||||
|
||||
|
||||
cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
|
||||
attr_t entity_type) except NULL:
|
||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||
cdef int i
|
||||
for i, spec in enumerate(token_specs):
|
||||
pattern[i].spec = <AttrValue*>mem.alloc(len(spec), sizeof(AttrValue))
|
||||
pattern[i].length = len(spec)
|
||||
for i, (quantifier, spec) in enumerate(token_specs):
|
||||
pattern[i].quantifier = quantifier
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||
pattern[i].nr_attr = len(spec)
|
||||
for j, (attr, value) in enumerate(spec):
|
||||
pattern[i].spec[j].attr = attr
|
||||
pattern[i].spec[j].value = value
|
||||
pattern[i].attrs[j].attr = attr
|
||||
pattern[i].attrs[j].value = value
|
||||
i = len(token_specs)
|
||||
pattern[i].spec = <AttrValue*>mem.alloc(2, sizeof(AttrValue))
|
||||
pattern[i].spec[0].attr = ENT_TYPE
|
||||
pattern[i].spec[0].value = entity_type
|
||||
pattern[i].spec[1].attr = LENGTH
|
||||
pattern[i].spec[1].value = len(token_specs)
|
||||
pattern[i].length = 0
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(3, sizeof(AttrValueC))
|
||||
pattern[i].attrs[0].attr = ID
|
||||
pattern[i].attrs[0].value = entity_id
|
||||
pattern[i].attrs[1].attr = ENT_TYPE
|
||||
pattern[i].attrs[1].value = entity_type
|
||||
pattern[i].nr_attr = 0
|
||||
return pattern
|
||||
|
||||
|
||||
cdef int match(const Pattern* pattern, const TokenC* token) except -1:
|
||||
cdef int i
|
||||
for i in range(pattern.length):
|
||||
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
cdef int is_final(const Pattern* pattern) except -1:
|
||||
return (pattern + 1).length == 0
|
||||
|
||||
|
||||
cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
|
||||
pattern += 1
|
||||
i += 1
|
||||
return (pattern.spec[0].value, i - pattern.spec[1].value, i)
|
||||
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
||||
for attr in pattern.attrs[:pattern.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
if pattern.quantifier == ONE:
|
||||
return REJECT
|
||||
elif pattern.quantifier == ZERO:
|
||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
||||
elif pattern.quantifier in (ZERO_ONE, ZERO_PLUS):
|
||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE_ZERO
|
||||
else:
|
||||
return PANIC
|
||||
if pattern.quantifier == ZERO:
|
||||
return REJECT
|
||||
elif pattern.quantifier in (ONE, ZERO_ONE):
|
||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
||||
elif pattern.quantifier == ZERO_PLUS:
|
||||
return REPEAT
|
||||
else:
|
||||
return PANIC
|
||||
|
||||
|
||||
def _convert_strings(token_specs, string_store):
|
||||
converted = []
|
||||
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
||||
operators = {'!': (ZERO,), '*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
||||
'?': (ZERO_ONE,)}
|
||||
tokens = []
|
||||
op = ONE
|
||||
for spec in token_specs:
|
||||
converted.append([])
|
||||
token = []
|
||||
ops = (ONE,)
|
||||
for attr, value in spec.items():
|
||||
if isinstance(attr, basestring) and attr.upper() == 'OP':
|
||||
if value in operators:
|
||||
ops = operators[value]
|
||||
else:
|
||||
raise KeyError(
|
||||
"Unknown operator. Options: %s" % ', '.join(operators.keys()))
|
||||
if isinstance(attr, basestring):
|
||||
attr = attrs.IDS.get(attr.upper())
|
||||
if isinstance(value, basestring):
|
||||
|
@ -118,8 +158,10 @@ def _convert_strings(token_specs, string_store):
|
|||
if isinstance(value, bool):
|
||||
value = int(value)
|
||||
if attr is not None:
|
||||
converted[-1].append((attr, value))
|
||||
return converted
|
||||
token.append((attr, value))
|
||||
for op in ops:
|
||||
tokens.append((op, token))
|
||||
return tokens
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
|
@ -150,7 +192,7 @@ def get_bilou(length):
|
|||
|
||||
cdef class Matcher:
|
||||
cdef Pool mem
|
||||
cdef vector[Pattern*] patterns
|
||||
cdef vector[TokenPatternC*] patterns
|
||||
cdef readonly Vocab vocab
|
||||
cdef object _patterns
|
||||
|
||||
|
@ -189,15 +231,15 @@ cdef class Matcher:
|
|||
# entity
|
||||
for spec in specs:
|
||||
spec = _convert_strings(spec, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, spec, etype))
|
||||
self.patterns.push_back(init_pattern(self.mem, spec, entity_key, etype))
|
||||
|
||||
def __call__(self, Doc doc, acceptor=None):
|
||||
cdef vector[Pattern*] partials
|
||||
cdef vector[StateC] partials
|
||||
cdef int n_partials = 0
|
||||
cdef int q = 0
|
||||
cdef int i, token_i
|
||||
cdef const TokenC* token
|
||||
cdef Pattern* state
|
||||
cdef StateC state
|
||||
matches = []
|
||||
for token_i in range(doc.length):
|
||||
token = &doc.c[token_i]
|
||||
|
@ -205,27 +247,57 @@ cdef class Matcher:
|
|||
# Go over the open matches, extending or finalizing if able. Otherwise,
|
||||
# we over-write them (q doesn't advance)
|
||||
for state in partials:
|
||||
if match(state, token):
|
||||
if is_final(state):
|
||||
label, start, end = get_entity(state, token, token_i)
|
||||
if acceptor is None or acceptor(doc, label, start, end):
|
||||
matches.append((label, start, end))
|
||||
else:
|
||||
partials[q] = state + 1
|
||||
q += 1
|
||||
action = get_action(state.second, token)
|
||||
while action == ADVANCE_ZERO:
|
||||
state.second += 1
|
||||
action = get_action(state.second, token)
|
||||
if action == REPEAT:
|
||||
# Leave the state in the queue, and advance to next slot
|
||||
# (i.e. we don't overwrite -- we want to greedily match more
|
||||
# pattern.
|
||||
q += 1
|
||||
elif action == REJECT:
|
||||
pass
|
||||
elif action == ADVANCE:
|
||||
partials[q].second += 1
|
||||
q += 1
|
||||
elif action == ACCEPT:
|
||||
# TODO: What to do about patterns starting with ZERO? Need to
|
||||
# adjust the start position.
|
||||
start = state.first
|
||||
end = token_i+1
|
||||
ent_id = state.second[1].attrs[0].value
|
||||
label = state.second[1].attrs[1].value
|
||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
||||
matches.append((ent_id, label, start, end))
|
||||
partials.resize(q)
|
||||
# Check whether we open any new patterns on this token
|
||||
for state in self.patterns:
|
||||
if match(state, token):
|
||||
if is_final(state):
|
||||
label, start, end = get_entity(state, token, token_i)
|
||||
if acceptor is None or acceptor(doc, label, start, end):
|
||||
matches.append((label, start, end))
|
||||
else:
|
||||
partials.push_back(state + 1)
|
||||
for pattern in self.patterns:
|
||||
action = get_action(pattern, token)
|
||||
while action == ADVANCE_ZERO:
|
||||
pattern += 1
|
||||
action = get_action(pattern, token)
|
||||
if action == REPEAT:
|
||||
state.first = token_i
|
||||
state.second = pattern
|
||||
partials.push_back(state)
|
||||
elif action == ADVANCE:
|
||||
# TODO: What to do about patterns starting with ZERO? Need to
|
||||
# adjust the start position.
|
||||
state.first = token_i
|
||||
state.second = pattern + 1
|
||||
partials.push_back(state)
|
||||
elif action == ACCEPT:
|
||||
start = token_i
|
||||
end = token_i+1
|
||||
ent_id = pattern[1].attrs[0].value
|
||||
label = pattern[1].attrs[1].value
|
||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
||||
matches.append((ent_id, label, start, end))
|
||||
seen = set()
|
||||
filtered = []
|
||||
for label, start, end in sorted(matches, key=lambda m: (m[1], -(m[1] - m[2]))):
|
||||
for ent_id, label, start, end in sorted(matches,
|
||||
key=lambda m: (m[2],-(m[2]-m[3]))):
|
||||
if all(i in seen for i in range(start, end)):
|
||||
continue
|
||||
else:
|
||||
|
|
|
@ -29,6 +29,7 @@ cdef struct LexemeC:
|
|||
|
||||
|
||||
cdef struct Entity:
|
||||
hash_t id
|
||||
int start
|
||||
int end
|
||||
int label
|
||||
|
@ -53,4 +54,5 @@ cdef struct TokenC:
|
|||
uint32_t r_edge
|
||||
|
||||
int ent_iob
|
||||
int ent_type
|
||||
int ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||
hash_t ent_id
|
||||
|
|
|
@ -6,46 +6,85 @@ from spacy.matcher import *
|
|||
from spacy.attrs import LOWER
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.en import English
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(EN):
|
||||
def matcher():
|
||||
patterns = {
|
||||
'Javascript': ['PRODUCT', {}, [[{'ORTH': 'JavaScript'}]]],
|
||||
'JS': ['PRODUCT', {}, [[{'ORTH': 'JavaScript'}]]],
|
||||
'GoogleNow': ['PRODUCT', {}, [[{'ORTH': 'Google'}, {'ORTH': 'Now'}]]],
|
||||
'Java': ['PRODUCT', {}, [[{'LOWER': 'java'}]]],
|
||||
}
|
||||
return Matcher(EN.vocab, patterns)
|
||||
return Matcher(Vocab(get_lex_attr=English.default_lex_attrs()), patterns)
|
||||
|
||||
|
||||
def test_compile(matcher):
|
||||
assert matcher.n_patterns == 3
|
||||
|
||||
|
||||
def test_no_match(matcher, EN):
|
||||
tokens = EN('I like cheese')
|
||||
assert matcher(tokens) == []
|
||||
def test_no_match(matcher):
|
||||
doc = Doc(matcher.vocab, ['I', 'like', 'cheese', '.'])
|
||||
assert matcher(doc) == []
|
||||
|
||||
|
||||
def test_match_start(matcher, EN):
|
||||
tokens = EN('JavaScript is good')
|
||||
assert matcher(tokens) == [(EN.vocab.strings['PRODUCT'], 0, 1)]
|
||||
def test_match_start(matcher):
|
||||
doc = Doc(matcher.vocab, ['JavaScript', 'is', 'good'])
|
||||
assert matcher(doc) == [(matcher.vocab.strings['JS'],
|
||||
matcher.vocab.strings['PRODUCT'], 0, 1)]
|
||||
|
||||
|
||||
def test_match_end(matcher, EN):
|
||||
tokens = EN('I like java')
|
||||
assert matcher(tokens) == [(EN.vocab.strings['PRODUCT'], 2, 3)]
|
||||
def test_match_end(matcher):
|
||||
doc = Doc(matcher.vocab, ['I', 'like', 'java'])
|
||||
assert matcher(doc) == [(doc.vocab.strings['Java'],
|
||||
doc.vocab.strings['PRODUCT'], 2, 3)]
|
||||
|
||||
|
||||
def test_match_middle(matcher, EN):
|
||||
tokens = EN('I like Google Now best')
|
||||
assert matcher(tokens) == [(EN.vocab.strings['PRODUCT'], 2, 4)]
|
||||
def test_match_middle(matcher):
|
||||
doc = Doc(matcher.vocab, ['I', 'like', 'Google', 'Now', 'best'])
|
||||
assert matcher(doc) == [(doc.vocab.strings['GoogleNow'],
|
||||
doc.vocab.strings['PRODUCT'], 2, 4)]
|
||||
|
||||
|
||||
def test_match_multi(matcher, EN):
|
||||
tokens = EN('I like Google Now and java best')
|
||||
assert matcher(tokens) == [(EN.vocab.strings['PRODUCT'], 2, 4),
|
||||
(EN.vocab.strings['PRODUCT'], 5, 6)]
|
||||
def test_match_multi(matcher):
|
||||
doc = Doc(matcher.vocab, 'I like Google Now and java best'.split())
|
||||
assert matcher(doc) == [(doc.vocab.strings['GoogleNow'],
|
||||
doc.vocab.strings['PRODUCT'], 2, 4),
|
||||
(doc.vocab.strings['Java'],
|
||||
doc.vocab.strings['PRODUCT'], 5, 6)]
|
||||
|
||||
def test_match_zero(matcher):
|
||||
matcher.add('Quote', '', {}, [
|
||||
[
|
||||
{'ORTH': '"'},
|
||||
{'OP': '!', 'IS_PUNCT': True},
|
||||
{'OP': '!', 'IS_PUNCT': True},
|
||||
{'ORTH': '"'}
|
||||
]])
|
||||
doc = Doc(matcher.vocab, 'He said , " some words " ...'.split())
|
||||
assert len(matcher(doc)) == 1
|
||||
doc = Doc(matcher.vocab, 'He said , " some three words " ...'.split())
|
||||
assert len(matcher(doc)) == 0
|
||||
matcher.add('Quote', '', {}, [
|
||||
[
|
||||
{'ORTH': '"'},
|
||||
{'IS_PUNCT': True},
|
||||
{'IS_PUNCT': True},
|
||||
{'IS_PUNCT': True},
|
||||
{'ORTH': '"'}
|
||||
]])
|
||||
assert len(matcher(doc)) == 0
|
||||
|
||||
|
||||
def test_match_zero_plus(matcher):
|
||||
matcher.add('Quote', '', {}, [
|
||||
[
|
||||
{'ORTH': '"'},
|
||||
{'OP': '*', 'IS_PUNCT': False},
|
||||
{'ORTH': '"'}
|
||||
]])
|
||||
doc = Doc(matcher.vocab, 'He said , " some words " ...'.split())
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
|
|
|
@ -241,6 +241,27 @@ cdef class Span:
|
|||
for word in self.rights:
|
||||
yield from word.subtree
|
||||
|
||||
property ent_id:
|
||||
'''An (integer) entity ID. Usually assigned by patterns in the Matcher.'''
|
||||
def __get__(self):
|
||||
return self.root.ent_id
|
||||
|
||||
def __set__(self, hash_t key):
|
||||
# TODO
|
||||
raise NotImplementedError(
|
||||
"Can't yet set ent_id from Span. Vote for this feature on the issue "
|
||||
"tracker: http://github.com/spacy-io/spaCy")
|
||||
property ent_id_:
|
||||
'''A (string) entity ID. Usually assigned by patterns in the Matcher.'''
|
||||
def __get__(self):
|
||||
return self.root.ent_id_
|
||||
|
||||
def __set__(self, hash_t key):
|
||||
# TODO
|
||||
raise NotImplementedError(
|
||||
"Can't yet set ent_id_ from Span. Vote for this feature on the issue "
|
||||
"tracker: http://github.com/spacy-io/spaCy")
|
||||
|
||||
property orth_:
|
||||
def __get__(self):
|
||||
return ''.join([t.string for t in self]).strip()
|
||||
|
|
|
@ -5,10 +5,9 @@ from .doc cimport Doc
|
|||
|
||||
|
||||
cdef class Token:
|
||||
cdef Vocab vocab
|
||||
cdef readonly Vocab vocab
|
||||
cdef TokenC* c
|
||||
cdef readonly int i
|
||||
cdef int array_len
|
||||
cdef readonly Doc doc
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -58,7 +58,6 @@ cdef class Token:
|
|||
self.doc = doc
|
||||
self.c = &self.doc.c[offset]
|
||||
self.i = offset
|
||||
self.array_len = doc.length
|
||||
|
||||
def __len__(self):
|
||||
return self.c.lex.length
|
||||
|
@ -410,6 +409,28 @@ cdef class Token:
|
|||
iob_strings = ('', 'I', 'O', 'B')
|
||||
return iob_strings[self.c.ent_iob]
|
||||
|
||||
property ent_id:
|
||||
'''An (integer) entity ID. Usually assigned by patterns in the Matcher.'''
|
||||
def __get__(self):
|
||||
return self.c.ent.ent_id
|
||||
|
||||
def __set__(self, hash_t key):
|
||||
# TODO
|
||||
raise NotImplementedError(
|
||||
"Can't yet set ent_id from Token. Vote for this feature on the issue "
|
||||
"tracker: http://github.com/spacy-io/spaCy")
|
||||
|
||||
property ent_id_:
|
||||
'''A (string) entity ID. Usually assigned by patterns in the Matcher.'''
|
||||
def __get__(self):
|
||||
return self.vocab.strings[self.c.ent_id]
|
||||
|
||||
def __set__(self, hash_t key):
|
||||
# TODO
|
||||
raise NotImplementedError(
|
||||
"Can't yet set ent_id_ from Token. Vote for this feature on the issue "
|
||||
"tracker: http://github.com/spacy-io/spaCy")
|
||||
|
||||
property whitespace_:
|
||||
def __get__(self):
|
||||
return ' ' if self.c.spacy else ''
|
||||
|
@ -507,3 +528,17 @@ cdef class Token:
|
|||
|
||||
property like_email:
|
||||
def __get__(self): return Lexeme.c_check_flag(self.c.lex, LIKE_EMAIL)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
doc = nlp('Google Now is a moribund project destined for closure.')
|
||||
|
||||
google_now = doc.ents[0] # Span instance
|
||||
|
||||
google_now.attrs['category'] == 'TECHNOLOGY'
|
||||
|
||||
ent_id = google_now.ent_id
|
||||
|
||||
attrs = nlp.matcher.get_attrs(ent_id)
|
||||
|
|
Loading…
Reference in New Issue
Block a user