diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index c19fd8919..f6f3f95d6 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -1,52 +1,100 @@ -class MatchState(object): - def __init__(self, token_spec, ext): - self.token_spec = token_spec - self.ext = ext - self.is_final = False +from .typedefs cimport attr_t +from .attrs cimport attr_id_t +from .structs cimport TokenC - def match(self, token): - for attr, value in self.token_spec: - if getattr(token, attr) != value: - return False - else: - return True +from cymem.cymem cimport Pool +from libcpp.vector cimport vector - def __repr__(self): - return '' % (self.token_spec) +from .attrs cimport LENGTH, ENT_TYPE +from .tokens.doc cimport get_token_attr +from .tokens.doc cimport Doc +from .vocab cimport Vocab -class EndState(object): - def __init__(self, entity_type, length): - self.entity_type = entity_type - self.length = length - self.is_final = True - - def __call__(self, token): - return (self.entity_type, ((token.i+1) - self.length), token.i+1) - - def __repr__(self): - return '' % (self.entity_type) +cdef struct AttrValue: + attr_id_t attr + attr_t value -class Matcher(object): +cdef struct Pattern: + AttrValue* spec + int length + + +cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL: + pattern = mem.alloc(len(token_specs) + 1, sizeof(Pattern)) + cdef int i + for i, spec in enumerate(token_specs): + pattern[i].spec = mem.alloc(len(spec), sizeof(AttrValue)) + pattern[i].length = len(spec) + for j, (attr, value) in enumerate(spec): + pattern[i].spec[j].attr = attr + pattern[i].spec[j].value = value + i = len(token_specs) + pattern[i].spec = mem.alloc(1, sizeof(AttrValue)) + pattern[i].spec[0].attr = ENT_TYPE + pattern[i].spec[0].value = entity_type + pattern[i].spec[1].attr = LENGTH + pattern[i].spec[1].value = len(token_specs) + pattern[i].length = 0 + return pattern + + +cdef int match(const Pattern* pattern, const TokenC* token) except -1: + cdef int i + for i in range(pattern.length): + if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value: + return False + return True + + +cdef int is_final(const Pattern* pattern) except -1: + return (pattern + 1).length == 0 + + +cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i): + pattern += 1 + i += 1 + return (pattern.spec[0].value, i - pattern.spec[1].value, i) + + +cdef class Matcher: + cdef Pool mem + cdef Pattern** patterns + cdef readonly int n_patterns + def __init__(self, patterns): - self.start_states = [] - for token_specs, entity_type in patterns: - state = EndState(entity_type, len(token_specs)) - for spec in reversed(token_specs): - state = MatchState(spec, state) - self.start_states.append(state) + self.mem = Pool() + self.patterns = self.mem.alloc(len(patterns), sizeof(Pattern*)) + for i, (token_specs, entity_type) in enumerate(patterns): + self.patterns[i] = init_pattern(self.mem, token_specs, entity_type) + self.n_patterns = len(patterns) - def __call__(self, tokens): - queue = list(self.start_states) + def __call__(self, Doc doc): + cdef vector[Pattern*] partials + cdef int n_partials = 0 + cdef int q = 0 + cdef int i, token_i + cdef const TokenC* token + cdef Pattern* state matches = [] - for token in tokens: - next_queue = list(self.start_states) - for pattern in queue: - if pattern.match(token): - if pattern.ext.is_final: - matches.append(pattern.ext(token)) + for token_i in range(doc.length): + token = &doc.data[token_i] + q = 0 + for i in range(partials.size()): + state = partials.at(i) + if match(state, token): + if is_final(state): + matches.append(get_entity(state, token, token_i)) else: - next_queue.append(pattern.ext) - queue = next_queue + partials[q] = state + 1 + q += 1 + partials.resize(q) + for i in range(self.n_patterns): + state = self.patterns[i] + if match(state, token): + if is_final(state): + matches.append(get_entity(state, token, token_i)) + else: + partials.push_back(state + 1) return matches diff --git a/tests/test_matcher.py b/tests/test_matcher.py index 391d9526c..fb3665623 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -1,52 +1,51 @@ from __future__ import unicode_literals import pytest +from spacy.strings import StringStore from spacy.matcher import * - - -class MockToken(object): - def __init__(self, i, string): - self.i = i - self.orth_ = string - - -def make_tokens(string): - return [MockToken(i, s) for i, s in enumerate(string.split())] +from spacy.attrs import ORTH +from spacy.tokens.doc import Doc +from spacy.vocab import Vocab @pytest.fixture -def matcher(): +def matcher(EN): specs = [] for string in ['JavaScript', 'Google Now', 'Java']: - spec = tuple([[('orth_', orth)] for orth in string.split()]) - specs.append((spec, 'product')) + spec = [] + for orth_ in string.split(): + spec.append([(ORTH, EN.vocab.strings[orth_])]) + specs.append((spec, EN.vocab.strings['product'])) return Matcher(specs) def test_compile(matcher): - assert len(matcher.start_states) == 3 + assert matcher.n_patterns == 3 - -def test_no_match(matcher): - tokens = make_tokens('I like cheese') +def test_no_match(matcher, EN): + tokens = EN('I like cheese') assert matcher(tokens) == [] -def test_match_start(matcher): - tokens = make_tokens('JavaScript is good') - assert matcher(tokens) == [('product', 0, 1)] +def test_match_start(matcher, EN): + tokens = EN('JavaScript is good') + assert matcher(tokens) == [(EN.vocab.strings['product'], 0, 1)] -def test_match_end(matcher): - tokens = make_tokens('I like Java') - assert matcher(tokens) == [('product', 2, 3)] +def test_match_end(matcher, EN): + tokens = EN('I like Java') + assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 3)] -def test_match_middle(matcher): - tokens = make_tokens('I like Google Now best') - assert matcher(tokens) == [('product', 2, 4)] +def test_match_middle(matcher, EN): + tokens = EN('I like Google Now best') + assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4)] -def test_match_multi(matcher): - tokens = make_tokens('I like Google Now and Java best') - assert matcher(tokens) == [('product', 2, 4), ('product', 5, 6)] +def test_match_multi(matcher, EN): + tokens = EN('I like Google Now and Java best') + assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4), + (EN.vocab.strings['product'], 5, 6)] + +def test_dummy(): + pass