Option for returning only greedy matches (#5771)

* add "greedy" option for match pattern

* distinction between greedy FIRST or LONGEST

* check for proper values, throw custom warning otherwise

* unxfail one more test

* add comment in docstring

* add test that LONGEST also prefers first match if equal length

* use c arrays for more efficient processing

* rename 'greediness' to 'greedy'
This commit is contained in:
Sofie Van Landeghem 2020-07-29 11:04:43 +02:00 committed by GitHub
parent 191a12d75f
commit 40c995b1be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 52 deletions

View File

@ -432,12 +432,12 @@ class Errors:
"Current DocBin: {current}\nOther DocBin: {other}")
E169 = ("Can't find module: {module}")
E170 = ("Cannot apply transition {name}: invalid for the current state.")
E171 = ("Matcher.add received invalid on_match callback argument: expected "
E171 = ("Matcher.add received invalid 'on_match' callback argument: expected "
"callable or None, but got: {arg_type}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
E177 = ("Ill-formed IOB input detected: {tag}")
E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you "
E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you "
"accidentally passed a single pattern to Matcher.add instead of a "
"list of patterns? If you only want to add one pattern, make sure "
"to wrap it in a list. For example: matcher.add('{key}', [pattern])")
@ -483,6 +483,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
"a List, but got: {arg_type}")
E952 = ("The section '{name}' is not a valid section in the provided config.")
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.")

View File

@ -66,6 +66,7 @@ cdef class Matcher:
cdef public object validate
cdef public object _patterns
cdef public object _callbacks
cdef public object _filter
cdef public object _extensions
cdef public object _extra_predicates
cdef public object _seen_attrs

View File

@ -1,6 +1,9 @@
# cython: infer_types=True, cython: profile=True
from typing import List
from libcpp.vector cimport vector
from libc.stdint cimport int32_t
from libc.string cimport memset, memcmp
from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64
@ -42,6 +45,7 @@ cdef class Matcher:
self._extra_predicates = []
self._patterns = {}
self._callbacks = {}
self._filter = {}
self._extensions = {}
self._seen_attrs = set()
self.vocab = vocab
@ -69,7 +73,7 @@ cdef class Matcher:
"""
return self._normalize_key(key) in self._patterns
def add(self, key, patterns, *_patterns, on_match=None):
def add(self, key, patterns, *, on_match=None, greedy: str=None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
@ -87,11 +91,10 @@ cdef class Matcher:
'+': Require the pattern to match 1 or more times.
'*': Allow the pattern to zero or more times.
The + and * operators are usually interpretted "greedily", i.e. longer
matches are returned where possible. However, if you specify two '+'
and '*' patterns in a row and their matches overlap, the first
operator will behave non-greedily. This quirk in the semantics makes
the matcher more efficient, by avoiding the need for back-tracking.
The + and * operators return all possible matches (not just the greedy
ones). However, the "greedy" argument can filter the final matches
by returning a non-overlapping set per key, either taking preference to
the first greedy match ("FIRST"), or the longest ("LONGEST").
As of spaCy v2.2.2, Matcher.add supports the future API, which makes
the patterns the second argument and a list (instead of a variable
@ -101,16 +104,15 @@ cdef class Matcher:
key (str): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
*_patterns (list): For backwards compatibility: list of patterns to add
as variable arguments. Will be ignored if a list of patterns is
provided as the second argument.
greedy (str): Optional filter: "FIRST" or "LONGEST".
"""
errors = {}
if on_match is not None and not hasattr(on_match, "__call__"):
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
if patterns is None or hasattr(patterns, "__call__"): # old API
on_match = patterns
patterns = _patterns
if patterns is None or not isinstance(patterns, List): # old API
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
if greedy is not None and greedy not in ["FIRST", "LONGEST"]:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy))
for i, pattern in enumerate(patterns):
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
@ -133,6 +135,7 @@ cdef class Matcher:
raise ValueError(Errors.E154.format())
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self._filter[key] = greedy
self._patterns[key].extend(patterns)
def remove(self, key):
@ -218,6 +221,7 @@ cdef class Matcher:
length = doclike.end - doclike.start
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
cdef Pool tmp_pool = Pool()
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged:
raise ValueError(Errors.E155.format())
@ -225,11 +229,42 @@ cdef class Matcher:
raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches):
final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered, non-overlapping ones
for (key, start, end) in matches:
span_filter = self._filter.get(key)
if span_filter is not None:
pairs = pairs_by_id.get(key, [])
pairs.append((start,end))
pairs_by_id[key] = pairs
else:
final_matches.append((key, start, end))
matched = <char*>tmp_pool.alloc(length, sizeof(char))
empty = <char*>tmp_pool.alloc(length, sizeof(char))
for key, pairs in pairs_by_id.items():
memset(matched, 0, length * sizeof(matched[0]))
span_filter = self._filter.get(key)
if span_filter == "FIRST":
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
elif span_filter == "LONGEST":
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for (start, end) in sorted_pairs:
assert 0 <= start < end # Defend against segfaults
span_len = end-start
# If no tokens in the span have matched
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
final_matches.append((key, start, end))
# Mark tokens that have matched
memset(&matched[start], 1, span_len * sizeof(matched[0]))
# perform the callbacks on the filtered set of results
for i, (key, start, end) in enumerate(final_matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
on_match(self, doc, i, final_matches)
return final_matches
def _normalize_key(self, key):
if isinstance(key, basestring):
@ -240,9 +275,9 @@ cdef class Matcher:
def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab)
for key, specs in patterns.items():
for key, pattern in patterns.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
matcher.add(key, pattern, on_match=callback)
return matcher

View File

@ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
"""
# TODO: make stateful component with "label" config
merger = Matcher(doc.vocab)
merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}])
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
matches = merger(doc)
spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
with doc.retokenize() as retokenizer:

View File

@ -63,18 +63,11 @@ def test_matcher_len_contains(matcher):
assert "TEST2" not in matcher
def test_matcher_add_new_old_api(en_vocab):
def test_matcher_add_new_api(en_vocab):
doc = Doc(en_vocab, words=["a", "b"])
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
matcher = Matcher(en_vocab)
matcher.add("OLD_API", None, *patterns)
assert len(matcher(doc)) == 2
matcher = Matcher(en_vocab)
on_match = Mock()
matcher.add("OLD_API_CALLBACK", on_match, *patterns)
assert len(matcher(doc)) == 2
assert on_match.call_count == 2
# New API: add(key: str, patterns: List[List[dict]], on_match: Callable)
matcher = Matcher(en_vocab)
matcher.add("NEW_API", patterns)
assert len(matcher(doc)) == 2
@ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher):
def test_matcher_match_one_plus(matcher):
control = Matcher(matcher.vocab)
control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}])
control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]])
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
m = control(doc)
assert len(m) == 2

View File

@ -7,18 +7,10 @@ from spacy.tokens import Doc, Span
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}]
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern4 = [
{"ORTH": "B"},
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
pattern5 = [
{"ORTH": "B", "OP": "*"},
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
re_pattern1 = "AA*"
re_pattern2 = "A*A"
@ -26,10 +18,16 @@ re_pattern3 = "AA"
re_pattern4 = "BA*B"
re_pattern5 = "B*A*B"
longest1 = "A A A A A"
longest2 = "A A A A A"
longest3 = "A A"
longest4 = "B A A A A A B" # "FIRST" would be "B B"
longest5 = "B B A A A A A B"
@pytest.fixture
def text():
return "(ABBAAAAAB)."
return "(BBAAAAAB)."
@pytest.fixture
@ -41,25 +39,63 @@ def doc(en_tokenizer, text):
@pytest.mark.parametrize(
"pattern,re_pattern",
[
pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()),
pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()),
pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()),
(pattern1, re_pattern1),
(pattern2, re_pattern2),
(pattern3, re_pattern3),
(pattern4, re_pattern4),
pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()),
(pattern5, re_pattern5),
],
)
def test_greedy_matching(doc, text, pattern, re_pattern):
"""Test that the greedy matching behavior of the * op is consistant with
def test_greedy_matching_first(doc, text, pattern, re_pattern):
"""Test that the greedy matching behavior "FIRST" is consistent with
other re implementations."""
matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern])
matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
for match, re_match in zip(matches, re_matches):
assert match[1:] == re_match
for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches):
# matching the string, not the exact position
assert doc[m_s:m_e].text == doc[re_s:re_e].text
@pytest.mark.parametrize(
"pattern,longest",
[
(pattern1, longest1),
(pattern2, longest2),
(pattern3, longest3),
(pattern4, longest4),
(pattern5, longest5),
],
)
def test_greedy_matching_longest(doc, text, pattern, longest):
"""Test the "LONGEST" greedy matching behavior"""
matcher = Matcher(doc.vocab)
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
for (key, s, e) in matches:
assert doc[s:e].text == longest
def test_greedy_matching_longest_first(en_tokenizer):
"""Test that "LONGEST" matching prefers the first of two equally long matches"""
doc = en_tokenizer(" ".join("CCC"))
matcher = Matcher(doc.vocab)
pattern = [{"ORTH": "C"}, {"ORTH": "C"}]
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
# out of 0-2 and 1-3, the first should be picked
assert len(matches) == 1
assert matches[0][1] == 0
assert matches[0][2] == 2
def test_invalid_greediness(doc, text):
matcher = Matcher(doc.vocab)
with pytest.raises(ValueError):
matcher.add("RULE", [pattern1], greedy="GREEDY")
@pytest.mark.xfail
@pytest.mark.parametrize(
"pattern,re_pattern",
[
@ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern):
"""Test that matcher.__call__ consumes tokens on a match similar to
re.findall."""
matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern])
matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
assert len(matches) == len(re_matches)