Option for returning only greedy matches (#5771)

* add "greedy" option for match pattern

* distinction between greedy FIRST or LONGEST

* check for proper values, throw custom warning otherwise

* unxfail one more test

* add comment in docstring

* add test that LONGEST also prefers first match if equal length

* use c arrays for more efficient processing

* rename 'greediness' to 'greedy'
This commit is contained in:
Sofie Van Landeghem 2020-07-29 11:04:43 +02:00 committed by GitHub
parent 191a12d75f
commit 40c995b1be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 52 deletions

View File

@ -432,12 +432,12 @@ class Errors:
"Current DocBin: {current}\nOther DocBin: {other}") "Current DocBin: {current}\nOther DocBin: {other}")
E169 = ("Can't find module: {module}") E169 = ("Can't find module: {module}")
E170 = ("Cannot apply transition {name}: invalid for the current state.") E170 = ("Cannot apply transition {name}: invalid for the current state.")
E171 = ("Matcher.add received invalid on_match callback argument: expected " E171 = ("Matcher.add received invalid 'on_match' callback argument: expected "
"callable or None, but got: {arg_type}") "callable or None, but got: {arg_type}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}") E175 = ("Can't remove rule for unknown match pattern ID: {key}")
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.") E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
E177 = ("Ill-formed IOB input detected: {tag}") E177 = ("Ill-formed IOB input detected: {tag}")
E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you " E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you "
"accidentally passed a single pattern to Matcher.add instead of a " "accidentally passed a single pattern to Matcher.add instead of a "
"list of patterns? If you only want to add one pattern, make sure " "list of patterns? If you only want to add one pattern, make sure "
"to wrap it in a list. For example: matcher.add('{key}', [pattern])") "to wrap it in a list. For example: matcher.add('{key}', [pattern])")
@ -483,6 +483,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
"a List, but got: {arg_type}")
E952 = ("The section '{name}' is not a valid section in the provided config.") E952 = ("The section '{name}' is not a valid section in the provided config.")
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.") E954 = ("The Tok2Vec listener did not receive a valid input.")

View File

@ -66,6 +66,7 @@ cdef class Matcher:
cdef public object validate cdef public object validate
cdef public object _patterns cdef public object _patterns
cdef public object _callbacks cdef public object _callbacks
cdef public object _filter
cdef public object _extensions cdef public object _extensions
cdef public object _extra_predicates cdef public object _extra_predicates
cdef public object _seen_attrs cdef public object _seen_attrs

View File

@ -1,6 +1,9 @@
# cython: infer_types=True, cython: profile=True # cython: infer_types=True, cython: profile=True
from typing import List
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libc.string cimport memset, memcmp
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
@ -42,6 +45,7 @@ cdef class Matcher:
self._extra_predicates = [] self._extra_predicates = []
self._patterns = {} self._patterns = {}
self._callbacks = {} self._callbacks = {}
self._filter = {}
self._extensions = {} self._extensions = {}
self._seen_attrs = set() self._seen_attrs = set()
self.vocab = vocab self.vocab = vocab
@ -69,7 +73,7 @@ cdef class Matcher:
""" """
return self._normalize_key(key) in self._patterns return self._normalize_key(key) in self._patterns
def add(self, key, patterns, *_patterns, on_match=None): def add(self, key, patterns, *, on_match=None, greedy: str=None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID """Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns. key, an on_match callback, and one or more patterns.
@ -87,11 +91,10 @@ cdef class Matcher:
'+': Require the pattern to match 1 or more times. '+': Require the pattern to match 1 or more times.
'*': Allow the pattern to zero or more times. '*': Allow the pattern to zero or more times.
The + and * operators are usually interpretted "greedily", i.e. longer The + and * operators return all possible matches (not just the greedy
matches are returned where possible. However, if you specify two '+' ones). However, the "greedy" argument can filter the final matches
and '*' patterns in a row and their matches overlap, the first by returning a non-overlapping set per key, either taking preference to
operator will behave non-greedily. This quirk in the semantics makes the first greedy match ("FIRST"), or the longest ("LONGEST").
the matcher more efficient, by avoiding the need for back-tracking.
As of spaCy v2.2.2, Matcher.add supports the future API, which makes As of spaCy v2.2.2, Matcher.add supports the future API, which makes
the patterns the second argument and a list (instead of a variable the patterns the second argument and a list (instead of a variable
@ -101,16 +104,15 @@ cdef class Matcher:
key (str): The match ID. key (str): The match ID.
patterns (list): The patterns to add for the given key. patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match. on_match (callable): Optional callback executed on match.
*_patterns (list): For backwards compatibility: list of patterns to add greedy (str): Optional filter: "FIRST" or "LONGEST".
as variable arguments. Will be ignored if a list of patterns is
provided as the second argument.
""" """
errors = {} errors = {}
if on_match is not None and not hasattr(on_match, "__call__"): if on_match is not None and not hasattr(on_match, "__call__"):
raise ValueError(Errors.E171.format(arg_type=type(on_match))) raise ValueError(Errors.E171.format(arg_type=type(on_match)))
if patterns is None or hasattr(patterns, "__call__"): # old API if patterns is None or not isinstance(patterns, List): # old API
on_match = patterns raise ValueError(Errors.E948.format(arg_type=type(patterns)))
patterns = _patterns if greedy is not None and greedy not in ["FIRST", "LONGEST"]:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy))
for i, pattern in enumerate(patterns): for i, pattern in enumerate(patterns):
if len(pattern) == 0: if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key)) raise ValueError(Errors.E012.format(key=key))
@ -133,6 +135,7 @@ cdef class Matcher:
raise ValueError(Errors.E154.format()) raise ValueError(Errors.E154.format())
self._patterns.setdefault(key, []) self._patterns.setdefault(key, [])
self._callbacks[key] = on_match self._callbacks[key] = on_match
self._filter[key] = greedy
self._patterns[key].extend(patterns) self._patterns[key].extend(patterns)
def remove(self, key): def remove(self, key):
@ -218,6 +221,7 @@ cdef class Matcher:
length = doclike.end - doclike.start length = doclike.end - doclike.start
else: else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
cdef Pool tmp_pool = Pool()
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged: and not doc.is_tagged:
raise ValueError(Errors.E155.format()) raise ValueError(Errors.E155.format())
@ -225,11 +229,42 @@ cdef class Matcher:
raise ValueError(Errors.E156.format()) raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates) extensions=self._extensions, predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches): final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered, non-overlapping ones
for (key, start, end) in matches:
span_filter = self._filter.get(key)
if span_filter is not None:
pairs = pairs_by_id.get(key, [])
pairs.append((start,end))
pairs_by_id[key] = pairs
else:
final_matches.append((key, start, end))
matched = <char*>tmp_pool.alloc(length, sizeof(char))
empty = <char*>tmp_pool.alloc(length, sizeof(char))
for key, pairs in pairs_by_id.items():
memset(matched, 0, length * sizeof(matched[0]))
span_filter = self._filter.get(key)
if span_filter == "FIRST":
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
elif span_filter == "LONGEST":
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for (start, end) in sorted_pairs:
assert 0 <= start < end # Defend against segfaults
span_len = end-start
# If no tokens in the span have matched
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
final_matches.append((key, start, end))
# Mark tokens that have matched
memset(&matched[start], 1, span_len * sizeof(matched[0]))
# perform the callbacks on the filtered set of results
for i, (key, start, end) in enumerate(final_matches):
on_match = self._callbacks.get(key, None) on_match = self._callbacks.get(key, None)
if on_match is not None: if on_match is not None:
on_match(self, doc, i, matches) on_match(self, doc, i, final_matches)
return matches return final_matches
def _normalize_key(self, key): def _normalize_key(self, key):
if isinstance(key, basestring): if isinstance(key, basestring):
@ -240,9 +275,9 @@ cdef class Matcher:
def unpickle_matcher(vocab, patterns, callbacks): def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab) matcher = Matcher(vocab)
for key, specs in patterns.items(): for key, pattern in patterns.items():
callback = callbacks.get(key, None) callback = callbacks.get(key, None)
matcher.add(key, callback, *specs) matcher.add(key, pattern, on_match=callback)
return matcher return matcher

View File

@ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
""" """
# TODO: make stateful component with "label" config # TODO: make stateful component with "label" config
merger = Matcher(doc.vocab) merger = Matcher(doc.vocab)
merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}]) merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
matches = merger(doc) matches = merger(doc)
spans = filter_spans([doc[start : end + 1] for _, start, end in matches]) spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:

View File

@ -63,18 +63,11 @@ def test_matcher_len_contains(matcher):
assert "TEST2" not in matcher assert "TEST2" not in matcher
def test_matcher_add_new_old_api(en_vocab): def test_matcher_add_new_api(en_vocab):
doc = Doc(en_vocab, words=["a", "b"]) doc = Doc(en_vocab, words=["a", "b"])
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]] patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
matcher.add("OLD_API", None, *patterns)
assert len(matcher(doc)) == 2
matcher = Matcher(en_vocab)
on_match = Mock() on_match = Mock()
matcher.add("OLD_API_CALLBACK", on_match, *patterns)
assert len(matcher(doc)) == 2
assert on_match.call_count == 2
# New API: add(key: str, patterns: List[List[dict]], on_match: Callable)
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
matcher.add("NEW_API", patterns) matcher.add("NEW_API", patterns)
assert len(matcher(doc)) == 2 assert len(matcher(doc)) == 2
@ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher):
def test_matcher_match_one_plus(matcher): def test_matcher_match_one_plus(matcher):
control = Matcher(matcher.vocab) control = Matcher(matcher.vocab)
control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}]) control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]])
doc = Doc(control.vocab, words=["Philippe", "Philippe"]) doc = Doc(control.vocab, words=["Philippe", "Philippe"])
m = control(doc) m = control(doc)
assert len(m) == 2 assert len(m) == 2

View File

@ -7,18 +7,10 @@ from spacy.tokens import Doc, Span
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}] pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}] pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}]
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}] pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern4 = [ pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
{"ORTH": "B"}, pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
pattern5 = [
{"ORTH": "B", "OP": "*"},
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
re_pattern1 = "AA*" re_pattern1 = "AA*"
re_pattern2 = "A*A" re_pattern2 = "A*A"
@ -26,10 +18,16 @@ re_pattern3 = "AA"
re_pattern4 = "BA*B" re_pattern4 = "BA*B"
re_pattern5 = "B*A*B" re_pattern5 = "B*A*B"
longest1 = "A A A A A"
longest2 = "A A A A A"
longest3 = "A A"
longest4 = "B A A A A A B" # "FIRST" would be "B B"
longest5 = "B B A A A A A B"
@pytest.fixture @pytest.fixture
def text(): def text():
return "(ABBAAAAAB)." return "(BBAAAAAB)."
@pytest.fixture @pytest.fixture
@ -41,25 +39,63 @@ def doc(en_tokenizer, text):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pattern,re_pattern", "pattern,re_pattern",
[ [
pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()), (pattern1, re_pattern1),
pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()), (pattern2, re_pattern2),
pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()), (pattern3, re_pattern3),
(pattern4, re_pattern4), (pattern4, re_pattern4),
pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()), (pattern5, re_pattern5),
], ],
) )
def test_greedy_matching(doc, text, pattern, re_pattern): def test_greedy_matching_first(doc, text, pattern, re_pattern):
"""Test that the greedy matching behavior of the * op is consistant with """Test that the greedy matching behavior "FIRST" is consistent with
other re implementations.""" other re implementations."""
matcher = Matcher(doc.vocab) matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern]) matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc) matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)] re_matches = [m.span() for m in re.finditer(re_pattern, text)]
for match, re_match in zip(matches, re_matches): for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches):
assert match[1:] == re_match # matching the string, not the exact position
assert doc[m_s:m_e].text == doc[re_s:re_e].text
@pytest.mark.parametrize(
"pattern,longest",
[
(pattern1, longest1),
(pattern2, longest2),
(pattern3, longest3),
(pattern4, longest4),
(pattern5, longest5),
],
)
def test_greedy_matching_longest(doc, text, pattern, longest):
"""Test the "LONGEST" greedy matching behavior"""
matcher = Matcher(doc.vocab)
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
for (key, s, e) in matches:
assert doc[s:e].text == longest
def test_greedy_matching_longest_first(en_tokenizer):
"""Test that "LONGEST" matching prefers the first of two equally long matches"""
doc = en_tokenizer(" ".join("CCC"))
matcher = Matcher(doc.vocab)
pattern = [{"ORTH": "C"}, {"ORTH": "C"}]
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
# out of 0-2 and 1-3, the first should be picked
assert len(matches) == 1
assert matches[0][1] == 0
assert matches[0][2] == 2
def test_invalid_greediness(doc, text):
matcher = Matcher(doc.vocab)
with pytest.raises(ValueError):
matcher.add("RULE", [pattern1], greedy="GREEDY")
@pytest.mark.xfail
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pattern,re_pattern", "pattern,re_pattern",
[ [
@ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern):
"""Test that matcher.__call__ consumes tokens on a match similar to """Test that matcher.__call__ consumes tokens on a match similar to
re.findall.""" re.findall."""
matcher = Matcher(doc.vocab) matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern]) matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc) matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)] re_matches = [m.span() for m in re.finditer(re_pattern, text)]
assert len(matches) == len(re_matches) assert len(matches) == len(re_matches)