* Reimplement matching in Cython, instead of Python.

This commit is contained in:
Matthew Honnibal 2015-08-05 01:05:54 +02:00
parent 4c87a696b3
commit 5bc0e83f9a
2 changed files with 117 additions and 70 deletions

View File

@ -1,52 +1,100 @@
class MatchState(object): from .typedefs cimport attr_t
def __init__(self, token_spec, ext): from .attrs cimport attr_id_t
self.token_spec = token_spec from .structs cimport TokenC
self.ext = ext
self.is_final = False
def match(self, token): from cymem.cymem cimport Pool
for attr, value in self.token_spec: from libcpp.vector cimport vector
if getattr(token, attr) != value:
from .attrs cimport LENGTH, ENT_TYPE
from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
cdef struct AttrValue:
attr_id_t attr
attr_t value
cdef struct Pattern:
AttrValue* spec
int length
cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL:
pattern = <Pattern*>mem.alloc(len(token_specs) + 1, sizeof(Pattern))
cdef int i
for i, spec in enumerate(token_specs):
pattern[i].spec = <AttrValue*>mem.alloc(len(spec), sizeof(AttrValue))
pattern[i].length = len(spec)
for j, (attr, value) in enumerate(spec):
pattern[i].spec[j].attr = attr
pattern[i].spec[j].value = value
i = len(token_specs)
pattern[i].spec = <AttrValue*>mem.alloc(1, sizeof(AttrValue))
pattern[i].spec[0].attr = ENT_TYPE
pattern[i].spec[0].value = entity_type
pattern[i].spec[1].attr = LENGTH
pattern[i].spec[1].value = len(token_specs)
pattern[i].length = 0
return pattern
cdef int match(const Pattern* pattern, const TokenC* token) except -1:
cdef int i
for i in range(pattern.length):
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
return False return False
else:
return True return True
def __repr__(self):
return '<spec %s>' % (self.token_spec) cdef int is_final(const Pattern* pattern) except -1:
return (pattern + 1).length == 0
class EndState(object): cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
def __init__(self, entity_type, length): pattern += 1
self.entity_type = entity_type i += 1
self.length = length return (pattern.spec[0].value, i - pattern.spec[1].value, i)
self.is_final = True
def __call__(self, token):
return (self.entity_type, ((token.i+1) - self.length), token.i+1)
def __repr__(self):
return '<end %s>' % (self.entity_type)
class Matcher(object): cdef class Matcher:
cdef Pool mem
cdef Pattern** patterns
cdef readonly int n_patterns
def __init__(self, patterns): def __init__(self, patterns):
self.start_states = [] self.mem = Pool()
for token_specs, entity_type in patterns: self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*))
state = EndState(entity_type, len(token_specs)) for i, (token_specs, entity_type) in enumerate(patterns):
for spec in reversed(token_specs): self.patterns[i] = init_pattern(self.mem, token_specs, entity_type)
state = MatchState(spec, state) self.n_patterns = len(patterns)
self.start_states.append(state)
def __call__(self, tokens): def __call__(self, Doc doc):
queue = list(self.start_states) cdef vector[Pattern*] partials
cdef int n_partials = 0
cdef int q = 0
cdef int i, token_i
cdef const TokenC* token
cdef Pattern* state
matches = [] matches = []
for token in tokens: for token_i in range(doc.length):
next_queue = list(self.start_states) token = &doc.data[token_i]
for pattern in queue: q = 0
if pattern.match(token): for i in range(partials.size()):
if pattern.ext.is_final: state = partials.at(i)
matches.append(pattern.ext(token)) if match(state, token):
if is_final(state):
matches.append(get_entity(state, token, token_i))
else: else:
next_queue.append(pattern.ext) partials[q] = state + 1
queue = next_queue q += 1
partials.resize(q)
for i in range(self.n_patterns):
state = self.patterns[i]
if match(state, token):
if is_final(state):
matches.append(get_entity(state, token, token_i))
else:
partials.push_back(state + 1)
return matches return matches

View File

@ -1,52 +1,51 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
from spacy.strings import StringStore
from spacy.matcher import * from spacy.matcher import *
from spacy.attrs import ORTH
from spacy.tokens.doc import Doc
class MockToken(object): from spacy.vocab import Vocab
def __init__(self, i, string):
self.i = i
self.orth_ = string
def make_tokens(string):
return [MockToken(i, s) for i, s in enumerate(string.split())]
@pytest.fixture @pytest.fixture
def matcher(): def matcher(EN):
specs = [] specs = []
for string in ['JavaScript', 'Google Now', 'Java']: for string in ['JavaScript', 'Google Now', 'Java']:
spec = tuple([[('orth_', orth)] for orth in string.split()]) spec = []
specs.append((spec, 'product')) for orth_ in string.split():
spec.append([(ORTH, EN.vocab.strings[orth_])])
specs.append((spec, EN.vocab.strings['product']))
return Matcher(specs) return Matcher(specs)
def test_compile(matcher): def test_compile(matcher):
assert len(matcher.start_states) == 3 assert matcher.n_patterns == 3
def test_no_match(matcher, EN):
def test_no_match(matcher): tokens = EN('I like cheese')
tokens = make_tokens('I like cheese')
assert matcher(tokens) == [] assert matcher(tokens) == []
def test_match_start(matcher): def test_match_start(matcher, EN):
tokens = make_tokens('JavaScript is good') tokens = EN('JavaScript is good')
assert matcher(tokens) == [('product', 0, 1)] assert matcher(tokens) == [(EN.vocab.strings['product'], 0, 1)]
def test_match_end(matcher): def test_match_end(matcher, EN):
tokens = make_tokens('I like Java') tokens = EN('I like Java')
assert matcher(tokens) == [('product', 2, 3)] assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 3)]
def test_match_middle(matcher): def test_match_middle(matcher, EN):
tokens = make_tokens('I like Google Now best') tokens = EN('I like Google Now best')
assert matcher(tokens) == [('product', 2, 4)] assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4)]
def test_match_multi(matcher): def test_match_multi(matcher, EN):
tokens = make_tokens('I like Google Now and Java best') tokens = EN('I like Google Now and Java best')
assert matcher(tokens) == [('product', 2, 4), ('product', 5, 6)] assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4),
(EN.vocab.strings['product'], 5, 6)]
def test_dummy():
pass