mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Reimplement matching in Cython, instead of Python.
This commit is contained in:
parent
4c87a696b3
commit
5bc0e83f9a
|
@ -1,52 +1,100 @@
|
||||||
class MatchState(object):
|
from .typedefs cimport attr_t
|
||||||
def __init__(self, token_spec, ext):
|
from .attrs cimport attr_id_t
|
||||||
self.token_spec = token_spec
|
from .structs cimport TokenC
|
||||||
self.ext = ext
|
|
||||||
self.is_final = False
|
|
||||||
|
|
||||||
def match(self, token):
|
from cymem.cymem cimport Pool
|
||||||
for attr, value in self.token_spec:
|
from libcpp.vector cimport vector
|
||||||
if getattr(token, attr) != value:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __repr__(self):
|
from .attrs cimport LENGTH, ENT_TYPE
|
||||||
return '<spec %s>' % (self.token_spec)
|
from .tokens.doc cimport get_token_attr
|
||||||
|
from .tokens.doc cimport Doc
|
||||||
|
from .vocab cimport Vocab
|
||||||
|
|
||||||
|
|
||||||
class EndState(object):
|
cdef struct AttrValue:
|
||||||
def __init__(self, entity_type, length):
|
attr_id_t attr
|
||||||
self.entity_type = entity_type
|
attr_t value
|
||||||
self.length = length
|
|
||||||
self.is_final = True
|
|
||||||
|
|
||||||
def __call__(self, token):
|
|
||||||
return (self.entity_type, ((token.i+1) - self.length), token.i+1)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '<end %s>' % (self.entity_type)
|
|
||||||
|
|
||||||
|
|
||||||
class Matcher(object):
|
cdef struct Pattern:
|
||||||
|
AttrValue* spec
|
||||||
|
int length
|
||||||
|
|
||||||
|
|
||||||
|
cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL:
|
||||||
|
pattern = <Pattern*>mem.alloc(len(token_specs) + 1, sizeof(Pattern))
|
||||||
|
cdef int i
|
||||||
|
for i, spec in enumerate(token_specs):
|
||||||
|
pattern[i].spec = <AttrValue*>mem.alloc(len(spec), sizeof(AttrValue))
|
||||||
|
pattern[i].length = len(spec)
|
||||||
|
for j, (attr, value) in enumerate(spec):
|
||||||
|
pattern[i].spec[j].attr = attr
|
||||||
|
pattern[i].spec[j].value = value
|
||||||
|
i = len(token_specs)
|
||||||
|
pattern[i].spec = <AttrValue*>mem.alloc(1, sizeof(AttrValue))
|
||||||
|
pattern[i].spec[0].attr = ENT_TYPE
|
||||||
|
pattern[i].spec[0].value = entity_type
|
||||||
|
pattern[i].spec[1].attr = LENGTH
|
||||||
|
pattern[i].spec[1].value = len(token_specs)
|
||||||
|
pattern[i].length = 0
|
||||||
|
return pattern
|
||||||
|
|
||||||
|
|
||||||
|
cdef int match(const Pattern* pattern, const TokenC* token) except -1:
|
||||||
|
cdef int i
|
||||||
|
for i in range(pattern.length):
|
||||||
|
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
cdef int is_final(const Pattern* pattern) except -1:
|
||||||
|
return (pattern + 1).length == 0
|
||||||
|
|
||||||
|
|
||||||
|
cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
|
||||||
|
pattern += 1
|
||||||
|
i += 1
|
||||||
|
return (pattern.spec[0].value, i - pattern.spec[1].value, i)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Matcher:
|
||||||
|
cdef Pool mem
|
||||||
|
cdef Pattern** patterns
|
||||||
|
cdef readonly int n_patterns
|
||||||
|
|
||||||
def __init__(self, patterns):
|
def __init__(self, patterns):
|
||||||
self.start_states = []
|
self.mem = Pool()
|
||||||
for token_specs, entity_type in patterns:
|
self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*))
|
||||||
state = EndState(entity_type, len(token_specs))
|
for i, (token_specs, entity_type) in enumerate(patterns):
|
||||||
for spec in reversed(token_specs):
|
self.patterns[i] = init_pattern(self.mem, token_specs, entity_type)
|
||||||
state = MatchState(spec, state)
|
self.n_patterns = len(patterns)
|
||||||
self.start_states.append(state)
|
|
||||||
|
|
||||||
def __call__(self, tokens):
|
def __call__(self, Doc doc):
|
||||||
queue = list(self.start_states)
|
cdef vector[Pattern*] partials
|
||||||
|
cdef int n_partials = 0
|
||||||
|
cdef int q = 0
|
||||||
|
cdef int i, token_i
|
||||||
|
cdef const TokenC* token
|
||||||
|
cdef Pattern* state
|
||||||
matches = []
|
matches = []
|
||||||
for token in tokens:
|
for token_i in range(doc.length):
|
||||||
next_queue = list(self.start_states)
|
token = &doc.data[token_i]
|
||||||
for pattern in queue:
|
q = 0
|
||||||
if pattern.match(token):
|
for i in range(partials.size()):
|
||||||
if pattern.ext.is_final:
|
state = partials.at(i)
|
||||||
matches.append(pattern.ext(token))
|
if match(state, token):
|
||||||
|
if is_final(state):
|
||||||
|
matches.append(get_entity(state, token, token_i))
|
||||||
else:
|
else:
|
||||||
next_queue.append(pattern.ext)
|
partials[q] = state + 1
|
||||||
queue = next_queue
|
q += 1
|
||||||
|
partials.resize(q)
|
||||||
|
for i in range(self.n_patterns):
|
||||||
|
state = self.patterns[i]
|
||||||
|
if match(state, token):
|
||||||
|
if is_final(state):
|
||||||
|
matches.append(get_entity(state, token, token_i))
|
||||||
|
else:
|
||||||
|
partials.push_back(state + 1)
|
||||||
return matches
|
return matches
|
||||||
|
|
|
@ -1,52 +1,51 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from spacy.strings import StringStore
|
||||||
from spacy.matcher import *
|
from spacy.matcher import *
|
||||||
|
from spacy.attrs import ORTH
|
||||||
|
from spacy.tokens.doc import Doc
|
||||||
class MockToken(object):
|
from spacy.vocab import Vocab
|
||||||
def __init__(self, i, string):
|
|
||||||
self.i = i
|
|
||||||
self.orth_ = string
|
|
||||||
|
|
||||||
|
|
||||||
def make_tokens(string):
|
|
||||||
return [MockToken(i, s) for i, s in enumerate(string.split())]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def matcher():
|
def matcher(EN):
|
||||||
specs = []
|
specs = []
|
||||||
for string in ['JavaScript', 'Google Now', 'Java']:
|
for string in ['JavaScript', 'Google Now', 'Java']:
|
||||||
spec = tuple([[('orth_', orth)] for orth in string.split()])
|
spec = []
|
||||||
specs.append((spec, 'product'))
|
for orth_ in string.split():
|
||||||
|
spec.append([(ORTH, EN.vocab.strings[orth_])])
|
||||||
|
specs.append((spec, EN.vocab.strings['product']))
|
||||||
return Matcher(specs)
|
return Matcher(specs)
|
||||||
|
|
||||||
|
|
||||||
def test_compile(matcher):
|
def test_compile(matcher):
|
||||||
assert len(matcher.start_states) == 3
|
assert matcher.n_patterns == 3
|
||||||
|
|
||||||
|
def test_no_match(matcher, EN):
|
||||||
def test_no_match(matcher):
|
tokens = EN('I like cheese')
|
||||||
tokens = make_tokens('I like cheese')
|
|
||||||
assert matcher(tokens) == []
|
assert matcher(tokens) == []
|
||||||
|
|
||||||
|
|
||||||
def test_match_start(matcher):
|
def test_match_start(matcher, EN):
|
||||||
tokens = make_tokens('JavaScript is good')
|
tokens = EN('JavaScript is good')
|
||||||
assert matcher(tokens) == [('product', 0, 1)]
|
assert matcher(tokens) == [(EN.vocab.strings['product'], 0, 1)]
|
||||||
|
|
||||||
|
|
||||||
def test_match_end(matcher):
|
def test_match_end(matcher, EN):
|
||||||
tokens = make_tokens('I like Java')
|
tokens = EN('I like Java')
|
||||||
assert matcher(tokens) == [('product', 2, 3)]
|
assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 3)]
|
||||||
|
|
||||||
|
|
||||||
def test_match_middle(matcher):
|
def test_match_middle(matcher, EN):
|
||||||
tokens = make_tokens('I like Google Now best')
|
tokens = EN('I like Google Now best')
|
||||||
assert matcher(tokens) == [('product', 2, 4)]
|
assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4)]
|
||||||
|
|
||||||
|
|
||||||
def test_match_multi(matcher):
|
def test_match_multi(matcher, EN):
|
||||||
tokens = make_tokens('I like Google Now and Java best')
|
tokens = EN('I like Google Now and Java best')
|
||||||
assert matcher(tokens) == [('product', 2, 4), ('product', 5, 6)]
|
assert matcher(tokens) == [(EN.vocab.strings['product'], 2, 4),
|
||||||
|
(EN.vocab.strings['product'], 5, 6)]
|
||||||
|
|
||||||
|
def test_dummy():
|
||||||
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user