mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Add draft dfa matcher, in Python. Passing tests.
This commit is contained in:
parent
eb7138c761
commit
4c87a696b3
52
spacy/matcher.pyx
Normal file
52
spacy/matcher.pyx
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
class MatchState(object):
|
||||||
|
def __init__(self, token_spec, ext):
|
||||||
|
self.token_spec = token_spec
|
||||||
|
self.ext = ext
|
||||||
|
self.is_final = False
|
||||||
|
|
||||||
|
def match(self, token):
|
||||||
|
for attr, value in self.token_spec:
|
||||||
|
if getattr(token, attr) != value:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<spec %s>' % (self.token_spec)
|
||||||
|
|
||||||
|
|
||||||
|
class EndState(object):
|
||||||
|
def __init__(self, entity_type, length):
|
||||||
|
self.entity_type = entity_type
|
||||||
|
self.length = length
|
||||||
|
self.is_final = True
|
||||||
|
|
||||||
|
def __call__(self, token):
|
||||||
|
return (self.entity_type, ((token.i+1) - self.length), token.i+1)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<end %s>' % (self.entity_type)
|
||||||
|
|
||||||
|
|
||||||
|
class Matcher(object):
|
||||||
|
def __init__(self, patterns):
|
||||||
|
self.start_states = []
|
||||||
|
for token_specs, entity_type in patterns:
|
||||||
|
state = EndState(entity_type, len(token_specs))
|
||||||
|
for spec in reversed(token_specs):
|
||||||
|
state = MatchState(spec, state)
|
||||||
|
self.start_states.append(state)
|
||||||
|
|
||||||
|
def __call__(self, tokens):
|
||||||
|
queue = list(self.start_states)
|
||||||
|
matches = []
|
||||||
|
for token in tokens:
|
||||||
|
next_queue = list(self.start_states)
|
||||||
|
for pattern in queue:
|
||||||
|
if pattern.match(token):
|
||||||
|
if pattern.ext.is_final:
|
||||||
|
matches.append(pattern.ext(token))
|
||||||
|
else:
|
||||||
|
next_queue.append(pattern.ext)
|
||||||
|
queue = next_queue
|
||||||
|
return matches
|
52
tests/test_matcher.py
Normal file
52
tests/test_matcher.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from spacy.matcher import *
|
||||||
|
|
||||||
|
|
||||||
|
class MockToken(object):
|
||||||
|
def __init__(self, i, string):
|
||||||
|
self.i = i
|
||||||
|
self.orth_ = string
|
||||||
|
|
||||||
|
|
||||||
|
def make_tokens(string):
|
||||||
|
return [MockToken(i, s) for i, s in enumerate(string.split())]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def matcher():
|
||||||
|
specs = []
|
||||||
|
for string in ['JavaScript', 'Google Now', 'Java']:
|
||||||
|
spec = tuple([[('orth_', orth)] for orth in string.split()])
|
||||||
|
specs.append((spec, 'product'))
|
||||||
|
return Matcher(specs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile(matcher):
|
||||||
|
assert len(matcher.start_states) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_match(matcher):
|
||||||
|
tokens = make_tokens('I like cheese')
|
||||||
|
assert matcher(tokens) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_start(matcher):
|
||||||
|
tokens = make_tokens('JavaScript is good')
|
||||||
|
assert matcher(tokens) == [('product', 0, 1)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_end(matcher):
|
||||||
|
tokens = make_tokens('I like Java')
|
||||||
|
assert matcher(tokens) == [('product', 2, 3)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_middle(matcher):
|
||||||
|
tokens = make_tokens('I like Google Now best')
|
||||||
|
assert matcher(tokens) == [('product', 2, 4)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_multi(matcher):
|
||||||
|
tokens = make_tokens('I like Google Now and Java best')
|
||||||
|
assert matcher(tokens) == [('product', 2, 4), ('product', 5, 6)]
|
Loading…
Reference in New Issue
Block a user